Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 96 additions & 1 deletion src/tracksdata/array/_graph_array.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections.abc import Sequence
from copy import copy
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

import numpy as np

Expand Down Expand Up @@ -200,6 +200,9 @@ def __init__(
frame_attr_key=DEFAULT_ATTR_KEYS.T,
bbox_attr_key=DEFAULT_ATTR_KEYS.BBOX,
)
self.graph.node_added.connect(self._on_node_added)
self.graph.node_removed.connect(self._on_node_removed)
self.graph.node_attrs_updated.connect(self._on_node_attrs_updated)

@property
def shape(self) -> tuple[int, ...]:
Expand Down Expand Up @@ -351,3 +354,95 @@ def _fill_array(self, time: int, volume_slicing: Sequence[slice], buffer: np.nda
for mask, value in zip(df[DEFAULT_ATTR_KEYS.MASK], df[self._attr_key], strict=True):
mask: Mask
mask.paint_buffer(buffer, value, offset=self._offset)

def _on_node_added(self, node_id: int) -> None:
self._invalidate_node_region_by_id(node_id)

def _on_node_removed(self, node_id: int) -> None:
self._invalidate_node_region_by_id(node_id)

def _on_node_attrs_updated(self, node_ids: Sequence[int], attr_keys: Sequence[str]) -> None:
changed_keys = set(attr_keys)
if (
self._attr_key not in changed_keys
and DEFAULT_ATTR_KEYS.T not in changed_keys
and DEFAULT_ATTR_KEYS.BBOX not in changed_keys
and DEFAULT_ATTR_KEYS.MASK not in changed_keys
):
return

if DEFAULT_ATTR_KEYS.T in changed_keys:
self._cache.clear()
return

if DEFAULT_ATTR_KEYS.BBOX in changed_keys or DEFAULT_ATTR_KEYS.MASK in changed_keys:
for node_id in node_ids:
attrs = self._node_attrs(node_id)
if attrs is None:
continue
time = attrs.get(DEFAULT_ATTR_KEYS.T)
if time is None:
self._cache.clear()
return
try:
self._cache.invalidate(time=int(time))
except (TypeError, ValueError):
self._cache.clear()
return
return

for node_id in node_ids:
self._invalidate_node_region_by_id(node_id)

def _invalidate_node_region_by_id(self, node_id: int) -> None:
attrs = self._node_attrs(node_id)
if attrs is None:
return
self._invalidate_node_region(attrs)

def _node_attrs(self, node_id: int) -> dict[str, Any] | None:
try:
return self.graph.nodes[node_id].to_dict()
except (IndexError, KeyError, ValueError):
return None

def _invalidate_node_region(self, attrs: dict[str, Any]) -> None:
time = attrs.get(DEFAULT_ATTR_KEYS.T)
bbox = attrs.get(DEFAULT_ATTR_KEYS.BBOX)

if time is None or bbox is None:
self._cache.clear()
return

try:
time_int = int(time)
except (TypeError, ValueError):
self._cache.clear()
return

if time_int < 0 or time_int >= self.original_shape[0]:
return

spatial_ndim = len(self.original_shape) - 1
try:
bbox_array = np.asarray(bbox, dtype=float).reshape(-1)
except (TypeError, ValueError):
self._cache.invalidate(time=time_int)
return

if bbox_array.size != 2 * spatial_ndim:
self._cache.invalidate(time=time_int)
return

volume_slicing = []
for axis in range(spatial_ndim):
start = int(np.floor(bbox_array[axis]))
stop = int(np.ceil(bbox_array[axis + spatial_ndim]))
axis_size = self.original_shape[axis + 1]
start = max(0, min(axis_size, start))
stop = max(0, min(axis_size, stop))
if stop <= start:
return
volume_slicing.append(slice(start, stop))

self._cache.invalidate(time=time_int, volume_slicing=tuple(volume_slicing))
59 changes: 55 additions & 4 deletions src/tracksdata/array/_nd_chunk_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,13 @@ def _chunk_bounds(self, slices: tuple[slice, ...]) -> tuple[tuple[int, int], ...
"""Return inclusive chunk-index bounds for every axis."""
return tuple((s.start // cs, (s.stop - 1) // cs) for s, cs in zip(slices, self.chunk_shape, strict=True))

def _chunk_slice(self, chunk_idx: tuple[int, ...]) -> tuple[slice, ...]:
"""Return the absolute volume slice for a chunk index."""
return tuple(
slice(ci * cs, min((ci + 1) * cs, fs))
for ci, cs, fs in zip(chunk_idx, self.chunk_shape, self.shape, strict=True)
)

def get(self, time: int, volume_slicing: tuple[slice | int | Sequence[int], ...]) -> np.ndarray:
"""
Retrieve data for `time` and arbitrary dimensional slices.
Expand Down Expand Up @@ -146,13 +153,57 @@ def get(self, time: int, volume_slicing: tuple[slice | int | Sequence[int], ...]
continue # already filled

# Absolute slice covering this chunk
chunk_slc = tuple(
slice(ci * cs, min((ci + 1) * cs, fs))
for ci, cs, fs in zip(chunk_idx, self.chunk_shape, self.shape, strict=True)
)
chunk_slc = self._chunk_slice(chunk_idx)
# Handle the case where chunk_slc exceeds volume_slices
self.compute_func(time, chunk_slc, store_entry.buffer)
store_entry.ready[chunk_idx] = True

# Return view on the big buffer
return store_entry.buffer[volume_slicing]

def clear(self) -> None:
"""Clear all cached buffers."""
self._store.clear()

def invalidate(
self,
time: int,
volume_slicing: tuple[slice | int | Sequence[int], ...] | None = None,
) -> None:
"""
Invalidate cached data for one time point or one region.

Parameters
----------
time : int
Time point to invalidate.
volume_slicing : tuple[slice | int | Sequence[int], ...] | None
If provided, only chunks overlapping this region are invalidated.
If None, the full buffer for `time` is removed from cache.
"""
if time not in self._store:
return

if volume_slicing is None:
del self._store[time]
return

if len(volume_slicing) != self.ndim:
raise ValueError("Number of slices must equal dimensionality")

normalized_slices: list[slice] = []
for slc, size in zip(volume_slicing, self.shape, strict=True):
slc = _to_slice(slc)
start = 0 if slc.start is None else max(0, min(size, slc.start))
stop = size if slc.stop is None else max(0, min(size, slc.stop))
if stop <= start:
return
normalized_slices.append(slice(start, stop))

store_entry = self._store[time]
bounds = self._chunk_bounds(tuple(normalized_slices))
chunk_ranges = [range(lo, hi + 1) for lo, hi in bounds]
for chunk_idx in itertools.product(*chunk_ranges):
chunk_slc = self._chunk_slice(chunk_idx)
store_entry.buffer[chunk_slc] = 0
store_entry.ready[chunk_idx] = False
88 changes: 88 additions & 0 deletions src/tracksdata/array/_test/test_graph_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,3 +378,91 @@ def test_graph_array_raise_error_on_non_scalar_attr_key(graph_backend: BaseGraph

with pytest.raises(ValueError, match="Attribute values for key 'label' must be scalar"):
GraphArrayView(graph=graph_backend, shape=(10, 100, 100), attr_key="label")


def test_graph_array_updates_after_add_and_remove(graph_backend: BaseGraph) -> None:
graph_backend.add_node_attr_key("label", dtype=pl.Int64)
graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, pl.Object)
graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, pl.Array(pl.Int64, 4))

base_mask = Mask(np.array([[True]], dtype=bool), bbox=np.array([10, 20, 11, 21]))
graph_backend.add_node(
{
DEFAULT_ATTR_KEYS.T: 0,
"label": 1,
DEFAULT_ATTR_KEYS.MASK: base_mask,
DEFAULT_ATTR_KEYS.BBOX: base_mask.bbox,
}
)

array_view = GraphArrayView(graph=graph_backend, shape=(10, 100, 100), attr_key="label")
_ = np.asarray(array_view[0]) # populate cache

new_mask = Mask(np.array([[True]], dtype=bool), bbox=np.array([30, 40, 31, 41]))
new_node_id = graph_backend.add_node(
{
DEFAULT_ATTR_KEYS.T: 0,
"label": 7,
DEFAULT_ATTR_KEYS.MASK: new_mask,
DEFAULT_ATTR_KEYS.BBOX: new_mask.bbox,
}
)

assert np.asarray(array_view[0])[30, 40] == 7

graph_backend.remove_node(new_node_id)
assert np.asarray(array_view[0])[30, 40] == 0


def test_graph_array_updates_after_node_attr_update(graph_backend: BaseGraph) -> None:
graph_backend.add_node_attr_key("label", dtype=pl.Int64)
graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, pl.Object)
graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, pl.Array(pl.Int64, 4))

mask = Mask(np.array([[True]], dtype=bool), bbox=np.array([10, 20, 11, 21]))
node_id = graph_backend.add_node(
{
DEFAULT_ATTR_KEYS.T: 0,
"label": 1,
DEFAULT_ATTR_KEYS.MASK: mask,
DEFAULT_ATTR_KEYS.BBOX: mask.bbox,
}
)

array_view = GraphArrayView(graph=graph_backend, shape=(10, 100, 100), attr_key="label")
_ = np.asarray(array_view[0]) # populate cache

graph_backend.update_node_attrs(node_ids=[node_id], attrs={"label": 9})
assert np.asarray(array_view[0])[10, 20] == 9


def test_graph_array_updates_after_bbox_and_mask_update(graph_backend: BaseGraph) -> None:
graph_backend.add_node_attr_key("label", dtype=pl.Int64)
graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, pl.Object)
graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, pl.Array(pl.Int64, 4))

mask = Mask(np.array([[True]], dtype=bool), bbox=np.array([10, 20, 11, 21]))
node_id = graph_backend.add_node(
{
DEFAULT_ATTR_KEYS.T: 0,
"label": 3,
DEFAULT_ATTR_KEYS.MASK: mask,
DEFAULT_ATTR_KEYS.BBOX: mask.bbox,
}
)

array_view = GraphArrayView(graph=graph_backend, shape=(10, 100, 100), attr_key="label")
_ = np.asarray(array_view[0]) # populate cache

moved_mask = Mask(np.array([[True]], dtype=bool), bbox=np.array([30, 40, 31, 41]))
graph_backend.update_node_attrs(
node_ids=[node_id],
attrs={
DEFAULT_ATTR_KEYS.BBOX: [moved_mask.bbox],
DEFAULT_ATTR_KEYS.MASK: [moved_mask],
},
)

refreshed = np.asarray(array_view[0])
assert refreshed[10, 20] == 0
assert refreshed[30, 40] == 3
19 changes: 19 additions & 0 deletions src/tracksdata/array/_test/test_nd_chunk_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,22 @@ def test_nd_chunk_cache_correctly_slice(array_nd_chunk_cache, volume_slicing):
# First request → triggers chunk computations
vol1 = cache.get(1, volume_slicing)
np.testing.assert_array_equal(vol1, np_array[1][volume_slicing])


def test_nd_chunk_cache_invalidate_region(array_nd_chunk_cache):
cache, np_array = array_nd_chunk_cache

_ = cache.get(1, (slice(0, 100), slice(0, 100), slice(0, 100)))
cache.invalidate(1, (slice(10, 20), slice(10, 20), slice(10, 20)))

updated = cache.get(1, (slice(10, 20), slice(10, 20), slice(10, 20)))
np.testing.assert_array_equal(updated, np_array[1][10:20, 10:20, 10:20])


def test_nd_chunk_cache_clear(array_nd_chunk_cache):
cache, _ = array_nd_chunk_cache

_ = cache.get(1, (slice(0, 10), slice(0, 10), slice(0, 10)))
assert len(cache._store) == 1
cache.clear()
assert len(cache._store) == 0
1 change: 1 addition & 0 deletions src/tracksdata/graph/_base_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class BaseGraph(abc.ABC):

node_added = Signal(int)
node_removed = Signal(int)
node_attrs_updated = Signal(object, object)

def __init__(self) -> None:
self._cache = {}
Expand Down
23 changes: 19 additions & 4 deletions src/tracksdata/graph/_graph_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,6 +652,10 @@ def update_node_attrs(
) -> None:
if node_ids is None:
node_ids = self.node_ids()
else:
node_ids = list(node_ids)

emit_signal = is_signal_on(self.node_attrs_updated)

self._root.update_node_attrs(
node_ids=node_ids,
Expand All @@ -660,13 +664,24 @@ def update_node_attrs(
# because attributes are passed by reference, we need don't need if both are rustworkx graphs
if not self._is_root_rx_graph:
if self.sync:
super().update_node_attrs(
node_ids=self._map_to_local(node_ids),
attrs=attrs,
)
local_node_ids = self._map_to_local(node_ids)
if emit_signal:
with self.node_attrs_updated.blocked():
super().update_node_attrs(
node_ids=local_node_ids,
attrs=attrs,
)
else:
super().update_node_attrs(
node_ids=local_node_ids,
attrs=attrs,
)
else:
self._out_of_sync = True

if emit_signal:
self.node_attrs_updated.emit_fast(list(node_ids), tuple(attrs.keys()))

def update_edge_attrs(
self,
*,
Expand Down
16 changes: 14 additions & 2 deletions src/tracksdata/graph/_rustworkx_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1216,6 +1216,8 @@ def update_node_attrs(
"""
if node_ids is None:
node_ids = self.node_ids()
else:
node_ids = list(node_ids)

for key, value in attrs.items():
if key not in self.node_attr_keys():
Expand All @@ -1231,6 +1233,9 @@ def update_node_attrs(
for node_id, v in zip(node_ids, value, strict=False):
self._graph[node_id][key] = v

if is_signal_on(self.node_attrs_updated):
self.node_attrs_updated.emit_fast(list(node_ids), tuple(attrs.keys()))

def update_edge_attrs(
self,
*,
Expand Down Expand Up @@ -1937,8 +1942,15 @@ def update_node_attrs(
node_ids : Sequence[int] | None
The node ids to update.
"""
node_ids = self._get_local_ids() if node_ids is None else self._map_to_local(node_ids)
super().update_node_attrs(attrs=attrs, node_ids=node_ids)
external_node_ids = self.node_ids() if node_ids is None else list(node_ids)
local_node_ids = self._map_to_local(external_node_ids)

if is_signal_on(self.node_attrs_updated):
with self.node_attrs_updated.blocked():
super().update_node_attrs(attrs=attrs, node_ids=local_node_ids)
self.node_attrs_updated.emit_fast(external_node_ids, tuple(attrs.keys()))
else:
super().update_node_attrs(attrs=attrs, node_ids=local_node_ids)

def remove_node(self, node_id: int) -> None:
"""
Expand Down
Loading
Loading