diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 1b597c63f32..fed94eecb23 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -15,6 +15,10 @@

Improvements 🛠

+* `qml.for_loop` will now fall back to a standard Python `for` loop if capturing a condensed, structured loop fails + with program capture enabled. + [(#8615)](https://github.com/PennyLaneAI/pennylane/pull/8615) + * The `~.BasisRotation` graph decomposition was re-written in a qjit friendly way with PennyLane control flow. [(#8560)](https://github.com/PennyLaneAI/pennylane/pull/8560) [(#8608)](https://github.com/PennyLaneAI/pennylane/pull/8608) diff --git a/pennylane/control_flow/for_loop.py b/pennylane/control_flow/for_loop.py index b6d87f947d2..ba2005ebee3 100644 --- a/pennylane/control_flow/for_loop.py +++ b/pennylane/control_flow/for_loop.py @@ -13,12 +13,15 @@ # limitations under the License. """For loop.""" import functools +import logging +import warnings from typing import Literal from pennylane import capture from pennylane.capture import FlatFn, enabled from pennylane.capture.dynamic_shapes import register_custom_staging_rule from pennylane.compiler.compiler import AvailableCompilers, active_compiler +from pennylane.exceptions import CaptureWarning from ._loop_abstract_axes import ( add_abstract_shapes, @@ -28,6 +31,9 @@ validate_no_resizing_returns, ) +logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) + def for_loop( start, stop=None, step=1, *, allow_array_resizing: Literal["auto", True, False] = "auto" @@ -395,10 +401,22 @@ def _call_capture_enabled(self, *init_state): import jax # pylint: disable=import-outside-toplevel - jaxpr_body_fn, abstract_shapes, flat_args, out_tree = self._get_jaxpr( - init_state, allow_array_resizing=self.allow_array_resizing - ) - + try: + jaxpr_body_fn, abstract_shapes, flat_args, out_tree = self._get_jaxpr( + init_state, allow_array_resizing=self.allow_array_resizing + ) + except Exception as e: # pylint: disable=broad-exception-caught + logger.exception(e, exc_info=True) + warnings.warn( + ( + "Structured capture of qml.for_loop failed with error:" + f"\n\n{e}.\n\nFull error logged at exception level. " + "Use qml.logging.enable_logging() to view." + "\nFalling back to unrolled Python for loop." + ), + CaptureWarning, + ) + return self._call_capture_disabled(*init_state) for_loop_prim = _get_for_loop_qfunc_prim() consts_slice = slice(0, len(jaxpr_body_fn.consts)) @@ -417,6 +435,7 @@ def _call_capture_enabled(self, *init_state): args_slice=args_slice, abstract_shapes_slice=abstract_shapes_slice, ) + results = results[-out_tree.num_leaves :] return jax.tree_util.tree_unflatten(out_tree, results) diff --git a/pennylane/exceptions.py b/pennylane/exceptions.py index 054ac965118..07972e7f882 100644 --- a/pennylane/exceptions.py +++ b/pennylane/exceptions.py @@ -193,6 +193,10 @@ class AutoGraphWarning(Warning): """Warnings related to PennyLane's AutoGraph submodule.""" +class CaptureWarning(Warning): + """Warnings related to the capture of the program into a condensed PLxPR format.""" + + # ============================================================================= # Autograph and Compilation Errors # ============================================================================= diff --git a/tests/capture/autograph/test_autograph_for_loop.py b/tests/capture/autograph/test_autograph_for_loop.py index 98e44cffbd7..6c5cefb0ad0 100644 --- a/tests/capture/autograph/test_autograph_for_loop.py +++ b/tests/capture/autograph/test_autograph_for_loop.py @@ -574,9 +574,8 @@ def f(): qml.RY(params[i], wires=0) return qml.expval(qml.PauliZ(0)) - with pytest.raises( - AutoGraphError, - match="Make sure that loop variables are not used in tracing-incompatible ways", + with pytest.warns( + qml.exceptions.CaptureWarning, match="Structured capture of qml.for_loop failed" ): run_autograph(f)() diff --git a/tests/capture/test_capture_for_loop.py b/tests/capture/test_capture_for_loop.py index 3d587b0faad..1d32a991c87 100644 --- a/tests/capture/test_capture_for_loop.py +++ b/tests/capture/test_capture_for_loop.py @@ -296,8 +296,12 @@ def w(i0): a0, b0 = jnp.ones(i0), jnp.ones(i0) return f(a0, b0) - with pytest.raises(ValueError, match="Detected dynamically shaped arrays being resized"): - jax.make_jaxpr(w)(1) + with pytest.warns( + qml.exceptions.CaptureWarning, match="Structured capture of qml.for_loop failed" + ): + jaxpr = jax.make_jaxpr(w)(1) + + assert for_loop_prim not in {eqn.primitive for eqn in jaxpr.eqns} def test_error_is_combining_independent_shapes(self): """Test that a useful error is raised if two arrays with dynamic shapes are combined.""" @@ -310,10 +314,12 @@ def w(i0): a0, b0 = jnp.ones(i0), jnp.ones(i0) return f(a0, b0) - with pytest.raises( - ValueError, match="attempt to combine arrays with two different dynamic shapes." + with pytest.warns( + qml.exceptions.CaptureWarning, match="Structured capture of qml.for_loop failed" ): - jax.make_jaxpr(w)(2) + jaxpr = jax.make_jaxpr(w)(2) + + assert for_loop_prim not in {eqn.primitive for eqn in jaxpr.eqns} def test_array_initialized_with_size_of_other_arg(self): """Test that one argument can have a shape that matches another argument, but @@ -347,8 +353,11 @@ def f(i, a): return f(jnp.arange(i0)) - with pytest.raises(ValueError, match="due to a closure variable with a dynamic shape"): - jax.make_jaxpr(w)(3) + with pytest.warns( + qml.exceptions.CaptureWarning, match="Structured capture of qml.for_loop failed" + ): + jaxpr = jax.make_jaxpr(w)(3) + assert for_loop_prim not in {eqn.primitive for eqn in jaxpr.eqns} @pytest.mark.parametrize("allow_array_resizing", ("auto", False)) def test_loop_with_argument_combining(self, allow_array_resizing): diff --git a/tests/ftqc/test_parametric_mid_measure.py b/tests/ftqc/test_parametric_mid_measure.py index 3c026af8f6e..3267ff68cb1 100644 --- a/tests/ftqc/test_parametric_mid_measure.py +++ b/tests/ftqc/test_parametric_mid_measure.py @@ -512,23 +512,23 @@ def circ(): return qml.expval(qml.Z(2)) plxpr = jax.make_jaxpr(circ)() - captured_measurement = str(plxpr.eqns[0]) # measurement is captured as epxected - assert "measure_in_basis" in captured_measurement - assert f"plane={plane}" in captured_measurement - assert f"postselect={postselect}" in captured_measurement - assert f"reset={reset}" in captured_measurement + assert plxpr.eqns[0].primitive.name == "measure_in_basis" + assert plxpr.eqns[0].params["plane"] == plane + assert plxpr.eqns[0].params["postselect"] == postselect + assert plxpr.eqns[0].params["reset"] == reset # parameters held in invars assert jax.numpy.isclose(angle, plxpr.eqns[0].invars[0].val) assert jax.numpy.isclose(wire, plxpr.eqns[0].invars[1].val) # measurement value is assigned and passed forward - conditional = str(plxpr.eqns[1]) - assert "cond" in conditional - assert captured_measurement[:8] == "a:bool[]" - assert "lambda ; a:i64[]" in conditional + assert plxpr.eqns[1].primitive.name == "cond" + assert plxpr.eqns[1].invars[0] == plxpr.eqns[0].outvars[0] + invar_aval = plxpr.eqns[1].params["jaxpr_branches"][0].invars[0].aval + assert invar_aval.dtype == jax.numpy.int64 + assert invar_aval.shape == () @pytest.mark.capture @pytest.mark.parametrize("angle, plane", [(1.23, "XY"), (1.5707, "YZ"), (-0.34, "ZX")]) @@ -557,13 +557,13 @@ def circ(): return qml.expval(qml.Z(2)) plxpr = jax.make_jaxpr(circ)() - captured_measurement = str(plxpr.eqns[0]) - - # measurement is captured as expected - assert "measure_in_basis" in captured_measurement - assert f"plane={plane}" in captured_measurement - assert f"postselect={postselect}" in captured_measurement - assert f"reset={reset}" in captured_measurement + assert plxpr.eqns[0].primitive.name == "measure_in_basis" + assert plxpr.eqns[0].params["plane"] == plane + assert plxpr.eqns[0].params["postselect"] == postselect + assert plxpr.eqns[0].params["reset"] == reset + outvar_aval = plxpr.eqns[0].outvars[0].aval + assert outvar_aval.shape == () + assert outvar_aval.dtype == jax.numpy.bool # dynamic parameters held in invars for numpy, and consts for jax if "jax" in angle_type: @@ -574,11 +574,10 @@ def circ(): # Wires captured as invars assert jax.numpy.allclose(wire, plxpr.eqns[0].invars[1].val) - # measurement value is assigned and passed forward - conditional = str(plxpr.eqns[1]) - assert "cond" in conditional - assert captured_measurement[:8] == "a:bool[]" - assert "lambda ; a:i64[]" in conditional + assert plxpr.eqns[1].primitive.name == "cond" + invar_aval = plxpr.eqns[1].params["jaxpr_branches"][0].invars[0].aval + assert invar_aval.shape == () + assert invar_aval.dtype == jax.numpy.int64 @pytest.mark.capture @pytest.mark.parametrize( diff --git a/tests/pytest.ini b/tests/pytest.ini index 8c4317837af..43831855274 100644 --- a/tests/pytest.ini +++ b/tests/pytest.ini @@ -33,6 +33,7 @@ filterwarnings = error:AutoGraph will not transform the function:pennylane.capture.autograph.AutoGraphWarning error:.*Deprecated NumPy 1\..*:DeprecationWarning error:Both 'shots=' parameter and 'set_shots' transform are specified. :UserWarning + error::pennylane.exceptions.CaptureWarning #addopts = --benchmark-disable xfail_strict=true rng_salt = v0.44.0