Skip to content

Commit ea0773e

Browse files
committed
rm more
1 parent 705f963 commit ea0773e

File tree

1 file changed

+2
-4
lines changed

1 file changed

+2
-4
lines changed

pennylane/capture/dynamic_shapes.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
has_jax = True
2020
try:
2121
import jax
22-
from jax._src.interpreters.partial_eval import TracingEqn
2322
from jax.interpreters import partial_eval as pe
2423
except ImportError: # pragma: no cover
2524
has_jax = False # pragma: no cover
@@ -166,8 +165,7 @@ def register_custom_staging_rule(
166165
# and https://github.com/jax-ml/jax/blob/9e62994bce7c7fcbb2f6a50c9ef89526cd2c2be6/jax/_src/lax/lax.py#L208
167166
# for reference to how jax is handling staging rules for dynamic shapes in v0.4.28
168167
# JAX 0.6.2 to 0.7.0 introduced breaking changes in custom staging rules for dynamic shapes:
169-
# 1. DynamicJaxprTracer constructor now requires the var as 3rd argument (previously created internally)
170-
# 2. TracingEqn must be used instead of JaxprEqn for trace.frame.add_eqn
168+
# DynamicJaxprTracer constructor now requires the var as 3rd argument (previously created internally)
171169
#
172170
# This implementation creates vars first using trace.frame.newvar() before constructing
173171
# DynamicJaxprTracer instances, fixing dynamic shape support that was broken in JAX 0.7.0.
@@ -222,7 +220,7 @@ def custom_staging_rule(
222220
else:
223221
out_tracers, returned_vars = (), ()
224222

225-
# JAX 0.7.0: Use t.val to get var from tracer, and TracingEqn for frame.add_eqn
223+
# JAX 0.7.0: Use t.val to get var from tracer
226224
invars = [t.val for t in tracers]
227225
eqn = jax.core.new_jaxpr_eqn(
228226
invars,

0 commit comments

Comments
 (0)