Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
320 changes: 314 additions & 6 deletions automated_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,10 +522,318 @@ def test_crop():
res = skel.crop(bbx)
assert np.all(res.vertices == skel.vertices)







class TestSkeletonChunking:
def test_single_chunk_small_skeleton(self):
"""Test that a small skeleton fits in a single chunk."""
vertices = np.array([
[0.0, 0.0, 0.0],
[1.0, 0.0, 0.0],
[1.0, 1.0, 0.0],
], dtype=np.float32)
edges = np.array([
[0, 1],
[1, 2],
], dtype=np.uint32)

skeleton = Skeleton(vertices, edges)
chunks = skeleton.chunk(chunk_size=(10.0, 10.0, 10.0))

assert len(chunks) == 1
assert (0, 0, 0) in chunks
assert len(chunks[(0, 0, 0)].vertices) > 0

def test_multiple_chunks_2x2_grid(self):
"""Test skeleton split across a 2x2x1 grid."""
vertices = np.array([
[0.0, 0.0, 0.0],
[1.5, 0.0, 0.0],
[0.0, 1.5, 0.0],
[1.5, 1.5, 0.0],
], dtype=np.float32)
edges = np.array([
[0, 1],
[0, 2],
[1, 3],
[2, 3],
], dtype=np.uint32)

skeleton = Skeleton(vertices, edges)
chunks = skeleton.chunk(chunk_size=(1.0, 1.0, 1.0))

# Should have chunks in multiple grid cells
assert len(chunks) > 1

# All chunks should be valid skeletons
for chunk_skel in chunks.values():
assert len(chunk_skel.vertices) > 0
assert len(chunk_skel.edges) > 0

def test_edge_crossing_boundary(self):
"""Test that edges crossing chunk boundaries are split."""
vertices = np.array([
[0.5, 0.5, 0.5],
[1.5, 0.5, 0.5], # Crosses boundary at x=1.0
], dtype=np.float32)
edges = np.array([[0, 1]], dtype=np.uint32)

skeleton = Skeleton(vertices, edges)
chunks = skeleton.chunk(
chunk_size=(1.0, 1.0, 1.0),
origin=(0,0,0),
)

# Should have 2 chunks
assert len(chunks) == 2
assert (0, 0, 0) in chunks
assert (1, 0, 0) in chunks

sk0 = chunks[(0, 0, 0)]
sk1 = chunks[(1, 0, 0)]
assert list(sk0.edges[0]) == [0,1]
assert list(sk1.edges[0]) == [0,1]

assert list(sk0.vertices[1]) == [1,0.5,0.5]
assert list(sk1.vertices[0]) == [1,0.5,0.5]

def test_custom_origin(self):
"""Test chunking with custom origin."""
vertices = np.array([
[5.0, 5.0, 5.0],
[6.0, 5.0, 5.0],
], dtype=np.float32)
edges = np.array([[0, 1]], dtype=np.uint32)

skeleton = Skeleton(vertices, edges)

# # With default origin (should be at min vertex)
chunks_default = skeleton.chunk(chunk_size=(1.0, 1.0, 1.0))

# With custom origin at (0, 0, 0)
chunks_custom = skeleton.chunk(
chunk_size=(1.0, 1.0, 1.0),
origin=np.array([0.0, 0.0, 0.0])
)

# Keys should be different due to different grid alignment
assert set(chunks_default.keys()) != set(chunks_custom.keys())

# With custom origin, should be in grid cell (5, 5, 5)
assert (5, 4, 4) in chunks_custom

def test_empty_skeleton(self):
"""Test chunking an empty skeleton."""
vertices = np.array([], dtype=np.float32).reshape(0, 3)
edges = np.array([], dtype=np.uint32).reshape(0, 2)

skeleton = Skeleton(vertices, edges)
chunks = skeleton.chunk(chunk_size=(1.0, 1.0, 1.0))

# Should return empty dict or single empty chunk
assert len(chunks) == 0 or all(
len(c.vertices) == 0 for c in chunks.values()
)

def test_vertex_on_boundary(self):
"""Test vertices exactly on chunk boundaries."""
vertices = np.array([
[0.0, 0.0, 0.0],
[1.0, 0.0, 0.0], # Exactly on boundary
[2.0, 0.0, 0.0],
], dtype=np.float32)
edges = np.array([
[0, 1],
[1, 2],
], dtype=np.uint32)

skeleton = Skeleton(vertices, edges)
chunks = skeleton.chunk(chunk_size=(1.0, 1.0, 1.0))

# Should handle boundary vertices consistently
assert len(chunks) >= 2

# Total edges across all chunks should match or exceed original
# (may be more due to splitting)
total_edges = sum(len(c.edges) for c in chunks.values())
assert total_edges >= len(edges)

def test_3d_grid(self):
"""Test chunking across all three dimensions."""
vertices = np.array([
[0.5, 0.5, 0.5],
[1.5, 1.5, 1.5],
], dtype=np.float32)
edges = np.array([[0, 1]], dtype=np.uint32)

skeleton = Skeleton(vertices, edges)
chunks = skeleton.chunk(
chunk_size=(1.0, 1.0, 1.0),
origin=(0,0,0),
)

# Diagonal line should cross through multiple chunks
assert len(chunks) >= 2

# Check that chunks are in different grid positions
keys = list(chunks.keys())
assert keys[0] != keys[-1]

def test_non_uniform_chunk_size(self):
"""Test with different chunk sizes in each dimension."""
vertices = np.array([
[0.0, 0.0, 0.0],
[2.0, 4.0, 8.0],
], dtype=np.float32)
edges = np.array([[0, 1]], dtype=np.uint32)

skeleton = Skeleton(vertices, edges)
chunks = skeleton.chunk(chunk_size=(1.0, 2.0, 4.0))

assert len(chunks) > 0

# Verify grid keys make sense with non-uniform sizes
for (gx, gy, gz) in chunks.keys():
assert gx >= 0 and gx <= 2 # 2.0 / 1.0 = 2 chunks in x
assert gy >= 0 and gy <= 2 # 4.0 / 2.0 = 2 chunks in y
assert gz >= 0 and gz <= 2 # 8.0 / 4.0 = 2 chunks in z

def test_large_chunk_size(self):
"""Test with chunk size larger than skeleton bounds."""
vertices = np.array([
[0.0, 0.0, 0.0],
[1.0, 1.0, 1.0],
], dtype=np.float32)
edges = np.array([[0, 1]], dtype=np.uint32)

skeleton = Skeleton(vertices, edges)
chunks = skeleton.chunk(chunk_size=(100.0, 100.0, 100.0))

# Everything should fit in one chunk
assert len(chunks) == 1

def test_vertex_preservation(self):
"""Test that all original vertices are preserved across chunks."""
vertices = np.array([
[0.0, 0.0, 0.0],
[0.5, 0.5, 0.5],
[1.5, 0.5, 0.5],
[2.0, 1.0, 1.0],
], dtype=np.float32)
edges = np.array([
[0, 1],
[1, 2],
[2, 3],
], dtype=np.uint32)

skeleton = Skeleton(vertices, edges)
chunks = skeleton.chunk(chunk_size=(1.0, 1.0, 1.0))

# Collect all unique vertices from chunks
all_chunk_vertices = []
for chunk_skel in chunks.values():
all_chunk_vertices.append(chunk_skel.vertices)

# Should have at least as many vertices as original
# (may have more due to splitting at boundaries)
total_verts = sum(len(v) for v in all_chunk_vertices)
assert total_verts >= len(vertices)

def test_connectivity_within_chunks(self):
"""Test that edges within chunks reference valid vertices."""
vertices = np.array([
[0.0, 0.0, 0.0],
[0.5, 0.0, 0.0],
[1.0, 0.0, 0.0],
[1.5, 0.0, 0.0],
], dtype=np.float32)
edges = np.array([
[0, 1],
[1, 2],
[2, 3],
], dtype=np.uint32)

skeleton = Skeleton(vertices, edges)
chunks = skeleton.chunk(chunk_size=(1.0, 1.0, 1.0))

# Verify all edges reference valid vertex indices
for chunk_skel in chunks.values():
num_verts = len(chunk_skel.vertices)
for edge in chunk_skel.edges:
assert edge[0] < num_verts
assert edge[1] < num_verts
assert edge[0] >= 0
assert edge[1] >= 0

def test_different_edge_dtypes(self):
"""Test with different edge array data types."""
vertices = np.array([
[0.0, 0.0, 0.0],
[1.0, 1.0, 1.0],
], dtype=np.float32)

for dtype in [np.uint8, np.uint16, np.uint32, np.uint64]:
edges = np.array([[0, 1]], dtype=dtype)
skeleton = Skeleton(vertices, edges)
chunks = skeleton.chunk(chunk_size=(1.0, 1.0, 1.0))

assert len(chunks) > 0, f"Failed for dtype {dtype}"

def test_negative_coordinates(self):
"""Test chunking with negative coordinates."""
vertices = np.array([
[-1.0, -1.0, -1.0],
[1.0, 1.0, 1.0],
], dtype=np.float32)
edges = np.array([[0, 1]], dtype=np.uint32)

skeleton = Skeleton(vertices, edges)
chunks = skeleton.chunk(chunk_size=(1.0, 1.0, 1.0))

assert len(chunks) > 0

# With default origin at min vertex, grid indices should be >= 0
for key in chunks.keys():
assert all(coord >= 0 for coord in key)

def test_star_topology(self):
"""Test skeleton with star topology (one vertex connected to many)."""
center = np.array([[0.0, 0.0, 0.0]], dtype=np.float32)
spokes = np.array([
[2.0, 0.0, 0.0],
[0.0, 2.0, 0.0],
[0.0, 0.0, 2.0],
[-2.0, 0.0, 0.0],
[0.0, -2.0, 0.0],
[0.0, 0.0, -2.0],
], dtype=np.float32)

vertices = np.vstack([center, spokes])
edges = np.array([
[0, i] for i in range(1, 7)
], dtype=np.uint32)

skeleton = Skeleton(vertices, edges)
chunks = skeleton.chunk(chunk_size=(1.0, 1.0, 1.0))

assert len(chunks) > 1

@pytest.mark.parametrize("chunk_size", [
(0.5, 0.5, 0.5),
(1.0, 1.0, 1.0),
(2.0, 2.0, 2.0),
(10.0, 10.0, 10.0),
])
def test_various_chunk_sizes(self, chunk_size):
"""Test with various chunk sizes."""
vertices = np.array([
[0.0, 0.0, 0.0],
[5.0, 5.0, 5.0],
], dtype=np.float32)
edges = np.array([[0, 1]], dtype=np.uint32)

skeleton = Skeleton(vertices, edges)
chunks = skeleton.chunk(chunk_size=chunk_size)

assert len(chunks) > 0
assert all(isinstance(k, tuple) and len(k) == 3 for k in chunks.keys())
assert all(isinstance(v, Skeleton) for v in chunks.values())

Loading