Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/interface-dependency-versions.yml
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,11 @@ jobs:

- name: Nightly Catalyst Version
id: catalyst
run: echo "nightly=--index https://test.pypi.org/simple/ --prerelease=allow --upgrade-package PennyLane-Catalyst PennyLane-Catalyst" >> $GITHUB_OUTPUT
run: echo "nightly=--upgrade-package PennyLane-Catalyst PennyLane-Catalyst" >> $GITHUB_OUTPUT

- name: PennyLane-Lightning Latest Version
id: pennylane-lightning
run: echo "latest=--index https://test.pypi.org/simple/ --prerelease=allow --upgrade-package PennyLane-Lightning PennyLane-Lightning" >> $GITHUB_OUTPUT
run: echo "latest=--upgrade-package PennyLane-Lightning PennyLane-Lightning" >> $GITHUB_OUTPUT

outputs:
catalyst-jax-version: jax==${{ steps.catalyst-jax.outputs.version }} jaxlib==${{ steps.catalyst-jax.outputs.version }}
Expand Down
2 changes: 2 additions & 0 deletions doc/development/release_notes.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ Release notes

This page contains the release notes for PennyLane.

.. mdinclude:: ../releases/changelog-0.42.1.md

.. mdinclude:: ../releases/changelog-0.42.0.md

.. mdinclude:: ../releases/changelog-0.41.1.md
Expand Down
2 changes: 1 addition & 1 deletion doc/releases/changelog-0.42.0.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
:orphan:

# Release 0.42.0 (current release)
# Release 0.42.0

<h3>New features since last release</h3>

Expand Down
14 changes: 14 additions & 0 deletions doc/releases/changelog-0.42.1.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
:orphan:

# Release 0.42.1 (current release)

<h3>Bug fixes 🐛</h3>

* A warning is raised if PennyLane is imported and a version of JAX greater than 0.6.2 is installed.
[(#7949)](https://github.com/PennyLaneAI/pennylane/pull/7949)

<h3>Contributors ✍️</h3>

This release contains contributions from (in alphabetical order):

Andrija Paurevic
15 changes: 15 additions & 0 deletions pennylane/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,21 @@
from pennylane.liealg import lie_closure, structure_constants, center
import pennylane.qnn


from importlib.metadata import version as _metadata_version
from importlib.util import find_spec as _find_spec
from packaging.version import Version as _Version

if _find_spec("jax") is not None:
if (jax_version := _Version(_metadata_version("jax"))) > _Version("0.6.2"): # pragma: no cover
warnings.warn(
"PennyLane is not yet compatible with JAX versions > 0.6.2. "
f"You have version {jax_version} installed. "
"Please downgrade JAX to 0.6.2 to avoid runtime errors using "
"python -m pip install jax~=0.6.0 jaxlib~=0.6.0",
RuntimeWarning,
)

# Look for an existing configuration file
default_config = Configuration("config.toml")

Expand Down
2 changes: 1 addition & 1 deletion pennylane/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@
Version number (major.minor.patch[-label])
"""

__version__ = "0.42.0"
__version__ = "0.42.1"
18 changes: 14 additions & 4 deletions pennylane/capture/base_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,15 @@
"""
This submodule defines a strategy structure for defining custom plxpr interpreters
"""
# pylint: disable=no-self-use
from copy import copy
from functools import partial, wraps

# pylint: disable=no-self-use, wrong-import-position
from importlib.metadata import version
from typing import Callable, Optional, Sequence

import jax
from packaging.version import Version

import pennylane as qml
from pennylane import math
Expand Down Expand Up @@ -635,15 +638,22 @@ class FlattenedInterpreter(PlxprInterpreter):
"""


jax_version = version("jax")
if Version(jax_version) > Version("0.6.2"): # pragma: no cover
from jax._src.pjit import jit_p as pjit_p
else: # pragma: no cover
from jax._src.pjit import pjit_p


# pylint: disable=protected-access
@FlattenedInterpreter.register_primitive(jax._src.pjit.pjit_p)
@FlattenedInterpreter.register_primitive(pjit_p)
def _(self, *invals, jaxpr, **params):
if jax.config.jax_dynamic_shapes:
# just evaluate it so it doesn't throw dynamic shape errors
return copy(self).eval(jaxpr.jaxpr, jaxpr.consts, *invals)

subfuns, params = jax._src.pjit.pjit_p.get_bind_params({"jaxpr": jaxpr, **params})
return jax._src.pjit.pjit_p.bind(*subfuns, *invals, **params)
subfuns, params = pjit_p.get_bind_params({"jaxpr": jaxpr, **params})
return pjit_p.bind(*subfuns, *invals, **params)


@FlattenedInterpreter.register_primitive(while_loop_prim)
Expand Down
2 changes: 1 addition & 1 deletion pennylane/capture/make_plxpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def fn(x):
if not has_jax: # pragma: no cover
raise ImportError(
"Module jax is required for the ``make_plxpr`` function. "
"You can install jax via: pip install jax"
"You can install jax via: pip install jax~=0.6.0"
)

if not qml.capture.enabled():
Expand Down
2 changes: 1 addition & 1 deletion pennylane/devices/qubit/apply_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,7 +711,7 @@ def _evolve_state_vector_under_parametrized_evolution(
except ImportError as e: # pragma: no cover
raise ImportError(
"Module jax is required for the ``ParametrizedEvolution`` class. "
"You can install jax via: pip install jax"
"You can install jax via: pip install jax~=0.6.0"
) from e

if operation.data is None or operation.t is None:
Expand Down
2 changes: 1 addition & 1 deletion pennylane/gradients/pulse_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def _assert_has_jax(transform_name):
if not has_jax: # pragma: no cover
raise ImportError(
f"Module jax is required for the {transform_name} gradient transform. "
"You can install jax via: pip install jax jaxlib"
"You can install jax via: pip install jax~=0.6.0 jaxlib~=0.6.0"
)


Expand Down
4 changes: 2 additions & 2 deletions pennylane/labs/dla/variational_kak.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def Kc(theta_opt):

if not has_jax: # pragma: no cover
raise ImportError(
"jax and optax are required for variational_kak_adj. You can install them with pip install jax jaxlib optax."
"jax and optax are required for variational_kak_adj. You can install them with pip install jax~=0.6.0 jaxlib~=0.6.0 optax."
) # pragma: no cover
if verbose >= 1 and not has_plt: # pragma: no cover
print(
Expand Down Expand Up @@ -392,7 +392,7 @@ def cost(x):

if not has_jax: # pragma: no cover
raise ImportError(
"jax and optax are required for run_opt. You can install them with pip install jax jaxlib optax."
"jax and optax are required for run_opt. You can install them with pip install jax~=0.6.0 jaxlib~=0.6.0 optax."
) # pragma: no cover

if optimizer is None:
Expand Down
6 changes: 3 additions & 3 deletions pennylane/pulse/convenience_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def f(p, t):
if not has_jax:
raise ImportError(
"Module jax is required for any pulse-related convenience function. "
"You can install jax via: pip install jax==0.4.10 jaxlib==0.4.10"
"You can install jax via: pip install jax~=0.6.0 jaxlib~=0.6.0"
)
if windows is not None:
is_nested = any(hasattr(w, "__len__") for w in windows)
Expand Down Expand Up @@ -291,7 +291,7 @@ def wrapped(p, t):
if not has_jax:
raise ImportError(
"Module jax is required for any pulse-related convenience function. "
"You can install jax via: pip install jax==0.4.3 jaxlib==0.4.3"
"You can install jax via: pip install jax~=0.6.0 jaxlib~=0.6.0"
)

if isinstance(timespan, (tuple, list)):
Expand Down Expand Up @@ -365,7 +365,7 @@ def fn(params, t):
if not has_jax:
raise ImportError(
"Module jax is required for any pulse-related convenience function. "
"You can install jax via: pip install jax==0.4.3 jaxlib==0.4.3"
"You can install jax via: pip install jax~=0.6.0 jaxlib~=0.6.0"
)

if isinstance(timespan, tuple):
Expand Down
4 changes: 2 additions & 2 deletions pennylane/pulse/parametrized_evolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ def __call__(
if not has_jax:
raise ImportError(
"Module jax is required for the ``ParametrizedEvolution`` class. "
"You can install jax via: pip install jax"
"You can install jax via: pip install jax~=0.6.0"
)
# Need to cast all elements inside params to `jnp.arrays` to make sure they are not cast
# to `np.arrays` inside `Operator.__init__`
Expand Down Expand Up @@ -511,7 +511,7 @@ def matrix(self, wire_order=None):
if not has_jax:
raise ImportError(
"Module jax is required for the ``ParametrizedEvolution`` class. "
"You can install jax via: pip install jax"
"You can install jax via: pip install jax~=0.6.0"
)
if not self.has_matrix:
raise ValueError(
Expand Down
4 changes: 2 additions & 2 deletions pennylane/qchem/factorization.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def factorize(
.. note::

Packages JAX and Optax are required when performing CDF with ``compressed=True``.
Install them using ``pip install jax optax``.
Install them using ``pip install jax~=0.6.0 optax``.

Args:
two_electron (array[array[float]]): Two-electron integral tensor in the molecular orbital
Expand Down Expand Up @@ -262,7 +262,7 @@ def factorize(
if not has_jax_optax:
raise ImportError(
"Jax and Optax libraries are required for optimizing the factors. Install them via "
"pip install jax optax"
"pip install jax~=0.6.0 optax"
) # pragma: no cover

norm_order = {None: None, "L1": 1, "L2": 2}.get(regularization, "LX")
Expand Down