Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
95 commits
Select commit Hold shift + click to select a range
5d32755
bump to 0.7.0
JerryChen97 Oct 22, 2025
a0afb1d
not the stable
JerryChen97 Oct 22, 2025
bfb1fc8
draft
JerryChen97 Oct 22, 2025
5732af2
more?
JerryChen97 Oct 22, 2025
6e2baed
Better cast_like and is_abstract
JerryChen97 Oct 22, 2025
816c6f6
After (JAX 0.7.0): vjp_func.args[0].args == ([],) for independent fun…
JerryChen97 Oct 22, 2025
5de0722
batches of assert value; rval to be hashable as wel
JerryChen97 Oct 22, 2025
53a9922
more
JerryChen97 Oct 22, 2025
fc3573b
more
JerryChen97 Oct 23, 2025
cb18461
enhance make_hashable
JerryChen97 Oct 23, 2025
3f2215f
Skip some weird fails for now
JerryChen97 Oct 23, 2025
b5150fc
skip tracer frist
JerryChen97 Oct 23, 2025
1c24e7e
more dynamic shape skips
JerryChen97 Oct 23, 2025
3c33252
patch jax
JerryChen97 Oct 23, 2025
d37c711
patch refactored in alignment with Catalyst
JerryChen97 Oct 23, 2025
1c556c1
don't import jax arbitrarily
JerryChen97 Oct 23, 2025
2c4d31b
more fix?
JerryChen97 Oct 23, 2025
4b2bed5
rm xfail
JerryChen97 Oct 23, 2025
6ad808d
rm xfail
JerryChen97 Oct 23, 2025
63ea71d
rm remains
JerryChen97 Oct 23, 2025
ff17844
fix all singles doubles
JerryChen97 Oct 23, 2025
2e4353d
Apply suggestions from code review
JerryChen97 Oct 23, 2025
7f47bbd
improve
JerryChen97 Oct 23, 2025
cb6e58f
deal with pylints
JerryChen97 Oct 23, 2025
9a1dfb3
disable protected-access
JerryChen97 Oct 24, 2025
555c8fb
Merge branch 'master' into bump-jax-to-0.7.0
JerryChen97 Oct 27, 2025
3d08400
update make hashable
JerryChen97 Oct 27, 2025
d5ca52a
more robust sorted call
JerryChen97 Oct 28, 2025
91ef7a7
Merge branch 'master' into bump-jax-to-0.7.0
JerryChen97 Oct 28, 2025
53b252d
refine the `_make_hashable` logic
JerryChen97 Oct 29, 2025
f328031
Merge branch 'master' into bump-jax-to-0.7.0
JerryChen97 Nov 6, 2025
d161795
Merge branch 'master' into bump-jax-to-0.7.0
JerryChen97 Nov 6, 2025
01c547c
fix output
JerryChen97 Nov 6, 2025
a2a3c7b
Merge branch 'master' into bump-jax-to-0.7.0
JerryChen97 Nov 12, 2025
6bd7904
Merge branch 'master' into bump-jax-to-0.7.0
JerryChen97 Nov 14, 2025
b8c7f24
temp fix
JerryChen97 Nov 14, 2025
821642b
Merge branch 'master' into bump-jax-to-0.7.0
JerryChen97 Nov 14, 2025
a2ab4ac
Merge branch 'master' into bump-jax-to-0.7.0
JerryChen97 Nov 17, 2025
da4d4aa
Apply suggestions from code review
JerryChen97 Nov 20, 2025
8d300c6
more slice improvement
JerryChen97 Nov 20, 2025
0995d9c
fix
JerryChen97 Nov 20, 2025
3107dcf
fix more
JerryChen97 Nov 20, 2025
46f2e55
remove restore list
JerryChen97 Nov 20, 2025
1ae39b7
get rid of _restore_dict (except for map_wires)
JerryChen97 Nov 20, 2025
f8c1da5
Merge branch 'master' into bump-jax-to-0.7.0
JerryChen97 Nov 20, 2025
dcfab13
remove _restore_dict
JerryChen97 Nov 20, 2025
714387a
xfail templates
JerryChen97 Nov 21, 2025
b6f1410
rm unnecessary local import
JerryChen97 Nov 21, 2025
a56f971
two more xfail (all within templates subfolder)
JerryChen97 Nov 21, 2025
8ab6e2e
Try not sorting
JerryChen97 Nov 21, 2025
9a74f0d
improtve the dev note
JerryChen97 Nov 21, 2025
b40954a
clean some remains forgotten to revert
JerryChen97 Nov 21, 2025
f3a8016
Update pennylane/capture/custom_primitives.py
JerryChen97 Nov 21, 2025
35a8905
remove historical comments that not make sense anymore
JerryChen97 Nov 21, 2025
368024a
Update pennylane/capture/custom_primitives.py
JerryChen97 Nov 21, 2025
f32f8f9
oooooooops
JerryChen97 Nov 21, 2025
f663bf6
move all the imports to the top level
JerryChen97 Nov 21, 2025
abdce2b
Update pennylane/capture/jax_patches.py
JerryChen97 Nov 21, 2025
b4d0a6b
doc req jax fix
JerryChen97 Nov 21, 2025
1dcd7ed
pylint
JerryChen97 Nov 21, 2025
7e8849f
jax~=0.6.0 -> ==0.7.0
JerryChen97 Nov 21, 2025
ea66ea1
== instead of ~=; jaxlib also update
JerryChen97 Nov 21, 2025
b706e7f
xfailed
JerryChen97 Nov 24, 2025
6699308
Merge branch 'master' into bump-jax-to-0.7.0
JerryChen97 Nov 24, 2025
de4c9a5
skip those impossible doctests
JerryChen97 Nov 24, 2025
3ddbf28
Merge branch 'master' into bump-jax-to-0.7.0
JerryChen97 Nov 24, 2025
b40c40d
Revert "skip those impossible doctests"
JerryChen97 Nov 24, 2025
c80f6b3
Merge branch 'master' into bump-jax-to-0.7.0
JerryChen97 Nov 25, 2025
02d2f38
Try: still use 0.6.2 for doctest
JerryChen97 Nov 25, 2025
57075da
everything but patches (dynamic shapes)
JerryChen97 Nov 25, 2025
2cc4b81
not include and xfail
JerryChen97 Nov 25, 2025
713714c
Merge branch 'master' into bump-jax-api-hashability
JerryChen97 Nov 25, 2025
1139e66
simplify `_make_hashable`
JerryChen97 Nov 26, 2025
b32dc30
improve tests_templates:
JerryChen97 Nov 26, 2025
b3c6759
Merge branch 'master' into bump-jax-api-hashability
JerryChen97 Nov 26, 2025
5f4fda4
remove too obvious commetns
JerryChen97 Nov 26, 2025
35e7091
rm a unrelated xfail mistakenly introduced
JerryChen97 Nov 26, 2025
86c934d
format
JerryChen97 Nov 26, 2025
1177d19
clean some unused suppressions
JerryChen97 Nov 26, 2025
23333c4
align with the jax pattern; keep the changes minimized compared with …
JerryChen97 Nov 26, 2025
15755c9
Merge branch 'master' into bump-jax-api-hashability
JerryChen97 Nov 27, 2025
705f963
removeusage of TracingEqn
JerryChen97 Nov 27, 2025
ea0773e
rm more
JerryChen97 Nov 27, 2025
158c6a3
Merge branch 'master' into bump-jax-api-hashability
JerryChen97 Nov 27, 2025
1cc9b22
1 more simplification: use make_eqn from pe
JerryChen97 Nov 27, 2025
644682d
Merge branch 'master' into bump-jax-api-hashability
JerryChen97 Nov 28, 2025
b934030
trigger
JerryChen97 Nov 28, 2025
522e8f2
try fix catalyst
JerryChen97 Nov 28, 2025
d51e9ab
revert https://github.com/PennyLaneAI/pennylane/pull/8701/commits/70…
JerryChen97 Nov 28, 2025
209e9ba
Trigger by Update .github/workflows/documentation-tests.yml
JerryChen97 Nov 29, 2025
d4683bb
no cover
JerryChen97 Nov 29, 2025
5d3300d
Merge remote-tracking branch 'refs/remotes/origin/bump-jax-api-hashab…
JerryChen97 Nov 29, 2025
9d7a237
no cover loop dynamic shape part
JerryChen97 Nov 29, 2025
73bc8ae
mf
JerryChen97 Nov 29, 2025
3dc74a5
no cover more
JerryChen97 Nov 29, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/documentation-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ jobs:
run: |
pip install --upgrade --index-url https://test.pypi.org/simple/ --extra-index-url https://pypi.org/simple/ pennylane-catalyst pennylane-lightning
pip install -e .
# TODO: use 0.7.0 after updating all the documentation
pip install sybil pytest "jax~=0.6.0" "jaxlib~=0.6.0" torch matplotlib pyzx

- name: Print Dependencies
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/interface-dependency-versions.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@ on:
description: The version of JAX to use for testing
required: false
type: string
default: '0.6.2'
default: '0.7.0'
catalyst_jax_version:
description: The version of JAX to use for testing along with Catalyst
required: false
type: string
default: '0.6.2'
default: '0.7.0'
torch_version:
description: The version of PyTorch to use for testing
required: false
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ timer.dat
tmp/*
benchmark/revisions/
venv
*venv*/
config.toml
.envrc
qml_debug.log
Expand Down
2 changes: 1 addition & 1 deletion doc/introduction/interfaces/jax.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ JAX interface

.. code-block:: bash

pip install jax~=0.6.0 jaxlib~=0.6.0
pip install jax==0.7.0 jaxlib==0.7.0

You can then import PennyLane and JAX as follows:

Expand Down
4 changes: 2 additions & 2 deletions doc/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ appdirs
autograd
autoray
cachetools
jax==0.6.0
jaxlib==0.6.0
jax==0.7.0
jaxlib==0.7.0
mistune==0.8.4
m2r2
# TODO: Remove once galois becomes compatible with latest numpy
Expand Down
8 changes: 4 additions & 4 deletions pennylane/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,12 +204,12 @@
from packaging.version import Version as _Version

if _find_spec("jax") is not None:
if (jax_version := _Version(_metadata_version("jax"))) > _Version("0.6.2"): # pragma: no cover
if (jax_version := _Version(_metadata_version("jax"))) > _Version("0.7.0"): # pragma: no cover
warnings.warn(
"PennyLane is not yet compatible with JAX versions > 0.6.2. "
"PennyLane is not yet compatible with JAX versions > 0.7.0. "
f"You have version {jax_version} installed. "
"Please downgrade JAX to 0.6.2 to avoid runtime errors using "
"python -m pip install jax~=0.6.0 jaxlib~=0.6.0",
"Please downgrade JAX to 0.7.0 to avoid runtime errors using "
"python -m pip install jax==0.7.0 jaxlib==0.7.0",
RuntimeWarning,
)

Expand Down
4 changes: 3 additions & 1 deletion pennylane/_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,9 @@ def _grad_abstract(*args, argnums, jaxpr, n_consts, method, h, scalar_out, fn):


def _shape(shape, dtype, weak_type=False):
if jax.config.jax_dynamic_shapes and any(not isinstance(s, int) for s in shape):
if jax.config.jax_dynamic_shapes and any(
not isinstance(s, int) for s in shape
): # pragma: no cover
return jax.core.DShapedArray(shape, dtype, weak_type=weak_type)
return jax.core.ShapedArray(shape, dtype, weak_type=weak_type)

Expand Down
32 changes: 28 additions & 4 deletions pennylane/capture/base_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def _fill_in_shape_with_dyn_shape(dyn_shape: tuple["jax.core.Tracer"], shape: tu
for s in shape:
if s is not None:
new_shape.append(s)
else:
else: # pragma: no cover
# pull from iterable of dynamic shapes
next_s = next(dyn_shape_iter)
if not qml.math.is_abstract(next_s):
Expand Down Expand Up @@ -496,9 +496,10 @@ def handle_for_loop(
self, start, stop, step, *args, jaxpr_body_fn, consts_slice, args_slice, abstract_shapes_slice
):
"""Handle a for loop primitive."""
consts = args[consts_slice]
init_state = args[args_slice]
abstract_shapes = args[abstract_shapes_slice]
# Convert tuples back to slices (tuples are used for JAX 0.7.0 hashability)
consts = args[slice(*consts_slice)]
init_state = args[slice(*args_slice)]
abstract_shapes = args[slice(*abstract_shapes_slice)]
new_jaxpr_body_fn = jaxpr_to_jaxpr(
copy(self), jaxpr_body_fn, consts, *abstract_shapes, start, *init_state
)
Expand All @@ -523,6 +524,10 @@ def handle_for_loop(
@PlxprInterpreter.register_primitive(cond_prim)
def handle_cond(self, *invals, jaxpr_branches, consts_slices, args_slice):
"""Handle a cond primitive."""
# Convert tuples back to slices (tuples are used for JAX 0.7.0 hashability)
args_slice = slice(*args_slice)
consts_slices = [slice(*s) for s in consts_slices]

args = invals[args_slice]

new_jaxprs = []
Expand Down Expand Up @@ -560,6 +565,11 @@ def handle_while_loop(
args_slice,
):
"""Handle a while loop primitive."""
# Convert tuples back to slices (tuples are used for JAX 0.7.0 hashability)
body_slice = slice(*body_slice)
cond_slice = slice(*cond_slice)
args_slice = slice(*args_slice)

consts_body = invals[body_slice]
consts_cond = invals[cond_slice]
init_state = invals[args_slice]
Expand Down Expand Up @@ -654,6 +664,11 @@ def flatten_while_loop(
args_slice,
):
"""Handle the while loop by a flattened python strategy."""
# Convert tuples back to slices (tuples are used for JAX 0.7.0 hashability)
body_slice = slice(*body_slice)
cond_slice = slice(*cond_slice)
args_slice = slice(*args_slice)

consts_body = invals[body_slice]
consts_cond = invals[cond_slice]
init_state = invals[args_slice]
Expand All @@ -671,6 +686,10 @@ def flatten_while_loop(
@FlattenedInterpreter.register_primitive(cond_prim)
def flattened_cond(self, *invals, jaxpr_branches, consts_slices, args_slice):
"""Handle the cond primitive by a flattened python strategy."""
# Convert tuples back to slices (tuples are used for JAX 0.7.0 hashability)
args_slice = slice(*args_slice)
consts_slices = [slice(*s) for s in consts_slices]

n_branches = len(jaxpr_branches)
conditions = invals[:n_branches]
args = invals[args_slice]
Expand All @@ -694,6 +713,11 @@ def flattened_for(
self, start, stop, step, *invals, jaxpr_body_fn, consts_slice, args_slice, abstract_shapes_slice
):
"""Handle the for loop by a flattened python strategy."""
# Convert tuples back to slices (tuples are used for JAX 0.7.0 hashability)
consts_slice = slice(*consts_slice)
args_slice = slice(*args_slice)
abstract_shapes_slice = slice(*abstract_shapes_slice)

consts = invals[consts_slice]
init_state = invals[args_slice]
abstract_shapes = invals[abstract_shapes_slice]
Expand Down
35 changes: 34 additions & 1 deletion pennylane/capture/custom_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"""

from enum import Enum
from typing import Any

from jax.extend.core import Primitive

Expand All @@ -30,10 +31,32 @@ class PrimitiveType(Enum):
TRANSFORM = "transform"


def _make_hashable(obj: Any) -> Any:
"""Convert potentially unhashable objects to hashable equivalents for JAX 0.7.0+.

JAX 0.7.0 requires all primitive parameters to be hashable. This helper converts
common unhashable types (list, dict, slice) to hashable tuples.

Args:
obj: Object to potentially convert to hashable form

Returns:
Hashable version of the object
"""
if isinstance(obj, slice):
return (obj.start, obj.stop, obj.step)
if isinstance(obj, list):
return tuple(_make_hashable(item) for item in obj)
if isinstance(obj, dict):
return tuple((k, _make_hashable(v)) for k, v in obj.items())

return obj


# pylint: disable=abstract-method,too-few-public-methods
class QmlPrimitive(Primitive):
"""A subclass for JAX's Primitive that differentiates between different
classes of primitives."""
classes of primitives and automatically makes parameters hashable for JAX 0.7.0+."""

_prim_type: PrimitiveType = PrimitiveType.DEFAULT

Expand All @@ -47,3 +70,13 @@ def prim_type(self):
def prim_type(self, value: str | PrimitiveType):
"""Setter for QmlPrimitive.prim_type."""
self._prim_type = PrimitiveType(value)

def bind(self, *args, **params):
"""Bind with automatic parameter hashability conversion for JAX 0.7.0+.

Overrides the parent bind method to automatically convert unhashable parameters
(like lists, dicts, and slices) to hashable tuples, which is required by JAX 0.7.0+.
"""
# Convert all parameters to hashable forms
hashable_params = {k: _make_hashable(v) for k, v in params.items()}
return super().bind(*args, **hashable_params)
44 changes: 33 additions & 11 deletions pennylane/capture/dynamic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
has_jax = True
try:
import jax
from jax._src.interpreters.partial_eval import TracingEqn
from jax.interpreters import partial_eval as pe
except ImportError: # pragma: no cover
has_jax = False # pragma: no cover
Expand Down Expand Up @@ -47,7 +48,7 @@ def _get_shape_for_array(x, abstract_shapes: list, previous_ints: list) -> dict:
return {}

abstract_axes = {}
for i, s in enumerate(getattr(x, "shape", ())):
for i, s in enumerate(getattr(x, "shape", ())): # pragma: no cover
if not isinstance(s, int): # if not int, then abstract
found = False
# check if the shape tracer is one we have already encountered
Expand Down Expand Up @@ -137,8 +138,8 @@ def f(n):
if not any(abstracted_axes):
return None, ()

abstracted_axes = jax.tree_util.tree_unflatten(structure, abstracted_axes)
return abstracted_axes, abstract_shapes
abstracted_axes = jax.tree_util.tree_unflatten(structure, abstracted_axes) # pragma: no cover
return abstracted_axes, abstract_shapes # pragma: no cover


def register_custom_staging_rule(
Expand All @@ -164,7 +165,14 @@ def register_custom_staging_rule(
# see https://github.com/jax-ml/jax/blob/9e62994bce7c7fcbb2f6a50c9ef89526cd2c2be6/jax/_src/lax/lax.py#L3538
# and https://github.com/jax-ml/jax/blob/9e62994bce7c7fcbb2f6a50c9ef89526cd2c2be6/jax/_src/lax/lax.py#L208
# for reference to how jax is handling staging rules for dynamic shapes in v0.4.28
# see also capture/intro_to_dynamic_shapes.md
# JAX 0.6.2 to 0.7.0 introduced breaking changes in custom staging rules for dynamic shapes:
# 1. DynamicJaxprTracer constructor now requires the var as 3rd argument (previously created internally)
# 2. TracingEqn must be used instead of JaxprEqn for trace.frame.add_eqn
#
# This implementation creates vars first using trace.frame.newvar() before constructing
# DynamicJaxprTracer instances, fixing dynamic shape support that was broken in JAX 0.7.0.
# See pennylane/capture/jax_patches.py for related fixes to JAX's own staging rules.
# See also capture/intro_to_dynamic_shapes.md for dynamic shapes documentation.

def _tracer_and_outvar(
jaxpr_trace: pe.DynamicJaxprTrace,
Expand All @@ -176,15 +184,18 @@ def _tracer_and_outvar(
Returned vars are cached in env for use in future shapes
"""
if not hasattr(outvar.aval, "shape"):
out_tracer = pe.DynamicJaxprTracer(jaxpr_trace, outvar.aval, None)
return out_tracer, jaxpr_trace.makevar(out_tracer)
# JAX 0.7.0: Create variable first, then pass to DynamicJaxprTracer
new_var = jaxpr_trace.frame.newvar(outvar.aval)
out_tracer = pe.DynamicJaxprTracer(jaxpr_trace, outvar.aval, new_var)
return out_tracer, new_var
new_shape = [s if isinstance(s, int) else env[s] for s in outvar.aval.shape]
if all(isinstance(s, int) for s in outvar.aval.shape):
new_aval = jax.core.ShapedArray(tuple(new_shape), outvar.aval.dtype)
else:
else: # pragma: no cover
new_aval = jax.core.DShapedArray(tuple(new_shape), outvar.aval.dtype)
out_tracer = pe.DynamicJaxprTracer(jaxpr_trace, new_aval, None)
new_var = jaxpr_trace.makevar(out_tracer)
# JAX 0.7.0: Create variable first, then pass to DynamicJaxprTracer
new_var = jaxpr_trace.frame.newvar(new_aval)
out_tracer = pe.DynamicJaxprTracer(jaxpr_trace, new_aval, new_var)

if not isinstance(outvar, jax.extend.core.Literal):
env[outvar] = new_var
Expand All @@ -211,15 +222,26 @@ def custom_staging_rule(
else:
out_tracers, returned_vars = (), ()

invars = [jaxpr_trace.getvar(x) for x in tracers]
# JAX 0.7.0: Use t.val to get var from tracer, and TracingEqn for frame.add_eqn
invars = [t.val for t in tracers]
eqn = jax.core.new_jaxpr_eqn(
invars,
returned_vars,
primitive,
params,
jax.core.no_effects,
source_info,
)
tracing_eqn = TracingEqn(
list(tracers),
returned_vars,
primitive,
params,
eqn.effects,
source_info,
eqn.ctx,
)
jaxpr_trace.frame.add_eqn(eqn)
jaxpr_trace.frame.add_eqn(tracing_eqn)
return out_tracers

pe.custom_staging_rules[primitive] = custom_staging_rule
6 changes: 3 additions & 3 deletions pennylane/capture/expand_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ class ExpandTransformsInterpreter(PlxprInterpreter):
def _(
self, *invals, inner_jaxpr, args_slice, consts_slice, targs_slice, tkwargs, transform
): # pylint: disable=too-many-arguments
args = invals[args_slice]
consts = invals[consts_slice]
targs = invals[targs_slice]
args = invals[slice(*args_slice)]
consts = invals[slice(*consts_slice)]
targs = invals[slice(*targs_slice)]

def wrapper(*inner_args):
return copy(self).eval(inner_jaxpr, consts, *inner_args)
Expand Down
2 changes: 1 addition & 1 deletion pennylane/capture/make_plxpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def fn(x):
if not has_jax: # pragma: no cover
raise ImportError(
"Module jax is required for the ``make_plxpr`` function. "
"You can install jax via: pip install jax~=0.6.0"
"You can install jax via: pip install jax==0.7.0"
)

if not qml.capture.enabled():
Expand Down
Loading
Loading