diff --git a/.gitignore b/.gitignore index 941303b3..b4de732b 100644 --- a/.gitignore +++ b/.gitignore @@ -193,4 +193,6 @@ benchmarks/outputs/*.md # Version file generated by hatch-vcs src/tracksdata/__about__.py + +# vscode .vscode/ diff --git a/docs/examples.md b/docs/examples.md index 496ff285..70ea4569 100644 --- a/docs/examples.md +++ b/docs/examples.md @@ -13,7 +13,7 @@ Here's a complete basic example that demonstrates the core workflow of TracksDat ## Key Components Explained - **Graph**: The core data structure holding nodes (objects) and edges (connections) -- **Nodes Operators**: Extract object features from segmented images (RegionPropsNodes, MaskNodes, etc.) +- **Nodes Operators**: Extract object features from segmented images (RegionPropsNodes, etc.) - **Edges Operators**: Create temporal connections between objects (DistanceEdges, IoUEdges, etc.) - **Solvers**: Optimize a minimization problem to find the best tracking assignments (NearestNeighborsSolver, ILPSolver) - **Functional**: Utilities for format conversion and visualization diff --git a/src/tracksdata/array/_graph_array.py b/src/tracksdata/array/_graph_array.py index 0dd1e3c0..fa351e9e 100644 --- a/src/tracksdata/array/_graph_array.py +++ b/src/tracksdata/array/_graph_array.py @@ -4,8 +4,8 @@ from tracksdata.array._base_array import ArrayIndex, BaseReadOnlyArray from tracksdata.attrs import NodeAttr from tracksdata.constants import DEFAULT_ATTR_KEYS +from tracksdata.functional._mask import paint_mask_to_buffer from tracksdata.graph._base_graph import BaseGraph -from tracksdata.nodes._mask import Mask from tracksdata.utils._dtypes import polars_dtype_to_numpy_dtype @@ -69,7 +69,7 @@ def __getitem__(self, index: ArrayIndex) -> ArrayLike: return np.zeros(self.shape[1:], dtype=self.dtype) df = graph_filter.node_attrs( - attr_keys=[self._attr_key, DEFAULT_ATTR_KEYS.MASK], + attr_keys=[self._attr_key, DEFAULT_ATTR_KEYS.MASK, DEFAULT_ATTR_KEYS.BBOX], ) dtype = polars_dtype_to_numpy_dtype(df[self._attr_key].dtype) @@ -83,9 +83,13 @@ def __getitem__(self, index: ArrayIndex) -> ArrayLike: # TODO: reuse buffer buffer = np.zeros(self.shape[1:], dtype=self.dtype) - for mask, value in zip(df[DEFAULT_ATTR_KEYS.MASK], df[self._attr_key], strict=False): - mask: Mask - mask.paint_buffer(buffer, value, offset=self._offset) + for bbox, mask, value in zip( + df[DEFAULT_ATTR_KEYS.BBOX], + df[DEFAULT_ATTR_KEYS.MASK], + df[self._attr_key], + strict=True, + ): + paint_mask_to_buffer(buffer, bbox, mask, value, offset=self._offset) return buffer else: diff --git a/src/tracksdata/array/_test/test_graph_array.py b/src/tracksdata/array/_test/test_graph_array.py index 7599d7b3..1c2938a5 100644 --- a/src/tracksdata/array/_test/test_graph_array.py +++ b/src/tracksdata/array/_test/test_graph_array.py @@ -4,7 +4,7 @@ from tracksdata.array import GraphArrayView from tracksdata.constants import DEFAULT_ATTR_KEYS from tracksdata.graph import RustWorkXGraph -from tracksdata.nodes._mask import Mask +from tracksdata.utils._test_utils import setup_mask_attrs # NOTE: this could be generic test for all array backends # when more slicing operations are implemented we could test as in: @@ -59,14 +59,14 @@ def test_graph_array_view_getitem_with_nodes() -> None: # Add attribute keys graph.add_node_attr_key("label", 0) - graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, None) + setup_mask_attrs(graph) # Create a mask - mask_data = np.array([[True, True], [True, False]], dtype=bool) - mask = Mask(mask_data, bbox=np.array([10, 20, 12, 22])) # y_min, x_min, y_max, x_max + mask = np.array([[True, True], [True, False]], dtype=bool) + bbox = np.array([10, 20, 12, 22]) # Add a node with mask and label - graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "label": 5, DEFAULT_ATTR_KEYS.MASK: mask}) + graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "label": 5, DEFAULT_ATTR_KEYS.MASK: mask, DEFAULT_ATTR_KEYS.BBOX: bbox}) array_view = GraphArrayView(graph=graph, shape=(10, 100, 100), attr_key="label") @@ -94,19 +94,19 @@ def test_graph_array_view_getitem_multiple_nodes() -> None: # Add attribute keys graph.add_node_attr_key("label", 0) - graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, None) + setup_mask_attrs(graph) # Create two masks at different locations - mask1_data = np.array([[True, True]], dtype=bool) - mask1 = Mask(mask1_data, bbox=np.array([10, 20, 11, 22])) + mask1 = np.array([[True, True]], dtype=bool) + bbox1 = np.array([10, 20, 11, 22]) - mask2_data = np.array([[True]], dtype=bool) - mask2 = Mask(mask2_data, bbox=np.array([30, 40, 31, 41])) + mask2 = np.array([[True]], dtype=bool) + bbox2 = np.array([30, 40, 31, 41]) # Add nodes with different labels - graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "label": 3, DEFAULT_ATTR_KEYS.MASK: mask1}) + graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "label": 3, DEFAULT_ATTR_KEYS.MASK: mask1, DEFAULT_ATTR_KEYS.BBOX: bbox1}) - graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "label": 7, DEFAULT_ATTR_KEYS.MASK: mask2}) + graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "label": 7, DEFAULT_ATTR_KEYS.MASK: mask2, DEFAULT_ATTR_KEYS.BBOX: bbox2}) array_view = GraphArrayView(graph=graph, shape=(10, 100, 100), attr_key="label") @@ -129,14 +129,15 @@ def test_graph_array_view_getitem_boolean_dtype() -> None: # Add attribute keys graph.add_node_attr_key("is_active", False) - graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, None) - + setup_mask_attrs(graph) # Create a mask - mask_data = np.array([[True]], dtype=bool) - mask = Mask(mask_data, bbox=np.array([10, 20, 11, 21])) + mask = np.array([[True]], dtype=bool) + bbox = np.array([10, 20, 11, 21]) # Add a node with boolean attribute - graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "is_active": True, DEFAULT_ATTR_KEYS.MASK: mask}) + graph.add_node( + {DEFAULT_ATTR_KEYS.T: 0, "is_active": True, DEFAULT_ATTR_KEYS.MASK: mask, DEFAULT_ATTR_KEYS.BBOX: bbox} + ) array_view = GraphArrayView(graph=graph, shape=(10, 100, 100), attr_key="is_active") @@ -155,14 +156,15 @@ def test_graph_array_view_dtype_inference() -> None: # Add attribute keys graph.add_node_attr_key("float_label", 0.0) - graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, None) - + setup_mask_attrs(graph) # Create a mask - mask_data = np.array([[True]], dtype=bool) - mask = Mask(mask_data, bbox=np.array([10, 20, 11, 21])) + mask = np.array([[True]], dtype=bool) + bbox = np.array([10, 20, 11, 21]) # Add a node with float attribute - graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "float_label": 3.14, DEFAULT_ATTR_KEYS.MASK: mask}) + graph.add_node( + {DEFAULT_ATTR_KEYS.T: 0, "float_label": 3.14, DEFAULT_ATTR_KEYS.MASK: mask, DEFAULT_ATTR_KEYS.BBOX: bbox} + ) array_view = GraphArrayView(graph=graph, shape=(10, 100, 100), attr_key="float_label") diff --git a/src/tracksdata/edges/_iou_edges.py b/src/tracksdata/edges/_iou_edges.py index c70a3492..7cb3cf22 100644 --- a/src/tracksdata/edges/_iou_edges.py +++ b/src/tracksdata/edges/_iou_edges.py @@ -1,6 +1,8 @@ +import numpy as np + from tracksdata.constants import DEFAULT_ATTR_KEYS from tracksdata.edges._generic_edges import GenericFuncEdgeAttrs -from tracksdata.nodes._mask import Mask +from tracksdata.functional._mask import mask_iou class IoUEdgeAttr(GenericFuncEdgeAttrs): @@ -14,15 +16,23 @@ class IoUEdgeAttr(GenericFuncEdgeAttrs): The key to use for the output of the IoU. mask_key : str The key to use for the masks of the nodes. + bbox_key : str + The key to use for the bounding boxes of the nodes. """ def __init__( self, output_key: str, mask_key: str = DEFAULT_ATTR_KEYS.MASK, + bbox_key: str = DEFAULT_ATTR_KEYS.BBOX, ): + def _compute_iou(source_attrs: dict[str, np.ndarray], target_attrs: dict[str, np.ndarray]) -> float: + return mask_iou( + source_attrs[bbox_key], source_attrs[mask_key], target_attrs[bbox_key], target_attrs[mask_key] + ) + super().__init__( - func=Mask.iou, - attr_keys=mask_key, + func=_compute_iou, + attr_keys=[mask_key, bbox_key], output_key=output_key, ) diff --git a/src/tracksdata/edges/_test/test_distance_edges.py b/src/tracksdata/edges/_test/test_distance_edges.py index b5e60756..883e2b4a 100644 --- a/src/tracksdata/edges/_test/test_distance_edges.py +++ b/src/tracksdata/edges/_test/test_distance_edges.py @@ -4,6 +4,11 @@ from tracksdata.edges import DistanceEdges from tracksdata.graph import RustWorkXGraph from tracksdata.options import get_options, options_context +from tracksdata.utils._test_utils import ( + setup_custom_node_attr, + setup_spatial_attrs_2d, + setup_spatial_attrs_3d, +) def test_distance_edges_init_default_params() -> None: @@ -48,8 +53,7 @@ def test_distance_edges_add_edges_single_timepoint_no_previous() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", 0.0) - graph.add_node_attr_key("y", 0.0) + setup_spatial_attrs_2d(graph) # Add nodes only at t=1 (no t=0) graph.add_node({DEFAULT_ATTR_KEYS.T: 1, "x": 0.0, "y": 0.0}) @@ -67,8 +71,7 @@ def test_distance_edges_add_edges_single_timepoint_no_current() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", 0.0) - graph.add_node_attr_key("y", 0.0) + setup_spatial_attrs_2d(graph) # Add nodes only at t=0 (no t=1) graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "y": 0.0}) @@ -86,8 +89,7 @@ def test_distance_edges_add_edges_2d_coordinates() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", 0.0) - graph.add_node_attr_key("y", 0.0) + setup_spatial_attrs_2d(graph) # Add nodes at t=0 _ = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "y": 0.0}) @@ -114,9 +116,7 @@ def test_distance_edges_add_edges_3d_coordinates() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", 0.0) - graph.add_node_attr_key("y", 0.0) - graph.add_node_attr_key("z", 0.0) + setup_spatial_attrs_3d(graph) # Add nodes at t=0 _ = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "y": 0.0, "z": 0.0}) @@ -139,8 +139,8 @@ def test_distance_edges_add_edges_custom_attr_keys() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("pos_x", 0.0) - graph.add_node_attr_key("pos_y", 0.0) + setup_custom_node_attr(graph, "pos_x", 0.0) + setup_custom_node_attr(graph, "pos_y", 0.0) # Add nodes at t=0 _ = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "pos_x": 0.0, "pos_y": 0.0}) @@ -163,8 +163,7 @@ def test_distance_edges_add_edges_distance_threshold() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", 0.0) - graph.add_node_attr_key("y", 0.0) + setup_spatial_attrs_2d(graph) # Add nodes at t=0 _ = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "y": 0.0}) @@ -186,8 +185,7 @@ def test_distance_edges_add_edges_multiple_timepoints(n_workers: int) -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", 0.0) - graph.add_node_attr_key("y", 0.0) + setup_spatial_attrs_2d(graph) # Add nodes at multiple timepoints for t in range(3): @@ -209,8 +207,7 @@ def test_distance_edges_add_edges_custom_weight_key() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", 0.0) - graph.add_node_attr_key("y", 0.0) + setup_spatial_attrs_2d(graph) # Add nodes at t=0 _ = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "y": 0.0}) @@ -237,8 +234,7 @@ def test_distance_edges_n_neighbors_limit() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", 0.0) - graph.add_node_attr_key("y", 0.0) + setup_spatial_attrs_2d(graph) # Add many nodes at t=0 for i in range(5): @@ -264,8 +260,7 @@ def test_distance_edges_add_edges_with_delta_t(n_workers: int) -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", 0.0) - graph.add_node_attr_key("y", 0.0) + setup_spatial_attrs_2d(graph) # Add nodes at t=0, t=1, t=2 node_0 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "y": 0.0}) diff --git a/src/tracksdata/edges/_test/test_generic_edges.py b/src/tracksdata/edges/_test/test_generic_edges.py index e900af16..7d3bef93 100644 --- a/src/tracksdata/edges/_test/test_generic_edges.py +++ b/src/tracksdata/edges/_test/test_generic_edges.py @@ -3,6 +3,11 @@ from tracksdata.constants import DEFAULT_ATTR_KEYS from tracksdata.edges import GenericFuncEdgeAttrs from tracksdata.graph import RustWorkXGraph +from tracksdata.utils._test_utils import ( + setup_custom_node_attr, + setup_edge_distance_attr, + setup_spatial_attrs_2d, +) def _scalar_distance_func(source_val: float, target_val: float) -> float: @@ -41,8 +46,8 @@ def test_generic_edges_add_weights_single_attr_key() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", 0.0) - graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, 0.0) + setup_custom_node_attr(graph, "x", 0.0) + setup_edge_distance_attr(graph) # Add nodes at time 0 node0 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 1.0}) @@ -78,9 +83,8 @@ def test_generic_edges_add_weights_multiple_attr_keys() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", 0.0) - graph.add_node_attr_key("y", 0.0) - graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, 0.0) + setup_spatial_attrs_2d(graph) + setup_edge_distance_attr(graph) # Add nodes at time 0 node0 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "y": 0.0}) @@ -112,8 +116,8 @@ def test_generic_edges_add_weights_all_time_points() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", 0.0) - graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, 0.0) + setup_custom_node_attr(graph, "x", 0.0) + setup_edge_distance_attr(graph) # Add nodes at different time points node0_t0 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 1.0}) @@ -142,7 +146,7 @@ def test_generic_edges_no_edges_at_time_point() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", 0.0) + setup_custom_node_attr(graph, "x", 0.0) # Add nodes but no edges at time 0 graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 1.0}) @@ -163,8 +167,8 @@ def test_generic_edges_creates_output_key() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", 0.0) - graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, 0.0) + setup_custom_node_attr(graph, "x", 0.0) + setup_edge_distance_attr(graph) # Add nodes and edge node0 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 1.0}) @@ -191,9 +195,9 @@ def test_generic_edges_dict_input_function() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("value", 0.0) - graph.add_node_attr_key("weight", 0.0) - graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, 0.0) + setup_custom_node_attr(graph, "value", 0.0) + setup_custom_node_attr(graph, "weight", 0.0) + setup_edge_distance_attr(graph) # Add nodes node0 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "value": 10.0, "weight": 2.0}) diff --git a/src/tracksdata/edges/_test/test_iou_edges.py b/src/tracksdata/edges/_test/test_iou_edges.py index 8fddde60..bfd64f6e 100644 --- a/src/tracksdata/edges/_test/test_iou_edges.py +++ b/src/tracksdata/edges/_test/test_iou_edges.py @@ -4,8 +4,12 @@ from tracksdata.constants import DEFAULT_ATTR_KEYS from tracksdata.edges import IoUEdgeAttr from tracksdata.graph import RustWorkXGraph -from tracksdata.nodes import Mask from tracksdata.options import get_options, options_context +from tracksdata.utils._test_utils import ( + setup_custom_node_attr, + setup_edge_distance_attr, + setup_mask_attrs, +) def test_iou_edges_init_default() -> None: @@ -13,17 +17,15 @@ def test_iou_edges_init_default() -> None: operator = IoUEdgeAttr(output_key="iou_score") assert operator.output_key == "iou_score" - assert operator.attr_keys == DEFAULT_ATTR_KEYS.MASK - assert operator.func == Mask.iou + assert operator.attr_keys == [DEFAULT_ATTR_KEYS.MASK, DEFAULT_ATTR_KEYS.BBOX] def test_iou_edges_init_custom() -> None: """Test IoUEdgesOperator initialization with custom parameters.""" - operator = IoUEdgeAttr(output_key="custom_iou", mask_key="custom_mask") + operator = IoUEdgeAttr(output_key="custom_iou", mask_key="custom_mask", bbox_key="custom_bbox") assert operator.output_key == "custom_iou" - assert operator.attr_keys == "custom_mask" - assert operator.func == Mask.iou + assert operator.attr_keys == ["custom_mask", "custom_bbox"] @pytest.mark.parametrize("n_workers", [1, 2]) @@ -31,24 +33,24 @@ def test_iou_edges_add_weights(n_workers: int) -> None: """Test adding IoU weights to edges with different worker counts.""" graph = RustWorkXGraph() - # Register attribute keys - graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, None) - graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, 0.0) + # Set up graph attributes + setup_mask_attrs(graph) + setup_edge_distance_attr(graph) - # Create test masks + # Create test masks and bboxes mask1_data = np.array([[True, True], [True, False]], dtype=bool) - mask1 = Mask(mask1_data, bbox=np.array([0, 0, 2, 2])) + bbox1 = np.array([0, 0, 2, 2]) mask2_data = np.array([[True, False], [False, False]], dtype=bool) - mask2 = Mask(mask2_data, bbox=np.array([0, 0, 2, 2])) + bbox2 = np.array([0, 0, 2, 2]) mask3_data = np.array([[True, True], [True, True]], dtype=bool) - mask3 = Mask(mask3_data, bbox=np.array([0, 0, 2, 2])) + bbox3 = np.array([0, 0, 2, 2]) - # Add nodes with masks - node1 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, DEFAULT_ATTR_KEYS.MASK: mask1}) - node2 = graph.add_node({DEFAULT_ATTR_KEYS.T: 1, DEFAULT_ATTR_KEYS.MASK: mask2}) - node3 = graph.add_node({DEFAULT_ATTR_KEYS.T: 2, DEFAULT_ATTR_KEYS.MASK: mask3}) + # Add nodes with masks and bboxes + node1 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, DEFAULT_ATTR_KEYS.MASK: mask1_data, DEFAULT_ATTR_KEYS.BBOX: bbox1}) + node2 = graph.add_node({DEFAULT_ATTR_KEYS.T: 1, DEFAULT_ATTR_KEYS.MASK: mask2_data, DEFAULT_ATTR_KEYS.BBOX: bbox2}) + node3 = graph.add_node({DEFAULT_ATTR_KEYS.T: 2, DEFAULT_ATTR_KEYS.MASK: mask3_data, DEFAULT_ATTR_KEYS.BBOX: bbox3}) # Add edge edge_id_1 = graph.add_edge(node1, node2, {DEFAULT_ATTR_KEYS.EDGE_DIST: 0.0}) @@ -81,20 +83,20 @@ def test_iou_edges_no_overlap() -> None: """Test IoU calculation with non-overlapping masks.""" graph = RustWorkXGraph() - # Register attribute keys - graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, None) - graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, 0.0) + # Set up graph attributes + setup_mask_attrs(graph) + setup_edge_distance_attr(graph) # Create non-overlapping masks mask1_data = np.array([[True, True], [False, False]], dtype=bool) - mask1 = Mask(mask1_data, bbox=np.array([0, 0, 2, 2])) + bbox1 = np.array([0, 0, 2, 2]) mask2_data = np.array([[False, False], [True, True]], dtype=bool) - mask2 = Mask(mask2_data, bbox=np.array([0, 0, 2, 2])) + bbox2 = np.array([0, 0, 2, 2]) - # Add nodes with masks - node1 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, DEFAULT_ATTR_KEYS.MASK: mask1}) - node2 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, DEFAULT_ATTR_KEYS.MASK: mask2}) + # Add nodes with masks and bboxes + node1 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, DEFAULT_ATTR_KEYS.MASK: mask1_data, DEFAULT_ATTR_KEYS.BBOX: bbox1}) + node2 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, DEFAULT_ATTR_KEYS.MASK: mask2_data, DEFAULT_ATTR_KEYS.BBOX: bbox2}) # Add edge edge_id = graph.add_edge(node1, node2, {DEFAULT_ATTR_KEYS.EDGE_DIST: 0.0}) @@ -120,18 +122,17 @@ def test_iou_edges_perfect_overlap() -> None: """Test IoU calculation with perfectly overlapping masks.""" graph = RustWorkXGraph() - # Register attribute keys - graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, None) - graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, 0.0) + # Set up graph attributes + setup_mask_attrs(graph) + setup_edge_distance_attr(graph) # Create identical masks mask_data = np.array([[True, True], [True, False]], dtype=bool) - mask1 = Mask(mask_data, bbox=np.array([0, 0, 2, 2])) - mask2 = Mask(mask_data, bbox=np.array([0, 0, 2, 2])) + bbox = np.array([0, 0, 2, 2]) - # Add nodes with masks - node1 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, DEFAULT_ATTR_KEYS.MASK: mask1}) - node2 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, DEFAULT_ATTR_KEYS.MASK: mask2}) + # Add nodes with masks and bboxes + node1 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, DEFAULT_ATTR_KEYS.MASK: mask_data, DEFAULT_ATTR_KEYS.BBOX: bbox}) + node2 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, DEFAULT_ATTR_KEYS.MASK: mask_data, DEFAULT_ATTR_KEYS.BBOX: bbox}) # Add edge edge_id = graph.add_edge(node1, node2, {DEFAULT_ATTR_KEYS.EDGE_DIST: 0.0}) @@ -156,26 +157,27 @@ def test_iou_edges_custom_mask_key() -> None: """Test IoU edges operator with custom mask key.""" graph = RustWorkXGraph() - # Register attribute keys - graph.add_node_attr_key("custom_mask", None) - graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, 0.0) + # Set up custom attributes + setup_custom_node_attr(graph, "custom_mask", None) + setup_custom_node_attr(graph, "custom_bbox", None) + setup_edge_distance_attr(graph) # Create test masks mask1_data = np.array([[True, True], [True, True]], dtype=bool) - mask1 = Mask(mask1_data, bbox=np.array([0, 0, 2, 2])) + bbox1 = np.array([0, 0, 2, 2]) mask2_data = np.array([[True, True], [False, False]], dtype=bool) - mask2 = Mask(mask2_data, bbox=np.array([0, 0, 2, 2])) + bbox2 = np.array([0, 0, 2, 2]) - # Add nodes with custom mask key - node1 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "custom_mask": mask1}) - node2 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "custom_mask": mask2}) + # Add nodes with custom mask and bbox keys + node1 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "custom_mask": mask1_data, "custom_bbox": bbox1}) + node2 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "custom_mask": mask2_data, "custom_bbox": bbox2}) # Add edge edge_id = graph.add_edge(node1, node2, {DEFAULT_ATTR_KEYS.EDGE_DIST: 0.0}) - # Create operator with custom mask key - operator = IoUEdgeAttr(output_key="iou_score", mask_key="custom_mask") + # Create operator with custom mask and bbox keys + operator = IoUEdgeAttr(output_key="iou_score", mask_key="custom_mask", bbox_key="custom_bbox") operator.add_edge_attrs(graph) # Check that IoU weights were calculated diff --git a/src/tracksdata/functional/__init__.py b/src/tracksdata/functional/__init__.py index ce4f78b1..630b8afe 100644 --- a/src/tracksdata/functional/__init__.py +++ b/src/tracksdata/functional/__init__.py @@ -1,5 +1,20 @@ """Functional utilities for graph operations.""" +from tracksdata.functional._mask import ( + crop_image_with_bbox, + mask_indices, + mask_intersection, + mask_iou, + paint_mask_to_buffer, +) from tracksdata.functional._napari import rx_digraph_to_napari_dict, to_napari_format -__all__ = ["rx_digraph_to_napari_dict", "to_napari_format"] +__all__ = [ + "crop_image_with_bbox", + "mask_indices", + "mask_intersection", + "mask_iou", + "paint_mask_to_buffer", + "rx_digraph_to_napari_dict", + "to_napari_format", +] diff --git a/src/tracksdata/functional/_disk_attrs.py b/src/tracksdata/functional/_disk_attrs.py new file mode 100644 index 00000000..890bee2c --- /dev/null +++ b/src/tracksdata/functional/_disk_attrs.py @@ -0,0 +1,165 @@ +from collections.abc import Sequence +from functools import lru_cache + +import numpy as np +import skimage.morphology as morph +from numpy.typing import NDArray + +from tracksdata.attrs import NodeAttr +from tracksdata.constants import DEFAULT_ATTR_KEYS +from tracksdata.graph._base_graph import BaseGraph +from tracksdata.nodes._base_node_attrs import BaseNodeAttrsOperator +from tracksdata.utils._logging import LOG + + +@lru_cache(maxsize=5) +def _spherical_mask( + radius: int, + ndim: int, +) -> NDArray[np.bool_]: + """ + Get a spherical mask of a given radius and dimension. + """ + if ndim == 2: + return morph.disk(radius).astype(bool) + + if ndim == 3: + return morph.ball(radius).astype(bool) + + raise ValueError(f"Spherical is only implemented for 2D and 3D, got ndim={ndim}") + + +def _create_mask_and_bbox_from_coordinates( + center: NDArray, + radius: int, + image_shape: tuple[int, ...] | None = None, +) -> tuple[NDArray[np.bool_], NDArray[np.int64]]: + """ + Create a mask and bounding box from center coordinates and radius. + + Parameters + ---------- + center : NDArray + The center of the mask. + radius : int + The radius of the mask. + image_shape : tuple[int, ...] | None + The shape of the image. When provided, crops regions outside the image. + + Returns + ------- + tuple[NDArray[np.bool_], NDArray[np.int64]] + The mask and bounding box arrays. + """ + mask = _spherical_mask(radius, len(center)) + center = np.round(center).astype(int) + + start = center - np.asarray(mask.shape) // 2 + end = start + mask.shape + + if image_shape is None: + bbox = np.concatenate([start, end], dtype=int) + else: + processed_start = np.maximum(start, 0) + processed_end = np.minimum(end, image_shape) + + start_overhang = processed_start - start + end_overhang = end - processed_end + + mask = mask[tuple(slice(s, -e if e > 0 else None) for s, e in zip(start_overhang, end_overhang, strict=True))] + + bbox = np.concatenate([processed_start, processed_end], dtype=int) + + return mask, bbox + + +class MaskDiskAttrs(BaseNodeAttrsOperator): + """ + Operator to create disk masks and bounding boxes for each node. + + Creates spherical masks in space, so temporal information should not be provided. + + Parameters + ---------- + radius : int + The radius of the mask. + image_shape : tuple[int, ...] + The shape of the image, must match the number of attr_keys. + attr_keys : Sequence[str] | None + The attributes for the center of the mask. + If not provided, "z", "y", "x" will be used. + mask_output_key : str + The key to store the mask attribute. + bbox_output_key : str + The key to store the bounding box attribute. + """ + + def __init__( + self, + radius: int, + image_shape: tuple[int, ...], + attr_keys: Sequence[str] | None = None, + mask_output_key: str = DEFAULT_ATTR_KEYS.MASK, + bbox_output_key: str = DEFAULT_ATTR_KEYS.BBOX, + ): + super().__init__(mask_output_key) # Primary output key for base class + + if attr_keys is None: + default_columns = ["z", "y", "x"] + attr_keys = default_columns[-len(image_shape) :] + + if len(attr_keys) != len(image_shape): + raise ValueError( + f"Expected image shape {image_shape} to have the same number of dimensions as attr_keys '{attr_keys}'." + ) + + self.radius = radius + self.image_shape = image_shape + self.attr_keys = attr_keys + self.mask_output_key = mask_output_key + self.bbox_output_key = bbox_output_key + + def _init_node_attrs(self, graph: BaseGraph) -> None: + """Initialize the node attributes for the graph.""" + if self.mask_output_key not in graph.node_attr_keys: + graph.add_node_attr_key(self.mask_output_key, default_value=None) + if self.bbox_output_key not in graph.node_attr_keys: + graph.add_node_attr_key(self.bbox_output_key, default_value=None) + + def _node_attrs_per_time( + self, + t: int, + *, + graph: BaseGraph, + frames: NDArray | None = None, + ) -> tuple[list[int], dict[str, list]]: + """ + Add mask and bbox attributes to nodes for a specific time point. + """ + # Get node IDs for the specified time point + graph_filter = graph.filter(NodeAttr(DEFAULT_ATTR_KEYS.T) == t) + + if graph_filter.is_empty(): + LOG.warning(f"No nodes at time point {t}") + return [], {} + + # Get attributes for these nodes + node_attrs = graph_filter.node_attrs(attr_keys=self.attr_keys) + + masks = [] + bboxes = [] + + for data_dict in node_attrs.rows(named=True): + center = np.asarray([data_dict[key] for key in self.attr_keys]) + mask, bbox = _create_mask_and_bbox_from_coordinates( + center=center, + radius=self.radius, + image_shape=self.image_shape, + ) + masks.append(mask) + bboxes.append(bbox) + + return graph_filter.node_ids(), { + self.mask_output_key: masks, + self.bbox_output_key: bboxes, + } diff --git a/src/tracksdata/functional/_mask.py b/src/tracksdata/functional/_mask.py new file mode 100644 index 00000000..bb9deaad --- /dev/null +++ b/src/tracksdata/functional/_mask.py @@ -0,0 +1,321 @@ +""" +Utility functions for working with mask and bounding box arrays. +""" + +import numpy as np +from numpy.typing import NDArray + +from tracksdata.functional._iou import fast_intersection_with_bbox, fast_iou_with_bbox + + +def crop_image_with_bbox( + image: NDArray, + bbox: NDArray[np.integer], + shape: tuple[int, ...] | None = None, +) -> NDArray: + """ + Crop an image using bounding box coordinates. + + Parameters + ---------- + image : NDArray + The image to crop from. + bbox : NDArray[np.integer] + The bounding box coordinates with shape (2 * ndim,). + First ndim elements are start indices, last ndim elements are end indices. + shape : tuple[int, ...] | None, optional + The shape of the cropped image. If None, the bbox will be used. + + Returns + ------- + NDArray + The cropped image. + + Raises + ------ + ValueError + If bbox length is not even or image dimensions don't match expected bbox dimensions. + + Examples + -------- + Crop a 2D image using a bounding box: + + ```python + import numpy as np + + image = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]) + bbox = np.array([1, 1, 3, 3]) # [min_y, min_x, max_y, max_x] + cropped = crop_image_with_bbox(image, bbox) + cropped + array([[6, 7], [10, 11]]) + ``` + + Crop with a fixed output shape: + ```python + cropped_fixed = crop_image_with_bbox(image, bbox, shape=(1, 4)) + cropped_fixed.shape + (1, 4) + ``` + """ + if len(bbox) % 2 != 0: + raise ValueError(f"Bbox must have even length, got {len(bbox)}") + + ndim = len(bbox) // 2 + + if image.ndim != ndim: + raise ValueError(f"Image dimensions ({image.ndim}) must match bbox dimensions ({ndim})") + + # Validate bbox coordinates + start_coords = bbox[:ndim] + end_coords = bbox[ndim:] + + if np.any(start_coords < 0) or np.any(end_coords <= start_coords): + raise ValueError(f"Invalid bbox coordinates: {bbox}") + + if shape is None: + slicing = tuple(slice(bbox[i], bbox[i + ndim]) for i in range(ndim)) + else: + if len(shape) != ndim: + raise ValueError(f"Shape dimensions ({len(shape)}) must match bbox dimensions ({ndim})") + + center = (bbox[:ndim] + bbox[ndim:]) // 2 + half_shape = np.asarray(shape) // 2 + start = np.maximum(center - half_shape, 0) + end = np.minimum(center + half_shape, image.shape) + slicing = tuple(slice(s, e) for s, e in zip(start, end, strict=True)) + + return image[slicing] + + +def mask_indices( + bbox: NDArray[np.integer], + mask: NDArray[np.bool_], + offset: NDArray[np.integer] | int = 0, +) -> tuple[NDArray[np.integer], ...]: + """ + Get the indices of pixels that are part of the mask in global coordinates. + + Parameters + ---------- + bbox : NDArray[np.integer] + The bounding box coordinates with shape (2 * ndim,). + mask : NDArray[np.bool_] + Binary mask indicating valid pixels. + offset : NDArray[np.integer] | int, optional + Additional offset to add to the indices. + + Returns + ------- + tuple[NDArray[np.integer], ...] + The indices of pixels that are part of the mask. + + Raises + ------ + ValueError + If bbox length is not even or mask dimensions don't match expected bbox dimensions. + + Examples + -------- + Get global indices for a 2D mask: + + >>> import numpy as np + >>> mask = np.array([[True, False], [False, True]]) + >>> bbox = np.array([10, 20, 12, 22]) # [min_y, min_x, max_y, max_x] + >>> y_indices, x_indices = mask_indices(bbox, mask) + >>> y_indices + array([10, 11]) + >>> x_indices + array([20, 21]) + + With an offset: + + >>> y_indices, x_indices = mask_indices(bbox, mask, offset=5) + >>> y_indices + array([15, 16]) + >>> x_indices + array([25, 26]) + """ + if len(bbox) % 2 != 0: + raise ValueError(f"Bbox must have even length, got {len(bbox)}") + + ndim = len(bbox) // 2 + + if mask.ndim != ndim: + raise ValueError(f"Mask dimensions ({mask.ndim}) must match bbox dimensions ({ndim})") + + if isinstance(offset, int): + offset = np.full(ndim, offset) + + indices = list(np.nonzero(mask)) + + for i, index in enumerate(indices): + indices[i] = index + bbox[i] + offset[i] + + return tuple(indices) + + +def paint_mask_to_buffer( + buffer: np.ndarray, + bbox: NDArray[np.integer], + mask: NDArray[np.bool_], + value: int | float, + offset: NDArray[np.integer] | int = 0, +) -> None: + """ + Paint mask pixels into a buffer. + + Parameters + ---------- + buffer : np.ndarray + The buffer to paint into (modified in place). + bbox : NDArray[np.integer] + The bounding box coordinates. + mask : NDArray[np.bool_] + Binary mask indicating pixels to paint. + value : int | float + The value to paint. + offset : NDArray[np.integer] | int, optional + Additional offset to add to the indices. + + Raises + ------ + ValueError + If buffer dimensions don't match bbox dimensions or if bbox is invalid. + + Examples + -------- + Paint mask pixels into a buffer: + + >>> import numpy as np + >>> buffer = np.zeros((5, 5)) + >>> mask = np.array([[True, False], [False, True]]) + >>> bbox = np.array([1, 1, 3, 3]) # [min_y, min_x, max_y, max_x] + >>> paint_mask_to_buffer(buffer, bbox, mask, value=255) + >>> buffer[1:3, 1:3] + array([[255., 0.], + [ 0., 255.]]) + """ + if len(bbox) % 2 != 0: + raise ValueError(f"Bbox must have even length, got {len(bbox)}") + + ndim = len(bbox) // 2 + + if buffer.ndim != ndim: + raise ValueError(f"Buffer dimensions ({buffer.ndim}) must match bbox dimensions ({ndim})") + indices = mask_indices(bbox, mask, offset) + buffer[indices] = value + + +def mask_iou( + bbox1: NDArray[np.integer], + mask1: NDArray[np.bool_], + bbox2: NDArray[np.integer], + mask2: NDArray[np.bool_], +) -> float: + """ + Compute Intersection over Union (IoU) between two mask/bbox pairs. + + Parameters + ---------- + bbox1 : NDArray[np.integer] + First bounding box coordinates. + mask1 : NDArray[np.bool_] + First binary mask. + bbox2 : NDArray[np.integer] + Second bounding box coordinates. + mask2 : NDArray[np.bool_] + Second binary mask. + + Returns + ------- + float + The IoU value between 0 and 1. + + Raises + ------ + ValueError + If bboxes have different dimensions or masks have different dimensions. + + Examples + -------- + Calculate IoU between two overlapping masks: + + >>> import numpy as np + >>> mask1 = np.array([[True, True], [True, False]]) + >>> bbox1 = np.array([0, 0, 2, 2]) + >>> mask2 = np.array([[True, False], [True, True]]) + >>> bbox2 = np.array([0, 0, 2, 2]) + >>> iou = mask_iou(bbox1, mask1, bbox2, mask2) + >>> iou # 2 intersection pixels / 4 union pixels + 0.5 + """ + bbox1 = np.asarray(bbox1, dtype=np.int64) + bbox2 = np.asarray(bbox2, dtype=np.int64) + mask1 = np.asarray(mask1, dtype=bool) + mask2 = np.asarray(mask2, dtype=bool) + + if len(bbox1) != len(bbox2): + raise ValueError(f"Bboxes must have same length, got {len(bbox1)} and {len(bbox2)}") + + if mask1.ndim != mask2.ndim: + raise ValueError(f"Masks must have same dimensions, got {mask1.ndim} and {mask2.ndim}") + + return fast_iou_with_bbox(bbox1, bbox2, mask1, mask2) + + +def mask_intersection( + bbox1: NDArray[np.integer], + mask1: NDArray[np.bool_], + bbox2: NDArray[np.integer], + mask2: NDArray[np.bool_], +) -> float: + """ + Compute intersection between two mask/bbox pairs. + + Parameters + ---------- + bbox1 : NDArray[np.integer] + First bounding box coordinates. + mask1 : NDArray[np.bool_] + First binary mask. + bbox2 : NDArray[np.integer] + Second bounding box coordinates. + mask2 : NDArray[np.bool_] + Second binary mask. + + Returns + ------- + float + The intersection value. + + Raises + ------ + ValueError + If bboxes have different dimensions or masks have different dimensions. + + Examples + -------- + Calculate intersection between two masks: + + >>> import numpy as np + >>> mask1 = np.array([[True, True], [True, False]]) + >>> bbox1 = np.array([0, 0, 2, 2]) + >>> mask2 = np.array([[True, False], [True, True]]) + >>> bbox2 = np.array([0, 0, 2, 2]) + >>> intersection = mask_intersection(bbox1, mask1, bbox2, mask2) + >>> intersection # 2 overlapping pixels + 2.0 + """ + # Ensure inputs are numpy arrays for numba compatibility + bbox1 = np.asarray(bbox1, dtype=np.int64) + bbox2 = np.asarray(bbox2, dtype=np.int64) + mask1 = np.asarray(mask1, dtype=bool) + mask2 = np.asarray(mask2, dtype=bool) + + if len(bbox1) != len(bbox2): + raise ValueError(f"Bboxes must have same length, got {len(bbox1)} and {len(bbox2)}") + + if mask1.ndim != mask2.ndim: + raise ValueError(f"Masks must have same dimensions, got {mask1.ndim} and {mask2.ndim}") + + return fast_intersection_with_bbox(bbox1, bbox2, mask1, mask2) diff --git a/src/tracksdata/functional/_test/test_disk_attrs.py b/src/tracksdata/functional/_test/test_disk_attrs.py new file mode 100644 index 00000000..8664de6a --- /dev/null +++ b/src/tracksdata/functional/_test/test_disk_attrs.py @@ -0,0 +1,315 @@ +import numpy as np +import pytest + +from tracksdata.constants import DEFAULT_ATTR_KEYS +from tracksdata.functional._disk_attrs import ( + MaskDiskAttrs, + _create_mask_and_bbox_from_coordinates, + _spherical_mask, +) +from tracksdata.graph import RustWorkXGraph +from tracksdata.utils._test_utils import setup_spatial_attrs_2d, setup_spatial_attrs_3d + + +def test_spherical_mask_2d() -> None: + """Test spherical mask creation in 2D.""" + mask = _spherical_mask(radius=2, ndim=2) + + assert mask.dtype == bool + assert mask.shape == (5, 5) # Disk radius 2 creates 5x5 mask + assert mask[2, 2] # Center should be True + + +def test_spherical_mask_3d() -> None: + """Test spherical mask creation in 3D.""" + mask = _spherical_mask(radius=1, ndim=3) + + assert mask.dtype == bool + assert mask.shape == (3, 3, 3) # Ball radius 1 creates 3x3x3 mask + assert mask[1, 1, 1] # Center should be True + + +def test_spherical_mask_invalid_ndim() -> None: + """Test spherical mask with invalid dimensions.""" + with pytest.raises(ValueError, match="Spherical is only implemented for 2D and 3D"): + _spherical_mask(radius=1, ndim=4) + + +def test_create_mask_and_bbox_from_coordinates_2d_basic() -> None: + """Test 2D mask creation and bbox without cropping.""" + center = np.asarray([5, 5]) + radius = 2 + + mask, bbox = _create_mask_and_bbox_from_coordinates(center, radius) + + # Should be a disk of radius 2, shape (5,5), centered at (5,5) + assert mask.shape == (5, 5) + assert mask[2, 2] # center pixel is True + np.testing.assert_array_equal(bbox, [3, 3, 8, 8]) + + +def test_create_mask_and_bbox_from_coordinates_3d_basic() -> None: + """Test 3D mask creation and bbox without cropping.""" + center = np.asarray([4, 5, 6]) + radius = 1 + + mask, bbox = _create_mask_and_bbox_from_coordinates(center, radius) + + # Should be a ball of radius 1, shape (3,3,3), centered at (4,5,6) + assert mask.shape == (3, 3, 3) + assert mask[1, 1, 1] # center voxel is True + np.testing.assert_array_equal(bbox, [3, 4, 5, 6, 7, 8]) + + +def test_create_mask_and_bbox_from_coordinates_cropping() -> None: + """Test cropping when mask falls outside the image boundary.""" + center = np.asarray([0, 0]) + radius = 5 + image_shape = (4, 3) + + mask, bbox = _create_mask_and_bbox_from_coordinates(center, radius, image_shape=image_shape) + + # Mask shape should match the bbox size + expected_shape = (4, 3) + assert mask.shape == expected_shape + + # Mask should be cropped to fit within image bounds + np.testing.assert_array_equal(bbox, [0, 0, 4, 3]) + + +def test_create_mask_and_bbox_from_coordinates_partial_cropping() -> None: + """Test partial cropping when mask partially falls outside boundary.""" + center = np.asarray([2, 8]) # Near right edge + radius = 2 + image_shape = (10, 10) + + mask, bbox = _create_mask_and_bbox_from_coordinates(center, radius, image_shape=image_shape) + + # Should be cropped on the right side + assert mask.shape[1] < 5 # Original disk would be 5x5 + assert bbox[3] == 10 # Right edge should be image boundary + assert bbox[1] == 6 # Left edge should be center - radius + + +def test_create_mask_and_bbox_from_coordinates_float_center() -> None: + """Test with float center coordinates (should be rounded).""" + center = np.asarray([2.7, 3.2]) + radius = 1 + + mask, bbox = _create_mask_and_bbox_from_coordinates(center, radius) + + # Center should be rounded to [3, 3] + expected_bbox = [2, 2, 5, 5] # center (3,3) with radius 1 + np.testing.assert_array_equal(bbox, expected_bbox) + + +def test_mask_disk_attrs_init_default() -> None: + """Test MaskDiskAttrs initialization with default parameters.""" + operator = MaskDiskAttrs(radius=2, image_shape=(10, 10)) + + assert operator.radius == 2 + assert operator.image_shape == (10, 10) + assert operator.attr_keys == ["y", "x"] # Default for 2D + assert operator.mask_output_key == DEFAULT_ATTR_KEYS.MASK + assert operator.bbox_output_key == DEFAULT_ATTR_KEYS.BBOX + + +def test_mask_disk_attrs_init_custom() -> None: + """Test MaskDiskAttrs initialization with custom parameters.""" + operator = MaskDiskAttrs( + radius=3, + image_shape=(5, 10, 15), + attr_keys=["z", "y", "x"], + mask_output_key="custom_mask", + bbox_output_key="custom_bbox", + ) + + assert operator.radius == 3 + assert operator.image_shape == (5, 10, 15) + assert operator.attr_keys == ["z", "y", "x"] + assert operator.mask_output_key == "custom_mask" + assert operator.bbox_output_key == "custom_bbox" + + +def test_mask_disk_attrs_init_auto_attr_keys() -> None: + """Test automatic attr_keys selection based on image_shape.""" + # 2D case + operator_2d = MaskDiskAttrs(radius=1, image_shape=(10, 20)) + assert operator_2d.attr_keys == ["y", "x"] + + # 3D case + operator_3d = MaskDiskAttrs(radius=1, image_shape=(5, 10, 15)) + assert operator_3d.attr_keys == ["z", "y", "x"] + + +def test_mask_disk_attrs_init_dimension_mismatch() -> None: + """Test error when image_shape and attr_keys have different dimensions.""" + with pytest.raises(ValueError, match="Expected image shape"): + MaskDiskAttrs(radius=1, image_shape=(10, 20), attr_keys=["z", "y", "x"]) # 3D keys for 2D shape + + +def test_mask_disk_attrs_add_nodes_2d() -> None: + """Test adding disk masks to 2D nodes.""" + graph = RustWorkXGraph() + + # Initialize required attributes + setup_spatial_attrs_2d(graph) + + # Add 2 nodes at t=0 + node_attrs = [{DEFAULT_ATTR_KEYS.T: 0, "y": 5.0, "x": 10.0}, {DEFAULT_ATTR_KEYS.T: 0, "y": 15.0, "x": 20.0}] + graph.bulk_add_nodes(node_attrs) + + # Add disk masks + disk_operator = MaskDiskAttrs(radius=2, image_shape=(30, 40)) + disk_operator.add_node_attrs(graph) + + # Check that masks and bboxes were added + nodes_df = graph.node_attrs() + assert DEFAULT_ATTR_KEYS.MASK in nodes_df.columns + assert DEFAULT_ATTR_KEYS.BBOX in nodes_df.columns + + # Check that we have 2 nodes with masks + assert len(nodes_df) == 2 + + # Check mask properties + masks = nodes_df[DEFAULT_ATTR_KEYS.MASK] + bboxes = nodes_df[DEFAULT_ATTR_KEYS.BBOX] + + for i in range(len(masks)): + mask = masks[i] + bbox = bboxes[i] + + assert isinstance(mask, np.ndarray) + assert mask.dtype == bool + assert mask.shape == (5, 5) # Radius 2 creates 5x5 disk + + assert len(bbox) == 4 # [min_y, min_x, max_y, max_x] + + +def test_mask_disk_attrs_add_nodes_3d() -> None: + """Test adding disk masks to 3D nodes.""" + graph = RustWorkXGraph() + + # Initialize required attributes + setup_spatial_attrs_3d(graph) + + # Add 2 nodes at t=0 with 3D coordinates + node_attrs = [ + {DEFAULT_ATTR_KEYS.T: 0, "z": 2.0, "y": 5.0, "x": 10.0}, + {DEFAULT_ATTR_KEYS.T: 0, "z": 8.0, "y": 15.0, "x": 20.0}, + ] + graph.bulk_add_nodes(node_attrs) + + # Add disk masks + disk_operator = MaskDiskAttrs(radius=1, image_shape=(10, 20, 30)) + disk_operator.add_node_attrs(graph) + + # Check that masks and bboxes were added + nodes_df = graph.node_attrs() + assert DEFAULT_ATTR_KEYS.MASK in nodes_df.columns + assert DEFAULT_ATTR_KEYS.BBOX in nodes_df.columns + + # Check mask properties + masks = nodes_df[DEFAULT_ATTR_KEYS.MASK] + bboxes = nodes_df[DEFAULT_ATTR_KEYS.BBOX] + + for i in range(len(masks)): + mask = masks[i] + bbox = bboxes[i] + + assert isinstance(mask, np.ndarray) + assert mask.dtype == bool + assert mask.shape == (3, 3, 3) # Radius 1 creates 3x3x3 ball + + assert len(bbox) == 6 # [min_z, min_y, min_x, max_z, max_y, max_x] + + +def test_mask_disk_attrs_with_cropping() -> None: + """Test disk mask creation with image boundary cropping.""" + graph = RustWorkXGraph() + + # Initialize required attributes + setup_spatial_attrs_2d(graph) + + # Add a node near the edge that will require cropping + node_attrs = [{DEFAULT_ATTR_KEYS.T: 0, "y": 1.0, "x": 1.0}] # Near top-left corner + graph.bulk_add_nodes(node_attrs) + + # Add disk masks with small image shape to force cropping + disk_operator = MaskDiskAttrs(radius=3, image_shape=(5, 5)) + disk_operator.add_node_attrs(graph) + + # Check that mask was cropped + nodes_df = graph.node_attrs() + mask = nodes_df[DEFAULT_ATTR_KEYS.MASK][0] + bbox = nodes_df[DEFAULT_ATTR_KEYS.BBOX][0] + + # Mask should be smaller than full disk due to cropping + assert mask.shape[0] < 7 # Full disk would be 7x7 + assert mask.shape[1] < 7 + + # Bbox should start at image boundary + assert bbox[0] == 0 # min_y should be 0 (image boundary) + assert bbox[1] == 0 # min_x should be 0 (image boundary) + + +def test_mask_disk_attrs_empty_time_point() -> None: + """Test behavior when no nodes exist at specified time point.""" + graph = RustWorkXGraph() + + # Don't add any nodes, so t=0 will be empty + disk_operator = MaskDiskAttrs(radius=1, image_shape=(10, 10)) + + # This should handle empty time point gracefully + node_ids, attrs = disk_operator._node_attrs_per_time(t=0, graph=graph) + + assert node_ids == [] + assert attrs == {} + + +def test_mask_disk_attrs_multiple_time_points() -> None: + """Test disk mask creation across multiple time points.""" + graph = RustWorkXGraph() + + # Initialize required attributes + setup_spatial_attrs_2d(graph) + + # Add nodes at different time points + node_attrs = [ + {DEFAULT_ATTR_KEYS.T: 0, "y": 5.0, "x": 5.0}, # t=0 + {DEFAULT_ATTR_KEYS.T: 1, "y": 10.0, "x": 10.0}, # t=1 + {DEFAULT_ATTR_KEYS.T: 1, "y": 15.0, "x": 15.0}, # t=1 + ] + graph.bulk_add_nodes(node_attrs) + + # Add disk masks + disk_operator = MaskDiskAttrs(radius=1, image_shape=(20, 20)) + disk_operator.add_node_attrs(graph) + + # Check that all nodes got masks + nodes_df = graph.node_attrs() + assert len(nodes_df) == 3 # 1 + 2 nodes + + # Check that masks were added for all time points + assert all(nodes_df[DEFAULT_ATTR_KEYS.MASK].is_not_null()) + assert all(nodes_df[DEFAULT_ATTR_KEYS.BBOX].is_not_null()) + + +def test_mask_disk_attrs_init_graph_attributes() -> None: + """Test that graph attributes are properly initialized.""" + graph = RustWorkXGraph() + + disk_operator = MaskDiskAttrs( + radius=1, image_shape=(10, 10), mask_output_key="test_mask", bbox_output_key="test_bbox" + ) + + # Before initialization, attributes shouldn't exist + assert "test_mask" not in graph.node_attr_keys + assert "test_bbox" not in graph.node_attr_keys + + # Initialize attributes + disk_operator._init_node_attrs(graph) + + # After initialization, attributes should exist + assert "test_mask" in graph.node_attr_keys + assert "test_bbox" in graph.node_attr_keys diff --git a/src/tracksdata/nodes/_test/test_mask.py b/src/tracksdata/functional/_test/test_mask.py similarity index 56% rename from src/tracksdata/nodes/_test/test_mask.py rename to src/tracksdata/functional/_test/test_mask.py index ebde2c8e..2a12ce3b 100644 --- a/src/tracksdata/nodes/_test/test_mask.py +++ b/src/tracksdata/functional/_test/test_mask.py @@ -1,36 +1,12 @@ import numpy as np -from tracksdata.nodes._mask import Mask - - -def test_mask_init() -> None: - """Test Mask initialization.""" - mask_array = np.array([[True, False], [False, True]], dtype=bool) - bbox = np.array([0, 0, 2, 2]) - - mask = Mask(mask_array, bbox) - assert np.array_equal(mask._mask, mask_array) - assert np.array_equal(mask._bbox, bbox) - - -def test_mask_getstate_setstate() -> None: - """Test Mask serialization and deserialization.""" - mask_array = np.array([[True, False], [False, True]], dtype=bool) - bbox = np.array([0, 0, 2, 2]) - - mask = Mask(mask_array, bbox) - - # Test serialization - state = mask.__getstate__() - assert "_mask" in state - assert np.array_equal(state["_bbox"], bbox) - - # Test deserialization - new_mask = Mask.__new__(Mask) - new_mask.__setstate__(state) - # After deserialization, the mask should be restored correctly - assert np.array_equal(new_mask._mask, mask_array) - assert np.array_equal(new_mask._bbox, bbox) +from tracksdata.functional._mask import ( + crop_image_with_bbox, + mask_indices, + mask_intersection, + mask_iou, + paint_mask_to_buffer, +) def test_mask_indices_no_offset() -> None: @@ -38,8 +14,7 @@ def test_mask_indices_no_offset() -> None: mask_array = np.array([[True, False], [False, True]], dtype=bool) bbox = np.array([1, 2, 3, 4]) # min_y, min_x, max_y, max_x - mask = Mask(mask_array, bbox) - indices = mask.mask_indices() + indices = mask_indices(bbox, mask_array) # True values are at positions (0,0) and (1,1) in the mask # With bbox offset [1, 2]: (0+1, 0+2) and (1+1, 1+2) = (1, 2) and (2, 3) @@ -56,8 +31,7 @@ def test_mask_indices_with_scalar_offset() -> None: mask_array = np.array([[True, False], [False, True]], dtype=bool) bbox = np.array([1, 2, 3, 4]) - mask = Mask(mask_array, bbox) - indices = mask.mask_indices(offset=5) + indices = mask_indices(bbox, mask_array, offset=5) # True values at (0,0) and (1,1) in mask # With bbox [1, 2] and offset 5: (0+1+5, 0+2+5) and (1+1+5, 1+2+5) = (6, 7) and (7, 8) @@ -74,9 +48,8 @@ def test_mask_indices_with_array_offset() -> None: mask_array = np.array([[True, False], [False, True]], dtype=bool) bbox = np.array([1, 2, 3, 4]) - mask = Mask(mask_array, bbox) offset = np.array([3, 4]) - indices = mask.mask_indices(offset=offset) + indices = mask_indices(bbox, mask_array, offset=offset) # True values at (0,0) and (1,1) in mask # With bbox [1, 2] and offset [3, 4]: (0+1+3, 0+2+4) and (1+1+3, 1+2+4) = (4, 6) and (5, 7) @@ -93,8 +66,7 @@ def test_mask_indices_3d() -> None: mask_array = np.array([[[True, False], [False, False]], [[False, False], [False, True]]], dtype=bool) bbox = np.array([1, 2, 3, 3, 4, 5]) # min_z, min_y, min_x, max_z, max_y, max_x - mask = Mask(mask_array, bbox) - indices = mask.mask_indices() + indices = mask_indices(bbox, mask_array) # True values at (0,0,0) and (1,1,1) in mask # With bbox offset [1,2,3]: (0+1, 0+2, 0+3) and (1+1, 1+2, 1+3) = (1,2,3) and (2,3,4) @@ -108,16 +80,14 @@ def test_mask_indices_3d() -> None: assert np.array_equal(indices[2], expected_x) -def test_paint_buffer() -> None: - """Test paint_buffer method.""" +def test_paint_mask_to_buffer() -> None: + """Test paint_mask_to_buffer function.""" mask_array = np.array([[True, False], [False, True]], dtype=bool) bbox = np.array([0, 0, 2, 2]) - mask = Mask(mask_array, bbox) - # Create a buffer to paint on buffer = np.zeros((4, 4), dtype=float) - mask.paint_buffer(buffer, value=5.0) + paint_mask_to_buffer(buffer, bbox, mask_array, value=5.0) # Check that the correct positions are painted expected_buffer = np.zeros((4, 4), dtype=float) @@ -127,17 +97,15 @@ def test_paint_buffer() -> None: assert np.array_equal(buffer, expected_buffer) -def test_paint_buffer_with_offset() -> None: - """Test paint_buffer method with offset.""" +def test_paint_mask_to_buffer_with_offset() -> None: + """Test paint_mask_to_buffer function with offset.""" mask_array = np.array([[True, False], [False, True]], dtype=bool) bbox = np.array([0, 0, 2, 2]) - mask = Mask(mask_array, bbox) - # Create a buffer to paint on buffer = np.zeros((6, 6), dtype=float) offset = np.array([2, 3]) - mask.paint_buffer(buffer, value=7.0, offset=offset) + paint_mask_to_buffer(buffer, bbox, mask_array, value=7.0, offset=offset) # Check that the correct positions are painted with offset expected_buffer = np.zeros((6, 6), dtype=float) @@ -148,17 +116,15 @@ def test_paint_buffer_with_offset() -> None: def test_mask_iou() -> None: - """Test IoU calculation between masks.""" + """Test IoU calculation between mask/bbox pairs.""" # Create two overlapping masks mask1_array = np.array([[True, True], [True, False]], dtype=bool) bbox1 = np.array([0, 0, 2, 2]) - mask1 = Mask(mask1_array, bbox1) mask2_array = np.array([[True, False], [True, True]], dtype=bool) bbox2 = np.array([0, 0, 2, 2]) - mask2 = Mask(mask2_array, bbox2) - iou = mask1.iou(mask2) + iou = mask_iou(bbox1, mask1_array, bbox2, mask2_array) # Intersection: positions (0,0) and (1,0) = 2 pixels # Union: 3 + 3 - 2 = 4 pixels @@ -171,13 +137,11 @@ def test_mask_iou_no_overlap() -> None: """Test IoU calculation with non-overlapping masks.""" mask1_array = np.array([[True, False], [False, False]], dtype=bool) bbox1 = np.array([0, 0, 2, 2]) - mask1 = Mask(mask1_array, bbox1) mask2_array = np.array([[False, False], [False, True]], dtype=bool) bbox2 = np.array([0, 0, 2, 2]) - mask2 = Mask(mask2_array, bbox2) - iou = mask1.iou(mask2) + iou = mask_iou(bbox1, mask1_array, bbox2, mask2_array) assert iou == 0.0 @@ -186,20 +150,32 @@ def test_mask_iou_identical() -> None: mask_array = np.array([[True, False], [False, True]], dtype=bool) bbox = np.array([0, 0, 2, 2]) - mask1 = Mask(mask_array, bbox) - mask2 = Mask(mask_array.copy(), bbox.copy()) - - iou = mask1.iou(mask2) + iou = mask_iou(bbox, mask_array, bbox.copy(), mask_array.copy()) assert iou == 1.0 +def test_mask_intersection() -> None: + """Test intersection calculation between mask/bbox pairs.""" + # Create two overlapping masks + mask1_array = np.array([[True, True], [True, False]], dtype=bool) + bbox1 = np.array([0, 0, 2, 2]) + + mask2_array = np.array([[True, False], [True, True]], dtype=bool) + bbox2 = np.array([0, 0, 2, 2]) + + intersection = mask_intersection(bbox1, mask1_array, bbox2, mask2_array) + + # Intersection: positions (0,0) and (1,0) = 2 pixels + expected_intersection = 2.0 + assert abs(intersection - expected_intersection) < 1e-6 + + def test_mask_empty() -> None: """Test mask with no True values.""" mask_array = np.array([[False, False], [False, False]], dtype=bool) bbox = np.array([0, 0, 2, 2]) - mask = Mask(mask_array, bbox) - indices = mask.mask_indices() + indices = mask_indices(bbox, mask_array) # Should return empty arrays assert len(indices) == 2 @@ -212,8 +188,7 @@ def test_mask_all_true() -> None: mask_array = np.array([[True, True], [True, True]], dtype=bool) bbox = np.array([1, 1, 3, 3]) - mask = Mask(mask_array, bbox) - indices = mask.mask_indices() + indices = mask_indices(bbox, mask_array) # Should return all positions expected_y = np.array([1, 1, 2, 2]) @@ -224,70 +199,23 @@ def test_mask_all_true() -> None: assert np.array_equal(indices[1], expected_x) -def test_mask_repr() -> None: - """Test mask representation.""" - mask_array = np.array([[True, False], [False, True]], dtype=bool) - bbox = np.array([0, 0, 2, 2]) +def test_crop_image_with_bbox() -> None: + """Test image cropping with bbox.""" + bbox = np.array([1, 1, 3, 3]) + image = np.array([[0, 0, 0, 0], [0, 1, 2, 0], [0, 3, 4, 0], [0, 0, 0, 0]]) - mask = Mask(mask_array, bbox) - assert repr(mask) == "Mask(bbox=[0:2, 0:2])" + cropped_image = crop_image_with_bbox(image, bbox) + expected_crop = image[1:3, 1:3] + assert np.array_equal(cropped_image, expected_crop) -def test_mask_crop() -> None: - """Test mask cropping.""" - mask_array = np.array([[True, False], [False, True]], dtype=bool) - bbox = np.array([1, 1, 3, 3]) - mask = Mask(mask_array, bbox) +def test_crop_image_with_bbox_and_shape() -> None: + """Test image cropping with bbox and specific shape.""" + bbox = np.array([1, 1, 3, 3]) image = np.array([[0, 0, 0, 0], [0, 1, 2, 0], [0, 3, 4, 0], [0, 0, 0, 0]]) - cropped_image = mask.crop(image) - assert np.array_equal(cropped_image, image[1:3, 1:3]) + cropped_image = crop_image_with_bbox(image, bbox, shape=(2, 4)) + expected_crop = image[1:3, 0:4] -def test_mask_crop_with_shape() -> None: - """Test mask cropping with shape.""" - mask_array = np.array([[True, False], [False, True]], dtype=bool) - bbox = np.array([1, 1, 3, 3]) - - mask = Mask(mask_array, bbox) - image = np.array([[0, 0, 0, 0], [0, 1, 2, 0], [0, 3, 4, 0], [0, 0, 0, 0]]) - cropped_image = mask.crop(image, shape=(2, 4)) - assert np.array_equal(cropped_image, image[1:3, 0:4]) - - -def test_mask_from_coordinates_2d_basic() -> None: - """Test 2D mask creation and bbox without cropping.""" - center = np.asarray([5, 5]) - radius = 2 - mask = Mask.from_coordinates(center, radius) - # Should be a disk of radius 2, shape (5,5), centered at (5,5) - assert mask.mask.shape == (5, 5) - assert mask.mask[2, 2] # center pixel is True - np.testing.assert_array_equal(mask.bbox, [3, 3, 8, 8]) - - -def test_mask_from_coordinates_3d_basic() -> None: - """Test 3D mask creation and bbox without cropping.""" - center = np.asarray([4, 5, 6]) - radius = 1 - mask = Mask.from_coordinates(center, radius) - # Should be a ball of radius 1, shape (3,3,3), centered at (4,5,6) - assert mask.mask.shape == (3, 3, 3) - assert mask.mask[1, 1, 1] # center voxel is True - np.testing.assert_array_equal(mask.bbox, [3, 4, 5, 6, 7, 8]) - - -def test_mask_from_coordinates_cropping() -> None: - """Test cropping when mask falls outside the image boundary.""" - center = np.asarray([0, 0]) - radius = 5 - image_shape = (4, 3) - - mask = Mask.from_coordinates(center, radius, image_shape=image_shape) - - # Mask shape should match the bbox size - expected_shape = (4, 3) - assert mask.mask.shape == expected_shape - - # Mask should be cropped to fit within image bounds - np.testing.assert_array_equal(mask.bbox, [0, 0, 4, 3]) + assert np.array_equal(cropped_image, expected_crop) diff --git a/src/tracksdata/functional/_test/test_napari.py b/src/tracksdata/functional/_test/test_napari.py index a888326a..8ff89e7b 100644 --- a/src/tracksdata/functional/_test/test_napari.py +++ b/src/tracksdata/functional/_test/test_napari.py @@ -31,7 +31,6 @@ def test_napari_conversion() -> None: mask_attrs = MaskDiskAttrs( radius=2, image_shape=image_shape, - output_key=DEFAULT_ATTR_KEYS.MASK, ) mask_attrs.add_node_attrs(graph) diff --git a/src/tracksdata/graph/_base_graph.py b/src/tracksdata/graph/_base_graph.py index 5fe03991..3bddd125 100644 --- a/src/tracksdata/graph/_base_graph.py +++ b/src/tracksdata/graph/_base_graph.py @@ -831,20 +831,24 @@ def compute_overlaps(self, iou_threshold: float = 0.0) -> None: graph.set_overlaps(iou_threshold=0.5) ``` """ + from tracksdata.functional._mask import mask_iou + if iou_threshold < 0.0 or iou_threshold > 1.0: raise ValueError("iou_threshold must be between 0.0 and 1.0") def _estimate_overlaps(t: int) -> list[list[int, 2]]: node_attrs = self.filter(NodeAttr(DEFAULT_ATTR_KEYS.T) == t).node_attrs( - attr_keys=[DEFAULT_ATTR_KEYS.NODE_ID, DEFAULT_ATTR_KEYS.MASK], + attr_keys=[DEFAULT_ATTR_KEYS.NODE_ID, DEFAULT_ATTR_KEYS.BBOX, DEFAULT_ATTR_KEYS.MASK], ) node_ids = node_attrs[DEFAULT_ATTR_KEYS.NODE_ID].to_list() masks = node_attrs[DEFAULT_ATTR_KEYS.MASK].to_list() + bboxes = node_attrs[DEFAULT_ATTR_KEYS.BBOX].to_list() overlaps = [] for i in range(len(masks)): mask_i = masks[i] + bbox_i = bboxes[i] for j in range(i + 1, len(masks)): - if mask_i.iou(masks[j]) > iou_threshold: + if mask_iou(bbox_i, mask_i, bboxes[j], masks[j]) > iou_threshold: overlaps.append([node_ids[i], node_ids[j]]) return overlaps diff --git a/src/tracksdata/graph/_rustworkx_graph.py b/src/tracksdata/graph/_rustworkx_graph.py index 8166065b..35b6c8eb 100644 --- a/src/tracksdata/graph/_rustworkx_graph.py +++ b/src/tracksdata/graph/_rustworkx_graph.py @@ -848,7 +848,11 @@ def _node_attrs_from_node_ids( columns[key].append(node_data[key]) for key in attr_keys: - columns[key] = np.asarray(columns[key]) + try: + columns[key] = np.asarray(columns[key]) + except ValueError: + # in case the array is inhomogeneous, we keep it as a list + pass # Create DataFrame and set node_id as index in one shot df = pl.DataFrame(columns) diff --git a/src/tracksdata/graph/_sql_graph.py b/src/tracksdata/graph/_sql_graph.py index 2fb81946..fdfe02e2 100644 --- a/src/tracksdata/graph/_sql_graph.py +++ b/src/tracksdata/graph/_sql_graph.py @@ -2,12 +2,15 @@ from enum import Enum from typing import TYPE_CHECKING, Any +import blosc2 +import cloudpickle import numpy as np import polars as pl import rustworkx as rx import sqlalchemy as sa from sqlalchemy.orm import DeclarativeBase, Session, aliased, load_only from sqlalchemy.sql.type_api import TypeEngine +from sqlalchemy.types import LargeBinary, TypeDecorator from tracksdata.attrs import AttrComparison, split_attr_comps from tracksdata.constants import DEFAULT_ATTR_KEYS @@ -19,6 +22,46 @@ if TYPE_CHECKING: from tracksdata.graph._graph_view import GraphView +# Minimum array size threshold for using blosc2 compression +# Arrays smaller than this are stored using cloudpickle +BLOSC_COMPRESSION_THRESHOLD = 30 + + +class MaybeBloscBytes(TypeDecorator): + """ + Custom SQLAlchemy type that compresses numpy arrays using blosc2.pack_array2 + and decompresses them using blosc2.unpack_array2. + + This provides efficient compression for large numpy arrays stored in the database, + particularly useful for mask and bbox data. + """ + + impl = LargeBinary + cache_ok = False + + def process_bind_param(self, value: Any, dialect: Any) -> bytes | None: + """Convert numpy array to compressed bytes for storage.""" + if value is None: + return None + + if isinstance(value, np.ndarray) and value.size > BLOSC_COMPRESSION_THRESHOLD: + # Use blosc2.pack_array2 for compression + return blosc2.pack_array2(value) + + return cloudpickle.dumps(value) + + def process_result_value(self, value: bytes | None, dialect: Any) -> Any: + """Convert compressed bytes back to numpy array.""" + if value is None: + return None + + try: + # Try to unpack as blosc2-compressed numpy array first + return blosc2.unpack_array2(value) + + except Exception: + return cloudpickle.loads(value) + def _is_builtin(obj: Any) -> bool: """Check if an object is a built-in type.""" @@ -1145,7 +1188,7 @@ def _sqlalchemy_type_inference(self, default_value: Any) -> TypeEngine: return sa.Enum(default_value.__class__) elif default_value is None or not _is_builtin(default_value): - return sa.PickleType + return MaybeBloscBytes else: raise ValueError(f"Unsupported default value type: {type(default_value)}") diff --git a/src/tracksdata/graph/_test/test_graph_backends.py b/src/tracksdata/graph/_test/test_graph_backends.py index d1458425..9d287a12 100644 --- a/src/tracksdata/graph/_test/test_graph_backends.py +++ b/src/tracksdata/graph/_test/test_graph_backends.py @@ -9,7 +9,7 @@ from tracksdata.graph import RustWorkXGraph, SQLGraph from tracksdata.graph._base_graph import BaseGraph from tracksdata.io._numpy_array import from_array -from tracksdata.nodes._mask import Mask +from tracksdata.utils._test_utils import setup_mask_attrs, setup_spatial_attrs_2d def test_already_existing_keys(graph_backend: BaseGraph) -> None: @@ -762,24 +762,29 @@ def test_sucessors_predecessors_edge_cases(graph_backend: BaseGraph) -> None: def test_match_method(graph_backend: BaseGraph) -> None: """Test the match method for matching nodes between two graphs.""" # Create first graph (self) with masks - graph_backend.add_node_attr_key("x", 0.0) - graph_backend.add_node_attr_key("y", 0.0) - graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, None) + setup_spatial_attrs_2d(graph_backend) + setup_mask_attrs(graph_backend) # Create masks for first graph - mask1_data = np.array([[True, True], [True, True]], dtype=bool) - mask1 = Mask(mask1_data, bbox=np.array([0, 0, 2, 2])) + mask1 = np.array([[True, True], [True, True]], dtype=bool) + bbox1 = np.array([0, 0, 2, 2]) - mask2_data = np.array([[True, False], [True, False]], dtype=bool) - mask2 = Mask(mask2_data, bbox=np.array([10, 10, 12, 12])) + mask2 = np.array([[True, False], [True, False]], dtype=bool) + bbox2 = np.array([10, 10, 12, 12]) - mask3_data = np.array([[True, True, True, True, True]], dtype=bool) - mask3 = Mask(mask3_data, bbox=np.array([20, 20, 21, 25])) + mask3 = np.array([[True, True, True, True, True]], dtype=bool) + bbox3 = np.array([20, 20, 21, 25]) # Add nodes to first graph - node1 = graph_backend.add_node({"t": 0, "x": 1.0, "y": 1.0, DEFAULT_ATTR_KEYS.MASK: mask1}) - node2 = graph_backend.add_node({"t": 1, "x": 2.0, "y": 2.0, DEFAULT_ATTR_KEYS.MASK: mask2}) - node3 = graph_backend.add_node({"t": 2, "x": 3.0, "y": 3.0, DEFAULT_ATTR_KEYS.MASK: mask3}) + node1 = graph_backend.add_node( + {"t": 0, "x": 1.0, "y": 1.0, DEFAULT_ATTR_KEYS.MASK: mask1, DEFAULT_ATTR_KEYS.BBOX: bbox1} + ) + node2 = graph_backend.add_node( + {"t": 1, "x": 2.0, "y": 2.0, DEFAULT_ATTR_KEYS.MASK: mask2, DEFAULT_ATTR_KEYS.BBOX: bbox2} + ) + node3 = graph_backend.add_node( + {"t": 2, "x": 3.0, "y": 3.0, DEFAULT_ATTR_KEYS.MASK: mask3, DEFAULT_ATTR_KEYS.BBOX: bbox3} + ) graph_backend.add_edge_attr_key("weight", 0.0) # this will not be matched @@ -796,33 +801,40 @@ def test_match_method(graph_backend: BaseGraph) -> None: kwargs = {} other_graph = graph_backend.__class__(**kwargs) - other_graph.add_node_attr_key("x", 0.0) - other_graph.add_node_attr_key("y", 0.0) - other_graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, None) + setup_spatial_attrs_2d(other_graph) + setup_mask_attrs(other_graph) # Create overlapping masks for second graph # This mask overlaps significantly with mask1 (IoU > 0.5) ref_mask1_data = np.array([[True, True], [True, False]], dtype=bool) - ref_mask1 = Mask(ref_mask1_data, bbox=np.array([0, 0, 2, 2])) + ref_mask1_bbox = np.array([0, 0, 2, 2]) # This mask overlaps significantly with mask3 (IoU > 0.5) ref_mask2_data = np.array([[True, True, True, True]], dtype=bool) - ref_mask2 = Mask(ref_mask2_data, bbox=np.array([20, 20, 21, 24])) + ref_mask2_bbox = np.array([20, 20, 21, 24]) # This mask should NOT overlap with other masks (IoU < 0.5, should not match) ref_mask3_data = np.array([[True]], dtype=bool) - ref_mask3 = Mask(ref_mask3_data, bbox=np.array([15, 15, 16, 16])) # Different location + ref_mask3_bbox = np.array([15, 15, 16, 16]) # This mask also overlaps significantly with mask3 (IoU > 0.5) but less than `ref_mask2` # therefore it should not match ref_mask4_data = np.array([[True, True, True]], dtype=bool) - ref_mask4 = Mask(ref_mask4_data, bbox=np.array([20, 21, 21, 24])) + ref_mask4_bbox = np.array([20, 21, 21, 24]) # Add nodes to reference graph - ref_node1 = other_graph.add_node({"t": 0, "x": 1.1, "y": 1.1, DEFAULT_ATTR_KEYS.MASK: ref_mask1}) - ref_node2 = other_graph.add_node({"t": 2, "x": 3.1, "y": 3.1, DEFAULT_ATTR_KEYS.MASK: ref_mask2}) - ref_node3 = other_graph.add_node({"t": 1, "x": 2.1, "y": 2.1, DEFAULT_ATTR_KEYS.MASK: ref_mask3}) - ref_node4 = other_graph.add_node({"t": 2, "x": 3.1, "y": 3.1, DEFAULT_ATTR_KEYS.MASK: ref_mask4}) + ref_node1 = other_graph.add_node( + {"t": 0, "x": 1.1, "y": 1.1, DEFAULT_ATTR_KEYS.MASK: ref_mask1_data, DEFAULT_ATTR_KEYS.BBOX: ref_mask1_bbox} + ) + ref_node2 = other_graph.add_node( + {"t": 2, "x": 3.1, "y": 3.1, DEFAULT_ATTR_KEYS.MASK: ref_mask2_data, DEFAULT_ATTR_KEYS.BBOX: ref_mask2_bbox} + ) + ref_node3 = other_graph.add_node( + {"t": 1, "x": 2.1, "y": 2.1, DEFAULT_ATTR_KEYS.MASK: ref_mask3_data, DEFAULT_ATTR_KEYS.BBOX: ref_mask3_bbox} + ) + ref_node4 = other_graph.add_node( + {"t": 2, "x": 3.1, "y": 3.1, DEFAULT_ATTR_KEYS.MASK: ref_mask4_data, DEFAULT_ATTR_KEYS.BBOX: ref_mask4_bbox} + ) # Add edges to reference graph - matching structure with first graph other_graph.add_edge_attr_key("weight", 0.0) @@ -1167,17 +1179,17 @@ def test_from_other_with_edges(graph_backend: BaseGraph) -> None: def test_compute_overlaps_basic(graph_backend: BaseGraph) -> None: """Test basic compute_overlaps functionality.""" - graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, None) + setup_mask_attrs(graph_backend) # Create overlapping masks at time 0 - mask1_data = np.array([[True, True], [True, True]], dtype=bool) - mask1 = Mask(mask1_data, bbox=np.array([0, 0, 2, 2])) + mask1 = np.array([[True, True], [True, True]], dtype=bool) + bbox1 = np.array([0, 0, 2, 2]) - mask2_data = np.array([[True, True], [False, False]], dtype=bool) - mask2 = Mask(mask2_data, bbox=np.array([0, 0, 2, 2])) + mask2 = np.array([[True, True], [False, False]], dtype=bool) + bbox2 = np.array([0, 0, 2, 2]) - node1 = graph_backend.add_node({"t": 0, DEFAULT_ATTR_KEYS.MASK: mask1}) - node2 = graph_backend.add_node({"t": 0, DEFAULT_ATTR_KEYS.MASK: mask2}) + node1 = graph_backend.add_node({"t": 0, DEFAULT_ATTR_KEYS.MASK: mask1, DEFAULT_ATTR_KEYS.BBOX: bbox1}) + node2 = graph_backend.add_node({"t": 0, DEFAULT_ATTR_KEYS.MASK: mask2, DEFAULT_ATTR_KEYS.BBOX: bbox2}) graph_backend.compute_overlaps(iou_threshold=0.3) @@ -1189,23 +1201,23 @@ def test_compute_overlaps_basic(graph_backend: BaseGraph) -> None: def test_compute_overlaps_with_threshold(graph_backend: BaseGraph) -> None: """Test compute_overlaps with different IoU thresholds.""" - graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, None) + setup_mask_attrs(graph_backend) # Create masks with different overlap levels - mask1_data = np.array([[True, True], [True, True]], dtype=bool) - mask1 = Mask(mask1_data, bbox=np.array([0, 0, 2, 2])) + mask1 = np.array([[True, True], [True, True]], dtype=bool) + bbox1 = np.array([0, 0, 2, 2]) # Partially overlapping mask (IoU = 0.5) - mask2_data = np.array([[True, True], [False, False]], dtype=bool) - mask2 = Mask(mask2_data, bbox=np.array([0, 0, 2, 2])) + mask2 = np.array([[True, True], [False, False]], dtype=bool) + bbox2 = np.array([0, 0, 2, 2]) # Non-overlapping mask - mask3_data = np.array([[True, True], [True, True]], dtype=bool) - mask3 = Mask(mask3_data, bbox=np.array([10, 10, 12, 12])) + mask3 = np.array([[True, True], [True, True]], dtype=bool) + bbox3 = np.array([10, 10, 12, 12]) - node1 = graph_backend.add_node({"t": 0, DEFAULT_ATTR_KEYS.MASK: mask1}) - node2 = graph_backend.add_node({"t": 0, DEFAULT_ATTR_KEYS.MASK: mask2}) - graph_backend.add_node({"t": 0, DEFAULT_ATTR_KEYS.MASK: mask3}) + node1 = graph_backend.add_node({"t": 0, DEFAULT_ATTR_KEYS.MASK: mask1, DEFAULT_ATTR_KEYS.BBOX: bbox1}) + node2 = graph_backend.add_node({"t": 0, DEFAULT_ATTR_KEYS.MASK: mask2, DEFAULT_ATTR_KEYS.BBOX: bbox2}) + graph_backend.add_node({"t": 0, DEFAULT_ATTR_KEYS.MASK: mask3, DEFAULT_ATTR_KEYS.BBOX: bbox3}) # With threshold 0.7, no overlaps should be found (IoU = 0.5 < 0.7) graph_backend.compute_overlaps(iou_threshold=0.7) @@ -1223,20 +1235,24 @@ def test_compute_overlaps_with_threshold(graph_backend: BaseGraph) -> None: def test_compute_overlaps_multiple_timepoints(graph_backend: BaseGraph) -> None: """Test compute_overlaps across multiple time points.""" - graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, None) + setup_mask_attrs(graph_backend) # Time 0: overlapping masks - mask1_t0 = Mask(np.array([[True, True], [True, True]], dtype=bool), bbox=np.array([0, 0, 2, 2])) - mask2_t0 = Mask(np.array([[True, True], [False, False]], dtype=bool), bbox=np.array([0, 0, 2, 2])) + mask1_t0 = np.array([[True, True], [True, True]], dtype=bool) + bbox1_t0 = np.array([0, 0, 2, 2]) + mask2_t0 = np.array([[True, True], [False, False]], dtype=bool) + bbox2_t0 = np.array([0, 0, 2, 2]) # Time 1: non-overlapping masks - mask1_t1 = Mask(np.array([[True, True], [True, True]], dtype=bool), bbox=np.array([0, 0, 2, 2])) - mask2_t1 = Mask(np.array([[True, True], [True, True]], dtype=bool), bbox=np.array([10, 10, 12, 12])) - - node1_t0 = graph_backend.add_node({"t": 0, DEFAULT_ATTR_KEYS.MASK: mask1_t0}) - node2_t0 = graph_backend.add_node({"t": 0, DEFAULT_ATTR_KEYS.MASK: mask2_t0}) - graph_backend.add_node({"t": 1, DEFAULT_ATTR_KEYS.MASK: mask1_t1}) - graph_backend.add_node({"t": 1, DEFAULT_ATTR_KEYS.MASK: mask2_t1}) + mask1_t1 = np.array([[True, True], [True, True]], dtype=bool) + bbox1_t1 = np.array([0, 0, 2, 2]) + mask2_t1 = np.array([[True, True], [True, True]], dtype=bool) + bbox2_t1 = np.array([10, 10, 12, 12]) + + node1_t0 = graph_backend.add_node({"t": 0, DEFAULT_ATTR_KEYS.MASK: mask1_t0, DEFAULT_ATTR_KEYS.BBOX: bbox1_t0}) + node2_t0 = graph_backend.add_node({"t": 0, DEFAULT_ATTR_KEYS.MASK: mask2_t0, DEFAULT_ATTR_KEYS.BBOX: bbox2_t0}) + graph_backend.add_node({"t": 1, DEFAULT_ATTR_KEYS.MASK: mask1_t1, DEFAULT_ATTR_KEYS.BBOX: bbox1_t1}) + graph_backend.add_node({"t": 1, DEFAULT_ATTR_KEYS.MASK: mask2_t1, DEFAULT_ATTR_KEYS.BBOX: bbox2_t1}) graph_backend.compute_overlaps(iou_threshold=0.3) diff --git a/src/tracksdata/graph/filters/_test/test_spatial_filter.py b/src/tracksdata/graph/filters/_test/test_spatial_filter.py index d31310b3..ef8a64b4 100644 --- a/src/tracksdata/graph/filters/_test/test_spatial_filter.py +++ b/src/tracksdata/graph/filters/_test/test_spatial_filter.py @@ -1,3 +1,4 @@ +import numpy as np import pytest from tracksdata.graph import RustWorkXGraph @@ -134,10 +135,10 @@ def test_bbox_spatial_filter_overlaps() -> None: graph.add_node_attr_key("bbox", [0, 0, 0, 0]) # Add nodes with bounding boxes bboxes = [ - [0, 20, 10, 30], # Node 1 - [5, 25, 15, 35], # Node 2 - [10, 30, 20, 40], # Node 3 - [15, 35, 25, 45], # Node 4 + np.asarray([0, 20, 10, 30]), # Node 1 + np.asarray([5, 25, 15, 35]), # Node 2 + np.asarray([10, 30, 20, 40]), # Node 3 + np.asarray([15, 35, 25, 45]), # Node 4 ] node_ids = graph.bulk_add_nodes([{"t": 0, "bbox": bbox} for bbox in bboxes]) @@ -154,12 +155,12 @@ def test_bbox_spatial_filter_overlaps() -> None: def test_bbox_spatial_filter_with_edges() -> None: """Test SpatialFilter preserves edges in subgraphs.""" graph = RustWorkXGraph() - graph.add_node_attr_key("bbox", [0, 0, 0, 0]) + graph.add_node_attr_key("bbox", np.asarray([0, 0, 0, 0])) graph.add_edge_attr_key("weight", 0.0) # Add nodes and edge - node1_id = graph.add_node({"t": 0, "bbox": [10, 20, 15, 25]}) - node2_id = graph.add_node({"t": 1, "bbox": [30, 40, 35, 45]}) + node1_id = graph.add_node({"t": 0, "bbox": np.asarray([10, 20, 15, 25])}) + node2_id = graph.add_node({"t": 1, "bbox": np.asarray([30, 40, 35, 45])}) graph.add_edge(node1_id, node2_id, {"weight": 1.0}) spatial_filter = BBoxSpatialFilter(graph, frame_attr_key="t", bbox_attr_key="bbox") diff --git a/src/tracksdata/io/_test/test_ctc_io.py b/src/tracksdata/io/_test/test_ctc_io.py index 32433959..1c7dcdc6 100644 --- a/src/tracksdata/io/_test/test_ctc_io.py +++ b/src/tracksdata/io/_test/test_ctc_io.py @@ -4,7 +4,11 @@ from tracksdata.constants import DEFAULT_ATTR_KEYS from tracksdata.graph import RustWorkXGraph -from tracksdata.nodes._mask import Mask +from tracksdata.utils._test_utils import ( + setup_custom_node_attr, + setup_edge_distance_attr, + setup_mask_attrs, +) def test_export_from_ctc_roundtrip(tmp_path: Path): @@ -12,10 +16,10 @@ def test_export_from_ctc_roundtrip(tmp_path: Path): # Create original graph with nodes and edges in_graph = RustWorkXGraph() - in_graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, None) - in_graph.add_node_attr_key(DEFAULT_ATTR_KEYS.TRACK_ID, -1) - in_graph.add_node_attr_key("x", -999_999) - in_graph.add_node_attr_key("y", -999_999) + setup_mask_attrs(in_graph) + setup_custom_node_attr(in_graph, DEFAULT_ATTR_KEYS.TRACK_ID, -1) + setup_custom_node_attr(in_graph, "x", -999_999) + setup_custom_node_attr(in_graph, "y", -999_999) node_1 = in_graph.add_node( attrs={ @@ -23,10 +27,8 @@ def test_export_from_ctc_roundtrip(tmp_path: Path): DEFAULT_ATTR_KEYS.TRACK_ID: 1, "x": 0, "y": 0, - DEFAULT_ATTR_KEYS.MASK: Mask( - mask=np.ones((2, 2), dtype=bool), - bbox=np.asarray([0, 0, 2, 2]), - ), + DEFAULT_ATTR_KEYS.MASK: np.ones((2, 2), dtype=bool), + DEFAULT_ATTR_KEYS.BBOX: np.asarray([0, 0, 2, 2]), }, ) @@ -36,10 +38,8 @@ def test_export_from_ctc_roundtrip(tmp_path: Path): DEFAULT_ATTR_KEYS.TRACK_ID: 2, "x": 1, "y": 1, - DEFAULT_ATTR_KEYS.MASK: Mask( - mask=np.ones((2, 2), dtype=bool), - bbox=np.asarray([0, 0, 2, 2]), - ), + DEFAULT_ATTR_KEYS.MASK: np.ones((2, 2), dtype=bool), + DEFAULT_ATTR_KEYS.BBOX: np.asarray([0, 0, 2, 2]), }, ) @@ -49,14 +49,12 @@ def test_export_from_ctc_roundtrip(tmp_path: Path): DEFAULT_ATTR_KEYS.TRACK_ID: 3, "x": 2, "y": 2, - DEFAULT_ATTR_KEYS.MASK: Mask( - mask=np.ones((2, 2), dtype=bool), - bbox=np.asarray([1, 1, 3, 3]), - ), + DEFAULT_ATTR_KEYS.MASK: np.ones((2, 2), dtype=bool), + DEFAULT_ATTR_KEYS.BBOX: np.asarray([1, 1, 3, 3]), }, ) - in_graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, 0.0) + setup_edge_distance_attr(in_graph) in_graph.add_edge(node_1, node_2, attrs={DEFAULT_ATTR_KEYS.EDGE_DIST: 1.0}) in_graph.add_edge(node_1, node_3, attrs={DEFAULT_ATTR_KEYS.EDGE_DIST: 1.0}) diff --git a/src/tracksdata/metrics/_ctc_metrics.py b/src/tracksdata/metrics/_ctc_metrics.py index 5d48b91b..38c4c611 100644 --- a/src/tracksdata/metrics/_ctc_metrics.py +++ b/src/tracksdata/metrics/_ctc_metrics.py @@ -6,6 +6,7 @@ from toolz import curry from tracksdata.constants import DEFAULT_ATTR_KEYS +from tracksdata.functional._mask import mask_intersection from tracksdata.io._ctc import compressed_tracks_table from tracksdata.options import get_options from tracksdata.utils._dtypes import column_from_bytes, column_to_bytes @@ -66,15 +67,30 @@ def _match_single_frame( _rows = [] _cols = [] - for i, (ref_id, ref_mask) in enumerate( - zip(ref_group[reference_graph_key], ref_group[DEFAULT_ATTR_KEYS.MASK], strict=True) + ref_mask_sizes = ref_group[DEFAULT_ATTR_KEYS.MASK].map_elements(np.sum, return_dtype=pl.Int64) + comp_mask_sizes = comp_group[DEFAULT_ATTR_KEYS.MASK].map_elements(np.sum, return_dtype=pl.Int64) + + for i, (ref_id, ref_bbox, ref_mask, ref_mask_size) in enumerate( + zip( + ref_group[reference_graph_key], + ref_group[DEFAULT_ATTR_KEYS.BBOX], + ref_group[DEFAULT_ATTR_KEYS.MASK], + ref_mask_sizes, + strict=True, + ) ): - for j, (comp_id, comp_mask) in enumerate( - zip(comp_group[input_graph_key], comp_group[DEFAULT_ATTR_KEYS.MASK], strict=True) + for j, (comp_id, comp_bbox, comp_mask, comp_mask_size) in enumerate( + zip( + comp_group[input_graph_key], + comp_group[DEFAULT_ATTR_KEYS.BBOX], + comp_group[DEFAULT_ATTR_KEYS.MASK], + comp_mask_sizes, + strict=True, + ) ): # intersection over reference is used to select the matches - inter = ref_mask.intersection(comp_mask) - ctc_score = inter / ref_mask.size + inter = mask_intersection(ref_bbox, ref_mask, comp_bbox, comp_mask) + ctc_score = inter / ref_mask_size if ctc_score > min_reference_intersection: _mapped_ref.append(ref_id) _mapped_comp.append(comp_id) @@ -83,8 +99,8 @@ def _match_single_frame( # NOTE: there was something weird with IoU, the length when compared with `ctc_metrics` # sometimes it had an extra element - iou = inter / (ref_mask.size + comp_mask.size - inter) - _ious.append(iou.item()) + iou = inter / (ref_mask_size + comp_mask_size - inter) + _ious.append(iou) if optimal_matching and len(_rows) > 0: LOG.info("Solving optimal matching ...") @@ -150,7 +166,9 @@ def _matching_data( ("ref", reference_graph, reference_graph_key), ("comp", input_graph, input_graph_key), ]: - nodes_df = graph.node_attrs(attr_keys=[DEFAULT_ATTR_KEYS.T, track_id_key, DEFAULT_ATTR_KEYS.MASK]) + nodes_df = graph.node_attrs( + attr_keys=[DEFAULT_ATTR_KEYS.T, track_id_key, DEFAULT_ATTR_KEYS.BBOX, DEFAULT_ATTR_KEYS.MASK] + ) if n_workers > 1: # required by multiprocessing nodes_df = column_to_bytes(nodes_df, DEFAULT_ATTR_KEYS.MASK) diff --git a/src/tracksdata/metrics/_test/test_metrics_visualize.py b/src/tracksdata/metrics/_test/test_metrics_visualize.py index 92b30df1..83e183c1 100644 --- a/src/tracksdata/metrics/_test/test_metrics_visualize.py +++ b/src/tracksdata/metrics/_test/test_metrics_visualize.py @@ -4,7 +4,6 @@ import numpy as np import pytest -from tracksdata.constants import DEFAULT_ATTR_KEYS from tracksdata.graph import RustWorkXGraph from tracksdata.metrics._visualize import visualize_matches from tracksdata.nodes import MaskDiskAttrs @@ -60,7 +59,6 @@ def test_visualize_matches(make_napari_viewer: Callable[[], "napari.Viewer"]) -> input_mask_attrs = MaskDiskAttrs( radius=5, image_shape=image_shape, - output_key=DEFAULT_ATTR_KEYS.MASK, ) input_mask_attrs.add_node_attrs(input_graph) @@ -68,7 +66,6 @@ def test_visualize_matches(make_napari_viewer: Callable[[], "napari.Viewer"]) -> ref_mask_attrs = MaskDiskAttrs( radius=3, image_shape=image_shape, - output_key=DEFAULT_ATTR_KEYS.MASK, ) ref_mask_attrs.add_node_attrs(ref_graph) diff --git a/src/tracksdata/nodes/__init__.py b/src/tracksdata/nodes/__init__.py index b031d4bc..7b7f5fb7 100644 --- a/src/tracksdata/nodes/__init__.py +++ b/src/tracksdata/nodes/__init__.py @@ -1,8 +1,8 @@ """Node operators for creating nodes and their respective attributes (e.g. masks) in a graph.""" +from tracksdata.functional._disk_attrs import MaskDiskAttrs from tracksdata.nodes._generic_nodes import GenericFuncNodeAttrs -from tracksdata.nodes._mask import Mask, MaskDiskAttrs from tracksdata.nodes._random import RandomNodes from tracksdata.nodes._regionprops import RegionPropsNodes -__all__ = ["GenericFuncNodeAttrs", "Mask", "MaskDiskAttrs", "RandomNodes", "RegionPropsNodes"] +__all__ = ["GenericFuncNodeAttrs", "MaskDiskAttrs", "RandomNodes", "RegionPropsNodes"] diff --git a/src/tracksdata/nodes/_disk.py b/src/tracksdata/nodes/_disk.py new file mode 100644 index 00000000..a1b99b8a --- /dev/null +++ b/src/tracksdata/nodes/_disk.py @@ -0,0 +1,165 @@ +from collections.abc import Sequence +from functools import lru_cache + +import numpy as np +import skimage.morphology as morph +from numpy.typing import NDArray + +from tracksdata.attrs import NodeAttr +from tracksdata.constants import DEFAULT_ATTR_KEYS +from tracksdata.graph._base_graph import BaseGraph +from tracksdata.nodes._base_node_attrs import BaseNodeAttrsOperator +from tracksdata.utils._logging import LOG + + +@lru_cache(maxsize=5) +def _spherical_mask( + radius: int, + ndim: int, +) -> NDArray[np.bool_]: + """ + Get a spherical mask of a given radius and dimension. + """ + if ndim == 2: + return morph.disk(radius).astype(bool) + + if ndim == 3: + return morph.ball(radius).astype(bool) + + raise ValueError(f"Spherical is only implemented for 2D and 3D, got ndim={ndim}") + + +def _create_mask_and_bbox_from_coordinates( + center: NDArray, + radius: int, + image_shape: tuple[int, ...] | None = None, +) -> tuple[NDArray[np.bool_], NDArray[np.int64]]: + """ + Create a mask and bounding box from center coordinates and radius. + + Parameters + ---------- + center : NDArray + The center of the mask. + radius : int + The radius of the mask. + image_shape : tuple[int, ...] | None + The shape of the image. When provided, crops regions outside the image. + + Returns + ------- + tuple[NDArray[np.bool_], NDArray[np.int64]] + The mask and bounding box arrays. + """ + mask = _spherical_mask(radius, len(center)) + center = np.round(center).astype(int) + + start = center - np.asarray(mask.shape) // 2 + end = start + mask.shape + + if image_shape is None: + bbox = np.concatenate([start, end]) + else: + processed_start = np.maximum(start, 0) + processed_end = np.minimum(end, image_shape) + + start_overhang = processed_start - start + end_overhang = end - processed_end + + mask = mask[tuple(slice(s, -e if e > 0 else None) for s, e in zip(start_overhang, end_overhang, strict=True))] + + bbox = np.concatenate([processed_start, processed_end]) + + return mask, bbox + + +class DiskMaskAttrs(BaseNodeAttrsOperator): + """ + Operator to create disk masks and bounding boxes for each node. + + Creates spherical masks in space, so temporal information should not be provided. + + Parameters + ---------- + radius : int + The radius of the mask. + image_shape : tuple[int, ...] + The shape of the image, must match the number of attr_keys. + attr_keys : Sequence[str] | None + The attributes for the center of the mask. + If not provided, "z", "y", "x" will be used. + mask_output_key : str + The key to store the mask attribute. + bbox_output_key : str + The key to store the bounding box attribute. + """ + + def __init__( + self, + radius: int, + image_shape: tuple[int, ...], + attr_keys: Sequence[str] | None = None, + mask_output_key: str = DEFAULT_ATTR_KEYS.MASK, + bbox_output_key: str = DEFAULT_ATTR_KEYS.BBOX, + ): + super().__init__(mask_output_key) # Primary output key for base class + + if attr_keys is None: + default_columns = ["z", "y", "x"] + attr_keys = default_columns[-len(image_shape) :] + + if len(attr_keys) != len(image_shape): + raise ValueError( + f"Expected image shape {image_shape} to have the same number of dimensions as attr_keys '{attr_keys}'." + ) + + self.radius = radius + self.image_shape = image_shape + self.attr_keys = attr_keys + self.mask_output_key = mask_output_key + self.bbox_output_key = bbox_output_key + + def _init_node_attrs(self, graph: BaseGraph) -> None: + """Initialize the node attributes for the graph.""" + if self.mask_output_key not in graph.node_attr_keys: + graph.add_node_attr_key(self.mask_output_key, default_value=None) + if self.bbox_output_key not in graph.node_attr_keys: + graph.add_node_attr_key(self.bbox_output_key, default_value=None) + + def _node_attrs_per_time( + self, + t: int, + *, + graph: BaseGraph, + frames: NDArray | None = None, + ) -> tuple[list[int], dict[str, list]]: + """ + Add mask and bbox attributes to nodes for a specific time point. + """ + # Get node IDs for the specified time point + graph_filter = graph.filter(NodeAttr(DEFAULT_ATTR_KEYS.T) == t) + + if graph_filter.is_empty(): + LOG.warning(f"No nodes at time point {t}") + return [], {} + + # Get attributes for these nodes + node_attrs = graph_filter.node_attrs(attr_keys=self.attr_keys) + + masks = [] + bboxes = [] + + for data_dict in node_attrs.rows(named=True): + center = np.asarray([data_dict[key] for key in self.attr_keys]) + mask, bbox = _create_mask_and_bbox_from_coordinates( + center=center, + radius=self.radius, + image_shape=self.image_shape, + ) + masks.append(mask) + bboxes.append(bbox) + + return graph_filter.node_ids(), { + self.mask_output_key: masks, + self.bbox_output_key: bboxes, + } diff --git a/src/tracksdata/nodes/_generic_nodes.py b/src/tracksdata/nodes/_generic_nodes.py index 10ca2f9c..478f74c9 100644 --- a/src/tracksdata/nodes/_generic_nodes.py +++ b/src/tracksdata/nodes/_generic_nodes.py @@ -46,16 +46,18 @@ class GenericFuncNodeAttrs(BaseNodeAttrsOperator): graph = ... - def intensity_median_times_t(image: NDArray, mask: Mask, t: int) -> float: - cropped_frame = mask.crop(image) - valid_pixels = cropped_frame[mask.mask] + def intensity_median_times_t(image: NDArray, mask: NDArray, bbox: NDArray, t: int) -> float: + from tracksdata.functional._mask_utils import crop_with_bbox_and_mask + + cropped_frame = crop_with_bbox_and_mask(image, bbox) + valid_pixels = cropped_frame[mask] return np.median(valid_pixels) * t crop_attrs = GenericFuncNodeAttrs( func=intensity_median, output_key="intensity_median", - attr_keys=["mask", "t"], + attr_keys=["mask", "bbox", "t"], ) crop_attrs.add_node_attrs(graph, frames=video) @@ -68,11 +70,13 @@ def intensity_median_times_t(image: NDArray, mask: Mask, t: int) -> float: graph = ... - def intensity_median_times_t(image: NDArray, masks: list[Mask], t: list[int]) -> list[float]: + def intensity_median_times_t(image: NDArray, mask: list[NDArray], bbox: list[NDArray], t: list[int]) -> list[float]: + from tracksdata.functional._mask_utils import crop_with_bbox_and_mask + results = [] - for i in range(len(masks)): - cropped_frame = masks[i].crop(image) - valid_pixels = cropped_frame[masks[i].mask] + for i in range(len(mask)): + cropped_frame = crop_with_bbox_and_mask(image, bbox[i]) + valid_pixels = cropped_frame[mask[i]] value = np.median(valid_pixels) * t[i] results.append(value) return results @@ -81,7 +85,7 @@ def intensity_median_times_t(image: NDArray, masks: list[Mask], t: list[int]) -> crop_attrs = GenericFuncNodeAttrs( func=intensity_median, output_key="intensity_median", - attr_keys=["mask", "t"], + attr_keys=["mask", "bbox", "t"], ) crop_attrs.add_node_attrs(graph, frames=video) @@ -176,7 +180,7 @@ def _node_attrs_per_time( results.extend(batch_results) else: - for data_dict in node_attrs.rows(named=True): + for data_dict in node_attrs.iter_rows(named=True): result = self.func(*args, **data_dict) results.append(result) diff --git a/src/tracksdata/nodes/_mask.py b/src/tracksdata/nodes/_mask.py deleted file mode 100644 index cb51a8e0..00000000 --- a/src/tracksdata/nodes/_mask.py +++ /dev/null @@ -1,322 +0,0 @@ -from collections.abc import Sequence -from functools import cached_property, lru_cache - -import blosc2 -import numpy as np -import skimage.morphology as morph -from numpy.typing import ArrayLike, NDArray - -from tracksdata.constants import DEFAULT_ATTR_KEYS -from tracksdata.functional._iou import fast_intersection_with_bbox, fast_iou_with_bbox -from tracksdata.nodes._generic_nodes import GenericFuncNodeAttrs - - -@lru_cache(maxsize=5) -def _spherical_mask( - radius: int, - ndim: int, -) -> NDArray[np.bool_]: - """ - Get a spherical mask of a given radius and dimension. - """ - - if ndim == 2: - return morph.disk(radius) - - if ndim == 3: - return morph.ball(radius) - - raise ValueError(f"Spherical is only implemented for 2D and 3D, got ndim={ndim}") - - -class Mask: - """ - Object used to store an individual segmentation mask of a single instance (object) - - Parameters - ---------- - mask : NDArray[np.bool_] - A binary indicating the pixels that are part of the object (e.g. cell, nucleus, etc.). - bbox : np.ndarray - The bounding box of the region of interest with shape (2 * ndim,). - The first ndim elements are the start indices and the last ndim elements are the end indices. - Equivalent to slicing a numpy array with `[start:end]`. - Examples - -------- - ```python - mask = Mask(mask=np.array([[True, False], [False, True]]), bbox=np.array([0, 0, 2, 2])) - ``` - """ - - def __init__( - self, - mask: NDArray[np.bool_], - bbox: ArrayLike, - ): - self._mask = mask - self.bbox = bbox - - def __getstate__(self) -> dict: - data_dict = self.__dict__.copy() - prev_nthreads = blosc2.set_nthreads(1) - data_dict["_mask"] = blosc2.pack_array2(self._mask) - blosc2.set_nthreads(prev_nthreads) - return data_dict - - def __setstate__(self, state: dict) -> None: - prev_nthreads = blosc2.set_nthreads(1) - state["_mask"] = blosc2.unpack_array2(state["_mask"]) - blosc2.set_nthreads(prev_nthreads) - self.__dict__.update(state) - - @property - def mask(self) -> NDArray[np.bool_]: - return self._mask - - @property - def bbox(self) -> NDArray[np.int64]: - return self._bbox - - @bbox.setter - def bbox(self, bbox: ArrayLike) -> None: - bbox = np.asarray(bbox, dtype=np.int64) - - if self._mask.ndim != bbox.shape[0] // 2: - raise ValueError(f"Mask dimension {self._mask.ndim} does not match bbox dimension {bbox.shape[0]} // 2") - - bbox_size = bbox[self._mask.ndim :] - bbox[: self._mask.ndim] - - if np.any(self._mask.shape != bbox_size): - raise ValueError(f"Mask shape {self._mask.shape} does not match bbox size {bbox_size}") - - self._bbox: NDArray[np.int64] = bbox - - def crop( - self, - image: NDArray, - shape: tuple[int, ...] | None = None, - ) -> NDArray: - """ - Crop the mask from an image. - - Parameters - ---------- - image : NDArray - The image to crop from. - shape : tuple[int, ...] | None - The shape of the cropped image. If None, the `bbox` will be used. - - Returns - ------- - NDArray - The cropped image. - """ - if shape is None: - ndim = self._mask.ndim - slicing = tuple(slice(self._bbox[i], self._bbox[i + ndim]) for i in range(ndim)) - - else: - center = (self._bbox[: self._mask.ndim] + self._bbox[self._mask.ndim :]) // 2 - half_shape = np.asarray(shape) // 2 - start = np.maximum(center - half_shape, 0) - end = np.minimum(center + half_shape, image.shape) - slicing = tuple(slice(s, e) for s, e in zip(start, end, strict=True)) - - return image[slicing] - - def mask_indices( - self, - offset: NDArray[np.integer] | int = 0, - ) -> tuple[NDArray[np.integer], ...]: - """ - Get the indices of the pixels that are part of the object. - - Parameters - ---------- - offset : NDArray[np.integer] | int, optional - The offset to add to the indices, should be used with bounding box information. - - Returns - ------- - tuple[NDArray[np.integer], ...] - The indices of the pixels that are part of the object. - """ - if isinstance(offset, int): - offset = np.full(self._mask.ndim, offset) - - indices = list(np.nonzero(self._mask)) - - for i, index in enumerate(indices): - indices[i] = index + self._bbox[i] + offset[i] - - return tuple(indices) - - def paint_buffer( - self, - buffer: np.ndarray, - value: int | float, - offset: NDArray[np.integer] | int = 0, - ) -> None: - """ - Paint object into a buffer. - - Parameters - ---------- - buffer : np.ndarray - The buffer to paint inplace. - value : int | float - The value to paint the object. - offset : NDArray[np.integer] | int, optional - The offset to add to the indices, should be used with bounding box information. - """ - # TODO: make it zarr and tensorstore compatible - indices = self.mask_indices(offset) - buffer[indices] = value - - def iou(self, other: "Mask") -> float: - """ - Compute the Intersection over Union (IoU) between two masks - considering their bounding boxes location. - - Parameters - ---------- - other : Mask - The other mask to compute the IoU with. - - Returns - ------- - float - The IoU between the two masks. - """ - return fast_iou_with_bbox(self._bbox, other._bbox, self._mask, other._mask) - - def intersection(self, other: "Mask") -> float: - """ - Compute the intersection between two masks considering their bounding boxes location. - - Parameters - ---------- - other : Mask - The other mask to compute the intersection with. - - Returns - ------- - float - The intersection between the two masks. - """ - return fast_intersection_with_bbox(self._bbox, other._bbox, self._mask, other._mask) - - @cached_property - def size(self) -> int: - """ - Get the number of pixels that are part of the object. - """ - return self._mask.sum() - - def __repr__(self) -> str: - slicing_str = ", ".join( - f"{i}:{j}" - for i, j in zip( - self._bbox[: self._mask.ndim], - self._bbox[self._mask.ndim :], - strict=True, - ) - ) - return f"Mask(bbox=[{slicing_str}])" - - @classmethod - def from_coordinates( - cls, - center: NDArray, - radius: int, - image_shape: tuple[int, ...] | None = None, - ) -> "Mask": - """ - Create a mask from a center and a radius. - Regions outside the image are cropped. - - Parameters - ---------- - center : NDArray - The center of the mask. - radius : int - The radius of the mask. - image_shape : tuple[int, ...] | None - The shape of the image. - When provided crops regions outside the image. - - Returns - ------- - Mask - The mask. - """ - mask = _spherical_mask(radius, len(center)) - center = np.round(center).astype(int) - - start = center - np.asarray(mask.shape) // 2 - end = start + mask.shape - - if image_shape is None: - bbox = np.concatenate([start, end]) - else: - processed_start = np.maximum(start, 0) - processed_end = np.minimum(end, image_shape) - - start_overhang = processed_start - start - end_overhang = end - processed_end - - mask = mask[ - tuple(slice(s, -e if e > 0 else None) for s, e in zip(start_overhang, end_overhang, strict=True)) - ] - - bbox = np.concatenate([processed_start, processed_end]) - - return cls(mask, bbox) - - -class MaskDiskAttrs(GenericFuncNodeAttrs): - """ - Operator to create a disk mask for each node. - - Masks are created in space, so temporal information should not be provided. - - Parameters - ---------- - radius : int - The radius of the mask. - image_shape : tuple[int, ...] - The shape of the image, must match the number of of the attr_keys. - attr_keys : Sequence[str] | None - The attributes for the center of the mask. - If not provided, "z", "y", "x" will be used. - output_key : str - The key of the attribute to store the mask. - """ - - def __init__( - self, - radius: int, - image_shape: tuple[int, ...], - attr_keys: Sequence[str] | None = None, - output_key: str = DEFAULT_ATTR_KEYS.MASK, - ): - if attr_keys is None: - default_columns = ["z", "y", "x"] - attr_keys = default_columns[-len(image_shape) :] - - if len(attr_keys) != len(image_shape): - raise ValueError( - f"Expected image shape {image_shape} to have the same number of dimensions as attr_keys '{attr_keys}'." - ) - - super().__init__( - func=lambda **kwargs: Mask.from_coordinates( - center=np.asarray(list(kwargs.values())), - radius=radius, - image_shape=image_shape, - ), - output_key=output_key, - attr_keys=attr_keys, - default_value=None, - batch_size=0, - ) diff --git a/src/tracksdata/nodes/_random.py b/src/tracksdata/nodes/_random.py index 5d24fccd..123e1a37 100644 --- a/src/tracksdata/nodes/_random.py +++ b/src/tracksdata/nodes/_random.py @@ -48,9 +48,6 @@ class RandomNodes(BaseNodesOperator): [RegionPropsNodes][tracksdata.nodes.RegionPropsNodes]: Extract nodes from segmented images using region properties. - [Mask][tracksdata.nodes.Mask]: - Node operator for mask-based objects. - Examples -------- Generate 2D random nodes: diff --git a/src/tracksdata/nodes/_regionprops.py b/src/tracksdata/nodes/_regionprops.py index 70899ac4..881694ab 100644 --- a/src/tracksdata/nodes/_regionprops.py +++ b/src/tracksdata/nodes/_regionprops.py @@ -10,7 +10,6 @@ from tracksdata.constants import DEFAULT_ATTR_KEYS from tracksdata.graph._base_graph import BaseGraph from tracksdata.nodes._base_nodes import BaseNodesOperator -from tracksdata.nodes._mask import Mask from tracksdata.utils._logging import LOG from tracksdata.utils._multiprocessing import multiprocessing_apply @@ -124,6 +123,9 @@ def _init_node_attrs(self, graph: BaseGraph, axis_names: list[str]) -> None: if attr_key not in graph.node_attr_keys: graph.add_node_attr_key(attr_key, None) + if DEFAULT_ATTR_KEYS.BBOX not in graph.node_attr_keys: + graph.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, None) + # initialize the attribute keys for attr_key in axis_names + self.attrs_keys(): if attr_key not in graph.node_attr_keys: @@ -288,7 +290,7 @@ def _nodes_per_time( else: attrs[prop] = getattr(obj, prop) - attrs[DEFAULT_ATTR_KEYS.MASK] = Mask(obj.image, obj.bbox) + attrs[DEFAULT_ATTR_KEYS.MASK] = obj.image attrs[DEFAULT_ATTR_KEYS.BBOX] = np.asarray(obj.bbox, dtype=int) attrs[DEFAULT_ATTR_KEYS.T] = t diff --git a/src/tracksdata/nodes/_test/test_generic_nodes.py b/src/tracksdata/nodes/_test/test_generic_nodes.py index 1fdc9375..486a2da9 100644 --- a/src/tracksdata/nodes/_test/test_generic_nodes.py +++ b/src/tracksdata/nodes/_test/test_generic_nodes.py @@ -3,9 +3,11 @@ from numpy.typing import NDArray from tracksdata.constants import DEFAULT_ATTR_KEYS +from tracksdata.functional import crop_image_with_bbox from tracksdata.graph import RustWorkXGraph -from tracksdata.nodes import GenericFuncNodeAttrs, Mask +from tracksdata.nodes import GenericFuncNodeAttrs from tracksdata.options import get_options, options_context +from tracksdata.utils._test_utils import setup_custom_node_attr, setup_mask_attrs def test_crop_func_attrs_init_default() -> None: @@ -92,19 +94,19 @@ def test_crop_func_attrs_function_with_frames() -> None: """Test applying a function with frames.""" graph = RustWorkXGraph() - # Register attribute keys - graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, None) + # Set up graph attributes + setup_mask_attrs(graph) # Create test masks - mask1_data = np.array([[True, True], [True, False]], dtype=bool) - mask1 = Mask(mask1_data, bbox=np.array([0, 0, 2, 2])) + mask1 = np.array([[True, True], [True, False]], dtype=bool) + bbox1 = np.array([0, 0, 2, 2]) - mask2_data = np.array([[True, False], [False, False]], dtype=bool) - mask2 = Mask(mask2_data, bbox=np.array([0, 0, 2, 2])) + mask2 = np.array([[True, False], [False, False]], dtype=bool) + bbox2 = np.array([0, 0, 2, 2]) # Add nodes with masks - node1 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, DEFAULT_ATTR_KEYS.MASK: mask1}) - node2 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, DEFAULT_ATTR_KEYS.MASK: mask2}) + node1 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, DEFAULT_ATTR_KEYS.MASK: mask1, DEFAULT_ATTR_KEYS.BBOX: bbox1}) + node2 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, DEFAULT_ATTR_KEYS.MASK: mask2, DEFAULT_ATTR_KEYS.BBOX: bbox2}) # Create test frames frames = np.array( @@ -113,15 +115,15 @@ def test_crop_func_attrs_function_with_frames() -> None: ] ) - def intensity_sum(frame: NDArray, mask: Mask) -> float: - cropped = mask.crop(frame) - return float(np.sum(cropped[mask.mask])) + def intensity_sum(frame: NDArray, mask: NDArray, bbox: NDArray) -> float: + cropped = crop_image_with_bbox(frame, bbox) + return float(np.sum(cropped[mask])) # Create operator and add attributes operator = GenericFuncNodeAttrs( func=intensity_sum, output_key="intensity_sum", - attr_keys=[DEFAULT_ATTR_KEYS.MASK], + attr_keys=[DEFAULT_ATTR_KEYS.MASK, DEFAULT_ATTR_KEYS.BBOX], ) operator.add_node_attrs(graph, t=0, frames=frames) @@ -145,20 +147,24 @@ def test_crop_func_attrs_function_with_frames_and_attrs() -> None: """Test applying a function with frames and additional attributes.""" graph = RustWorkXGraph() - # Register attribute keys - graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, None) - graph.add_node_attr_key("multiplier", 1.0) + # Set up graph attributes + setup_mask_attrs(graph) + setup_custom_node_attr(graph, "multiplier", 1.0) # Create test masks - mask1_data = np.array([[True, True], [True, False]], dtype=bool) - mask1 = Mask(mask1_data, bbox=np.array([0, 0, 2, 2])) + mask1 = np.array([[True, True], [True, False]], dtype=bool) + bbox1 = np.array([0, 0, 2, 2]) - mask2_data = np.array([[True, False], [False, False]], dtype=bool) - mask2 = Mask(mask2_data, bbox=np.array([0, 0, 2, 2])) + mask2 = np.array([[True, False], [False, False]], dtype=bool) + bbox2 = np.array([0, 0, 2, 2]) # Add nodes with masks and multipliers - node1 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, DEFAULT_ATTR_KEYS.MASK: mask1, "multiplier": 2.0}) - node2 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, DEFAULT_ATTR_KEYS.MASK: mask2, "multiplier": 3.0}) + node1 = graph.add_node( + {DEFAULT_ATTR_KEYS.T: 0, DEFAULT_ATTR_KEYS.MASK: mask1, "multiplier": 2.0, DEFAULT_ATTR_KEYS.BBOX: bbox1} + ) + node2 = graph.add_node( + {DEFAULT_ATTR_KEYS.T: 0, DEFAULT_ATTR_KEYS.MASK: mask2, "multiplier": 3.0, DEFAULT_ATTR_KEYS.BBOX: bbox2} + ) # Create test frames frames = np.array( @@ -167,15 +173,15 @@ def test_crop_func_attrs_function_with_frames_and_attrs() -> None: ] ) - def intensity_sum_times_multiplier(frame: NDArray, mask: Mask, multiplier: float) -> float: - cropped = mask.crop(frame) - return float(np.sum(cropped[mask.mask]) * multiplier) + def intensity_sum_times_multiplier(frame: NDArray, mask: NDArray, bbox: NDArray, multiplier: float) -> float: + cropped = crop_image_with_bbox(frame, bbox) + return float(np.sum(cropped[mask]) * multiplier) # Create operator and add attributes operator = GenericFuncNodeAttrs( func=intensity_sum_times_multiplier, output_key="weighted_intensity", - attr_keys=["mask", "multiplier"], + attr_keys=[DEFAULT_ATTR_KEYS.MASK, DEFAULT_ATTR_KEYS.BBOX, "multiplier"], ) operator.add_node_attrs(graph, t=0, frames=frames) @@ -198,33 +204,33 @@ def test_crop_func_attrs_function_returns_different_types() -> None: """Test that functions can return different types.""" graph = RustWorkXGraph() - # Register attribute keys - graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, None) + # Set up graph attributes + setup_mask_attrs(graph) # Create test mask - mask_data = np.array([[True, True], [True, False]], dtype=bool) - mask = Mask(mask_data, bbox=np.array([0, 0, 2, 2])) + mask = np.array([[True, True], [True, False]], dtype=bool) + bbox = np.array([0, 0, 2, 2]) # Add node - graph.add_node({DEFAULT_ATTR_KEYS.T: 0, DEFAULT_ATTR_KEYS.MASK: mask}) + graph.add_node({DEFAULT_ATTR_KEYS.T: 0, DEFAULT_ATTR_KEYS.MASK: mask, DEFAULT_ATTR_KEYS.BBOX: bbox}) - def return_string(mask: Mask) -> str: + def return_string(mask: NDArray, bbox: NDArray) -> str: return "test_string" - def return_list(mask: Mask) -> list[int]: + def return_list(mask: NDArray, bbox: NDArray) -> list[int]: return [1, 2, 3] - def return_dict(mask: Mask) -> dict[str, int]: + def return_dict(mask: NDArray, bbox: NDArray) -> dict[str, int]: return {"count": 3} - def return_array(mask: Mask) -> NDArray: + def return_array(mask: NDArray, bbox: NDArray) -> NDArray: return np.asarray([1, 2, 3]) # Test string return type operator_str = GenericFuncNodeAttrs( func=return_string, output_key="string_result", - attr_keys=[DEFAULT_ATTR_KEYS.MASK], + attr_keys=[DEFAULT_ATTR_KEYS.MASK, DEFAULT_ATTR_KEYS.BBOX], ) operator_str.add_node_attrs(graph) @@ -232,7 +238,7 @@ def return_array(mask: Mask) -> NDArray: operator_list = GenericFuncNodeAttrs( func=return_list, output_key="list_result", - attr_keys=[DEFAULT_ATTR_KEYS.MASK], + attr_keys=[DEFAULT_ATTR_KEYS.MASK, DEFAULT_ATTR_KEYS.BBOX], ) operator_list.add_node_attrs(graph) @@ -240,7 +246,7 @@ def return_array(mask: Mask) -> NDArray: operator_dict = GenericFuncNodeAttrs( func=return_dict, output_key="dict_result", - attr_keys=[DEFAULT_ATTR_KEYS.MASK], + attr_keys=[DEFAULT_ATTR_KEYS.MASK, DEFAULT_ATTR_KEYS.BBOX], ) operator_dict.add_node_attrs(graph) @@ -248,7 +254,7 @@ def return_array(mask: Mask) -> NDArray: operator_array = GenericFuncNodeAttrs( func=return_array, output_key="array_result", - attr_keys=[DEFAULT_ATTR_KEYS.MASK], + attr_keys=[DEFAULT_ATTR_KEYS.MASK, DEFAULT_ATTR_KEYS.BBOX], ) operator_array.add_node_attrs(graph) @@ -264,18 +270,18 @@ def test_crop_func_attrs_error_handling_missing_attr_key() -> None: """Test error handling when required attr_key is missing.""" graph = RustWorkXGraph() - # Register attribute keys - graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, None) + # Set up graph attributes + setup_mask_attrs(graph) # Note: "value" is not registered # Create test mask - mask_data = np.array([[True, True], [True, False]], dtype=bool) - mask = Mask(mask_data, bbox=np.array([0, 0, 2, 2])) + mask = np.array([[True, True], [True, False]], dtype=bool) + bbox = np.array([0, 0, 2, 2]) # Add node without the required attribute - graph.add_node({DEFAULT_ATTR_KEYS.T: 0, DEFAULT_ATTR_KEYS.MASK: mask}) + graph.add_node({DEFAULT_ATTR_KEYS.T: 0, DEFAULT_ATTR_KEYS.MASK: mask, DEFAULT_ATTR_KEYS.BBOX: bbox}) - def use_value(mask: Mask, value: float) -> float: + def use_value(mask: NDArray, bbox: NDArray, value: float) -> float: return value * 2.0 # Create operator that requires "value" attribute @@ -295,19 +301,19 @@ def test_crop_func_attrs_function_with_frames_multiprocessing(n_workers: int) -> """Test applying a function with frames using different worker counts.""" graph = RustWorkXGraph() - # Register attribute keys - graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, None) + # Set up graph attributes + setup_mask_attrs(graph) # Create test masks for multiple time points - mask1_data = np.array([[True, True], [True, False]], dtype=bool) - mask1 = Mask(mask1_data, bbox=np.array([0, 0, 2, 2])) + mask1 = np.array([[True, True], [True, False]], dtype=bool) + bbox1 = np.array([0, 0, 2, 2]) - mask2_data = np.array([[True, False], [False, False]], dtype=bool) - mask2 = Mask(mask2_data, bbox=np.array([0, 0, 2, 2])) + mask2 = np.array([[True, False], [False, False]], dtype=bool) + bbox2 = np.array([0, 0, 2, 2]) # Add nodes with masks at different time points - node1 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, DEFAULT_ATTR_KEYS.MASK: mask1}) - node2 = graph.add_node({DEFAULT_ATTR_KEYS.T: 1, DEFAULT_ATTR_KEYS.MASK: mask2}) + node1 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, DEFAULT_ATTR_KEYS.MASK: mask1, DEFAULT_ATTR_KEYS.BBOX: bbox1}) + node2 = graph.add_node({DEFAULT_ATTR_KEYS.T: 1, DEFAULT_ATTR_KEYS.MASK: mask2, DEFAULT_ATTR_KEYS.BBOX: bbox2}) # Create test frames for multiple time points frames = np.array( @@ -317,15 +323,15 @@ def test_crop_func_attrs_function_with_frames_multiprocessing(n_workers: int) -> ] ) - def intensity_sum(frame: NDArray, mask: Mask) -> float: - cropped = mask.crop(frame) - return float(np.sum(cropped[mask.mask])) + def intensity_sum(frame: NDArray, mask: NDArray, bbox: NDArray) -> float: + cropped = crop_image_with_bbox(frame, bbox) + return float(np.sum(cropped[mask])) # Create operator and add attributes operator = GenericFuncNodeAttrs( func=intensity_sum, output_key="intensity_sum", - attr_keys=[DEFAULT_ATTR_KEYS.MASK], + attr_keys=[DEFAULT_ATTR_KEYS.MASK, DEFAULT_ATTR_KEYS.BBOX], ) with options_context(n_workers=n_workers): @@ -347,10 +353,10 @@ def test_crop_func_attrs_empty_graph() -> None: """Test behavior with an empty graph.""" graph = RustWorkXGraph() - # Register attribute keys - graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, None) + # Set up graph attributes + setup_custom_node_attr(graph, DEFAULT_ATTR_KEYS.MASK, None) - def dummy_func(mask: Mask) -> float: + def dummy_func(mask: NDArray) -> float: return 1.0 operator = GenericFuncNodeAttrs( @@ -413,23 +419,23 @@ def test_crop_func_attrs_batch_processing_with_frames() -> None: """Test batch processing with batch_size > 0 with frames.""" graph = RustWorkXGraph() - # Register attribute keys - graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, None) + # Set up graph attributes + setup_mask_attrs(graph) # Create test masks - mask1_data = np.array([[True, True], [True, False]], dtype=bool) - mask1 = Mask(mask1_data, bbox=np.array([0, 0, 2, 2])) + mask1 = np.array([[True, True], [True, False]], dtype=bool) + bbox1 = np.array([0, 0, 2, 2]) - mask2_data = np.array([[True, False], [False, False]], dtype=bool) - mask2 = Mask(mask2_data, bbox=np.array([0, 0, 2, 2])) + mask2 = np.array([[True, False], [False, False]], dtype=bool) + bbox2 = np.array([0, 0, 2, 2]) - mask3_data = np.array([[False, True], [True, True]], dtype=bool) - mask3 = Mask(mask3_data, bbox=np.array([0, 0, 2, 2])) + mask3 = np.array([[False, True], [True, True]], dtype=bool) + bbox3 = np.array([0, 0, 2, 2]) # Add nodes with masks - node1 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, DEFAULT_ATTR_KEYS.MASK: mask1}) - node2 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, DEFAULT_ATTR_KEYS.MASK: mask2}) - node3 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, DEFAULT_ATTR_KEYS.MASK: mask3}) + node1 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, DEFAULT_ATTR_KEYS.MASK: mask1, DEFAULT_ATTR_KEYS.BBOX: bbox1}) + node2 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, DEFAULT_ATTR_KEYS.MASK: mask2, DEFAULT_ATTR_KEYS.BBOX: bbox2}) + node3 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, DEFAULT_ATTR_KEYS.MASK: mask3, DEFAULT_ATTR_KEYS.BBOX: bbox3}) # Create test frames frames = np.array( @@ -438,19 +444,19 @@ def test_crop_func_attrs_batch_processing_with_frames() -> None: ] ) - def batch_intensity_sum(frame: NDArray, mask: list[Mask]) -> list[float]: + def batch_intensity_sum(frame: NDArray, mask: list[NDArray], bbox: list[NDArray]) -> list[float]: """Batch function that calculates intensity sum for each mask.""" results = [] - for m in mask: - cropped = m.crop(frame) - results.append(float(np.sum(cropped[m.mask]))) + for m, b in zip(mask, bbox, strict=False): + cropped = crop_image_with_bbox(frame, b) + results.append(float(np.sum(cropped[m]))) return results # Create operator with batch_size = 2 operator = GenericFuncNodeAttrs( func=batch_intensity_sum, output_key="intensity_sum", - attr_keys=["mask"], + attr_keys=[DEFAULT_ATTR_KEYS.MASK, DEFAULT_ATTR_KEYS.BBOX], batch_size=2, ) diff --git a/src/tracksdata/nodes/_test/test_regionprops.py b/src/tracksdata/nodes/_test/test_regionprops.py index 7fe8466c..c0a97015 100644 --- a/src/tracksdata/nodes/_test/test_regionprops.py +++ b/src/tracksdata/nodes/_test/test_regionprops.py @@ -1,10 +1,11 @@ import numpy as np +import polars as pl import pytest from skimage.measure._regionprops import RegionProperties from tracksdata.constants import DEFAULT_ATTR_KEYS from tracksdata.graph import RustWorkXGraph -from tracksdata.nodes import Mask, RegionPropsNodes +from tracksdata.nodes import RegionPropsNodes from tracksdata.options import get_options, options_context @@ -64,6 +65,7 @@ def test_regionprops_add_nodes_2d() -> None: assert "x" in nodes_df.columns assert "area" in nodes_df.columns assert DEFAULT_ATTR_KEYS.MASK in nodes_df.columns + assert DEFAULT_ATTR_KEYS.BBOX in nodes_df.columns # Check that all nodes have t=0 assert all(nodes_df[DEFAULT_ATTR_KEYS.T] == 0) @@ -98,6 +100,7 @@ def test_regionprops_add_nodes_3d() -> None: assert "x" in nodes_df.columns assert "area" in nodes_df.columns assert DEFAULT_ATTR_KEYS.MASK in nodes_df.columns + assert DEFAULT_ATTR_KEYS.BBOX in nodes_df.columns def test_regionprops_add_nodes_with_intensity() -> None: @@ -231,12 +234,10 @@ def test_regionprops_mask_creation() -> None: # Check that masks were created nodes_df = graph.node_attrs() masks = nodes_df[DEFAULT_ATTR_KEYS.MASK] + bboxes = nodes_df[DEFAULT_ATTR_KEYS.BBOX] - # All masks should be Mask objects - for mask in masks: - assert isinstance(mask, Mask) - assert mask._mask is not None - assert mask._bbox is not None + assert isinstance(masks.dtype, pl.Object) + assert isinstance(bboxes.dtype, pl.List) def test_regionprops_spacing() -> None: diff --git a/src/tracksdata/solvers/_test/test_ilp_solver.py b/src/tracksdata/solvers/_test/test_ilp_solver.py index f35a07c6..3d58ecbb 100644 --- a/src/tracksdata/solvers/_test/test_ilp_solver.py +++ b/src/tracksdata/solvers/_test/test_ilp_solver.py @@ -6,6 +6,10 @@ from tracksdata.constants import DEFAULT_ATTR_KEYS from tracksdata.graph import RustWorkXGraph from tracksdata.solvers import ILPSolver +from tracksdata.utils._test_utils import ( + setup_edge_distance_attr, + setup_spatial_attrs_2d, +) def test_ilp_solver_init_default() -> None: @@ -85,8 +89,7 @@ def test_ilp_solver_solve_no_edges(caplog: pytest.LogCaptureFixture) -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", 0.0) - graph.add_node_attr_key("y", 0.0) + setup_spatial_attrs_2d(graph) # Add some nodes graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "y": 0.0}) @@ -107,9 +110,8 @@ def test_ilp_solver_solve_simple_case() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", 0.0) - graph.add_node_attr_key("y", 0.0) - graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, 0.0) + setup_spatial_attrs_2d(graph) + setup_edge_distance_attr(graph) # Add nodes node0 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "y": 0.0}) @@ -147,9 +149,8 @@ def test_ilp_solver_solve_with_appearance_weight() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", 0.0) - graph.add_node_attr_key("y", 0.0) - graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, 0.0) + setup_spatial_attrs_2d(graph) + setup_edge_distance_attr(graph) # Add nodes node0 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "y": 0.0}) @@ -187,9 +188,8 @@ def test_ilp_solver_solve_with_disappearance_weight() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", 0.0) - graph.add_node_attr_key("y", 0.0) - graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, 0.0) + setup_spatial_attrs_2d(graph) + setup_edge_distance_attr(graph) # Add nodes node0 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "y": 0.0}) @@ -227,9 +227,8 @@ def test_ilp_solver_solve_with_division_weight() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", 0.0) - graph.add_node_attr_key("y", 0.0) - graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, 0.0) + setup_spatial_attrs_2d(graph) + setup_edge_distance_attr(graph) # Add nodes for division scenario node0 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "y": 0.0}) @@ -279,8 +278,7 @@ def test_ilp_solver_solve_custom_edge_weight_expr() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", 0.0) - graph.add_node_attr_key("y", 0.0) + setup_spatial_attrs_2d(graph) graph.add_edge_attr_key("custom_weight", 0.0) graph.add_edge_attr_key("confidence", 0.0) @@ -311,7 +309,7 @@ def test_ilp_solver_solve_custom_node_weight_expr() -> None: # Register attribute keys graph.add_node_attr_key("x", 0.0) graph.add_node_attr_key("quality", 0.0) - graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, 0.0) + setup_edge_distance_attr(graph) # Add nodes with quality attribute node0 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "quality": 0.9}) @@ -337,9 +335,8 @@ def test_ilp_solver_solve_custom_output_key() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", 0.0) - graph.add_node_attr_key("y", 0.0) - graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, 0.0) + setup_spatial_attrs_2d(graph) + setup_edge_distance_attr(graph) # Add nodes and edges node0 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "y": 0.0}) @@ -365,9 +362,8 @@ def test_ilp_solver_solve_with_all_weights() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", 0.0) - graph.add_node_attr_key("y", 0.0) - graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, 0.0) + setup_spatial_attrs_2d(graph) + setup_edge_distance_attr(graph) # Add nodes node0 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "y": 0.0}) @@ -430,9 +426,8 @@ def test_ilp_solver_division_constraint() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", 0.0) - graph.add_node_attr_key("y", 0.0) - graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, 0.0) + setup_spatial_attrs_2d(graph) + setup_edge_distance_attr(graph) # Create a scenario where division would be tempting but should be constrained # Time 0: 1 parent node @@ -503,9 +498,8 @@ def test_ilp_solver_solve_with_inf_expr() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", 0.0) - graph.add_node_attr_key("y", 0.0) - graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, 0.0) + setup_spatial_attrs_2d(graph) + setup_edge_distance_attr(graph) # Add nodes node0 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "y": 5.0}) @@ -537,7 +531,7 @@ def test_ilp_solver_solve_with_pos_inf_rejection() -> None: # Register attribute keys graph.add_node_attr_key("x", 0.0) - graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, 0.0) + setup_edge_distance_attr(graph) # Add nodes node0 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 1.0}) @@ -566,7 +560,7 @@ def test_ilp_solver_solve_with_neg_inf_node_weight() -> None: # Register attribute keys graph.add_node_attr_key("priority", 0.0) - graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, 0.0) + setup_edge_distance_attr(graph) # Add nodes node0 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "priority": 1.0}) # High priority @@ -626,9 +620,8 @@ def test_ilp_solver_solve_with_overlaps() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", 0.0) - graph.add_node_attr_key("y", 0.0) - graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, 0.0) + setup_spatial_attrs_2d(graph) + setup_edge_distance_attr(graph) # Add nodes - overlapping pair at time t=1 node0 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "y": 0.0}) diff --git a/src/tracksdata/solvers/_test/test_nearest_neighbors_solver.py b/src/tracksdata/solvers/_test/test_nearest_neighbors_solver.py index 4481013f..40f759c6 100644 --- a/src/tracksdata/solvers/_test/test_nearest_neighbors_solver.py +++ b/src/tracksdata/solvers/_test/test_nearest_neighbors_solver.py @@ -4,6 +4,10 @@ from tracksdata.constants import DEFAULT_ATTR_KEYS from tracksdata.graph import RustWorkXGraph from tracksdata.solvers import NearestNeighborsSolver +from tracksdata.utils._test_utils import ( + setup_edge_distance_attr, + setup_spatial_attrs_2d, +) def test_nearest_neighbors_solver_init_default() -> None: @@ -47,8 +51,7 @@ def test_nearest_neighbors_solver_solve_no_edges() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", 0.0) - graph.add_node_attr_key("y", 0.0) + setup_spatial_attrs_2d(graph) # Add some nodes graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "y": 0.0}) @@ -66,9 +69,8 @@ def test_nearest_neighbors_solver_solve_simple_case() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", 0.0) - graph.add_node_attr_key("y", 0.0) - graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, 0.0) + setup_spatial_attrs_2d(graph) + setup_edge_distance_attr(graph) # Add nodes node0 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "y": 0.0}) @@ -102,9 +104,8 @@ def test_nearest_neighbors_solver_solve_max_children_constraint() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", 0.0) - graph.add_node_attr_key("y", 0.0) - graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, 0.0) + setup_spatial_attrs_2d(graph) + setup_edge_distance_attr(graph) # Add nodes node0 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "y": 0.0}) # Parent @@ -142,9 +143,8 @@ def test_nearest_neighbors_solver_solve_one_parent_constraint() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", 0.0) - graph.add_node_attr_key("y", 0.0) - graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, 0.0) + setup_spatial_attrs_2d(graph) + setup_edge_distance_attr(graph) # Add nodes node0 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "y": 0.0}) # Parent 1 @@ -174,8 +174,7 @@ def test_nearest_neighbors_solver_solve_custom_weight_expr() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", 0.0) - graph.add_node_attr_key("y", 0.0) + setup_spatial_attrs_2d(graph) graph.add_edge_attr_key("custom_weight", 0.0) # Add nodes @@ -207,8 +206,7 @@ def test_nearest_neighbors_solver_solve_complex_expression() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", 0.0) - graph.add_node_attr_key("y", 0.0) + setup_spatial_attrs_2d(graph) graph.add_edge_attr_key("distance", 0.0) graph.add_edge_attr_key("confidence", 0.0) @@ -242,9 +240,8 @@ def test_nearest_neighbors_solver_solve_custom_output_key() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", 0.0) - graph.add_node_attr_key("y", 0.0) - graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, 0.0) + setup_spatial_attrs_2d(graph) + setup_edge_distance_attr(graph) # Add nodes and edges node0 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "y": 0.0}) @@ -270,9 +267,8 @@ def test_nearest_neighbors_solver_solve_with_overlaps() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", 0.0) - graph.add_node_attr_key("y", 0.0) - graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, 0.0) + setup_spatial_attrs_2d(graph) + setup_edge_distance_attr(graph) # Add nodes - overlapping pair at time t=1 node0 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "y": 0.0}) @@ -325,9 +321,8 @@ def test_nearest_neighbors_solver_solve_large_graph() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", 0.0) - graph.add_node_attr_key("y", 0.0) - graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, 0.0) + setup_spatial_attrs_2d(graph) + setup_edge_distance_attr(graph) # Create a more complex graph structure # Time 0: nodes 0, 1 diff --git a/src/tracksdata/utils/_dtypes.py b/src/tracksdata/utils/_dtypes.py index 1c11a36d..d153dc14 100644 --- a/src/tracksdata/utils/_dtypes.py +++ b/src/tracksdata/utils/_dtypes.py @@ -1,3 +1,6 @@ +from typing import Any + +import blosc2 import numpy as np import polars as pl from cloudpickle import dumps, loads @@ -18,6 +21,8 @@ UInt64, ) +from tracksdata.utils._logging import LOG + _POLARS_DTYPE_TO_NUMPY_DTYPE = { Datetime: np.datetime64, Boolean: np.bool_, @@ -56,6 +61,29 @@ def polars_dtype_to_numpy_dtype(polars_dtype: DataType) -> np.dtype: ) from e +def _try_packing_numpy_array(x: Any) -> bytes: + if isinstance(x, np.ndarray): + packed = blosc2.pack_array2(x) + else: + packed = dumps(x) + return packed + + +def _try_unpacking_numpy_array(x: bytes) -> Any: + try: + unpacked = blosc2.unpack_array2(x) + except (RuntimeError, ValueError, TypeError) as e: + # If blosc2 fails, try cloudpickle + try: + unpacked = loads(x) + except Exception as pickle_error: + raise ValueError( + f"Failed to deserialize data: blosc2 error: {e}, pickle error: {pickle_error}" + ) from pickle_error + + return unpacked + + def column_to_bytes(df: pl.DataFrame, column: str) -> pl.DataFrame: """ Convert a column of a DataFrame to bytes. @@ -73,10 +101,13 @@ def column_to_bytes(df: pl.DataFrame, column: str) -> pl.DataFrame: pl.DataFrame The converted DataFrame. """ - return df.with_columns(pl.col(column).map_elements(dumps, return_dtype=pl.Binary)) + prev_nthreads = blosc2.set_nthreads(1) + df = df.with_columns(pl.col(column).map_elements(_try_packing_numpy_array, return_dtype=pl.Binary)) + blosc2.set_nthreads(prev_nthreads) + return df -def column_from_bytes(df: pl.DataFrame, column: str) -> pl.DataFrame: +def column_from_bytes(df: pl.DataFrame, column: str | None = None) -> pl.DataFrame: """ Convert a column of a DataFrame from bytes. @@ -84,12 +115,41 @@ def column_from_bytes(df: pl.DataFrame, column: str) -> pl.DataFrame: ---------- df : pl.DataFrame The DataFrame to convert. - column : str - The column to convert. + column : str | None + The column to convert. If not provided, all pl.Binary columns will be converted. Returns ------- pl.DataFrame The converted DataFrame. """ - return df.with_columns(pl.col(column).map_elements(loads, return_dtype=pl.Object)) + # This function used to be simple + # but polars sometimes was failing to serialize numpy arrays + + if column is None: + columns = [c for c, d in zip(df.columns, df.dtypes, strict=False) if d == pl.Binary] + else: + columns = [column] # if column in df.columns and df[column].dtype == pl.Binary else [] + + # If no binary columns found, return as-is (data already deserialized) + if not columns: + return df + + prev_nthreads = blosc2.set_nthreads(1) + for c in columns: + # Always use Object dtype to avoid issues with heterogeneous array shapes/types + try: + df = df.with_columns(pl.col(c).map_elements(_try_unpacking_numpy_array, return_dtype=pl.Object)) + except (pl.exceptions.ComputeError, pl.exceptions.InvalidOperationError) as e: + # If there's an issue with map_elements (e.g., polars conversion errors), + # fall back to manual conversion + LOG.warning(f"Polars error in map_elements for column {c}: {e}, falling back to manual conversion") + try: + values = [_try_unpacking_numpy_array(val) for val in df[c].to_numpy()] + df = df.with_columns(pl.Series(name=c, values=values, dtype=pl.Object)) + except Exception as fallback_error: + LOG.error(f"Failed to deserialize column {c} even with manual fallback: {fallback_error}") + raise ValueError(f"Unable to deserialize column {c}") from fallback_error + + blosc2.set_nthreads(prev_nthreads) + return df diff --git a/src/tracksdata/utils/_test_utils.py b/src/tracksdata/utils/_test_utils.py new file mode 100644 index 00000000..7fd8fd70 --- /dev/null +++ b/src/tracksdata/utils/_test_utils.py @@ -0,0 +1,99 @@ +""" +Utility functions for setting up tests to reduce code duplication. +""" + +from typing import Any + +from tracksdata.constants import DEFAULT_ATTR_KEYS +from tracksdata.graph._base_graph import BaseGraph + + +def setup_mask_attrs(graph: BaseGraph) -> None: + """Set up mask and bbox attribute keys on a graph.""" + graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, None) + graph.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, None) + + +def setup_spatial_attrs_2d(graph: BaseGraph, default_value: float = 0.0) -> None: + """Set up 2D spatial coordinate attributes (x, y) on a graph.""" + graph.add_node_attr_key("x", default_value) + graph.add_node_attr_key("y", default_value) + + +def setup_spatial_attrs_3d(graph: BaseGraph, default_value: float = 0.0) -> None: + """Set up 3D spatial coordinate attributes (x, y, z) on a graph.""" + setup_spatial_attrs_2d(graph, default_value) + graph.add_node_attr_key("z", default_value) + + +def setup_time_attr(graph: BaseGraph, default_value: int = 0) -> None: + """Set up time attribute on a graph.""" + graph.add_node_attr_key(DEFAULT_ATTR_KEYS.T, default_value) + + +def setup_edge_distance_attr(graph: BaseGraph, default_value: float = 0.0) -> None: + """Set up edge distance attribute on a graph.""" + graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, default_value) + + +def setup_solution_attrs(graph: BaseGraph) -> None: + """Set up solution attributes for tracking on a graph.""" + graph.add_node_attr_key(DEFAULT_ATTR_KEYS.SOLUTION, True) + graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.SOLUTION, True) + + +def setup_tracking_graph( + graph: BaseGraph, + *, + spatial_dims: int = 2, + include_time: bool = True, + include_mask: bool = False, + include_edge_dist: bool = True, + spatial_default: float = 0.0, + time_default: int = 0, +) -> None: + """ + Set up a graph with common tracking attributes. + + Parameters + ---------- + graph : BaseGraph + The graph to set up. + spatial_dims : int, default 2 + Number of spatial dimensions (2 or 3). + include_time : bool, default True + Whether to include time attribute. + include_mask : bool, default False + Whether to include mask and bbox attributes. + include_edge_dist : bool, default True + Whether to include edge distance attribute. + spatial_default : float, default 0.0 + Default value for spatial coordinates. + time_default : int, default 0 + Default value for time attribute. + """ + if include_time: + setup_time_attr(graph, time_default) + + if spatial_dims == 2: + setup_spatial_attrs_2d(graph, spatial_default) + elif spatial_dims == 3: + setup_spatial_attrs_3d(graph, spatial_default) + else: + raise ValueError(f"Unsupported spatial dimensions: {spatial_dims}") + + if include_mask: + setup_mask_attrs(graph) + + if include_edge_dist: + setup_edge_distance_attr(graph) + + +def setup_custom_node_attr(graph: BaseGraph, key: str, default_value: Any) -> None: + """Set up a custom node attribute on a graph.""" + graph.add_node_attr_key(key, default_value) + + +def setup_custom_edge_attr(graph: BaseGraph, key: str, default_value: Any) -> None: + """Set up a custom edge attribute on a graph.""" + graph.add_edge_attr_key(key, default_value)