Skip to content

Commit 4308b9e

Browse files
committed
Add spatial trajectory plotting APIs and example
1 parent 2af6602 commit 4308b9e

File tree

4 files changed

+456
-5
lines changed

4 files changed

+456
-5
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,4 +108,5 @@ benchmarks/spatial/outputs/
108108
ref/
109109
temp_debug_files/
110110
/example/comparison_outputs
111+
example/outputs/
111112
docs/SPATIAL_NAV_MODULE_DESIGN.md

example/spatial_plotting_demo.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
#!/usr/bin/env python3
2+
"""Demonstrate trajectory plotting helpers for canns_lib.spatial.Agent.
3+
4+
The script mirrors RatInABox usage: we create an environment with walls and a
5+
central hole, simulate a stochastic agent, and save several plots:
6+
7+
- trajectory.png: path overlaid on the environment
8+
- heatmap.png: spatial occupancy heatmap
9+
- speeds.png / rotation.png: histograms of speeds and rotational velocities
10+
11+
Run with:
12+
13+
uv run --no-sync python example/spatial_plotting_demo.py
14+
15+
Output files are written next to the script in ``example/outputs/``.
16+
"""
17+
18+
from __future__ import annotations
19+
20+
import pathlib
21+
22+
import matplotlib
23+
24+
matplotlib.use("Agg")
25+
from matplotlib import pyplot as plt
26+
27+
import numpy as np
28+
29+
from canns_lib import spatial
30+
31+
OUTPUT_DIR = pathlib.Path(__file__).resolve().parent / "outputs"
32+
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
33+
34+
ENVIRONMENT_PARAMS = {
35+
"dimensionality": "2D",
36+
"boundary_conditions": "solid",
37+
"walls": [
38+
[[0.1, 0.1], [0.9, 0.1]],
39+
[[0.9, 0.1], [0.9, 0.9]],
40+
[[0.1, 0.9], [0.9, 0.9]],
41+
[[0.1, 0.1], [0.1, 0.9]],
42+
],
43+
"holes": [
44+
[[0.35, 0.35], [0.65, 0.35], [0.65, 0.65], [0.35, 0.65]],
45+
],
46+
}
47+
48+
AGENT_PARAMS = {
49+
"speed_mean": 0.08,
50+
"speed_std": 0.02,
51+
"rotational_velocity_std": np.deg2rad(50),
52+
"speed_coherence_time": 0.7,
53+
"rotational_velocity_coherence_time": 0.12,
54+
"wall_repel_distance": 0.15,
55+
"wall_repel_strength": 1.5,
56+
"thigmotaxis": 0.4,
57+
}
58+
59+
60+
def main() -> None:
61+
env = spatial.Environment(**ENVIRONMENT_PARAMS)
62+
agent = spatial.Agent(env, params=AGENT_PARAMS, rng_seed=2025, init_pos=[0.4, 0.2])
63+
64+
for _ in range(2000):
65+
agent.update(dt=0.02)
66+
67+
# Trajectory plot
68+
fig, ax = agent.plot_trajectory()
69+
fig.savefig(OUTPUT_DIR / "trajectory.png", dpi=150)
70+
plt.close(fig)
71+
72+
# Heatmap
73+
fig, ax = agent.plot_position_heatmap(bins=60)
74+
fig.savefig(OUTPUT_DIR / "heatmap.png", dpi=150)
75+
plt.close(fig)
76+
77+
# Speed histogram
78+
fig, ax = agent.plot_histogram_of_speeds(bins=40)
79+
fig.savefig(OUTPUT_DIR / "speeds.png", dpi=150)
80+
plt.close(fig)
81+
82+
# Rotational velocity histogram
83+
fig, ax = agent.plot_histogram_of_rotational_velocities(bins=40)
84+
fig.savefig(OUTPUT_DIR / "rotation.png", dpi=150)
85+
plt.close(fig)
86+
87+
print(f"Saved plots to {OUTPUT_DIR}")
88+
89+
90+
if __name__ == "__main__":
91+
main()

0 commit comments

Comments
 (0)