Skip to content

rajna-fani/meds-triplet-mae

Repository files navigation

CV-Based Masking for EHR Foundation Models

Implementation of Coefficient of Variation Masking (CV-Masking), a volatility-aware pretraining strategy for Electronic Health Record (EHR) foundation models using Masked Autoencoders (MAE).

Overview

This repository contains the code for our ML4H 2025 paper on CV-based masking. Our approach adaptively adjusts masking probabilities based on the intrinsic variability of clinical features, improving reconstruction accuracy and downstream task performance.

Key Results:

  • 71% win rate over random masking on laboratory test reconstruction
  • 50% faster convergence during pretraining
  • Improved performance on mortality prediction and readmission tasks

Key Contributions

1. Value-Only Masked Autoencoder (VO-MAE)

A novel MAE architecture for EHR triplets (time, code, value) that masks only numeric values while preserving temporal and categorical context:

Input Triplet:     (t=10.5, code=GLUCOSE, value=120.5)
Masked Triplet:    (t=10.5, code=GLUCOSE, value=[MASK])
Model predicts:    value = 120.5

Why Value-Only Masking?

  • Clinical alignment: When labs are ordered, the timing and test type are known—only results are uncertain
  • Efficient learning: Model focuses on the challenging prediction task (numeric values) rather than easy reconstruction (codes/timestamps)
  • Reduced memory: 25% less training memory vs. full triplet masking
  • Better representations: Forcing the model to leverage context (time + code) to predict values

Implementation:

  • Masking logic: src/meds_triplet_mae/datamodules/mae_datamodule.py
  • Model architecture: src/meds_triplet_mae/models/model.py
  • The encoder sees: full triplets with masked values replaced by learned mask tokens
  • The decoder predicts: only the numeric values at masked positions

2. CV-Based Masking Strategy

Adaptively adjusts masking probabilities based on coefficient of variation (CV):

  • High CV (volatile labs like Lactate): 80% masking probability → model learns challenging patterns
  • Low CV (stable labs like Sodium): 20% masking probability → less focus on predictable values
  • Threshold: 75th percentile CV distinguishes volatile from stable biomarkers
  • Implementation: experiments/cv_based_masking_fixed/scripts/create_cv_masking_weights.py

How They Work Together

Our approach combines what to mask (VO-MAE) with how often to mask (CV-Masking):

Component Traditional MAE Random VO-MAE (baseline) CV-VO-MAE (ours)
What to mask Entire tokens/triplets ✅ Only values ✅ Only values
Masking probability Uniform 25% Uniform 25% ✅ Adaptive 20-80%
Clinical alignment ❌ No ✅ Yes ✅ Yes
Curriculum learning ❌ No ❌ No ✅ Yes

Example: Patient with glucose and lactate measurements

# Input sequence
[(t=1.0, code=GLUCOSE, value=110),    # Stable biomarker (low CV)
 (t=2.5, code=LACTATE, value=2.1),    # Volatile biomarker (high CV)
 (t=4.0, code=GLUCOSE, value=105)]

# Traditional MAE: masks randomly (time, code, OR value)
[(t=[MASK], code=GLUCOSE, value=110),
 (t=2.5, code=[MASK], value=2.1),
 (t=4.0, code=GLUCOSE, value=[MASK])]  # 25% probability each

# Random VO-MAE: masks values uniformly
[(t=1.0, code=GLUCOSE, value=110),
 (t=2.5, code=LACTATE, value=[MASK]),   # 25% probability
 (t=4.0, code=GLUCOSE, value=[MASK])]   # 25% probability

# CV-VO-MAE (ours): masks values adaptively
[(t=1.0, code=GLUCOSE, value=110),      # Only 20% mask probability (stable)
 (t=2.5, code=LACTATE, value=[MASK]),   # 80% mask probability (volatile)
 (t=4.0, code=GLUCOSE, value=105)]      # Only 20% mask probability (stable)

Quick Start

Prerequisites

  • Python 3.11+
  • CUDA-capable GPU (recommended)
  • MIMIC-IV v2.2 data in MEDS format

Installation

# Clone repository
git clone https://github.com/rajna-fani/meds-triplet-mae.git
cd meds-triplet-mae

# Create virtual environment
python -m venv venv
source venv/bin/activate  # Windows: venv\Scripts\activate

# Install dependencies
pip install -r requirements.txt

Environment Setup

Set the data directory environment variable:

export MEDS_DATA_DIR=/path/to/your/meds/data

Or create a .env file:

cp env.template .env
# Edit .env with your paths

Reproducing Paper Results

Prerequisites: MEDS-Formatted Data

This codebase requires MEDS-formatted EHR data. You can use:

Required MEDS outputs:

$MEDS_DATA_DIR/
├── train/                    # Training data
│   └── *.nrt
├── tuning/                   # Validation data
│   └── *.nrt
├── held_out/                 # Test data
│   └── *.nrt
├── tokenization/
│   ├── schemas/              # Data schemas
│   │   └── *.parquet
│   └── vocabulary/
│       └── codes.parquet     # Code vocabulary
└── fit_normalization/
    └── codes.parquet         # Normalization statistics (mean, std, CV)

Key requirement: The fit_normalization/codes.parquet file must contain columns:

  • code: Medical code string
  • values/sum: Sum of values for CV calculation
  • values/sum_sqd: Sum of squared values
  • values/n_occurrences: Number of occurrences

Step 1: Generate CV-Based Masking Weights

python experiments/cv_based_masking_fixed/scripts/create_cv_masking_weights.py \
    --fit-normalization-codes $MEDS_DATA_DIR/fit_normalization/codes.parquet \
    --vocab-size 829 \
    --temperature 2.0 \
    --output-dir data/masking_weights

This generates:

  • data/masking_weights/masking_weights_cv_T2.0.pt (CV-based weights)
  • data/masking_weights/cv_analysis_summary.csv (analysis results)

Step 2: Train Models

All models use the Value-Only Masked Autoencoder (VO-MAE) architecture by default. The masking strategy (CV-based vs. random vs. variance-based) is controlled via the datamodule configuration.

Train with CV-based masking (our method):

python -m meds_triplet_mae.train \
    model=meds_mae_cv_masking \
    datamodule=mae_cv_masking \
    training=mae_cv_masking

Train baseline (random masking):

python -m meds_triplet_mae.train \
    model=meds_mae \
    datamodule=mae_lab \
    training=meds_mae

This uses:

  • ✅ Value-only masking objective (only numeric values masked)
  • ✅ Uniform random masking probability (25% for all features)
  • ✅ Same VO-MAE architecture as CV-based method (fair comparison)

Train variance-based baseline:

# First generate variance weights
python experiments/variance_based_masking_fixed/scripts/create_complete_variance_weights.py \
    --fit-normalization-codes $MEDS_DATA_DIR/fit_normalization/codes.parquet \
    --vocab-size 829

# Then train
python -m meds_triplet_mae.train \
    config=_train_mae_variance_masking

Step 3: Evaluate Reconstruction

Generate top 100 lab codes list (if not already done):

python scripts/data_inspection/count_lab_codes_from_nrt.py \
    --data-dir $MEDS_DATA_DIR/train \
    --output top100_labs.csv

Run reconstruction analysis:

python scripts/evaluation/evaluate_mae_reconstruction.py \
    --checkpoint outputs/cv_masking/checkpoints/best.ckpt \
    --data-dir $MEDS_DATA_DIR \
    --top100-csv top100_labs.csv \
    --codes-parquet $MEDS_DATA_DIR/tokenization/vocabulary/codes.parquet \
    --fit-normalization-codes-parquet $MEDS_DATA_DIR/fit_normalization/codes.parquet \
    --output-dir results/cv_reconstruction

Compare across methods:

python scripts/evaluation/comprehensive_masking_analysis.py \
    --cv-results results/cv_reconstruction \
    --random-results results/random_reconstruction \
    --variance-results results/variance_reconstruction \
    --output-dir results/comparison

Step 4: Perturbation Analysis

python scripts/evaluation/perturbation/run_complete_perturbation_analysis.py \
    --cv-model-path outputs/cv_masking/checkpoints/best.ckpt \
    --random-model-path outputs/random_masking/checkpoints/best.ckpt \
    --data-dir $MEDS_DATA_DIR \
    --output-dir results/perturbation

Step 5: Downstream Tasks (Optional)

Fine-tune on clinical prediction tasks:

python experiments/downstream_finetuning/scripts/finetune_downstream.py \
    --config experiments/downstream_finetuning/configs/mortality_cv_masking.yaml \
    --pretrained-checkpoint outputs/cv_masking/checkpoints/best.ckpt

Project Structure

meds-triplet-mae/
├── configs/                              # Hydra configuration files
│   ├── model/                           # Model architectures
│   │   └── meds_mae_cv_masking.yaml    # VO-MAE architecture specs
│   ├── training/                        # Training hyperparameters  
│   │   ├── mae_cv_masking.yaml         # CV-based training config
│   │   └── meds_mae.yaml               # Random masking baseline
│   └── datamodule/                      # Data configurations
│       ├── mae_cv_masking.yaml         # CV-based data loading
│       └── mae_lab.yaml                # Random masking data
│
├── src/meds_triplet_mae/                # Core implementation
│   ├── models/
│   │   └── model.py                    # VO-MAE encoder-decoder architecture
│   ├── datamodules/
│   │   └── mae_datamodule.py           # Value-only masking logic
│   ├── lightning_module/
│   │   └── mae_lab.py                  # Training loop & loss computation
│   └── train.py                         # Main training script
│
├── experiments/                         # Masking weight generation
│   ├── cv_based_masking_fixed/
│   │   └── scripts/
│   │       └── create_cv_masking_weights.py      # CV-based weights
│   └── variance_based_masking_fixed/
│       └── scripts/
│           └── create_complete_variance_weights.py  # Variance weights
│
├── scripts/                             # Analysis & evaluation
│   ├── evaluation/
│   │   ├── evaluate_mae_reconstruction.py         # Reconstruction metrics
│   │   ├── comprehensive_masking_analysis.py      # Cross-method comparison
│   │   └── perturbation/
│   │       └── run_complete_perturbation_analysis.py  # Contextual learning
│   └── data_inspection/
│       └── count_lab_codes_from_nrt.py            # Data statistics
│
└── requirements.txt                     # Python dependencies

Key Implementation Files

Value-Only Masking Architecture:

  • src/meds_triplet_mae/models/model.py - Asymmetric encoder-decoder with mask tokens
  • src/meds_triplet_mae/datamodules/mae_datamodule.py - Value masking logic (lines 150-200)
  • src/meds_triplet_mae/lightning_module/mae_lab.py - Joint loss (masked + unmasked)

CV-Based Masking Strategy:

  • experiments/cv_based_masking_fixed/scripts/create_cv_masking_weights.py - Weight generation
  • configs/datamodule/mae_cv_masking.yaml - Masking weight loading configuration

Configuration

Models use Hydra for configuration. Key configs:

  • configs/model/meds_mae_cv_masking.yaml - CV-based MAE architecture
  • configs/training/mae_cv_masking.yaml - Training hyperparameters
  • configs/datamodule/mae_cv_masking.yaml - Data processing

Override configs via command line:

python -m meds_triplet_mae.train \
    model.embed_dim=512 \
    training.batch_size=64 \
    training.learning_rate=1e-3

Model Architecture

Encoder: 8-layer Transformer (d=256, 8 heads) Decoder: 4-layer Transformer (d=128, 4 heads) Masking: 25% ratio, CV-weighted (0.8 for volatile labs, 0.2 for stable) Loss: MSE with λ=0.1 auxiliary loss on visible positions

See docs/meds_mae_architecture_report.md for details.

Citation

If you use this code, please cite our paper:

@inproceedings{fani2025cv,
  title={Coefficient of Variation Masking: A Volatility-Aware Strategy for EHR Foundation Models},
  author={Fani, Rajna and Al Attrach, Rafi and Restrepo, David and Jia, Yugang and Celi, Leo Anthony and Sch{\"u}ffler, Peter},
  booktitle={Machine Learning for Health (ML4H)},
  year={2025}
}

Data

Using MIMIC-IV v2.2

Our paper uses MIMIC-IV v2.2, available through PhysioNet after completing CITI training and signing a data use agreement: https://physionet.org/

To convert MIMIC-IV to MEDS format:

  1. Follow the MEDS_transforms pipeline
  2. Ensure fit_normalization/codes.parquet is generated (contains mean, std, CV statistics)

Using Other Datasets

This codebase works with any MEDS-formatted EHR data. To adapt for your dataset:

  1. Convert to MEDS format: Follow the MEDS specification
  2. Generate normalization statistics: Run MEDS_transforms fit_normalization stage
  3. Update vocabulary size: Set --vocab-size parameter based on your dataset
  4. Adjust configs: Update configs/ files with your data paths

License

MIT License - see LICENSE file.

Contact

For questions about the code or paper, please open an issue on GitHub.

Acknowledgments

About

No description, website, or topics provided.

Resources

License

Contributing

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published