2929AbstractShapeLocation = namedtuple ("AbstractShapeLocation" , ("arg_idx" , "shape_idx" ))
3030
3131
32- def add_abstract_shapes (f , shape_locations : list [list [AbstractShapeLocation ]]):
32+ def add_abstract_shapes (f , shape_locations : list [list [AbstractShapeLocation ]]): # pragma: no cover
3333 """Add the abstract shapes at the specified locations to the output of f.
3434
3535 Here we can see that the shapes at argument 0, shape index 0 and
@@ -65,7 +65,7 @@ def new_f(*args, **kwargs):
6565 return new_f
6666
6767
68- def get_dummy_arg (arg ):
68+ def get_dummy_arg (arg ): # pragma: no cover
6969 """If any axes are abstract, replace them with an empty numpy array.
7070
7171 Even if abstracted_axes specifies two dimensions as having different dynamic shapes,
@@ -112,7 +112,7 @@ def validate_no_resizing_returns(
112112 """
113113 offset = len (locations ) # number of abstract shapes. We start from the first normal arg.
114114
115- for locations_list in locations :
115+ for locations_list in locations : # pragma: no cover
116116 loc0 = locations_list [0 ]
117117 first_var = jaxpr .outvars [loc0 .arg_idx + offset ].aval .shape [loc0 .shape_idx ]
118118 for compare_loc in locations_list [1 :]:
@@ -132,7 +132,7 @@ def validate_no_resizing_returns(
132132
133133
134134def _has_dynamic_shape (val ):
135- return any (not isinstance (s , int ) for s in getattr (val , "shape" , ()))
135+ return any (not isinstance (s , int ) for s in getattr (val , "shape" , ())) # pragma: no cover
136136
137137
138138def handle_jaxpr_error (
@@ -142,7 +142,7 @@ def handle_jaxpr_error(
142142 about 'Incompatible shapes for broadcasting'."""
143143 import jax # pylint: disable=import-outside-toplevel
144144
145- if "Incompatible shapes for broadcasting" in str (e ) and jax .config .jax_dynamic_shapes :
145+ if "Incompatible shapes for broadcasting" in str (e ) and jax .config .jax_dynamic_shapes : # pragma: no cover
146146 closures = sum (((fn .__closure__ or ()) for fn in fns ), ())
147147 if any (_has_dynamic_shape (i .cell_contents ) for i in closures ):
148148 msg = (
@@ -176,7 +176,7 @@ def add_arg(self, x_idx: int, x):
176176 arg_abstracted_axes = {}
177177
178178 for shape_idx , s in enumerate (getattr (x , "shape" , ())):
179- if not isinstance (s , int ): # if not int, then abstract
179+ if not isinstance (s , int ): # pragma: no cover
180180 found = False
181181 if not self .allow_array_resizing :
182182 for previous_idx , previous_shape in enumerate (self .abstract_shapes ):
@@ -189,7 +189,7 @@ def add_arg(self, x_idx: int, x):
189189 break
190190 # haven't encountered it, so add it to abstract_axes
191191 # and use new number designation
192- if not found :
192+ if not found : # pragma: no cover
193193 arg_abstracted_axes [shape_idx ] = len (self .abstract_shapes )
194194 self .shape_locations .append ([AbstractShapeLocation (x_idx , shape_idx )])
195195 self .abstract_shapes .append (s )
@@ -262,5 +262,5 @@ def f(*args, allow_array_resizing):
262262 if not any (calculator .abstracted_axes ):
263263 return None , [], []
264264
265- abstracted_axes = jax .tree_util .tree_unflatten (structure , calculator .abstracted_axes )
266- return abstracted_axes , calculator .abstract_shapes , calculator .shape_locations
265+ abstracted_axes = jax .tree_util .tree_unflatten (structure , calculator .abstracted_axes ) # pragma: no cover
266+ return abstracted_axes , calculator .abstract_shapes , calculator .shape_locations # pragma: no cover
0 commit comments