Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -193,4 +193,6 @@ benchmarks/outputs/*.md

# Version file generated by hatch-vcs
src/tracksdata/__about__.py

# vscode
.vscode/
2 changes: 1 addition & 1 deletion docs/examples.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ Here's a complete basic example that demonstrates the core workflow of TracksDat
## Key Components Explained

- **Graph**: The core data structure holding nodes (objects) and edges (connections)
- **Nodes Operators**: Extract object features from segmented images (RegionPropsNodes, MaskNodes, etc.)
- **Nodes Operators**: Extract object features from segmented images (RegionPropsNodes, etc.)
- **Edges Operators**: Create temporal connections between objects (DistanceEdges, IoUEdges, etc.)
- **Solvers**: Optimize a minimization problem to find the best tracking assignments (NearestNeighborsSolver, ILPSolver)
- **Functional**: Utilities for format conversion and visualization
Expand Down
14 changes: 9 additions & 5 deletions src/tracksdata/array/_graph_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from tracksdata.array._base_array import ArrayIndex, BaseReadOnlyArray
from tracksdata.attrs import NodeAttr
from tracksdata.constants import DEFAULT_ATTR_KEYS
from tracksdata.functional._mask import paint_mask_to_buffer
from tracksdata.graph._base_graph import BaseGraph
from tracksdata.nodes._mask import Mask
from tracksdata.utils._dtypes import polars_dtype_to_numpy_dtype


Expand Down Expand Up @@ -69,7 +69,7 @@ def __getitem__(self, index: ArrayIndex) -> ArrayLike:
return np.zeros(self.shape[1:], dtype=self.dtype)

df = graph_filter.node_attrs(
attr_keys=[self._attr_key, DEFAULT_ATTR_KEYS.MASK],
attr_keys=[self._attr_key, DEFAULT_ATTR_KEYS.MASK, DEFAULT_ATTR_KEYS.BBOX],
)

dtype = polars_dtype_to_numpy_dtype(df[self._attr_key].dtype)
Expand All @@ -83,9 +83,13 @@ def __getitem__(self, index: ArrayIndex) -> ArrayLike:
# TODO: reuse buffer
buffer = np.zeros(self.shape[1:], dtype=self.dtype)

for mask, value in zip(df[DEFAULT_ATTR_KEYS.MASK], df[self._attr_key], strict=False):
mask: Mask
mask.paint_buffer(buffer, value, offset=self._offset)
for bbox, mask, value in zip(
df[DEFAULT_ATTR_KEYS.BBOX],
df[DEFAULT_ATTR_KEYS.MASK],
df[self._attr_key],
strict=True,
):
paint_mask_to_buffer(buffer, bbox, mask, value, offset=self._offset)

return buffer
else:
Expand Down
46 changes: 24 additions & 22 deletions src/tracksdata/array/_test/test_graph_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from tracksdata.array import GraphArrayView
from tracksdata.constants import DEFAULT_ATTR_KEYS
from tracksdata.graph import RustWorkXGraph
from tracksdata.nodes._mask import Mask
from tracksdata.utils._test_utils import setup_mask_attrs

# NOTE: this could be generic test for all array backends
# when more slicing operations are implemented we could test as in:
Expand Down Expand Up @@ -59,14 +59,14 @@ def test_graph_array_view_getitem_with_nodes() -> None:

# Add attribute keys
graph.add_node_attr_key("label", 0)
graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, None)
setup_mask_attrs(graph)

# Create a mask
mask_data = np.array([[True, True], [True, False]], dtype=bool)
mask = Mask(mask_data, bbox=np.array([10, 20, 12, 22])) # y_min, x_min, y_max, x_max
mask = np.array([[True, True], [True, False]], dtype=bool)
bbox = np.array([10, 20, 12, 22])

# Add a node with mask and label
graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "label": 5, DEFAULT_ATTR_KEYS.MASK: mask})
graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "label": 5, DEFAULT_ATTR_KEYS.MASK: mask, DEFAULT_ATTR_KEYS.BBOX: bbox})

array_view = GraphArrayView(graph=graph, shape=(10, 100, 100), attr_key="label")

Expand Down Expand Up @@ -94,19 +94,19 @@ def test_graph_array_view_getitem_multiple_nodes() -> None:

# Add attribute keys
graph.add_node_attr_key("label", 0)
graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, None)
setup_mask_attrs(graph)

# Create two masks at different locations
mask1_data = np.array([[True, True]], dtype=bool)
mask1 = Mask(mask1_data, bbox=np.array([10, 20, 11, 22]))
mask1 = np.array([[True, True]], dtype=bool)
bbox1 = np.array([10, 20, 11, 22])

mask2_data = np.array([[True]], dtype=bool)
mask2 = Mask(mask2_data, bbox=np.array([30, 40, 31, 41]))
mask2 = np.array([[True]], dtype=bool)
bbox2 = np.array([30, 40, 31, 41])

# Add nodes with different labels
graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "label": 3, DEFAULT_ATTR_KEYS.MASK: mask1})
graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "label": 3, DEFAULT_ATTR_KEYS.MASK: mask1, DEFAULT_ATTR_KEYS.BBOX: bbox1})

graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "label": 7, DEFAULT_ATTR_KEYS.MASK: mask2})
graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "label": 7, DEFAULT_ATTR_KEYS.MASK: mask2, DEFAULT_ATTR_KEYS.BBOX: bbox2})

array_view = GraphArrayView(graph=graph, shape=(10, 100, 100), attr_key="label")

Expand All @@ -129,14 +129,15 @@ def test_graph_array_view_getitem_boolean_dtype() -> None:

# Add attribute keys
graph.add_node_attr_key("is_active", False)
graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, None)

setup_mask_attrs(graph)
# Create a mask
mask_data = np.array([[True]], dtype=bool)
mask = Mask(mask_data, bbox=np.array([10, 20, 11, 21]))
mask = np.array([[True]], dtype=bool)
bbox = np.array([10, 20, 11, 21])

# Add a node with boolean attribute
graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "is_active": True, DEFAULT_ATTR_KEYS.MASK: mask})
graph.add_node(
{DEFAULT_ATTR_KEYS.T: 0, "is_active": True, DEFAULT_ATTR_KEYS.MASK: mask, DEFAULT_ATTR_KEYS.BBOX: bbox}
)

array_view = GraphArrayView(graph=graph, shape=(10, 100, 100), attr_key="is_active")

Expand All @@ -155,14 +156,15 @@ def test_graph_array_view_dtype_inference() -> None:

# Add attribute keys
graph.add_node_attr_key("float_label", 0.0)
graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, None)

setup_mask_attrs(graph)
# Create a mask
mask_data = np.array([[True]], dtype=bool)
mask = Mask(mask_data, bbox=np.array([10, 20, 11, 21]))
mask = np.array([[True]], dtype=bool)
bbox = np.array([10, 20, 11, 21])

# Add a node with float attribute
graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "float_label": 3.14, DEFAULT_ATTR_KEYS.MASK: mask})
graph.add_node(
{DEFAULT_ATTR_KEYS.T: 0, "float_label": 3.14, DEFAULT_ATTR_KEYS.MASK: mask, DEFAULT_ATTR_KEYS.BBOX: bbox}
)

array_view = GraphArrayView(graph=graph, shape=(10, 100, 100), attr_key="float_label")

Expand Down
16 changes: 13 additions & 3 deletions src/tracksdata/edges/_iou_edges.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import numpy as np

from tracksdata.constants import DEFAULT_ATTR_KEYS
from tracksdata.edges._generic_edges import GenericFuncEdgeAttrs
from tracksdata.nodes._mask import Mask
from tracksdata.functional._mask import mask_iou


class IoUEdgeAttr(GenericFuncEdgeAttrs):
Expand All @@ -14,15 +16,23 @@ class IoUEdgeAttr(GenericFuncEdgeAttrs):
The key to use for the output of the IoU.
mask_key : str
The key to use for the masks of the nodes.
bbox_key : str
The key to use for the bounding boxes of the nodes.
"""

def __init__(
self,
output_key: str,
mask_key: str = DEFAULT_ATTR_KEYS.MASK,
bbox_key: str = DEFAULT_ATTR_KEYS.BBOX,
):
def _compute_iou(source_attrs: dict[str, np.ndarray], target_attrs: dict[str, np.ndarray]) -> float:
return mask_iou(
source_attrs[bbox_key], source_attrs[mask_key], target_attrs[bbox_key], target_attrs[mask_key]
)

super().__init__(
func=Mask.iou,
attr_keys=mask_key,
func=_compute_iou,
attr_keys=[mask_key, bbox_key],
output_key=output_key,
)
37 changes: 16 additions & 21 deletions src/tracksdata/edges/_test/test_distance_edges.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@
from tracksdata.edges import DistanceEdges
from tracksdata.graph import RustWorkXGraph
from tracksdata.options import get_options, options_context
from tracksdata.utils._test_utils import (
setup_custom_node_attr,
setup_spatial_attrs_2d,
setup_spatial_attrs_3d,
)


def test_distance_edges_init_default_params() -> None:
Expand Down Expand Up @@ -48,8 +53,7 @@ def test_distance_edges_add_edges_single_timepoint_no_previous() -> None:
graph = RustWorkXGraph()

# Register attribute keys
graph.add_node_attr_key("x", 0.0)
graph.add_node_attr_key("y", 0.0)
setup_spatial_attrs_2d(graph)

# Add nodes only at t=1 (no t=0)
graph.add_node({DEFAULT_ATTR_KEYS.T: 1, "x": 0.0, "y": 0.0})
Expand All @@ -67,8 +71,7 @@ def test_distance_edges_add_edges_single_timepoint_no_current() -> None:
graph = RustWorkXGraph()

# Register attribute keys
graph.add_node_attr_key("x", 0.0)
graph.add_node_attr_key("y", 0.0)
setup_spatial_attrs_2d(graph)

# Add nodes only at t=0 (no t=1)
graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "y": 0.0})
Expand All @@ -86,8 +89,7 @@ def test_distance_edges_add_edges_2d_coordinates() -> None:
graph = RustWorkXGraph()

# Register attribute keys
graph.add_node_attr_key("x", 0.0)
graph.add_node_attr_key("y", 0.0)
setup_spatial_attrs_2d(graph)

# Add nodes at t=0
_ = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "y": 0.0})
Expand All @@ -114,9 +116,7 @@ def test_distance_edges_add_edges_3d_coordinates() -> None:
graph = RustWorkXGraph()

# Register attribute keys
graph.add_node_attr_key("x", 0.0)
graph.add_node_attr_key("y", 0.0)
graph.add_node_attr_key("z", 0.0)
setup_spatial_attrs_3d(graph)

# Add nodes at t=0
_ = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "y": 0.0, "z": 0.0})
Expand All @@ -139,8 +139,8 @@ def test_distance_edges_add_edges_custom_attr_keys() -> None:
graph = RustWorkXGraph()

# Register attribute keys
graph.add_node_attr_key("pos_x", 0.0)
graph.add_node_attr_key("pos_y", 0.0)
setup_custom_node_attr(graph, "pos_x", 0.0)
setup_custom_node_attr(graph, "pos_y", 0.0)

# Add nodes at t=0
_ = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "pos_x": 0.0, "pos_y": 0.0})
Expand All @@ -163,8 +163,7 @@ def test_distance_edges_add_edges_distance_threshold() -> None:
graph = RustWorkXGraph()

# Register attribute keys
graph.add_node_attr_key("x", 0.0)
graph.add_node_attr_key("y", 0.0)
setup_spatial_attrs_2d(graph)

# Add nodes at t=0
_ = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "y": 0.0})
Expand All @@ -186,8 +185,7 @@ def test_distance_edges_add_edges_multiple_timepoints(n_workers: int) -> None:
graph = RustWorkXGraph()

# Register attribute keys
graph.add_node_attr_key("x", 0.0)
graph.add_node_attr_key("y", 0.0)
setup_spatial_attrs_2d(graph)

# Add nodes at multiple timepoints
for t in range(3):
Expand All @@ -209,8 +207,7 @@ def test_distance_edges_add_edges_custom_weight_key() -> None:
graph = RustWorkXGraph()

# Register attribute keys
graph.add_node_attr_key("x", 0.0)
graph.add_node_attr_key("y", 0.0)
setup_spatial_attrs_2d(graph)

# Add nodes at t=0
_ = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "y": 0.0})
Expand All @@ -237,8 +234,7 @@ def test_distance_edges_n_neighbors_limit() -> None:
graph = RustWorkXGraph()

# Register attribute keys
graph.add_node_attr_key("x", 0.0)
graph.add_node_attr_key("y", 0.0)
setup_spatial_attrs_2d(graph)

# Add many nodes at t=0
for i in range(5):
Expand All @@ -264,8 +260,7 @@ def test_distance_edges_add_edges_with_delta_t(n_workers: int) -> None:
graph = RustWorkXGraph()

# Register attribute keys
graph.add_node_attr_key("x", 0.0)
graph.add_node_attr_key("y", 0.0)
setup_spatial_attrs_2d(graph)

# Add nodes at t=0, t=1, t=2
node_0 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "y": 0.0})
Expand Down
30 changes: 17 additions & 13 deletions src/tracksdata/edges/_test/test_generic_edges.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@
from tracksdata.constants import DEFAULT_ATTR_KEYS
from tracksdata.edges import GenericFuncEdgeAttrs
from tracksdata.graph import RustWorkXGraph
from tracksdata.utils._test_utils import (
setup_custom_node_attr,
setup_edge_distance_attr,
setup_spatial_attrs_2d,
)


def _scalar_distance_func(source_val: float, target_val: float) -> float:
Expand Down Expand Up @@ -41,8 +46,8 @@ def test_generic_edges_add_weights_single_attr_key() -> None:
graph = RustWorkXGraph()

# Register attribute keys
graph.add_node_attr_key("x", 0.0)
graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, 0.0)
setup_custom_node_attr(graph, "x", 0.0)
setup_edge_distance_attr(graph)

# Add nodes at time 0
node0 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 1.0})
Expand Down Expand Up @@ -78,9 +83,8 @@ def test_generic_edges_add_weights_multiple_attr_keys() -> None:
graph = RustWorkXGraph()

# Register attribute keys
graph.add_node_attr_key("x", 0.0)
graph.add_node_attr_key("y", 0.0)
graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, 0.0)
setup_spatial_attrs_2d(graph)
setup_edge_distance_attr(graph)

# Add nodes at time 0
node0 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "y": 0.0})
Expand Down Expand Up @@ -112,8 +116,8 @@ def test_generic_edges_add_weights_all_time_points() -> None:
graph = RustWorkXGraph()

# Register attribute keys
graph.add_node_attr_key("x", 0.0)
graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, 0.0)
setup_custom_node_attr(graph, "x", 0.0)
setup_edge_distance_attr(graph)

# Add nodes at different time points
node0_t0 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 1.0})
Expand Down Expand Up @@ -142,7 +146,7 @@ def test_generic_edges_no_edges_at_time_point() -> None:
graph = RustWorkXGraph()

# Register attribute keys
graph.add_node_attr_key("x", 0.0)
setup_custom_node_attr(graph, "x", 0.0)

# Add nodes but no edges at time 0
graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 1.0})
Expand All @@ -163,8 +167,8 @@ def test_generic_edges_creates_output_key() -> None:
graph = RustWorkXGraph()

# Register attribute keys
graph.add_node_attr_key("x", 0.0)
graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, 0.0)
setup_custom_node_attr(graph, "x", 0.0)
setup_edge_distance_attr(graph)

# Add nodes and edge
node0 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 1.0})
Expand All @@ -191,9 +195,9 @@ def test_generic_edges_dict_input_function() -> None:
graph = RustWorkXGraph()

# Register attribute keys
graph.add_node_attr_key("value", 0.0)
graph.add_node_attr_key("weight", 0.0)
graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, 0.0)
setup_custom_node_attr(graph, "value", 0.0)
setup_custom_node_attr(graph, "weight", 0.0)
setup_edge_distance_attr(graph)

# Add nodes
node0 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "value": 10.0, "weight": 2.0})
Expand Down
Loading
Loading