1414"""
1515Contains a utility for handling inputs with dynamically shaped arrays.
1616"""
17- from collections .abc import Callable , Sequence
17+ from collections .abc import Callable
1818
1919has_jax = True
2020try :
2121 import jax
2222 from jax ._src .interpreters .partial_eval import TracingEqn
23- from jax .interpreters import partial_eval as pe
24- except ImportError : # pragma: no cover
25- has_jax = False # pragma: no cover
23+
24+
25+ except ImportError as e : # pragma: no cover
26+ has_jax = False
2627
2728
2829def _get_shape_for_array (x , abstract_shapes : list , previous_ints : list ) -> dict :
@@ -123,7 +124,10 @@ def f(n):
123124
124125 """
125126 if not has_jax : # pragma: no cover
126- raise ImportError ("jax must be installed to use determine_abstracted_axes" )
127+ raise ImportError (
128+ "JAX == 0.7.0 must be installed to use determine_abstracted_axes. "
129+ "Install with: pip install jax==0.7.0 jaxlib==0.7.0 "
130+ )
127131 if not jax .config .jax_dynamic_shapes :
128132 return None , ()
129133
@@ -175,18 +179,20 @@ def register_custom_staging_rule(
175179 # See also capture/intro_to_dynamic_shapes.md for dynamic shapes documentation.
176180
177181 def _tracer_and_outvar (
178- jaxpr_trace : pe . DynamicJaxprTrace ,
182+ jaxpr_trace ,
179183 outvar : jax .extend .core .Var ,
180184 env : dict [jax .extend .core .Var , jax .extend .core .Var ],
181- ) -> tuple [ pe . DynamicJaxprTracer , jax . extend . core . Var ] :
185+ ):
182186 """
183187 Create a new tracer and return var from the true branch outvar.
184188 Returned vars are cached in env for use in future shapes
185189 """
186190 if not hasattr (outvar .aval , "shape" ):
187191 # JAX 0.7.0: Create variable first, then pass to DynamicJaxprTracer
188192 new_var = jaxpr_trace .frame .newvar (outvar .aval )
189- out_tracer = pe .DynamicJaxprTracer (jaxpr_trace , outvar .aval , new_var )
193+ out_tracer = jax .interpreters .partial_eval .DynamicJaxprTracer (
194+ jaxpr_trace , outvar .aval , new_var
195+ )
190196 return out_tracer , new_var
191197 new_shape = [s if isinstance (s , int ) else env [s ] for s in outvar .aval .shape ]
192198 if all (isinstance (s , int ) for s in outvar .aval .shape ):
@@ -195,15 +201,15 @@ def _tracer_and_outvar(
195201 new_aval = jax .core .DShapedArray (tuple (new_shape ), outvar .aval .dtype )
196202 # JAX 0.7.0: Create variable first, then pass to DynamicJaxprTracer
197203 new_var = jaxpr_trace .frame .newvar (new_aval )
198- out_tracer = pe .DynamicJaxprTracer (jaxpr_trace , new_aval , new_var )
204+ out_tracer = jax .interpreters .partial_eval .DynamicJaxprTracer (
205+ jaxpr_trace , new_aval , new_var
206+ )
199207
200208 if not isinstance (outvar , jax .extend .core .Literal ):
201209 env [outvar ] = new_var
202210 return out_tracer , new_var
203211
204- def custom_staging_rule (
205- jaxpr_trace : pe .DynamicJaxprTrace , source_info , * tracers : pe .DynamicJaxprTracer , ** params
206- ) -> Sequence [pe .DynamicJaxprTracer ] | pe .DynamicJaxprTracer :
212+ def custom_staging_rule (jaxpr_trace , source_info , * tracers , ** params ):
207213 """
208214 Add new jaxpr equation to the jaxpr_trace and return new tracers.
209215 """
@@ -244,4 +250,4 @@ def custom_staging_rule(
244250 jaxpr_trace .frame .add_eqn (tracing_eqn )
245251 return out_tracers
246252
247- pe .custom_staging_rules [primitive ] = custom_staging_rule
253+ jax . interpreters . partial_eval .custom_staging_rules [primitive ] = custom_staging_rule
0 commit comments