Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 25 additions & 11 deletions pennylane/control_flow/for_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,26 @@
# limitations under the License.
"""For loop."""
import functools
import logging
import warnings
from typing import Literal

from pennylane import capture
from pennylane.capture import FlatFn, enabled
from pennylane.capture.dynamic_shapes import register_custom_staging_rule
from pennylane.compiler.compiler import AvailableCompilers, active_compiler
from pennylane.exceptions import CaptureWarning

from ._loop_abstract_axes import (
add_abstract_shapes,
get_dummy_arg,
handle_jaxpr_error,
loop_determine_abstracted_axes,
validate_no_resizing_returns,
)

logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())


def for_loop(
start, stop=None, step=1, *, allow_array_resizing: Literal["auto", True, False] = "auto"
Expand Down Expand Up @@ -374,12 +379,9 @@ def _get_jaxpr(self, init_state, allow_array_resizing):
new_body_fn = flat_fn
dummy_init_state = flat_args

try:
jaxpr_body_fn = jax.make_jaxpr(new_body_fn, abstracted_axes=abstracted_axes)(
0, *dummy_init_state
)
except ValueError as e:
handle_jaxpr_error(e, (self.body_fn,), self.allow_array_resizing, "for_loop")
jaxpr_body_fn = jax.make_jaxpr(new_body_fn, abstracted_axes=abstracted_axes)(
0, *dummy_init_state
)

error_msg = validate_no_resizing_returns(jaxpr_body_fn.jaxpr, shape_locations, "for_loop")
if error_msg:
Expand All @@ -395,10 +397,21 @@ def _call_capture_enabled(self, *init_state):

import jax # pylint: disable=import-outside-toplevel

jaxpr_body_fn, abstract_shapes, flat_args, out_tree = self._get_jaxpr(
init_state, allow_array_resizing=self.allow_array_resizing
)

try:
jaxpr_body_fn, abstract_shapes, flat_args, out_tree = self._get_jaxpr(
init_state, allow_array_resizing=self.allow_array_resizing
)
except Exception as e: # pylint: disable=broad-exception-caught
logger.exception(e, exc_info=True)
warnings.warn(
(
"Structured capture of qml.for_loop failed with error:"
f"\n\n{e}.\n\nFull error logged at exception level. "
"Use qml.logging.enable_logging() to view."
),
CaptureWarning,
)
return self._call_capture_disabled(*init_state)
for_loop_prim = _get_for_loop_qfunc_prim()

consts_slice = slice(0, len(jaxpr_body_fn.consts))
Expand All @@ -417,6 +430,7 @@ def _call_capture_enabled(self, *init_state):
args_slice=args_slice,
abstract_shapes_slice=abstract_shapes_slice,
)

results = results[-out_tree.num_leaves :]
return jax.tree_util.tree_unflatten(out_tree, results)

Expand Down
4 changes: 4 additions & 0 deletions pennylane/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,10 @@ class AutoGraphWarning(Warning):
"""Warnings related to PennyLane's AutoGraph submodule."""


class CaptureWarning(Warning):
"""Warnings related to the capture of the program into a condensed PLxPR format."""


# =============================================================================
# Autograph and Compilation Errors
# =============================================================================
Expand Down
Loading