1919has_jax = True
2020try :
2121 import jax
22+ from jax ._src .interpreters .partial_eval import TracingEqn
2223 from jax .interpreters import partial_eval as pe
2324except ImportError : # pragma: no cover
2425 has_jax = False # pragma: no cover
@@ -47,7 +48,7 @@ def _get_shape_for_array(x, abstract_shapes: list, previous_ints: list) -> dict:
4748 return {}
4849
4950 abstract_axes = {}
50- for i , s in enumerate (getattr (x , "shape" , ())):
51+ for i , s in enumerate (getattr (x , "shape" , ())): # pragma: no cover
5152 if not isinstance (s , int ): # if not int, then abstract
5253 found = False
5354 # check if the shape tracer is one we have already encountered
@@ -137,8 +138,8 @@ def f(n):
137138 if not any (abstracted_axes ):
138139 return None , ()
139140
140- abstracted_axes = jax .tree_util .tree_unflatten (structure , abstracted_axes )
141- return abstracted_axes , abstract_shapes
141+ abstracted_axes = jax .tree_util .tree_unflatten (structure , abstracted_axes ) # pragma: no cover
142+ return abstracted_axes , abstract_shapes # pragma: no cover
142143
143144
144145def register_custom_staging_rule (
@@ -164,7 +165,14 @@ def register_custom_staging_rule(
164165 # see https://github.com/jax-ml/jax/blob/9e62994bce7c7fcbb2f6a50c9ef89526cd2c2be6/jax/_src/lax/lax.py#L3538
165166 # and https://github.com/jax-ml/jax/blob/9e62994bce7c7fcbb2f6a50c9ef89526cd2c2be6/jax/_src/lax/lax.py#L208
166167 # for reference to how jax is handling staging rules for dynamic shapes in v0.4.28
167- # see also capture/intro_to_dynamic_shapes.md
168+ # JAX 0.6.2 to 0.7.0 introduced breaking changes in custom staging rules for dynamic shapes:
169+ # 1. DynamicJaxprTracer constructor now requires the var as 3rd argument (previously created internally)
170+ # 2. TracingEqn must be used instead of JaxprEqn for trace.frame.add_eqn
171+ #
172+ # This implementation creates vars first using trace.frame.newvar() before constructing
173+ # DynamicJaxprTracer instances, fixing dynamic shape support that was broken in JAX 0.7.0.
174+ # See pennylane/capture/jax_patches.py for related fixes to JAX's own staging rules.
175+ # See also capture/intro_to_dynamic_shapes.md for dynamic shapes documentation.
168176
169177 def _tracer_and_outvar (
170178 jaxpr_trace : pe .DynamicJaxprTrace ,
@@ -176,15 +184,18 @@ def _tracer_and_outvar(
176184 Returned vars are cached in env for use in future shapes
177185 """
178186 if not hasattr (outvar .aval , "shape" ):
179- out_tracer = pe .DynamicJaxprTracer (jaxpr_trace , outvar .aval , None )
180- return out_tracer , jaxpr_trace .makevar (out_tracer )
187+ # JAX 0.7.0: Create variable first, then pass to DynamicJaxprTracer
188+ new_var = jaxpr_trace .frame .newvar (outvar .aval )
189+ out_tracer = pe .DynamicJaxprTracer (jaxpr_trace , outvar .aval , new_var )
190+ return out_tracer , new_var
181191 new_shape = [s if isinstance (s , int ) else env [s ] for s in outvar .aval .shape ]
182192 if all (isinstance (s , int ) for s in outvar .aval .shape ):
183193 new_aval = jax .core .ShapedArray (tuple (new_shape ), outvar .aval .dtype )
184- else :
194+ else : # pragma: no cover
185195 new_aval = jax .core .DShapedArray (tuple (new_shape ), outvar .aval .dtype )
186- out_tracer = pe .DynamicJaxprTracer (jaxpr_trace , new_aval , None )
187- new_var = jaxpr_trace .makevar (out_tracer )
196+ # JAX 0.7.0: Create variable first, then pass to DynamicJaxprTracer
197+ new_var = jaxpr_trace .frame .newvar (new_aval )
198+ out_tracer = pe .DynamicJaxprTracer (jaxpr_trace , new_aval , new_var )
188199
189200 if not isinstance (outvar , jax .extend .core .Literal ):
190201 env [outvar ] = new_var
@@ -211,15 +222,26 @@ def custom_staging_rule(
211222 else :
212223 out_tracers , returned_vars = (), ()
213224
214- invars = [jaxpr_trace .getvar (x ) for x in tracers ]
225+ # JAX 0.7.0: Use t.val to get var from tracer, and TracingEqn for frame.add_eqn
226+ invars = [t .val for t in tracers ]
215227 eqn = jax .core .new_jaxpr_eqn (
216228 invars ,
217229 returned_vars ,
218230 primitive ,
219231 params ,
220232 jax .core .no_effects ,
233+ source_info ,
234+ )
235+ tracing_eqn = TracingEqn (
236+ list (tracers ),
237+ returned_vars ,
238+ primitive ,
239+ params ,
240+ eqn .effects ,
241+ source_info ,
242+ eqn .ctx ,
221243 )
222- jaxpr_trace .frame .add_eqn (eqn )
244+ jaxpr_trace .frame .add_eqn (tracing_eqn )
223245 return out_tracers
224246
225247 pe .custom_staging_rules [primitive ] = custom_staging_rule
0 commit comments