Skip to content

Commit e5c81be

Browse files
JerryChen97albi3romudit2812
authored
Bump jax api hashability (#8701)
**Context:** Subset of #8525 focusing on hashability, as well as the jax API updates **Description of the Change:** **Benefits:** **Possible Drawbacks:** **Related GitHub Issues:** [sc-102157] --------- Co-authored-by: Christina Lee <[email protected]> Co-authored-by: Mudit Pandey <[email protected]>
1 parent 741b82e commit e5c81be

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+384
-171
lines changed

.github/workflows/documentation-tests.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ jobs:
5555
run: |
5656
pip install --upgrade --index-url https://test.pypi.org/simple/ --extra-index-url https://pypi.org/simple/ pennylane-catalyst pennylane-lightning
5757
pip install -e .
58+
# TODO: use 0.7.0 after updating all the documentation
5859
pip install sybil pytest "jax~=0.6.0" "jaxlib~=0.6.0" torch matplotlib pyzx
5960
6061
- name: Print Dependencies

.github/workflows/interface-dependency-versions.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,12 @@ on:
1616
description: The version of JAX to use for testing
1717
required: false
1818
type: string
19-
default: '0.6.2'
19+
default: '0.7.0'
2020
catalyst_jax_version:
2121
description: The version of JAX to use for testing along with Catalyst
2222
required: false
2323
type: string
24-
default: '0.6.2'
24+
default: '0.7.0'
2525
torch_version:
2626
description: The version of PyTorch to use for testing
2727
required: false

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ timer.dat
2424
tmp/*
2525
benchmark/revisions/
2626
venv
27+
*venv*/
2728
config.toml
2829
.envrc
2930
qml_debug.log

doc/introduction/interfaces/jax.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ JAX interface
1010

1111
.. code-block:: bash
1212
13-
pip install jax~=0.6.0 jaxlib~=0.6.0
13+
pip install jax==0.7.0 jaxlib==0.7.0
1414
1515
You can then import PennyLane and JAX as follows:
1616

doc/requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@ appdirs
33
autograd
44
autoray
55
cachetools
6-
jax==0.6.0
7-
jaxlib==0.6.0
6+
jax==0.7.0
7+
jaxlib==0.7.0
88
mistune==0.8.4
99
m2r2
1010
# TODO: Remove once galois becomes compatible with latest numpy

pennylane/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -204,12 +204,12 @@
204204
from packaging.version import Version as _Version
205205

206206
if _find_spec("jax") is not None:
207-
if (jax_version := _Version(_metadata_version("jax"))) > _Version("0.6.2"): # pragma: no cover
207+
if (jax_version := _Version(_metadata_version("jax"))) > _Version("0.7.0"): # pragma: no cover
208208
warnings.warn(
209-
"PennyLane is not yet compatible with JAX versions > 0.6.2. "
209+
"PennyLane is not yet compatible with JAX versions > 0.7.0. "
210210
f"You have version {jax_version} installed. "
211-
"Please downgrade JAX to 0.6.2 to avoid runtime errors using "
212-
"python -m pip install jax~=0.6.0 jaxlib~=0.6.0",
211+
"Please downgrade JAX to 0.7.0 to avoid runtime errors using "
212+
"python -m pip install jax==0.7.0 jaxlib==0.7.0",
213213
RuntimeWarning,
214214
)
215215

pennylane/_grad.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,9 @@ def _grad_abstract(*args, argnums, jaxpr, n_consts, method, h, scalar_out, fn):
8686

8787

8888
def _shape(shape, dtype, weak_type=False):
89-
if jax.config.jax_dynamic_shapes and any(not isinstance(s, int) for s in shape):
89+
if jax.config.jax_dynamic_shapes and any(
90+
not isinstance(s, int) for s in shape
91+
): # pragma: no cover
9092
return jax.core.DShapedArray(shape, dtype, weak_type=weak_type)
9193
return jax.core.ShapedArray(shape, dtype, weak_type=weak_type)
9294

pennylane/capture/base_interpreter.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def _fill_in_shape_with_dyn_shape(dyn_shape: tuple["jax.core.Tracer"], shape: tu
8282
for s in shape:
8383
if s is not None:
8484
new_shape.append(s)
85-
else:
85+
else: # pragma: no cover
8686
# pull from iterable of dynamic shapes
8787
next_s = next(dyn_shape_iter)
8888
if not qml.math.is_abstract(next_s):
@@ -496,9 +496,10 @@ def handle_for_loop(
496496
self, start, stop, step, *args, jaxpr_body_fn, consts_slice, args_slice, abstract_shapes_slice
497497
):
498498
"""Handle a for loop primitive."""
499-
consts = args[consts_slice]
500-
init_state = args[args_slice]
501-
abstract_shapes = args[abstract_shapes_slice]
499+
# Convert tuples back to slices (tuples are used for JAX 0.7.0 hashability)
500+
consts = args[slice(*consts_slice)]
501+
init_state = args[slice(*args_slice)]
502+
abstract_shapes = args[slice(*abstract_shapes_slice)]
502503
new_jaxpr_body_fn = jaxpr_to_jaxpr(
503504
copy(self), jaxpr_body_fn, consts, *abstract_shapes, start, *init_state
504505
)
@@ -523,6 +524,10 @@ def handle_for_loop(
523524
@PlxprInterpreter.register_primitive(cond_prim)
524525
def handle_cond(self, *invals, jaxpr_branches, consts_slices, args_slice):
525526
"""Handle a cond primitive."""
527+
# Convert tuples back to slices (tuples are used for JAX 0.7.0 hashability)
528+
args_slice = slice(*args_slice)
529+
consts_slices = [slice(*s) for s in consts_slices]
530+
526531
args = invals[args_slice]
527532

528533
new_jaxprs = []
@@ -560,6 +565,11 @@ def handle_while_loop(
560565
args_slice,
561566
):
562567
"""Handle a while loop primitive."""
568+
# Convert tuples back to slices (tuples are used for JAX 0.7.0 hashability)
569+
body_slice = slice(*body_slice)
570+
cond_slice = slice(*cond_slice)
571+
args_slice = slice(*args_slice)
572+
563573
consts_body = invals[body_slice]
564574
consts_cond = invals[cond_slice]
565575
init_state = invals[args_slice]
@@ -654,6 +664,11 @@ def flatten_while_loop(
654664
args_slice,
655665
):
656666
"""Handle the while loop by a flattened python strategy."""
667+
# Convert tuples back to slices (tuples are used for JAX 0.7.0 hashability)
668+
body_slice = slice(*body_slice)
669+
cond_slice = slice(*cond_slice)
670+
args_slice = slice(*args_slice)
671+
657672
consts_body = invals[body_slice]
658673
consts_cond = invals[cond_slice]
659674
init_state = invals[args_slice]
@@ -671,6 +686,10 @@ def flatten_while_loop(
671686
@FlattenedInterpreter.register_primitive(cond_prim)
672687
def flattened_cond(self, *invals, jaxpr_branches, consts_slices, args_slice):
673688
"""Handle the cond primitive by a flattened python strategy."""
689+
# Convert tuples back to slices (tuples are used for JAX 0.7.0 hashability)
690+
args_slice = slice(*args_slice)
691+
consts_slices = [slice(*s) for s in consts_slices]
692+
674693
n_branches = len(jaxpr_branches)
675694
conditions = invals[:n_branches]
676695
args = invals[args_slice]
@@ -694,6 +713,11 @@ def flattened_for(
694713
self, start, stop, step, *invals, jaxpr_body_fn, consts_slice, args_slice, abstract_shapes_slice
695714
):
696715
"""Handle the for loop by a flattened python strategy."""
716+
# Convert tuples back to slices (tuples are used for JAX 0.7.0 hashability)
717+
consts_slice = slice(*consts_slice)
718+
args_slice = slice(*args_slice)
719+
abstract_shapes_slice = slice(*abstract_shapes_slice)
720+
697721
consts = invals[consts_slice]
698722
init_state = invals[args_slice]
699723
abstract_shapes = invals[abstract_shapes_slice]

pennylane/capture/custom_primitives.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
"""
1717

1818
from enum import Enum
19+
from typing import Any
1920

2021
from jax.extend.core import Primitive
2122

@@ -30,10 +31,32 @@ class PrimitiveType(Enum):
3031
TRANSFORM = "transform"
3132

3233

34+
def _make_hashable(obj: Any) -> Any:
35+
"""Convert potentially unhashable objects to hashable equivalents for JAX 0.7.0+.
36+
37+
JAX 0.7.0 requires all primitive parameters to be hashable. This helper converts
38+
common unhashable types (list, dict, slice) to hashable tuples.
39+
40+
Args:
41+
obj: Object to potentially convert to hashable form
42+
43+
Returns:
44+
Hashable version of the object
45+
"""
46+
if isinstance(obj, slice):
47+
return (obj.start, obj.stop, obj.step)
48+
if isinstance(obj, list):
49+
return tuple(_make_hashable(item) for item in obj)
50+
if isinstance(obj, dict):
51+
return tuple((k, _make_hashable(v)) for k, v in obj.items())
52+
53+
return obj
54+
55+
3356
# pylint: disable=abstract-method,too-few-public-methods
3457
class QmlPrimitive(Primitive):
3558
"""A subclass for JAX's Primitive that differentiates between different
36-
classes of primitives."""
59+
classes of primitives and automatically makes parameters hashable for JAX 0.7.0+."""
3760

3861
_prim_type: PrimitiveType = PrimitiveType.DEFAULT
3962

@@ -47,3 +70,13 @@ def prim_type(self):
4770
def prim_type(self, value: str | PrimitiveType):
4871
"""Setter for QmlPrimitive.prim_type."""
4972
self._prim_type = PrimitiveType(value)
73+
74+
def bind(self, *args, **params):
75+
"""Bind with automatic parameter hashability conversion for JAX 0.7.0+.
76+
77+
Overrides the parent bind method to automatically convert unhashable parameters
78+
(like lists, dicts, and slices) to hashable tuples, which is required by JAX 0.7.0+.
79+
"""
80+
# Convert all parameters to hashable forms
81+
hashable_params = {k: _make_hashable(v) for k, v in params.items()}
82+
return super().bind(*args, **hashable_params)

pennylane/capture/dynamic_shapes.py

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
has_jax = True
2020
try:
2121
import jax
22+
from jax._src.interpreters.partial_eval import TracingEqn
2223
from jax.interpreters import partial_eval as pe
2324
except ImportError: # pragma: no cover
2425
has_jax = False # pragma: no cover
@@ -47,7 +48,7 @@ def _get_shape_for_array(x, abstract_shapes: list, previous_ints: list) -> dict:
4748
return {}
4849

4950
abstract_axes = {}
50-
for i, s in enumerate(getattr(x, "shape", ())):
51+
for i, s in enumerate(getattr(x, "shape", ())): # pragma: no cover
5152
if not isinstance(s, int): # if not int, then abstract
5253
found = False
5354
# check if the shape tracer is one we have already encountered
@@ -137,8 +138,8 @@ def f(n):
137138
if not any(abstracted_axes):
138139
return None, ()
139140

140-
abstracted_axes = jax.tree_util.tree_unflatten(structure, abstracted_axes)
141-
return abstracted_axes, abstract_shapes
141+
abstracted_axes = jax.tree_util.tree_unflatten(structure, abstracted_axes) # pragma: no cover
142+
return abstracted_axes, abstract_shapes # pragma: no cover
142143

143144

144145
def register_custom_staging_rule(
@@ -164,7 +165,14 @@ def register_custom_staging_rule(
164165
# see https://github.com/jax-ml/jax/blob/9e62994bce7c7fcbb2f6a50c9ef89526cd2c2be6/jax/_src/lax/lax.py#L3538
165166
# and https://github.com/jax-ml/jax/blob/9e62994bce7c7fcbb2f6a50c9ef89526cd2c2be6/jax/_src/lax/lax.py#L208
166167
# for reference to how jax is handling staging rules for dynamic shapes in v0.4.28
167-
# see also capture/intro_to_dynamic_shapes.md
168+
# JAX 0.6.2 to 0.7.0 introduced breaking changes in custom staging rules for dynamic shapes:
169+
# 1. DynamicJaxprTracer constructor now requires the var as 3rd argument (previously created internally)
170+
# 2. TracingEqn must be used instead of JaxprEqn for trace.frame.add_eqn
171+
#
172+
# This implementation creates vars first using trace.frame.newvar() before constructing
173+
# DynamicJaxprTracer instances, fixing dynamic shape support that was broken in JAX 0.7.0.
174+
# See pennylane/capture/jax_patches.py for related fixes to JAX's own staging rules.
175+
# See also capture/intro_to_dynamic_shapes.md for dynamic shapes documentation.
168176

169177
def _tracer_and_outvar(
170178
jaxpr_trace: pe.DynamicJaxprTrace,
@@ -176,15 +184,18 @@ def _tracer_and_outvar(
176184
Returned vars are cached in env for use in future shapes
177185
"""
178186
if not hasattr(outvar.aval, "shape"):
179-
out_tracer = pe.DynamicJaxprTracer(jaxpr_trace, outvar.aval, None)
180-
return out_tracer, jaxpr_trace.makevar(out_tracer)
187+
# JAX 0.7.0: Create variable first, then pass to DynamicJaxprTracer
188+
new_var = jaxpr_trace.frame.newvar(outvar.aval)
189+
out_tracer = pe.DynamicJaxprTracer(jaxpr_trace, outvar.aval, new_var)
190+
return out_tracer, new_var
181191
new_shape = [s if isinstance(s, int) else env[s] for s in outvar.aval.shape]
182192
if all(isinstance(s, int) for s in outvar.aval.shape):
183193
new_aval = jax.core.ShapedArray(tuple(new_shape), outvar.aval.dtype)
184-
else:
194+
else: # pragma: no cover
185195
new_aval = jax.core.DShapedArray(tuple(new_shape), outvar.aval.dtype)
186-
out_tracer = pe.DynamicJaxprTracer(jaxpr_trace, new_aval, None)
187-
new_var = jaxpr_trace.makevar(out_tracer)
196+
# JAX 0.7.0: Create variable first, then pass to DynamicJaxprTracer
197+
new_var = jaxpr_trace.frame.newvar(new_aval)
198+
out_tracer = pe.DynamicJaxprTracer(jaxpr_trace, new_aval, new_var)
188199

189200
if not isinstance(outvar, jax.extend.core.Literal):
190201
env[outvar] = new_var
@@ -211,15 +222,26 @@ def custom_staging_rule(
211222
else:
212223
out_tracers, returned_vars = (), ()
213224

214-
invars = [jaxpr_trace.getvar(x) for x in tracers]
225+
# JAX 0.7.0: Use t.val to get var from tracer, and TracingEqn for frame.add_eqn
226+
invars = [t.val for t in tracers]
215227
eqn = jax.core.new_jaxpr_eqn(
216228
invars,
217229
returned_vars,
218230
primitive,
219231
params,
220232
jax.core.no_effects,
233+
source_info,
234+
)
235+
tracing_eqn = TracingEqn(
236+
list(tracers),
237+
returned_vars,
238+
primitive,
239+
params,
240+
eqn.effects,
241+
source_info,
242+
eqn.ctx,
221243
)
222-
jaxpr_trace.frame.add_eqn(eqn)
244+
jaxpr_trace.frame.add_eqn(tracing_eqn)
223245
return out_tracers
224246

225247
pe.custom_staging_rules[primitive] = custom_staging_rule

0 commit comments

Comments
 (0)