Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@

<h3>Improvements 🛠</h3>

* `qml.for_loop` will now fall back to a standard python loop if capture of a condensed, structured loop fails
with program capture enabled.
[(#8615)](https://github.com/PennyLaneAI/pennylane/pull/8615)

* The `~.BasisRotation` graph decomposition was re-written in a qjit friendly way with PennyLane control flow.
[(#8560)](https://github.com/PennyLaneAI/pennylane/pull/8560)
[(#8608)](https://github.com/PennyLaneAI/pennylane/pull/8608)
Expand Down
26 changes: 22 additions & 4 deletions pennylane/control_flow/for_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,15 @@
# limitations under the License.
"""For loop."""
import functools
import logging
import warnings
from typing import Literal

from pennylane import capture
from pennylane.capture import FlatFn, enabled
from pennylane.capture.dynamic_shapes import register_custom_staging_rule
from pennylane.compiler.compiler import AvailableCompilers, active_compiler
from pennylane.exceptions import CaptureWarning

from ._loop_abstract_axes import (
add_abstract_shapes,
Expand All @@ -28,6 +31,9 @@
validate_no_resizing_returns,
)

logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())


def for_loop(
start, stop=None, step=1, *, allow_array_resizing: Literal["auto", True, False] = "auto"
Expand Down Expand Up @@ -395,10 +401,21 @@ def _call_capture_enabled(self, *init_state):

import jax # pylint: disable=import-outside-toplevel

jaxpr_body_fn, abstract_shapes, flat_args, out_tree = self._get_jaxpr(
init_state, allow_array_resizing=self.allow_array_resizing
)

try:
jaxpr_body_fn, abstract_shapes, flat_args, out_tree = self._get_jaxpr(
init_state, allow_array_resizing=self.allow_array_resizing
)
except Exception as e: # pylint: disable=broad-exception-caught
logger.exception(e, exc_info=True)
warnings.warn(
(
"Structured capture of qml.for_loop failed with error:"
f"\n\n{e}.\n\nFull error logged at exception level. "
"Use qml.logging.enable_logging() to view."
),
CaptureWarning,
)
return self._call_capture_disabled(*init_state)
for_loop_prim = _get_for_loop_qfunc_prim()

consts_slice = slice(0, len(jaxpr_body_fn.consts))
Expand All @@ -417,6 +434,7 @@ def _call_capture_enabled(self, *init_state):
args_slice=args_slice,
abstract_shapes_slice=abstract_shapes_slice,
)

results = results[-out_tree.num_leaves :]
return jax.tree_util.tree_unflatten(out_tree, results)

Expand Down
4 changes: 4 additions & 0 deletions pennylane/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,10 @@ class AutoGraphWarning(Warning):
"""Warnings related to PennyLane's AutoGraph submodule."""


class CaptureWarning(Warning):
"""Warnings related to the capture of the program into a condensed PLxPR format."""


# =============================================================================
# Autograph and Compilation Errors
# =============================================================================
Expand Down
5 changes: 2 additions & 3 deletions tests/capture/autograph/test_autograph_for_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,9 +574,8 @@ def f():
qml.RY(params[i], wires=0)
return qml.expval(qml.PauliZ(0))

with pytest.raises(
AutoGraphError,
match="Make sure that loop variables are not used in tracing-incompatible ways",
with pytest.warns(
qml.exceptions.CaptureWarning, match="Structured capture of qml.for_loop failed"
):
run_autograph(f)()

Expand Down
23 changes: 16 additions & 7 deletions tests/capture/test_capture_for_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,8 +296,12 @@ def w(i0):
a0, b0 = jnp.ones(i0), jnp.ones(i0)
return f(a0, b0)

with pytest.raises(ValueError, match="Detected dynamically shaped arrays being resized"):
jax.make_jaxpr(w)(1)
with pytest.warns(
qml.exceptions.CaptureWarning, match="Structured capture of qml.for_loop failed"
):
jaxpr = jax.make_jaxpr(w)(1)

assert for_loop_prim not in {eqn.primitive for eqn in jaxpr.eqns}

def test_error_is_combining_independent_shapes(self):
"""Test that a useful error is raised if two arrays with dynamic shapes are combined."""
Expand All @@ -310,10 +314,12 @@ def w(i0):
a0, b0 = jnp.ones(i0), jnp.ones(i0)
return f(a0, b0)

with pytest.raises(
ValueError, match="attempt to combine arrays with two different dynamic shapes."
with pytest.warns(
qml.exceptions.CaptureWarning, match="Structured capture of qml.for_loop failed"
):
jax.make_jaxpr(w)(2)
jaxpr = jax.make_jaxpr(w)(2)

assert for_loop_prim not in {eqn.primitive for eqn in jaxpr.eqns}

def test_array_initialized_with_size_of_other_arg(self):
"""Test that one argument can have a shape that matches another argument, but
Expand Down Expand Up @@ -347,8 +353,11 @@ def f(i, a):

return f(jnp.arange(i0))

with pytest.raises(ValueError, match="due to a closure variable with a dynamic shape"):
jax.make_jaxpr(w)(3)
with pytest.warns(
qml.exceptions.CaptureWarning, match="Structured capture of qml.for_loop failed"
):
jaxpr = jax.make_jaxpr(w)(3)
assert for_loop_prim not in {eqn.primitive for eqn in jaxpr.eqns}

@pytest.mark.parametrize("allow_array_resizing", ("auto", False))
def test_loop_with_argument_combining(self, allow_array_resizing):
Expand Down
41 changes: 20 additions & 21 deletions tests/ftqc/test_parametric_mid_measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,23 +512,23 @@ def circ():
return qml.expval(qml.Z(2))

plxpr = jax.make_jaxpr(circ)()
captured_measurement = str(plxpr.eqns[0])

# measurement is captured as epxected
assert "measure_in_basis" in captured_measurement
assert f"plane={plane}" in captured_measurement
assert f"postselect={postselect}" in captured_measurement
assert f"reset={reset}" in captured_measurement
assert plxpr.eqns[0].primitive.name == "measure_in_basis"
assert plxpr.eqns[0].params["plane"] == plane
assert plxpr.eqns[0].params["postselect"] == postselect
assert plxpr.eqns[0].params["reset"] == reset

# parameters held in invars
assert jax.numpy.isclose(angle, plxpr.eqns[0].invars[0].val)
assert jax.numpy.isclose(wire, plxpr.eqns[0].invars[1].val)

# measurement value is assigned and passed forward
conditional = str(plxpr.eqns[1])
assert "cond" in conditional
assert captured_measurement[:8] == "a:bool[]"
assert "lambda ; a:i64[]" in conditional
assert plxpr.eqns[1].primitive.name == "cond"
assert plxpr.eqns[1].invars[0] == plxpr.eqns[0].outvars[0]
invar_aval = plxpr.eqns[1].params["jaxpr_branches"][0].invars[0].aval
assert invar_aval.dtype == jax.numpy.int64
assert invar_aval.shape == ()

@pytest.mark.capture
@pytest.mark.parametrize("angle, plane", [(1.23, "XY"), (1.5707, "YZ"), (-0.34, "ZX")])
Expand Down Expand Up @@ -557,13 +557,13 @@ def circ():
return qml.expval(qml.Z(2))

plxpr = jax.make_jaxpr(circ)()
captured_measurement = str(plxpr.eqns[0])

# measurement is captured as expected
assert "measure_in_basis" in captured_measurement
assert f"plane={plane}" in captured_measurement
assert f"postselect={postselect}" in captured_measurement
assert f"reset={reset}" in captured_measurement
assert plxpr.eqns[0].primitive.name == "measure_in_basis"
assert plxpr.eqns[0].params["plane"] == plane
assert plxpr.eqns[0].params["postselect"] == postselect
assert plxpr.eqns[0].params["reset"] == reset
outvar_aval = plxpr.eqns[0].outvars[0].aval
assert outvar_aval.shape == ()
assert outvar_aval.dtype == jax.numpy.bool

# dynamic parameters held in invars for numpy, and consts for jax
if "jax" in angle_type:
Expand All @@ -574,11 +574,10 @@ def circ():
# Wires captured as invars
assert jax.numpy.allclose(wire, plxpr.eqns[0].invars[1].val)

# measurement value is assigned and passed forward
conditional = str(plxpr.eqns[1])
assert "cond" in conditional
assert captured_measurement[:8] == "a:bool[]"
assert "lambda ; a:i64[]" in conditional
assert plxpr.eqns[1].primitive.name == "cond"
invar_aval = plxpr.eqns[1].params["jaxpr_branches"][0].invars[0].aval
assert invar_aval.shape == ()
assert invar_aval.dtype == jax.numpy.int64

@pytest.mark.capture
@pytest.mark.parametrize(
Expand Down
1 change: 1 addition & 0 deletions tests/pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ filterwarnings =
error:AutoGraph will not transform the function:pennylane.capture.autograph.AutoGraphWarning
error:.*Deprecated NumPy 1\..*:DeprecationWarning
error:Both 'shots=' parameter and 'set_shots' transform are specified. :UserWarning
error::pennylane.exceptions.CaptureWarning
#addopts = --benchmark-disable
xfail_strict=true
rng_salt = v0.44.0