Skip to content

Commit a9fbc84

Browse files
Merge branch 'master' into pl_qualtran_prototype
2 parents ebfc992 + bbea644 commit a9fbc84

File tree

8 files changed

+420
-181
lines changed

8 files changed

+420
-181
lines changed

‎doc/releases/changelog-dev.md‎

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -856,6 +856,10 @@ Here's a list of deprecations made this release. For a more detailed breakdown o
856856

857857
<h3>Documentation 📝</h3>
858858

859+
* The functions in `qml.qchem.vibrational` are updated to include additional information about the
860+
theory and input arguments.
861+
[(#6918)](https://github.com/PennyLaneAI/pennylane/pull/6918)
862+
859863
* The usage examples for `qml.decomposition.DecompositionGraph` have been updated.
860864
[(#7692)](https://github.com/PennyLaneAI/pennylane/pull/7692)
861865

‎pennylane/compiler/python_compiler/jax_utils.py‎

Lines changed: 43 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from typing import Callable, TypeAlias
1919

2020
import jaxlib
21+
from catalyst import QJIT
2122
from jaxlib.mlir.dialects import stablehlo as jstablehlo # pylint: disable=no-name-in-module
2223
from jaxlib.mlir.ir import Context as jContext # pylint: disable=no-name-in-module
2324
from jaxlib.mlir.ir import Module as jModule # pylint: disable=no-name-in-module
@@ -30,6 +31,7 @@
3031
from xdsl.dialects import tensor as xtensor
3132
from xdsl.dialects import transform as xtransform
3233
from xdsl.parser import Parser as xParser
34+
from xdsl.traits import SymbolTable as xSymbolTable
3335

3436
from pennylane.compiler.python_compiler.quantum_dialect import QuantumDialect
3537

@@ -42,9 +44,7 @@ def _module_inline(func: JaxJittedFunction, *args, **kwargs) -> jModule:
4244

4345

4446
def module(func: JaxJittedFunction) -> Callable[..., jModule]:
45-
"""
46-
Decorator for _module_inline
47-
"""
47+
"""Decorator for _module_inline"""
4848

4949
@wraps(func)
5050
def wrapper(*args, **kwargs) -> jModule:
@@ -54,18 +54,14 @@ def wrapper(*args, **kwargs) -> jModule:
5454

5555

5656
def _generic_inline(func: JaxJittedFunction, *args, **kwargs) -> str: # pragma: no cover
57-
"""
58-
Create the generic textual representation for the jax.jit'ed function
59-
"""
57+
"""Create the generic textual representation for the jax.jit'ed function"""
6058
lowered = func.lower(*args, **kwargs)
6159
mod = lowered.compiler_ir()
6260
return mod.operation.get_asm(binary=False, print_generic_op_form=True, assume_verified=True)
6361

6462

6563
def generic(func: JaxJittedFunction) -> Callable[..., str]: # pragma: no cover
66-
"""
67-
Decorator for _generic_inline.
68-
"""
64+
"""Decorator for _generic_inline."""
6965

7066
@wraps(func)
7167
def wrapper(*args, **kwargs) -> str:
@@ -125,12 +121,47 @@ def wrapper(*_, **__):
125121

126122

127123
def xdsl_module(func: JaxJittedFunction) -> Callable[..., xbuiltin.ModuleOp]: # pragma: no cover
128-
"""
129-
Decorator for _xdsl_module_inline
130-
"""
124+
"""Decorator for _xdsl_module_inline"""
131125

132126
@wraps(func)
133127
def wrapper(*args, **kwargs) -> xbuiltin.ModuleOp:
134128
return _xdsl_module_inline(func, *args, **kwargs)
135129

136130
return wrapper
131+
132+
133+
def inline_module(
134+
from_mod: xbuiltin.ModuleOp, to_mod: xbuiltin.ModuleOp, change_main_to: str = None
135+
) -> None:
136+
"""Inline the contents of one xDSL module into another xDSL module. The inlined body is appended
137+
to the end of ``to_mod``."""
138+
if change_main_to:
139+
main = xSymbolTable.lookup_symbol(from_mod, "main")
140+
if main is not None:
141+
assert isinstance(main, xfunc.FuncOp)
142+
main.properties["sym_name"] = xbuiltin.StringAttr(change_main_to)
143+
144+
for op in from_mod.body.ops:
145+
xSymbolTable.insert_or_update(to_mod, op.clone())
146+
147+
148+
def inline_jit_to_module(func: JaxJittedFunction, mod: xbuiltin.ModuleOp, *args, **kwargs) -> None:
149+
"""Inline a ``jax.jit``-ed Python function to an xDSL module. The inlined body is appended
150+
to the end of ``mod``."""
151+
func_mod = _xdsl_module_inline(func, *args, **kwargs)
152+
inline_module(func_mod, mod, change_main_to=func.__name__)
153+
154+
155+
def xdsl_from_qjit(func: QJIT) -> Callable[..., xbuiltin.ModuleOp]:
156+
"""Decorator to convert QJIT-ed functions into xDSL modules."""
157+
158+
@wraps(func)
159+
def wrapper(*args, **kwargs):
160+
func.jaxpr, *_ = func.capture(args, **kwargs)
161+
mlir_module = func.generate_ir()
162+
generic_str = mlir_module.operation.get_asm(
163+
binary=False, print_generic_op_form=True, assume_verified=True
164+
)
165+
return parse_generic_to_xdsl_module(generic_str)
166+
167+
return wrapper

‎pennylane/labs/tests/vibrational/test_pes_generator.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828
h5py = pytest.importorskip("h5py")
2929

30-
# pylint: disable=too-many-arguments, protected-access, too-many-positional-arguments
30+
# pylint: disable=too-many-arguments, protected-access, too-many-positional-arguments, unsubscriptable-object
3131

3232
ref_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "test_ref_files")
3333

‎pennylane/qchem/vibrational/localize_modes.py‎

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -174,26 +174,38 @@ def _localize_modes(freqs, vecs):
174174

175175

176176
def localize_normal_modes(freqs, vecs, bins=[2600]):
177-
"""
178-
Localizes vibrational normal modes.
177+
r"""Computes spatially localized vibrational normal modes.
179178
180-
The normal modes are localized by separating frequencies into specified ranges following the
181-
procedure described in `J. Chem. Phys. 141, 104105 (2014)
179+
The vibrational normal modes are localized using a localizing unitary following the procedure
180+
described in `J. Chem. Phys. 141, 104105 (2014)
182181
<https://pubs.aip.org/aip/jcp/article-abstract/141/10/104105/74317/
183-
Efficient-anharmonic-vibrational-spectroscopy-for?redirectedFrom=fulltext>`_.
182+
Efficient-anharmonic-vibrational-spectroscopy-for?redirectedFrom=fulltext>`_. The localizing
183+
unitary :math:`U` is defined in terms of the normal and local coordinates, :math:`q` and
184+
:math:`\tilde{q}`, respectively as:
185+
186+
.. math::
187+
188+
\tilde{q} = \sum_{j=1}^M U_{ij} q_j,
189+
190+
where :math:`M` is the number of modes. The normal modes
191+
can be separately localized, to prevent mixing between specific groups of normal modes, by
192+
defining frequency ranges in ``bins``. For instance, ``bins = [2600]`` allows to separately
193+
localize modes that have frequencies above and below :math:`2600` reciprocal centimetre (:math:`\text{cm}^{-1}`).
194+
Similarly, ``bins = [1300, 2600]`` allows to separately localize modes in three groups that have
195+
frequencies below :math:`1300`, between :math:`1300-2600` and above :math:`2600`.
184196
185197
Args:
186-
freqs (list[float]): normal mode frequencies in ``cm^-1``
187-
vecs (TensorLike[float]): displacement vectors for normal modes
188-
bins (list[float]): List of upper bound frequencies in ``cm^-1`` for creating separation bins .
189-
Default is ``[2600]`` which means having one bin for all frequencies between ``0`` and ``2600 cm^-1``.
198+
freqs (TensorLike[float]): normal mode frequencies in reciprocal centimetre (:math:`\text{cm}^{-1}`).
199+
vecs (TensorLike[float]): displacement vectors of the normal modes
200+
bins (List[float]): grid of frequencies for grouping normal modes.
201+
Default is ``[2600]``.
190202
191203
Returns:
192204
tuple: A tuple containing the following:
193-
- list[float] : localized frequencies
194-
- TensorLike[float] : localized displacement vectors
195-
- TensorLike[float] : localization matrix describing the relationship between
196-
original and localized modes.
205+
- TensorLike[float] : localized frequencies in reciprocal centimetre (:math:`\text{cm}^{-1}`).
206+
- List[TensorLike[float]] : localized displacement vectors
207+
- TensorLike[float] : localization matrix describing the relationship between the
208+
original and the localized modes
197209
198210
**Example**
199211
@@ -209,7 +221,7 @@ def localize_normal_modes(freqs, vecs, bins=[2600]):
209221
... [-5.49709883e-17, 7.49851221e-08, -2.77912798e-02]]])
210222
>>> freqs_loc, vecs_loc, uloc = qml.qchem.localize_normal_modes(freqs, vectors)
211223
>>> freqs_loc
212-
array([1332.62008773, 2296.73455892, 2296.7346082 ])
224+
array([1332.62013257, 2296.73453455, 2296.73460655])
213225
214226
"""
215227
if not bins:

0 commit comments

Comments
 (0)