From 29177f30247cb94daeb9f2bf51297f694a2818ff Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Tue, 17 Feb 2026 13:16:47 +0900 Subject: [PATCH 01/12] added private metadata machinery --- src/tracksdata/graph/_base_graph.py | 46 +++++++++++++++++-- src/tracksdata/graph/_graph_view.py | 12 ++--- src/tracksdata/graph/_rustworkx_graph.py | 6 +-- src/tracksdata/graph/_sql_graph.py | 6 +-- .../graph/_test/test_graph_backends.py | 21 +++++++++ 5 files changed, 76 insertions(+), 15 deletions(-) diff --git a/src/tracksdata/graph/_base_graph.py b/src/tracksdata/graph/_base_graph.py index 5b3708ad..4340ca1d 100644 --- a/src/tracksdata/graph/_base_graph.py +++ b/src/tracksdata/graph/_base_graph.py @@ -47,6 +47,8 @@ class BaseGraph(abc.ABC): Base class for a graph backend. """ + _PRIVATE_METADATA_PREFIX = "__private_" + node_added = Signal(int) node_removed = Signal(int) @@ -1187,6 +1189,9 @@ def from_other(cls: type[T], other: "BaseGraph", **kwargs) -> T: graph = cls(**kwargs) graph.update_metadata(**other.metadata()) + private_metadata = other._private_metadata() + if private_metadata: + graph._update_metadata(**private_metadata) current_node_attr_schemas = graph._node_attr_schemas() for k, v in other._node_attr_schemas().items(): @@ -1824,7 +1829,6 @@ def to_geff( zarr_format=zarr_format, ) - @abc.abstractmethod def metadata(self) -> dict[str, Any]: """ Return the metadata of the graph. @@ -1841,8 +1845,8 @@ def metadata(self) -> dict[str, Any]: print(metadata["shape"]) ``` """ + return {k: v for k, v in self._metadata().items() if not self._is_private_metadata_key(k)} - @abc.abstractmethod def update_metadata(self, **kwargs) -> None: """ Set or update metadata for the graph. @@ -1859,8 +1863,9 @@ def update_metadata(self, **kwargs) -> None: graph.update_metadata(description="Tracking data from experiment 1") ``` """ + self._validate_public_metadata_keys(kwargs.keys()) + self._update_metadata(**kwargs) - @abc.abstractmethod def remove_metadata(self, key: str) -> None: """ Remove a metadata key from the graph. @@ -1876,6 +1881,41 @@ def remove_metadata(self, key: str) -> None: graph.remove_metadata("shape") ``` """ + self._validate_public_metadata_key(key) + self._remove_metadata(key) + + @classmethod + def _is_private_metadata_key(cls, key: str) -> bool: + return key.startswith(cls._PRIVATE_METADATA_PREFIX) + + def _validate_public_metadata_key(self, key: str) -> None: + if self._is_private_metadata_key(key): + raise ValueError(f"Metadata key '{key}' is reserved for internal use.") + + def _validate_public_metadata_keys(self, keys: Sequence[str]) -> None: + for key in keys: + self._validate_public_metadata_key(key) + + def _private_metadata(self) -> dict[str, Any]: + return {k: v for k, v in self._metadata().items() if self._is_private_metadata_key(k)} + + @abc.abstractmethod + def _metadata(self) -> dict[str, Any]: + """ + Return the full metadata including private keys. + """ + + @abc.abstractmethod + def _update_metadata(self, **kwargs) -> None: + """ + Backend-specific metadata update implementation without public key validation. + """ + + @abc.abstractmethod + def _remove_metadata(self, key: str) -> None: + """ + Backend-specific metadata removal implementation without public key validation. + """ def to_traccuracy_graph(self, array_view_kwargs: dict[str, Any] | None = None) -> "TrackingGraph": """ diff --git a/src/tracksdata/graph/_graph_view.py b/src/tracksdata/graph/_graph_view.py index b9f82ead..c689931d 100644 --- a/src/tracksdata/graph/_graph_view.py +++ b/src/tracksdata/graph/_graph_view.py @@ -847,11 +847,11 @@ def copy(self, **kwargs) -> "GraphView": "Use `detach` to create a new reference-less graph with the same nodes and edges." ) - def metadata(self) -> dict[str, Any]: - return self._root.metadata() + def _metadata(self) -> dict[str, Any]: + return self._root._metadata() - def update_metadata(self, **kwargs) -> None: - self._root.update_metadata(**kwargs) + def _update_metadata(self, **kwargs) -> None: + self._root._update_metadata(**kwargs) - def remove_metadata(self, key: str) -> None: - self._root.remove_metadata(key) + def _remove_metadata(self, key: str) -> None: + self._root._remove_metadata(key) diff --git a/src/tracksdata/graph/_rustworkx_graph.py b/src/tracksdata/graph/_rustworkx_graph.py index ef4a3f4f..229eacc2 100644 --- a/src/tracksdata/graph/_rustworkx_graph.py +++ b/src/tracksdata/graph/_rustworkx_graph.py @@ -1499,13 +1499,13 @@ def edge_id(self, source_id: int, target_id: int) -> int: """ return self.rx_graph.get_edge_data(source_id, target_id)[DEFAULT_ATTR_KEYS.EDGE_ID] - def metadata(self) -> dict[str, Any]: + def _metadata(self) -> dict[str, Any]: return self._graph.attrs - def update_metadata(self, **kwargs) -> None: + def _update_metadata(self, **kwargs) -> None: self._graph.attrs.update(kwargs) - def remove_metadata(self, key: str) -> None: + def _remove_metadata(self, key: str) -> None: self._graph.attrs.pop(key, None) def edge_list(self) -> list[list[int, int]]: diff --git a/src/tracksdata/graph/_sql_graph.py b/src/tracksdata/graph/_sql_graph.py index 985cbdc9..c8ea38ed 100644 --- a/src/tracksdata/graph/_sql_graph.py +++ b/src/tracksdata/graph/_sql_graph.py @@ -1992,19 +1992,19 @@ def remove_edge( raise ValueError(f"Edge {edge_id} does not exist in the graph.") session.commit() - def metadata(self) -> dict[str, Any]: + def _metadata(self) -> dict[str, Any]: with Session(self._engine) as session: result = session.query(self.Metadata).all() return {row.key: row.value for row in result} - def update_metadata(self, **kwargs) -> None: + def _update_metadata(self, **kwargs) -> None: with Session(self._engine) as session: for key, value in kwargs.items(): metadata_entry = self.Metadata(key=key, value=value) session.merge(metadata_entry) session.commit() - def remove_metadata(self, key: str) -> None: + def _remove_metadata(self, key: str) -> None: with Session(self._engine) as session: session.query(self.Metadata).filter(self.Metadata.key == key).delete() session.commit() diff --git a/src/tracksdata/graph/_test/test_graph_backends.py b/src/tracksdata/graph/_test/test_graph_backends.py index d6084cd8..7619188e 100644 --- a/src/tracksdata/graph/_test/test_graph_backends.py +++ b/src/tracksdata/graph/_test/test_graph_backends.py @@ -2511,6 +2511,27 @@ def test_metadata_multiple_dtypes(graph_backend: BaseGraph) -> None: assert "mixed_list" not in retrieved +def test_private_metadata_is_hidden_from_public_apis(graph_backend: BaseGraph) -> None: + private_key = "__private_dtype_map" + + graph_backend._update_metadata(**{private_key: {"x": "float64"}}) + graph_backend.update_metadata(shape=[1, 2, 3]) + + public_metadata = graph_backend.metadata() + assert private_key not in public_metadata + assert public_metadata["shape"] == [1, 2, 3] + + with pytest.raises(ValueError, match="reserved for internal use"): + graph_backend.update_metadata(**{private_key: {"x": "int64"}}) + + with pytest.raises(ValueError, match="reserved for internal use"): + graph_backend.remove_metadata(private_key) + + # Internal APIs can still remove private keys. + graph_backend._remove_metadata(private_key) + assert private_key not in graph_backend._metadata() + + def test_pickle_roundtrip(graph_backend: BaseGraph) -> None: if isinstance(graph_backend, SQLGraph): pytest.skip("SQLGraph does not support pickle roundtrip") From d8292f1c9b01ac75c94316d8a990937be3bd74e2 Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Tue, 17 Feb 2026 14:01:04 +0900 Subject: [PATCH 02/12] before adding private --- src/tracksdata/array/_graph_array.py | 2 +- .../functional/_test/test_napari.py | 2 +- src/tracksdata/graph/__init__.py | 4 +- src/tracksdata/graph/_base_graph.py | 118 +++++++++++------- src/tracksdata/graph/_rustworkx_graph.py | 2 +- .../graph/_test/test_graph_backends.py | 44 +++---- src/tracksdata/io/_test/test_ctc_io.py | 2 +- src/tracksdata/nodes/_regionprops.py | 4 +- .../nodes/_test/test_regionprops.py | 36 +++--- 9 files changed, 123 insertions(+), 91 deletions(-) diff --git a/src/tracksdata/array/_graph_array.py b/src/tracksdata/array/_graph_array.py index 018e7ae6..80418986 100644 --- a/src/tracksdata/array/_graph_array.py +++ b/src/tracksdata/array/_graph_array.py @@ -23,7 +23,7 @@ def _validate_shape( """Helper function to validate the shape argument.""" if shape is None: try: - shape = graph.metadata()["shape"] + shape = graph.metadata["shape"] except KeyError as e: raise KeyError( f"`shape` is required to `{func_name}`. " diff --git a/src/tracksdata/functional/_test/test_napari.py b/src/tracksdata/functional/_test/test_napari.py index 9b4a81dc..712cf53d 100644 --- a/src/tracksdata/functional/_test/test_napari.py +++ b/src/tracksdata/functional/_test/test_napari.py @@ -31,7 +31,7 @@ def test_napari_conversion(metadata_shape: bool) -> None: shape = (2, 10, 22, 32) if metadata_shape: - graph.update_metadata(shape=shape) + graph.metadata.update(shape=shape) arg_shape = None else: arg_shape = shape diff --git a/src/tracksdata/graph/__init__.py b/src/tracksdata/graph/__init__.py index fcf207e2..3906949b 100644 --- a/src/tracksdata/graph/__init__.py +++ b/src/tracksdata/graph/__init__.py @@ -1,10 +1,10 @@ """Graph backends for representing tracking data as directed graphs in memory or on disk.""" -from tracksdata.graph._base_graph import BaseGraph +from tracksdata.graph._base_graph import BaseGraph, MetadataView from tracksdata.graph._graph_view import GraphView from tracksdata.graph._rustworkx_graph import IndexedRXGraph, RustWorkXGraph from tracksdata.graph._sql_graph import SQLGraph InMemoryGraph = RustWorkXGraph -__all__ = ["BaseGraph", "GraphView", "InMemoryGraph", "IndexedRXGraph", "RustWorkXGraph", "SQLGraph"] +__all__ = ["BaseGraph", "GraphView", "InMemoryGraph", "IndexedRXGraph", "MetadataView", "RustWorkXGraph", "SQLGraph"] diff --git a/src/tracksdata/graph/_base_graph.py b/src/tracksdata/graph/_base_graph.py index 4340ca1d..bfc8239b 100644 --- a/src/tracksdata/graph/_base_graph.py +++ b/src/tracksdata/graph/_base_graph.py @@ -42,6 +42,61 @@ T = TypeVar("T", bound="BaseGraph") +class MetadataView(dict[str, Any]): + """Dictionary-like metadata view that syncs mutations back to the graph.""" + + _MISSING = object() + + def __init__(self, graph: "BaseGraph", data: dict[str, Any]) -> None: + super().__init__(data) + self._graph = graph + + def __setitem__(self, key: str, value: Any) -> None: + self._graph._set_public_metadata(**{key: value}) + super().__setitem__(key, value) + + def __delitem__(self, key: str) -> None: + self._graph._remove_public_metadata(key) + super().__delitem__(key) + + def pop(self, key: str, default: Any = _MISSING) -> Any: + self._graph._validate_public_metadata_key(key) + + if key not in self: + if default is self._MISSING: + raise KeyError(key) + return default + + value = super().__getitem__(key) + self._graph._remove_metadata(key) + super().pop(key, None) + return value + + def popitem(self) -> tuple[str, Any]: + key, value = super().popitem() + self._graph._remove_metadata(key) + return key, value + + def clear(self) -> None: + keys = list(self.keys()) + for key in keys: + self._graph._remove_metadata(key) + super().clear() + + def setdefault(self, key: str, default: Any = None) -> Any: + if key in self: + return super().__getitem__(key) + self._graph._set_public_metadata(**{key: default}) + super().__setitem__(key, default) + return default + + def update(self, *args, **kwargs) -> None: + updates = dict(*args, **kwargs) + if updates: + self._graph._set_public_metadata(**updates) + super().update(updates) + + class BaseGraph(abc.ABC): """ Base class for a graph backend. @@ -1188,7 +1243,7 @@ def from_other(cls: type[T], other: "BaseGraph", **kwargs) -> T: node_attrs = node_attrs.drop(DEFAULT_ATTR_KEYS.NODE_ID) graph = cls(**kwargs) - graph.update_metadata(**other.metadata()) + graph.metadata.update(other.metadata) private_metadata = other._private_metadata() if private_metadata: graph._update_metadata(**private_metadata) @@ -1791,7 +1846,7 @@ def to_geff( for k, v in edge_attrs.to_dict().items() } - td_metadata = self.metadata().copy() + td_metadata = self.metadata.copy() td_metadata.pop("geff", None) # avoid geff being written multiple times geff_metadata = geff.GeffMetadata( @@ -1829,66 +1884,35 @@ def to_geff( zarr_format=zarr_format, ) - def metadata(self) -> dict[str, Any]: + @property + def metadata(self) -> MetadataView: """ Return the metadata of the graph. Returns ------- - dict[str, Any] + MetadataView The metadata of the graph as a dictionary. Examples -------- ```python - metadata = graph.metadata() + metadata = graph.metadata print(metadata["shape"]) ``` """ - return {k: v for k, v in self._metadata().items() if not self._is_private_metadata_key(k)} - - def update_metadata(self, **kwargs) -> None: - """ - Set or update metadata for the graph. - - Parameters - ---------- - **kwargs : Any - The metadata items to set by key. Values will be stored as JSON. - - Examples - -------- - ```python - graph.update_metadata(shape=[1, 25, 25], path="path/to/image.ome.zarr") - graph.update_metadata(description="Tracking data from experiment 1") - ``` - """ - self._validate_public_metadata_keys(kwargs.keys()) - self._update_metadata(**kwargs) - - def remove_metadata(self, key: str) -> None: - """ - Remove a metadata key from the graph. - - Parameters - ---------- - key : str - The key of the metadata to remove. - - Examples - -------- - ```python - graph.remove_metadata("shape") - ``` - """ - self._validate_public_metadata_key(key) - self._remove_metadata(key) + return MetadataView( + graph=self, + data={k: v for k, v in self._metadata().items() if not self._is_private_metadata_key(k)}, + ) @classmethod def _is_private_metadata_key(cls, key: str) -> bool: return key.startswith(cls._PRIVATE_METADATA_PREFIX) def _validate_public_metadata_key(self, key: str) -> None: + if not isinstance(key, str): + raise TypeError(f"Metadata key must be a string. Got {type(key)}.") if self._is_private_metadata_key(key): raise ValueError(f"Metadata key '{key}' is reserved for internal use.") @@ -1896,6 +1920,14 @@ def _validate_public_metadata_keys(self, keys: Sequence[str]) -> None: for key in keys: self._validate_public_metadata_key(key) + def _set_public_metadata(self, **kwargs) -> None: + self._validate_public_metadata_keys(kwargs.keys()) + self._update_metadata(**kwargs) + + def _remove_public_metadata(self, key: str) -> None: + self._validate_public_metadata_key(key) + self._remove_metadata(key) + def _private_metadata(self) -> dict[str, Any]: return {k: v for k, v in self._metadata().items() if self._is_private_metadata_key(k)} diff --git a/src/tracksdata/graph/_rustworkx_graph.py b/src/tracksdata/graph/_rustworkx_graph.py index 229eacc2..05cf1c17 100644 --- a/src/tracksdata/graph/_rustworkx_graph.py +++ b/src/tracksdata/graph/_rustworkx_graph.py @@ -371,7 +371,7 @@ def __init__(self, rx_graph: rx.PyDiGraph | None = None) -> None: elif not isinstance(self._graph.attrs, dict): LOG.warning( - "previous attribute %s will be added to key 'old_attrs' of `graph.metadata()`", + "previous attribute %s will be added to key 'old_attrs' of `graph.metadata`", self._graph.attrs, ) self._graph.attrs = { diff --git a/src/tracksdata/graph/_test/test_graph_backends.py b/src/tracksdata/graph/_test/test_graph_backends.py index 7619188e..73a1161c 100644 --- a/src/tracksdata/graph/_test/test_graph_backends.py +++ b/src/tracksdata/graph/_test/test_graph_backends.py @@ -1359,7 +1359,7 @@ def test_from_other_with_edges( ) -> None: """Ensure from_other preserves structure across backend conversions.""" # Create source graph with nodes, edges, and attributes - graph_backend.update_metadata(special_key="special_value") + graph_backend.metadata.update(special_key="special_value") graph_backend.add_node_attr_key("x", dtype=pl.Float64) graph_backend.add_edge_attr_key("weight", dtype=pl.Float64, default_value=-1) @@ -1386,7 +1386,7 @@ def test_from_other_with_edges( assert set(new_graph.node_attr_keys()) == set(graph_backend.node_attr_keys()) assert set(new_graph.edge_attr_keys()) == set(graph_backend.edge_attr_keys()) - assert new_graph.metadata() == graph_backend.metadata() + assert new_graph.metadata == graph_backend.metadata assert new_graph._node_attr_schemas() == graph_backend._node_attr_schemas() assert new_graph._edge_attr_schemas() == graph_backend._edge_attr_schemas() @@ -2322,7 +2322,7 @@ def _fill_mock_geff_graph(graph_backend: BaseGraph) -> None: graph_backend.add_edge_attr_key("weight", pl.Float16) - graph_backend.update_metadata( + graph_backend.metadata.update( shape=[1, 25, 25], path="path/to/image.ome.zarr", ) @@ -2383,11 +2383,11 @@ def test_geff_roundtrip(graph_backend: BaseGraph) -> None: geff_graph, _ = IndexedRXGraph.from_geff(output_store) - assert "geff" in geff_graph.metadata() + assert "geff" in geff_graph.metadata # geff metadata was not stored in original graph - geff_graph.metadata().pop("geff") - assert geff_graph.metadata() == graph_backend.metadata() + geff_graph.metadata.pop("geff") + assert geff_graph.metadata == graph_backend.metadata assert geff_graph.num_nodes() == 3 assert geff_graph.num_edges() == 2 @@ -2442,11 +2442,11 @@ def test_geff_with_keymapping(graph_backend: BaseGraph) -> None: edge_attr_key_map={"weight": "weight_new"}, ) - assert "geff" in geff_graph.metadata() + assert "geff" in geff_graph.metadata # geff metadata was not stored in original graph - geff_graph.metadata().pop("geff") - assert geff_graph.metadata() == graph_backend.metadata() + geff_graph.metadata.pop("geff") + assert geff_graph.metadata == graph_backend.metadata assert geff_graph.num_nodes() == 3 assert geff_graph.num_edges() == 2 @@ -2483,30 +2483,30 @@ def test_metadata_multiple_dtypes(graph_backend: BaseGraph) -> None: } # Update metadata with all test values - graph_backend.update_metadata(**test_metadata) + graph_backend.metadata.update(**test_metadata) # Retrieve and verify - retrieved = graph_backend.metadata() + retrieved = graph_backend.metadata for key, expected_value in test_metadata.items(): assert key in retrieved, f"Key '{key}' not found in metadata" assert retrieved[key] == expected_value, f"Value mismatch for '{key}': {retrieved[key]} != {expected_value}" # Test updating existing keys - graph_backend.update_metadata(string="updated_value", new_key="new_value") - retrieved = graph_backend.metadata() + graph_backend.metadata.update(string="updated_value", new_key="new_value") + retrieved = graph_backend.metadata assert retrieved["string"] == "updated_value" assert retrieved["new_key"] == "new_value" assert retrieved["integer"] == 42 # Other values unchanged # Testing removing metadata - graph_backend.remove_metadata("string") - retrieved = graph_backend.metadata() + graph_backend.metadata.pop("string", None) + retrieved = graph_backend.metadata assert "string" not in retrieved - graph_backend.remove_metadata("mixed_list") - retrieved = graph_backend.metadata() + graph_backend.metadata.pop("mixed_list", None) + retrieved = graph_backend.metadata assert "string" not in retrieved assert "mixed_list" not in retrieved @@ -2515,17 +2515,17 @@ def test_private_metadata_is_hidden_from_public_apis(graph_backend: BaseGraph) - private_key = "__private_dtype_map" graph_backend._update_metadata(**{private_key: {"x": "float64"}}) - graph_backend.update_metadata(shape=[1, 2, 3]) + graph_backend.metadata.update(shape=[1, 2, 3]) - public_metadata = graph_backend.metadata() + public_metadata = graph_backend.metadata assert private_key not in public_metadata assert public_metadata["shape"] == [1, 2, 3] with pytest.raises(ValueError, match="reserved for internal use"): - graph_backend.update_metadata(**{private_key: {"x": "int64"}}) + graph_backend.metadata.update(**{private_key: {"x": "int64"}}) with pytest.raises(ValueError, match="reserved for internal use"): - graph_backend.remove_metadata(private_key) + graph_backend.metadata.pop(private_key, None) # Internal APIs can still remove private keys. graph_backend._remove_metadata(private_key) @@ -2606,7 +2606,7 @@ def test_to_traccuracy_graph(graph_backend: BaseGraph) -> None: graph_backend.add_node_attr_key("y", pl.Float64) graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, pl.Object) graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, pl.Array(pl.Int64, 4)) - graph_backend.update_metadata(shape=[3, 25, 25]) + graph_backend.metadata.update(shape=[3, 25, 25]) # Create masks for first graph mask1_data = np.array([[True, True], [True, True]], dtype=bool) diff --git a/src/tracksdata/io/_test/test_ctc_io.py b/src/tracksdata/io/_test/test_ctc_io.py index 7c5fb925..01025213 100644 --- a/src/tracksdata/io/_test/test_ctc_io.py +++ b/src/tracksdata/io/_test/test_ctc_io.py @@ -68,7 +68,7 @@ def test_export_from_ctc_roundtrip(tmp_path: Path, metadata_shape: bool) -> None in_graph.add_edge(node_1, node_3, attrs={DEFAULT_ATTR_KEYS.EDGE_DIST: 1.0}) if metadata_shape: - in_graph.update_metadata(shape=(2, 4, 4)) + in_graph.metadata.update(shape=(2, 4, 4)) shape = None else: shape = (2, 4, 4) diff --git a/src/tracksdata/nodes/_regionprops.py b/src/tracksdata/nodes/_regionprops.py index c78feb32..5be49713 100644 --- a/src/tracksdata/nodes/_regionprops.py +++ b/src/tracksdata/nodes/_regionprops.py @@ -230,8 +230,8 @@ def add_nodes( axis_names = self._axis_names(labels) self._init_node_attrs(graph, axis_names, ndims=labels.ndim) - if "shape" not in graph.metadata(): - graph.update_metadata(shape=labels.shape) + if "shape" not in graph.metadata: + graph.metadata.update(shape=labels.shape) if t is None: time_points = range(labels.shape[0]) diff --git a/src/tracksdata/nodes/_test/test_regionprops.py b/src/tracksdata/nodes/_test/test_regionprops.py index 350d231b..567c62e0 100644 --- a/src/tracksdata/nodes/_test/test_regionprops.py +++ b/src/tracksdata/nodes/_test/test_regionprops.py @@ -79,8 +79,8 @@ def test_regionprops_add_nodes_2d() -> None: operator = RegionPropsNodes(extra_properties=extra_properties) operator.add_nodes(graph, labels=labels) - assert "shape" in graph.metadata() - assert graph.metadata()["shape"] == labels.shape + assert "shape" in graph.metadata + assert graph.metadata["shape"] == labels.shape # Check that nodes were added assert graph.num_nodes() == 2 # Two regions (labels 1 and 2) @@ -115,8 +115,8 @@ def test_regionprops_add_nodes_3d() -> None: operator = RegionPropsNodes(extra_properties=extra_properties) operator.add_nodes(graph, labels=labels) - assert "shape" in graph.metadata() - assert graph.metadata()["shape"] == labels.shape + assert "shape" in graph.metadata + assert graph.metadata["shape"] == labels.shape # Check that nodes were added assert graph.num_nodes() == 2 # Two regions @@ -150,8 +150,8 @@ def test_regionprops_add_nodes_with_intensity() -> None: operator.add_nodes(graph, labels=labels, intensity_image=intensity) - assert "shape" in graph.metadata() - assert graph.metadata()["shape"] == labels.shape + assert "shape" in graph.metadata + assert graph.metadata["shape"] == labels.shape # Check that nodes were added with intensity attributes nodes_df = graph.node_attrs() @@ -181,8 +181,8 @@ def test_regionprops_add_nodes_timelapse(n_workers: int) -> None: with options_context(n_workers=n_workers): operator.add_nodes(graph, labels=labels) - assert "shape" in graph.metadata() - assert graph.metadata()["shape"] == labels.shape + assert "shape" in graph.metadata + assert graph.metadata["shape"] == labels.shape # Check that nodes were added for both time points nodes_df = graph.node_attrs() @@ -209,8 +209,8 @@ def test_regionprops_add_nodes_timelapse_with_intensity() -> None: operator.add_nodes(graph, labels=labels, intensity_image=intensity) - assert "shape" in graph.metadata() - assert graph.metadata()["shape"] == labels.shape + assert "shape" in graph.metadata + assert graph.metadata["shape"] == labels.shape # Check that nodes were added with intensity attributes nodes_df = graph.node_attrs() @@ -237,8 +237,8 @@ def double_area(region: RegionProperties) -> float: operator.add_nodes(graph, labels=labels, t=0) - assert "shape" in graph.metadata() - assert graph.metadata()["shape"] == labels.shape + assert "shape" in graph.metadata + assert graph.metadata["shape"] == labels.shape # Check that custom property was calculated nodes_df = graph.node_attrs() @@ -275,8 +275,8 @@ def test_regionprops_mask_creation() -> None: operator.add_nodes(graph, labels=labels, t=0) - assert "shape" in graph.metadata() - assert graph.metadata()["shape"] == labels.shape + assert "shape" in graph.metadata + assert graph.metadata["shape"] == labels.shape # Check that masks were created nodes_df = graph.node_attrs() @@ -300,8 +300,8 @@ def test_regionprops_spacing() -> None: operator.add_nodes(graph, labels=labels, t=0) - assert "shape" in graph.metadata() - assert graph.metadata()["shape"] == labels.shape + assert "shape" in graph.metadata + assert graph.metadata["shape"] == labels.shape # Check that nodes were added (spacing affects internal calculations) nodes_df = graph.node_attrs() @@ -323,8 +323,8 @@ def test_regionprops_empty_labels() -> None: operator.add_nodes(graph, labels=labels, t=0) - assert "shape" in graph.metadata() - assert graph.metadata()["shape"] == labels.shape + assert "shape" in graph.metadata + assert graph.metadata["shape"] == labels.shape # No nodes should be added assert graph.num_nodes() == 0 From cff58981bd54328d07a6208e48c2410541504c31 Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Tue, 17 Feb 2026 14:11:08 +0900 Subject: [PATCH 03/12] added private metadata view --- src/tracksdata/graph/_base_graph.py | 63 ++++++++++++------- .../graph/_test/test_graph_backends.py | 9 ++- 2 files changed, 46 insertions(+), 26 deletions(-) diff --git a/src/tracksdata/graph/_base_graph.py b/src/tracksdata/graph/_base_graph.py index bfc8239b..28e8baa2 100644 --- a/src/tracksdata/graph/_base_graph.py +++ b/src/tracksdata/graph/_base_graph.py @@ -47,20 +47,27 @@ class MetadataView(dict[str, Any]): _MISSING = object() - def __init__(self, graph: "BaseGraph", data: dict[str, Any]) -> None: + def __init__( + self, + graph: "BaseGraph", + data: dict[str, Any], + *, + is_public: bool = True, + ) -> None: super().__init__(data) self._graph = graph + self._is_public = is_public def __setitem__(self, key: str, value: Any) -> None: - self._graph._set_public_metadata(**{key: value}) + self._graph._set_public_metadata(is_public=self._is_public, **{key: value}) super().__setitem__(key, value) def __delitem__(self, key: str) -> None: - self._graph._remove_public_metadata(key) + self._graph._remove_public_metadata(key, is_public=self._is_public) super().__delitem__(key) def pop(self, key: str, default: Any = _MISSING) -> Any: - self._graph._validate_public_metadata_key(key) + self._graph._validate_metadata_key(key, is_public=self._is_public) if key not in self: if default is self._MISSING: @@ -68,32 +75,32 @@ def pop(self, key: str, default: Any = _MISSING) -> Any: return default value = super().__getitem__(key) - self._graph._remove_metadata(key) + self._graph._remove_public_metadata(key, is_public=self._is_public) super().pop(key, None) return value def popitem(self) -> tuple[str, Any]: key, value = super().popitem() - self._graph._remove_metadata(key) + self._graph._remove_public_metadata(key, is_public=self._is_public) return key, value def clear(self) -> None: keys = list(self.keys()) for key in keys: - self._graph._remove_metadata(key) + self._graph._remove_public_metadata(key, is_public=self._is_public) super().clear() def setdefault(self, key: str, default: Any = None) -> Any: if key in self: return super().__getitem__(key) - self._graph._set_public_metadata(**{key: default}) + self._graph._set_public_metadata(is_public=self._is_public, **{key: default}) super().__setitem__(key, default) return default def update(self, *args, **kwargs) -> None: updates = dict(*args, **kwargs) if updates: - self._graph._set_public_metadata(**updates) + self._graph._set_public_metadata(is_public=self._is_public, **updates) super().update(updates) @@ -1244,9 +1251,7 @@ def from_other(cls: type[T], other: "BaseGraph", **kwargs) -> T: graph = cls(**kwargs) graph.metadata.update(other.metadata) - private_metadata = other._private_metadata() - if private_metadata: - graph._update_metadata(**private_metadata) + graph._private_metadata.update(other._private_metadata) current_node_attr_schemas = graph._node_attr_schemas() for k, v in other._node_attr_schemas().items(): @@ -1904,33 +1909,45 @@ def metadata(self) -> MetadataView: return MetadataView( graph=self, data={k: v for k, v in self._metadata().items() if not self._is_private_metadata_key(k)}, + is_public=True, + ) + + @property + def _private_metadata(self) -> MetadataView: + return MetadataView( + graph=self, + data={k: v for k, v in self._metadata().items() if self._is_private_metadata_key(k)}, + is_public=False, ) @classmethod def _is_private_metadata_key(cls, key: str) -> bool: return key.startswith(cls._PRIVATE_METADATA_PREFIX) - def _validate_public_metadata_key(self, key: str) -> None: + def _validate_metadata_key(self, key: str, *, is_public: bool) -> None: if not isinstance(key, str): raise TypeError(f"Metadata key must be a string. Got {type(key)}.") - if self._is_private_metadata_key(key): + is_private_key = self._is_private_metadata_key(key) + if is_public and is_private_key: raise ValueError(f"Metadata key '{key}' is reserved for internal use.") + if not is_public and not is_private_key: + raise ValueError( + f"Metadata key '{key}' is not private. Private metadata keys must start with " + f"'{self._PRIVATE_METADATA_PREFIX}'." + ) - def _validate_public_metadata_keys(self, keys: Sequence[str]) -> None: + def _validate_metadata_keys(self, keys: Sequence[str], *, is_public: bool) -> None: for key in keys: - self._validate_public_metadata_key(key) + self._validate_metadata_key(key, is_public=is_public) - def _set_public_metadata(self, **kwargs) -> None: - self._validate_public_metadata_keys(kwargs.keys()) + def _set_public_metadata(self, is_public: bool = True, **kwargs) -> None: + self._validate_metadata_keys(kwargs.keys(), is_public=is_public) self._update_metadata(**kwargs) - def _remove_public_metadata(self, key: str) -> None: - self._validate_public_metadata_key(key) + def _remove_public_metadata(self, key: str, *, is_public: bool = True) -> None: + self._validate_metadata_key(key, is_public=is_public) self._remove_metadata(key) - def _private_metadata(self) -> dict[str, Any]: - return {k: v for k, v in self._metadata().items() if self._is_private_metadata_key(k)} - @abc.abstractmethod def _metadata(self) -> dict[str, Any]: """ diff --git a/src/tracksdata/graph/_test/test_graph_backends.py b/src/tracksdata/graph/_test/test_graph_backends.py index 73a1161c..e9088c75 100644 --- a/src/tracksdata/graph/_test/test_graph_backends.py +++ b/src/tracksdata/graph/_test/test_graph_backends.py @@ -2514,7 +2514,7 @@ def test_metadata_multiple_dtypes(graph_backend: BaseGraph) -> None: def test_private_metadata_is_hidden_from_public_apis(graph_backend: BaseGraph) -> None: private_key = "__private_dtype_map" - graph_backend._update_metadata(**{private_key: {"x": "float64"}}) + graph_backend._private_metadata.update(**{private_key: {"x": "float64"}}) graph_backend.metadata.update(shape=[1, 2, 3]) public_metadata = graph_backend.metadata @@ -2527,8 +2527,11 @@ def test_private_metadata_is_hidden_from_public_apis(graph_backend: BaseGraph) - with pytest.raises(ValueError, match="reserved for internal use"): graph_backend.metadata.pop(private_key, None) - # Internal APIs can still remove private keys. - graph_backend._remove_metadata(private_key) + with pytest.raises(ValueError, match="is not private"): + graph_backend._private_metadata.update(shape=[1, 2, 3]) + + # Private metadata view can remove private keys. + graph_backend._private_metadata.pop(private_key, None) assert private_key not in graph_backend._metadata() From 68b01d40c6368a5120474dc38e88276f6eb121da Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Tue, 17 Feb 2026 14:14:06 +0900 Subject: [PATCH 04/12] renamed func --- src/tracksdata/graph/_base_graph.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/tracksdata/graph/_base_graph.py b/src/tracksdata/graph/_base_graph.py index 28e8baa2..03dc3a01 100644 --- a/src/tracksdata/graph/_base_graph.py +++ b/src/tracksdata/graph/_base_graph.py @@ -59,11 +59,11 @@ def __init__( self._is_public = is_public def __setitem__(self, key: str, value: Any) -> None: - self._graph._set_public_metadata(is_public=self._is_public, **{key: value}) + self._graph._set_metadata_with_validation(is_public=self._is_public, **{key: value}) super().__setitem__(key, value) def __delitem__(self, key: str) -> None: - self._graph._remove_public_metadata(key, is_public=self._is_public) + self._graph._remove_metadata_with_validation(key, is_public=self._is_public) super().__delitem__(key) def pop(self, key: str, default: Any = _MISSING) -> Any: @@ -75,32 +75,32 @@ def pop(self, key: str, default: Any = _MISSING) -> Any: return default value = super().__getitem__(key) - self._graph._remove_public_metadata(key, is_public=self._is_public) + self._graph._remove_metadata_with_validation(key, is_public=self._is_public) super().pop(key, None) return value def popitem(self) -> tuple[str, Any]: key, value = super().popitem() - self._graph._remove_public_metadata(key, is_public=self._is_public) + self._graph._remove_metadata_with_validation(key, is_public=self._is_public) return key, value def clear(self) -> None: keys = list(self.keys()) for key in keys: - self._graph._remove_public_metadata(key, is_public=self._is_public) + self._graph._remove_metadata_with_validation(key, is_public=self._is_public) super().clear() def setdefault(self, key: str, default: Any = None) -> Any: if key in self: return super().__getitem__(key) - self._graph._set_public_metadata(is_public=self._is_public, **{key: default}) + self._graph._set_metadata_with_validation(is_public=self._is_public, **{key: default}) super().__setitem__(key, default) return default def update(self, *args, **kwargs) -> None: updates = dict(*args, **kwargs) if updates: - self._graph._set_public_metadata(is_public=self._is_public, **updates) + self._graph._set_metadata_with_validation(is_public=self._is_public, **updates) super().update(updates) @@ -1940,11 +1940,11 @@ def _validate_metadata_keys(self, keys: Sequence[str], *, is_public: bool) -> No for key in keys: self._validate_metadata_key(key, is_public=is_public) - def _set_public_metadata(self, is_public: bool = True, **kwargs) -> None: + def _set_metadata_with_validation(self, is_public: bool = True, **kwargs) -> None: self._validate_metadata_keys(kwargs.keys(), is_public=is_public) self._update_metadata(**kwargs) - def _remove_public_metadata(self, key: str, *, is_public: bool = True) -> None: + def _remove_metadata_with_validation(self, key: str, *, is_public: bool = True) -> None: self._validate_metadata_key(key, is_public=is_public) self._remove_metadata(key) From 1ae242670a8051827723c926883359657980238a Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Tue, 17 Feb 2026 15:34:33 +0900 Subject: [PATCH 05/12] implementation of saving and loading dtypes as metadata --- src/tracksdata/graph/_base_graph.py | 72 ++++++++++++++++- src/tracksdata/graph/_rustworkx_graph.py | 45 ++++++----- src/tracksdata/graph/_sql_graph.py | 71 ++++++++++++----- .../graph/_test/test_graph_backends.py | 77 +++++++++++++++++++ src/tracksdata/utils/_dtypes.py | 61 +++++++++++++++ .../utils/_test/test_dtype_serialization.py | 45 +++++++++++ 6 files changed, 332 insertions(+), 39 deletions(-) create mode 100644 src/tracksdata/utils/_test/test_dtype_serialization.py diff --git a/src/tracksdata/graph/_base_graph.py b/src/tracksdata/graph/_base_graph.py index 03dc3a01..5c2c8f85 100644 --- a/src/tracksdata/graph/_base_graph.py +++ b/src/tracksdata/graph/_base_graph.py @@ -4,6 +4,7 @@ from collections.abc import Sequence from pathlib import Path from typing import TYPE_CHECKING, Any, Literal, TypeVar, overload +import warnings import geff import numpy as np @@ -21,7 +22,9 @@ from tracksdata.utils._dtypes import ( AttrSchema, column_to_numpy, + deserialize_polars_dtype, polars_dtype_to_numpy_dtype, + serialize_polars_dtype, ) from tracksdata.utils._logging import LOG from tracksdata.utils._multiprocessing import multiprocessing_apply @@ -110,6 +113,7 @@ class BaseGraph(abc.ABC): """ _PRIVATE_METADATA_PREFIX = "__private_" + _PRIVATE_DTYPE_MAP_KEY = "__private_dtype_map" node_added = Signal(int) node_removed = Signal(int) @@ -1281,7 +1285,6 @@ def from_other(cls: type[T], other: "BaseGraph", **kwargs) -> T: current_edge_attr_schemas = graph._edge_attr_schemas() for k, v in other._edge_attr_schemas().items(): if k not in current_edge_attr_schemas: - print(f"Adding edge attribute key: {k} with dtype: {v.dtype} and default value: {v.default_value}") graph.add_edge_attr_key(k, v.dtype, v.default_value) edge_attrs = edge_attrs.with_columns( @@ -1948,6 +1951,73 @@ def _remove_metadata_with_validation(self, key: str, *, is_public: bool = True) self._validate_metadata_key(key, is_public=is_public) self._remove_metadata(key) + def _get_private_dtype_map(self) -> dict[str, dict[str, str]]: + dtype_map = self._private_metadata.get(self._PRIVATE_DTYPE_MAP_KEY, {}) + if not isinstance(dtype_map, dict): + return {"node": {}, "edge": {}} + + node_dtype_map = dtype_map.get("node", {}) + edge_dtype_map = dtype_map.get("edge", {}) + if not isinstance(node_dtype_map, dict): + node_dtype_map = {} + if not isinstance(edge_dtype_map, dict): + edge_dtype_map = {} + + return {"node": dict(node_dtype_map), "edge": dict(edge_dtype_map)} + + def _set_private_dtype_map(self, dtype_map: dict[str, dict[str, str]]) -> None: + self._private_metadata.update( + **{ + self._PRIVATE_DTYPE_MAP_KEY: { + "node": dict(dtype_map.get("node", {})), + "edge": dict(dtype_map.get("edge", {})), + } + } + ) + + def _set_attr_dtype_metadata(self, *, key: str, dtype: pl.DataType, is_node: bool) -> None: + dtype_map = self._get_private_dtype_map() + map_key = "node" if is_node else "edge" + dtype_map[map_key][key] = serialize_polars_dtype(dtype) + self._set_private_dtype_map(dtype_map) + + def _remove_attr_dtype_metadata(self, *, key: str, is_node: bool) -> None: + dtype_map = self._get_private_dtype_map() + map_key = "node" if is_node else "edge" + dtype_map[map_key].pop(key, None) + self._set_private_dtype_map(dtype_map) + + def _attr_dtype_from_metadata(self, *, key: str, is_node: bool) -> pl.DataType | None: + dtype_map = self._get_private_dtype_map() + map_key = "node" if is_node else "edge" + encoded_dtype = dtype_map[map_key].get(key) + if not isinstance(encoded_dtype, str): + return None + + try: + return deserialize_polars_dtype(encoded_dtype) + except Exception: + warnings.warn( + f"Initializing schemas from existing database tables for the key {key}. " + "This is a fallback mechanism when loading existing graphs, and may not perfectly restore the original schemas. " + "This method is deprecated and will be removed in the major release. ", + UserWarning, + ) + return None + + def _sync_attr_dtype_metadata(self) -> None: + dtype_map = { + "node": { + key: serialize_polars_dtype(schema.dtype) + for key, schema in self._node_attr_schemas().items() + }, + "edge": { + key: serialize_polars_dtype(schema.dtype) + for key, schema in self._edge_attr_schemas().items() + }, + } + self._set_private_dtype_map(dtype_map) + @abc.abstractmethod def _metadata(self) -> dict[str, Any]: """ diff --git a/src/tracksdata/graph/_rustworkx_graph.py b/src/tracksdata/graph/_rustworkx_graph.py index 05cf1c17..a15dd1f5 100644 --- a/src/tracksdata/graph/_rustworkx_graph.py +++ b/src/tracksdata/graph/_rustworkx_graph.py @@ -400,11 +400,13 @@ def __init__(self, rx_graph: rx.PyDiGraph | None = None) -> None: for key, value in first_node_attrs.items(): if key == DEFAULT_ATTR_KEYS.NODE_ID: continue - try: - dtype = pl.Series([value]).dtype - except (ValueError, TypeError): - # If polars can't infer dtype (e.g., for complex objects), use Object - dtype = pl.Object + dtype = self._attr_dtype_from_metadata(key=key, is_node=True) + if dtype is None: + try: + dtype = pl.Series([value]).dtype + except (ValueError, TypeError): + # If polars can't infer dtype (e.g., for complex objects), use Object + dtype = pl.Object self.__node_attr_schemas[key] = AttrSchema(key=key, dtype=dtype) # Process edges: set edge IDs and infer schemas @@ -422,13 +424,17 @@ def __init__(self, rx_graph: rx.PyDiGraph | None = None) -> None: # TODO: check if EDGE_SOURCE and EDGE_TARGET should be also ignored or in the schema if key == DEFAULT_ATTR_KEYS.EDGE_ID: continue - try: - dtype = pl.Series([value]).dtype - except (ValueError, TypeError): - # If polars can't infer dtype (e.g., for complex objects), use Object - dtype = pl.Object + dtype = self._attr_dtype_from_metadata(key=key, is_node=False) + if dtype is None: + try: + dtype = pl.Series([value]).dtype + except (ValueError, TypeError): + # If polars can't infer dtype (e.g., for complex objects), use Object + dtype = pl.Object self.__edge_attr_schemas[key] = AttrSchema(key=key, dtype=dtype) + self._sync_attr_dtype_metadata() + def _node_attr_schemas(self) -> dict[str, AttrSchema]: return self.__node_attr_schemas @@ -986,6 +992,7 @@ def add_node_attr_key( # Store schema self.__node_attr_schemas[schema.key] = schema + self._set_attr_dtype_metadata(key=schema.key, dtype=schema.dtype, is_node=True) def remove_node_attr_key(self, key: str) -> None: """ @@ -998,6 +1005,7 @@ def remove_node_attr_key(self, key: str) -> None: raise ValueError(f"Cannot remove required node attribute key {key}") del self.__node_attr_schemas[key] + self._remove_attr_dtype_metadata(key=key, is_node=True) for node_attr in self.rx_graph.nodes(): node_attr.pop(key, None) @@ -1026,6 +1034,7 @@ def add_edge_attr_key( # Store schema self.__edge_attr_schemas[schema.key] = schema + self._set_attr_dtype_metadata(key=schema.key, dtype=schema.dtype, is_node=False) def remove_edge_attr_key(self, key: str) -> None: """ @@ -1035,6 +1044,7 @@ def remove_edge_attr_key(self, key: str) -> None: raise ValueError(f"Edge attribute key {key} does not exist") del self.__edge_attr_schemas[key] + self._remove_attr_dtype_metadata(key=key, is_node=False) for edge_attr in self.rx_graph.edges(): edge_attr.pop(key, None) @@ -1153,16 +1163,11 @@ def edge_attrs( edge_map = rx_graph.edge_index_map() if len(edge_map) == 0: - return pl.DataFrame( - { - key: [] - for key in [ - *attr_keys, - DEFAULT_ATTR_KEYS.EDGE_SOURCE, - DEFAULT_ATTR_KEYS.EDGE_TARGET, - ] - } - ) + empty_columns = {} + for key in [*attr_keys, DEFAULT_ATTR_KEYS.EDGE_SOURCE, DEFAULT_ATTR_KEYS.EDGE_TARGET]: + schema = self._edge_attr_schemas()[key] + empty_columns[key] = pl.Series(name=key, values=[], dtype=schema.dtype) + return pl.DataFrame(empty_columns) source, target, data = zip(*edge_map.values(), strict=False) diff --git a/src/tracksdata/graph/_sql_graph.py b/src/tracksdata/graph/_sql_graph.py index c8ea38ed..b36e3ab8 100644 --- a/src/tracksdata/graph/_sql_graph.py +++ b/src/tracksdata/graph/_sql_graph.py @@ -479,6 +479,7 @@ def __init__( # Initialize schemas from existing table columns self._init_schemas_from_tables() + self._sync_attr_dtype_metadata() self._max_id_per_time = {} self._update_max_id_per_time() @@ -556,12 +557,19 @@ def _init_schemas_from_tables(self) -> None: Initialize AttrSchema objects from existing database table columns. This is used when loading an existing graph from the database. """ + + node_column_names = list(self.Node.__table__.columns.keys()) + preferred_node_order = [DEFAULT_ATTR_KEYS.T, DEFAULT_ATTR_KEYS.NODE_ID] + ordered_node_columns = [name for name in preferred_node_order if name in node_column_names] + ordered_node_columns.extend(name for name in node_column_names if name not in preferred_node_order) + # Initialize node schemas from Node table columns - for column_name in self.Node.__table__.columns.keys(): + for column_name in ordered_node_columns: if column_name not in self.__node_attr_schemas: - column = self.Node.__table__.columns[column_name] - # Infer polars dtype from SQLAlchemy type - pl_dtype = sqlalchemy_type_to_polars_dtype(column.type) + pl_dtype = self._attr_dtype_from_metadata(key=column_name, is_node=True) + if pl_dtype is None: + column = self.Node.__table__.columns[column_name] + pl_dtype = sqlalchemy_type_to_polars_dtype(column.type) # AttrSchema.__post_init__ will infer the default_value self.__node_attr_schemas[column_name] = AttrSchema( key=column_name, @@ -572,9 +580,10 @@ def _init_schemas_from_tables(self) -> None: for column_name in self.Edge.__table__.columns.keys(): # Skip internal edge columns if column_name not in self.__edge_attr_schemas: - column = self.Edge.__table__.columns[column_name] - # Infer polars dtype from SQLAlchemy type - pl_dtype = sqlalchemy_type_to_polars_dtype(column.type) + pl_dtype = self._attr_dtype_from_metadata(key=column_name, is_node=False) + if pl_dtype is None: + column = self.Edge.__table__.columns[column_name] + pl_dtype = sqlalchemy_type_to_polars_dtype(column.type) # AttrSchema.__post_init__ will infer the default_value self.__edge_attr_schemas[column_name] = AttrSchema( key=column_name, @@ -593,11 +602,16 @@ def _polars_schema_override(self, table_class: type[DeclarativeBase]) -> SchemaD else: schemas = self._edge_attr_schemas() - # Return schema overrides for special types that need explicit casting + # Return schema overrides for columns safely represented in SQL. + # Pickled columns are unpickled and casted in a second pass. return { key: schema.dtype for key, schema in schemas.items() - if not (schema.dtype == pl.Object or isinstance(schema.dtype, pl.Array | pl.List)) + if ( + key in table_class.__table__.columns + and not isinstance(table_class.__table__.columns[key].type, sa.PickleType | sa.LargeBinary) + and not (schema.dtype == pl.Object or isinstance(schema.dtype, pl.Array | pl.List)) + ) } def _cast_array_columns(self, table_class: type[DeclarativeBase], df: pl.DataFrame) -> pl.DataFrame: @@ -607,12 +621,19 @@ def _cast_array_columns(self, table_class: type[DeclarativeBase], df: pl.DataFra else: schemas = self._edge_attr_schemas() - # Cast array columns (stored as blobs in database) - df = df.with_columns( - pl.Series(key, df[key].to_list(), dtype=schema.dtype) - for key, schema in schemas.items() - if isinstance(schema.dtype, pl.Array) and key in df.columns - ) + casts: list[pl.Series] = [] + for key, schema in schemas.items(): + if key not in df.columns: + continue + + try: + casts.append(pl.Series(key, df[key].to_list(), dtype=schema.dtype)) + except Exception: + # Keep original dtype when values cannot be casted to the target schema. + continue + + if casts: + df = df.with_columns(casts) return df def _update_max_id_per_time(self) -> None: @@ -1289,6 +1310,8 @@ def node_attrs( # indices are included by default and must be removed if attr_keys is not None: nodes_df = nodes_df.select([pl.col(c) for c in attr_keys]) + else: + nodes_df = nodes_df.select([pl.col(c) for c in self._node_attr_schemas() if c in nodes_df.columns]) if unpack: nodes_df = unpack_array_attrs(nodes_df) @@ -1331,6 +1354,8 @@ def edge_attrs( if unpack: edges_df = unpack_array_attrs(edges_df) + elif attr_keys is None: + edges_df = edges_df.select([pl.col(c) for c in self._edge_attr_schemas() if c in edges_df.columns]) return edges_df @@ -1575,6 +1600,9 @@ def _add_new_column( sa_column = sa.Column(schema.key, sa_type, default=default_value) str_dialect_type = sa_column.type.compile(dialect=self._engine.dialect) + identifier_preparer = self._engine.dialect.identifier_preparer + quoted_table_name = identifier_preparer.format_table(table_class.__table__) + quoted_column_name = identifier_preparer.quote(sa_column.name) # Properly quote default values based on type if isinstance(default_value, str): @@ -1585,8 +1613,8 @@ def _add_new_column( quoted_default = str(default_value) add_column_stmt = sa.DDL( - f"ALTER TABLE {table_class.__table__} ADD " - f"COLUMN {sa_column.name} {str_dialect_type} " + f"ALTER TABLE {quoted_table_name} ADD " + f"COLUMN {quoted_column_name} {str_dialect_type} " f"DEFAULT {quoted_default}", ) LOG.info("add %s column statement:\n'%s'", table_class.__table__, add_column_stmt) @@ -1601,7 +1629,10 @@ def _add_new_column( table_class.__table__.append_column(sa_column) def _drop_column(self, table_class: type[DeclarativeBase], key: str) -> None: - drop_column_stmt = sa.DDL(f"ALTER TABLE {table_class.__table__} DROP COLUMN {key}") + identifier_preparer = self._engine.dialect.identifier_preparer + quoted_table_name = identifier_preparer.format_table(table_class.__table__) + quoted_column_name = identifier_preparer.quote(key) + drop_column_stmt = sa.DDL(f"ALTER TABLE {quoted_table_name} DROP COLUMN {quoted_column_name}") LOG.info("drop %s column statement:\n'%s'", table_class.__table__, drop_column_stmt) with Session(self._engine) as session: @@ -1625,6 +1656,7 @@ def add_node_attr_key( # Add column to database self._add_new_column(self.Node, schema) + self._set_attr_dtype_metadata(key=schema.key, dtype=schema.dtype, is_node=True) def remove_node_attr_key(self, key: str) -> None: if key not in self.node_attr_keys(): @@ -1635,6 +1667,7 @@ def remove_node_attr_key(self, key: str) -> None: self._drop_column(self.Node, key) self.__node_attr_schemas.pop(key, None) + self._remove_attr_dtype_metadata(key=key, is_node=True) def add_edge_attr_key( self, @@ -1650,6 +1683,7 @@ def add_edge_attr_key( # Add column to database self._add_new_column(self.Edge, schema) + self._set_attr_dtype_metadata(key=schema.key, dtype=schema.dtype, is_node=False) def remove_edge_attr_key(self, key: str) -> None: if key not in self.edge_attr_keys(): @@ -1657,6 +1691,7 @@ def remove_edge_attr_key(self, key: str) -> None: self._drop_column(self.Edge, key) self.__edge_attr_schemas.pop(key, None) + self._remove_attr_dtype_metadata(key=key, is_node=False) def num_edges(self) -> int: with Session(self._engine) as session: diff --git a/src/tracksdata/graph/_test/test_graph_backends.py b/src/tracksdata/graph/_test/test_graph_backends.py index e9088c75..f1ede1b2 100644 --- a/src/tracksdata/graph/_test/test_graph_backends.py +++ b/src/tracksdata/graph/_test/test_graph_backends.py @@ -1437,6 +1437,83 @@ def test_from_other_with_edges( assert new_overlaps == source_overlaps +@pytest.mark.parametrize( + ("target_cls", "target_kwargs"), + [ + pytest.param(RustWorkXGraph, {}, id="rustworkx"), + pytest.param( + SQLGraph, + { + "drivername": "sqlite", + "database": ":memory:", + "engine_kwargs": {"connect_args": {"check_same_thread": False}}, + }, + id="sql", + ), + pytest.param(IndexedRXGraph, {}, id="indexed"), + ], +) +def test_from_other_preserves_schema_roundtrip(target_cls: type[BaseGraph], target_kwargs: dict[str, Any]) -> None: + """Test that from_other preserves node and edge attribute schemas across backends.""" + graph = RustWorkXGraph() + for dtype in [ + pl.Float16, pl.Float32, + pl.Float64, + pl.Int8, + pl.Int16, + pl.Int32, + pl.Int64, + pl.UInt8, + pl.UInt16, + pl.UInt32, + pl.UInt64, + pl.Date, pl.Datetime, + pl.Boolean, + pl.Array(pl.Float32, 3), + pl.List(pl.Int32), + pl.Struct({"a": pl.Int8, "b": pl.Array(pl.String, 2)}), + pl.String, + pl.Object]: + graph.add_node_attr_key(f"attr_{dtype}", dtype=dtype) + graph.add_node({"t":0, + "attr_Float16": np.float16(1.5), + "attr_Float32": np.float32(2.5), + "attr_Float64": np.float64(3.5), + "attr_Int8": np.int8(4), + "attr_Int16": np.int16(5), + "attr_Int32": np.int32(6), + "attr_Int64": np.int64(7), + "attr_UInt8": np.uint8(8), + "attr_UInt16": np.uint16(9), + "attr_UInt32": np.uint32(10), + "attr_UInt64": np.uint64(11), + "attr_Date": pl.date(2024, 1, 1), + "attr_Datetime": pl.datetime(2024, 1, 1, 12, 0, 0), + "attr_Boolean": True, + "attr_Array(Float32, shape=(3,))": np.array([1.0, 2.0, 3.0], dtype=np.float32), + "attr_List(Int32)": [1, 2, 3], + "attr_Struct({'a': Int8, 'b': Array(String, shape=(2,))})": {"a": 1, "b": np.array(["x", "y"], dtype=object)}, + "attr_String": "test", + "attr_Object": {"key": "value"}}) + graph2 = target_cls.from_other(graph, **target_kwargs) + + assert graph2.num_nodes() == graph.num_nodes() + assert set(graph2.node_attr_keys()) == set(graph.node_attr_keys()) + + assert graph2._node_attr_schemas() == graph._node_attr_schemas() + assert graph2._edge_attr_schemas() == graph._edge_attr_schemas() + assert graph2.node_attrs().schema == graph.node_attrs().schema + assert graph2.edge_attrs().schema == graph.edge_attrs().schema + + graph3 = RustWorkXGraph.from_other(graph2) + assert graph3._node_attr_schemas() == graph._node_attr_schemas() + assert graph3._edge_attr_schemas() == graph._edge_attr_schemas() + assert graph3.node_attrs().schema == graph.node_attrs().schema + assert graph3.edge_attrs().schema == graph.edge_attrs().schema + + + + @pytest.mark.parametrize( ("target_cls", "target_kwargs"), [ diff --git a/src/tracksdata/utils/_dtypes.py b/src/tracksdata/utils/_dtypes.py index 8e671487..f0f24e25 100644 --- a/src/tracksdata/utils/_dtypes.py +++ b/src/tracksdata/utils/_dtypes.py @@ -1,6 +1,8 @@ from __future__ import annotations +import base64 from dataclasses import dataclass +import io from typing import Any import numpy as np @@ -202,6 +204,37 @@ def copy(self) -> AttrSchema: """ return AttrSchema(key=self.key, dtype=self.dtype, default_value=self.default_value) + def __eq__(self, other: object) -> bool: + if not isinstance(other, AttrSchema): + return NotImplemented + return ( + self.key == other.key + and self.dtype == other.dtype + and _values_equal(self.default_value, other.default_value) + ) + + +def _values_equal(left: Any, right: Any) -> bool: + if isinstance(left, np.ndarray) and isinstance(right, np.ndarray): + return bool(np.array_equal(left, right)) + if isinstance(left, dict) and isinstance(right, dict): + if left.keys() != right.keys(): + return False + return all(_values_equal(left[k], right[k]) for k in left) + if isinstance(left, list | tuple) and isinstance(right, list | tuple): + if len(left) != len(right): + return False + return all(_values_equal(lv, rv) for lv, rv in zip(left, right, strict=True)) + + try: + value = left == right + except Exception: + return False + + if isinstance(value, np.ndarray): + return bool(np.all(value)) + return bool(value) + def process_attr_key_args( key_or_schema: str | AttrSchema, @@ -445,6 +478,34 @@ def sqlalchemy_type_to_polars_dtype(sa_type: TypeEngine) -> pl.DataType: return pl.Object +def serialize_polars_dtype(dtype: pl.DataType) -> str: + """ + Serializes a Polars dtype to a safe, cross-platform base64 string + using the Arrow IPC format. + """ + # Wrap the dtype in an empty DataFrame schema + # We use an empty DataFrame so no actual data is processed, only metadata. + dummy_df = pl.DataFrame(schema={"dummy": dtype}) + # Write to Arrow IPC (binary buffer) + # IPC is stable across versions/platforms unlike internal serialization. + buffer = io.BytesIO() + dummy_df.write_ipc(buffer) + # Encode binary to a standard string + return base64.b64encode(buffer.getvalue()).decode('utf-8') + + +def deserialize_polars_dtype(encoded_dtype: str) -> pl.DataType: + """ + Recovers a Polars dtype from a base64 string. + """ + # Decode string back to binary + data = base64.b64decode(encoded_dtype) + # Read the IPC buffer + buffer = io.BytesIO(data) + restored_df = pl.read_ipc(buffer) + # Extract the dtype from the schema + return restored_df.schema["dummy"] + def validate_default_value_dtype_compatibility(default_value: Any, dtype: pl.DataType) -> None: """ Validate that a default value is compatible with a polars dtype. diff --git a/src/tracksdata/utils/_test/test_dtype_serialization.py b/src/tracksdata/utils/_test/test_dtype_serialization.py new file mode 100644 index 00000000..3a997209 --- /dev/null +++ b/src/tracksdata/utils/_test/test_dtype_serialization.py @@ -0,0 +1,45 @@ +import base64 +import binascii + +import polars as pl +import pytest + +from tracksdata.utils._dtypes import deserialize_polars_dtype, serialize_polars_dtype + + +@pytest.mark.parametrize( + "dtype", + [ + pl.Int64, + pl.Float32, + pl.Boolean, + pl.String, + pl.List(pl.Int16), + pl.Array(pl.Float64, 4), + pl.Array(pl.Int32, (2, 3)), + pl.Struct({"x": pl.Int64, "y": pl.List(pl.String)}), + pl.Datetime("us", "UTC"), + ], +) +def test_serialize_deserialize_polars_dtype_roundtrip(dtype: pl.DataType) -> None: + encoded = serialize_polars_dtype(dtype) + + assert isinstance(encoded, str) + assert encoded + assert base64.b64decode(encoded) + + restored_dtype = deserialize_polars_dtype(encoded) + + assert restored_dtype == dtype + + +def test_deserialize_polars_dtype_invalid_base64_raises() -> None: + with pytest.raises(binascii.Error): + deserialize_polars_dtype("not-base64") + + +def test_deserialize_polars_dtype_non_ipc_payload_raises() -> None: + encoded = base64.b64encode(b"not-arrow-ipc").decode("utf-8") + + with pytest.raises((OSError, pl.exceptions.PolarsError)): + deserialize_polars_dtype(encoded) From c50a07b9eaf33665312b3b8db58ac8107b6d2a77 Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Tue, 17 Feb 2026 15:48:05 +0900 Subject: [PATCH 06/12] lint --- src/tracksdata/graph/_base_graph.py | 16 ++-- src/tracksdata/graph/_sql_graph.py | 2 +- .../graph/_test/test_graph_backends.py | 86 ++++++++++--------- src/tracksdata/utils/_dtypes.py | 5 +- 4 files changed, 57 insertions(+), 52 deletions(-) diff --git a/src/tracksdata/graph/_base_graph.py b/src/tracksdata/graph/_base_graph.py index 5c2c8f85..9f35119e 100644 --- a/src/tracksdata/graph/_base_graph.py +++ b/src/tracksdata/graph/_base_graph.py @@ -1,10 +1,10 @@ import abc import functools import operator +import warnings from collections.abc import Sequence from pathlib import Path from typing import TYPE_CHECKING, Any, Literal, TypeVar, overload -import warnings import geff import numpy as np @@ -1999,22 +1999,18 @@ def _attr_dtype_from_metadata(self, *, key: str, is_node: bool) -> pl.DataType | except Exception: warnings.warn( f"Initializing schemas from existing database tables for the key {key}. " - "This is a fallback mechanism when loading existing graphs, and may not perfectly restore the original schemas. " + "This is a fallback mechanism when loading existing graphs, and may not " + "perfectly restore the original schemas. " "This method is deprecated and will be removed in the major release. ", UserWarning, + stacklevel=2, ) return None def _sync_attr_dtype_metadata(self) -> None: dtype_map = { - "node": { - key: serialize_polars_dtype(schema.dtype) - for key, schema in self._node_attr_schemas().items() - }, - "edge": { - key: serialize_polars_dtype(schema.dtype) - for key, schema in self._edge_attr_schemas().items() - }, + "node": {key: serialize_polars_dtype(schema.dtype) for key, schema in self._node_attr_schemas().items()}, + "edge": {key: serialize_polars_dtype(schema.dtype) for key, schema in self._edge_attr_schemas().items()}, } self._set_private_dtype_map(dtype_map) diff --git a/src/tracksdata/graph/_sql_graph.py b/src/tracksdata/graph/_sql_graph.py index b36e3ab8..735a10d3 100644 --- a/src/tracksdata/graph/_sql_graph.py +++ b/src/tracksdata/graph/_sql_graph.py @@ -557,7 +557,7 @@ def _init_schemas_from_tables(self) -> None: Initialize AttrSchema objects from existing database table columns. This is used when loading an existing graph from the database. """ - + node_column_names = list(self.Node.__table__.columns.keys()) preferred_node_order = [DEFAULT_ATTR_KEYS.T, DEFAULT_ATTR_KEYS.NODE_ID] ordered_node_columns = [name for name in preferred_node_order if name in node_column_names] diff --git a/src/tracksdata/graph/_test/test_graph_backends.py b/src/tracksdata/graph/_test/test_graph_backends.py index f1ede1b2..0bc7dcf9 100644 --- a/src/tracksdata/graph/_test/test_graph_backends.py +++ b/src/tracksdata/graph/_test/test_graph_backends.py @@ -1457,44 +1457,54 @@ def test_from_other_preserves_schema_roundtrip(target_cls: type[BaseGraph], targ """Test that from_other preserves node and edge attribute schemas across backends.""" graph = RustWorkXGraph() for dtype in [ - pl.Float16, pl.Float32, - pl.Float64, - pl.Int8, - pl.Int16, - pl.Int32, - pl.Int64, - pl.UInt8, - pl.UInt16, - pl.UInt32, - pl.UInt64, - pl.Date, pl.Datetime, - pl.Boolean, - pl.Array(pl.Float32, 3), - pl.List(pl.Int32), - pl.Struct({"a": pl.Int8, "b": pl.Array(pl.String, 2)}), - pl.String, - pl.Object]: + pl.Float16, + pl.Float32, + pl.Float64, + pl.Int8, + pl.Int16, + pl.Int32, + pl.Int64, + pl.UInt8, + pl.UInt16, + pl.UInt32, + pl.UInt64, + pl.Date, + pl.Datetime, + pl.Boolean, + pl.Array(pl.Float32, 3), + pl.List(pl.Int32), + pl.Struct({"a": pl.Int8, "b": pl.Array(pl.String, 2)}), + pl.String, + pl.Object, + ]: graph.add_node_attr_key(f"attr_{dtype}", dtype=dtype) - graph.add_node({"t":0, - "attr_Float16": np.float16(1.5), - "attr_Float32": np.float32(2.5), - "attr_Float64": np.float64(3.5), - "attr_Int8": np.int8(4), - "attr_Int16": np.int16(5), - "attr_Int32": np.int32(6), - "attr_Int64": np.int64(7), - "attr_UInt8": np.uint8(8), - "attr_UInt16": np.uint16(9), - "attr_UInt32": np.uint32(10), - "attr_UInt64": np.uint64(11), - "attr_Date": pl.date(2024, 1, 1), - "attr_Datetime": pl.datetime(2024, 1, 1, 12, 0, 0), - "attr_Boolean": True, - "attr_Array(Float32, shape=(3,))": np.array([1.0, 2.0, 3.0], dtype=np.float32), - "attr_List(Int32)": [1, 2, 3], - "attr_Struct({'a': Int8, 'b': Array(String, shape=(2,))})": {"a": 1, "b": np.array(["x", "y"], dtype=object)}, - "attr_String": "test", - "attr_Object": {"key": "value"}}) + graph.add_node( + { + "t": 0, + "attr_Float16": np.float16(1.5), + "attr_Float32": np.float32(2.5), + "attr_Float64": np.float64(3.5), + "attr_Int8": np.int8(4), + "attr_Int16": np.int16(5), + "attr_Int32": np.int32(6), + "attr_Int64": np.int64(7), + "attr_UInt8": np.uint8(8), + "attr_UInt16": np.uint16(9), + "attr_UInt32": np.uint32(10), + "attr_UInt64": np.uint64(11), + "attr_Date": pl.date(2024, 1, 1), + "attr_Datetime": pl.datetime(2024, 1, 1, 12, 0, 0), + "attr_Boolean": True, + "attr_Array(Float32, shape=(3,))": np.array([1.0, 2.0, 3.0], dtype=np.float32), + "attr_List(Int32)": [1, 2, 3], + "attr_Struct({'a': Int8, 'b': Array(String, shape=(2,))})": { + "a": 1, + "b": np.array(["x", "y"], dtype=object), + }, + "attr_String": "test", + "attr_Object": {"key": "value"}, + } + ) graph2 = target_cls.from_other(graph, **target_kwargs) assert graph2.num_nodes() == graph.num_nodes() @@ -1512,8 +1522,6 @@ def test_from_other_preserves_schema_roundtrip(target_cls: type[BaseGraph], targ assert graph3.edge_attrs().schema == graph.edge_attrs().schema - - @pytest.mark.parametrize( ("target_cls", "target_kwargs"), [ diff --git a/src/tracksdata/utils/_dtypes.py b/src/tracksdata/utils/_dtypes.py index f0f24e25..90fc2006 100644 --- a/src/tracksdata/utils/_dtypes.py +++ b/src/tracksdata/utils/_dtypes.py @@ -1,8 +1,8 @@ from __future__ import annotations import base64 -from dataclasses import dataclass import io +from dataclasses import dataclass from typing import Any import numpy as np @@ -491,7 +491,7 @@ def serialize_polars_dtype(dtype: pl.DataType) -> str: buffer = io.BytesIO() dummy_df.write_ipc(buffer) # Encode binary to a standard string - return base64.b64encode(buffer.getvalue()).decode('utf-8') + return base64.b64encode(buffer.getvalue()).decode("utf-8") def deserialize_polars_dtype(encoded_dtype: str) -> pl.DataType: @@ -506,6 +506,7 @@ def deserialize_polars_dtype(encoded_dtype: str) -> pl.DataType: # Extract the dtype from the schema return restored_df.schema["dummy"] + def validate_default_value_dtype_compatibility(default_value: Any, dtype: pl.DataType) -> None: """ Validate that a default value is compatible with a polars dtype. From e9bf28f90f4591b4fb82c1a6c93e7703fbe8aae3 Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Wed, 18 Feb 2026 10:45:18 +0900 Subject: [PATCH 07/12] restricted dtype metadata to sqlgraph --- src/tracksdata/graph/_base_graph.py | 73 +-------- src/tracksdata/graph/_rustworkx_graph.py | 32 ++-- src/tracksdata/graph/_sql_graph.py | 151 +++++++++++++----- .../graph/_test/test_graph_backends.py | 64 ++++++++ src/tracksdata/utils/_dtypes.py | 91 +++++++++++ .../utils/_test/test_dtype_serialization.py | 27 +++- 6 files changed, 310 insertions(+), 128 deletions(-) diff --git a/src/tracksdata/graph/_base_graph.py b/src/tracksdata/graph/_base_graph.py index 9f35119e..90500b67 100644 --- a/src/tracksdata/graph/_base_graph.py +++ b/src/tracksdata/graph/_base_graph.py @@ -1,7 +1,6 @@ import abc import functools import operator -import warnings from collections.abc import Sequence from pathlib import Path from typing import TYPE_CHECKING, Any, Literal, TypeVar, overload @@ -22,9 +21,7 @@ from tracksdata.utils._dtypes import ( AttrSchema, column_to_numpy, - deserialize_polars_dtype, polars_dtype_to_numpy_dtype, - serialize_polars_dtype, ) from tracksdata.utils._logging import LOG from tracksdata.utils._multiprocessing import multiprocessing_apply @@ -113,7 +110,6 @@ class BaseGraph(abc.ABC): """ _PRIVATE_METADATA_PREFIX = "__private_" - _PRIVATE_DTYPE_MAP_KEY = "__private_dtype_map" node_added = Signal(int) node_removed = Signal(int) @@ -1255,7 +1251,7 @@ def from_other(cls: type[T], other: "BaseGraph", **kwargs) -> T: graph = cls(**kwargs) graph.metadata.update(other.metadata) - graph._private_metadata.update(other._private_metadata) + graph._private_metadata.update(other._private_metadata_for_copy()) current_node_attr_schemas = graph._node_attr_schemas() for k, v in other._node_attr_schemas().items(): @@ -1951,68 +1947,13 @@ def _remove_metadata_with_validation(self, key: str, *, is_public: bool = True) self._validate_metadata_key(key, is_public=is_public) self._remove_metadata(key) - def _get_private_dtype_map(self) -> dict[str, dict[str, str]]: - dtype_map = self._private_metadata.get(self._PRIVATE_DTYPE_MAP_KEY, {}) - if not isinstance(dtype_map, dict): - return {"node": {}, "edge": {}} - - node_dtype_map = dtype_map.get("node", {}) - edge_dtype_map = dtype_map.get("edge", {}) - if not isinstance(node_dtype_map, dict): - node_dtype_map = {} - if not isinstance(edge_dtype_map, dict): - edge_dtype_map = {} - - return {"node": dict(node_dtype_map), "edge": dict(edge_dtype_map)} - - def _set_private_dtype_map(self, dtype_map: dict[str, dict[str, str]]) -> None: - self._private_metadata.update( - **{ - self._PRIVATE_DTYPE_MAP_KEY: { - "node": dict(dtype_map.get("node", {})), - "edge": dict(dtype_map.get("edge", {})), - } - } - ) - - def _set_attr_dtype_metadata(self, *, key: str, dtype: pl.DataType, is_node: bool) -> None: - dtype_map = self._get_private_dtype_map() - map_key = "node" if is_node else "edge" - dtype_map[map_key][key] = serialize_polars_dtype(dtype) - self._set_private_dtype_map(dtype_map) - - def _remove_attr_dtype_metadata(self, *, key: str, is_node: bool) -> None: - dtype_map = self._get_private_dtype_map() - map_key = "node" if is_node else "edge" - dtype_map[map_key].pop(key, None) - self._set_private_dtype_map(dtype_map) - - def _attr_dtype_from_metadata(self, *, key: str, is_node: bool) -> pl.DataType | None: - dtype_map = self._get_private_dtype_map() - map_key = "node" if is_node else "edge" - encoded_dtype = dtype_map[map_key].get(key) - if not isinstance(encoded_dtype, str): - return None - - try: - return deserialize_polars_dtype(encoded_dtype) - except Exception: - warnings.warn( - f"Initializing schemas from existing database tables for the key {key}. " - "This is a fallback mechanism when loading existing graphs, and may not " - "perfectly restore the original schemas. " - "This method is deprecated and will be removed in the major release. ", - UserWarning, - stacklevel=2, - ) - return None + def _private_metadata_for_copy(self) -> dict[str, Any]: + """ + Return private metadata entries that should be propagated by `from_other`. - def _sync_attr_dtype_metadata(self) -> None: - dtype_map = { - "node": {key: serialize_polars_dtype(schema.dtype) for key, schema in self._node_attr_schemas().items()}, - "edge": {key: serialize_polars_dtype(schema.dtype) for key, schema in self._edge_attr_schemas().items()}, - } - self._set_private_dtype_map(dtype_map) + Backends can override this to exclude backend-specific private metadata. + """ + return dict(self._private_metadata) @abc.abstractmethod def _metadata(self) -> dict[str, Any]: diff --git a/src/tracksdata/graph/_rustworkx_graph.py b/src/tracksdata/graph/_rustworkx_graph.py index a15dd1f5..a415b89d 100644 --- a/src/tracksdata/graph/_rustworkx_graph.py +++ b/src/tracksdata/graph/_rustworkx_graph.py @@ -343,7 +343,7 @@ def __init__(self, rx_graph: rx.PyDiGraph | None = None) -> None: self._time_to_nodes: dict[int, list[int]] = {} self.__node_attr_schemas: dict[str, AttrSchema] = {} self.__edge_attr_schemas: dict[str, AttrSchema] = {} - self._overlaps: list[list[int, 2]] = [] + self._overlaps: list[list[int]] = [] # Add default node attributes with inferred schemas self.__node_attr_schemas[DEFAULT_ATTR_KEYS.T] = AttrSchema( @@ -400,13 +400,11 @@ def __init__(self, rx_graph: rx.PyDiGraph | None = None) -> None: for key, value in first_node_attrs.items(): if key == DEFAULT_ATTR_KEYS.NODE_ID: continue - dtype = self._attr_dtype_from_metadata(key=key, is_node=True) - if dtype is None: - try: - dtype = pl.Series([value]).dtype - except (ValueError, TypeError): - # If polars can't infer dtype (e.g., for complex objects), use Object - dtype = pl.Object + try: + dtype = pl.Series([value]).dtype + except (ValueError, TypeError): + # If polars can't infer dtype (e.g., for complex objects), use Object + dtype = pl.Object self.__node_attr_schemas[key] = AttrSchema(key=key, dtype=dtype) # Process edges: set edge IDs and infer schemas @@ -424,17 +422,13 @@ def __init__(self, rx_graph: rx.PyDiGraph | None = None) -> None: # TODO: check if EDGE_SOURCE and EDGE_TARGET should be also ignored or in the schema if key == DEFAULT_ATTR_KEYS.EDGE_ID: continue - dtype = self._attr_dtype_from_metadata(key=key, is_node=False) - if dtype is None: - try: - dtype = pl.Series([value]).dtype - except (ValueError, TypeError): - # If polars can't infer dtype (e.g., for complex objects), use Object - dtype = pl.Object + try: + dtype = pl.Series([value]).dtype + except (ValueError, TypeError): + # If polars can't infer dtype (e.g., for complex objects), use Object + dtype = pl.Object self.__edge_attr_schemas[key] = AttrSchema(key=key, dtype=dtype) - self._sync_attr_dtype_metadata() - def _node_attr_schemas(self) -> dict[str, AttrSchema]: return self.__node_attr_schemas @@ -992,7 +986,6 @@ def add_node_attr_key( # Store schema self.__node_attr_schemas[schema.key] = schema - self._set_attr_dtype_metadata(key=schema.key, dtype=schema.dtype, is_node=True) def remove_node_attr_key(self, key: str) -> None: """ @@ -1005,7 +998,6 @@ def remove_node_attr_key(self, key: str) -> None: raise ValueError(f"Cannot remove required node attribute key {key}") del self.__node_attr_schemas[key] - self._remove_attr_dtype_metadata(key=key, is_node=True) for node_attr in self.rx_graph.nodes(): node_attr.pop(key, None) @@ -1034,7 +1026,6 @@ def add_edge_attr_key( # Store schema self.__edge_attr_schemas[schema.key] = schema - self._set_attr_dtype_metadata(key=schema.key, dtype=schema.dtype, is_node=False) def remove_edge_attr_key(self, key: str) -> None: """ @@ -1044,7 +1035,6 @@ def remove_edge_attr_key(self, key: str) -> None: raise ValueError(f"Edge attribute key {key} does not exist") del self.__edge_attr_schemas[key] - self._remove_attr_dtype_metadata(key=key, is_node=False) for edge_attr in self.rx_graph.edges(): edge_attr.pop(key, None) diff --git a/src/tracksdata/graph/_sql_graph.py b/src/tracksdata/graph/_sql_graph.py index 735a10d3..b9da5c1a 100644 --- a/src/tracksdata/graph/_sql_graph.py +++ b/src/tracksdata/graph/_sql_graph.py @@ -20,8 +20,10 @@ from tracksdata.utils._dataframe import unpack_array_attrs, unpickle_bytes_columns from tracksdata.utils._dtypes import ( AttrSchema, + deserialize_attr_schema, polars_dtype_to_sqlalchemy_type, process_attr_key_args, + serialize_attr_schema, sqlalchemy_type_to_polars_dtype, ) from tracksdata.utils._logging import LOG @@ -441,6 +443,7 @@ class SQLGraph(BaseGraph): """ node_id_time_multiplier: int = 1_000_000_000 + _PRIVATE_SQL_SCHEMA_STORE_KEY = "__private_sql_attr_schema_store" Base: type[DeclarativeBase] Node: type[DeclarativeBase] Edge: type[DeclarativeBase] @@ -469,8 +472,6 @@ def __init__( # Create unique classes for this instance self._define_schema(overwrite=overwrite) - self.__node_attr_schemas: dict[str, AttrSchema] = {} - self.__edge_attr_schemas: dict[str, AttrSchema] = {} if overwrite: self.Base.metadata.drop_all(self._engine) @@ -479,7 +480,6 @@ def __init__( # Initialize schemas from existing table columns self._init_schemas_from_tables() - self._sync_attr_dtype_metadata() self._max_id_per_time = {} self._update_max_id_per_time() @@ -552,43 +552,109 @@ class Metadata(Base): self.Overlap = Overlap self.Metadata = Metadata + @classmethod + def _empty_attr_schema_store(cls) -> dict[str, dict[str, str]]: + return {"node": {}, "edge": {}} + + def _attr_schema_store(self) -> dict[str, dict[str, str]]: + store = self._private_metadata.get(self._PRIVATE_SQL_SCHEMA_STORE_KEY, {}) + if not isinstance(store, dict): + return self._empty_attr_schema_store() + + normalized = self._empty_attr_schema_store() + for section_key in ("node", "edge"): + section = store.get(section_key, {}) + if not isinstance(section, dict): + continue + for key, encoded_schema in section.items(): + if isinstance(encoded_schema, str): + normalized[section_key][key] = encoded_schema + + return normalized + + def _set_attr_schema_store(self, store: dict[str, dict[str, str]]) -> None: + normalized = self._empty_attr_schema_store() + for section_key in ("node", "edge"): + section = store.get(section_key, {}) + if not isinstance(section, dict): + continue + for key, encoded_schema in section.items(): + if isinstance(encoded_schema, str): + normalized[section_key][key] = encoded_schema + + self._private_metadata.update(**{self._PRIVATE_SQL_SCHEMA_STORE_KEY: normalized}) + + def _get_attr_schemas_from_store(self, *, is_node: bool) -> dict[str, AttrSchema]: + section_key = "node" if is_node else "edge" + section = self._attr_schema_store()[section_key] + + schemas: dict[str, AttrSchema] = {} + for key, encoded_schema in section.items(): + try: + schemas[key] = deserialize_attr_schema(encoded_schema, key=key) + except Exception: + LOG.warning( + "Failed to deserialize SQL schema metadata for key '%s'. Falling back to table inference.", + key, + ) + + return schemas + + def _set_attr_schemas_to_store(self, *, is_node: bool, schemas: dict[str, AttrSchema]) -> None: + section_key = "node" if is_node else "edge" + store = self._attr_schema_store() + store[section_key] = {key: serialize_attr_schema(schema) for key, schema in schemas.items()} + self._set_attr_schema_store(store) + + @property + def __node_attr_schemas(self) -> dict[str, AttrSchema]: + return self._get_attr_schemas_from_store(is_node=True) + + @__node_attr_schemas.setter + def __node_attr_schemas(self, schemas: dict[str, AttrSchema]) -> None: + self._set_attr_schemas_to_store(is_node=True, schemas=schemas) + + @property + def __edge_attr_schemas(self) -> dict[str, AttrSchema]: + return self._get_attr_schemas_from_store(is_node=False) + + @__edge_attr_schemas.setter + def __edge_attr_schemas(self, schemas: dict[str, AttrSchema]) -> None: + self._set_attr_schemas_to_store(is_node=False, schemas=schemas) + def _init_schemas_from_tables(self) -> None: """ Initialize AttrSchema objects from existing database table columns. This is used when loading an existing graph from the database. """ - node_column_names = list(self.Node.__table__.columns.keys()) preferred_node_order = [DEFAULT_ATTR_KEYS.T, DEFAULT_ATTR_KEYS.NODE_ID] ordered_node_columns = [name for name in preferred_node_order if name in node_column_names] ordered_node_columns.extend(name for name in node_column_names if name not in preferred_node_order) - # Initialize node schemas from Node table columns + node_schemas = {k: v for k, v in self.__node_attr_schemas.items() if k in ordered_node_columns} for column_name in ordered_node_columns: - if column_name not in self.__node_attr_schemas: - pl_dtype = self._attr_dtype_from_metadata(key=column_name, is_node=True) - if pl_dtype is None: - column = self.Node.__table__.columns[column_name] - pl_dtype = sqlalchemy_type_to_polars_dtype(column.type) - # AttrSchema.__post_init__ will infer the default_value - self.__node_attr_schemas[column_name] = AttrSchema( - key=column_name, - dtype=pl_dtype, - ) + if column_name in node_schemas: + continue + column = self.Node.__table__.columns[column_name] + node_schemas[column_name] = AttrSchema( + key=column_name, + dtype=sqlalchemy_type_to_polars_dtype(column.type), + ) + self.__node_attr_schemas = node_schemas # Initialize edge schemas from Edge table columns + edge_column_names = list(self.Edge.__table__.columns.keys()) + edge_schemas = {k: v for k, v in self.__edge_attr_schemas.items() if k in edge_column_names} for column_name in self.Edge.__table__.columns.keys(): - # Skip internal edge columns - if column_name not in self.__edge_attr_schemas: - pl_dtype = self._attr_dtype_from_metadata(key=column_name, is_node=False) - if pl_dtype is None: - column = self.Edge.__table__.columns[column_name] - pl_dtype = sqlalchemy_type_to_polars_dtype(column.type) - # AttrSchema.__post_init__ will infer the default_value - self.__edge_attr_schemas[column_name] = AttrSchema( - key=column_name, - dtype=pl_dtype, - ) + if column_name in edge_schemas: + continue + column = self.Edge.__table__.columns[column_name] + edge_schemas[column_name] = AttrSchema( + key=column_name, + dtype=sqlalchemy_type_to_polars_dtype(column.type), + ) + self.__edge_attr_schemas = edge_schemas def _restore_pickled_column_types(self, table: sa.Table) -> None: for column in table.columns: @@ -1648,15 +1714,14 @@ def add_node_attr_key( dtype: pl.DataType | None = None, default_value: Any = None, ) -> None: + node_schemas = self.__node_attr_schemas # Process arguments and create validated schema - schema = process_attr_key_args(key_or_schema, dtype, default_value, self.__node_attr_schemas) - - # Store schema - self.__node_attr_schemas[schema.key] = schema + schema = process_attr_key_args(key_or_schema, dtype, default_value, node_schemas) # Add column to database self._add_new_column(self.Node, schema) - self._set_attr_dtype_metadata(key=schema.key, dtype=schema.dtype, is_node=True) + node_schemas[schema.key] = schema + self.__node_attr_schemas = node_schemas def remove_node_attr_key(self, key: str) -> None: if key not in self.node_attr_keys(): @@ -1665,9 +1730,10 @@ def remove_node_attr_key(self, key: str) -> None: if key in (DEFAULT_ATTR_KEYS.NODE_ID, DEFAULT_ATTR_KEYS.T): raise ValueError(f"Cannot remove required node attribute key {key}") + node_schemas = self.__node_attr_schemas self._drop_column(self.Node, key) - self.__node_attr_schemas.pop(key, None) - self._remove_attr_dtype_metadata(key=key, is_node=True) + node_schemas.pop(key, None) + self.__node_attr_schemas = node_schemas def add_edge_attr_key( self, @@ -1675,23 +1741,23 @@ def add_edge_attr_key( dtype: pl.DataType | None = None, default_value: Any = None, ) -> None: + edge_schemas = self.__edge_attr_schemas # Process arguments and create validated schema - schema = process_attr_key_args(key_or_schema, dtype, default_value, self.__edge_attr_schemas) - - # Store schema - self.__edge_attr_schemas[schema.key] = schema + schema = process_attr_key_args(key_or_schema, dtype, default_value, edge_schemas) # Add column to database self._add_new_column(self.Edge, schema) - self._set_attr_dtype_metadata(key=schema.key, dtype=schema.dtype, is_node=False) + edge_schemas[schema.key] = schema + self.__edge_attr_schemas = edge_schemas def remove_edge_attr_key(self, key: str) -> None: if key not in self.edge_attr_keys(): raise ValueError(f"Edge attribute key {key} does not exist") + edge_schemas = self.__edge_attr_schemas self._drop_column(self.Edge, key) - self.__edge_attr_schemas.pop(key, None) - self._remove_attr_dtype_metadata(key=key, is_node=False) + edge_schemas.pop(key, None) + self.__edge_attr_schemas = edge_schemas def num_edges(self) -> int: with Session(self._engine) as session: @@ -2032,6 +2098,11 @@ def _metadata(self) -> dict[str, Any]: result = session.query(self.Metadata).all() return {row.key: row.value for row in result} + def _private_metadata_for_copy(self) -> dict[str, Any]: + private_metadata = super()._private_metadata_for_copy() + private_metadata.pop(self._PRIVATE_SQL_SCHEMA_STORE_KEY, None) + return private_metadata + def _update_metadata(self, **kwargs) -> None: with Session(self._engine) as session: for key, value in kwargs.items(): diff --git a/src/tracksdata/graph/_test/test_graph_backends.py b/src/tracksdata/graph/_test/test_graph_backends.py index 0bc7dcf9..99ee386e 100644 --- a/src/tracksdata/graph/_test/test_graph_backends.py +++ b/src/tracksdata/graph/_test/test_graph_backends.py @@ -1704,6 +1704,70 @@ def test_sql_graph_max_id_restored_per_timepoint(tmp_path: Path) -> None: assert next_id == first_id + 1 +def test_sql_graph_schema_defaults_survive_reload(tmp_path: Path) -> None: + """Reloading a SQLGraph should preserve dtype and default schema metadata.""" + db_path = tmp_path / "schema_defaults.db" + graph = SQLGraph("sqlite", str(db_path)) + + node_array_default = np.array([1.0, 2.0, 3.0], dtype=np.float32) + node_object_default = {"nested": [1, 2, 3]} + edge_score_default = 0.25 + + graph.add_node_attr_key("node_array_default", pl.Array(pl.Float32, 3), node_array_default) + graph.add_node_attr_key("node_object_default", pl.Object, node_object_default) + graph.add_edge_attr_key("edge_score_default", pl.Float32, edge_score_default) + graph._engine.dispose() + + reloaded = SQLGraph("sqlite", str(db_path)) + + node_schemas = reloaded._node_attr_schemas() + edge_schemas = reloaded._edge_attr_schemas() + np.testing.assert_array_equal(node_schemas["node_array_default"].default_value, node_array_default) + assert node_schemas["node_array_default"].dtype == pl.Array(pl.Float32, 3) + assert node_schemas["node_object_default"].default_value == node_object_default + assert node_schemas["node_object_default"].dtype == pl.Object + assert edge_schemas["edge_score_default"].default_value == edge_score_default + assert edge_schemas["edge_score_default"].dtype == pl.Float32 + + +def test_sql_schema_metadata_not_copied_to_in_memory_graphs() -> None: + """SQL-private schema metadata should not leak into in-memory backends via from_other.""" + sql_graph = SQLGraph("sqlite", ":memory:") + sql_graph.add_node_attr_key("node_array_default", pl.Array(pl.Float32, 3), np.array([1.0, 2.0, 3.0], np.float32)) + sql_graph.add_node_attr_key("node_object_default", pl.Object, {"payload": [1, 2, 3]}) + sql_graph.add_edge_attr_key("edge_score_default", pl.Float32, 0.25) + + n1 = sql_graph.add_node( + { + "t": 0, + "node_array_default": np.array([1.0, 1.0, 1.0], dtype=np.float32), + "node_object_default": {"payload": [10]}, + } + ) + n2 = sql_graph.add_node( + { + "t": 1, + "node_array_default": np.array([2.0, 2.0, 2.0], dtype=np.float32), + "node_object_default": {"payload": [20]}, + } + ) + sql_graph.add_edge(n1, n2, {"edge_score_default": 0.75}) + + assert SQLGraph._PRIVATE_SQL_SCHEMA_STORE_KEY in sql_graph._private_metadata + + rx_graph = RustWorkXGraph.from_other(sql_graph) + assert SQLGraph._PRIVATE_SQL_SCHEMA_STORE_KEY not in rx_graph._metadata() + + sql_graph_roundtrip = SQLGraph.from_other( + rx_graph, + drivername="sqlite", + database=":memory:", + engine_kwargs={"connect_args": {"check_same_thread": False}}, + ) + assert sql_graph_roundtrip._node_attr_schemas() == sql_graph._node_attr_schemas() + assert sql_graph_roundtrip._edge_attr_schemas() == sql_graph._edge_attr_schemas() + + def test_compute_overlaps_invalid_threshold(graph_backend: BaseGraph) -> None: """Test compute_overlaps with invalid threshold values.""" with pytest.raises(ValueError, match=r"iou_threshold must be between 0.0 and 1\.0"): diff --git a/src/tracksdata/utils/_dtypes.py b/src/tracksdata/utils/_dtypes.py index 90fc2006..0245acdc 100644 --- a/src/tracksdata/utils/_dtypes.py +++ b/src/tracksdata/utils/_dtypes.py @@ -507,6 +507,97 @@ def deserialize_polars_dtype(encoded_dtype: str) -> pl.DataType: return restored_df.schema["dummy"] +_ATTR_SCHEMA_DTYPE_COL = "__attr_schema_dtype__" +_ATTR_SCHEMA_DEFAULT_COL = "__attr_schema_default__" +_ATTR_SCHEMA_DTYPE_PICKLE_COL = "__attr_schema_dtype_pickle__" + + +def serialize_attr_schema(schema: AttrSchema) -> str: + """ + Serialize an AttrSchema into a base64-encoded Arrow IPC payload. + + The payload stores dtype metadata and the default value in the same + DataFrame serialization so schema roundtrip can restore both fields. + """ + default_payload = dumps(schema.default_value) + dtype_payload = dumps(schema.dtype) + df = pl.DataFrame( + { + _ATTR_SCHEMA_DTYPE_COL: pl.Series( + _ATTR_SCHEMA_DTYPE_COL, + values=[None], + dtype=schema.dtype, + ), + _ATTR_SCHEMA_DEFAULT_COL: pl.Series( + _ATTR_SCHEMA_DEFAULT_COL, + values=[default_payload], + dtype=pl.Binary, + ), + _ATTR_SCHEMA_DTYPE_PICKLE_COL: pl.Series( + _ATTR_SCHEMA_DTYPE_PICKLE_COL, + values=[dtype_payload], + dtype=pl.Binary, + ), + } + ) + + buffer = io.BytesIO() + try: + df.write_ipc(buffer) + except Exception: + # Fallback for dtypes that cannot be represented in Arrow IPC schema + # (e.g., pl.Object). Keep everything in one DataFrame payload. + fallback_df = pl.DataFrame( + { + _ATTR_SCHEMA_DTYPE_COL: pl.Series( + _ATTR_SCHEMA_DTYPE_COL, + values=[None], + dtype=pl.Binary, + ), + _ATTR_SCHEMA_DEFAULT_COL: pl.Series( + _ATTR_SCHEMA_DEFAULT_COL, + values=[default_payload], + dtype=pl.Binary, + ), + _ATTR_SCHEMA_DTYPE_PICKLE_COL: pl.Series( + _ATTR_SCHEMA_DTYPE_PICKLE_COL, + values=[dtype_payload], + dtype=pl.Binary, + ), + } + ) + buffer = io.BytesIO() + fallback_df.write_ipc(buffer) + + return base64.b64encode(buffer.getvalue()).decode("utf-8") + + +def deserialize_attr_schema(encoded_schema: str, *, key: str) -> AttrSchema: + """ + Deserialize an AttrSchema previously encoded by `serialize_attr_schema`. + """ + data = base64.b64decode(encoded_schema) + buffer = io.BytesIO(data) + restored_df = pl.read_ipc(buffer) + + if _ATTR_SCHEMA_DTYPE_PICKLE_COL in restored_df.columns: + dtype_pickle = restored_df[_ATTR_SCHEMA_DTYPE_PICKLE_COL][0] + else: + dtype_pickle = None + + if dtype_pickle is not None: + dtype = loads(dtype_pickle) + else: + dtype = restored_df.schema[_ATTR_SCHEMA_DTYPE_COL] + + if not pl.datatypes.is_polars_dtype(dtype): + raise TypeError(f"Decoded value is not a polars dtype: {type(dtype)}") + + default_payload = restored_df[_ATTR_SCHEMA_DEFAULT_COL][0] + default_value = loads(default_payload) if default_payload is not None else None + return AttrSchema(key=key, dtype=dtype, default_value=default_value) + + def validate_default_value_dtype_compatibility(default_value: Any, dtype: pl.DataType) -> None: """ Validate that a default value is compatible with a polars dtype. diff --git a/src/tracksdata/utils/_test/test_dtype_serialization.py b/src/tracksdata/utils/_test/test_dtype_serialization.py index 3a997209..51b2659a 100644 --- a/src/tracksdata/utils/_test/test_dtype_serialization.py +++ b/src/tracksdata/utils/_test/test_dtype_serialization.py @@ -1,10 +1,17 @@ import base64 import binascii +import numpy as np import polars as pl import pytest -from tracksdata.utils._dtypes import deserialize_polars_dtype, serialize_polars_dtype +from tracksdata.utils._dtypes import ( + AttrSchema, + deserialize_attr_schema, + deserialize_polars_dtype, + serialize_attr_schema, + serialize_polars_dtype, +) @pytest.mark.parametrize( @@ -43,3 +50,21 @@ def test_deserialize_polars_dtype_non_ipc_payload_raises() -> None: with pytest.raises((OSError, pl.exceptions.PolarsError)): deserialize_polars_dtype(encoded) + + +@pytest.mark.parametrize( + "schema", + [ + AttrSchema(key="score", dtype=pl.Float64, default_value=1.25), + AttrSchema( + key="vector", + dtype=pl.Array(pl.Float32, 3), + default_value=np.array([1.0, 2.0, 3.0], dtype=np.float32), + ), + AttrSchema(key="payload", dtype=pl.Object, default_value={"nested": [1, 2, 3]}), + ], +) +def test_serialize_deserialize_attr_schema_roundtrip(schema: AttrSchema) -> None: + encoded = serialize_attr_schema(schema) + restored = deserialize_attr_schema(encoded, key=schema.key) + assert restored == schema From 9aa9c3a686754a9c540ffd8bbe7c8507579eeecd Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Wed, 18 Feb 2026 11:03:12 +0900 Subject: [PATCH 08/12] udpated serialization strategies --- .../graph/_test/test_graph_backends.py | 3 +- src/tracksdata/utils/_dtypes.py | 113 +++++++----------- .../utils/_test/test_dtype_serialization.py | 33 +++-- 3 files changed, 68 insertions(+), 81 deletions(-) diff --git a/src/tracksdata/graph/_test/test_graph_backends.py b/src/tracksdata/graph/_test/test_graph_backends.py index 99ee386e..63013a4d 100644 --- a/src/tracksdata/graph/_test/test_graph_backends.py +++ b/src/tracksdata/graph/_test/test_graph_backends.py @@ -1,3 +1,4 @@ +import datetime as dt from pathlib import Path from typing import Any @@ -1493,7 +1494,7 @@ def test_from_other_preserves_schema_roundtrip(target_cls: type[BaseGraph], targ "attr_UInt32": np.uint32(10), "attr_UInt64": np.uint64(11), "attr_Date": pl.date(2024, 1, 1), - "attr_Datetime": pl.datetime(2024, 1, 1, 12, 0, 0), + "attr_Datetime": dt.datetime(2024, 1, 1, 12, 0, 0), "attr_Boolean": True, "attr_Array(Float32, shape=(3,))": np.array([1.0, 2.0, 3.0], dtype=np.float32), "attr_List(Int32)": [1, 2, 3], diff --git a/src/tracksdata/utils/_dtypes.py b/src/tracksdata/utils/_dtypes.py index 0245acdc..05338fa5 100644 --- a/src/tracksdata/utils/_dtypes.py +++ b/src/tracksdata/utils/_dtypes.py @@ -478,66 +478,48 @@ def sqlalchemy_type_to_polars_dtype(sa_type: TypeEngine) -> pl.DataType: return pl.Object -def serialize_polars_dtype(dtype: pl.DataType) -> str: - """ - Serializes a Polars dtype to a safe, cross-platform base64 string - using the Arrow IPC format. - """ - # Wrap the dtype in an empty DataFrame schema - # We use an empty DataFrame so no actual data is processed, only metadata. - dummy_df = pl.DataFrame(schema={"dummy": dtype}) - # Write to Arrow IPC (binary buffer) - # IPC is stable across versions/platforms unlike internal serialization. - buffer = io.BytesIO() - dummy_df.write_ipc(buffer) - # Encode binary to a standard string - return base64.b64encode(buffer.getvalue()).decode("utf-8") +def _normalize_default_for_dtype(default_value: Any, dtype: pl.DataType) -> Any: + if isinstance(dtype, pl.Array | pl.List) and isinstance(default_value, np.ndarray): + return default_value.tolist() + return default_value -def deserialize_polars_dtype(encoded_dtype: str) -> pl.DataType: - """ - Recovers a Polars dtype from a base64 string. - """ - # Decode string back to binary - data = base64.b64decode(encoded_dtype) - # Read the IPC buffer - buffer = io.BytesIO(data) - restored_df = pl.read_ipc(buffer) - # Extract the dtype from the schema - return restored_df.schema["dummy"] +def _normalize_deserialized_default(default_value: Any, dtype: pl.DataType) -> Any: + if isinstance(dtype, pl.Array): + if isinstance(default_value, pl.Series): + default_value = default_value.to_list() + numpy_dtype = polars_dtype_to_numpy_dtype(dtype.inner, allow_sequence=True) + return np.asarray(default_value, dtype=numpy_dtype).reshape(dtype.shape) + + if isinstance(dtype, pl.List): + if isinstance(default_value, pl.Series): + return default_value.to_list() + if isinstance(default_value, np.ndarray): + return default_value.tolist() + + return default_value -_ATTR_SCHEMA_DTYPE_COL = "__attr_schema_dtype__" -_ATTR_SCHEMA_DEFAULT_COL = "__attr_schema_default__" -_ATTR_SCHEMA_DTYPE_PICKLE_COL = "__attr_schema_dtype_pickle__" +_ATTR_SCHEMA_VALUE_COL = "__attr_schema_value__" +_ATTR_SCHEMA_FALLBACK_COL = "__attr_schema_fallback__" def serialize_attr_schema(schema: AttrSchema) -> str: """ Serialize an AttrSchema into a base64-encoded Arrow IPC payload. - The payload stores dtype metadata and the default value in the same - DataFrame serialization so schema roundtrip can restore both fields. + The primary format stores schema.default_value in the first row of a + single dummy column whose dtype is schema.dtype. This keeps dtype and + default value in one Arrow IPC payload. """ - default_payload = dumps(schema.default_value) - dtype_payload = dumps(schema.dtype) + normalized_default = _normalize_default_for_dtype(schema.default_value, schema.dtype) df = pl.DataFrame( { - _ATTR_SCHEMA_DTYPE_COL: pl.Series( - _ATTR_SCHEMA_DTYPE_COL, - values=[None], + _ATTR_SCHEMA_VALUE_COL: pl.Series( + _ATTR_SCHEMA_VALUE_COL, + values=[normalized_default], dtype=schema.dtype, ), - _ATTR_SCHEMA_DEFAULT_COL: pl.Series( - _ATTR_SCHEMA_DEFAULT_COL, - values=[default_payload], - dtype=pl.Binary, - ), - _ATTR_SCHEMA_DTYPE_PICKLE_COL: pl.Series( - _ATTR_SCHEMA_DTYPE_PICKLE_COL, - values=[dtype_payload], - dtype=pl.Binary, - ), } ) @@ -545,23 +527,14 @@ def serialize_attr_schema(schema: AttrSchema) -> str: try: df.write_ipc(buffer) except Exception: - # Fallback for dtypes that cannot be represented in Arrow IPC schema - # (e.g., pl.Object). Keep everything in one DataFrame payload. + # Some dtypes (e.g. pl.Object) cannot roundtrip through Arrow IPC schema. + # Store pickled (dtype, default) in the first row of a binary dummy column. + fallback_payload = dumps((schema.dtype, schema.default_value)) fallback_df = pl.DataFrame( { - _ATTR_SCHEMA_DTYPE_COL: pl.Series( - _ATTR_SCHEMA_DTYPE_COL, - values=[None], - dtype=pl.Binary, - ), - _ATTR_SCHEMA_DEFAULT_COL: pl.Series( - _ATTR_SCHEMA_DEFAULT_COL, - values=[default_payload], - dtype=pl.Binary, - ), - _ATTR_SCHEMA_DTYPE_PICKLE_COL: pl.Series( - _ATTR_SCHEMA_DTYPE_PICKLE_COL, - values=[dtype_payload], + _ATTR_SCHEMA_FALLBACK_COL: pl.Series( + _ATTR_SCHEMA_FALLBACK_COL, + values=[fallback_payload], dtype=pl.Binary, ), } @@ -580,21 +553,21 @@ def deserialize_attr_schema(encoded_schema: str, *, key: str) -> AttrSchema: buffer = io.BytesIO(data) restored_df = pl.read_ipc(buffer) - if _ATTR_SCHEMA_DTYPE_PICKLE_COL in restored_df.columns: - dtype_pickle = restored_df[_ATTR_SCHEMA_DTYPE_PICKLE_COL][0] - else: - dtype_pickle = None - - if dtype_pickle is not None: - dtype = loads(dtype_pickle) + if _ATTR_SCHEMA_VALUE_COL in restored_df.columns: + dtype = restored_df.schema[_ATTR_SCHEMA_VALUE_COL] + default_value = restored_df[_ATTR_SCHEMA_VALUE_COL][0] + elif _ATTR_SCHEMA_FALLBACK_COL in restored_df.columns: + fallback_payload = restored_df[_ATTR_SCHEMA_FALLBACK_COL][0] + if fallback_payload is None: + raise ValueError("Fallback schema payload is missing.") + dtype, default_value = loads(fallback_payload) else: - dtype = restored_df.schema[_ATTR_SCHEMA_DTYPE_COL] + raise ValueError("Unrecognized attr schema payload format.") if not pl.datatypes.is_polars_dtype(dtype): raise TypeError(f"Decoded value is not a polars dtype: {type(dtype)}") - default_payload = restored_df[_ATTR_SCHEMA_DEFAULT_COL][0] - default_value = loads(default_payload) if default_payload is not None else None + default_value = _normalize_deserialized_default(default_value, dtype) return AttrSchema(key=key, dtype=dtype, default_value=default_value) diff --git a/src/tracksdata/utils/_test/test_dtype_serialization.py b/src/tracksdata/utils/_test/test_dtype_serialization.py index 51b2659a..1f406224 100644 --- a/src/tracksdata/utils/_test/test_dtype_serialization.py +++ b/src/tracksdata/utils/_test/test_dtype_serialization.py @@ -1,5 +1,6 @@ import base64 import binascii +import io import numpy as np import polars as pl @@ -8,9 +9,7 @@ from tracksdata.utils._dtypes import ( AttrSchema, deserialize_attr_schema, - deserialize_polars_dtype, serialize_attr_schema, - serialize_polars_dtype, ) @@ -28,28 +27,29 @@ pl.Datetime("us", "UTC"), ], ) -def test_serialize_deserialize_polars_dtype_roundtrip(dtype: pl.DataType) -> None: - encoded = serialize_polars_dtype(dtype) +def test_serialize_deserialize_attr_schema_dtype_roundtrip(dtype: pl.DataType) -> None: + schema = AttrSchema(key="dummy", dtype=dtype) + encoded = serialize_attr_schema(schema) assert isinstance(encoded, str) assert encoded assert base64.b64decode(encoded) - restored_dtype = deserialize_polars_dtype(encoded) + restored = deserialize_attr_schema(encoded, key=schema.key) - assert restored_dtype == dtype + assert restored == schema -def test_deserialize_polars_dtype_invalid_base64_raises() -> None: +def test_deserialize_attr_schema_invalid_base64_raises() -> None: with pytest.raises(binascii.Error): - deserialize_polars_dtype("not-base64") + deserialize_attr_schema("not-base64", key="dummy") -def test_deserialize_polars_dtype_non_ipc_payload_raises() -> None: +def test_deserialize_attr_schema_non_ipc_payload_raises() -> None: encoded = base64.b64encode(b"not-arrow-ipc").decode("utf-8") with pytest.raises((OSError, pl.exceptions.PolarsError)): - deserialize_polars_dtype(encoded) + deserialize_attr_schema(encoded, key="dummy") @pytest.mark.parametrize( @@ -68,3 +68,16 @@ def test_serialize_deserialize_attr_schema_roundtrip(schema: AttrSchema) -> None encoded = serialize_attr_schema(schema) restored = deserialize_attr_schema(encoded, key=schema.key) assert restored == schema + + +def test_serialize_attr_schema_stores_default_in_dummy_row() -> None: + schema = AttrSchema(key="score", dtype=pl.Float64, default_value=1.25) + encoded = serialize_attr_schema(schema) + + payload = base64.b64decode(encoded) + df = pl.read_ipc(io.BytesIO(payload)) + + assert "__attr_schema_value__" in df.columns + assert df.schema["__attr_schema_value__"] == pl.Float64 + assert df["__attr_schema_value__"][0] == 1.25 + assert "__attr_schema_dtype_pickle__" not in df.columns From 7e61ac33ee6ec0976c01eed1e084fa86c2ee7bcc Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Wed, 18 Feb 2026 11:11:10 +0900 Subject: [PATCH 09/12] solved failing tests --- src/tracksdata/solvers/_ilp_solver.py | 9 ++++++++- src/tracksdata/solvers/_nearest_neighbors_solver.py | 3 ++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/tracksdata/solvers/_ilp_solver.py b/src/tracksdata/solvers/_ilp_solver.py index 6485eaf7..3f6676d5 100644 --- a/src/tracksdata/solvers/_ilp_solver.py +++ b/src/tracksdata/solvers/_ilp_solver.py @@ -175,6 +175,9 @@ def _evaluate_expr( expr: Attr, df: pl.DataFrame, ) -> list[float]: + if df.is_empty(): + return [] + if len(expr.expr_columns) == 0: return [expr.evaluate(df).item()] * len(df) else: @@ -388,7 +391,11 @@ def solve( node_attr_keys.extend(self.merge_weight_expr.columns) nodes_df = graph.node_attrs(attr_keys=node_attr_keys) - edges_df = graph.edge_attrs(attr_keys=self.edge_weight_expr.columns) + # When no edges exist, avoid requesting edge weight columns that may not + # be registered in the backend schema yet. _solve() handles this as a + # regular "no edges" ValueError. + edge_attr_keys = [] if graph.num_edges() == 0 else self.edge_weight_expr.columns + edges_df = graph.edge_attrs(attr_keys=edge_attr_keys) self._add_objective_and_variables(nodes_df, edges_df) self._add_continuous_flow_constraints(nodes_df[DEFAULT_ATTR_KEYS.NODE_ID].to_list(), edges_df) diff --git a/src/tracksdata/solvers/_nearest_neighbors_solver.py b/src/tracksdata/solvers/_nearest_neighbors_solver.py index 34011dee..21915290 100644 --- a/src/tracksdata/solvers/_nearest_neighbors_solver.py +++ b/src/tracksdata/solvers/_nearest_neighbors_solver.py @@ -235,7 +235,8 @@ def solve( The graph view of the solution if `return_solution` is True, otherwise None. """ # get edges and sort them by weight - edges_df = graph.edge_attrs(attr_keys=self.edge_weight_expr.columns) + edge_attr_keys = [] if graph.num_edges() == 0 else self.edge_weight_expr.columns + edges_df = graph.edge_attrs(attr_keys=edge_attr_keys) if len(edges_df) == 0: raise ValueError("No edges found in the graph, there is nothing to solve.") From e5968bf44aa9697756df354b97f005f05c2a2bd6 Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Wed, 18 Feb 2026 11:45:10 +0900 Subject: [PATCH 10/12] added test for shape-less pl.Array (xfail) --- src/tracksdata/graph/_sql_graph.py | 8 +------- .../graph/_test/test_graph_backends.py | 17 +++++++++++++++++ 2 files changed, 18 insertions(+), 7 deletions(-) diff --git a/src/tracksdata/graph/_sql_graph.py b/src/tracksdata/graph/_sql_graph.py index b9da5c1a..2d0ea3ad 100644 --- a/src/tracksdata/graph/_sql_graph.py +++ b/src/tracksdata/graph/_sql_graph.py @@ -590,13 +590,7 @@ def _get_attr_schemas_from_store(self, *, is_node: bool) -> dict[str, AttrSchema schemas: dict[str, AttrSchema] = {} for key, encoded_schema in section.items(): - try: - schemas[key] = deserialize_attr_schema(encoded_schema, key=key) - except Exception: - LOG.warning( - "Failed to deserialize SQL schema metadata for key '%s'. Falling back to table inference.", - key, - ) + schemas[key] = deserialize_attr_schema(encoded_schema, key=key) return schemas diff --git a/src/tracksdata/graph/_test/test_graph_backends.py b/src/tracksdata/graph/_test/test_graph_backends.py index 63013a4d..f60d9641 100644 --- a/src/tracksdata/graph/_test/test_graph_backends.py +++ b/src/tracksdata/graph/_test/test_graph_backends.py @@ -1523,6 +1523,23 @@ def test_from_other_preserves_schema_roundtrip(target_cls: type[BaseGraph], targ assert graph3.edge_attrs().schema == graph.edge_attrs().schema +@pytest.mark.xfail(reason="This is because of the lack of support of shape-less pl.Array in write_ipc of polars.") +def test_from_other_with_array_no_shape(): + """Test that from_other raises an error when trying to copy array attributes without shape information.""" + graph = RustWorkXGraph() + graph.add_node_attr_key("array_attr", pl.Array) + graph.add_node({"t": 0, "array_attr": np.array([1.0, 2.0, 3.0], dtype=np.float32)}) + + # This should raise an error because the schema does not include shape information + graph2 = SQLGraph.from_other( + graph, drivername="sqlite", database=":memory:", engine_kwargs={"connect_args": {"check_same_thread": False}} + ) + assert graph2.num_nodes() == graph.num_nodes() + assert set(graph2.node_attr_keys()) == set(graph.node_attr_keys()) + assert graph2._node_attr_schemas() == graph._node_attr_schemas() + assert graph2.node_attrs().schema == graph.node_attrs().schema + + @pytest.mark.parametrize( ("target_cls", "target_kwargs"), [ From b4acde3e8985b7eacd5a24328f3d4840dc49889c Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Thu, 19 Feb 2026 10:43:59 +0900 Subject: [PATCH 11/12] working --- src/tracksdata/graph/_sql_graph.py | 65 ++++--------------- .../graph/_test/test_graph_backends.py | 6 +- 2 files changed, 16 insertions(+), 55 deletions(-) diff --git a/src/tracksdata/graph/_sql_graph.py b/src/tracksdata/graph/_sql_graph.py index 2d0ea3ad..20ea59a9 100644 --- a/src/tracksdata/graph/_sql_graph.py +++ b/src/tracksdata/graph/_sql_graph.py @@ -443,7 +443,8 @@ class SQLGraph(BaseGraph): """ node_id_time_multiplier: int = 1_000_000_000 - _PRIVATE_SQL_SCHEMA_STORE_KEY = "__private_sql_attr_schema_store" + _PRIVATE_SQL_NODE_SCHEMA_STORE_KEY = "__private_sql_node_attr_schema_store" + _PRIVATE_SQL_EDGE_SCHEMA_STORE_KEY = "__private_sql_edge_attr_schema_store" Base: type[DeclarativeBase] Node: type[DeclarativeBase] Edge: type[DeclarativeBase] @@ -552,69 +553,26 @@ class Metadata(Base): self.Overlap = Overlap self.Metadata = Metadata - @classmethod - def _empty_attr_schema_store(cls) -> dict[str, dict[str, str]]: - return {"node": {}, "edge": {}} - - def _attr_schema_store(self) -> dict[str, dict[str, str]]: - store = self._private_metadata.get(self._PRIVATE_SQL_SCHEMA_STORE_KEY, {}) - if not isinstance(store, dict): - return self._empty_attr_schema_store() - - normalized = self._empty_attr_schema_store() - for section_key in ("node", "edge"): - section = store.get(section_key, {}) - if not isinstance(section, dict): - continue - for key, encoded_schema in section.items(): - if isinstance(encoded_schema, str): - normalized[section_key][key] = encoded_schema - - return normalized - - def _set_attr_schema_store(self, store: dict[str, dict[str, str]]) -> None: - normalized = self._empty_attr_schema_store() - for section_key in ("node", "edge"): - section = store.get(section_key, {}) - if not isinstance(section, dict): - continue - for key, encoded_schema in section.items(): - if isinstance(encoded_schema, str): - normalized[section_key][key] = encoded_schema - - self._private_metadata.update(**{self._PRIVATE_SQL_SCHEMA_STORE_KEY: normalized}) - - def _get_attr_schemas_from_store(self, *, is_node: bool) -> dict[str, AttrSchema]: - section_key = "node" if is_node else "edge" - section = self._attr_schema_store()[section_key] - - schemas: dict[str, AttrSchema] = {} - for key, encoded_schema in section.items(): - schemas[key] = deserialize_attr_schema(encoded_schema, key=key) - - return schemas - - def _set_attr_schemas_to_store(self, *, is_node: bool, schemas: dict[str, AttrSchema]) -> None: - section_key = "node" if is_node else "edge" - store = self._attr_schema_store() - store[section_key] = {key: serialize_attr_schema(schema) for key, schema in schemas.items()} - self._set_attr_schema_store(store) @property def __node_attr_schemas(self) -> dict[str, AttrSchema]: - return self._get_attr_schemas_from_store(is_node=True) + encoded_schemas = self._private_metadata.get(self._PRIVATE_SQL_NODE_SCHEMA_STORE_KEY, {}) + return {key: deserialize_attr_schema(encoded_schema, key=key) for key, encoded_schema in encoded_schemas.items()} @__node_attr_schemas.setter def __node_attr_schemas(self, schemas: dict[str, AttrSchema]) -> None: - self._set_attr_schemas_to_store(is_node=True, schemas=schemas) + encoded_schemas = {key: serialize_attr_schema(schema) for key, schema in schemas.items()} + self._private_metadata[self._PRIVATE_SQL_NODE_SCHEMA_STORE_KEY] = encoded_schemas @property def __edge_attr_schemas(self) -> dict[str, AttrSchema]: - return self._get_attr_schemas_from_store(is_node=False) + encoded_schemas = self._private_metadata.get(self._PRIVATE_SQL_EDGE_SCHEMA_STORE_KEY, {}) + return {key: deserialize_attr_schema(encoded_schema, key=key) for key, encoded_schema in encoded_schemas.items()} @__edge_attr_schemas.setter def __edge_attr_schemas(self, schemas: dict[str, AttrSchema]) -> None: - self._set_attr_schemas_to_store(is_node=False, schemas=schemas) + encoded_schemas = {key: serialize_attr_schema(schema) for key, schema in schemas.items()} + self._private_metadata[self._PRIVATE_SQL_EDGE_SCHEMA_STORE_KEY] = encoded_schemas def _init_schemas_from_tables(self) -> None: """ @@ -2094,7 +2052,8 @@ def _metadata(self) -> dict[str, Any]: def _private_metadata_for_copy(self) -> dict[str, Any]: private_metadata = super()._private_metadata_for_copy() - private_metadata.pop(self._PRIVATE_SQL_SCHEMA_STORE_KEY, None) + private_metadata.pop(self._PRIVATE_SQL_NODE_SCHEMA_STORE_KEY, None) + private_metadata.pop(self._PRIVATE_SQL_EDGE_SCHEMA_STORE_KEY, None) return private_metadata def _update_metadata(self, **kwargs) -> None: diff --git a/src/tracksdata/graph/_test/test_graph_backends.py b/src/tracksdata/graph/_test/test_graph_backends.py index f60d9641..8aa2324f 100644 --- a/src/tracksdata/graph/_test/test_graph_backends.py +++ b/src/tracksdata/graph/_test/test_graph_backends.py @@ -1771,10 +1771,12 @@ def test_sql_schema_metadata_not_copied_to_in_memory_graphs() -> None: ) sql_graph.add_edge(n1, n2, {"edge_score_default": 0.75}) - assert SQLGraph._PRIVATE_SQL_SCHEMA_STORE_KEY in sql_graph._private_metadata + assert SQLGraph._PRIVATE_SQL_NODE_SCHEMA_STORE_KEY in sql_graph._private_metadata + assert SQLGraph._PRIVATE_SQL_EDGE_SCHEMA_STORE_KEY in sql_graph._private_metadata rx_graph = RustWorkXGraph.from_other(sql_graph) - assert SQLGraph._PRIVATE_SQL_SCHEMA_STORE_KEY not in rx_graph._metadata() + assert SQLGraph._PRIVATE_SQL_NODE_SCHEMA_STORE_KEY not in rx_graph._metadata() + assert SQLGraph._PRIVATE_SQL_EDGE_SCHEMA_STORE_KEY not in rx_graph._metadata() sql_graph_roundtrip = SQLGraph.from_other( rx_graph, From cc55976e997cf7c48ffcd750b8bf57220b0ba8c0 Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Thu, 19 Feb 2026 10:48:06 +0900 Subject: [PATCH 12/12] simplified code --- src/tracksdata/graph/_sql_graph.py | 133 +++++++++++++++++------------ 1 file changed, 79 insertions(+), 54 deletions(-) diff --git a/src/tracksdata/graph/_sql_graph.py b/src/tracksdata/graph/_sql_graph.py index 20ea59a9..65784572 100644 --- a/src/tracksdata/graph/_sql_graph.py +++ b/src/tracksdata/graph/_sql_graph.py @@ -479,9 +479,6 @@ def __init__( self.Base.metadata.create_all(self._engine) - # Initialize schemas from existing table columns - self._init_schemas_from_tables() - self._max_id_per_time = {} self._update_max_id_per_time() @@ -553,72 +550,102 @@ class Metadata(Base): self.Overlap = Overlap self.Metadata = Metadata + @staticmethod + def _default_node_attr_schemas() -> dict[str, AttrSchema]: + return { + DEFAULT_ATTR_KEYS.T: AttrSchema(key=DEFAULT_ATTR_KEYS.T, dtype=pl.Int32), + DEFAULT_ATTR_KEYS.NODE_ID: AttrSchema(key=DEFAULT_ATTR_KEYS.NODE_ID, dtype=pl.Int64), + } + + @staticmethod + def _default_edge_attr_schemas() -> dict[str, AttrSchema]: + return { + DEFAULT_ATTR_KEYS.EDGE_ID: AttrSchema(key=DEFAULT_ATTR_KEYS.EDGE_ID, dtype=pl.Int32), + DEFAULT_ATTR_KEYS.EDGE_SOURCE: AttrSchema(key=DEFAULT_ATTR_KEYS.EDGE_SOURCE, dtype=pl.Int64), + DEFAULT_ATTR_KEYS.EDGE_TARGET: AttrSchema(key=DEFAULT_ATTR_KEYS.EDGE_TARGET, dtype=pl.Int64), + } + + def _attr_schemas_from_metadata( + self, + *, + table_class: type[DeclarativeBase], + metadata_key: str, + default_schemas: dict[str, AttrSchema], + preferred_order: Sequence[str], + ) -> dict[str, AttrSchema]: + encoded_schemas = self._private_metadata.get(metadata_key, {}) + schemas = default_schemas.copy() + schemas.update( + {key: deserialize_attr_schema(encoded_schema, key=key) for key, encoded_schema in encoded_schemas.items()} + ) + + # Legacy databases may not have schema metadata for all columns. + for column_name, column in table_class.__table__.columns.items(): + if column_name not in schemas: + schemas[column_name] = AttrSchema( + key=column_name, + dtype=sqlalchemy_type_to_polars_dtype(column.type), + ) + + ordered_keys = [key for key in preferred_order if key in schemas] + ordered_keys.extend(key for key in table_class.__table__.columns.keys() if key not in ordered_keys) + ordered_keys.extend(key for key in schemas if key not in ordered_keys) + return {key: schemas[key] for key in ordered_keys} + + def _attr_schemas_for_table(self, table_class: type[DeclarativeBase]) -> dict[str, AttrSchema]: + if table_class.__tablename__ == self.Node.__tablename__: + return self._node_attr_schemas() + return self._edge_attr_schemas() + + @staticmethod + def _is_pickled_sql_type(column_type: TypeEngine) -> bool: + return isinstance(column_type, sa.PickleType | sa.LargeBinary) @property def __node_attr_schemas(self) -> dict[str, AttrSchema]: - encoded_schemas = self._private_metadata.get(self._PRIVATE_SQL_NODE_SCHEMA_STORE_KEY, {}) - return {key: deserialize_attr_schema(encoded_schema, key=key) for key, encoded_schema in encoded_schemas.items()} + return self._attr_schemas_from_metadata( + table_class=self.Node, + metadata_key=self._PRIVATE_SQL_NODE_SCHEMA_STORE_KEY, + default_schemas=self._default_node_attr_schemas(), + preferred_order=[DEFAULT_ATTR_KEYS.T, DEFAULT_ATTR_KEYS.NODE_ID], + ) @__node_attr_schemas.setter def __node_attr_schemas(self, schemas: dict[str, AttrSchema]) -> None: + merged_schemas = self._default_node_attr_schemas() + merged_schemas.update(schemas) + schemas = merged_schemas encoded_schemas = {key: serialize_attr_schema(schema) for key, schema in schemas.items()} self._private_metadata[self._PRIVATE_SQL_NODE_SCHEMA_STORE_KEY] = encoded_schemas @property def __edge_attr_schemas(self) -> dict[str, AttrSchema]: - encoded_schemas = self._private_metadata.get(self._PRIVATE_SQL_EDGE_SCHEMA_STORE_KEY, {}) - return {key: deserialize_attr_schema(encoded_schema, key=key) for key, encoded_schema in encoded_schemas.items()} + return self._attr_schemas_from_metadata( + table_class=self.Edge, + metadata_key=self._PRIVATE_SQL_EDGE_SCHEMA_STORE_KEY, + default_schemas=self._default_edge_attr_schemas(), + preferred_order=[ + DEFAULT_ATTR_KEYS.EDGE_ID, + DEFAULT_ATTR_KEYS.EDGE_SOURCE, + DEFAULT_ATTR_KEYS.EDGE_TARGET, + ], + ) @__edge_attr_schemas.setter def __edge_attr_schemas(self, schemas: dict[str, AttrSchema]) -> None: + merged_schemas = self._default_edge_attr_schemas() + merged_schemas.update(schemas) + schemas = merged_schemas encoded_schemas = {key: serialize_attr_schema(schema) for key, schema in schemas.items()} self._private_metadata[self._PRIVATE_SQL_EDGE_SCHEMA_STORE_KEY] = encoded_schemas - def _init_schemas_from_tables(self) -> None: - """ - Initialize AttrSchema objects from existing database table columns. - This is used when loading an existing graph from the database. - """ - node_column_names = list(self.Node.__table__.columns.keys()) - preferred_node_order = [DEFAULT_ATTR_KEYS.T, DEFAULT_ATTR_KEYS.NODE_ID] - ordered_node_columns = [name for name in preferred_node_order if name in node_column_names] - ordered_node_columns.extend(name for name in node_column_names if name not in preferred_node_order) - - node_schemas = {k: v for k, v in self.__node_attr_schemas.items() if k in ordered_node_columns} - for column_name in ordered_node_columns: - if column_name in node_schemas: - continue - column = self.Node.__table__.columns[column_name] - node_schemas[column_name] = AttrSchema( - key=column_name, - dtype=sqlalchemy_type_to_polars_dtype(column.type), - ) - self.__node_attr_schemas = node_schemas - - # Initialize edge schemas from Edge table columns - edge_column_names = list(self.Edge.__table__.columns.keys()) - edge_schemas = {k: v for k, v in self.__edge_attr_schemas.items() if k in edge_column_names} - for column_name in self.Edge.__table__.columns.keys(): - if column_name in edge_schemas: - continue - column = self.Edge.__table__.columns[column_name] - edge_schemas[column_name] = AttrSchema( - key=column_name, - dtype=sqlalchemy_type_to_polars_dtype(column.type), - ) - self.__edge_attr_schemas = edge_schemas - def _restore_pickled_column_types(self, table: sa.Table) -> None: for column in table.columns: if isinstance(column.type, sa.LargeBinary): column.type = sa.PickleType() def _polars_schema_override(self, table_class: type[DeclarativeBase]) -> SchemaDict: - # Get the appropriate schema dict based on table class - if table_class.__tablename__ == self.Node.__tablename__: - schemas = self._node_attr_schemas() - else: - schemas = self._edge_attr_schemas() + schemas = self._attr_schemas_for_table(table_class) # Return schema overrides for columns safely represented in SQL. # Pickled columns are unpickled and casted in a second pass. @@ -627,21 +654,19 @@ def _polars_schema_override(self, table_class: type[DeclarativeBase]) -> SchemaD for key, schema in schemas.items() if ( key in table_class.__table__.columns - and not isinstance(table_class.__table__.columns[key].type, sa.PickleType | sa.LargeBinary) - and not (schema.dtype == pl.Object or isinstance(schema.dtype, pl.Array | pl.List)) + and not self._is_pickled_sql_type(table_class.__table__.columns[key].type) ) } def _cast_array_columns(self, table_class: type[DeclarativeBase], df: pl.DataFrame) -> pl.DataFrame: - # Get the appropriate schema dict based on table class - if table_class.__tablename__ == self.Node.__tablename__: - schemas = self._node_attr_schemas() - else: - schemas = self._edge_attr_schemas() + schemas = self._attr_schemas_for_table(table_class) casts: list[pl.Series] = [] for key, schema in schemas.items(): - if key not in df.columns: + if key not in df.columns or key not in table_class.__table__.columns: + continue + + if not self._is_pickled_sql_type(table_class.__table__.columns[key].type): continue try: