Skip to content
Open
Show file tree
Hide file tree
Changes from 72 commits
Commits
Show all changes
76 commits
Select commit Hold shift + click to select a range
5d32755
bump to 0.7.0
JerryChen97 Oct 22, 2025
a0afb1d
not the stable
JerryChen97 Oct 22, 2025
bfb1fc8
draft
JerryChen97 Oct 22, 2025
5732af2
more?
JerryChen97 Oct 22, 2025
6e2baed
Better cast_like and is_abstract
JerryChen97 Oct 22, 2025
816c6f6
After (JAX 0.7.0): vjp_func.args[0].args == ([],) for independent fun…
JerryChen97 Oct 22, 2025
5de0722
batches of assert value; rval to be hashable as wel
JerryChen97 Oct 22, 2025
53a9922
more
JerryChen97 Oct 22, 2025
fc3573b
more
JerryChen97 Oct 23, 2025
cb18461
enhance make_hashable
JerryChen97 Oct 23, 2025
3f2215f
Skip some weird fails for now
JerryChen97 Oct 23, 2025
b5150fc
skip tracer frist
JerryChen97 Oct 23, 2025
1c24e7e
more dynamic shape skips
JerryChen97 Oct 23, 2025
3c33252
patch jax
JerryChen97 Oct 23, 2025
d37c711
patch refactored in alignment with Catalyst
JerryChen97 Oct 23, 2025
1c556c1
don't import jax arbitrarily
JerryChen97 Oct 23, 2025
2c4d31b
more fix?
JerryChen97 Oct 23, 2025
4b2bed5
rm xfail
JerryChen97 Oct 23, 2025
6ad808d
rm xfail
JerryChen97 Oct 23, 2025
63ea71d
rm remains
JerryChen97 Oct 23, 2025
ff17844
fix all singles doubles
JerryChen97 Oct 23, 2025
2e4353d
Apply suggestions from code review
JerryChen97 Oct 23, 2025
7f47bbd
improve
JerryChen97 Oct 23, 2025
cb6e58f
deal with pylints
JerryChen97 Oct 23, 2025
9a1dfb3
disable protected-access
JerryChen97 Oct 24, 2025
555c8fb
Merge branch 'master' into bump-jax-to-0.7.0
JerryChen97 Oct 27, 2025
3d08400
update make hashable
JerryChen97 Oct 27, 2025
d5ca52a
more robust sorted call
JerryChen97 Oct 28, 2025
91ef7a7
Merge branch 'master' into bump-jax-to-0.7.0
JerryChen97 Oct 28, 2025
53b252d
refine the `_make_hashable` logic
JerryChen97 Oct 29, 2025
f328031
Merge branch 'master' into bump-jax-to-0.7.0
JerryChen97 Nov 6, 2025
d161795
Merge branch 'master' into bump-jax-to-0.7.0
JerryChen97 Nov 6, 2025
01c547c
fix output
JerryChen97 Nov 6, 2025
a2a3c7b
Merge branch 'master' into bump-jax-to-0.7.0
JerryChen97 Nov 12, 2025
6bd7904
Merge branch 'master' into bump-jax-to-0.7.0
JerryChen97 Nov 14, 2025
b8c7f24
temp fix
JerryChen97 Nov 14, 2025
821642b
Merge branch 'master' into bump-jax-to-0.7.0
JerryChen97 Nov 14, 2025
a2ab4ac
Merge branch 'master' into bump-jax-to-0.7.0
JerryChen97 Nov 17, 2025
da4d4aa
Apply suggestions from code review
JerryChen97 Nov 20, 2025
8d300c6
more slice improvement
JerryChen97 Nov 20, 2025
0995d9c
fix
JerryChen97 Nov 20, 2025
3107dcf
fix more
JerryChen97 Nov 20, 2025
46f2e55
remove restore list
JerryChen97 Nov 20, 2025
1ae39b7
get rid of _restore_dict (except for map_wires)
JerryChen97 Nov 20, 2025
f8c1da5
Merge branch 'master' into bump-jax-to-0.7.0
JerryChen97 Nov 20, 2025
dcfab13
remove _restore_dict
JerryChen97 Nov 20, 2025
714387a
xfail templates
JerryChen97 Nov 21, 2025
b6f1410
rm unnecessary local import
JerryChen97 Nov 21, 2025
a56f971
two more xfail (all within templates subfolder)
JerryChen97 Nov 21, 2025
8ab6e2e
Try not sorting
JerryChen97 Nov 21, 2025
9a74f0d
improtve the dev note
JerryChen97 Nov 21, 2025
b40954a
clean some remains forgotten to revert
JerryChen97 Nov 21, 2025
f3a8016
Update pennylane/capture/custom_primitives.py
JerryChen97 Nov 21, 2025
35a8905
remove historical comments that not make sense anymore
JerryChen97 Nov 21, 2025
368024a
Update pennylane/capture/custom_primitives.py
JerryChen97 Nov 21, 2025
f32f8f9
oooooooops
JerryChen97 Nov 21, 2025
f663bf6
move all the imports to the top level
JerryChen97 Nov 21, 2025
abdce2b
Update pennylane/capture/jax_patches.py
JerryChen97 Nov 21, 2025
b4d0a6b
doc req jax fix
JerryChen97 Nov 21, 2025
1dcd7ed
pylint
JerryChen97 Nov 21, 2025
7e8849f
jax~=0.6.0 -> ==0.7.0
JerryChen97 Nov 21, 2025
ea66ea1
== instead of ~=; jaxlib also update
JerryChen97 Nov 21, 2025
b706e7f
xfailed
JerryChen97 Nov 24, 2025
6699308
Merge branch 'master' into bump-jax-to-0.7.0
JerryChen97 Nov 24, 2025
de4c9a5
skip those impossible doctests
JerryChen97 Nov 24, 2025
3ddbf28
Merge branch 'master' into bump-jax-to-0.7.0
JerryChen97 Nov 24, 2025
b40c40d
Revert "skip those impossible doctests"
JerryChen97 Nov 24, 2025
c80f6b3
Merge branch 'master' into bump-jax-to-0.7.0
JerryChen97 Nov 25, 2025
02d2f38
Try: still use 0.6.2 for doctest
JerryChen97 Nov 25, 2025
aae5e6f
import from top-level
JerryChen97 Nov 26, 2025
2e5284e
remove too obvious comments
JerryChen97 Nov 26, 2025
b5c6ea4
Merge branch 'master' into bump-jax-to-0.7.0
JerryChen97 Nov 26, 2025
fd9ac16
Merge branch 'master' into bump-jax-to-0.7.0
JerryChen97 Dec 1, 2025
ce13ff7
[Capture] JAX patcher for capture (#8654)
JerryChen97 Dec 1, 2025
5b88e87
Update pennylane/workflow/_capture_qnode.py
JerryChen97 Dec 1, 2025
89335e7
delete unused
JerryChen97 Dec 1, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/documentation-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ jobs:
run: |
pip install --upgrade --index-url https://test.pypi.org/simple/ --extra-index-url https://pypi.org/simple/ pennylane-catalyst pennylane-lightning
pip install -e .
# TODO: use 0.7.0 after updating all the documentation
pip install sybil pytest "jax~=0.6.0" "jaxlib~=0.6.0" torch matplotlib pyzx
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

?


- name: Print Dependencies
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/interface-dependency-versions.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@ on:
description: The version of JAX to use for testing
required: false
type: string
default: '0.6.2'
default: '0.7.0'
catalyst_jax_version:
description: The version of JAX to use for testing along with Catalyst
required: false
type: string
default: '0.6.2'
default: '0.7.0'
torch_version:
description: The version of PyTorch to use for testing
required: false
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ timer.dat
tmp/*
benchmark/revisions/
venv
*venv*/
config.toml
.envrc
qml_debug.log
Expand Down
2 changes: 1 addition & 1 deletion doc/introduction/interfaces/jax.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ JAX interface

.. code-block:: bash

pip install jax~=0.6.0 jaxlib~=0.6.0
pip install jax==0.7.0 jaxlib==0.7.0

You can then import PennyLane and JAX as follows:

Expand Down
4 changes: 2 additions & 2 deletions doc/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ appdirs
autograd
autoray
cachetools
jax==0.6.0
jaxlib==0.6.0
jax==0.7.0
jaxlib==0.7.0
mistune==0.8.4
m2r2
numpy
Expand Down
8 changes: 4 additions & 4 deletions pennylane/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,12 +204,12 @@
from packaging.version import Version as _Version

if _find_spec("jax") is not None:
if (jax_version := _Version(_metadata_version("jax"))) > _Version("0.6.2"): # pragma: no cover
if (jax_version := _Version(_metadata_version("jax"))) > _Version("0.7.0"): # pragma: no cover
warnings.warn(
"PennyLane is not yet compatible with JAX versions > 0.6.2. "
"PennyLane is not yet compatible with JAX versions > 0.7.0. "
f"You have version {jax_version} installed. "
"Please downgrade JAX to 0.6.2 to avoid runtime errors using "
"python -m pip install jax~=0.6.0 jaxlib~=0.6.0",
"Please downgrade JAX to 0.7.0 to avoid runtime errors using "
"python -m pip install jax==0.7.0 jaxlib==0.7.0",
RuntimeWarning,
)

Expand Down
2 changes: 2 additions & 0 deletions pennylane/capture/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,8 @@ def _(*args, **kwargs):
from .autograph import run_autograph, disable_autograph
from .dynamic_shapes import determine_abstracted_axes, register_custom_staging_rule

from . import jax_patches

# by defining this here, we avoid
# E0611: No name 'AbstractOperator' in module 'pennylane.capture' (no-name-in-module)
# on use of from capture import AbstractOperator
Expand Down
30 changes: 27 additions & 3 deletions pennylane/capture/base_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,9 +496,10 @@ def handle_for_loop(
self, start, stop, step, *args, jaxpr_body_fn, consts_slice, args_slice, abstract_shapes_slice
):
"""Handle a for loop primitive."""
consts = args[consts_slice]
init_state = args[args_slice]
abstract_shapes = args[abstract_shapes_slice]
# Convert tuples back to slices (tuples are used for JAX 0.7.0 hashability)
consts = args[slice(*consts_slice)]
init_state = args[slice(*args_slice)]
abstract_shapes = args[slice(*abstract_shapes_slice)]
new_jaxpr_body_fn = jaxpr_to_jaxpr(
copy(self), jaxpr_body_fn, consts, *abstract_shapes, start, *init_state
)
Expand All @@ -523,6 +524,10 @@ def handle_for_loop(
@PlxprInterpreter.register_primitive(cond_prim)
def handle_cond(self, *invals, jaxpr_branches, consts_slices, args_slice):
"""Handle a cond primitive."""
# Convert tuples back to slices (tuples are used for JAX 0.7.0 hashability)
args_slice = slice(*args_slice)
consts_slices = [slice(*s) for s in consts_slices]

args = invals[args_slice]

new_jaxprs = []
Expand Down Expand Up @@ -560,6 +565,11 @@ def handle_while_loop(
args_slice,
):
"""Handle a while loop primitive."""
# Convert tuples back to slices (tuples are used for JAX 0.7.0 hashability)
body_slice = slice(*body_slice)
cond_slice = slice(*cond_slice)
args_slice = slice(*args_slice)

consts_body = invals[body_slice]
consts_cond = invals[cond_slice]
init_state = invals[args_slice]
Expand Down Expand Up @@ -654,6 +664,11 @@ def flatten_while_loop(
args_slice,
):
"""Handle the while loop by a flattened python strategy."""
# Convert tuples back to slices (tuples are used for JAX 0.7.0 hashability)
body_slice = slice(*body_slice)
cond_slice = slice(*cond_slice)
args_slice = slice(*args_slice)

consts_body = invals[body_slice]
consts_cond = invals[cond_slice]
init_state = invals[args_slice]
Expand All @@ -671,6 +686,10 @@ def flatten_while_loop(
@FlattenedInterpreter.register_primitive(cond_prim)
def flattened_cond(self, *invals, jaxpr_branches, consts_slices, args_slice):
"""Handle the cond primitive by a flattened python strategy."""
# Convert tuples back to slices (tuples are used for JAX 0.7.0 hashability)
args_slice = slice(*args_slice)
consts_slices = [slice(*s) for s in consts_slices]

n_branches = len(jaxpr_branches)
conditions = invals[:n_branches]
args = invals[args_slice]
Expand All @@ -694,6 +713,11 @@ def flattened_for(
self, start, stop, step, *invals, jaxpr_body_fn, consts_slice, args_slice, abstract_shapes_slice
):
"""Handle the for loop by a flattened python strategy."""
# Convert tuples back to slices (tuples are used for JAX 0.7.0 hashability)
consts_slice = slice(*consts_slice)
args_slice = slice(*args_slice)
abstract_shapes_slice = slice(*abstract_shapes_slice)

consts = invals[consts_slice]
init_state = invals[args_slice]
abstract_shapes = invals[abstract_shapes_slice]
Expand Down
55 changes: 54 additions & 1 deletion pennylane/capture/custom_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
This submodule offers custom primitives for the PennyLane capture module.
"""

# pylint: disable=too-many-return-statements

from enum import Enum
from typing import Any

from jax.extend.core import Primitive

Expand All @@ -30,10 +33,50 @@ class PrimitiveType(Enum):
TRANSFORM = "transform"


def _make_hashable(obj: Any) -> Any:
"""Convert potentially unhashable objects to hashable equivalents for JAX 0.7.0+.

JAX 0.7.0 requires all primitive parameters to be hashable. This helper converts
common unhashable types (list, dict, slice) to hashable tuples.

Args:
obj: Object to potentially convert to hashable form

Returns:
Hashable version of the object
"""
if isinstance(obj, slice):
return (obj.start, obj.stop, obj.step)

# First, check if the object is already hashable
try:
hash(obj)
return obj
except TypeError:
pass

# Import here to avoid circular dependency and only when needed
# pylint: disable=import-outside-toplevel
import jax
import numpy as np

if isinstance(obj, jax.core.Tracer):
raise ValueError("Tracers should never occur in primitive metadata.")
if isinstance(obj, np.ndarray) or (hasattr(jax, "Array") and isinstance(obj, jax.Array)):
raise ValueError("Arrays should never be in primitive metadata.")
if isinstance(obj, list):
return tuple(_make_hashable(item) for item in obj)
if isinstance(obj, dict):
# Python 3.7+ maintains dict insertion order, so no need to sort
# For the same primitive constructed the same way, keys are always in the same order
return tuple((k, _make_hashable(v)) for k, v in obj.items())
return obj


# pylint: disable=abstract-method,too-few-public-methods
class QmlPrimitive(Primitive):
"""A subclass for JAX's Primitive that differentiates between different
classes of primitives."""
classes of primitives and automatically makes parameters hashable for JAX 0.7.0+."""

_prim_type: PrimitiveType = PrimitiveType.DEFAULT

Expand All @@ -47,3 +90,13 @@ def prim_type(self):
def prim_type(self, value: str | PrimitiveType):
"""Setter for QmlPrimitive.prim_type."""
self._prim_type = PrimitiveType(value)

def bind(self, *args, **params):
"""Bind with automatic parameter hashability conversion for JAX 0.7.0+.

Overrides the parent bind method to automatically convert unhashable parameters
(like lists, dicts, and slices) to hashable tuples, which is required by JAX 0.7.0+.
"""
# Convert all parameters to hashable forms
hashable_params = {k: _make_hashable(v) for k, v in params.items()}
return super().bind(*args, **hashable_params)
35 changes: 27 additions & 8 deletions pennylane/capture/dynamic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
has_jax = True
try:
import jax
from jax._src import compute_on, config, xla_metadata_lib
from jax._src.interpreters.partial_eval import JaxprEqnContext, TracingEqn
from jax.interpreters import partial_eval as pe
except ImportError: # pragma: no cover
has_jax = False # pragma: no cover
Expand Down Expand Up @@ -164,7 +166,14 @@ def register_custom_staging_rule(
# see https://github.com/jax-ml/jax/blob/9e62994bce7c7fcbb2f6a50c9ef89526cd2c2be6/jax/_src/lax/lax.py#L3538
# and https://github.com/jax-ml/jax/blob/9e62994bce7c7fcbb2f6a50c9ef89526cd2c2be6/jax/_src/lax/lax.py#L208
# for reference to how jax is handling staging rules for dynamic shapes in v0.4.28
# see also capture/intro_to_dynamic_shapes.md
# JAX 0.6.2 to 0.7.0 introduced breaking changes in custom staging rules for dynamic shapes:
# 1. DynamicJaxprTracer constructor now requires the var as 3rd argument (previously created internally)
# 2. TracingEqn must be used instead of JaxprEqn for trace.frame.add_eqn
#
# This implementation creates vars first using trace.frame.newvar() before constructing
# DynamicJaxprTracer instances, fixing dynamic shape support that was broken in JAX 0.7.0.
# See pennylane/capture/jax_patches.py for related fixes to JAX's own staging rules.
# See also capture/intro_to_dynamic_shapes.md for dynamic shapes documentation.

def _tracer_and_outvar(
jaxpr_trace: pe.DynamicJaxprTrace,
Expand All @@ -176,15 +185,18 @@ def _tracer_and_outvar(
Returned vars are cached in env for use in future shapes
"""
if not hasattr(outvar.aval, "shape"):
out_tracer = pe.DynamicJaxprTracer(jaxpr_trace, outvar.aval, None)
return out_tracer, jaxpr_trace.makevar(out_tracer)
# JAX 0.7.0: Create variable first, then pass to DynamicJaxprTracer
new_var = jaxpr_trace.frame.newvar(outvar.aval)
out_tracer = pe.DynamicJaxprTracer(jaxpr_trace, outvar.aval, new_var)
return out_tracer, new_var
new_shape = [s if isinstance(s, int) else env[s] for s in outvar.aval.shape]
if all(isinstance(s, int) for s in outvar.aval.shape):
new_aval = jax.core.ShapedArray(tuple(new_shape), outvar.aval.dtype)
else:
new_aval = jax.core.DShapedArray(tuple(new_shape), outvar.aval.dtype)
out_tracer = pe.DynamicJaxprTracer(jaxpr_trace, new_aval, None)
new_var = jaxpr_trace.makevar(out_tracer)
# JAX 0.7.0: Create variable first, then pass to DynamicJaxprTracer
new_var = jaxpr_trace.frame.newvar(new_aval)
out_tracer = pe.DynamicJaxprTracer(jaxpr_trace, new_aval, new_var)

if not isinstance(outvar, jax.extend.core.Literal):
env[outvar] = new_var
Expand All @@ -211,13 +223,20 @@ def custom_staging_rule(
else:
out_tracers, returned_vars = (), ()

invars = [jaxpr_trace.getvar(x) for x in tracers]
eqn = jax.core.new_jaxpr_eqn(
invars,
ctx = JaxprEqnContext(
compute_on.current_compute_type(),
config.threefry_partitionable.value,
xla_metadata_lib.current_xla_metadata(),
)

eqn = TracingEqn(
tracers, # in_tracers (not invars!)
returned_vars,
primitive,
params,
jax.core.no_effects,
source_info,
ctx,
)
jaxpr_trace.frame.add_eqn(eqn)
return out_tracers
Expand Down
6 changes: 3 additions & 3 deletions pennylane/capture/expand_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ class ExpandTransformsInterpreter(PlxprInterpreter):
def _(
self, *invals, inner_jaxpr, args_slice, consts_slice, targs_slice, tkwargs, transform
): # pylint: disable=too-many-arguments
args = invals[args_slice]
consts = invals[consts_slice]
targs = invals[targs_slice]
args = invals[slice(*args_slice)]
consts = invals[slice(*consts_slice)]
targs = invals[slice(*targs_slice)]

def wrapper(*inner_args):
return copy(self).eval(inner_jaxpr, consts, *inner_args)
Expand Down
Loading
Loading