Skip to content

Ode-PBLLC/mangroves-agu

Repository files navigation

Mangrove Detection with Clay Embeddings

Train MLP classifiers on Clay foundation model embeddings for mangrove detection using Global Mangrove Watch (GMW) labels.

Quick Start

python train_clay_model.py --config config/train_clay.yaml

Configuration

Core Settings (config/train_clay.yaml)

Data Paths

data:
  embeddings_path: "data/embeddings.parquet"    # Clay embeddings
  mangrove_raster_dir: "data/gmw_v3_2020/"      # GMW raster tiles
  aoi_filter_path: null                          # Optional: filter by AOI
  test_region_path: null                         # Held-out test region (GeoJSON)

Label Configuration

labels:
  use_soft_labels: true              # Continuous [0,1] vs binary labels
  presence_threshold_pct: 0.1        # % coverage for binary positive
  target_coverage_percent: 0.1       # Coverage % for positive threshold
  max_coverage_percent: 1.0          # Coverage % for label=1.0

Soft Labels: When use_soft_labels: true, labels are continuous values:

  • label = min(coverage_pct / max_coverage_percent, 1.0)
  • target_coverage_percent defines threshold for positive in metrics

Training Parameters

training:
  n_trials: 30                       # Optuna hyperparameter search trials
  optimize_metric: 'f1'              # 'f1' or 'f2'
  use_checkerboard_split: true       # Spatial train/val/test split
  cell_size_km: 200                  # Checkerboard grid cell size
  final_training: false              # Final training mode (see below)

Final Training Mode

When final_training: true:

  1. Test region (test_region_path) is held out completely
  2. Remaining data split 80/20 train/val using checkerboard
  3. Optuna tunes hyperparameters on train/val
  4. Final model retrained on train+val with best hyperparameters
  5. Evaluated on held-out test region
training:
  final_training: true
data:
  test_region_path: "data/test_region.geojson"

Key Functions

CalculateMangrovePresence

Labels embeddings using GMW raster tiles. Computes mangrove coverage per chip and generates soft or binary labels.

CheckerboardSpatialSplit

Creates spatially-stratified train/val/test split (60/20/20) using 5-group checkerboard pattern. Optimizes grid offset to balance positive rates across splits.

CheckerboardSpatialSplitFinal

For final training mode: 80/20 train/val split (test region excluded separately).

objective

Optuna objective function optimizing:

  • Learning rate, epochs, patience
  • Class weights, batch size, dropout
  • Network architecture (layers, hidden sizes)
  • Classification threshold

train_final_model

Trains model with fixed hyperparameters for final training mode.

Experiment Runners

Soft Label Grid Search

python run_soft_label_experiments.py \
  --config config/train_clay.yaml \
  --output results/soft_label_sweep.csv \
  --target-coverages 0.05 0.1 1.0 \
  --max-coverages 1.0 5.0 10.0

Output Files

Each run creates a timestamped directory in runs/:

  • config.yaml - Copy of config used
  • metrics.json - Train/val/test metrics
  • *.pt - Model weights
  • *_thresholds.pkl - Classification thresholds
  • *_test_predictions.parquet - Test set predictions with geometries
  • checkerboard_grid.geojson - Spatial split grid
  • labeled_embeddings_map.png - Label visualization

Model Architecture

MLPBinaryClassifier: Configurable MLP with:

  • Variable hidden layers (1-4 layers, 64-512 units)
  • Dropout regularization
  • Xavier weight initialization
  • BCELoss with class weighting (supports soft labels via interpolation)

Dependencies

  • PyTorch
  • GeoPandas, Rasterio
  • Optuna (hyperparameter tuning)
  • Weights & Biases (experiment tracking)

AEF experiments to be added soon

About

Using geospatial foundation models to map mangroves

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages