diff --git a/manim/mobject/table.py b/manim/mobject/table.py index dc4519dea3..87592f8da8 100644 --- a/manim/mobject/table.py +++ b/manim/mobject/table.py @@ -66,6 +66,7 @@ def construct(self): import itertools as it from collections.abc import Callable, Iterable, Sequence +from typing import Any, Self from manim.mobject.geometry.line import Line from manim.mobject.geometry.polygram import Polygon @@ -186,9 +187,9 @@ def construct(self): def __init__( self, - table: Iterable[Iterable[float | str | VMobject]], - row_labels: Iterable[VMobject] | None = None, - col_labels: Iterable[VMobject] | None = None, + table: Sequence[Sequence[float | str | VMobject]], + row_labels: Sequence[VMobject] | None = None, + col_labels: Sequence[VMobject] | None = None, top_left_entry: VMobject | None = None, v_buff: float = 0.8, h_buff: float = 1.3, @@ -198,16 +199,25 @@ def __init__( include_background_rectangle: bool = False, background_rectangle_color: ParsableManimColor = BLACK, element_to_mobject: Callable[ + [float | str], + VMobject, + ] + | Callable[ + [VMobject], + VMobject, + ] + | Callable[ [float | str | VMobject], VMobject, - ] = Paragraph, + ] + | type[VMobject] = Paragraph, element_to_mobject_config: dict = {}, arrange_in_grid_config: dict = {}, line_config: dict = {}, - **kwargs, + **kwargs: Any, ): - self.row_labels = row_labels - self.col_labels = col_labels + self.row_labels = list(row_labels) if row_labels else None + self.col_labels = list(col_labels) if col_labels else None self.top_left_entry = top_left_entry self.row_dim = len(table) self.col_dim = len(table[0]) @@ -230,7 +240,7 @@ def __init__( raise ValueError("Not all rows in table have the same length.") super().__init__(**kwargs) - mob_table = self._table_to_mob_table(table) + mob_table: list[list[VMobject]] = self._table_to_mob_table(table) self.elements_without_labels = VGroup(*it.chain(*mob_table)) mob_table = self._add_labels(mob_table) self._organize_mob_table(mob_table) @@ -252,7 +262,7 @@ def __init__( def _table_to_mob_table( self, table: Iterable[Iterable[float | str | VMobject]], - ) -> list: + ) -> list[list[VMobject]]: """Initializes the entries of ``table`` as :class:`~.VMobject`. Parameters @@ -268,13 +278,15 @@ def _table_to_mob_table( """ return [ [ - self.element_to_mobject(item, **self.element_to_mobject_config) + # error: Argument 1 has incompatible type "float | str | VMobject"; expected "float | str" [arg-type] + # error: Argument 1 has incompatible type "float | str | VMobject"; expected "VMobject" [arg-type] + self.element_to_mobject(item, **self.element_to_mobject_config) # type: ignore[arg-type] for item in row ] for row in table ] - def _organize_mob_table(self, table: Iterable[Iterable[VMobject]]) -> VGroup: + def _organize_mob_table(self, table: Sequence[Sequence[VMobject]]) -> VGroup: """Arranges the :class:`~.VMobject` of ``table`` in a grid. Parameters @@ -300,7 +312,7 @@ def _organize_mob_table(self, table: Iterable[Iterable[VMobject]]) -> VGroup: ) return help_table - def _add_labels(self, mob_table: VGroup) -> VGroup: + def _add_labels(self, mob_table: list[list[VMobject]]) -> list[list[VMobject]]: """Adds labels to an in a grid arranged :class:`~.VGroup`. Parameters @@ -319,13 +331,13 @@ def _add_labels(self, mob_table: VGroup) -> VGroup: if self.col_labels is not None: if self.row_labels is not None: if self.top_left_entry is not None: - col_labels = [self.top_left_entry] + self.col_labels + col_labels = [self.top_left_entry] + list(self.col_labels) mob_table.insert(0, col_labels) else: # Placeholder to use arrange_in_grid if top_left_entry is not set. # Import OpenGLVMobject to work with --renderer=opengl dummy_mobject = get_vectorized_mobject_class()() - col_labels = [dummy_mobject] + self.col_labels + col_labels = [dummy_mobject] + list(self.col_labels) mob_table.insert(0, col_labels) else: mob_table.insert(0, self.col_labels) @@ -682,7 +694,10 @@ def construct(self): item.set_color(random_bright_color()) self.add(table) """ - return VGroup(*self.row_labels) + if self.row_labels: + return VGroup(*self.row_labels) + else: + return VGroup() def get_col_labels(self) -> VGroup: """Return the column labels of the table. @@ -710,7 +725,10 @@ def construct(self): item.set_color(random_bright_color()) self.add(table) """ - return VGroup(*self.col_labels) + if self.col_labels: + return VGroup(*self.col_labels) + else: + return VGroup() def get_labels(self) -> VGroup: """Returns the labels of the table. @@ -753,7 +771,7 @@ def add_background_to_entries(self, color: ParsableManimColor = BLACK) -> Table: mob.add_background_rectangle(color=ManimColor(color)) return self - def get_cell(self, pos: Sequence[int] = (1, 1), **kwargs) -> Polygon: + def get_cell(self, pos: Sequence[int] = (1, 1), **kwargs: Any) -> Polygon: """Returns one specific cell as a rectangular :class:`~.Polygon` without the entry. Parameters @@ -814,7 +832,7 @@ def get_highlighted_cell( self, pos: Sequence[int] = (1, 1), color: ParsableManimColor = PURE_YELLOW, - **kwargs, + **kwargs: Any, ) -> BackgroundRectangle: """Returns a :class:`~.BackgroundRectangle` of the cell at the given position. @@ -853,7 +871,7 @@ def add_highlighted_cell( self, pos: Sequence[int] = (1, 1), color: ParsableManimColor = PURE_YELLOW, - **kwargs, + **kwargs: Any, ) -> Table: """Highlights one cell at a specific position on the table by adding a :class:`~.BackgroundRectangle`. @@ -896,7 +914,7 @@ def create( label_animation: Callable[[VMobject | VGroup], Animation] = Write, element_animation: Callable[[VMobject | VGroup], Animation] = Create, entry_animation: Callable[[VMobject | VGroup], Animation] = FadeIn, - **kwargs, + **kwargs: Any, ) -> AnimationGroup: """Customized create-type function for tables. @@ -936,7 +954,7 @@ def construct(self): self.play(table.create()) self.wait() """ - animations: Sequence[Animation] = [ + animations: list[Animation] = [ line_animation( VGroup(self.vertical_lines, self.horizontal_lines), **kwargs, @@ -963,12 +981,14 @@ def construct(self): return AnimationGroup(*animations, lag_ratio=lag_ratio) - def scale(self, scale_factor: float, **kwargs): + def scale( + self, scale_factor: float, scale_stroke: bool = False, **kwargs: Any + ) -> Self: # h_buff and v_buff must be adjusted so that Table.get_cell # can construct an accurate polygon for a cell. self.h_buff *= scale_factor self.v_buff *= scale_factor - super().scale(scale_factor, **kwargs) + super().scale(scale_factor, scale_stroke=scale_stroke, **kwargs) return self @@ -994,9 +1014,10 @@ def construct(self): def __init__( self, - table: Iterable[Iterable[float | str]], - element_to_mobject: Callable[[float | str], VMobject] = MathTex, - **kwargs, + table: Sequence[Sequence[float | str]], + element_to_mobject: Callable[[float | str], VMobject] + | type[VMobject] = MathTex, + **kwargs: Any, ): """ Special case of :class:`~.Table` with `element_to_mobject` set to :class:`~.MathTex`. @@ -1049,9 +1070,10 @@ def construct(self): def __init__( self, - table: Iterable[Iterable[VMobject]], - element_to_mobject: Callable[[VMobject], VMobject] = lambda m: m, - **kwargs, + table: Sequence[Sequence[VMobject]], + element_to_mobject: Callable[[VMobject], VMobject] + | type[VMobject] = lambda m: m, + **kwargs: Any, ): """ Special case of :class:`~.Table` with ``element_to_mobject`` set to an identity function. @@ -1097,9 +1119,10 @@ def construct(self): def __init__( self, - table: Iterable[Iterable[float | str]], - element_to_mobject: Callable[[float | str], VMobject] = Integer, - **kwargs, + table: Sequence[Sequence[float | str]], + element_to_mobject: Callable[[float | str], VMobject] + | type[VMobject] = Integer, + **kwargs: Any, ): """ Special case of :class:`~.Table` with `element_to_mobject` set to :class:`~.Integer`. @@ -1141,10 +1164,11 @@ def construct(self): def __init__( self, - table: Iterable[Iterable[float | str]], - element_to_mobject: Callable[[float | str], VMobject] = DecimalNumber, + table: Sequence[Sequence[float | str]], + element_to_mobject: Callable[[float | str], VMobject] + | type[VMobject] = DecimalNumber, element_to_mobject_config: dict = {"num_decimal_places": 1}, - **kwargs, + **kwargs: Any, ): """ Special case of :class:`~.Table` with ``element_to_mobject`` set to :class:`~.DecimalNumber`. diff --git a/mypy.ini b/mypy.ini index 0ee928de4c..c20bffe8f3 100644 --- a/mypy.ini +++ b/mypy.ini @@ -94,9 +94,6 @@ ignore_errors = True [mypy-manim.mobject.opengl.opengl_vectorized_mobject] ignore_errors = True -[mypy-manim.mobject.table] -ignore_errors = True - [mypy-manim.mobject.types.point_cloud_mobject] ignore_errors = True