Skip to content
2 changes: 1 addition & 1 deletion src/tracksdata/array/_graph_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def _validate_shape(
"""Helper function to validate the shape argument."""
if shape is None:
try:
shape = graph.metadata()["shape"]
shape = graph.metadata["shape"]
except KeyError as e:
raise KeyError(
f"`shape` is required to `{func_name}`. "
Expand Down
2 changes: 1 addition & 1 deletion src/tracksdata/functional/_test/test_napari.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def test_napari_conversion(metadata_shape: bool) -> None:

shape = (2, 10, 22, 32)
if metadata_shape:
graph.update_metadata(shape=shape)
graph.metadata.update(shape=shape)
arg_shape = None
else:
arg_shape = shape
Expand Down
4 changes: 2 additions & 2 deletions src/tracksdata/graph/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
"""Graph backends for representing tracking data as directed graphs in memory or on disk."""

from tracksdata.graph._base_graph import BaseGraph
from tracksdata.graph._base_graph import BaseGraph, MetadataView
from tracksdata.graph._graph_view import GraphView
from tracksdata.graph._rustworkx_graph import IndexedRXGraph, RustWorkXGraph
from tracksdata.graph._sql_graph import SQLGraph

InMemoryGraph = RustWorkXGraph

__all__ = ["BaseGraph", "GraphView", "InMemoryGraph", "IndexedRXGraph", "RustWorkXGraph", "SQLGraph"]
__all__ = ["BaseGraph", "GraphView", "InMemoryGraph", "IndexedRXGraph", "MetadataView", "RustWorkXGraph", "SQLGraph"]
160 changes: 128 additions & 32 deletions src/tracksdata/graph/_base_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,75 @@
T = TypeVar("T", bound="BaseGraph")


class MetadataView(dict[str, Any]):
"""Dictionary-like metadata view that syncs mutations back to the graph."""

_MISSING = object()

def __init__(
self,
graph: "BaseGraph",
data: dict[str, Any],
*,
is_public: bool = True,
) -> None:
super().__init__(data)
self._graph = graph
self._is_public = is_public

def __setitem__(self, key: str, value: Any) -> None:
self._graph._set_metadata_with_validation(is_public=self._is_public, **{key: value})
super().__setitem__(key, value)

def __delitem__(self, key: str) -> None:
self._graph._remove_metadata_with_validation(key, is_public=self._is_public)
super().__delitem__(key)

def pop(self, key: str, default: Any = _MISSING) -> Any:
self._graph._validate_metadata_key(key, is_public=self._is_public)

if key not in self:
if default is self._MISSING:
raise KeyError(key)
return default

value = super().__getitem__(key)
self._graph._remove_metadata_with_validation(key, is_public=self._is_public)
super().pop(key, None)
return value

def popitem(self) -> tuple[str, Any]:
key, value = super().popitem()
self._graph._remove_metadata_with_validation(key, is_public=self._is_public)
return key, value

def clear(self) -> None:
keys = list(self.keys())
for key in keys:
self._graph._remove_metadata_with_validation(key, is_public=self._is_public)
super().clear()

def setdefault(self, key: str, default: Any = None) -> Any:
if key in self:
return super().__getitem__(key)
self._graph._set_metadata_with_validation(is_public=self._is_public, **{key: default})
super().__setitem__(key, default)
return default

def update(self, *args, **kwargs) -> None:
updates = dict(*args, **kwargs)
if updates:
self._graph._set_metadata_with_validation(is_public=self._is_public, **updates)
super().update(updates)


class BaseGraph(abc.ABC):
"""
Base class for a graph backend.
"""

_PRIVATE_METADATA_PREFIX = "__private_"

node_added = Signal(int)
node_removed = Signal(int)

Expand Down Expand Up @@ -1186,7 +1250,8 @@ def from_other(cls: type[T], other: "BaseGraph", **kwargs) -> T:
node_attrs = node_attrs.drop(DEFAULT_ATTR_KEYS.NODE_ID)

graph = cls(**kwargs)
graph.update_metadata(**other.metadata())
graph.metadata.update(other.metadata)
graph._private_metadata.update(other._private_metadata_for_copy())

current_node_attr_schemas = graph._node_attr_schemas()
for k, v in other._node_attr_schemas().items():
Expand Down Expand Up @@ -1216,7 +1281,6 @@ def from_other(cls: type[T], other: "BaseGraph", **kwargs) -> T:
current_edge_attr_schemas = graph._edge_attr_schemas()
for k, v in other._edge_attr_schemas().items():
if k not in current_edge_attr_schemas:
print(f"Adding edge attribute key: {k} with dtype: {v.dtype} and default value: {v.default_value}")
graph.add_edge_attr_key(k, v.dtype, v.default_value)

edge_attrs = edge_attrs.with_columns(
Expand Down Expand Up @@ -1786,7 +1850,7 @@ def to_geff(
for k, v in edge_attrs.to_dict().items()
}

td_metadata = self.metadata().copy()
td_metadata = self.metadata.copy()
td_metadata.pop("geff", None) # avoid geff being written multiple times

geff_metadata = geff.GeffMetadata(
Expand Down Expand Up @@ -1824,57 +1888,89 @@ def to_geff(
zarr_format=zarr_format,
)

@abc.abstractmethod
def metadata(self) -> dict[str, Any]:
@property
def metadata(self) -> MetadataView:
"""
Return the metadata of the graph.

Returns
-------
dict[str, Any]
MetadataView
The metadata of the graph as a dictionary.

Examples
--------
```python
metadata = graph.metadata()
metadata = graph.metadata
print(metadata["shape"])
```
"""
return MetadataView(
graph=self,
data={k: v for k, v in self._metadata().items() if not self._is_private_metadata_key(k)},
is_public=True,
)

@abc.abstractmethod
def update_metadata(self, **kwargs) -> None:
"""
Set or update metadata for the graph.
@property
def _private_metadata(self) -> MetadataView:
return MetadataView(
graph=self,
data={k: v for k, v in self._metadata().items() if self._is_private_metadata_key(k)},
is_public=False,
)

Parameters
----------
**kwargs : Any
The metadata items to set by key. Values will be stored as JSON.
@classmethod
def _is_private_metadata_key(cls, key: str) -> bool:
return key.startswith(cls._PRIVATE_METADATA_PREFIX)

def _validate_metadata_key(self, key: str, *, is_public: bool) -> None:
if not isinstance(key, str):
raise TypeError(f"Metadata key must be a string. Got {type(key)}.")
is_private_key = self._is_private_metadata_key(key)
if is_public and is_private_key:
raise ValueError(f"Metadata key '{key}' is reserved for internal use.")
if not is_public and not is_private_key:
raise ValueError(
f"Metadata key '{key}' is not private. Private metadata keys must start with "
f"'{self._PRIVATE_METADATA_PREFIX}'."
)

Examples
--------
```python
graph.update_metadata(shape=[1, 25, 25], path="path/to/image.ome.zarr")
graph.update_metadata(description="Tracking data from experiment 1")
```
def _validate_metadata_keys(self, keys: Sequence[str], *, is_public: bool) -> None:
for key in keys:
self._validate_metadata_key(key, is_public=is_public)

def _set_metadata_with_validation(self, is_public: bool = True, **kwargs) -> None:
self._validate_metadata_keys(kwargs.keys(), is_public=is_public)
self._update_metadata(**kwargs)

def _remove_metadata_with_validation(self, key: str, *, is_public: bool = True) -> None:
self._validate_metadata_key(key, is_public=is_public)
self._remove_metadata(key)

def _private_metadata_for_copy(self) -> dict[str, Any]:
"""
Return private metadata entries that should be propagated by `from_other`.

Backends can override this to exclude backend-specific private metadata.
"""
return dict(self._private_metadata)

@abc.abstractmethod
def remove_metadata(self, key: str) -> None:
def _metadata(self) -> dict[str, Any]:
"""
Return the full metadata including private keys.
"""
Remove a metadata key from the graph.

Parameters
----------
key : str
The key of the metadata to remove.
@abc.abstractmethod
def _update_metadata(self, **kwargs) -> None:
"""
Backend-specific metadata update implementation without public key validation.
"""

Examples
--------
```python
graph.remove_metadata("shape")
```
@abc.abstractmethod
def _remove_metadata(self, key: str) -> None:
"""
Backend-specific metadata removal implementation without public key validation.
"""

def to_traccuracy_graph(self, array_view_kwargs: dict[str, Any] | None = None) -> "TrackingGraph":
Expand Down
12 changes: 6 additions & 6 deletions src/tracksdata/graph/_graph_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -847,11 +847,11 @@ def copy(self, **kwargs) -> "GraphView":
"Use `detach` to create a new reference-less graph with the same nodes and edges."
)

def metadata(self) -> dict[str, Any]:
return self._root.metadata()
def _metadata(self) -> dict[str, Any]:
return self._root._metadata()

def update_metadata(self, **kwargs) -> None:
self._root.update_metadata(**kwargs)
def _update_metadata(self, **kwargs) -> None:
self._root._update_metadata(**kwargs)

def remove_metadata(self, key: str) -> None:
self._root.remove_metadata(key)
def _remove_metadata(self, key: str) -> None:
self._root._remove_metadata(key)
25 changes: 10 additions & 15 deletions src/tracksdata/graph/_rustworkx_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ def __init__(self, rx_graph: rx.PyDiGraph | None = None) -> None:
self._time_to_nodes: dict[int, list[int]] = {}
self.__node_attr_schemas: dict[str, AttrSchema] = {}
self.__edge_attr_schemas: dict[str, AttrSchema] = {}
self._overlaps: list[list[int, 2]] = []
self._overlaps: list[list[int]] = []

# Add default node attributes with inferred schemas
self.__node_attr_schemas[DEFAULT_ATTR_KEYS.T] = AttrSchema(
Expand Down Expand Up @@ -371,7 +371,7 @@ def __init__(self, rx_graph: rx.PyDiGraph | None = None) -> None:

elif not isinstance(self._graph.attrs, dict):
LOG.warning(
"previous attribute %s will be added to key 'old_attrs' of `graph.metadata()`",
"previous attribute %s will be added to key 'old_attrs' of `graph.metadata`",
self._graph.attrs,
)
self._graph.attrs = {
Expand Down Expand Up @@ -1153,16 +1153,11 @@ def edge_attrs(

edge_map = rx_graph.edge_index_map()
if len(edge_map) == 0:
return pl.DataFrame(
{
key: []
for key in [
*attr_keys,
DEFAULT_ATTR_KEYS.EDGE_SOURCE,
DEFAULT_ATTR_KEYS.EDGE_TARGET,
]
}
)
empty_columns = {}
for key in [*attr_keys, DEFAULT_ATTR_KEYS.EDGE_SOURCE, DEFAULT_ATTR_KEYS.EDGE_TARGET]:
schema = self._edge_attr_schemas()[key]
empty_columns[key] = pl.Series(name=key, values=[], dtype=schema.dtype)
return pl.DataFrame(empty_columns)

source, target, data = zip(*edge_map.values(), strict=False)

Expand Down Expand Up @@ -1499,13 +1494,13 @@ def edge_id(self, source_id: int, target_id: int) -> int:
"""
return self.rx_graph.get_edge_data(source_id, target_id)[DEFAULT_ATTR_KEYS.EDGE_ID]

def metadata(self) -> dict[str, Any]:
def _metadata(self) -> dict[str, Any]:
return self._graph.attrs

def update_metadata(self, **kwargs) -> None:
def _update_metadata(self, **kwargs) -> None:
self._graph.attrs.update(kwargs)

def remove_metadata(self, key: str) -> None:
def _remove_metadata(self, key: str) -> None:
self._graph.attrs.pop(key, None)

def edge_list(self) -> list[list[int, int]]:
Expand Down
Loading
Loading