Skip to content
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
9bc4fa9
restructure qnode_prim access
albi3ro Jan 16, 2025
b6ce7d3
fix improt
albi3ro Jan 16, 2025
7ee4757
add backprop validation and some jvp structure
albi3ro Jan 16, 2025
d6155de
add finite difference derivatives
albi3ro Jan 16, 2025
a26231e
changelog
albi3ro Jan 17, 2025
d246020
Merge branch 'master' into finite-diff-capture
albi3ro Jan 22, 2025
1ed5e9e
move workflow capture tests
albi3ro Jan 22, 2025
cac689e
adding tests
albi3ro Jan 22, 2025
300fdca
Merge branch 'master' into finite-diff-capture
albi3ro Jan 22, 2025
e056a67
somehow file was properly moved
albi3ro Jan 22, 2025
f34f1ed
Merge branch 'finite-diff-capture' of https://github.com/PennyLaneAI/…
albi3ro Jan 22, 2025
39f371d
fixing test
albi3ro Jan 22, 2025
17b9fcd
minor clean up
albi3ro Jan 23, 2025
0df100c
add finite_diff_jvp to gradients module
albi3ro Jan 23, 2025
ab2e6c6
adding tests for finite_difF_jvp
albi3ro Jan 23, 2025
50425b1
one additional tesT
albi3ro Jan 23, 2025
8f28f06
Merge branch 'master' into finite-diff-capture
albi3ro Jan 23, 2025
fd19014
add strategy and approx_order
albi3ro Jan 23, 2025
dd4a49b
minor efficiency rewriting
albi3ro Jan 23, 2025
09a2671
Apply suggestions from code review
albi3ro Jan 24, 2025
b87978a
responding to feedback
albi3ro Jan 24, 2025
e6ba05a
Apply suggestions from code review
albi3ro Jan 24, 2025
c53f10d
mergeing
albi3ro Jan 24, 2025
a2d622f
Apply suggestions from code review
albi3ro Jan 27, 2025
b045726
Merge branch 'master' into finite-diff-capture
albi3ro Jan 27, 2025
297b88d
Merge branch 'master' into finite-diff-capture
albi3ro Jan 27, 2025
38f1535
black
albi3ro Jan 27, 2025
7ed7f77
Merge branch 'master' into finite-diff-capture
albi3ro Jan 27, 2025
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,13 @@
* An informative error is raised when a `QNode` with `diff_method=None` is differentiated.
[(#6770)](https://github.com/PennyLaneAI/pennylane/pull/6770)

* `qml.gradients.finite_diff_jvp` has been added to compute the jvp of an arbitrary numeric
function.
[(#6853)](https://github.com/PennyLaneAI/pennylane/pull/6853)

* With program capture enabled, `QNode`'s can now be differentiated with `diff_method="finite-diff"`.
[(#6853)](https://github.com/PennyLaneAI/pennylane/pull/6853)

* The requested `diff_method` is now validated when program capture is enabled.
[(#6852)](https://github.com/PennyLaneAI/pennylane/pull/6852)

Expand Down
3 changes: 2 additions & 1 deletion pennylane/gradients/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
:toctree: api

finite_diff_coeffs
finite_diff_jvp
generate_shifted_tapes
generate_multishifted_tapes
generate_shift_rule
Expand Down Expand Up @@ -345,7 +346,7 @@ def my_custom_gradient(tape: qml.tape.QuantumScript, **kwargs) -> tuple[QuantumS
)
from .adjoint_metric_tensor import adjoint_metric_tensor
from .classical_jacobian import classical_jacobian
from .finite_difference import finite_diff, finite_diff_coeffs
from .finite_difference import finite_diff, finite_diff_coeffs, finite_diff_jvp
from .fisher import classical_fisher, quantum_fisher
from .general_shift_rules import (
eigvals_to_frequencies,
Expand Down
107 changes: 96 additions & 11 deletions pennylane/gradients/finite_difference.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2018-2021 Xanadu Quantum Technologies Inc.
# Copyright 2018-2025 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -17,6 +17,7 @@
"""
import functools
from functools import partial
from typing import Callable, Literal

# pylint: disable=protected-access,too-many-arguments,too-many-branches,too-many-statements,unused-argument
from warnings import warn
Expand Down Expand Up @@ -166,6 +167,88 @@ def finite_diff_coeffs(n, approx_order, strategy):
return coeffs_and_shifts


def finite_diff_jvp(
f: Callable,
args: tuple,
tangents: tuple,
*,
h: float = 1e-6,
approx_order: int = 1,
strategy: Literal["forward", "backward", "center"] = "forward",
) -> tuple:
r"""Compute the jvp of a generic function using finite differences.

Args:
f (Callable): a generic function that returns an iterable of tensors. Note that this
function should not have keyword arguments.
args (tuple[TensorLike]): The tuple of arguments to ``f``
tangents (tuple[TensorLike]): the tuple of tangents for the arguments.

Keyword Args:
h=1e-6 (float): finite difference method step size
approx_order=1 (int): The approximation order of the finite-difference method to use.
strategy="forward" (str): The strategy of the finite difference method. Must be one of
``"forward"``, ``"center"``, or ``"backward"``.
For the ``"forward"`` strategy, the finite-difference shifts occur at the points
:math:`x_0, x_0+h, x_0+2h,\dots`, where :math:`h` is some small
stepsize. The ``"backwards"`` strategy is similar, but in
reverse: :math:`x_0, x_0-h, x_0-2h, \dots`. Finally, the
``"center"`` strategy results in shifts symmetric around the
unshifted point: :math:`\dots, x_0-2h, x_0-h, x_0, x_0+h, x_0+2h,\dots`.

Returns:
results, d_results: Returns the results and the cotangents of the results

>>> def f(x, y):
... return 2 * x * y, x**2
>>> args = (0.5, 1.2)
>>> tangents = (1.0, 1.0)
>>> results, dresults = qml.gradients.finite_diff_jvp(f, args, tangents)
>>> results
(1.2, 0.25)
>>> dresults
[np.float64(3.399999999986747), np.float64(1.000001000006634)]

"""
coeffs, shifts = finite_diff_coeffs(n=1, approx_order=approx_order, strategy=strategy)

initial_res = f(*args)
if not isinstance(initial_res, (list, tuple)):
raise ValueError("Input function f must return either a list or tuple.")

jvps = [0 for _ in initial_res]
for i, t in enumerate(tangents):
if type(t).__name__ == "Zero": # Zero = jax.interpreters.ad.Zero
continue
t = np.array(t) if isinstance(t, (int, float)) else t

if qml.math.get_dtype_name(args[i]) == "float32":
warn(
"Detected float32 parameter with finite differences. Recommend use of float64 with finite diff.",
UserWarning,
)

shifted_args = list(args)
for index in np.ndindex(qml.math.shape(args[i])):
ti = t[index]

if not qml.math.is_abstract(ti) and qml.math.allclose(ti, 0):
continue

ti_over_h = ti / h
for coeff, shift in zip(coeffs, shifts):
if shift == 0:
res = initial_res
else:
shifted_args[i] = qml.math.scatter_element_add(args[i], index, h * shift)
res = f(*shifted_args)

for result_idx, r in enumerate(res):
jvps[result_idx] += ti_over_h * coeff * r

return initial_res, jvps


def _processing_fn(results, shots, single_shot_batch_fn):
if not shots.has_partitioned_shots:
return single_shot_batch_fn(results)
Expand All @@ -188,15 +271,16 @@ def _finite_diff_stopping_condition(op) -> bool:
return True


# pylint: disable=too-many-positional-arguments
def _expand_transform_finite_diff(
tape: QuantumScript,
argnum=None,
h=1e-7,
approx_order=1,
n=1,
strategy="forward",
h: float = 1e-7,
approx_order: int = 1,
n: int = 1,
strategy: Literal["forward", "backward", "center"] = "forward",
f0=None,
validate_params=True,
validate_params: bool = True,
) -> tuple[QuantumScriptBatch, PostprocessingFn]:
"""Expand function to be applied before finite difference."""
[new_tape], postprocessing = qml.devices.preprocess.decompose(
Expand All @@ -213,6 +297,7 @@ def _expand_transform_finite_diff(
return [new_tape], postprocessing


# pylint: disable=too-many-positional-arguments
@partial(
transform,
expand_transform=_expand_transform_finite_diff,
Expand All @@ -222,12 +307,12 @@ def _expand_transform_finite_diff(
def finite_diff(
tape: QuantumScript,
argnum=None,
h=1e-7,
approx_order=1,
n=1,
strategy="forward",
h: float = 1e-7,
approx_order: int = 1,
n: int = 1,
strategy: Literal["forward", "backward", "center"] = "forward",
f0=None,
validate_params=True,
validate_params: bool = True,
) -> tuple[QuantumScriptBatch, PostprocessingFn]:
r"""Transform a circuit to compute the finite-difference gradient of all gate parameters with respect to its inputs.

Expand Down
19 changes: 15 additions & 4 deletions pennylane/workflow/_capture_qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,6 @@ def _(*args, qnode, shots, device, qnode_kwargs, qfunc_jaxpr, n_consts, batch_di
raise NotImplementedError(
"Overriding shots is not yet supported with the program capture execution."
)
if qnode_kwargs["diff_method"] not in {"backprop", "best"}:
raise NotImplementedError("Only backpropagation derivatives are supported at this time.")

consts = args[:n_consts]
non_const_args = args[n_consts:]
Expand Down Expand Up @@ -251,6 +249,7 @@ def _qnode_batching_rule(
"using parameter broadcasting to a quantum operation that supports batching.",
UserWarning,
)

# To resolve this ambiguity, we might add more properties to the AbstractOperator
# class to indicate which operators support batching and check them here.
# As above, at this stage we raise a warning and give the user full flexibility.
Expand Down Expand Up @@ -291,7 +290,14 @@ def _backprop(args, tangents, **impl_kwargs):
return jax.jvp(partial(qnode_prim.impl, **impl_kwargs), args, tangents)


diff_method_map = {"backprop": _backprop}
def _finite_diff(args, tangents, **impl_kwargs):
f = partial(qnode_prim.bind, **impl_kwargs)
return qml.gradients.finite_diff_jvp(
f, args, tangents, **impl_kwargs["qnode_kwargs"]["gradient_kwargs"]
)


diff_method_map = {"backprop": _backprop, "finite-diff": _finite_diff}


def _resolve_diff_method(diff_method: str, device) -> str:
Expand Down Expand Up @@ -405,7 +411,12 @@ def f(x):

execute_kwargs = copy(qnode.execute_kwargs)
mcm_config = asdict(execute_kwargs.pop("mcm_config"))
qnode_kwargs = {"diff_method": qnode.diff_method, **execute_kwargs, **mcm_config}
qnode_kwargs = {
"diff_method": qnode.diff_method,
**execute_kwargs,
"gradient_kwargs": qnode.gradient_kwargs,
**mcm_config,
}

flat_args = jax.tree_util.tree_leaves(args)

Expand Down
Loading
Loading