diff --git a/models/oir_flatmount_segmentation/LICENSE b/models/oir_flatmount_segmentation/LICENSE new file mode 100644 index 00000000..2f1b4388 --- /dev/null +++ b/models/oir_flatmount_segmentation/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2026 Hartnett Lab + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/models/oir_flatmount_segmentation/README.md b/models/oir_flatmount_segmentation/README.md new file mode 100644 index 00000000..8340ab3e --- /dev/null +++ b/models/oir_flatmount_segmentation/README.md @@ -0,0 +1,51 @@ +# OIR Flatmount Segmentation (Hartnett Lab) + +### **Authors** +Neal S. Shah*, Aniket Ramshekar*, Bright Asare-Bediako, Morgan P. Tankersley, Heng-Chiao Huang, Shreya Beri, Eric Kunz, Aaron Y. Lee, M. Elizabeth Hartnett + +Byers Eye Institute Department of Ophthalmology, Stanford University School of Medicine, Stanford, CA, USA + +### **Tags** +Segmentation, Retinal Flatmount, Oxygen-Induced Retinopathy, OIR, Mouse, Rat, Intravitreal Neovascularization, Avascular Area + +## **Model Description** +This model performs automated segmentation of oxygen-induced retinopathy (OIR) retinal flatmount images into three regions: total retina (TR), intravitreal neovascularization (IVNV), and avascular area (AVA). +The architecture is a multi-task Attention U-Net with a ConvNeXt-Tiny encoder [1] and deep supervision (~8.7M trainable parameters). +For inference, the final release uses an ensemble of 5 cross-validation models, with test-time augmentation and per-class thresholding, to improve robustness across mouse and rat OIR images. + +## **Data** +Model development used three datasets: + +1. **Rat IVNV pretraining dataset:** 72 rat OIR flatmount images with IVNV-only annotations (used in intermediate Stage 2 training). +2. **Final development dataset:** 345 annotated images total (267 mouse, 78 rat), including: + - 127 expert human-annotated images (49 mouse, 78 rat) + - 218 curated open-source mouse images [2] with reviewed masks generated from a prior published model [3] +3. **Independent test dataset:** 37 images (18 mouse OIR, 19 rat OIR), held out from training/validation/model selection. + +For final model development, a modified 5-fold cross-validation strategy was used, with expert-annotated images serving as fold-level validation references and curated open-source mouse images used in training only. + +#### **Preprocessing** +Input retinal flatmount images were converted to grayscale, resized to 512×512, and intensity-normalized. +During training, joint image-mask augmentation was applied using random horizontal/vertical flips, random rotations (up to 180 degrees), brightness/contrast perturbation, CLAHE, Gaussian noise, elastic/grid/optical distortions, coarse dropout, motion blur, and random gamma adjustments. +These augmentations are implemented in the project training pipeline (`retrain_kfold_v2.py` / `train_with_split.py` using `dataset.py`). The MONAI `configs/train.json` file in this bundle is a compatibility template and keeps transform lists minimal. + +## **Performance** + +Dice agreement between model masks and human consensus masks was high for total retina (TR) and AVA, and moderate for IVNV in both species: +- Rat: TR Dice=0.983, AVA Dice=0.924, IVNV Dice=0.612 +- Mouse: TR Dice=0.975, AVA Dice=0.912, IVNV Dice=0.601 + +At the metric level, the deep learning model showed strong correlation with the mean of three graders for rat percent AVA (r=0.979) and rat percent IVNV (r=0.943). In mouse OIR, correlation was strong for percent AVA (r=0.957) but weak for percent IVNV (r=0.265), likely due to high inter-grader variability for mouse IVNV scoring. + +(For full analysis please refer to the manuscript.) + +## **System Configuration** +This model was trained on an Apple M2 pro 16GB Macbook. 5-fold cross-validation was run sequentially with batch size 4 and a maximum of 120 epochs per fold. The folds ran for 86, 88, 120, 61, and 80 epochs, with total training time of approximately 24 hours. + +## **Additional Usage Steps** +Model checkpoints are hosted externally and linked through `large_files.yml` (not committed directly in the repo due to file size limits). + +## **References** +1. Liu Z, Mao H, Wu CY, Feichtenhofer C, Darrell T, Xie S. A ConvNet for the 2020s. 2022:11966-11976. +2. Marra KV, Chen JS, Robles-Holmes HK, et al. Development of an Open-Source Dataset of Flat-Mounted Images for the Murine Oxygen-Induced Retinopathy Model of Ischemic Retinopathy. Transl Vis Sci Technol. Dec 2 2024;13(12):4. +3. Xiao S, Bucher F, Wu Y, et al. Fully automated, deep learning segmentation of oxygen-induced retinopathy images. JCI Insight. Dec 21 2017;2(24)doi:10.1172/jci.insight.97585 diff --git a/models/oir_flatmount_segmentation/configs/inference.json b/models/oir_flatmount_segmentation/configs/inference.json new file mode 100644 index 00000000..79f2d58c --- /dev/null +++ b/models/oir_flatmount_segmentation/configs/inference.json @@ -0,0 +1,71 @@ +{ + "bundle_root": ".", + "device": "$torch.device('cuda' if torch.cuda.is_available() else 'cpu')", + "output_dir": "$@bundle_root + '/infer_output'", + "input_dir": "$@bundle_root + '/inputs'", + "pred_threshold": 0.5, + "network_def": { + "_target_": "model_transformer.AttentionUNetTransformer", + "in_ch": 1, + "out_ch": 3, + "backbone": "convnext_tiny", + "pretrained": false + }, + "inferer": { + "_target_": "monai.inferers.SimpleInferer" + }, + "preprocessing": { + "_target_": "monai.transforms.Compose", + "transforms": [ + { + "_target_": "monai.transforms.LoadImaged", + "keys": [ + "image" + ], + "ensure_channel_first": true + }, + { + "_target_": "monai.transforms.ScaleIntensityRanged", + "keys": [ + "image" + ], + "a_min": 0, + "a_max": 255, + "b_min": -1.0, + "b_max": 1.0, + "clip": true + }, + { + "_target_": "monai.transforms.Resized", + "keys": [ + "image" + ], + "spatial_size": [ + 512, + 512 + ], + "mode": "area" + } + ] + }, + "postprocessing": { + "_target_": "monai.transforms.Compose", + "transforms": [ + { + "_target_": "monai.transforms.Activationsd", + "keys": [ + "pred" + ], + "sigmoid": true + }, + { + "_target_": "monai.transforms.AsDiscreted", + "keys": [ + "pred" + ], + "threshold": "$@pred_threshold" + } + ] + }, + "notes": "This config includes minimal standalone postprocessing (sigmoid + threshold) for bundle inference. For production results, use infer.py for 5-fold ensemble, D4 TTA, and calibrated per-class thresholds." +} diff --git a/models/oir_flatmount_segmentation/configs/metadata.json b/models/oir_flatmount_segmentation/configs/metadata.json new file mode 100644 index 00000000..48ea188c --- /dev/null +++ b/models/oir_flatmount_segmentation/configs/metadata.json @@ -0,0 +1,53 @@ +{ + "version": "0.1.0", + "changelog": { + "0.1.0": "Initial MONAI bundle packaging for OIR flatmount segmentation." + }, + "monai_version": "1.4.0", + "pytorch_version": "2.3.0", + "numpy_version": "1.26.4", + "name": "OIR Flatmount Segmentation (Hartnett Lab)", + "task": "pathology_segmentation", + "description": "Multi-task segmentation of total retina, intravitreal neovascularization, and avascular area in OIR flatmount images.", + "authors": "Neal Shah, Aniket Ramshekar, Bright Asare-Bediako, Morgan Tankersley, Heng-Chiao Huang, Shreya Beri, Eric Kunz, Aaron Y. Lee, M. Elizabeth Hartnett", + "copyright": "Hartnett Lab", + "data_source": "Hartnett Lab OIR flatmount datasets", + "data_type": "image", + "image_classes": "retinal_flatmount", + "intended_use": "Research", + "network_data_format": { + "inputs": { + "image": { + "type": "image", + "format": "png/jpg/tif/bmp", + "dtype": "float32", + "num_channels": 1, + "spatial_shape": [ + 512, + 512 + ] + } + }, + "outputs": { + "pred": { + "type": "image", + "dtype": "float32", + "num_channels": 3, + "channels": [ + "tr", + "ivnv", + "ava" + ] + } + } + }, + "schema": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/meta_schema_20240725.json", + "optional_packages_version": { + "timm": "", + "albumentations": "", + "opencv-python": "", + "pandas": "", + "matplotlib": "", + "openpyxl": "" + } +} diff --git a/models/oir_flatmount_segmentation/configs/train.json b/models/oir_flatmount_segmentation/configs/train.json new file mode 100644 index 00000000..3835a9c8 --- /dev/null +++ b/models/oir_flatmount_segmentation/configs/train.json @@ -0,0 +1,70 @@ +{ + "bundle_root": ".", + "dataset_dir": "$@bundle_root + '/data'", + "output_dir": "$@bundle_root + '/train_output'", + "device": "$torch.device('cuda' if torch.cuda.is_available() else 'cpu')", + "network_def": { + "_target_": "model_transformer.AttentionUNetTransformer", + "in_ch": 1, + "out_ch": 3, + "backbone": "convnext_tiny", + "pretrained": true + }, + "train": { + "trainer": { + "_target_": "monai.engines.SupervisedTrainer" + }, + "trainer#max_epochs": 120, + "dataset": { + "_target_": "dataset.OIRSegmentationDataset" + }, + "dataset#data": "$@dataset_dir", + "preprocessing": { + "_target_": "monai.transforms.Compose", + "transforms": [] + }, + "postprocessing": { + "_target_": "monai.transforms.Compose", + "transforms": [] + }, + "inferer": { + "_target_": "monai.inferers.SimpleInferer" + }, + "key_metric": { + "mean_dice": { + "_target_": "monai.metrics.DiceMetric", + "include_background": true + } + }, + "handlers": [] + }, + "validate": { + "evaluator": { + "_target_": "monai.engines.SupervisedEvaluator" + }, + "dataset": { + "_target_": "dataset.OIRSegmentationDataset" + }, + "dataset#data": "$@dataset_dir", + "preprocessing": { + "_target_": "monai.transforms.Compose", + "transforms": [] + }, + "postprocessing": { + "_target_": "monai.transforms.Compose", + "transforms": [] + }, + "inferer": { + "_target_": "monai.inferers.SimpleInferer" + }, + "key_metric": { + "mean_dice": { + "_target_": "monai.metrics.DiceMetric", + "include_background": true + } + }, + "handlers": [] + }, + "val_interval": 1, + "notes": "Primary training is executed via retrain_kfold_v2.py and train_with_split.py in the parent project. This file exists to satisfy MONAI preferred train-config keys. Override dataset_dir to your local dataset path when running training. Augmentation and dataset-specific preprocessing are implemented in the external training pipeline, so transform lists in this template config are intentionally minimal." +} diff --git a/models/oir_flatmount_segmentation/cv_summary.csv b/models/oir_flatmount_segmentation/cv_summary.csv new file mode 100644 index 00000000..c3122cca --- /dev/null +++ b/models/oir_flatmount_segmentation/cv_summary.csv @@ -0,0 +1,6 @@ +fold,epochs_ran,best_val_dice,last_val_dice_retina,last_val_dice_nv,last_val_dice_vo,thr_retina,thr_nv,thr_vo +0,86,0.803034,0.954282,0.455693,0.872625,0.3,0.95,0.85 +1,88,0.830743,0.974396,0.569426,0.907811,0.8,0.6,0.65 +2,120,0.825497,0.966939,0.591886,0.916625,0.75,0.65,0.85 +3,61,0.852035,0.980751,0.545412,0.932191,0.35,0.95,0.5 +4,80,0.815581,0.977167,0.534101,0.897727,0.9,0.75,0.4 diff --git a/models/oir_flatmount_segmentation/docs/README.md b/models/oir_flatmount_segmentation/docs/README.md new file mode 100644 index 00000000..3e7bcac6 --- /dev/null +++ b/models/oir_flatmount_segmentation/docs/README.md @@ -0,0 +1,93 @@ +# OIR Flatmount Segmentation (Hartnett Lab) + +## Overview + +This bundle provides automated segmentation of oxygen-induced retinopathy (OIR) flatmount images for: + +- Total Retina (TR) +- Intravitreal Neovascularization (IVNV) +- Avascular Area (AVA) + +The model is a multi-task Attention U-Net with a ConvNeXt-Tiny encoder and deep supervision, trained with fold-wise threshold calibration and ensemble inference. + +## Intended Use + +- Research use in preclinical retinal OIR image analysis. +- Automated quantification support for TR, IVNV, and AVA area measurements. +- Batch processing workflows for publication and reproducible analytics. + +## Not Intended For + +- Clinical diagnosis, triage, or autonomous treatment decisions. +- Use outside domain conditions (non-flatmount, unrelated species, unrelated disease context) without additional validation. + +## Input and Output + +- **Input**: Single-channel retinal flatmount image. RGB inputs are converted to grayscale. +- **Output**: + - Binary masks for TR, IVNV, AVA + - Overlay visualizations + - Per-image quantitative metrics (areas and percentages) + +## Training Summary + +- Data split strategy: + - Deduplicated by basename to avoid `.jpg/.tif` leakage. + - HQ expert annotations reserved for validation folds. + - Auto-generated images included only in training. +- Loss: + - BCE + Dice (all channels) + - Focal Tversky (TR/IVNV/AVA channel-weighted) + - Boundary Dice with Sobel gradients + - Deep supervision loss at decoder intermediates +- Optimization: + - AdamW + warmup cosine LR + - Discriminative learning rates (backbone vs decoder/heads) + - EMA weights + - Gradient clipping + - Early stopping + - Strong augmentation + +## Inference Summary + +- Ensemble over fold checkpoints (`fold_*/best.pth`) +- D4 test-time augmentation (rotations and flips) +- Threshold calibration and post-processing: + - mask binarization by per-class threshold + - retina-constrained IVNV/AVA + - optional component filtering and morphological closing + +## Reproducibility Artifacts + +Each training fold is expected to emit: + +- `training_history.csv` +- `learning_curves.png` +- `dataset_log.xlsx` +- `run_manifest.json` +- `best.pth`, `final.pth`, `thresholds.json` + +## Authors + +Neal Shah1*, Aniket Ramshekar1*, Bright Asare-Bediako1, Morgan Tankersley1, Heng-Chiao Huang1,2, Shreya Beri1, Eric Kunz3, Aaron Y. Lee4, M. Elizabeth Hartnett1,# + +1 Byers Eye Institute Department of Ophthalmology, Stanford University School of Medicine, Stanford, California, USA +2 Department of Ophthalmology, Chang Gung Memorial Hospital, Chiayi, Taiwan +3 John A. Moran Eye Center, University of Utah, Salt Lake City, Utah, USA +4 John F. Hardesty Department of Ophthalmology and Visual Sciences, Washington University in St. Louis, St. Louis, Missouri, USA + +## Contacts + +- Neal Shah: neals1@stanford.edu +- Aniket Ramshekar: aniket.ramshekar@stanford.edu +- M. Elizabeth Hartnett: me.hartnett@stanford.edu + +## Citation + +If you use this model, please cite the associated TVST publication (to be updated after acceptance) and acknowledge the Hartnett Lab. + +## Known Limitations + +- Performance may degrade for out-of-distribution scanners/prep protocols. +- Small IVNV lesions are sensitive to threshold and component filtering settings. +- Cross-species domain shift (mouse vs rat) should be evaluated explicitly per cohort. diff --git a/models/oir_flatmount_segmentation/docs/data_license.txt b/models/oir_flatmount_segmentation/docs/data_license.txt new file mode 100644 index 00000000..cc0da09d --- /dev/null +++ b/models/oir_flatmount_segmentation/docs/data_license.txt @@ -0,0 +1,24 @@ +Data licensing and access notes for `oir_flatmount_segmentation`. + +1. Data distribution: + - No training or evaluation images are distributed in this bundle. + - This bundle contains model code/configuration and references to externally hosted model weights only. + +2. Source datasets: + - Model development used retinal flatmount datasets from Hartnett Lab studies and curated open-source data. + - Each source dataset remains subject to its original license/terms of use. + +3. User responsibilities: + - Users are responsible for obtaining lawful access to any data used with this model. + - Users must comply with all applicable data use agreements, institutional policies, and local regulations. + +4. Privacy and ethics: + - Do not use data in ways that violate privacy, consent, or ethics approvals. + +5. Attribution: + - Please cite the associated manuscript and acknowledge the Hartnett Lab when using this model. + +6. Contact: + - neals1@stanford.edu + - aniket.ramshekar@stanford.edu + - me.hartnett@stanford.edu diff --git a/models/oir_flatmount_segmentation/large_files.yml b/models/oir_flatmount_segmentation/large_files.yml new file mode 100644 index 00000000..302daea4 --- /dev/null +++ b/models/oir_flatmount_segmentation/large_files.yml @@ -0,0 +1,21 @@ +large_files: + - path: "weights/fold_0/model.pth" + url: "https://github.com/hartnettlabteam/oir_flatmount_segmentation/releases/download/v1.0.0/oir_flatmount_segmentation_fold_0_model.pth" + hash_val: "1e4b607f1a89ffd8272d2edbcd5596b3096403177d6d00aeed384cef2a6aa078" + hash_type: "sha256" + - path: "weights/fold_1/model.pth" + url: "https://github.com/hartnettlabteam/oir_flatmount_segmentation/releases/download/v1.0.0/oir_flatmount_segmentation_fold_1_model.pth" + hash_val: "62d5c247c8025ddcd7ad4ae0d06d9a2921a8e86cc790f47ecd3d9efaf6e5fffd" + hash_type: "sha256" + - path: "weights/fold_2/model.pth" + url: "https://github.com/hartnettlabteam/oir_flatmount_segmentation/releases/download/v1.0.0/oir_flatmount_segmentation_fold_2_model.pth" + hash_val: "4fe96be0fbfe53350abd3dfc7b83260509b6989513818fa5b60c38675781ee72" + hash_type: "sha256" + - path: "weights/fold_3/model.pth" + url: "https://github.com/hartnettlabteam/oir_flatmount_segmentation/releases/download/v1.0.0/oir_flatmount_segmentation_fold_3_model.pth" + hash_val: "8a60f5096cd233b660374a1192cef153f3aebd2ea79d53f293b652a4fb78049d" + hash_type: "sha256" + - path: "weights/fold_4/model.pth" + url: "https://github.com/hartnettlabteam/oir_flatmount_segmentation/releases/download/v1.0.0/oir_flatmount_segmentation_fold_4_model.pth" + hash_val: "1891584577d473e3206b720635bca4c8e98add0b6eabc232e4f22c875ec1dbbd" + hash_type: "sha256" diff --git a/models/oir_flatmount_segmentation/model.py b/models/oir_flatmount_segmentation/model.py new file mode 100644 index 00000000..07e8e1cd --- /dev/null +++ b/models/oir_flatmount_segmentation/model.py @@ -0,0 +1,205 @@ +from typing import List, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class ConvBNReLU(nn.Module): + def __init__(self, in_ch: int, out_ch: int, k: int = 3, s: int = 1, p: int = 1): + super().__init__() + self.conv = nn.Conv2d(in_ch, out_ch, k, s, p, bias=False) + self.bn = nn.BatchNorm2d(out_ch) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + x = x.contiguous() + x = self.conv(x).contiguous() + x = self.bn(x) + x = self.relu(x) + return x + + +class scSE(nn.Module): + # Concurrent spatial and channel squeeze & excitation + def __init__(self, ch: int, reduction: int = 16): + super().__init__() + self.cSE = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(ch, max(ch // reduction, 1), 1), + nn.ReLU(inplace=True), + nn.Conv2d(max(ch // reduction, 1), ch, 1), + nn.Sigmoid(), + ) + self.sSE = nn.Sequential( + nn.Conv2d(ch, 1, 1), + nn.Sigmoid(), + ) + + def forward(self, x): + c = self.cSE(x) + s = self.sSE(x) + return x * c + x * s + + +class AttentionGate(nn.Module): + # Gating from decoder g modulates skip x + def __init__(self, in_ch_x: int, in_ch_g: int, inter_ch: int): + super().__init__() + self.theta_x = nn.Conv2d(in_ch_x, inter_ch, kernel_size=1, bias=False) + self.phi_g = nn.Conv2d(in_ch_g, inter_ch, kernel_size=1, bias=True) + self.psi = nn.Conv2d(inter_ch, 1, kernel_size=1, bias=True) + self.relu = nn.ReLU(inplace=True) + self.sigmoid = nn.Sigmoid() + + def forward(self, x_skip, g): + # Both x_skip and g should have the same spatial size here + theta_x = self.theta_x(x_skip) + phi_g = self.phi_g(g) + t = self.relu(theta_x + phi_g) + psi = self.sigmoid(self.psi(t)) + return x_skip * psi + + +class EncoderBlock(nn.Module): + def __init__(self, in_ch: int, out_ch: int): + super().__init__() + self.conv1 = ConvBNReLU(in_ch, out_ch) + self.conv2 = ConvBNReLU(out_ch, out_ch) + self.scse = scSE(out_ch) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + x = self.scse(x) + return x + + +class DecoderBlock(nn.Module): + def __init__(self, in_ch: int, skip_ch: int, out_ch: int): + super().__init__() + self.up = nn.Sequential( + nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False), + nn.Conv2d(in_ch, out_ch, kernel_size=1, bias=False), + ) + # Gate uses decoder's out_ch as g input channels + self.att = AttentionGate(skip_ch, out_ch, inter_ch=max(out_ch // 2, 1)) + self.conv1 = ConvBNReLU(out_ch + skip_ch, out_ch) + self.conv2 = ConvBNReLU(out_ch, out_ch) + self.scse = scSE(out_ch) + + def forward(self, x, skip): + x = self.up(x) + skip = self.att(skip, x) + x = torch.cat([x, skip], dim=1).contiguous() + x = self.conv1(x) + x = self.conv2(x) + x = self.scse(x) + return x + + +class AttentionUNet(nn.Module): + def __init__(self, in_ch: int = 1, base_ch: int = 32, out_ch: int = 3): + super().__init__() + c1 = base_ch + c2 = base_ch * 2 + c3 = base_ch * 4 + c4 = base_ch * 8 + c5 = base_ch * 16 + + self.enc1 = EncoderBlock(in_ch, c1) + self.enc2 = EncoderBlock(c1, c2) + self.enc3 = EncoderBlock(c2, c3) + self.enc4 = EncoderBlock(c3, c4) + self.center = EncoderBlock(c4, c5) + + self.pool = nn.MaxPool2d(2) + + self.dec4 = DecoderBlock(c5, c4, c4) + self.dec3 = DecoderBlock(c4, c3, c3) + self.dec2 = DecoderBlock(c3, c2, c2) + self.dec1 = DecoderBlock(c2, c1, c1) + + # Separate heads for retina, NV, VO (main output at d1) + self.retina_head = nn.Conv2d(c1, 1, kernel_size=1) + self.nv_head = nn.Conv2d(c1, 1, kernel_size=1) + self.vo_head = nn.Conv2d(c1, 1, kernel_size=1) + + # Deep supervision heads at d2 and d3 + self.retina_ds2_head = nn.Conv2d(c2, 1, kernel_size=1) + self.nv_ds2_head = nn.Conv2d(c2, 1, kernel_size=1) + self.vo_ds2_head = nn.Conv2d(c2, 1, kernel_size=1) + self.retina_ds3_head = nn.Conv2d(c3, 1, kernel_size=1) + self.nv_ds3_head = nn.Conv2d(c3, 1, kernel_size=1) + self.vo_ds3_head = nn.Conv2d(c3, 1, kernel_size=1) + + def freeze_retina_head(self): + # Backward compatibility + for p in self.retina_head.parameters(): + p.requires_grad = False + + def freeze_retina_heads(self): + # Freeze all retina-related heads including deep supervision + for name, p in self.named_parameters(): + if ("retina_head" in name) or ("retina_ds" in name): + p.requires_grad = False + + def reset_nv_vo_heads(self): + # Reset main heads + for m in [self.nv_head, self.vo_head]: + nn.init.kaiming_normal_(m.weight, nonlinearity="relu") + if m.bias is not None: + nn.init.zeros_(m.bias) + # Reset DS heads + for m in [self.nv_ds2_head, self.vo_ds2_head, self.nv_ds3_head, self.vo_ds3_head]: + nn.init.kaiming_normal_(m.weight, nonlinearity="relu") + if m.bias is not None: + nn.init.zeros_(m.bias) + + def forward(self, x): + e1 = self.enc1(x) + e2 = self.enc2(self.pool(e1)) + e3 = self.enc3(self.pool(e2)) + e4 = self.enc4(self.pool(e3)) + cent = self.center(self.pool(e4)) + + d4 = self.dec4(cent, e4) + d3 = self.dec3(d4, e3) + d2 = self.dec2(d3, e2) + d1 = self.dec1(d2, e1) + + logits_r = self.retina_head(d1) + logits_nv = self.nv_head(d1) + logits_vo = self.vo_head(d1) + logits = torch.cat([logits_r, logits_nv, logits_vo], dim=1) + return logits + + def forward_with_aux(self, x): + # Returns main logits and deep supervision logits upsampled to input size + e1 = self.enc1(x) + e2 = self.enc2(self.pool(e1)) + e3 = self.enc3(self.pool(e2)) + e4 = self.enc4(self.pool(e3)) + cent = self.center(self.pool(e4)) + + d4 = self.dec4(cent, e4) + d3 = self.dec3(d4, e3) + d2 = self.dec2(d3, e2) + d1 = self.dec1(d2, e1) + + # main + logits_r = self.retina_head(d1) + logits_nv = self.nv_head(d1) + logits_vo = self.vo_head(d1) + main = torch.cat([logits_r, logits_nv, logits_vo], dim=1) + # ds2 from d2 + ds2_r = F.interpolate(self.retina_ds2_head(d2), size=(x.shape[2], x.shape[3]), mode="bilinear", align_corners=False) + ds2_nv = F.interpolate(self.nv_ds2_head(d2), size=(x.shape[2], x.shape[3]), mode="bilinear", align_corners=False) + ds2_vo = F.interpolate(self.vo_ds2_head(d2), size=(x.shape[2], x.shape[3]), mode="bilinear", align_corners=False) + ds2 = torch.cat([ds2_r, ds2_nv, ds2_vo], dim=1) + # ds3 from d3 + ds3_r = F.interpolate(self.retina_ds3_head(d3), size=(x.shape[2], x.shape[3]), mode="bilinear", align_corners=False) + ds3_nv = F.interpolate(self.nv_ds3_head(d3), size=(x.shape[2], x.shape[3]), mode="bilinear", align_corners=False) + ds3_vo = F.interpolate(self.vo_ds3_head(d3), size=(x.shape[2], x.shape[3]), mode="bilinear", align_corners=False) + ds3 = torch.cat([ds3_r, ds3_nv, ds3_vo], dim=1) + return {"main": main, "ds2": ds2, "ds3": ds3} diff --git a/models/oir_flatmount_segmentation/model_transformer.py b/models/oir_flatmount_segmentation/model_transformer.py new file mode 100644 index 00000000..91b089fc --- /dev/null +++ b/models/oir_flatmount_segmentation/model_transformer.py @@ -0,0 +1,154 @@ +import timm +import torch +import torch.nn as nn +import torch.nn.functional as F + +from model import AttentionGate, ConvBNReLU, scSE + +class TransformerEncoder(nn.Module): + def __init__(self, backbone: str = "swin_tiny_patch4_window7_224", in_ch: int = 1, pretrained: bool = True): + super().__init__() + self.backbone = timm.create_model(backbone, pretrained=pretrained, features_only=True, in_chans=in_ch) + self.channels = self.backbone.feature_info.channels + + def forward(self, x): + feats = self.backbone(x) + return feats # list [C1@H/4, C2@H/8, C3@H/16, C4@H/32] + + +class AttentionUNetTransformer(nn.Module): + def __init__(self, in_ch: int = 1, out_ch: int = 3, backbone: str = "swin_tiny_patch4_window7_224", pretrained: bool = True): + super().__init__() + self.enc = TransformerEncoder(backbone=backbone, in_ch=in_ch, pretrained=pretrained) + c1, c2, c3, c4 = self.enc.channels + + # project features to stable decoder widths + dec_c4 = 384 + dec_c3 = 192 + dec_c2 = 96 + dec_c1 = 48 + + self.proj4 = nn.Conv2d(c4, dec_c4, 1, bias=False) + self.proj3 = nn.Conv2d(c3, dec_c3, 1, bias=False) + self.proj2 = nn.Conv2d(c2, dec_c2, 1, bias=False) + self.proj1 = nn.Conv2d(c1, dec_c1, 1, bias=False) + + self.dec3 = self._decoder_block(dec_c4, dec_c3) + self.dec2 = self._decoder_block(dec_c3, dec_c2) + self.dec1 = self._decoder_block(dec_c2, dec_c1) + self.dec0 = nn.Sequential( + nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False), + ConvBNReLU(dec_c1, dec_c1), + scSE(dec_c1), + ) + # Separate heads + self.retina_head = nn.Conv2d(dec_c1, 1, 1) + self.nv_head = nn.Conv2d(dec_c1, 1, 1) + self.vo_head = nn.Conv2d(dec_c1, 1, 1) + + # Deep supervision heads at y1 (dec1 output) and y2 (dec2 output) + self.retina_ds2_head = nn.Conv2d(dec_c2, 1, 1) + self.nv_ds2_head = nn.Conv2d(dec_c2, 1, 1) + self.vo_ds2_head = nn.Conv2d(dec_c2, 1, 1) + self.retina_ds3_head = nn.Conv2d(dec_c3, 1, 1) + self.nv_ds3_head = nn.Conv2d(dec_c3, 1, 1) + self.vo_ds3_head = nn.Conv2d(dec_c3, 1, 1) + + def freeze_retina_head(self): + for p in self.retina_head.parameters(): + p.requires_grad = False + + def freeze_retina_heads(self): + for name, p in self.named_parameters(): + if ("retina_head" in name) or ("retina_ds" in name): + p.requires_grad = False + + def reset_nv_vo_heads(self): + for m in [self.nv_head, self.vo_head]: + nn.init.kaiming_normal_(m.weight, nonlinearity="relu") + if m.bias is not None: + nn.init.zeros_(m.bias) + for m in [self.nv_ds2_head, self.vo_ds2_head, self.nv_ds3_head, self.vo_ds3_head]: + nn.init.kaiming_normal_(m.weight, nonlinearity="relu") + if m.bias is not None: + nn.init.zeros_(m.bias) + + def _decoder_block(self, in_c: int, skip_c: int): + return nn.ModuleDict({ + "up": nn.Sequential( + nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False), + nn.Conv2d(in_c, skip_c, 1, bias=False), + ), + "att": AttentionGate(skip_c, skip_c, inter_ch=max(skip_c // 2, 1)), + "conv": nn.Sequential( + ConvBNReLU(skip_c * 2, skip_c), + ConvBNReLU(skip_c, skip_c), + scSE(skip_c), + ), + }) + + def forward(self, x): + # Encoder features + f1, f2, f3, f4 = self.enc(x) + p1 = self.proj1(f1) + p2 = self.proj2(f2) + p3 = self.proj3(f3) + p4 = self.proj4(f4) + + y3_u = self.dec3["up"](p4) + y3_s = self.dec3["att"](p3, y3_u) + y3 = self.dec3["conv"](torch.cat([y3_u, y3_s], dim=1).contiguous()) + + y2_u = self.dec2["up"](y3) + y2_s = self.dec2["att"](p2, y2_u) + y2 = self.dec2["conv"](torch.cat([y2_u, y2_s], dim=1).contiguous()) + + y1_u = self.dec1["up"](y2) + y1_s = self.dec1["att"](p1, y1_u) + y1 = self.dec1["conv"](torch.cat([y1_u, y1_s], dim=1).contiguous()) + + y0 = self.dec0(y1.contiguous()) + # y0 is typically 256x256 for 512 input; upsample heads to 512 + logits_r = F.interpolate(self.retina_head(y0), size=(x.shape[2], x.shape[3]), mode="bilinear", align_corners=False) + logits_nv = F.interpolate(self.nv_head(y0), size=(x.shape[2], x.shape[3]), mode="bilinear", align_corners=False) + logits_vo = F.interpolate(self.vo_head(y0), size=(x.shape[2], x.shape[3]), mode="bilinear", align_corners=False) + logits = torch.cat([logits_r, logits_nv, logits_vo], dim=1) + return logits + + def forward_with_aux(self, x): + # Encoder features + f1, f2, f3, f4 = self.enc(x) + p1 = self.proj1(f1) + p2 = self.proj2(f2) + p3 = self.proj3(f3) + p4 = self.proj4(f4) + + y3_u = self.dec3["up"](p4) + y3_s = self.dec3["att"](p3, y3_u) + y3 = self.dec3["conv"](torch.cat([y3_u, y3_s], dim=1).contiguous()) + + y2_u = self.dec2["up"](y3) + y2_s = self.dec2["att"](p2, y2_u) + y2 = self.dec2["conv"](torch.cat([y2_u, y2_s], dim=1).contiguous()) + + y1_u = self.dec1["up"](y2) + y1_s = self.dec1["att"](p1, y1_u) + y1 = self.dec1["conv"](torch.cat([y1_u, y1_s], dim=1).contiguous()) + + y0 = self.dec0(y1.contiguous()) + logits_r = F.interpolate(self.retina_head(y0), size=(x.shape[2], x.shape[3]), mode="bilinear", align_corners=False) + logits_nv = F.interpolate(self.nv_head(y0), size=(x.shape[2], x.shape[3]), mode="bilinear", align_corners=False) + logits_vo = F.interpolate(self.vo_head(y0), size=(x.shape[2], x.shape[3]), mode="bilinear", align_corners=False) + main = torch.cat([logits_r, logits_nv, logits_vo], dim=1) + + # ds2 from y2 + ds2_r = F.interpolate(self.retina_ds2_head(y2), size=(x.shape[2], x.shape[3]), mode="bilinear", align_corners=False) + ds2_nv = F.interpolate(self.nv_ds2_head(y2), size=(x.shape[2], x.shape[3]), mode="bilinear", align_corners=False) + ds2_vo = F.interpolate(self.vo_ds2_head(y2), size=(x.shape[2], x.shape[3]), mode="bilinear", align_corners=False) + ds2 = torch.cat([ds2_r, ds2_nv, ds2_vo], dim=1) + # ds3 from y3 + ds3_r = F.interpolate(self.retina_ds3_head(y3), size=(x.shape[2], x.shape[3]), mode="bilinear", align_corners=False) + ds3_nv = F.interpolate(self.nv_ds3_head(y3), size=(x.shape[2], x.shape[3]), mode="bilinear", align_corners=False) + ds3_vo = F.interpolate(self.vo_ds3_head(y3), size=(x.shape[2], x.shape[3]), mode="bilinear", align_corners=False) + ds3 = torch.cat([ds3_r, ds3_nv, ds3_vo], dim=1) + return {"main": main, "ds2": ds2, "ds3": ds3} diff --git a/models/oir_flatmount_segmentation/release_manifest.json b/models/oir_flatmount_segmentation/release_manifest.json new file mode 100644 index 00000000..4a74e9c1 --- /dev/null +++ b/models/oir_flatmount_segmentation/release_manifest.json @@ -0,0 +1,71 @@ +{ + "model_name": "oir_flatmount_segmentation", + "release_id": "v2_20260317", + "created_utc": "2026-03-18T06:06:53Z", + "selection_method": "5-fold ensemble of best.pth checkpoints with per-fold calibrated thresholds", + "notes": "Selected as best method for robustness and generalization after completed 5-fold CV.", + "cv_aggregate": { + "best_val_dice_mean": 0.8253780000000001, + "best_val_dice_std": 0.018266284378603097, + "avg_thr_tr": 0.62, + "avg_thr_ivnv": 0.7799999999999999, + "avg_thr_ava": 0.65 + }, + "folds": [ + { + "fold": 0, + "model_path": "weights/fold_0/model.pth", + "thresholds_path": "weights/fold_0/thresholds.json", + "model_sha256": "1e4b607f1a89ffd8272d2edbcd5596b3096403177d6d00aeed384cef2a6aa078", + "thresholds": { + "tr": 0.3, + "ivnv": 0.95, + "ava": 0.85 + } + }, + { + "fold": 1, + "model_path": "weights/fold_1/model.pth", + "thresholds_path": "weights/fold_1/thresholds.json", + "model_sha256": "62d5c247c8025ddcd7ad4ae0d06d9a2921a8e86cc790f47ecd3d9efaf6e5fffd", + "thresholds": { + "tr": 0.8, + "ivnv": 0.6, + "ava": 0.65 + } + }, + { + "fold": 2, + "model_path": "weights/fold_2/model.pth", + "thresholds_path": "weights/fold_2/thresholds.json", + "model_sha256": "4fe96be0fbfe53350abd3dfc7b83260509b6989513818fa5b60c38675781ee72", + "thresholds": { + "tr": 0.75, + "ivnv": 0.65, + "ava": 0.85 + } + }, + { + "fold": 3, + "model_path": "weights/fold_3/model.pth", + "thresholds_path": "weights/fold_3/thresholds.json", + "model_sha256": "8a60f5096cd233b660374a1192cef153f3aebd2ea79d53f293b652a4fb78049d", + "thresholds": { + "tr": 0.35, + "ivnv": 0.95, + "ava": 0.5 + } + }, + { + "fold": 4, + "model_path": "weights/fold_4/model.pth", + "thresholds_path": "weights/fold_4/thresholds.json", + "model_sha256": "1891584577d473e3206b720635bca4c8e98add0b6eabc232e4f22c875ec1dbbd", + "thresholds": { + "tr": 0.9, + "ivnv": 0.75, + "ava": 0.4 + } + } + ] +} diff --git a/models/oir_flatmount_segmentation/requirements.txt b/models/oir_flatmount_segmentation/requirements.txt new file mode 100644 index 00000000..5dd43ff6 --- /dev/null +++ b/models/oir_flatmount_segmentation/requirements.txt @@ -0,0 +1,9 @@ +monai +torch +timm +numpy +pandas +opencv-python +matplotlib +albumentations +openpyxl diff --git a/models/oir_flatmount_segmentation/scripts/plot_learning_curves.py b/models/oir_flatmount_segmentation/scripts/plot_learning_curves.py new file mode 100644 index 00000000..c86ed935 --- /dev/null +++ b/models/oir_flatmount_segmentation/scripts/plot_learning_curves.py @@ -0,0 +1,68 @@ +#!/usr/bin/env python3 +"""Aggregate and plot fold-wise learning curves for OIR training outputs.""" + +import argparse +import os +from typing import List + +import matplotlib +import pandas as pd + +matplotlib.use("Agg") +import matplotlib.pyplot as plt # noqa: E402 + +def _load_histories(kfold_dir: str) -> List[pd.DataFrame]: + histories: List[pd.DataFrame] = [] + for fold in range(5): + csv_path = os.path.join(kfold_dir, f"fold_{fold}", "training_history.csv") + if os.path.exists(csv_path): + df = pd.read_csv(csv_path) + if not df.empty: + df["fold"] = fold + histories.append(df) + return histories + + +def main() -> None: + parser = argparse.ArgumentParser(description="Plot OIR k-fold learning curves.") + parser.add_argument("--kfold_dir", type=str, required=True, help="Path containing fold_*/training_history.csv") + parser.add_argument("--out", type=str, required=True, help="Output PNG path") + args = parser.parse_args() + + histories = _load_histories(args.kfold_dir) + if not histories: + raise FileNotFoundError(f"No training_history.csv files found in: {args.kfold_dir}") + + plt.style.use("seaborn-v0_8-whitegrid") + fig, axes = plt.subplots(2, 2, figsize=(13, 9)) + + for df in histories: + fold = int(df["fold"].iloc[0]) + axes[0, 0].plot(df["epoch"], df["train_loss"], alpha=0.8, label=f"fold_{fold}") + axes[0, 1].plot(df["epoch"], df["val_loss"], alpha=0.8, label=f"fold_{fold}") + axes[1, 0].plot(df["epoch"], df["val_dice_nv"], alpha=0.8, label=f"fold_{fold}") + axes[1, 1].plot(df["epoch"], df["val_dice_vo"], alpha=0.8, label=f"fold_{fold}") + + axes[0, 0].set_title("Train Loss") + axes[0, 1].set_title("Validation Loss") + axes[1, 0].set_title("Validation Dice - IVNV") + axes[1, 1].set_title("Validation Dice - AVA") + + for ax in axes.flatten(): + ax.set_xlabel("Epoch") + ax.legend(fontsize=8) + + axes[0, 0].set_ylabel("Loss") + axes[0, 1].set_ylabel("Loss") + axes[1, 0].set_ylabel("Dice") + axes[1, 1].set_ylabel("Dice") + + plt.tight_layout() + os.makedirs(os.path.dirname(args.out) or ".", exist_ok=True) + plt.savefig(args.out, dpi=300, bbox_inches="tight") + plt.close(fig) + print(f"Saved curves: {args.out}") + + +if __name__ == "__main__": + main() diff --git a/models/oir_flatmount_segmentation/weights/fold_0/thresholds.json b/models/oir_flatmount_segmentation/weights/fold_0/thresholds.json new file mode 100644 index 00000000..3022273c --- /dev/null +++ b/models/oir_flatmount_segmentation/weights/fold_0/thresholds.json @@ -0,0 +1,5 @@ +{ + "tr": 0.3, + "ivnv": 0.95, + "ava": 0.85 +} diff --git a/models/oir_flatmount_segmentation/weights/fold_1/thresholds.json b/models/oir_flatmount_segmentation/weights/fold_1/thresholds.json new file mode 100644 index 00000000..b7ac94eb --- /dev/null +++ b/models/oir_flatmount_segmentation/weights/fold_1/thresholds.json @@ -0,0 +1,5 @@ +{ + "tr": 0.8, + "ivnv": 0.6, + "ava": 0.65 +} diff --git a/models/oir_flatmount_segmentation/weights/fold_2/thresholds.json b/models/oir_flatmount_segmentation/weights/fold_2/thresholds.json new file mode 100644 index 00000000..42dbca7d --- /dev/null +++ b/models/oir_flatmount_segmentation/weights/fold_2/thresholds.json @@ -0,0 +1,5 @@ +{ + "tr": 0.75, + "ivnv": 0.65, + "ava": 0.85 +} diff --git a/models/oir_flatmount_segmentation/weights/fold_3/thresholds.json b/models/oir_flatmount_segmentation/weights/fold_3/thresholds.json new file mode 100644 index 00000000..df41eb95 --- /dev/null +++ b/models/oir_flatmount_segmentation/weights/fold_3/thresholds.json @@ -0,0 +1,5 @@ +{ + "tr": 0.35, + "ivnv": 0.95, + "ava": 0.5 +} diff --git a/models/oir_flatmount_segmentation/weights/fold_4/thresholds.json b/models/oir_flatmount_segmentation/weights/fold_4/thresholds.json new file mode 100644 index 00000000..7e9ad29d --- /dev/null +++ b/models/oir_flatmount_segmentation/weights/fold_4/thresholds.json @@ -0,0 +1,5 @@ +{ + "tr": 0.9, + "ivnv": 0.75, + "ava": 0.4 +}