File tree Expand file tree Collapse file tree 1 file changed +2
-4
lines changed Expand file tree Collapse file tree 1 file changed +2
-4
lines changed Original file line number Diff line number Diff line change 1919has_jax = True
2020try :
2121 import jax
22- from jax ._src .interpreters .partial_eval import TracingEqn
2322 from jax .interpreters import partial_eval as pe
2423except ImportError : # pragma: no cover
2524 has_jax = False # pragma: no cover
@@ -166,8 +165,7 @@ def register_custom_staging_rule(
166165 # and https://github.com/jax-ml/jax/blob/9e62994bce7c7fcbb2f6a50c9ef89526cd2c2be6/jax/_src/lax/lax.py#L208
167166 # for reference to how jax is handling staging rules for dynamic shapes in v0.4.28
168167 # JAX 0.6.2 to 0.7.0 introduced breaking changes in custom staging rules for dynamic shapes:
169- # 1. DynamicJaxprTracer constructor now requires the var as 3rd argument (previously created internally)
170- # 2. TracingEqn must be used instead of JaxprEqn for trace.frame.add_eqn
168+ # DynamicJaxprTracer constructor now requires the var as 3rd argument (previously created internally)
171169 #
172170 # This implementation creates vars first using trace.frame.newvar() before constructing
173171 # DynamicJaxprTracer instances, fixing dynamic shape support that was broken in JAX 0.7.0.
@@ -222,7 +220,7 @@ def custom_staging_rule(
222220 else :
223221 out_tracers , returned_vars = (), ()
224222
225- # JAX 0.7.0: Use t.val to get var from tracer, and TracingEqn for frame.add_eqn
223+ # JAX 0.7.0: Use t.val to get var from tracer
226224 invars = [t .val for t in tracers ]
227225 eqn = jax .core .new_jaxpr_eqn (
228226 invars ,
You can’t perform that action at this time.
0 commit comments