Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
88 commits
Select commit Hold shift + click to select a range
ae67a0b
Non-functioning first draft of plxpr defer_measurements
mudit2812 Jan 15, 2025
50e62b4
conditions work.. kinda?
mudit2812 Jan 17, 2025
a5a33ea
Classical control works
mudit2812 Jan 22, 2025
5196be3
MCM statistics work
mudit2812 Jan 23, 2025
4b4a741
Remove commented code
mudit2812 Jan 23, 2025
0127f81
Merge branch 'master' into capture-defer-measurements
mudit2812 Jan 23, 2025
6878f38
Add template tests
mudit2812 Jan 27, 2025
8c35f86
Add unit tests
mudit2812 Jan 27, 2025
12ddca5
Add gate parameter support; make items public method
mudit2812 Jan 28, 2025
db7621f
PlxprInterpreter custom registrations handle constants correctly
mudit2812 Jan 28, 2025
8561866
Add changelog entry
mudit2812 Jan 28, 2025
6bdf7b2
Update usage of state
mudit2812 Jan 28, 2025
1c16f1f
Finish adding tests
mudit2812 Jan 28, 2025
ffdb278
Add link to changelog entry
mudit2812 Jan 28, 2025
9eaa4d9
Add program capture usage details to defer_measurements docs
mudit2812 Jan 28, 2025
bc7a504
Merge branch 'master' into capture-defer-measurements
mudit2812 Jan 28, 2025
efffcb9
Update base interpreter custom registrations
mudit2812 Jan 28, 2025
432d0db
Linting
mudit2812 Jan 28, 2025
511f1fe
Linting
mudit2812 Jan 29, 2025
5ddab87
Coverage
mudit2812 Jan 29, 2025
0c84693
Merge branch 'master' into capture-defer-measurements
mudit2812 Jan 29, 2025
cdb161f
Minor warning fix
mudit2812 Jan 29, 2025
c33e387
Merge branch 'master' into capture-defer-measurements
mudit2812 Jan 29, 2025
aca613b
Address code review
mudit2812 Jan 30, 2025
6653763
Merge branch 'master' into capture-defer-measurements
mudit2812 Jan 30, 2025
1c3141e
Merge branch 'master' into capture-defer-measurements
mudit2812 Feb 3, 2025
2e60440
Merge branch 'master' into capture-defer-measurements
mudit2812 Feb 5, 2025
fc2bee5
Move `get_mcm_predicates` to mid_measure.py and expand docstring
mudit2812 Feb 5, 2025
18b13d9
Improve structure and add doc comments to `resolve_mcm_values`
mudit2812 Feb 5, 2025
798b4a6
Update docstring with constraints about control flow
mudit2812 Feb 5, 2025
baa24c8
Address code review
mudit2812 Feb 6, 2025
e5cdf80
Merge branch 'master' into capture-defer-measurements
mudit2812 Feb 6, 2025
704079a
[skip ci] Skip CI
mudit2812 Feb 6, 2025
2c69a4a
Use dict.get instead of pop
mudit2812 Feb 7, 2025
52f60ed
Fix changelog entry
mudit2812 Feb 7, 2025
6503f79
Merge branch 'master' into capture-defer-measurements
mudit2812 Feb 7, 2025
d4e3ad5
Add note about mcms as gate params
mudit2812 Feb 11, 2025
4ad7cb8
Update pennylane/transforms/defer_measurements.py
mudit2812 Feb 11, 2025
74d5ed7
Merge branch 'master' into capture-defer-measurements
mudit2812 Feb 11, 2025
6ba6ca9
Update pennylane/transforms/defer_measurements.py
mudit2812 Feb 13, 2025
54540a8
Add error for multiple MCM parameters
mudit2812 Feb 13, 2025
65d4a1a
Merge branch 'master' into capture-defer-measurements
mudit2812 Feb 13, 2025
9c2d881
Merge branch 'master' into capture-defer-measurements
mudit2812 Feb 18, 2025
d066d05
Fix changelog
mudit2812 Feb 18, 2025
8160a05
DQ works with defer_measurements
mudit2812 Feb 7, 2025
b2f9f13
Fix use of plxpr fixture in dq_interpreter tests
mudit2812 Feb 12, 2025
fff52ab
Add skeleton for tests
mudit2812 Feb 12, 2025
182b383
Add tests for custom registrations
mudit2812 Feb 18, 2025
78700e4
Remove test class
mudit2812 Feb 19, 2025
f70ae01
Add execution tests
mudit2812 Feb 19, 2025
982d1a3
Update changelog
mudit2812 Feb 19, 2025
a73fbb3
Update changelog
mudit2812 Feb 20, 2025
c00d10f
Merge branch 'master' into dq-capture-deferred
mudit2812 Feb 20, 2025
972f053
Fix test merge conflicts
mudit2812 Feb 20, 2025
37f4121
[skip ci] Remove xfails
mudit2812 Feb 20, 2025
19bb7e4
Add skeleton for shots support with defer_measurements on default.qubit
mudit2812 Feb 12, 2025
c6a6b9b
Added support for hw-like postselect_mode
mudit2812 Feb 18, 2025
08bc35b
Partial changes to rework `defer_measurements`; not integrated new ar…
mudit2812 Feb 20, 2025
02cf0ba
switch from qnode_kwargs to execution_config
albi3ro Feb 21, 2025
d879b06
Merge branch 'master' into qnode-prim-execution-config
albi3ro Feb 21, 2025
73e27fc
Apply suggestions from code review
albi3ro Feb 21, 2025
11ba2a9
null qubit and pylint
albi3ro Feb 21, 2025
28e1146
Merge branch 'qnode-prim-execution-config' of https://github.com/Penn…
albi3ro Feb 21, 2025
93d8548
Merge branch 'master' into dq-defer-measurements-shots
mudit2812 Feb 21, 2025
9823ee1
Merge branch 'qnode-prim-execution-config' into dq-defer-measurements…
mudit2812 Feb 21, 2025
c1e946a
use construct_execution_config
albi3ro Feb 21, 2025
103be57
Apply suggestions from code review
albi3ro Feb 21, 2025
6602aa3
Merge branch 'master' into qnode-prim-execution-config
albi3ro Feb 21, 2025
a85f7ad
Finish num_wires support; add reduce_postselected argument
mudit2812 Feb 21, 2025
7d5ff2a
Merge branch 'qnode-prim-execution-config' into dq-defer-measurements…
mudit2812 Feb 21, 2025
0952ae3
Add hw-like support; TODO: figure out how to handle control flow
mudit2812 Feb 21, 2025
c314885
Update mcm_method integration with qnode_prim.impl
mudit2812 Feb 24, 2025
7ced640
Fix tests
mudit2812 Feb 24, 2025
fc07a29
Merge branch 'master' into dq-defer-measurements-shots
mudit2812 Feb 25, 2025
550728d
Remove qnode processing of mcms
mudit2812 Feb 25, 2025
db01d78
Move execution tests to new file
mudit2812 Feb 25, 2025
2aea17a
Add test skeleton
mudit2812 Feb 25, 2025
308b4d8
Add tests
mudit2812 Feb 26, 2025
89ad52a
Merge branch 'master' into dq-defer-measurements-shots
mudit2812 Feb 26, 2025
4406d51
Update defer_measurements for program capture sharp bits
mudit2812 Feb 26, 2025
3e403cf
Add link to changelog entry
mudit2812 Feb 26, 2025
c58c792
Fix tests; add more unit tests
mudit2812 Feb 26, 2025
a7f828a
Merge branch 'master' into dq-defer-measurements-shots
mudit2812 Feb 26, 2025
3905b7d
Add tests that use mocking for checking sampling behaviour
mudit2812 Feb 28, 2025
905f444
Merge branch 'master' into dq-defer-measurements-shots
mudit2812 Feb 28, 2025
f296322
Add tests for coverage
mudit2812 Feb 28, 2025
58e40ca
Merge branch 'master' into dq-defer-measurements-shots
mudit2812 Mar 4, 2025
e4141ab
Linting
mudit2812 Mar 4, 2025
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
`qml.defer_measurements` can be executed on `default.qubit`.
[(#6838)](https://github.com/PennyLaneAI/pennylane/pull/6838)
[(#6937)](https://github.com/PennyLaneAI/pennylane/pull/6937)
[(#6961)](https://github.com/PennyLaneAI/pennylane/pull/6961)

Using `qml.defer_measurements` with program capture enables many new features, including:
* Significantly richer variety of classical processing on mid-circuit measurement values.
Expand Down
5 changes: 4 additions & 1 deletion pennylane/devices/default_qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -955,7 +955,10 @@ def eval_jaxpr(
key = jax.random.PRNGKey(self._rng.integers(100000))

interpreter = DefaultQubitInterpreter(
num_wires=len(self.wires), shots=self.shots.total_shots, key=key
num_wires=len(self.wires),
shots=self.shots.total_shots,
key=key,
execution_config=execution_config,
)
return interpreter.eval(jaxpr, consts, *args)

Expand Down
40 changes: 34 additions & 6 deletions pennylane/devices/qubit/dq_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from pennylane.capture import pause
from pennylane.capture.base_interpreter import FlattenedHigherOrderPrimitives, PlxprInterpreter
from pennylane.capture.primitives import adjoint_transform_prim, ctrl_transform_prim, measure_prim
from pennylane.devices import ExecutionConfig
from pennylane.measurements import MidMeasureMP, Shots
from pennylane.ops import adjoint, ctrl
from pennylane.ops.qubit import Projector
Expand All @@ -33,7 +34,7 @@
from .simulate import _postselection_postprocess # pylint: disable=protected-access


# pylint: disable=attribute-defined-outside-init, access-member-before-definition
# pylint: disable=attribute-defined-outside-init, access-member-before-definition,too-many-instance-attributes
class DefaultQubitInterpreter(PlxprInterpreter):
"""Implements a class for interpreting plxpr using python simulation tools.

Expand Down Expand Up @@ -77,11 +78,15 @@ def __copy__(self):
return inst

def __init__(
self, num_wires: int, shots: int | None = None, key: None | jax.numpy.ndarray = None
self,
num_wires: int,
shots: int | None = None,
key: None | jax.numpy.ndarray = None,
execution_config: None | ExecutionConfig = None,
):
self.num_wires = num_wires
self.shots = Shots(shots)
if self.shots.has_partitioned_shots:
self.original_shots = Shots(shots)
if self.original_shots.has_partitioned_shots:
raise NotImplementedError(
"DefaultQubitInterpreter does not yet support partitioned shots."
)
Expand All @@ -90,6 +95,8 @@ def __init__(

self.initial_key = key
self.stateref = None
self.execution_config = execution_config or ExecutionConfig()

super().__init__()

@property
Expand All @@ -107,6 +114,21 @@ def state(self, new_val):
except TypeError as e:
raise AttributeError("execution not yet initialized.") from e

@property
def shots(self):
"""The shots"""
try:
return self.stateref["shots"]
except TypeError as e:
raise AttributeError("execution not yet initialized.") from e

@shots.setter
def shots(self, new_val):
try:
self.stateref["shots"] = new_val
except TypeError as e:
raise AttributeError("execution not yet initialized.") from e

@property
def key(self):
"""A jax PRNGKey for random number generation."""
Expand Down Expand Up @@ -141,6 +163,7 @@ def setup(self) -> None:
if self.stateref is None:
self.stateref = {
"state": create_initial_state(range(self.num_wires), like="jax"),
"shots": self.original_shots,
"key": self.initial_key,
"is_state_batched": False,
}
Expand All @@ -156,8 +179,13 @@ def interpret_operation(self, op):
self.is_state_batched = True

if isinstance(op, Projector):
self.state, _ = _postselection_postprocess(
self.state, self.is_state_batched, self.shots
self.key, new_key = jax.random.split(self.key, 2)
self.state, self.shots = _postselection_postprocess(
self.state,
self.is_state_batched,
self.shots,
prng_key=new_key,
postselect_mode=self.execution_config.mcm_config.postselect_mode,
)

return op
Expand Down
130 changes: 95 additions & 35 deletions pennylane/transforms/defer_measurements.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from pennylane.ops.op_math import ctrl
from pennylane.queuing import QueuingManager
from pennylane.tape import QuantumScript, QuantumScriptBatch
from pennylane.transforms import transform
from pennylane.transforms import TransformError, transform
from pennylane.typing import PostprocessingFn
from pennylane.wires import Wires

Expand Down Expand Up @@ -119,7 +119,7 @@ def _get_plxpr_defer_measurements():
# pylint: disable=import-outside-toplevel
import jax

from pennylane.capture import CaptureError, PlxprInterpreter
from pennylane.capture import PlxprInterpreter
from pennylane.capture.primitives import cond_prim, ctrl_transform_prim, measure_prim
except ImportError: # pragma: no cover
return None, None
Expand All @@ -131,13 +131,15 @@ class DeferMeasurementsInterpreter(PlxprInterpreter):

# pylint: disable=unnecessary-lambda-assignment,attribute-defined-outside-init,no-self-use

def __init__(self, aux_wires):
def __init__(self, num_wires):
super().__init__()
self._aux_wires = Wires(aux_wires)
self._num_wires = num_wires

# We use a dict here instead of a normal int variable because we want the state to mutate
# when we interpret higher-order primitives
self.state = {"cur_idx": 0}
# We store all used wires rather than just the max because if the wires are tracers, then
# we can't do comparisons to find the max wire
self.state = {"cur_target": num_wires - 1, "used_wires": set()}

def cleanup(self) -> None:
"""Perform any final steps after iterating through all equations.
Expand All @@ -147,7 +149,27 @@ def cleanup(self) -> None:
be used for the target wire of the next mid-circuit measurement's replacement
:class:`~pennylane.CNOT`.
"""
self.state = {"cur_idx": 0}
self.state = {"cur_target": self._num_wires - 1, "used_wires": set()}

def _update_used_wires(self, wires: qml.wires.Wires, cur_target: int):
"""Update the state with the number of wires that have been used and validate that
there is no overlap between the used circuit wires and the mid-circuit measurement
target wires.

Args:
wires (pennylane.wires.Wires): wires to add to the set of used wires
cur_target (int): target wire to be used for a mid-circuit measurement

Raises:
TransformError: if there is overlap between the used circuit wires and mid-circuit
measurement target wires
"""
self.state["used_wires"] |= wires.toset()
if self.state["used_wires"].intersection(range(cur_target, self._num_wires)):
raise TransformError(
"Too many mid-circuit measurements for the specified number of wires "
"with 'defer_measurements'."
)

def interpret_dynamic_operation(self, data, struct, inds):
"""Interpret an operation that uses mid-circuit measurement outcomes as parameters.
Expand All @@ -161,13 +183,14 @@ def interpret_dynamic_operation(self, data, struct, inds):
struct (PyTreeDef): Pytree structure of the operator
inds (Sequence[int]): Indices of mid-circuit measurement values in ``data``

Returns:
None
Raises:
TransformError: if there is overlap between the used circuit wires and mid-circuit
measurement target wires
"""
if len(inds) > 1:
raise CaptureError(
raise TransformError(
"Cannot create operations with multiple parameters based on "
"mid-circuit measurements."
"mid-circuit measurements with 'defer_measurements'."
)

idx = inds[0]
Expand All @@ -192,10 +215,14 @@ def interpret_operation(self, op: "qml.operation.Operator"):
See also: :meth:`~.interpret_operation_eqn`.

"""
# Range for comparison is [cur_target + 1, num_wires) because cur_target
# is the _next_ wire to be used for an MCM. We want to check if the used
# wires overlap with the already applied MCMs.
self._update_used_wires(op.wires, self.state["cur_target"] + 1)

# We treat operators with operators based on mid-circuit measurement values
# separately, and otherwise default to the standard behaviour
data, struct = jax.tree_util.tree_flatten(op)

mcm_data_inds = []
for i, d in enumerate(data):
if isinstance(d, MeasurementValue):
Expand All @@ -219,9 +246,14 @@ def interpret_measurement(self, measurement: "qml.measurement.MeasurementProcess
kwargs = {"wires": measurement.wires, "eigvals": measurement.eigvals()}
if isinstance(measurement, CountsMP):
kwargs["all_outcomes"] = measurement.all_outcomes

measurement = type(measurement)(**kwargs)

else:
# Range for comparison is [cur_target + 1, num_wires) because cur_target
# is the _next_ wire to be used for an MCM. We want to check if the used
# wires overlap with the already applied MCMs.
self._update_used_wires(measurement.wires, self.state["cur_target"] + 1)

return super().interpret_measurement(measurement)

def resolve_mcm_values(
Expand Down Expand Up @@ -330,24 +362,19 @@ def eval(self, jaxpr: "jax.core.Jaxpr", consts: list, *args) -> list:

@DeferMeasurementsInterpreter.register_primitive(measure_prim)
def _(self, wires, reset=False, postselect=None):
if self.state["cur_idx"] >= len(self._aux_wires):
raise ValueError(
"Not enough auxiliary wires provided to apply specified number of mid-circuit "
"measurements using qml.defer_measurements."
)
cur_target = self.state["cur_target"]
# Range for comparison is [cur_target, num_wires) because cur_target
# is the _current_ wire to be used for an MCM.
self._update_used_wires(Wires(wires), cur_target)

# Using type.__call__ instead of normally constructing the class prevents
# the primitive corresponding to the class to get binded. We do not want the
# MidMeasureMP's primitive to get recorded.
meas = type.__call__(
MidMeasureMP,
Wires(self._aux_wires[self.state["cur_idx"]]),
reset=reset,
postselect=postselect,
id=self.state["cur_idx"],
MidMeasureMP, Wires(cur_target), reset=reset, postselect=postselect, id=str(cur_target)
)

cnot_wires = (wires, self._aux_wires[self.state["cur_idx"]])
cnot_wires = (wires, cur_target)
if postselect is not None:
qml.Projector(jax.numpy.array([postselect]), wires=wires)

Expand All @@ -358,7 +385,7 @@ def _(self, wires, reset=False, postselect=None):
elif postselect == 1:
qml.PauliX(wires=wires)

self.state["cur_idx"] += 1
self.state["cur_target"] -= 1
return MeasurementValue([meas], lambda x: x)

@DeferMeasurementsInterpreter.register_primitive(cond_prim)
Expand Down Expand Up @@ -390,6 +417,7 @@ def _(
control_wires = Wires([m.wires[0] for m in condition.measurements])

for branch, value in condition.items():
# When reduce_postselected is True, some branches can be ()
cur_consts = invals[consts_slices[i]]
qml.cond(value, ctrl_transform_prim.bind)(
*cur_consts,
Expand All @@ -407,9 +435,9 @@ def _(
def defer_measurements_plxpr_to_plxpr(jaxpr, consts, targs, tkwargs, *args):
"""Function for applying the ``defer_measurements`` transform on plxpr."""

if not tkwargs.get("aux_wires", None):
if not tkwargs.get("num_wires", None):
raise ValueError(
"'aux_wires' argument for qml.defer_measurements must be provided "
"'num_wires' argument for qml.defer_measurements must be provided "
"when qml.capture.enabled() is True."
)
if tkwargs.pop("reduce_postselected", False):
Expand Down Expand Up @@ -444,7 +472,7 @@ def defer_measurements(
tape: QuantumScript,
reduce_postselected: bool = True,
allow_postselect: bool = True,
aux_wires: Optional[Union[int, Sequence[int], Wires]] = None,
num_wires: Optional[int] = None,
) -> tuple[QuantumScriptBatch, PostprocessingFn]:
"""Quantum function transform that substitutes operations conditioned on
measurement outcomes to controlled operations.
Expand Down Expand Up @@ -503,7 +531,7 @@ def defer_measurements(
allow_postselect (bool): Whether postselection is allowed. In order to perform postselection
with ``defer_measurements``, the device must support the :class:`~.Projector` operation.
Defaults to ``True``. This is currently ignored if program capture is enabled.
aux_wires (Sequence): Optional sequence of wires to use to map mid-circuit measurements. This is
num_wires (int): Optional argument to specify the total number of circuit wires. This is
only used if program capture is enabled.

Returns:
Expand Down Expand Up @@ -619,9 +647,41 @@ def node(x):
:title: Deferred measurements with program capture

``qml.defer_measurements`` can be applied to callables when program capture is enabled. To do so,
the ``aux_wires`` argument must be provided, which should be a sequence of integers to be used
as the target wires for transforming mid-circuit measurements. With program capture enabled, some
new features, as well as new restrictions are introduced, that are detailed below:
the ``num_wires`` argument must be provided, which should be an integer corresponding to the total
number of available wires. For ``m`` mid-circuit measurements, ``range(num_wires - m, num_wires)``
will be the range of wires used to map mid-circuit measurements to ``CNOT`` gates.

.. warning::

While the transform includes validation to avoid overlap between wires of the original
circuit and mid-circuit measurement target wires, if any wires of the original ciruit
are traced, i.e. dependent on dynamic arguments to the transformed workflow, the
validation may not catch overlaps. Consider the following example:

.. code-block:: python

from functools import partial
import jax

qml.capture.enable()

@qml.capture.expand_plxpr_transforms
@partial(qml.defer_measurements, num_wires=1)
def f(n):
qml.measure(n)

>>> jax.make_jaxpr(f)(0)
{ lambda ; a:i64[]. let _:AbstractOperator() = CNOT[n_wires=2] a 0 in () }

The circuit gets transformed without issue because the concrete value of the measured wire
is unknown. However, execution with n = 0 would raise an error, as the CNOT wires would
be (0, 0).

Thus, users must by cautious when transforming a circuit. **For ``n`` total wires and
``c`` circuit wires, the number of mid-circuit measurements allowed is ``n - c``.**

Using ``defer_measurements`` with program capture enabled introduces new features and
restrictions:

**New features**

Expand All @@ -645,7 +705,7 @@ def node(x):
qml.capture.enable()

@qml.capture.expand_plxpr_transforms
@partial(qml.defer_measurements, aux_wires=list(range(5, 10)))
@partial(qml.defer_measurements, num_wires=10)
def f():
m0 = qml.measure(0)

Expand All @@ -655,21 +715,21 @@ def f():

>>> jax.make_jaxpr(f)()
{ lambda ; . let
_:AbstractOperator() = CNOT[n_wires=2] 0 5
_:AbstractOperator() = CNOT[n_wires=2] 0 9
a:f64[] = mul 0.0 3.141592653589793
b:f64[] = sin a
c:AbstractOperator() = RX[n_wires=1] b 0
_:AbstractOperator() = Controlled[
control_values=(False,)
work_wires=Wires([])
] c 5
] c 9
d:f64[] = mul 1.0 3.141592653589793
e:f64[] = sin d
f:AbstractOperator() = RX[n_wires=1] e 0
_:AbstractOperator() = Controlled[
control_values=(True,)
work_wires=Wires([])
] f 5
] f 9
g:AbstractOperator() = PauliZ[n_wires=1] 0
h:AbstractMeasurement(n_wires=None) = expval_obs g
in (h,) }
Expand Down
Loading
Loading