diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md
index 1b597c63f32..fed94eecb23 100644
--- a/doc/releases/changelog-dev.md
+++ b/doc/releases/changelog-dev.md
@@ -15,6 +15,10 @@
Improvements ðŸ›
+* `qml.for_loop` will now fall back to a standard Python `for` loop if capturing a condensed, structured loop fails
+ with program capture enabled.
+ [(#8615)](https://github.com/PennyLaneAI/pennylane/pull/8615)
+
* The `~.BasisRotation` graph decomposition was re-written in a qjit friendly way with PennyLane control flow.
[(#8560)](https://github.com/PennyLaneAI/pennylane/pull/8560)
[(#8608)](https://github.com/PennyLaneAI/pennylane/pull/8608)
diff --git a/pennylane/control_flow/for_loop.py b/pennylane/control_flow/for_loop.py
index b6d87f947d2..ba2005ebee3 100644
--- a/pennylane/control_flow/for_loop.py
+++ b/pennylane/control_flow/for_loop.py
@@ -13,12 +13,15 @@
# limitations under the License.
"""For loop."""
import functools
+import logging
+import warnings
from typing import Literal
from pennylane import capture
from pennylane.capture import FlatFn, enabled
from pennylane.capture.dynamic_shapes import register_custom_staging_rule
from pennylane.compiler.compiler import AvailableCompilers, active_compiler
+from pennylane.exceptions import CaptureWarning
from ._loop_abstract_axes import (
add_abstract_shapes,
@@ -28,6 +31,9 @@
validate_no_resizing_returns,
)
+logger = logging.getLogger(__name__)
+logger.addHandler(logging.NullHandler())
+
def for_loop(
start, stop=None, step=1, *, allow_array_resizing: Literal["auto", True, False] = "auto"
@@ -395,10 +401,22 @@ def _call_capture_enabled(self, *init_state):
import jax # pylint: disable=import-outside-toplevel
- jaxpr_body_fn, abstract_shapes, flat_args, out_tree = self._get_jaxpr(
- init_state, allow_array_resizing=self.allow_array_resizing
- )
-
+ try:
+ jaxpr_body_fn, abstract_shapes, flat_args, out_tree = self._get_jaxpr(
+ init_state, allow_array_resizing=self.allow_array_resizing
+ )
+ except Exception as e: # pylint: disable=broad-exception-caught
+ logger.exception(e, exc_info=True)
+ warnings.warn(
+ (
+ "Structured capture of qml.for_loop failed with error:"
+ f"\n\n{e}.\n\nFull error logged at exception level. "
+ "Use qml.logging.enable_logging() to view."
+ "\nFalling back to unrolled Python for loop."
+ ),
+ CaptureWarning,
+ )
+ return self._call_capture_disabled(*init_state)
for_loop_prim = _get_for_loop_qfunc_prim()
consts_slice = slice(0, len(jaxpr_body_fn.consts))
@@ -417,6 +435,7 @@ def _call_capture_enabled(self, *init_state):
args_slice=args_slice,
abstract_shapes_slice=abstract_shapes_slice,
)
+
results = results[-out_tree.num_leaves :]
return jax.tree_util.tree_unflatten(out_tree, results)
diff --git a/pennylane/exceptions.py b/pennylane/exceptions.py
index 054ac965118..07972e7f882 100644
--- a/pennylane/exceptions.py
+++ b/pennylane/exceptions.py
@@ -193,6 +193,10 @@ class AutoGraphWarning(Warning):
"""Warnings related to PennyLane's AutoGraph submodule."""
+class CaptureWarning(Warning):
+ """Warnings related to the capture of the program into a condensed PLxPR format."""
+
+
# =============================================================================
# Autograph and Compilation Errors
# =============================================================================
diff --git a/tests/capture/autograph/test_autograph_for_loop.py b/tests/capture/autograph/test_autograph_for_loop.py
index 98e44cffbd7..6c5cefb0ad0 100644
--- a/tests/capture/autograph/test_autograph_for_loop.py
+++ b/tests/capture/autograph/test_autograph_for_loop.py
@@ -574,9 +574,8 @@ def f():
qml.RY(params[i], wires=0)
return qml.expval(qml.PauliZ(0))
- with pytest.raises(
- AutoGraphError,
- match="Make sure that loop variables are not used in tracing-incompatible ways",
+ with pytest.warns(
+ qml.exceptions.CaptureWarning, match="Structured capture of qml.for_loop failed"
):
run_autograph(f)()
diff --git a/tests/capture/test_capture_for_loop.py b/tests/capture/test_capture_for_loop.py
index 3d587b0faad..1d32a991c87 100644
--- a/tests/capture/test_capture_for_loop.py
+++ b/tests/capture/test_capture_for_loop.py
@@ -296,8 +296,12 @@ def w(i0):
a0, b0 = jnp.ones(i0), jnp.ones(i0)
return f(a0, b0)
- with pytest.raises(ValueError, match="Detected dynamically shaped arrays being resized"):
- jax.make_jaxpr(w)(1)
+ with pytest.warns(
+ qml.exceptions.CaptureWarning, match="Structured capture of qml.for_loop failed"
+ ):
+ jaxpr = jax.make_jaxpr(w)(1)
+
+ assert for_loop_prim not in {eqn.primitive for eqn in jaxpr.eqns}
def test_error_is_combining_independent_shapes(self):
"""Test that a useful error is raised if two arrays with dynamic shapes are combined."""
@@ -310,10 +314,12 @@ def w(i0):
a0, b0 = jnp.ones(i0), jnp.ones(i0)
return f(a0, b0)
- with pytest.raises(
- ValueError, match="attempt to combine arrays with two different dynamic shapes."
+ with pytest.warns(
+ qml.exceptions.CaptureWarning, match="Structured capture of qml.for_loop failed"
):
- jax.make_jaxpr(w)(2)
+ jaxpr = jax.make_jaxpr(w)(2)
+
+ assert for_loop_prim not in {eqn.primitive for eqn in jaxpr.eqns}
def test_array_initialized_with_size_of_other_arg(self):
"""Test that one argument can have a shape that matches another argument, but
@@ -347,8 +353,11 @@ def f(i, a):
return f(jnp.arange(i0))
- with pytest.raises(ValueError, match="due to a closure variable with a dynamic shape"):
- jax.make_jaxpr(w)(3)
+ with pytest.warns(
+ qml.exceptions.CaptureWarning, match="Structured capture of qml.for_loop failed"
+ ):
+ jaxpr = jax.make_jaxpr(w)(3)
+ assert for_loop_prim not in {eqn.primitive for eqn in jaxpr.eqns}
@pytest.mark.parametrize("allow_array_resizing", ("auto", False))
def test_loop_with_argument_combining(self, allow_array_resizing):
diff --git a/tests/ftqc/test_parametric_mid_measure.py b/tests/ftqc/test_parametric_mid_measure.py
index 3c026af8f6e..3267ff68cb1 100644
--- a/tests/ftqc/test_parametric_mid_measure.py
+++ b/tests/ftqc/test_parametric_mid_measure.py
@@ -512,23 +512,23 @@ def circ():
return qml.expval(qml.Z(2))
plxpr = jax.make_jaxpr(circ)()
- captured_measurement = str(plxpr.eqns[0])
# measurement is captured as epxected
- assert "measure_in_basis" in captured_measurement
- assert f"plane={plane}" in captured_measurement
- assert f"postselect={postselect}" in captured_measurement
- assert f"reset={reset}" in captured_measurement
+ assert plxpr.eqns[0].primitive.name == "measure_in_basis"
+ assert plxpr.eqns[0].params["plane"] == plane
+ assert plxpr.eqns[0].params["postselect"] == postselect
+ assert plxpr.eqns[0].params["reset"] == reset
# parameters held in invars
assert jax.numpy.isclose(angle, plxpr.eqns[0].invars[0].val)
assert jax.numpy.isclose(wire, plxpr.eqns[0].invars[1].val)
# measurement value is assigned and passed forward
- conditional = str(plxpr.eqns[1])
- assert "cond" in conditional
- assert captured_measurement[:8] == "a:bool[]"
- assert "lambda ; a:i64[]" in conditional
+ assert plxpr.eqns[1].primitive.name == "cond"
+ assert plxpr.eqns[1].invars[0] == plxpr.eqns[0].outvars[0]
+ invar_aval = plxpr.eqns[1].params["jaxpr_branches"][0].invars[0].aval
+ assert invar_aval.dtype == jax.numpy.int64
+ assert invar_aval.shape == ()
@pytest.mark.capture
@pytest.mark.parametrize("angle, plane", [(1.23, "XY"), (1.5707, "YZ"), (-0.34, "ZX")])
@@ -557,13 +557,13 @@ def circ():
return qml.expval(qml.Z(2))
plxpr = jax.make_jaxpr(circ)()
- captured_measurement = str(plxpr.eqns[0])
-
- # measurement is captured as expected
- assert "measure_in_basis" in captured_measurement
- assert f"plane={plane}" in captured_measurement
- assert f"postselect={postselect}" in captured_measurement
- assert f"reset={reset}" in captured_measurement
+ assert plxpr.eqns[0].primitive.name == "measure_in_basis"
+ assert plxpr.eqns[0].params["plane"] == plane
+ assert plxpr.eqns[0].params["postselect"] == postselect
+ assert plxpr.eqns[0].params["reset"] == reset
+ outvar_aval = plxpr.eqns[0].outvars[0].aval
+ assert outvar_aval.shape == ()
+ assert outvar_aval.dtype == jax.numpy.bool
# dynamic parameters held in invars for numpy, and consts for jax
if "jax" in angle_type:
@@ -574,11 +574,10 @@ def circ():
# Wires captured as invars
assert jax.numpy.allclose(wire, plxpr.eqns[0].invars[1].val)
- # measurement value is assigned and passed forward
- conditional = str(plxpr.eqns[1])
- assert "cond" in conditional
- assert captured_measurement[:8] == "a:bool[]"
- assert "lambda ; a:i64[]" in conditional
+ assert plxpr.eqns[1].primitive.name == "cond"
+ invar_aval = plxpr.eqns[1].params["jaxpr_branches"][0].invars[0].aval
+ assert invar_aval.shape == ()
+ assert invar_aval.dtype == jax.numpy.int64
@pytest.mark.capture
@pytest.mark.parametrize(
diff --git a/tests/pytest.ini b/tests/pytest.ini
index 8c4317837af..43831855274 100644
--- a/tests/pytest.ini
+++ b/tests/pytest.ini
@@ -33,6 +33,7 @@ filterwarnings =
error:AutoGraph will not transform the function:pennylane.capture.autograph.AutoGraphWarning
error:.*Deprecated NumPy 1\..*:DeprecationWarning
error:Both 'shots=' parameter and 'set_shots' transform are specified. :UserWarning
+ error::pennylane.exceptions.CaptureWarning
#addopts = --benchmark-disable
xfail_strict=true
rng_salt = v0.44.0