Skip to content

Commit 4c31904

Browse files
Control flow fallback (#8615)
**Context:** One of the challenges with turning autograph on by default is that it converts things it shouldn't convert into a structured for loop. When that happens, the whole program just fails and falls over. To make things more user friendly, we instead just fall back to a standard python for loop. Less performant, but at least it works. **Description of the Change:** Fall back to standard python for loop if the capture of the jaxpr fails. **Benefits:** Things continue to run. This fallback can be turned off by turning the `qml.exceptions.CaptureWarning` into an error. **Possible Drawbacks:** Things are unrolled, even if unrolled with a warning. **Related GitHub Issues:** [sc-103710] --------- Co-authored-by: Isaac De Vlugt <[email protected]>
1 parent 1283666 commit 4c31904

File tree

7 files changed

+70
-35
lines changed

7 files changed

+70
-35
lines changed

doc/releases/changelog-dev.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@
1515

1616
<h3>Improvements 🛠</h3>
1717

18+
* `qml.for_loop` will now fall back to a standard Python `for` loop if capturing a condensed, structured loop fails
19+
with program capture enabled.
20+
[(#8615)](https://github.com/PennyLaneAI/pennylane/pull/8615)
21+
1822
* The `~.BasisRotation` graph decomposition was re-written in a qjit friendly way with PennyLane control flow.
1923
[(#8560)](https://github.com/PennyLaneAI/pennylane/pull/8560)
2024
[(#8608)](https://github.com/PennyLaneAI/pennylane/pull/8608)

pennylane/control_flow/for_loop.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,15 @@
1313
# limitations under the License.
1414
"""For loop."""
1515
import functools
16+
import logging
17+
import warnings
1618
from typing import Literal
1719

1820
from pennylane import capture
1921
from pennylane.capture import FlatFn, enabled
2022
from pennylane.capture.dynamic_shapes import register_custom_staging_rule
2123
from pennylane.compiler.compiler import AvailableCompilers, active_compiler
24+
from pennylane.exceptions import CaptureWarning
2225

2326
from ._loop_abstract_axes import (
2427
add_abstract_shapes,
@@ -28,6 +31,9 @@
2831
validate_no_resizing_returns,
2932
)
3033

34+
logger = logging.getLogger(__name__)
35+
logger.addHandler(logging.NullHandler())
36+
3137

3238
def for_loop(
3339
start, stop=None, step=1, *, allow_array_resizing: Literal["auto", True, False] = "auto"
@@ -395,10 +401,22 @@ def _call_capture_enabled(self, *init_state):
395401

396402
import jax # pylint: disable=import-outside-toplevel
397403

398-
jaxpr_body_fn, abstract_shapes, flat_args, out_tree = self._get_jaxpr(
399-
init_state, allow_array_resizing=self.allow_array_resizing
400-
)
401-
404+
try:
405+
jaxpr_body_fn, abstract_shapes, flat_args, out_tree = self._get_jaxpr(
406+
init_state, allow_array_resizing=self.allow_array_resizing
407+
)
408+
except Exception as e: # pylint: disable=broad-exception-caught
409+
logger.exception(e, exc_info=True)
410+
warnings.warn(
411+
(
412+
"Structured capture of qml.for_loop failed with error:"
413+
f"\n\n{e}.\n\nFull error logged at exception level. "
414+
"Use qml.logging.enable_logging() to view."
415+
"\nFalling back to unrolled Python for loop."
416+
),
417+
CaptureWarning,
418+
)
419+
return self._call_capture_disabled(*init_state)
402420
for_loop_prim = _get_for_loop_qfunc_prim()
403421

404422
consts_slice = slice(0, len(jaxpr_body_fn.consts))
@@ -417,6 +435,7 @@ def _call_capture_enabled(self, *init_state):
417435
args_slice=args_slice,
418436
abstract_shapes_slice=abstract_shapes_slice,
419437
)
438+
420439
results = results[-out_tree.num_leaves :]
421440
return jax.tree_util.tree_unflatten(out_tree, results)
422441

pennylane/exceptions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,10 @@ class AutoGraphWarning(Warning):
193193
"""Warnings related to PennyLane's AutoGraph submodule."""
194194

195195

196+
class CaptureWarning(Warning):
197+
"""Warnings related to the capture of the program into a condensed PLxPR format."""
198+
199+
196200
# =============================================================================
197201
# Autograph and Compilation Errors
198202
# =============================================================================

tests/capture/autograph/test_autograph_for_loop.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -574,9 +574,8 @@ def f():
574574
qml.RY(params[i], wires=0)
575575
return qml.expval(qml.PauliZ(0))
576576

577-
with pytest.raises(
578-
AutoGraphError,
579-
match="Make sure that loop variables are not used in tracing-incompatible ways",
577+
with pytest.warns(
578+
qml.exceptions.CaptureWarning, match="Structured capture of qml.for_loop failed"
580579
):
581580
run_autograph(f)()
582581

tests/capture/test_capture_for_loop.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -296,8 +296,12 @@ def w(i0):
296296
a0, b0 = jnp.ones(i0), jnp.ones(i0)
297297
return f(a0, b0)
298298

299-
with pytest.raises(ValueError, match="Detected dynamically shaped arrays being resized"):
300-
jax.make_jaxpr(w)(1)
299+
with pytest.warns(
300+
qml.exceptions.CaptureWarning, match="Structured capture of qml.for_loop failed"
301+
):
302+
jaxpr = jax.make_jaxpr(w)(1)
303+
304+
assert for_loop_prim not in {eqn.primitive for eqn in jaxpr.eqns}
301305

302306
def test_error_is_combining_independent_shapes(self):
303307
"""Test that a useful error is raised if two arrays with dynamic shapes are combined."""
@@ -310,10 +314,12 @@ def w(i0):
310314
a0, b0 = jnp.ones(i0), jnp.ones(i0)
311315
return f(a0, b0)
312316

313-
with pytest.raises(
314-
ValueError, match="attempt to combine arrays with two different dynamic shapes."
317+
with pytest.warns(
318+
qml.exceptions.CaptureWarning, match="Structured capture of qml.for_loop failed"
315319
):
316-
jax.make_jaxpr(w)(2)
320+
jaxpr = jax.make_jaxpr(w)(2)
321+
322+
assert for_loop_prim not in {eqn.primitive for eqn in jaxpr.eqns}
317323

318324
def test_array_initialized_with_size_of_other_arg(self):
319325
"""Test that one argument can have a shape that matches another argument, but
@@ -347,8 +353,11 @@ def f(i, a):
347353

348354
return f(jnp.arange(i0))
349355

350-
with pytest.raises(ValueError, match="due to a closure variable with a dynamic shape"):
351-
jax.make_jaxpr(w)(3)
356+
with pytest.warns(
357+
qml.exceptions.CaptureWarning, match="Structured capture of qml.for_loop failed"
358+
):
359+
jaxpr = jax.make_jaxpr(w)(3)
360+
assert for_loop_prim not in {eqn.primitive for eqn in jaxpr.eqns}
352361

353362
@pytest.mark.parametrize("allow_array_resizing", ("auto", False))
354363
def test_loop_with_argument_combining(self, allow_array_resizing):

tests/ftqc/test_parametric_mid_measure.py

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -512,23 +512,23 @@ def circ():
512512
return qml.expval(qml.Z(2))
513513

514514
plxpr = jax.make_jaxpr(circ)()
515-
captured_measurement = str(plxpr.eqns[0])
516515

517516
# measurement is captured as epxected
518-
assert "measure_in_basis" in captured_measurement
519-
assert f"plane={plane}" in captured_measurement
520-
assert f"postselect={postselect}" in captured_measurement
521-
assert f"reset={reset}" in captured_measurement
517+
assert plxpr.eqns[0].primitive.name == "measure_in_basis"
518+
assert plxpr.eqns[0].params["plane"] == plane
519+
assert plxpr.eqns[0].params["postselect"] == postselect
520+
assert plxpr.eqns[0].params["reset"] == reset
522521

523522
# parameters held in invars
524523
assert jax.numpy.isclose(angle, plxpr.eqns[0].invars[0].val)
525524
assert jax.numpy.isclose(wire, plxpr.eqns[0].invars[1].val)
526525

527526
# measurement value is assigned and passed forward
528-
conditional = str(plxpr.eqns[1])
529-
assert "cond" in conditional
530-
assert captured_measurement[:8] == "a:bool[]"
531-
assert "lambda ; a:i64[]" in conditional
527+
assert plxpr.eqns[1].primitive.name == "cond"
528+
assert plxpr.eqns[1].invars[0] == plxpr.eqns[0].outvars[0]
529+
invar_aval = plxpr.eqns[1].params["jaxpr_branches"][0].invars[0].aval
530+
assert invar_aval.dtype == jax.numpy.int64
531+
assert invar_aval.shape == ()
532532

533533
@pytest.mark.capture
534534
@pytest.mark.parametrize("angle, plane", [(1.23, "XY"), (1.5707, "YZ"), (-0.34, "ZX")])
@@ -557,13 +557,13 @@ def circ():
557557
return qml.expval(qml.Z(2))
558558

559559
plxpr = jax.make_jaxpr(circ)()
560-
captured_measurement = str(plxpr.eqns[0])
561-
562-
# measurement is captured as expected
563-
assert "measure_in_basis" in captured_measurement
564-
assert f"plane={plane}" in captured_measurement
565-
assert f"postselect={postselect}" in captured_measurement
566-
assert f"reset={reset}" in captured_measurement
560+
assert plxpr.eqns[0].primitive.name == "measure_in_basis"
561+
assert plxpr.eqns[0].params["plane"] == plane
562+
assert plxpr.eqns[0].params["postselect"] == postselect
563+
assert plxpr.eqns[0].params["reset"] == reset
564+
outvar_aval = plxpr.eqns[0].outvars[0].aval
565+
assert outvar_aval.shape == ()
566+
assert outvar_aval.dtype == jax.numpy.bool
567567

568568
# dynamic parameters held in invars for numpy, and consts for jax
569569
if "jax" in angle_type:
@@ -574,11 +574,10 @@ def circ():
574574
# Wires captured as invars
575575
assert jax.numpy.allclose(wire, plxpr.eqns[0].invars[1].val)
576576

577-
# measurement value is assigned and passed forward
578-
conditional = str(plxpr.eqns[1])
579-
assert "cond" in conditional
580-
assert captured_measurement[:8] == "a:bool[]"
581-
assert "lambda ; a:i64[]" in conditional
577+
assert plxpr.eqns[1].primitive.name == "cond"
578+
invar_aval = plxpr.eqns[1].params["jaxpr_branches"][0].invars[0].aval
579+
assert invar_aval.shape == ()
580+
assert invar_aval.dtype == jax.numpy.int64
582581

583582
@pytest.mark.capture
584583
@pytest.mark.parametrize(

tests/pytest.ini

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ filterwarnings =
3333
error:AutoGraph will not transform the function:pennylane.capture.autograph.AutoGraphWarning
3434
error:.*Deprecated NumPy 1\..*:DeprecationWarning
3535
error:Both 'shots=' parameter and 'set_shots' transform are specified. :UserWarning
36+
error::pennylane.exceptions.CaptureWarning
3637
#addopts = --benchmark-disable
3738
xfail_strict=true
3839
rng_salt = v0.44.0

0 commit comments

Comments
 (0)