Skip to content

Commit 9d7a237

Browse files
committed
no cover loop dynamic shape part
1 parent 5d3300d commit 9d7a237

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

pennylane/control_flow/_loop_abstract_axes.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
AbstractShapeLocation = namedtuple("AbstractShapeLocation", ("arg_idx", "shape_idx"))
3030

3131

32-
def add_abstract_shapes(f, shape_locations: list[list[AbstractShapeLocation]]):
32+
def add_abstract_shapes(f, shape_locations: list[list[AbstractShapeLocation]]): # pragma: no cover
3333
"""Add the abstract shapes at the specified locations to the output of f.
3434
3535
Here we can see that the shapes at argument 0, shape index 0 and
@@ -65,7 +65,7 @@ def new_f(*args, **kwargs):
6565
return new_f
6666

6767

68-
def get_dummy_arg(arg):
68+
def get_dummy_arg(arg): # pragma: no cover
6969
"""If any axes are abstract, replace them with an empty numpy array.
7070
7171
Even if abstracted_axes specifies two dimensions as having different dynamic shapes,
@@ -112,7 +112,7 @@ def validate_no_resizing_returns(
112112
"""
113113
offset = len(locations) # number of abstract shapes. We start from the first normal arg.
114114

115-
for locations_list in locations:
115+
for locations_list in locations: # pragma: no cover
116116
loc0 = locations_list[0]
117117
first_var = jaxpr.outvars[loc0.arg_idx + offset].aval.shape[loc0.shape_idx]
118118
for compare_loc in locations_list[1:]:
@@ -132,7 +132,7 @@ def validate_no_resizing_returns(
132132

133133

134134
def _has_dynamic_shape(val):
135-
return any(not isinstance(s, int) for s in getattr(val, "shape", ()))
135+
return any(not isinstance(s, int) for s in getattr(val, "shape", ())) # pragma: no cover
136136

137137

138138
def handle_jaxpr_error(
@@ -142,7 +142,7 @@ def handle_jaxpr_error(
142142
about 'Incompatible shapes for broadcasting'."""
143143
import jax # pylint: disable=import-outside-toplevel
144144

145-
if "Incompatible shapes for broadcasting" in str(e) and jax.config.jax_dynamic_shapes:
145+
if "Incompatible shapes for broadcasting" in str(e) and jax.config.jax_dynamic_shapes: # pragma: no cover
146146
closures = sum(((fn.__closure__ or ()) for fn in fns), ())
147147
if any(_has_dynamic_shape(i.cell_contents) for i in closures):
148148
msg = (
@@ -176,7 +176,7 @@ def add_arg(self, x_idx: int, x):
176176
arg_abstracted_axes = {}
177177

178178
for shape_idx, s in enumerate(getattr(x, "shape", ())):
179-
if not isinstance(s, int): # if not int, then abstract
179+
if not isinstance(s, int): # pragma: no cover
180180
found = False
181181
if not self.allow_array_resizing:
182182
for previous_idx, previous_shape in enumerate(self.abstract_shapes):
@@ -189,7 +189,7 @@ def add_arg(self, x_idx: int, x):
189189
break
190190
# haven't encountered it, so add it to abstract_axes
191191
# and use new number designation
192-
if not found:
192+
if not found: # pragma: no cover
193193
arg_abstracted_axes[shape_idx] = len(self.abstract_shapes)
194194
self.shape_locations.append([AbstractShapeLocation(x_idx, shape_idx)])
195195
self.abstract_shapes.append(s)
@@ -262,5 +262,5 @@ def f(*args, allow_array_resizing):
262262
if not any(calculator.abstracted_axes):
263263
return None, [], []
264264

265-
abstracted_axes = jax.tree_util.tree_unflatten(structure, calculator.abstracted_axes)
266-
return abstracted_axes, calculator.abstract_shapes, calculator.shape_locations
265+
abstracted_axes = jax.tree_util.tree_unflatten(structure, calculator.abstracted_axes) # pragma: no cover
266+
return abstracted_axes, calculator.abstract_shapes, calculator.shape_locations # pragma: no cover

0 commit comments

Comments
 (0)