diff --git a/src/tracksdata/array/_graph_array.py b/src/tracksdata/array/_graph_array.py index 018e7ae6..f9427ab7 100644 --- a/src/tracksdata/array/_graph_array.py +++ b/src/tracksdata/array/_graph_array.py @@ -1,6 +1,6 @@ from collections.abc import Sequence from copy import copy -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import numpy as np @@ -200,6 +200,9 @@ def __init__( frame_attr_key=DEFAULT_ATTR_KEYS.T, bbox_attr_key=DEFAULT_ATTR_KEYS.BBOX, ) + self.graph.node_added.connect(self._on_node_added) + self.graph.node_removed.connect(self._on_node_removed) + self.graph.node_attrs_updated.connect(self._on_node_attrs_updated) @property def shape(self) -> tuple[int, ...]: @@ -351,3 +354,95 @@ def _fill_array(self, time: int, volume_slicing: Sequence[slice], buffer: np.nda for mask, value in zip(df[DEFAULT_ATTR_KEYS.MASK], df[self._attr_key], strict=True): mask: Mask mask.paint_buffer(buffer, value, offset=self._offset) + + def _on_node_added(self, node_id: int) -> None: + self._invalidate_node_region_by_id(node_id) + + def _on_node_removed(self, node_id: int) -> None: + self._invalidate_node_region_by_id(node_id) + + def _on_node_attrs_updated(self, node_ids: Sequence[int], attr_keys: Sequence[str]) -> None: + changed_keys = set(attr_keys) + if ( + self._attr_key not in changed_keys + and DEFAULT_ATTR_KEYS.T not in changed_keys + and DEFAULT_ATTR_KEYS.BBOX not in changed_keys + and DEFAULT_ATTR_KEYS.MASK not in changed_keys + ): + return + + if DEFAULT_ATTR_KEYS.T in changed_keys: + self._cache.clear() + return + + if DEFAULT_ATTR_KEYS.BBOX in changed_keys or DEFAULT_ATTR_KEYS.MASK in changed_keys: + for node_id in node_ids: + attrs = self._node_attrs(node_id) + if attrs is None: + continue + time = attrs.get(DEFAULT_ATTR_KEYS.T) + if time is None: + self._cache.clear() + return + try: + self._cache.invalidate(time=int(time)) + except (TypeError, ValueError): + self._cache.clear() + return + return + + for node_id in node_ids: + self._invalidate_node_region_by_id(node_id) + + def _invalidate_node_region_by_id(self, node_id: int) -> None: + attrs = self._node_attrs(node_id) + if attrs is None: + return + self._invalidate_node_region(attrs) + + def _node_attrs(self, node_id: int) -> dict[str, Any] | None: + try: + return self.graph.nodes[node_id].to_dict() + except (IndexError, KeyError, ValueError): + return None + + def _invalidate_node_region(self, attrs: dict[str, Any]) -> None: + time = attrs.get(DEFAULT_ATTR_KEYS.T) + bbox = attrs.get(DEFAULT_ATTR_KEYS.BBOX) + + if time is None or bbox is None: + self._cache.clear() + return + + try: + time_int = int(time) + except (TypeError, ValueError): + self._cache.clear() + return + + if time_int < 0 or time_int >= self.original_shape[0]: + return + + spatial_ndim = len(self.original_shape) - 1 + try: + bbox_array = np.asarray(bbox, dtype=float).reshape(-1) + except (TypeError, ValueError): + self._cache.invalidate(time=time_int) + return + + if bbox_array.size != 2 * spatial_ndim: + self._cache.invalidate(time=time_int) + return + + volume_slicing = [] + for axis in range(spatial_ndim): + start = int(np.floor(bbox_array[axis])) + stop = int(np.ceil(bbox_array[axis + spatial_ndim])) + axis_size = self.original_shape[axis + 1] + start = max(0, min(axis_size, start)) + stop = max(0, min(axis_size, stop)) + if stop <= start: + return + volume_slicing.append(slice(start, stop)) + + self._cache.invalidate(time=time_int, volume_slicing=tuple(volume_slicing)) diff --git a/src/tracksdata/array/_nd_chunk_cache.py b/src/tracksdata/array/_nd_chunk_cache.py index 447dabe6..44b9e0be 100644 --- a/src/tracksdata/array/_nd_chunk_cache.py +++ b/src/tracksdata/array/_nd_chunk_cache.py @@ -114,6 +114,13 @@ def _chunk_bounds(self, slices: tuple[slice, ...]) -> tuple[tuple[int, int], ... """Return inclusive chunk-index bounds for every axis.""" return tuple((s.start // cs, (s.stop - 1) // cs) for s, cs in zip(slices, self.chunk_shape, strict=True)) + def _chunk_slice(self, chunk_idx: tuple[int, ...]) -> tuple[slice, ...]: + """Return the absolute volume slice for a chunk index.""" + return tuple( + slice(ci * cs, min((ci + 1) * cs, fs)) + for ci, cs, fs in zip(chunk_idx, self.chunk_shape, self.shape, strict=True) + ) + def get(self, time: int, volume_slicing: tuple[slice | int | Sequence[int], ...]) -> np.ndarray: """ Retrieve data for `time` and arbitrary dimensional slices. @@ -146,13 +153,57 @@ def get(self, time: int, volume_slicing: tuple[slice | int | Sequence[int], ...] continue # already filled # Absolute slice covering this chunk - chunk_slc = tuple( - slice(ci * cs, min((ci + 1) * cs, fs)) - for ci, cs, fs in zip(chunk_idx, self.chunk_shape, self.shape, strict=True) - ) + chunk_slc = self._chunk_slice(chunk_idx) # Handle the case where chunk_slc exceeds volume_slices self.compute_func(time, chunk_slc, store_entry.buffer) store_entry.ready[chunk_idx] = True # Return view on the big buffer return store_entry.buffer[volume_slicing] + + def clear(self) -> None: + """Clear all cached buffers.""" + self._store.clear() + + def invalidate( + self, + time: int, + volume_slicing: tuple[slice | int | Sequence[int], ...] | None = None, + ) -> None: + """ + Invalidate cached data for one time point or one region. + + Parameters + ---------- + time : int + Time point to invalidate. + volume_slicing : tuple[slice | int | Sequence[int], ...] | None + If provided, only chunks overlapping this region are invalidated. + If None, the full buffer for `time` is removed from cache. + """ + if time not in self._store: + return + + if volume_slicing is None: + del self._store[time] + return + + if len(volume_slicing) != self.ndim: + raise ValueError("Number of slices must equal dimensionality") + + normalized_slices: list[slice] = [] + for slc, size in zip(volume_slicing, self.shape, strict=True): + slc = _to_slice(slc) + start = 0 if slc.start is None else max(0, min(size, slc.start)) + stop = size if slc.stop is None else max(0, min(size, slc.stop)) + if stop <= start: + return + normalized_slices.append(slice(start, stop)) + + store_entry = self._store[time] + bounds = self._chunk_bounds(tuple(normalized_slices)) + chunk_ranges = [range(lo, hi + 1) for lo, hi in bounds] + for chunk_idx in itertools.product(*chunk_ranges): + chunk_slc = self._chunk_slice(chunk_idx) + store_entry.buffer[chunk_slc] = 0 + store_entry.ready[chunk_idx] = False diff --git a/src/tracksdata/array/_test/test_graph_array.py b/src/tracksdata/array/_test/test_graph_array.py index f8dbc552..f8f83abd 100644 --- a/src/tracksdata/array/_test/test_graph_array.py +++ b/src/tracksdata/array/_test/test_graph_array.py @@ -378,3 +378,91 @@ def test_graph_array_raise_error_on_non_scalar_attr_key(graph_backend: BaseGraph with pytest.raises(ValueError, match="Attribute values for key 'label' must be scalar"): GraphArrayView(graph=graph_backend, shape=(10, 100, 100), attr_key="label") + + +def test_graph_array_updates_after_add_and_remove(graph_backend: BaseGraph) -> None: + graph_backend.add_node_attr_key("label", dtype=pl.Int64) + graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, pl.Object) + graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, pl.Array(pl.Int64, 4)) + + base_mask = Mask(np.array([[True]], dtype=bool), bbox=np.array([10, 20, 11, 21])) + graph_backend.add_node( + { + DEFAULT_ATTR_KEYS.T: 0, + "label": 1, + DEFAULT_ATTR_KEYS.MASK: base_mask, + DEFAULT_ATTR_KEYS.BBOX: base_mask.bbox, + } + ) + + array_view = GraphArrayView(graph=graph_backend, shape=(10, 100, 100), attr_key="label") + _ = np.asarray(array_view[0]) # populate cache + + new_mask = Mask(np.array([[True]], dtype=bool), bbox=np.array([30, 40, 31, 41])) + new_node_id = graph_backend.add_node( + { + DEFAULT_ATTR_KEYS.T: 0, + "label": 7, + DEFAULT_ATTR_KEYS.MASK: new_mask, + DEFAULT_ATTR_KEYS.BBOX: new_mask.bbox, + } + ) + + assert np.asarray(array_view[0])[30, 40] == 7 + + graph_backend.remove_node(new_node_id) + assert np.asarray(array_view[0])[30, 40] == 0 + + +def test_graph_array_updates_after_node_attr_update(graph_backend: BaseGraph) -> None: + graph_backend.add_node_attr_key("label", dtype=pl.Int64) + graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, pl.Object) + graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, pl.Array(pl.Int64, 4)) + + mask = Mask(np.array([[True]], dtype=bool), bbox=np.array([10, 20, 11, 21])) + node_id = graph_backend.add_node( + { + DEFAULT_ATTR_KEYS.T: 0, + "label": 1, + DEFAULT_ATTR_KEYS.MASK: mask, + DEFAULT_ATTR_KEYS.BBOX: mask.bbox, + } + ) + + array_view = GraphArrayView(graph=graph_backend, shape=(10, 100, 100), attr_key="label") + _ = np.asarray(array_view[0]) # populate cache + + graph_backend.update_node_attrs(node_ids=[node_id], attrs={"label": 9}) + assert np.asarray(array_view[0])[10, 20] == 9 + + +def test_graph_array_updates_after_bbox_and_mask_update(graph_backend: BaseGraph) -> None: + graph_backend.add_node_attr_key("label", dtype=pl.Int64) + graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, pl.Object) + graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, pl.Array(pl.Int64, 4)) + + mask = Mask(np.array([[True]], dtype=bool), bbox=np.array([10, 20, 11, 21])) + node_id = graph_backend.add_node( + { + DEFAULT_ATTR_KEYS.T: 0, + "label": 3, + DEFAULT_ATTR_KEYS.MASK: mask, + DEFAULT_ATTR_KEYS.BBOX: mask.bbox, + } + ) + + array_view = GraphArrayView(graph=graph_backend, shape=(10, 100, 100), attr_key="label") + _ = np.asarray(array_view[0]) # populate cache + + moved_mask = Mask(np.array([[True]], dtype=bool), bbox=np.array([30, 40, 31, 41])) + graph_backend.update_node_attrs( + node_ids=[node_id], + attrs={ + DEFAULT_ATTR_KEYS.BBOX: [moved_mask.bbox], + DEFAULT_ATTR_KEYS.MASK: [moved_mask], + }, + ) + + refreshed = np.asarray(array_view[0]) + assert refreshed[10, 20] == 0 + assert refreshed[30, 40] == 3 diff --git a/src/tracksdata/array/_test/test_nd_chunk_cache.py b/src/tracksdata/array/_test/test_nd_chunk_cache.py index de3378c6..71de09bf 100644 --- a/src/tracksdata/array/_test/test_nd_chunk_cache.py +++ b/src/tracksdata/array/_test/test_nd_chunk_cache.py @@ -116,3 +116,22 @@ def test_nd_chunk_cache_correctly_slice(array_nd_chunk_cache, volume_slicing): # First request → triggers chunk computations vol1 = cache.get(1, volume_slicing) np.testing.assert_array_equal(vol1, np_array[1][volume_slicing]) + + +def test_nd_chunk_cache_invalidate_region(array_nd_chunk_cache): + cache, np_array = array_nd_chunk_cache + + _ = cache.get(1, (slice(0, 100), slice(0, 100), slice(0, 100))) + cache.invalidate(1, (slice(10, 20), slice(10, 20), slice(10, 20))) + + updated = cache.get(1, (slice(10, 20), slice(10, 20), slice(10, 20))) + np.testing.assert_array_equal(updated, np_array[1][10:20, 10:20, 10:20]) + + +def test_nd_chunk_cache_clear(array_nd_chunk_cache): + cache, _ = array_nd_chunk_cache + + _ = cache.get(1, (slice(0, 10), slice(0, 10), slice(0, 10))) + assert len(cache._store) == 1 + cache.clear() + assert len(cache._store) == 0 diff --git a/src/tracksdata/graph/_base_graph.py b/src/tracksdata/graph/_base_graph.py index 5b3708ad..b6d6dc6a 100644 --- a/src/tracksdata/graph/_base_graph.py +++ b/src/tracksdata/graph/_base_graph.py @@ -49,6 +49,7 @@ class BaseGraph(abc.ABC): node_added = Signal(int) node_removed = Signal(int) + node_attrs_updated = Signal(object, object) def __init__(self) -> None: self._cache = {} diff --git a/src/tracksdata/graph/_graph_view.py b/src/tracksdata/graph/_graph_view.py index b9f82ead..6e977cdf 100644 --- a/src/tracksdata/graph/_graph_view.py +++ b/src/tracksdata/graph/_graph_view.py @@ -652,6 +652,10 @@ def update_node_attrs( ) -> None: if node_ids is None: node_ids = self.node_ids() + else: + node_ids = list(node_ids) + + emit_signal = is_signal_on(self.node_attrs_updated) self._root.update_node_attrs( node_ids=node_ids, @@ -660,13 +664,24 @@ def update_node_attrs( # because attributes are passed by reference, we need don't need if both are rustworkx graphs if not self._is_root_rx_graph: if self.sync: - super().update_node_attrs( - node_ids=self._map_to_local(node_ids), - attrs=attrs, - ) + local_node_ids = self._map_to_local(node_ids) + if emit_signal: + with self.node_attrs_updated.blocked(): + super().update_node_attrs( + node_ids=local_node_ids, + attrs=attrs, + ) + else: + super().update_node_attrs( + node_ids=local_node_ids, + attrs=attrs, + ) else: self._out_of_sync = True + if emit_signal: + self.node_attrs_updated.emit_fast(list(node_ids), tuple(attrs.keys())) + def update_edge_attrs( self, *, diff --git a/src/tracksdata/graph/_rustworkx_graph.py b/src/tracksdata/graph/_rustworkx_graph.py index ef4a3f4f..7704cebc 100644 --- a/src/tracksdata/graph/_rustworkx_graph.py +++ b/src/tracksdata/graph/_rustworkx_graph.py @@ -1216,6 +1216,8 @@ def update_node_attrs( """ if node_ids is None: node_ids = self.node_ids() + else: + node_ids = list(node_ids) for key, value in attrs.items(): if key not in self.node_attr_keys(): @@ -1231,6 +1233,9 @@ def update_node_attrs( for node_id, v in zip(node_ids, value, strict=False): self._graph[node_id][key] = v + if is_signal_on(self.node_attrs_updated): + self.node_attrs_updated.emit_fast(list(node_ids), tuple(attrs.keys())) + def update_edge_attrs( self, *, @@ -1937,8 +1942,15 @@ def update_node_attrs( node_ids : Sequence[int] | None The node ids to update. """ - node_ids = self._get_local_ids() if node_ids is None else self._map_to_local(node_ids) - super().update_node_attrs(attrs=attrs, node_ids=node_ids) + external_node_ids = self.node_ids() if node_ids is None else list(node_ids) + local_node_ids = self._map_to_local(external_node_ids) + + if is_signal_on(self.node_attrs_updated): + with self.node_attrs_updated.blocked(): + super().update_node_attrs(attrs=attrs, node_ids=local_node_ids) + self.node_attrs_updated.emit_fast(external_node_ids, tuple(attrs.keys())) + else: + super().update_node_attrs(attrs=attrs, node_ids=local_node_ids) def remove_node(self, node_id: int) -> None: """ diff --git a/src/tracksdata/graph/_sql_graph.py b/src/tracksdata/graph/_sql_graph.py index 985cbdc9..b8cb079a 100644 --- a/src/tracksdata/graph/_sql_graph.py +++ b/src/tracksdata/graph/_sql_graph.py @@ -1755,8 +1755,17 @@ def update_node_attrs( if "t" in attrs: raise ValueError("Node attribute 't' cannot be updated.") + emit_signal = is_signal_on(self.node_attrs_updated) + if node_ids is None and emit_signal: + node_ids = self.node_ids() + elif node_ids is not None: + node_ids = list(node_ids) + self._update_table(self.Node, node_ids, DEFAULT_ATTR_KEYS.NODE_ID, attrs) + if emit_signal: + self.node_attrs_updated.emit_fast(list(node_ids), tuple(attrs.keys())) + def update_edge_attrs( self, *, diff --git a/src/tracksdata/graph/filters/_spatial_filter.py b/src/tracksdata/graph/filters/_spatial_filter.py index 52ce080e..985d7ce4 100644 --- a/src/tracksdata/graph/filters/_spatial_filter.py +++ b/src/tracksdata/graph/filters/_spatial_filter.py @@ -1,4 +1,5 @@ import time +from collections.abc import Sequence from typing import TYPE_CHECKING, Any import numpy as np @@ -225,50 +226,15 @@ def __init__( frame_attr_key: str | None = DEFAULT_ATTR_KEYS.T, bbox_attr_key: str = DEFAULT_ATTR_KEYS.BBOX, ) -> None: - from spatial_graph import PointRTree - self._graph = graph self._frame_attr_key = frame_attr_key self._bbox_attr_key = bbox_attr_key - - if frame_attr_key is None: - attr_keys = [DEFAULT_ATTR_KEYS.NODE_ID, bbox_attr_key] - else: - attr_keys = [DEFAULT_ATTR_KEYS.NODE_ID, frame_attr_key, bbox_attr_key] - nodes_df = graph.node_attrs(attr_keys=attr_keys) - node_ids = np.ascontiguousarray(nodes_df[DEFAULT_ATTR_KEYS.NODE_ID].to_numpy(), dtype=np.int64).copy() - - if nodes_df.is_empty(): - self._node_rtree = None - else: - bboxes = self._bboxes_to_array(nodes_df[bbox_attr_key]) - if bboxes.shape[1] % 2 != 0: - raise ValueError(f"Bounding box coordinates must have even number of dimensions, got {bboxes.shape[1]}") - num_dims = bboxes.shape[1] // 2 - - if frame_attr_key is None: - self._ndims = num_dims - positions_min = np.ascontiguousarray(bboxes[:, :num_dims], dtype=np.float32) - positions_max = np.ascontiguousarray(bboxes[:, num_dims:], dtype=np.float32) - else: - frames = nodes_df[frame_attr_key].to_numpy() - self._ndims = num_dims + 1 # +1 for the frame dimension - positions_min = np.ascontiguousarray( - np.hstack((frames[:, np.newaxis], bboxes[:, :num_dims])), dtype=np.float32 - ) - positions_max = np.ascontiguousarray( - np.hstack((frames[:, np.newaxis], bboxes[:, num_dims:])), dtype=np.float32 - ) - self._node_rtree = PointRTree( - item_dtype="int64", - coord_dtype="float32", - dims=self._ndims, - ) - self._node_rtree.insert_bb_items(node_ids, positions_min, positions_max) + self._rebuild_index() # setup signal connections self._graph.node_added.connect(self._add_node) self._graph.node_removed.connect(self._remove_node) + self._graph.node_attrs_updated.connect(self._on_node_attrs_updated) def __getitem__(self, keys: tuple[slice, ...]) -> "BaseFilter": """ @@ -330,6 +296,60 @@ def __getitem__(self, keys: tuple[slice, ...]) -> "BaseFilter": ) return self._graph.filter(node_ids=node_ids) + def _rebuild_index(self) -> None: + from spatial_graph import PointRTree + + if self._frame_attr_key is None: + attr_keys = [DEFAULT_ATTR_KEYS.NODE_ID, self._bbox_attr_key] + else: + attr_keys = [DEFAULT_ATTR_KEYS.NODE_ID, self._frame_attr_key, self._bbox_attr_key] + + nodes_df = self._graph.node_attrs(attr_keys=attr_keys) + node_ids = np.ascontiguousarray(nodes_df[DEFAULT_ATTR_KEYS.NODE_ID].to_numpy(), dtype=np.int64).copy() + + if nodes_df.is_empty(): + self._node_rtree = None + return + + bboxes = self._bboxes_to_array(nodes_df[self._bbox_attr_key]) + if bboxes.shape[1] % 2 != 0: + raise ValueError(f"Bounding box coordinates must have even number of dimensions, got {bboxes.shape[1]}") + + num_dims = bboxes.shape[1] // 2 + if self._frame_attr_key is None: + self._ndims = num_dims + positions_min = np.ascontiguousarray(bboxes[:, :num_dims], dtype=np.float32) + positions_max = np.ascontiguousarray(bboxes[:, num_dims:], dtype=np.float32) + else: + frames = nodes_df[self._frame_attr_key].to_numpy() + self._ndims = num_dims + 1 + positions_min = np.ascontiguousarray( + np.hstack((frames[:, np.newaxis], bboxes[:, :num_dims])), + dtype=np.float32, + ) + positions_max = np.ascontiguousarray( + np.hstack((frames[:, np.newaxis], bboxes[:, num_dims:])), + dtype=np.float32, + ) + + self._node_rtree = PointRTree( + item_dtype="int64", + coord_dtype="float32", + dims=self._ndims, + ) + self._node_rtree.insert_bb_items(node_ids, positions_min, positions_max) + + def _on_node_attrs_updated(self, node_ids: Sequence[int], attr_keys: Sequence[str]) -> None: + """ + Refresh index when updates affect spatial coordinates. + + BBox updates can change both min/max bounds and potentially the frame axis, + so we rebuild from current graph state to keep the index consistent. + """ + changed = set(attr_keys) + if self._bbox_attr_key in changed or (self._frame_attr_key is not None and self._frame_attr_key in changed): + self._rebuild_index() + def _attrs_to_bb_window(self, attrs: dict[str, Any]) -> tuple[np.ndarray, np.ndarray]: """ Convert attributes to bounding box window.