diff --git a/manim/mobject/mobject.py b/manim/mobject/mobject.py index 9f3818c188..342fa37542 100644 --- a/manim/mobject/mobject.py +++ b/manim/mobject/mobject.py @@ -2157,30 +2157,33 @@ def restore(self) -> Self: self.become(self.saved_state) return self - def reduce_across_dimension( - self, reduce_func: Callable[[Iterable[float]], float], dim: int + def get_extreme_coord_along_dimension( + self, measurement_function: Callable[[Iterable[float]], float], dim: int = 0 ) -> float: - """Find the min or max value from a dimension across all points in this and submobjects.""" - assert dim >= 0 - assert dim <= 2 - if len(self.submobjects) == 0 and len(self.points) == 0: - # If we have no points and no submobjects, return 0 (e.g. center) - return 0 - - # If we do not have points (but do have submobjects) - # use only the points from those. - if len(self.points) == 0: # noqa: SIM108 - rv = None - else: - # Otherwise, be sure to include our own points - rv = reduce_func(self.points[:, dim]) - # Recursively ask submobjects (if any) for the biggest/ - # smallest dimension they have and compare it to the return value. - for mobj in self.submobjects: - value = mobj.reduce_across_dimension(reduce_func, dim) - rv = value if rv is None else reduce_func([value, rv]) - assert rv is not None - return rv + """ + Finds the minimum or maximum coordinate value in a given dimension, across all points in self and its submobjects. + + measurement_function can be either max or min. + """ + if not (0 <= dim <= 2): + raise ValueError("dim must be either 0, 1 or 2") + + extreme_coord = None + + if self.points.size > 0: + extreme_coord = measurement_function(self.points[:, dim]) + + for submobject in self.submobjects: + submobject_extreme_coord = submobject.get_extreme_coord_along_dimension( + measurement_function, dim + ) + extreme_coord = ( + submobject_extreme_coord + if extreme_coord is None + else measurement_function([submobject_extreme_coord, extreme_coord]) + ) + + return extreme_coord if extreme_coord is not None else 0.0 def nonempty_submobjects(self) -> Sequence[Mobject]: return [ @@ -2333,11 +2336,11 @@ def get_nadir(self) -> Point3D: def length_over_dim(self, dim: int) -> float: """Measure the length of an :class:`~.Mobject` in a certain direction.""" - max_coord: float = self.reduce_across_dimension( + max_coord: float = self.get_extreme_coord_along_dimension( max, dim, ) - min_coord: float = self.reduce_across_dimension(min, dim) + min_coord: float = self.get_extreme_coord_along_dimension(min, dim) return max_coord - min_coord def get_coord(self, dim: int, direction: Vector3DLike = ORIGIN) -> float: