Train MLP classifiers on Clay foundation model embeddings for mangrove detection using Global Mangrove Watch (GMW) labels.
python train_clay_model.py --config config/train_clay.yamldata:
embeddings_path: "data/embeddings.parquet" # Clay embeddings
mangrove_raster_dir: "data/gmw_v3_2020/" # GMW raster tiles
aoi_filter_path: null # Optional: filter by AOI
test_region_path: null # Held-out test region (GeoJSON)labels:
use_soft_labels: true # Continuous [0,1] vs binary labels
presence_threshold_pct: 0.1 # % coverage for binary positive
target_coverage_percent: 0.1 # Coverage % for positive threshold
max_coverage_percent: 1.0 # Coverage % for label=1.0Soft Labels: When use_soft_labels: true, labels are continuous values:
label = min(coverage_pct / max_coverage_percent, 1.0)target_coverage_percentdefines threshold for positive in metrics
training:
n_trials: 30 # Optuna hyperparameter search trials
optimize_metric: 'f1' # 'f1' or 'f2'
use_checkerboard_split: true # Spatial train/val/test split
cell_size_km: 200 # Checkerboard grid cell size
final_training: false # Final training mode (see below)When final_training: true:
- Test region (
test_region_path) is held out completely - Remaining data split 80/20 train/val using checkerboard
- Optuna tunes hyperparameters on train/val
- Final model retrained on train+val with best hyperparameters
- Evaluated on held-out test region
training:
final_training: true
data:
test_region_path: "data/test_region.geojson"Labels embeddings using GMW raster tiles. Computes mangrove coverage per chip and generates soft or binary labels.
Creates spatially-stratified train/val/test split (60/20/20) using 5-group checkerboard pattern. Optimizes grid offset to balance positive rates across splits.
For final training mode: 80/20 train/val split (test region excluded separately).
Optuna objective function optimizing:
- Learning rate, epochs, patience
- Class weights, batch size, dropout
- Network architecture (layers, hidden sizes)
- Classification threshold
Trains model with fixed hyperparameters for final training mode.
python run_soft_label_experiments.py \
--config config/train_clay.yaml \
--output results/soft_label_sweep.csv \
--target-coverages 0.05 0.1 1.0 \
--max-coverages 1.0 5.0 10.0Each run creates a timestamped directory in runs/:
config.yaml- Copy of config usedmetrics.json- Train/val/test metrics*.pt- Model weights*_thresholds.pkl- Classification thresholds*_test_predictions.parquet- Test set predictions with geometriescheckerboard_grid.geojson- Spatial split gridlabeled_embeddings_map.png- Label visualization
MLPBinaryClassifier: Configurable MLP with:
- Variable hidden layers (1-4 layers, 64-512 units)
- Dropout regularization
- Xavier weight initialization
- BCELoss with class weighting (supports soft labels via interpolation)
- PyTorch
- GeoPandas, Rasterio
- Optuna (hyperparameter tuning)
- Weights & Biases (experiment tracking)