Skip to content

Commit d939c9f

Browse files
swahtzcursoragentCopilot
authored
LangSplatV2: Integrate TensorBoard logging with evaluation images (#55)
## Summary Integrates TensorBoard logging with evaluation images into the LangSplatV2 training pipeline, following the GARfVDB writer pattern (`GaussianSplatSegmentationWriter`) and using language-feature-specific visualizations (PCA projections, error heatmaps, feature coverage). - **Create `LangSplatV2Writer` + `LangSplatV2WriterConfig`** in `langsplatv2/training/langsplatv2_writer.py` — duplicates the GARfVDB writer interface (CSV metrics, disk image saving, checkpoints, TensorBoard) without cross-project dependency - **Add `cosine_error_map()` visualization utility** to `langsplatv2/util.py` — computes per-pixel `1 - cosine_similarity` and maps through the turbo colormap - **Refactor `LangSplatV2Trainer` to use the writer** — replaces internal `_log_metric()` / `_save_checkpoint()` / file management with an injected `LangSplatV2Writer` instance - **Add periodic training image logging** — PCA of predicted/GT features, feature coverage mask, cosine error heatmap (controlled by new `log_test_images` config flag) - **Enhance `eval()` with full image diagnostics** — beauty render, predicted features PCA, GT features PCA, cosine error heatmap, alpha map, and side-by-side comparison composite - **Update `train_langsplatv2.py`** — accepts `io: LangSplatV2WriterConfig` parameter and instantiates the writer before training Fixes #54 --------- Signed-off-by: Jonathan Swartz <jonathan@jswartz.info> Co-authored-by: Cursor <cursoragent@cursor.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 545b213 commit d939c9f

12 files changed

Lines changed: 1066 additions & 199 deletions

File tree

instance_segmentation/garfvdb/garfvdb/util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def pca_projection_fast(
136136
V = calculate_pca_projection(features_centered, n_components, center=False)
137137

138138
# Project data onto principal components
139-
projected = torch.mm(features_flat, V.to(features.device))
139+
projected = torch.mm(features_centered, V.to(features.device))
140140

141141
# Normalize to [0, 1] range
142142
mins = projected.min(dim=0, keepdim=True)[0]

open_vocabulary_segmentation/langsplatv2/README.md

Lines changed: 85 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,55 +1,111 @@
1-
# LangSplatV2
1+
# LangSplatV2 (fVDB)
22

3-
This project implements LangSplatV2 (Li, et al. 2025) with fVDB for open-vocabulary 3D segmentation.
3+
LangSplatV2-style open-vocabulary 3D segmentation using [fVDB](https://github.com/openvdb/fvdb-core) and pre-trained Gaussian splat reconstructions. This implementation trains per-Gaussian sparse coefficient fields and shared CLIP-aligned codebooks on an existing reconstruction; it does not train the underlying Gaussians or colors.
44

5-
## Overview
5+
## What this implements
66

7-
The LangSplatV2 scene data transformation pipeline consists of two main steps:
7+
- **Preprocessing**: Multi-scale SAM2 masks and OpenCLIP feature encoding for each image (cached on disk).
8+
- **Training**: Residual VQ codebooks and per-splat sparse logits so that rendered language features match the CLIP embeddings from SAM masks. One feature level (scale) per run; train multiple levels separately and combine at inference.
9+
- **Compatibility**: Same feature pipeline and training setup (loss, LR, layer schedule) as the original LangSplatV2; uses fVDB for the 3D representation and rendering.
810

9-
1. **Multi-scale SAM2 Segmentation**: Uses SAM2 to generate segmentation masks at multiple scales (default, small, medium, large) for each image.
11+
## Prerequisites
1012

11-
2. **CLIP Feature Encoding**: Encodes each segmented region using OpenCLIP to produce language-aligned features that can be used for open-vocabulary queries.
13+
- **SfM scene**: COLMAP, `simple_directory`, or E57 dataset (images + cameras + optional point cloud).
14+
- **Pre-trained Gaussian splat reconstruction**: A `.ply` or `.pt` / `.pth` checkpoint produced by e.g. [fvdb-reality-capture](https://github.com/openvdb/fvdb-reality-capture) or another fVDB-compatible pipeline. The script uses its normalization transform so the scene and Gaussians are aligned.
1215

1316
## Installation
1417

18+
From this directory (`open_vocabulary_segmentation/langsplatv2/`), with the `fvdb` conda environment active:
19+
1520
```bash
16-
# Install from the fvdb-examples repository
21+
conda activate fvdb
1722
pip install -e .
23+
```
24+
25+
Dependencies (see `pyproject.toml`) include `torch`, `open-clip-torch`, `fvdb-reality-capture`, `tyro`, and optional TensorBoard for logging.
26+
27+
## How to run
28+
29+
Training loads the SfM scene, applies preprocessing (SAM2 + CLIP) with caching, then runs the language-feature training loop.
30+
31+
**Minimal (COLMAP scene + PLY reconstruction):**
1832

19-
# Or install dependencies manually
20-
pip install open-clip-torch fvdb-reality-capture
33+
```bash
34+
python train_langsplatv2.py \
35+
--sfm-dataset-path /path/to/colmap/scene \
36+
--reconstruction-path /path/to/point_cloud.ply
37+
```
38+
39+
**With explicit feature level and log directory:**
40+
41+
```bash
42+
python train_langsplatv2.py \
43+
--sfm-dataset-path /path/to/colmap/scene \
44+
--reconstruction-path /path/to/point_cloud.ply \
45+
--config.feature-level 1 \
46+
--log-path langsplatv2_logs
2147
```
2248

49+
**Train all three scale levels (as in the paper):**
50+
51+
```bash
52+
for level in 1 2 3; do
53+
python train_langsplatv2.py \
54+
--sfm-dataset-path /path/to/scene \
55+
--reconstruction-path /path/to/gaussians.ply \
56+
--config.feature-level $level \
57+
--log-path langsplatv2_logs
58+
done
59+
```
60+
61+
62+
**Useful flags:**
63+
64+
- `--config.feature-level` — 0=default, 1=small, 2=medium, 3=large (default: 1).
65+
- `--config.max-steps` — Training steps (default from max_epochs if not set).
66+
- `--preprocess.image-downsample-factor` — Downsample images before SAM2/CLIP (e.g. 2 for speed).
67+
- `--preprocess.sam2.checkpoint` — SAM2 size: `large`, `small`, `tiny`, `base_plus`.
68+
- `--log-path` — Directory for run subdirs (checkpoints, metrics). Use `None` to disable saving.
69+
- `--io.use-tensorboard` — Log scalars (and optionally images) to TensorBoard.
70+
- `--use-every-n-as-val` — Hold out every N-th image for validation (e.g. 5); -1 = no validation.
71+
72+
## Outputs
73+
74+
With `--log-path` set (e.g. `langsplatv2_logs`), each run writes:
75+
76+
- `log_path/run_<timestamp>/` (or `log_path/<run_name>/` if `--run-name` is set)
77+
- `checkpoints/<step>/langsplatv2_ckpt.pt` — Model state and config (when `io.save_checkpoints` is True).
78+
- `metrics_log.csv` — Step, loss, and optional validation metrics.
79+
- `tensorboard/` — If `io.use_tensorboard` is True.
80+
- `images/` — If `io.save_images` is True (e.g. feature visualizations at save steps).
81+
82+
Preprocessing caches (SAM2 masks, CLIP features) are stored under the scene’s cache directory and reused across runs.
2383

24-
## Scene Transform Outputs
84+
## Preprocessing pipeline and cache format
2585

26-
### SAM2 Masks
86+
The pipeline (see `LangSplatV2PreprocessConfig` in `config.py`) runs in order: optional scene normalization, point filtering, image downsampling, filter images by visible points, **ComputeMultiScaleSAM2Masks**, **ComputeCLIPFeatures**, and optional cropping.
2787

28-
For each image, the SAM2 transform produces:
88+
### SAM2 masks (per image)
2989

30-
- `{scale}_segmentations`: Binary masks, shape `[N, H, W]`
31-
- `{scale}_bboxes`: Bounding boxes in XYWH format, shape `[N, 4]`
32-
- `{scale}_areas`: Mask areas in pixels, shape `[N]`
33-
- `{scale}_predicted_ious`: SAM2's IoU predictions, shape `[N]`
34-
- `{scale}_stability_scores`: Mask stability scores, shape `[N]`
90+
- `{scale}_segmentations`: `[N, H, W]` binary masks
91+
- `{scale}_bboxes`: `[N, 4]` XYWH
92+
- `{scale}_areas`, `{scale}_predicted_ious`, `{scale}_stability_scores`
3593

36-
where `{scale}` is one of: `default`, `s` (small), `m` (medium), `l` (large).
94+
Scales: `default`, `s` (small, &lt;1% area), `m` (1–10%), `l` (≥10%).
3795

38-
Masks are categorized by area ratio:
39-
- **Large (l)**: area >= 10% of image
40-
- **Medium (m)**: 1% <= area < 10%
41-
- **Small (s)**: area < 1%
42-
- **Default**: all masks
96+
### CLIP features (per image)
4397

44-
### CLIP Features
98+
- `features`: `[N_total, 512]` L2-normalized OpenCLIP embeddings (one per mask, concatenated over scales).
99+
- `seg_maps`: `[4, H, W]` — pixel → feature index per scale (-1 = no mask).
100+
- `lengths`: `[4]` — number of masks per scale (default, s, m, l).
45101

46-
For each image, the CLIP transform produces:
102+
Training uses a single `feature_level` (0–3) to choose which scale’s seg map and features to use.
47103

48-
- `features`: CLIP embeddings, shape `[N_total, 512]`
49-
- `seg_maps`: Segmentation maps, shape `[4, H, W]`
50-
- `lengths`: Number of masks per scale, shape `[4]`
104+
## Training details and comparison with original LangSplatV2
51105

52-
The `seg_maps` tensor maps each pixel to a feature index (or -1 for unmasked pixels).
106+
- **Feature generation**: Same as original — crop mask region → pad to square → resize to 224 → OpenCLIP encode → L2-normalize. Scale order and seg-map indexing (default → s → m → l, cumulative) match.
107+
- **Optimization**: Same language-feature LR (0.0025), layer schedule (every 10k steps), and cosine loss over valid pixels with gradient scaling via mask fraction. The scalar `train/loss` is the (mask-fraction-scaled) total loss used for backprop. For a smoother, more interpretable curve when mask coverage varies across images, use `train/cosine_loss_valid`, which is the mean cosine loss over valid pixels only (no mask-fraction scaling), we use this for logging.
108+
- **Data sampling**: One random permutation of all training views per “epoch” (InfiniteSampler with shuffle), one view per step when `batch_size=1`, matching the original’s viewpoint-stack behavior.
53109

54110
## References
55111

open_vocabulary_segmentation/langsplatv2/langsplatv2/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,13 @@
88
)
99
from .loss import calculate_langsplatv2_loss
1010
from .model import LangSplatV2Model
11+
from .training.langsplatv2_writer import LangSplatV2WriterConfig
1112

1213
__all__ = [
1314
"LangSplatV2PreprocessConfig",
1415
"LangSplatV2TrainingConfig",
1516
"LangSplatV2ModelConfig",
17+
"LangSplatV2WriterConfig",
1618
"LangSplatV2Model",
1719
"calculate_langsplatv2_loss",
1820
]

open_vocabulary_segmentation/langsplatv2/langsplatv2/config.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,29 @@ class SAM2Config:
3030
points_per_side: int = 32
3131
"""Grid density for point prompts."""
3232

33+
points_per_batch: int = 64
34+
"""Points processed simultaneously by SAM2."""
35+
3336
pred_iou_thresh: float = 0.7
3437
"""Predicted IoU threshold for mask filtering."""
3538

3639
stability_score_thresh: float = 0.85
3740
"""Stability score threshold for mask filtering."""
3841

42+
crop_n_layers: int = 1
43+
"""Number of crop layers. 1 = also run SAM on image crops (matching
44+
the original LangSplatV2 which uses ``crop_n_layers=1``)."""
45+
46+
crop_n_points_downscale_factor: int = 1
47+
"""Point grid downscale factor per crop layer."""
48+
49+
min_mask_region_area: int = 100
50+
"""Minimum mask region area for post-processing (matching the original
51+
LangSplatV2 which uses ``min_mask_region_area=100``)."""
52+
53+
box_nms_thresh: float = 0.7
54+
"""Box NMS IoU threshold within each crop."""
55+
3956
nms_iou_thr: float = 0.8
4057
"""IoU threshold for mask NMS post-processing."""
4158

@@ -172,8 +189,13 @@ def build_scene_transforms(
172189
ComputeMultiScaleSAM2Masks(
173190
checkpoint=self.sam2.checkpoint,
174191
points_per_side=self.sam2.points_per_side,
192+
points_per_batch=self.sam2.points_per_batch,
175193
pred_iou_thresh=self.sam2.pred_iou_thresh,
176194
stability_score_thresh=self.sam2.stability_score_thresh,
195+
crop_n_layers=self.sam2.crop_n_layers,
196+
crop_n_points_downscale_factor=self.sam2.crop_n_points_downscale_factor,
197+
min_mask_region_area=self.sam2.min_mask_region_area,
198+
box_nms_thresh=self.sam2.box_nms_thresh,
177199
nms_iou_thr=self.sam2.nms_iou_thr,
178200
nms_score_thr=self.sam2.nms_score_thr,
179201
nms_inner_thr=self.sam2.nms_inner_thr,
@@ -213,8 +235,13 @@ def build_sam2_transform(self):
213235
return ComputeMultiScaleSAM2Masks(
214236
checkpoint=self.sam2.checkpoint,
215237
points_per_side=self.sam2.points_per_side,
238+
points_per_batch=self.sam2.points_per_batch,
216239
pred_iou_thresh=self.sam2.pred_iou_thresh,
217240
stability_score_thresh=self.sam2.stability_score_thresh,
241+
crop_n_layers=self.sam2.crop_n_layers,
242+
crop_n_points_downscale_factor=self.sam2.crop_n_points_downscale_factor,
243+
min_mask_region_area=self.sam2.min_mask_region_area,
244+
box_nms_thresh=self.sam2.box_nms_thresh,
218245
nms_iou_thr=self.sam2.nms_iou_thr,
219246
nms_score_thr=self.sam2.nms_score_thr,
220247
nms_inner_thr=self.sam2.nms_inner_thr,
@@ -294,7 +321,12 @@ class LangSplatV2TrainingConfig:
294321
model: LangSplatV2ModelConfig = field(default_factory=LangSplatV2ModelConfig)
295322
"""Model architecture configuration."""
296323

297-
eval_at_percent: list[int] = field(default_factory=lambda: [25, 50, 75, 100])
324+
log_test_images: bool = False
325+
"""Whether to log visualization images (PCA features, error heatmaps)
326+
during training steps. Eval images are always logged when the writer
327+
supports image output, regardless of this flag."""
328+
329+
eval_at_percent: list[int] = field(default_factory=lambda: [5, 10, 20, 30, 40, 50, 75, 100])
298330
"""Percentages of total epochs at which to run evaluation."""
299331

300332
save_at_percent: list[int] = field(default_factory=lambda: [50, 100])

open_vocabulary_segmentation/langsplatv2/langsplatv2/loss.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -63,25 +63,14 @@ def calculate_langsplatv2_loss(
6363
"""
6464
assert use_cosine_loss or use_l1_loss, "At least one loss type must be enabled"
6565

66-
# Apply mask: only compute loss on valid pixels
67-
mask_expanded = mask.unsqueeze(-1).float() # [B, H, W, 1]
68-
6966
# Optionally normalize predicted features
7067
if normalize_features:
7168
predicted_features = predicted_features / (predicted_features.norm(dim=-1, keepdim=True) + 1e-10)
7269

73-
# Mask both prediction and target
74-
pred_masked = predicted_features * mask_expanded
75-
gt_masked = gt_features * mask_expanded
76-
77-
# Only compute on valid pixels to avoid diluting the loss
78-
valid_pred = pred_masked[mask] # [N_valid, clip_n_dims]
79-
valid_gt = gt_masked[mask] # [N_valid, clip_n_dims]
80-
8170
loss_dict: dict[str, torch.Tensor] = {}
8271
total_loss = torch.tensor(0.0, device=predicted_features.device)
8372

84-
if valid_pred.shape[0] == 0:
73+
if not mask.any():
8574
# No valid pixels - return zero loss
8675
loss_dict["total_loss"] = total_loss
8776
if use_cosine_loss:
@@ -90,13 +79,32 @@ def calculate_langsplatv2_loss(
9079
loss_dict["l1_loss"] = total_loss
9180
return loss_dict
9281

82+
# Gather only valid pixels (clean signal, no NaN risk from torch.empty).
83+
valid_pred = predicted_features[mask] # [N_valid, clip_n_dims]
84+
valid_gt = gt_features[mask] # [N_valid, clip_n_dims]
85+
86+
# The original LangSplatV2 computes .mean() over ALL H*W pixels, where
87+
# masked-out pixels are zero-vectors that contribute ~0 to the sum but
88+
# inflate the denominator. This implicitly scales gradients down by
89+
# (N_valid / N_total). We replicate this by computing the loss on valid
90+
# pixels only (clean, interpretable values) and multiplying by the mask
91+
# coverage fraction so that gradient magnitudes match the original exactly:
92+
#
93+
# grad_original = (1/N_total) * sum_valid(dL_i)
94+
# grad_ours = (ratio/N_valid) * sum_valid(dL_i)
95+
# = (1/N_total) * sum_valid(dL_i) [identical]
96+
mask_fraction = mask.sum().float() / mask.numel()
97+
9398
if use_cosine_loss:
94-
cos_loss = cosine_loss(valid_pred, valid_gt)
99+
cos_loss_raw = cosine_loss(valid_pred, valid_gt)
100+
cos_loss = cos_loss_raw * mask_fraction
95101
loss_dict["cosine_loss"] = cos_loss
102+
# Mean over valid pixels only (no mask_fraction); stable for logging when coverage varies
103+
loss_dict["cosine_loss_valid"] = cos_loss_raw
96104
total_loss = total_loss + cos_loss
97105

98106
if use_l1_loss:
99-
l1 = l1_loss(valid_pred, valid_gt)
107+
l1 = l1_loss(valid_pred, valid_gt) * mask_fraction
100108
loss_dict["l1_loss"] = l1
101109
total_loss = total_loss + l1
102110

0 commit comments

Comments
 (0)