Implementation of Coefficient of Variation Masking (CV-Masking), a volatility-aware pretraining strategy for Electronic Health Record (EHR) foundation models using Masked Autoencoders (MAE).
This repository contains the code for our ML4H 2025 paper on CV-based masking. Our approach adaptively adjusts masking probabilities based on the intrinsic variability of clinical features, improving reconstruction accuracy and downstream task performance.
Key Results:
- 71% win rate over random masking on laboratory test reconstruction
- 50% faster convergence during pretraining
- Improved performance on mortality prediction and readmission tasks
A novel MAE architecture for EHR triplets (time, code, value) that masks only numeric values while preserving temporal and categorical context:
Input Triplet: (t=10.5, code=GLUCOSE, value=120.5)
Masked Triplet: (t=10.5, code=GLUCOSE, value=[MASK])
Model predicts: value = 120.5
Why Value-Only Masking?
- ✅ Clinical alignment: When labs are ordered, the timing and test type are known—only results are uncertain
- ✅ Efficient learning: Model focuses on the challenging prediction task (numeric values) rather than easy reconstruction (codes/timestamps)
- ✅ Reduced memory: 25% less training memory vs. full triplet masking
- ✅ Better representations: Forcing the model to leverage context (time + code) to predict values
Implementation:
- Masking logic:
src/meds_triplet_mae/datamodules/mae_datamodule.py - Model architecture:
src/meds_triplet_mae/models/model.py - The encoder sees: full triplets with masked values replaced by learned mask tokens
- The decoder predicts: only the numeric values at masked positions
Adaptively adjusts masking probabilities based on coefficient of variation (CV):
- High CV (volatile labs like Lactate): 80% masking probability → model learns challenging patterns
- Low CV (stable labs like Sodium): 20% masking probability → less focus on predictable values
- Threshold: 75th percentile CV distinguishes volatile from stable biomarkers
- Implementation:
experiments/cv_based_masking_fixed/scripts/create_cv_masking_weights.py
Our approach combines what to mask (VO-MAE) with how often to mask (CV-Masking):
| Component | Traditional MAE | Random VO-MAE (baseline) | CV-VO-MAE (ours) |
|---|---|---|---|
| What to mask | Entire tokens/triplets | ✅ Only values | ✅ Only values |
| Masking probability | Uniform 25% | Uniform 25% | ✅ Adaptive 20-80% |
| Clinical alignment | ❌ No | ✅ Yes | ✅ Yes |
| Curriculum learning | ❌ No | ❌ No | ✅ Yes |
Example: Patient with glucose and lactate measurements
# Input sequence
[(t=1.0, code=GLUCOSE, value=110), # Stable biomarker (low CV)
(t=2.5, code=LACTATE, value=2.1), # Volatile biomarker (high CV)
(t=4.0, code=GLUCOSE, value=105)]
# Traditional MAE: masks randomly (time, code, OR value)
[(t=[MASK], code=GLUCOSE, value=110),
(t=2.5, code=[MASK], value=2.1),
(t=4.0, code=GLUCOSE, value=[MASK])] # 25% probability each
# Random VO-MAE: masks values uniformly
[(t=1.0, code=GLUCOSE, value=110),
(t=2.5, code=LACTATE, value=[MASK]), # 25% probability
(t=4.0, code=GLUCOSE, value=[MASK])] # 25% probability
# CV-VO-MAE (ours): masks values adaptively
[(t=1.0, code=GLUCOSE, value=110), # Only 20% mask probability (stable)
(t=2.5, code=LACTATE, value=[MASK]), # 80% mask probability (volatile)
(t=4.0, code=GLUCOSE, value=105)] # Only 20% mask probability (stable)- Python 3.11+
- CUDA-capable GPU (recommended)
- MIMIC-IV v2.2 data in MEDS format
# Clone repository
git clone https://github.com/rajna-fani/meds-triplet-mae.git
cd meds-triplet-mae
# Create virtual environment
python -m venv venv
source venv/bin/activate # Windows: venv\Scripts\activate
# Install dependencies
pip install -r requirements.txtSet the data directory environment variable:
export MEDS_DATA_DIR=/path/to/your/meds/dataOr create a .env file:
cp env.template .env
# Edit .env with your pathsThis codebase requires MEDS-formatted EHR data. You can use:
- MIMIC-IV v2.2: Convert using MEDS_transforms
- Your own EHR data: Convert to MEDS format following the MEDS specification
Required MEDS outputs:
$MEDS_DATA_DIR/
├── train/ # Training data
│ └── *.nrt
├── tuning/ # Validation data
│ └── *.nrt
├── held_out/ # Test data
│ └── *.nrt
├── tokenization/
│ ├── schemas/ # Data schemas
│ │ └── *.parquet
│ └── vocabulary/
│ └── codes.parquet # Code vocabulary
└── fit_normalization/
└── codes.parquet # Normalization statistics (mean, std, CV)
Key requirement: The fit_normalization/codes.parquet file must contain columns:
code: Medical code stringvalues/sum: Sum of values for CV calculationvalues/sum_sqd: Sum of squared valuesvalues/n_occurrences: Number of occurrences
python experiments/cv_based_masking_fixed/scripts/create_cv_masking_weights.py \
--fit-normalization-codes $MEDS_DATA_DIR/fit_normalization/codes.parquet \
--vocab-size 829 \
--temperature 2.0 \
--output-dir data/masking_weightsThis generates:
data/masking_weights/masking_weights_cv_T2.0.pt(CV-based weights)data/masking_weights/cv_analysis_summary.csv(analysis results)
All models use the Value-Only Masked Autoencoder (VO-MAE) architecture by default. The masking strategy (CV-based vs. random vs. variance-based) is controlled via the datamodule configuration.
Train with CV-based masking (our method):
python -m meds_triplet_mae.train \
model=meds_mae_cv_masking \
datamodule=mae_cv_masking \
training=mae_cv_maskingTrain baseline (random masking):
python -m meds_triplet_mae.train \
model=meds_mae \
datamodule=mae_lab \
training=meds_maeThis uses:
- ✅ Value-only masking objective (only numeric values masked)
- ✅ Uniform random masking probability (25% for all features)
- ✅ Same VO-MAE architecture as CV-based method (fair comparison)
Train variance-based baseline:
# First generate variance weights
python experiments/variance_based_masking_fixed/scripts/create_complete_variance_weights.py \
--fit-normalization-codes $MEDS_DATA_DIR/fit_normalization/codes.parquet \
--vocab-size 829
# Then train
python -m meds_triplet_mae.train \
config=_train_mae_variance_maskingGenerate top 100 lab codes list (if not already done):
python scripts/data_inspection/count_lab_codes_from_nrt.py \
--data-dir $MEDS_DATA_DIR/train \
--output top100_labs.csvRun reconstruction analysis:
python scripts/evaluation/evaluate_mae_reconstruction.py \
--checkpoint outputs/cv_masking/checkpoints/best.ckpt \
--data-dir $MEDS_DATA_DIR \
--top100-csv top100_labs.csv \
--codes-parquet $MEDS_DATA_DIR/tokenization/vocabulary/codes.parquet \
--fit-normalization-codes-parquet $MEDS_DATA_DIR/fit_normalization/codes.parquet \
--output-dir results/cv_reconstructionCompare across methods:
python scripts/evaluation/comprehensive_masking_analysis.py \
--cv-results results/cv_reconstruction \
--random-results results/random_reconstruction \
--variance-results results/variance_reconstruction \
--output-dir results/comparisonpython scripts/evaluation/perturbation/run_complete_perturbation_analysis.py \
--cv-model-path outputs/cv_masking/checkpoints/best.ckpt \
--random-model-path outputs/random_masking/checkpoints/best.ckpt \
--data-dir $MEDS_DATA_DIR \
--output-dir results/perturbationFine-tune on clinical prediction tasks:
python experiments/downstream_finetuning/scripts/finetune_downstream.py \
--config experiments/downstream_finetuning/configs/mortality_cv_masking.yaml \
--pretrained-checkpoint outputs/cv_masking/checkpoints/best.ckptmeds-triplet-mae/
├── configs/ # Hydra configuration files
│ ├── model/ # Model architectures
│ │ └── meds_mae_cv_masking.yaml # VO-MAE architecture specs
│ ├── training/ # Training hyperparameters
│ │ ├── mae_cv_masking.yaml # CV-based training config
│ │ └── meds_mae.yaml # Random masking baseline
│ └── datamodule/ # Data configurations
│ ├── mae_cv_masking.yaml # CV-based data loading
│ └── mae_lab.yaml # Random masking data
│
├── src/meds_triplet_mae/ # Core implementation
│ ├── models/
│ │ └── model.py # VO-MAE encoder-decoder architecture
│ ├── datamodules/
│ │ └── mae_datamodule.py # Value-only masking logic
│ ├── lightning_module/
│ │ └── mae_lab.py # Training loop & loss computation
│ └── train.py # Main training script
│
├── experiments/ # Masking weight generation
│ ├── cv_based_masking_fixed/
│ │ └── scripts/
│ │ └── create_cv_masking_weights.py # CV-based weights
│ └── variance_based_masking_fixed/
│ └── scripts/
│ └── create_complete_variance_weights.py # Variance weights
│
├── scripts/ # Analysis & evaluation
│ ├── evaluation/
│ │ ├── evaluate_mae_reconstruction.py # Reconstruction metrics
│ │ ├── comprehensive_masking_analysis.py # Cross-method comparison
│ │ └── perturbation/
│ │ └── run_complete_perturbation_analysis.py # Contextual learning
│ └── data_inspection/
│ └── count_lab_codes_from_nrt.py # Data statistics
│
└── requirements.txt # Python dependencies
Value-Only Masking Architecture:
src/meds_triplet_mae/models/model.py- Asymmetric encoder-decoder with mask tokenssrc/meds_triplet_mae/datamodules/mae_datamodule.py- Value masking logic (lines 150-200)src/meds_triplet_mae/lightning_module/mae_lab.py- Joint loss (masked + unmasked)
CV-Based Masking Strategy:
experiments/cv_based_masking_fixed/scripts/create_cv_masking_weights.py- Weight generationconfigs/datamodule/mae_cv_masking.yaml- Masking weight loading configuration
Models use Hydra for configuration. Key configs:
configs/model/meds_mae_cv_masking.yaml- CV-based MAE architectureconfigs/training/mae_cv_masking.yaml- Training hyperparametersconfigs/datamodule/mae_cv_masking.yaml- Data processing
Override configs via command line:
python -m meds_triplet_mae.train \
model.embed_dim=512 \
training.batch_size=64 \
training.learning_rate=1e-3Encoder: 8-layer Transformer (d=256, 8 heads) Decoder: 4-layer Transformer (d=128, 4 heads) Masking: 25% ratio, CV-weighted (0.8 for volatile labs, 0.2 for stable) Loss: MSE with λ=0.1 auxiliary loss on visible positions
See docs/meds_mae_architecture_report.md for details.
If you use this code, please cite our paper:
@inproceedings{fani2025cv,
title={Coefficient of Variation Masking: A Volatility-Aware Strategy for EHR Foundation Models},
author={Fani, Rajna and Al Attrach, Rafi and Restrepo, David and Jia, Yugang and Celi, Leo Anthony and Sch{\"u}ffler, Peter},
booktitle={Machine Learning for Health (ML4H)},
year={2025}
}Our paper uses MIMIC-IV v2.2, available through PhysioNet after completing CITI training and signing a data use agreement: https://physionet.org/
To convert MIMIC-IV to MEDS format:
- Follow the MEDS_transforms pipeline
- Ensure
fit_normalization/codes.parquetis generated (contains mean, std, CV statistics)
This codebase works with any MEDS-formatted EHR data. To adapt for your dataset:
- Convert to MEDS format: Follow the MEDS specification
- Generate normalization statistics: Run MEDS_transforms
fit_normalizationstage - Update vocabulary size: Set
--vocab-sizeparameter based on your dataset - Adjust configs: Update
configs/files with your data paths
MIT License - see LICENSE file.
For questions about the code or paper, please open an issue on GitHub.
- MIMIC-IV data contributors
- MEDS format specification: https://github.com/Medical-Event-Data-Standard/meds
- PyTorch Lightning and Hydra frameworks