Skip to content
Open
Show file tree
Hide file tree
Changes from 74 commits
Commits
Show all changes
76 commits
Select commit Hold shift + click to select a range
5d32755
bump to 0.7.0
JerryChen97 Oct 22, 2025
a0afb1d
not the stable
JerryChen97 Oct 22, 2025
bfb1fc8
draft
JerryChen97 Oct 22, 2025
5732af2
more?
JerryChen97 Oct 22, 2025
6e2baed
Better cast_like and is_abstract
JerryChen97 Oct 22, 2025
816c6f6
After (JAX 0.7.0): vjp_func.args[0].args == ([],) for independent fun…
JerryChen97 Oct 22, 2025
5de0722
batches of assert value; rval to be hashable as wel
JerryChen97 Oct 22, 2025
53a9922
more
JerryChen97 Oct 22, 2025
fc3573b
more
JerryChen97 Oct 23, 2025
cb18461
enhance make_hashable
JerryChen97 Oct 23, 2025
3f2215f
Skip some weird fails for now
JerryChen97 Oct 23, 2025
b5150fc
skip tracer frist
JerryChen97 Oct 23, 2025
1c24e7e
more dynamic shape skips
JerryChen97 Oct 23, 2025
3c33252
patch jax
JerryChen97 Oct 23, 2025
d37c711
patch refactored in alignment with Catalyst
JerryChen97 Oct 23, 2025
1c556c1
don't import jax arbitrarily
JerryChen97 Oct 23, 2025
2c4d31b
more fix?
JerryChen97 Oct 23, 2025
4b2bed5
rm xfail
JerryChen97 Oct 23, 2025
6ad808d
rm xfail
JerryChen97 Oct 23, 2025
63ea71d
rm remains
JerryChen97 Oct 23, 2025
ff17844
fix all singles doubles
JerryChen97 Oct 23, 2025
2e4353d
Apply suggestions from code review
JerryChen97 Oct 23, 2025
7f47bbd
improve
JerryChen97 Oct 23, 2025
cb6e58f
deal with pylints
JerryChen97 Oct 23, 2025
9a1dfb3
disable protected-access
JerryChen97 Oct 24, 2025
555c8fb
Merge branch 'master' into bump-jax-to-0.7.0
JerryChen97 Oct 27, 2025
3d08400
update make hashable
JerryChen97 Oct 27, 2025
d5ca52a
more robust sorted call
JerryChen97 Oct 28, 2025
91ef7a7
Merge branch 'master' into bump-jax-to-0.7.0
JerryChen97 Oct 28, 2025
53b252d
refine the `_make_hashable` logic
JerryChen97 Oct 29, 2025
f328031
Merge branch 'master' into bump-jax-to-0.7.0
JerryChen97 Nov 6, 2025
d161795
Merge branch 'master' into bump-jax-to-0.7.0
JerryChen97 Nov 6, 2025
01c547c
fix output
JerryChen97 Nov 6, 2025
a2a3c7b
Merge branch 'master' into bump-jax-to-0.7.0
JerryChen97 Nov 12, 2025
6bd7904
Merge branch 'master' into bump-jax-to-0.7.0
JerryChen97 Nov 14, 2025
b8c7f24
temp fix
JerryChen97 Nov 14, 2025
821642b
Merge branch 'master' into bump-jax-to-0.7.0
JerryChen97 Nov 14, 2025
a2ab4ac
Merge branch 'master' into bump-jax-to-0.7.0
JerryChen97 Nov 17, 2025
da4d4aa
Apply suggestions from code review
JerryChen97 Nov 20, 2025
8d300c6
more slice improvement
JerryChen97 Nov 20, 2025
0995d9c
fix
JerryChen97 Nov 20, 2025
3107dcf
fix more
JerryChen97 Nov 20, 2025
46f2e55
remove restore list
JerryChen97 Nov 20, 2025
1ae39b7
get rid of _restore_dict (except for map_wires)
JerryChen97 Nov 20, 2025
f8c1da5
Merge branch 'master' into bump-jax-to-0.7.0
JerryChen97 Nov 20, 2025
dcfab13
remove _restore_dict
JerryChen97 Nov 20, 2025
714387a
xfail templates
JerryChen97 Nov 21, 2025
b6f1410
rm unnecessary local import
JerryChen97 Nov 21, 2025
a56f971
two more xfail (all within templates subfolder)
JerryChen97 Nov 21, 2025
8ab6e2e
Try not sorting
JerryChen97 Nov 21, 2025
9a74f0d
improtve the dev note
JerryChen97 Nov 21, 2025
b40954a
clean some remains forgotten to revert
JerryChen97 Nov 21, 2025
f3a8016
Update pennylane/capture/custom_primitives.py
JerryChen97 Nov 21, 2025
35a8905
remove historical comments that not make sense anymore
JerryChen97 Nov 21, 2025
368024a
Update pennylane/capture/custom_primitives.py
JerryChen97 Nov 21, 2025
f32f8f9
oooooooops
JerryChen97 Nov 21, 2025
f663bf6
move all the imports to the top level
JerryChen97 Nov 21, 2025
abdce2b
Update pennylane/capture/jax_patches.py
JerryChen97 Nov 21, 2025
b4d0a6b
doc req jax fix
JerryChen97 Nov 21, 2025
1dcd7ed
pylint
JerryChen97 Nov 21, 2025
7e8849f
jax~=0.6.0 -> ==0.7.0
JerryChen97 Nov 21, 2025
ea66ea1
== instead of ~=; jaxlib also update
JerryChen97 Nov 21, 2025
b706e7f
xfailed
JerryChen97 Nov 24, 2025
6699308
Merge branch 'master' into bump-jax-to-0.7.0
JerryChen97 Nov 24, 2025
de4c9a5
skip those impossible doctests
JerryChen97 Nov 24, 2025
3ddbf28
Merge branch 'master' into bump-jax-to-0.7.0
JerryChen97 Nov 24, 2025
b40c40d
Revert "skip those impossible doctests"
JerryChen97 Nov 24, 2025
c80f6b3
Merge branch 'master' into bump-jax-to-0.7.0
JerryChen97 Nov 25, 2025
02d2f38
Try: still use 0.6.2 for doctest
JerryChen97 Nov 25, 2025
aae5e6f
import from top-level
JerryChen97 Nov 26, 2025
2e5284e
remove too obvious comments
JerryChen97 Nov 26, 2025
b5c6ea4
Merge branch 'master' into bump-jax-to-0.7.0
JerryChen97 Nov 26, 2025
fd9ac16
Merge branch 'master' into bump-jax-to-0.7.0
JerryChen97 Dec 1, 2025
ce13ff7
[Capture] JAX patcher for capture (#8654)
JerryChen97 Dec 1, 2025
5b88e87
Update pennylane/workflow/_capture_qnode.py
JerryChen97 Dec 1, 2025
89335e7
delete unused
JerryChen97 Dec 1, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions pennylane/capture/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,10 @@ def _(*args, **kwargs):
from .autograph import run_autograph, disable_autograph
from .dynamic_shapes import determine_abstracted_axes, register_custom_staging_rule

# Import Patcher for contextual patching (preferred over global patches)
from .patching import Patcher
from .jax_patches import get_jax_patches

# by defining this here, we avoid
# E0611: No name 'AbstractOperator' in module 'pennylane.capture' (no-name-in-module)
# on use of from capture import AbstractOperator
Expand Down Expand Up @@ -238,4 +242,6 @@ def __getattr__(key):
"FlatFn",
"run_autograph",
"make_plxpr",
"Patcher",
"get_jax_patches",
)
2 changes: 2 additions & 0 deletions pennylane/capture/custom_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
This submodule offers custom primitives for the PennyLane capture module.
"""

# pylint: disable=too-many-return-statements

from enum import Enum
from typing import Any

Expand Down
323 changes: 323 additions & 0 deletions pennylane/capture/jax_patches.py
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are actually all for dynamic shapes. In #8654 we can see that they are only required to be active within the dynamic shape context

Original file line number Diff line number Diff line change
@@ -0,0 +1,323 @@
# Copyright 2025 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Runtime patches for JAX internals to fix compatibility issues.

This module patches JAX internal functions to fix bugs that affect PennyLane's
capture mechanism. These patches are applied at module import time and are
version-specific.

This approach is inspired by Catalyst's JAX patches (see catalyst.jax_extras.patches),
which similarly monkey-patch JAX internal functions for compatibility. The key insight
is to add a `make_eqn` helper method to DynamicJaxprTrace that properly creates
TracingEqn objects, which JAX 0.7.0 requires but doesn't provide in all code paths.

JAX 0.7.0+ Patches
------------------

1. **_dyn_shape_staging_rule**: Fixed dynamic shape handling in lax/lax.py lines 267-275.

The bug in the original JAX implementation:
- Uses `pe.new_jaxpr_eqn` which creates a JaxprEqn, but `trace.frame.add_eqn`
expects a TracingEqn. This causes an AssertionError.
- This bug affects ALL array creation operations with traced dimensions:
* jax.numpy.arange(traced_value)
* jax.numpy.ones((traced_value,))
* jax.numpy.zeros(traced_value)
* Any operation using lax.broadcasted_iota with dynamic shapes

The fix:
- For StagingJaxprTrace (has counter): Use `pe.new_eqn_recipe` which properly
creates equation recipes for dynamic tracing.
- For DynamicJaxprTrace (no counter): Create TracingEqn directly with proper
JaxprEqnContext, avoiding the AssertionError.
- This enables array creation with traced dimensions to work correctly.

2. **pjit_staging_rule**: Fixed dynamic shape handling in pjit.py lines 1894-1898.

The bug in the original JAX implementation:
- Uses `core.new_jaxpr_eqn` which creates a JaxprEqn, but `trace.frame.add_eqn`
expects a TracingEqn. This causes an AssertionError.
- Accesses `arg.var` which doesn't exist for DynamicJaxprTracer objects.

The fix:
- Use `pe.new_eqn_recipe` which properly creates equation recipes for dynamic tracing.
- Wrap outvars in DynamicJaxprTracer instances before creating the equation.
- Special handling for DynamicJaxprTrace (eval_jaxpr path) to avoid counter errors.
- This enables pjit operations with dynamic shapes to work correctly.

Impact
------
These patches fix many dynamic shape tests that were previously failing due to these JAX bugs:
- Array creation operations (jnp.arange, jnp.ones, jnp.zeros with traced dimensions)
- Cond operations with dynamic shapes
- For loop operations with dynamic shapes
- While loop operations with dynamic shapes
- Custom staging rules with dynamic shapes

Without these patches, any operation creating arrays with traced dimensions would fail
with AssertionError in trace.frame.add_eqn.
"""

# pylint: disable=too-many-arguments
# pylint: disable=unused-import,no-else-return,unidiomatic-typecheck,use-dict-literal
# pylint: disable=protected-access

has_jax = True
try:
import jax
from jax._src import config as jax_config
from jax._src import core, pjit, source_info_util
from jax._src.core import JaxprEqnContext, Var
from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters.partial_eval import (
DynamicJaxprTracer,
TracingEqn,
compute_on,
xla_metadata_lib,
)
from jax._src.lax import lax
from packaging.version import Version
except ModuleNotFoundError: # pragma: no cover
has_jax = False # pragma: no cover


def _add_make_eqn_helper():
"""
Return a make_eqn helper method to DynamicJaxprTrace.

This helper properly creates TracingEqn objects, which is needed for JAX 0.7.0
compatibility. This is based on Catalyst's approach to the same issue.

Returns:
tuple: (DynamicJaxprTrace, "make_eqn", make_eqn).
"""

def make_eqn(
self,
in_tracers: list,
out_avals_or_tracers: list,
primitive,
params: dict,
effects: set,
source_info=None,
ctx=None,
):
"""Create a tracing equation properly.

Args:
in_tracers: Input tracers
out_avals_or_tracers: Output abstract values OR output tracers (with vars already created)
primitive: The primitive operation
params: Parameters for the primitive
effects: Effects of the operation
source_info: Source information for debugging
ctx: JaxprEqnContext (created if not provided)

Returns:
(eqn, out_tracers): TracingEqn and output tracers
"""
source_info = source_info or source_info_util.new_source_info()
ctx = ctx or JaxprEqnContext(
compute_on.current_compute_type(),
jax_config.threefry_partitionable.value,
xla_metadata_lib.current_xla_metadata(),
)

# Normalize out_avals to a list
if not isinstance(out_avals_or_tracers, (list, tuple)):
out_avals = [out_avals_or_tracers]
else:
out_avals = out_avals_or_tracers

outvars = [self.frame.newvar(aval) for aval in out_avals]

if jax_config.enable_checks.value:
assert all(isinstance(x, DynamicJaxprTracer) for x in in_tracers)
assert all(isinstance(v, Var) for v in outvars)

eqn = TracingEqn(list(in_tracers), outvars, primitive, params, effects, source_info, ctx)

# Create output tracers - manually create DynamicJaxprTracer objects
# We pass the equation as the parent parameter (4th argument to __init__)
out_tracers = [
DynamicJaxprTracer(self, aval, v, source_info, eqn)
for aval, v in zip(out_avals, outvars)
]

return eqn, out_tracers

return (pe.DynamicJaxprTrace, "make_eqn", make_eqn)


def _patch_dyn_shape_staging_rule():
"""
Return _dyn_shape_staging_rule patch to fix dynamic shape handling.

The bug in JAX 0.7.0's lax/lax.py lines 267-275 is that it uses:
- pe.new_jaxpr_eqn instead of proper TracingEqn creation

This causes an AssertionError when add_eqn expects a TracingEqn but gets a JaxprEqn.
This affects all array creation operations with traced dimensions like jnp.arange,
jnp.ones, jnp.zeros, etc.

The fix uses the make_eqn helper to properly create TracingEqn objects.

Returns:
list: List of patch tuples.
"""

def patched_dyn_shape_staging_rule(trace, source_info, prim, out_aval, *args, **params):
"""Patched version of _dyn_shape_staging_rule using make_eqn helper."""
# Check if we have a StagingJaxprTrace (has counter) or DynamicJaxprTrace (no counter)
if hasattr(trace, "counter"):
# StagingJaxprTrace path - use new_eqn_recipe (original JAX approach)
var = trace.frame.newvar(out_aval)
out_tracer = pe.DynamicJaxprTracer(trace, out_aval, var, source_info)
eqn = pe.new_eqn_recipe(
trace, args, [out_tracer], prim, params, core.no_effects, source_info
)
out_tracer.recipe = eqn
return out_tracer

# DynamicJaxprTrace path - use make_eqn helper
eqn, out_tracers = trace.make_eqn(
args, out_aval, prim, params, core.no_effects, source_info
)
trace.frame.add_eqn(eqn)
# Return single tracer (not list) since out_aval is a single value
return out_tracers[0]

# Return just the core patch - the wrappers will call the patched version
return [
(lax, "_dyn_shape_staging_rule", patched_dyn_shape_staging_rule),
]


def _patch_pjit_staging_rule():
"""
Return pjit_staging_rule patch to fix dynamic shape handling.

The bug in JAX 0.7.0's pjit.py lines 1894-1898 is that it uses:
- core.new_jaxpr_eqn instead of pe.new_eqn_recipe
- arg.var instead of accessing the correct tracer value

This causes an AssertionError when add_eqn expects a TracingEqn but gets a JaxprEqn.

Returns:
list: List of patch tuples.
"""
# Store the original function
original_staging_rule = pjit.pjit_staging_rule

def patched_pjit_staging_rule(trace, source_info, *args, **params):
"""Patched version of pjit_staging_rule with dynamic shape fixes."""
# Use the original implementation for most cases
if not jax_config.dynamic_shapes.value:
return original_staging_rule(trace, source_info, *args, **params)

# Check if we're in the inline path
if (
params["inline"]
and all(isinstance(i, pjit.UnspecifiedValue) for i in params["in_shardings"])
and all(isinstance(o, pjit.UnspecifiedValue) for o in params["out_shardings"])
and all(i is None for i in params["in_layouts"])
and all(o is None for o in params["out_layouts"])
):
# Use original for inline path
return original_staging_rule(trace, source_info, *args, **params)

jaxpr = params["jaxpr"]

# This is the dynamic shapes path that needs fixing
jaxpr, in_fwd, out_shardings, out_layouts = pjit._pjit_forwarding(
jaxpr, params["out_shardings"], params["out_layouts"]
)
params = {
**params,
"jaxpr": jaxpr,
"out_shardings": out_shardings,
"out_layouts": out_layouts,
}

# Fix 1: Use list instead of map to create outvars
outvars = [trace.frame.newvar(aval) for aval in pjit._out_type(jaxpr)]

# DynamicJaxprTrace (eval_jaxpr path) vs StagingJaxprTrace need different approaches
# DynamicJaxprTrace doesn't have 'counter' attribute
if not hasattr(trace, "counter"):
# For DynamicJaxprTrace: Use make_eqn helper (from _add_make_eqn_helper)
# Pass avals (not tracers) to make_eqn
in_tracers = [core.get_referent(arg) for arg in args]
out_avals = [v.aval for v in outvars]
eqn, out_tracers = trace.make_eqn(
in_tracers, out_avals, pjit.jit_p, params, jaxpr.effects, source_info
)
trace.frame.add_eqn(eqn)
else:
# For StagingJaxprTrace: Use new_eqn_recipe
eqn = pe.new_eqn_recipe(
trace,
args,
[pe.DynamicJaxprTracer(trace, v.aval, v, source_info) for v in outvars],
pjit.jit_p,
params,
jaxpr.effects,
source_info,
)

out_tracers = [pe.DynamicJaxprTracer(trace, v.aval, v, source_info) for v in outvars]
for t in out_tracers:
t.recipe = eqn

# Handle forwarding
out_tracers_ = iter(out_tracers)
out_tracers = [args[f] if isinstance(f, int) else next(out_tracers_) for f in in_fwd]
assert next(out_tracers_, None) is None

return out_tracers

return [
(pjit, "pjit_staging_rule", patched_pjit_staging_rule),
(pe.custom_staging_rules, "__dict_item__", pjit.jit_p, patched_pjit_staging_rule),
]


def get_jax_patches():
"""Get patch tuples for use with Patcher context manager.

Returns a tuple of (obj, attr, new_value) tuples that can be passed to Patcher.
These patches fix JAX 0.7.0+ compatibility issues for dynamic shapes and pjit.

Returns:
tuple: Patch tuples for Patcher, or empty tuple if patches not needed

Example:
>>> from pennylane.capture.patching import Patcher
>>> from pennylane.capture.jax_patches import get_jax_patches
>>> with Patcher(*get_jax_patches()):
... # JAX operations with patches applied
... jaxpr = jax.make_jaxpr(my_function)(args)
"""
if not has_jax:
return ()

patches = []

# Get all patches from the helper functions
patches.append(_add_make_eqn_helper())
patches.extend(_patch_dyn_shape_staging_rule())
patches.extend(_patch_pjit_staging_rule())

return tuple(patches)
Loading
Loading