diff --git a/configs/data/dataset.yaml b/configs/data/dataset.yaml index e13fec8d..09f8f4a1 100644 --- a/configs/data/dataset.yaml +++ b/configs/data/dataset.yaml @@ -14,6 +14,9 @@ dataset: test_split_filename: "split_mapping/test_split.csv" tiling_run_id: "fdf7550a2004474f8c7a05dc0cf1fd86" filter_tiles_run_id: "4e8f5d3c82124ea5a8f871a42d3ed9ba" + stratified_kfold_run_id: "c7eafdffa32743aa9eb6dd2bf3a185b5" + stratified_group_kfold_run_id: "382b41d2fa894514908e8067949c4326" + embedding_run_id: "c325e3a5033b4077b6febb0e3e6b0bd6" tissue_masks_run_id: "52bc0924f8624b259819c480c7cf213f" tissue_stats_run_id: "16ae2d003d88471b924e5f332415232a" @@ -56,4 +59,4 @@ dataset: "44 Klarzelliges Nierenzellkarzinom_1": "kidney" "50 Muzinöses Zystadenom_1": "breast" "85 Mammakarzinom NST": "breast" - "28 Zöliakie": "small intestine" \ No newline at end of file + "28 Zöliakie": "small intestine" diff --git a/configs/experiment/ml/linear_classifier_lbfgs_stratified_group_kfold.yaml b/configs/experiment/ml/linear_classifier_lbfgs_stratified_group_kfold.yaml new file mode 100644 index 00000000..4d92561f --- /dev/null +++ b/configs/experiment/ml/linear_classifier_lbfgs_stratified_group_kfold.yaml @@ -0,0 +1,27 @@ +# @package _global_ + +defaults: + - /experiment/ml/linear_classifier_stratified_group_kfold + - _self_ + +trainer: + max_epochs: 10 + +data: + batch_size: 1000000000 + train_shuffle: false + train_drop_last: false + num_workers: 0 + +model: + optimizer: lbfgs + learning_rate: 1.0 + lbfgs: + max_iter: 100 + max_eval: null + tolerance_grad: 1.0e-7 + tolerance_change: 1.0e-9 + history_size: 100 + line_search_fn: strong_wolfe + accumulate_batches: 1 + accumulate_on_cpu: false diff --git a/configs/experiment/ml/linear_classifier_lbfgs_stratified_kfold.yaml b/configs/experiment/ml/linear_classifier_lbfgs_stratified_kfold.yaml new file mode 100644 index 00000000..bd3c10b3 --- /dev/null +++ b/configs/experiment/ml/linear_classifier_lbfgs_stratified_kfold.yaml @@ -0,0 +1,26 @@ +# @package _global_ + +defaults: + - /experiment/ml/linear_classifier_stratified_kfold + - _self_ + +trainer: + max_epochs: 10 + +data: + batch_size: 1000000000 + train_shuffle: false + train_drop_last: false + +model: + optimizer: lbfgs + learning_rate: 1.0 + lbfgs: + max_iter: 100 + max_eval: null + tolerance_grad: 1.0e-7 + tolerance_change: 1.0e-9 + history_size: 100 + line_search_fn: strong_wolfe + accumulate_batches: 1 + accumulate_on_cpu: false diff --git a/configs/experiment/ml/linear_classifier_stratified_group_kfold.yaml b/configs/experiment/ml/linear_classifier_stratified_group_kfold.yaml new file mode 100644 index 00000000..471f5a36 --- /dev/null +++ b/configs/experiment/ml/linear_classifier_stratified_group_kfold.yaml @@ -0,0 +1,8 @@ +# @package _global_ + +defaults: + - /ml/linear_classifier + - _self_ + +kfold_strategy: stratified_group +kfold_run_id: ${dataset.mlflow_artifacts.stratified_group_kfold_run_id} diff --git a/configs/experiment/ml/linear_classifier_stratified_kfold.yaml b/configs/experiment/ml/linear_classifier_stratified_kfold.yaml new file mode 100644 index 00000000..c01fbbf9 --- /dev/null +++ b/configs/experiment/ml/linear_classifier_stratified_kfold.yaml @@ -0,0 +1,8 @@ +# @package _global_ + +defaults: + - /ml/linear_classifier + - _self_ + +kfold_strategy: stratified +kfold_run_id: ${dataset.mlflow_artifacts.stratified_kfold_run_id} diff --git a/configs/ml/data/embedding.yaml b/configs/ml/data/embedding.yaml new file mode 100644 index 00000000..40ff4b71 --- /dev/null +++ b/configs/ml/data/embedding.yaml @@ -0,0 +1,35 @@ +# @package _global_ + +data: + batch_size: 1024 + num_workers: 4 + train_shuffle: true + train_drop_last: true + + train: + _target_: ml.data.datasets.EmbeddingTilesDataset + embedding_uri: ${train_embedding_uri} + metadata_uri: ${train_metadata_uri} + class_indices: ${class_indices} + thresholds: ${thresholds} + tissue_prop_min: ${tissue_prop_min} + exclude_folds: + - ${val_fold} + + val: + _target_: ml.data.datasets.EmbeddingTilesDataset + embedding_uri: ${train_embedding_uri} + metadata_uri: ${train_metadata_uri} + class_indices: ${class_indices} + thresholds: ${thresholds} + tissue_prop_min: ${tissue_prop_min} + include_folds: + - ${val_fold} + + test: + _target_: ml.data.datasets.EmbeddingTilesDataset + embedding_uri: ${test_embedding_uri} + metadata_uri: ${test_metadata_uri} + class_indices: ${class_indices} + thresholds: ${thresholds} + tissue_prop_min: ${tissue_prop_min} diff --git a/configs/ml/linear_classifier.yaml b/configs/ml/linear_classifier.yaml new file mode 100644 index 00000000..d3393372 --- /dev/null +++ b/configs/ml/linear_classifier.yaml @@ -0,0 +1,54 @@ +# @package _global_ + +defaults: + - /data: dataset + - /class_mapping: collapse_alterations_to_other + - /ml/trainer: default + - /ml/data: embedding + - /ml/model: linear_classifier + - _self_ + +mode: fit + +embedding_run_id: ${dataset.mlflow_artifacts.embedding_run_id} +kfold_strategy: stratified +kfold_run_id: ${dataset.mlflow_artifacts.stratified_kfold_run_id} +filter_tiles_run_id: ${dataset.mlflow_artifacts.filter_tiles_run_id} + +train_embedding_uri: runs:/${embedding_run_id}/train/tiles +test_embedding_uri: runs:/${embedding_run_id}/test/tiles +train_metadata_uri: runs:/${kfold_run_id}/kfold_split/kfold_tiles.parquet +test_metadata_uri: runs:/${filter_tiles_run_id}/filter_tiles/test_tiles.parquet + +val_fold: 0 + +tissue_prop_min: 0.2 +thresholds: + Nerve: 0.0 + Blood: 0.0 + Connective-Tissue: 0.4 + Fat: 0.6 + Epithelium: 0.2 + Muscle: 0.5 + Other: 0.5 + +mlflow_artifact_path: linear_classifier + +metadata: + run_name: Linear Classifier ${dataset.name} ${kfold_strategy} fold=${val_fold} opt=${model.optimizer} wd=${model.weight_decay} + description: "Linear probe over frozen Virchow2 embeddings (run ${embedding_run_id}), ${kfold_strategy} kfold metadata ${kfold_run_id}." + hyperparams: + embedding_run_id: ${embedding_run_id} + kfold_strategy: ${kfold_strategy} + kfold_run_id: ${kfold_run_id} + filter_tiles_run_id: ${filter_tiles_run_id} + val_fold: ${val_fold} + tissue_prop_min: ${tissue_prop_min} + thresholds: ${thresholds} + optimizer: ${model.optimizer} + learning_rate: ${model.learning_rate} + weight_decay: ${model.weight_decay} + lbfgs: ${model.lbfgs} + batch_size: ${data.batch_size} + train_shuffle: ${data.train_shuffle} + train_drop_last: ${data.train_drop_last} diff --git a/configs/ml/model/linear_classifier.yaml b/configs/ml/model/linear_classifier.yaml new file mode 100644 index 00000000..4b4d9e83 --- /dev/null +++ b/configs/ml/model/linear_classifier.yaml @@ -0,0 +1,25 @@ +# @package _global_ + +model: + backbone: + _target_: torch.nn.Identity + + decode_head: + _target_: torch.nn.Linear + in_features: 2560 + out_features: ${len:${class_indices}} + + class_indices: ${class_indices} + + optimizer: adamw + learning_rate: 1.0e-4 + weight_decay: 0.0 + lbfgs: + max_iter: 100 + max_eval: null + tolerance_grad: 1.0e-7 + tolerance_change: 1.0e-9 + history_size: 100 + line_search_fn: strong_wolfe + accumulate_batches: 1 + accumulate_on_cpu: false diff --git a/configs/ml/trainer/default.yaml b/configs/ml/trainer/default.yaml new file mode 100644 index 00000000..8615025e --- /dev/null +++ b/configs/ml/trainer/default.yaml @@ -0,0 +1,30 @@ +# @package _global_ + +trainer: + max_epochs: 500 + accelerator: auto + devices: auto + precision: 32 + log_every_n_steps: 50 + deterministic: false + + callbacks: + early_stopping: + _target_: lightning.pytorch.callbacks.EarlyStopping + monitor: train/loss_epoch + mode: min + patience: 1 + min_delta: 1.0e-4 + model_checkpoint: + _target_: lightning.pytorch.callbacks.ModelCheckpoint + monitor: train/loss_epoch + mode: min + save_top_k: 1 + filename: "epoch={epoch}-train_loss={train/loss_epoch:.4f}" + auto_insert_metric_name: false + lr_monitor: + _target_: lightning.pytorch.callbacks.LearningRateMonitor + logging_interval: epoch + prediction_writer: + _target_: ml.callbacks.ParquetPredictionWriter + output_filename: predictions.parquet diff --git a/ml/__init__.py b/ml/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ml/__main__.py b/ml/__main__.py new file mode 100644 index 00000000..318c37c4 --- /dev/null +++ b/ml/__main__.py @@ -0,0 +1,36 @@ +from random import randint + +import hydra +import mlflow +from lightning import seed_everything +from omegaconf import DictConfig, OmegaConf +from rationai.mlkit import Trainer, autolog +from rationai.mlkit.lightning.loggers import MLFlowLogger + +from ml.data import DataModule +from ml.meta_arch import MetaArch + + +OmegaConf.register_new_resolver( + "random_seed", lambda: randint(0, 2**31), use_cache=True +) +OmegaConf.register_new_resolver("len", lambda x: len(x)) + + +@hydra.main(config_path="../configs", config_name="ml", version_base=None) +@autolog +def main(config: DictConfig, logger: MLFlowLogger) -> None: + seed_everything(config.seed, workers=True) + + data = hydra.utils.instantiate(config.data, _recursive_=False, _target_=DataModule) + model = hydra.utils.instantiate(config.model, _target_=MetaArch) + trainer = hydra.utils.instantiate(config.trainer, _target_=Trainer, logger=logger) + allowed_modes = ["fit", "test", "validate", "predict"] + if config.mode not in allowed_modes: + raise ValueError(f"Invalid mode {config.mode!r}. Allowed: {allowed_modes}") + getattr(trainer, config.mode)(model, datamodule=data, ckpt_path=config.checkpoint) + mlflow.end_run() + + +if __name__ == "__main__": + main() diff --git a/ml/callbacks/__init__.py b/ml/callbacks/__init__.py new file mode 100644 index 00000000..e9c20c4c --- /dev/null +++ b/ml/callbacks/__init__.py @@ -0,0 +1,4 @@ +from ml.callbacks.parquet_prediction_writer import ParquetPredictionWriter + + +__all__ = ["ParquetPredictionWriter"] diff --git a/ml/callbacks/parquet_prediction_writer.py b/ml/callbacks/parquet_prediction_writer.py new file mode 100644 index 00000000..a6f676bc --- /dev/null +++ b/ml/callbacks/parquet_prediction_writer.py @@ -0,0 +1,72 @@ +"""Aggregate ``predict_step`` outputs and write them as a parquet artifact.""" + +from pathlib import Path +from typing import Any + +import lightning as pl +import mlflow +import numpy as np +import pandas as pd +from lightning.pytorch.callbacks import BasePredictionWriter + + +class ParquetPredictionWriter(BasePredictionWriter): + """Collect per-tile predictions and write them as a parquet artifact. + + Aggregates ``predict_step`` outputs across the predict loop, writes one + parquet file with ``slide_id``, ``target``, ``pred``, ``prob_`` + columns and logs it to the active MLflow run. + """ + + def __init__(self, output_filename: str = "predictions.parquet") -> None: + super().__init__(write_interval="epoch") + self.output_filename = output_filename + + def write_on_epoch_end( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + predictions: Any, + batch_indices: Any, + ) -> None: + if trainer.global_rank != 0: + return + + batches = ( + predictions + if not predictions or isinstance(predictions[0], dict) + else [b for dataloader_preds in predictions for b in dataloader_preds] + ) + + slide_ids: list[str] = [] + targets: list[int] = [] + preds: list[int] = [] + probs: list[np.ndarray] = [] + for b in batches: + slide_ids.extend(b["slide_id"]) + targets.extend(b["target"].tolist()) + preds.extend(b["pred"].tolist()) + probs.append(b["probs"].numpy()) + + if not slide_ids: + return + + prob_matrix = np.concatenate(probs, axis=0) + + class_names = getattr(pl_module, "class_names", None) + prob_columns = ( + [f"prob_{c}" for c in class_names] + if class_names is not None and len(class_names) == prob_matrix.shape[1] + else [f"prob_{i}" for i in range(prob_matrix.shape[1])] + ) + + df = pd.DataFrame({"slide_id": slide_ids, "target": targets, "pred": preds}) + df = pd.concat([df, pd.DataFrame(prob_matrix, columns=prob_columns)], axis=1) + + out_path = Path(trainer.default_root_dir) / self.output_filename + out_path.parent.mkdir(parents=True, exist_ok=True) + df.to_parquet(out_path, index=False) + + active = mlflow.active_run() + if active is not None: + mlflow.log_artifact(str(out_path), artifact_path="predictions") diff --git a/ml/data/__init__.py b/ml/data/__init__.py new file mode 100644 index 00000000..e7058ee5 --- /dev/null +++ b/ml/data/__init__.py @@ -0,0 +1,4 @@ +from ml.data.data_module import DataModule + + +__all__ = ["DataModule"] diff --git a/ml/data/data_module.py b/ml/data/data_module.py new file mode 100644 index 00000000..bfac1182 --- /dev/null +++ b/ml/data/data_module.py @@ -0,0 +1,74 @@ +from collections.abc import Iterable + +from hydra.utils import instantiate +from lightning import LightningDataModule +from omegaconf import DictConfig +from torch.utils.data import DataLoader + +from ml.typing import Input + + +class DataModule(LightningDataModule): + """Generic Lightning datamodule that instantiates datasets lazily per stage. + + Mirrors the template pattern: ``**datasets`` accepts ``train``, ``val``, + ``test`` (or ``predict``) DictConfigs whose targets resolve to ``Dataset``s. + """ + + def __init__( + self, + batch_size: int, + num_workers: int = 0, + train_shuffle: bool = True, + train_drop_last: bool = True, + **datasets: DictConfig, + ) -> None: + super().__init__() + self.batch_size = batch_size + self.num_workers = num_workers + self.train_shuffle = train_shuffle + self.train_drop_last = train_drop_last + self.datasets = datasets + + def setup(self, stage: str) -> None: + match stage: + case "fit": + self.train = instantiate(self.datasets["train"]) + self.val = instantiate(self.datasets["val"]) + case "validate": + self.val = instantiate(self.datasets["val"]) + case "test": + self.test = instantiate(self.datasets["test"]) + case "predict": + dataset_cfg = self.datasets.get("predict") or self.datasets.get("test") + if dataset_cfg is None: + raise KeyError("Neither 'predict' nor 'test' dataset configured") + self.predict = instantiate(dataset_cfg) + + def train_dataloader(self) -> Iterable[Input]: + return DataLoader( + self.train, + batch_size=self.batch_size, + shuffle=self.train_shuffle, + drop_last=self.train_drop_last, + num_workers=self.num_workers, + persistent_workers=self.num_workers > 0, + ) + + def val_dataloader(self) -> Iterable[Input]: + return DataLoader( + self.val, + batch_size=self.batch_size, + num_workers=self.num_workers, + persistent_workers=self.num_workers > 0, + ) + + def test_dataloader(self) -> Iterable[Input]: + return DataLoader( + self.test, batch_size=self.batch_size, num_workers=self.num_workers + ) + + def predict_dataloader(self) -> Iterable[Input]: + return DataLoader( + self.predict, batch_size=self.batch_size, num_workers=self.num_workers + ) diff --git a/ml/data/datasets/__init__.py b/ml/data/datasets/__init__.py new file mode 100644 index 00000000..cd2f91a9 --- /dev/null +++ b/ml/data/datasets/__init__.py @@ -0,0 +1,4 @@ +from ml.data.datasets.embedding_tiles import EmbeddingTilesDataset + + +__all__ = ["EmbeddingTilesDataset"] diff --git a/ml/data/datasets/embedding_tiles.py b/ml/data/datasets/embedding_tiles.py new file mode 100644 index 00000000..5160ba18 --- /dev/null +++ b/ml/data/datasets/embedding_tiles.py @@ -0,0 +1,197 @@ +"""Tile-embedding dataset. + +Joins precomputed tile embeddings with tile metadata (k-fold parquet for train, +filter_tiles parquet for test) and applies tissue + per-class thresholds at +load time to produce ``(embedding, class_index, slide_id)`` triples. +""" + +from functools import cache +from pathlib import Path + +import numpy as np +import pandas as pd +import pyarrow as pa +import pyarrow.dataset as pads +import torch +from mlflow.artifacts import download_artifacts +from torch.utils.data import Dataset + +from ml.typing import Sample + + +class EmbeddingTilesDataset(Dataset[Sample]): + """Tile-level embedding dataset with on-the-fly filtering and labeling. + + Inner-joins ``embedding`` parquet with ``metadata`` parquet on + ``(slide_id, x, y)``. Metadata must contain ``roi_coverage_*`` columns; + label is the dominant class whose coverage meets its threshold. Tiles + failing the tissue proportion floor, with more than one annotated class, + or whose dominant class falls below its threshold are dropped. + + For train/val: pass the k-fold parquet as ``metadata_uri`` and use + ``include_folds`` / ``exclude_folds`` to split. For test: pass the + filter_tiles parquet (no fold column). + """ + + def __init__( + self, + embedding_uri: str | Path, + metadata_uri: str | Path, + class_indices: dict[str, int], + thresholds: dict[str, float], + tissue_prop_min: float, + include_folds: list[int] | None = None, + exclude_folds: list[int] | None = None, + ) -> None: + meta_df = self._filter_metadata( + metadata_uri, + thresholds, + tissue_prop_min, + include_folds, + exclude_folds, + ) + + emb_dir = self._resolve_uri(embedding_uri) + emb_table = pads.dataset(emb_dir, format="parquet").to_table( + columns=["slide_id", "x", "y", "embedding"] + ) + + emb_col = emb_table.column("embedding") + if pa.types.is_list(emb_col.type): + target_type = pa.large_list(emb_col.type.value_type) + emb_col = pa.chunked_array( + [c.cast(target_type) for c in emb_col.chunks], type=target_type + ) + + emb_idx = pa.array(range(emb_table.num_rows), type=pa.int64()) + emb_keys = emb_table.drop(["embedding"]).append_column("_emb_idx", emb_idx) + del emb_table + + meta_table = pa.Table.from_pandas(meta_df, preserve_index=False) + del meta_df + joined_keys = meta_table.join( + emb_keys, keys=["slide_id", "x", "y"], join_type="inner" + ) + del emb_keys, meta_table + if joined_keys.num_rows == 0: + raise RuntimeError("inner join with embeddings produced empty dataset") + + _idx_col = joined_keys.column("_emb_idx") + if isinstance(_idx_col, pa.ChunkedArray): + _idx_col = _idx_col.combine_chunks() + indices_np = _idx_col.to_numpy() + + first_chunk = emb_col.chunks[0] + embedding_dim = len(first_chunk.values) // len(first_chunk) + + # sort indices for sequential per-chunk access; restore order afterwards + sort_order = np.argsort(indices_np) + sorted_indices = indices_np[sort_order] + + chunk_offsets = np.concatenate( + [[0], np.cumsum([len(c) for c in emb_col.chunks])] + ) + embeddings = np.empty((len(indices_np), embedding_dim), dtype=np.float32) + for ci, chunk in enumerate(emb_col.chunks): + lo, hi = chunk_offsets[ci], chunk_offsets[ci + 1] + mask = (sorted_indices >= lo) & (sorted_indices < hi) + if not mask.any(): + continue + local_idx = sorted_indices[mask] - lo + chunk_np = ( + chunk.values.to_numpy(zero_copy_only=False) + .reshape(len(chunk), embedding_dim) + .astype(np.float32) + ) + embeddings[sort_order[mask]] = chunk_np[local_idx] + del emb_col + + self.embeddings = embeddings + labels = joined_keys.column("label").to_pandas() + unknown = set(labels.unique()) - set(class_indices.keys()) + if unknown: + raise ValueError( + f"labels in data not present in class_indices: {sorted(unknown)}" + ) + self.labels = labels.map(class_indices).to_numpy(dtype=np.int64) + self.slide_ids = joined_keys.column("slide_id").to_pandas().to_numpy() + + def __len__(self) -> int: + return len(self.labels) + + def __getitem__(self, idx: int) -> Sample: + return ( + torch.from_numpy(self.embeddings[idx]), + int(self.labels[idx]), + str(self.slide_ids[idx]), + ) + + @staticmethod + def _filter_metadata( + metadata_uri: str | Path, + thresholds: dict[str, float], + tissue_prop_min: float, + include_folds: list[int] | None, + exclude_folds: list[int] | None, + ) -> pd.DataFrame: + local = EmbeddingTilesDataset._resolve_uri(metadata_uri) + df = pd.read_parquet(local) + + roi_cols = [c for c in df.columns if c.startswith("roi_coverage_")] + if not roi_cols: + raise ValueError( + "metadata parquet has no roi_coverage_* columns; cannot label" + ) + + classes_in_data = {c.removeprefix("roi_coverage_") for c in roi_cols} + missing_thresholds = classes_in_data - set(thresholds.keys()) + if missing_thresholds: + raise ValueError( + f"thresholds missing entries for classes present in data: " + f"{sorted(missing_thresholds)}" + ) + + tissue_prop = df[roi_cols].sum(axis=1).to_numpy() + df = df.loc[tissue_prop >= tissue_prop_min] + if df.empty: + raise RuntimeError("all tiles dropped by tissue_prop_min filter") + + nonzero_classes = (df[roi_cols].to_numpy() > 0).sum(axis=1) + df = df.loc[pd.Series(nonzero_classes <= 1, index=df.index)] + if df.empty: + raise RuntimeError("all tiles dropped by single-class filter") + + roi_only = df[roi_cols] + dominant = roi_only.idxmax(axis=1).str.removeprefix("roi_coverage_") + dominant_value = roi_only.max(axis=1).to_numpy() + threshold_per_row = dominant.map(thresholds).to_numpy() + keep = dominant_value >= threshold_per_row + df = df.loc[pd.Series(keep, index=df.index)].copy() + df["label"] = dominant.to_numpy()[keep] + if df.empty: + raise RuntimeError("all tiles dropped by per-class thresholds") + + if include_folds is not None or exclude_folds is not None: + if "fold" not in df.columns: + raise RuntimeError( + "fold filter requested but 'fold' column not in metadata" + ) + if include_folds is not None: + df = df[df["fold"].isin(include_folds)] + if exclude_folds is not None: + df = df[~df["fold"].isin(exclude_folds)] + if df.empty: + raise RuntimeError("all tiles dropped by fold filter") + + return df[["slide_id", "x", "y", "label"]] + + @staticmethod + def _resolve_uri(path_or_uri: str | Path) -> str: + return EmbeddingTilesDataset._resolve_uri_cached(str(path_or_uri)) + + @staticmethod + @cache + def _resolve_uri_cached(uri: str) -> str: + if uri.startswith(("mlflow-artifacts:/", "runs:/")): + return download_artifacts(artifact_uri=uri) + return uri diff --git a/ml/meta_arch.py b/ml/meta_arch.py new file mode 100644 index 00000000..ab6882e9 --- /dev/null +++ b/ml/meta_arch.py @@ -0,0 +1,390 @@ +from collections import defaultdict +from collections.abc import Iterable +from typing import Any, cast + +import mlflow +import numpy as np +import pandas as pd +import torch +from lightning import LightningModule +from lightning.pytorch.utilities import grad_norm +from matplotlib import pyplot as plt +from matplotlib.figure import Figure +from torch import Tensor, nn +from torch.optim.optimizer import Optimizer +from torchmetrics import MetricCollection +from torchmetrics.classification import ( + MulticlassAccuracy, + MulticlassConfusionMatrix, + MulticlassF1Score, +) + +from ml.typing import Input, Outputs + + +class MetaArch(LightningModule): + """Top-level classification architecture: backbone + decode_head. + + For linear probing on precomputed embeddings, ``backbone`` is typically + ``nn.Identity`` and ``decode_head`` is a single ``nn.Linear``. + Criterion is class-weighted CrossEntropyLoss, computed from training labels in setup(). + """ + + def __init__( + self, + backbone: nn.Module, + decode_head: nn.Module, + class_indices: dict[str, int], + learning_rate: float = 1e-3, + weight_decay: float = 0.0, + optimizer: str = "adamw", + lbfgs: dict[str, Any] | None = None, + ) -> None: + super().__init__() + self.save_hyperparameters(ignore=["backbone", "decode_head"]) + + if optimizer not in {"adamw", "lbfgs"}: + raise ValueError(f"Unsupported optimizer {optimizer!r}") + if optimizer == "lbfgs": + self.automatic_optimization = False + + self.backbone = backbone + self.decode_head = decode_head + self._lbfgs_batches: list[tuple[Tensor, Tensor]] = [] + + self.class_names = [ + n for n, _ in sorted(class_indices.items(), key=lambda kv: kv[1]) + ] + num_classes = len(self.class_names) + self.criterion = nn.CrossEntropyLoss(weight=torch.ones(num_classes)) + + macro_metrics = MetricCollection( + { + "acc_macro": MulticlassAccuracy( + num_classes=num_classes, average="macro" + ), + "f1_macro": MulticlassF1Score(num_classes=num_classes, average="macro"), + } + ) + per_class_metrics = MetricCollection( + { + "acc_per_class": MulticlassAccuracy( + num_classes=num_classes, average=None + ), + "f1_per_class": MulticlassF1Score( + num_classes=num_classes, average=None + ), + } + ) + self.val_metrics = macro_metrics.clone(prefix="validation/") + self.test_metrics = macro_metrics.clone(prefix="test/") + self.val_per_class = per_class_metrics.clone(prefix="validation/") + self.test_per_class = per_class_metrics.clone(prefix="test/") + self.val_confmat = MulticlassConfusionMatrix(num_classes=num_classes) + self.test_confmat = MulticlassConfusionMatrix(num_classes=num_classes) + + self._test_slide_correct: dict[str, int] = defaultdict(int) + self._test_slide_total: dict[str, int] = defaultdict(int) + + def setup(self, stage: str) -> None: + if stage == "fit": + datamodule = cast("Any", self.trainer).datamodule + labels = datamodule.train.labels + if self.hparams["optimizer"] == "lbfgs": + self._validate_lbfgs_full_batch(datamodule, len(labels)) + num_classes = len(self.class_names) + counts = np.bincount(labels, minlength=num_classes).astype(float) + counts = np.maximum(counts, 1.0) + weights = len(labels) / (num_classes * counts) + self.criterion = nn.CrossEntropyLoss( + weight=torch.tensor(weights, dtype=torch.float32) + ) + for cls, w in zip(self.class_names, weights.tolist(), strict=True): + mlflow.log_metric(f"class_weight/{cls}", w) + + def forward(self, x: Tensor) -> Outputs: + features = self.backbone(x) + return self.decode_head(features) + + def training_step(self, batch: Input, batch_idx: int) -> Tensor: + if self.hparams["optimizer"] == "lbfgs": + return self._lbfgs_training_step(batch, batch_idx) + + inputs, targets, _ = batch + outputs = self(inputs) + loss = self.criterion(outputs, targets) + self.log("train/loss", loss, on_step=True, on_epoch=True, prog_bar=True) + return loss + + def on_before_optimizer_step(self, optimizer: Optimizer) -> None: + if self.hparams["optimizer"] == "lbfgs": + return + norms = grad_norm(self, norm_type=2) + self.log( + "train/grad_norm", + norms["grad_2.0_norm_total"], + on_step=False, + on_epoch=True, + prog_bar=True, + ) + + def validation_step(self, batch: Input, batch_idx: int) -> None: + inputs, targets, _ = batch + outputs = self(inputs) + loss = self.criterion(outputs, targets) + self.log("validation/loss", loss, on_epoch=True, prog_bar=True) + self.val_metrics.update(outputs, targets) + self.val_per_class.update(outputs, targets) + self.val_confmat.update(outputs, targets) + self.log_dict(self.val_metrics, on_epoch=True) + + def on_validation_epoch_end(self) -> None: + self._log_per_class(self.val_per_class, "validation") + self._log_confmat(self.val_confmat, "validation") + + def test_step(self, batch: Input, batch_idx: int) -> None: + inputs, targets, slide_ids = batch + outputs = self(inputs) + self.test_metrics.update(outputs, targets) + self.test_per_class.update(outputs, targets) + self.test_confmat.update(outputs, targets) + self.log_dict(self.test_metrics, on_epoch=True) + + preds = outputs.argmax(dim=1) + correct = (preds == targets).cpu().tolist() + for slide_id, ok in zip(slide_ids, correct, strict=True): + self._test_slide_correct[slide_id] += int(ok) + self._test_slide_total[slide_id] += 1 + + def on_test_epoch_end(self) -> None: + self._log_per_class(self.test_per_class, "test") + self._log_confmat(self.test_confmat, "test") + self._log_per_slide_accuracy() + self._test_slide_correct.clear() + self._test_slide_total.clear() + + def predict_step( + self, batch: Input, batch_idx: int, dataloader_idx: int = 0 + ) -> dict[str, Any]: + inputs, targets, slide_ids = batch + outputs = self(inputs) + probs = outputs.softmax(dim=1) + preds = outputs.argmax(dim=1) + return { + "slide_id": list(slide_ids), + "target": targets.cpu(), + "pred": preds.cpu(), + "probs": probs.cpu(), + } + + def configure_optimizers(self) -> Optimizer: + if self.hparams["optimizer"] == "lbfgs": + lbfgs = self.hparams.get("lbfgs") or {} + return torch.optim.LBFGS( + self.parameters(), + lr=self.hparams["learning_rate"], + max_iter=lbfgs.get("max_iter", 100), + max_eval=lbfgs.get("max_eval"), + tolerance_grad=lbfgs.get("tolerance_grad", 1.0e-7), + tolerance_change=lbfgs.get("tolerance_change", 1.0e-9), + history_size=lbfgs.get("history_size", 100), + line_search_fn=lbfgs.get("line_search_fn", "strong_wolfe"), + ) + + return torch.optim.AdamW( + self.parameters(), + lr=self.hparams["learning_rate"], + weight_decay=self.hparams["weight_decay"], + ) + + def _lbfgs_training_step(self, batch: Input, batch_idx: int) -> Tensor: + inputs, targets, _ = batch + self._lbfgs_batches.append(self._prepare_lbfgs_batch(inputs, targets)) + lbfgs = self.hparams.get("lbfgs") or {} + accumulation_steps = int(lbfgs.get("accumulate_batches", 1)) + is_last_batch = batch_idx + 1 == self.trainer.num_training_batches + should_step = len(self._lbfgs_batches) >= accumulation_steps or is_last_batch + if not should_step: + with torch.no_grad(): + return self.criterion(self(inputs), targets) + + optimizer = cast("Any", self.optimizers()) + total_samples = sum(targets.numel() for _, targets in self._lbfgs_batches) + + def closure() -> Tensor: + optimizer.zero_grad() + loss, _, _ = self._lbfgs_buffered_loss(total_samples) + if not torch.isfinite(loss): + raise FloatingPointError(f"non-finite LBFGS loss: {loss.item()}") + self.manual_backward(loss) + return loss + + step_loss = optimizer.step(closure=closure) + if not isinstance(step_loss, Tensor): + step_loss = torch.as_tensor(step_loss, device=self.device) + + optimizer.zero_grad() + loss, ce_loss, l2_loss = self._lbfgs_buffered_loss(total_samples) + if not torch.isfinite(loss): + raise FloatingPointError(f"non-finite LBFGS post-step loss: {loss.item()}") + self.manual_backward(loss) + grad_norm = self._total_grad_norm() + if grad_norm is not None and not torch.isfinite(grad_norm): + raise FloatingPointError( + f"non-finite LBFGS gradient norm: {grad_norm.item()}" + ) + optimizer.zero_grad() + self._lbfgs_batches.clear() + + self.log("train/loss", loss, on_step=True, on_epoch=True, prog_bar=True) + self.log("train/ce_loss", ce_loss, on_step=True, on_epoch=True) + self.log("train/l2_loss", l2_loss, on_step=True, on_epoch=True) + if grad_norm is not None: + self.log( + "train/grad_norm", + grad_norm, + on_step=True, + on_epoch=True, + prog_bar=True, + ) + self.log("train/lbfgs_step_loss", step_loss.detach(), on_step=True) + return loss + + def _prepare_lbfgs_batch( + self, inputs: Tensor, targets: Tensor + ) -> tuple[Tensor, Tensor]: + lbfgs = self.hparams.get("lbfgs") or {} + if lbfgs.get("accumulate_on_cpu", False): + return inputs.detach().cpu(), targets.detach().cpu() + return inputs, targets + + def _validate_lbfgs_full_batch(self, datamodule: Any, train_size: int) -> None: + lbfgs = self.hparams.get("lbfgs") or {} + batch_size = int(datamodule.batch_size) + accumulation_steps = int(lbfgs.get("accumulate_batches", 1)) + effective_batch_size = batch_size * accumulation_steps + + if datamodule.train_shuffle: + raise ValueError("LBFGS requires data.train_shuffle=false.") + if datamodule.train_drop_last: + raise ValueError("LBFGS requires data.train_drop_last=false.") + if effective_batch_size < train_size: + raise ValueError( + "LBFGS requires a deterministic full-batch objective. Set " + "data.batch_size >= len(train) or set " + "model.lbfgs.accumulate_batches >= ceil(len(train) / " + "data.batch_size). Current effective batch size is " + f"{effective_batch_size} for {train_size} training samples." + ) + + def _lbfgs_buffered_loss(self, total_samples: int) -> tuple[Tensor, Tensor, Tensor]: + ce_loss = torch.zeros((), device=self.device) + for micro_inputs, micro_targets in self._lbfgs_batches: + micro_inputs = micro_inputs.to(self.device) + micro_targets = micro_targets.to(self.device) + outputs = self(micro_inputs) + weight = micro_targets.numel() / total_samples + ce_loss = ce_loss + self.criterion(outputs, micro_targets) * weight + + l2_loss = self._l2_loss() + return ce_loss + l2_loss, ce_loss, l2_loss + + def _objective_loss(self, outputs: Tensor, targets: Tensor) -> Tensor: + return self.criterion(outputs, targets) + self._l2_loss() + + def _l2_loss(self) -> Tensor: + weight_decay = self.hparams["weight_decay"] + if weight_decay == 0: + return torch.zeros((), device=self.device) + + penalty = torch.zeros((), device=self.device) + for name, param in self.named_parameters(): + if param.requires_grad and name.endswith("weight"): + penalty = penalty + param.square().sum() + return 0.5 * weight_decay * penalty + + def _total_grad_norm(self) -> Tensor | None: + grads = [ + param.grad.detach().norm(2) + for param in self.parameters() + if param.grad is not None + ] + if not grads: + return None + return torch.linalg.vector_norm(torch.stack(grads), ord=2) + + def _log_per_class(self, collection: MetricCollection, split: str) -> None: + computed = collection.compute() + for metric_name, values in computed.items(): + tag = metric_name.split("/")[-1] # e.g. "acc_per_class" + for cls_name, val in zip(self.class_names, values.tolist(), strict=True): + self.log(f"{split}/{tag}/{cls_name}", val, on_epoch=True) + collection.reset() + + def _log_confmat(self, confmat: MulticlassConfusionMatrix, split: str) -> None: + matrix = confmat.compute().cpu().numpy() + confmat.reset() + fig = _confmat_figure(matrix, self.class_names, title=f"{split} confmat") + artifact_file = f"confusion_matrix/{split}_epoch_{self.current_epoch}.png" + try: + mlflow.log_figure(fig, artifact_file=artifact_file) + finally: + plt.close(fig) + + def _log_per_slide_accuracy(self) -> None: + accs = [ + self._test_slide_correct[s] / self._test_slide_total[s] + for s in self._test_slide_total + ] + if not accs: + return + self.log("test/slide_acc_mean", float(np.mean(accs)), on_epoch=True) + self.log("test/slide_acc_median", float(np.median(accs)), on_epoch=True) + self.log("test/slide_acc_min", float(np.min(accs)), on_epoch=True) + + rows = [ + { + "slide_id": s, + "tile_accuracy": self._test_slide_correct[s] / n, + "n_tiles": n, + } + for s, n in self._test_slide_total.items() + ] + mlflow.log_table( + data=pd.DataFrame(rows), + artifact_file="per_slide/test_tile_accuracy.json", + ) + + +def _confmat_figure( + matrix: np.ndarray, class_names: Iterable[str], title: str +) -> Figure: + row_sums = matrix.sum(axis=1, keepdims=True) + normalized = np.divide( + matrix, row_sums, where=row_sums > 0, out=np.zeros_like(matrix, dtype=float) + ) + + fig, ax = plt.subplots(figsize=(6, 5)) + im = ax.imshow(normalized, cmap="Blues", vmin=0, vmax=1) + ax.set_title(title) + ax.set_xlabel("Predicted") + ax.set_ylabel("True") + names = list(class_names) + ax.set_xticks(range(len(names))) + ax.set_yticks(range(len(names))) + ax.set_xticklabels(names, rotation=45, ha="right") + ax.set_yticklabels(names) + for i in range(matrix.shape[0]): + for j in range(matrix.shape[1]): + ax.text( + j, + i, + str(matrix[i, j]), + ha="center", + va="center", + color="white" if normalized[i, j] > 0.5 else "black", + fontsize=8, + ) + fig.colorbar(im, ax=ax) + fig.tight_layout() + return fig diff --git a/ml/modeling/__init__.py b/ml/modeling/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ml/typing.py b/ml/typing.py new file mode 100644 index 00000000..7060f5e4 --- /dev/null +++ b/ml/typing.py @@ -0,0 +1,6 @@ +import torch + + +type Sample = tuple[torch.Tensor, int, str] +type Input = tuple[torch.Tensor, torch.Tensor, list[str]] +type Outputs = torch.Tensor diff --git a/preprocessing/_labels.py b/preprocessing/_labels.py new file mode 100644 index 00000000..229f7dc6 --- /dev/null +++ b/preprocessing/_labels.py @@ -0,0 +1,24 @@ +"""Shared helpers for deriving tile labels from roi_coverage_* columns.""" + +from collections.abc import Mapping +from typing import Any + +import numpy as np +import pandas as pd + + +def compute_label_and_tissue_prop( + roi_data: Mapping[str, Any], + roi_cols: list[str], +) -> tuple[np.ndarray, np.ndarray]: + """Compute (label, tissue_prop) from roi_coverage_* columns. + + label = argmax across roi_cols (with ``roi_coverage_`` prefix stripped), + falling back to ``"background"`` whenever all coverages are zero. + tissue_prop = sum across roi_cols. + """ + roi_df = pd.DataFrame({col: roi_data[col] for col in roi_cols}) + tp = roi_df.sum(axis=1).to_numpy() + lbl = roi_df.idxmax(axis=1).str.removeprefix("roi_coverage_").to_numpy() + lbl[tp == 0] = "background" + return lbl, tp diff --git a/pyproject.toml b/pyproject.toml index 183bdd91..450c4922 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,19 +19,18 @@ dependencies = [ "tqdm>=4.66.0", "rationai-sdk", "ratiopath>=1.2.0", - "pyarrow>=19.0.0", - "datasets>=3.0.0", - "scikit-learn>=1.8.0", + "pyarrow>=19.0.1", + "datasets>=4.0.0", "numpy>=2.3.5", "rationai-tiling>=1.1.1", "tifffile>=2025.12.20", "torch>=2.0.0", "torchvision>=0.15.0", + "lightning>=2.0.0", + "torchmetrics>=1.0.0", "timm>=1.0.0", "einops>=0.8.0", "matplotlib>=3.10.7", - "pyarrow>=19.0.1", - "datasets>=4.0.0", ] [dependency-groups] diff --git a/scripts/submit_train_linear_probe.py b/scripts/submit_train_linear_probe.py new file mode 100644 index 00000000..3f7ecb1a --- /dev/null +++ b/scripts/submit_train_linear_probe.py @@ -0,0 +1,18 @@ +from kube_jobs import storage, submit_job + + +submit_job( + job_name="tissue-classification-train-linear", + username="vcifka", + cpu=8, + memory="64Gi", + gpu=None, + public=False, + script=[ + "git clone --branch feature/ml-linear-classifier https://github.com/RationAI/tissue-classification.git workdir", + "cd workdir", + "uv sync", + "uv run python -m ml +experiment=ml/linear_classifier_stratified_group_kfold val_fold=0,1,2,3,4 model.weight_decay=0,1e-5,1e-4,1e-3,1e-2 --multirun", + ], + storage=[storage.secure.PROJECTS], +) diff --git a/split/kfold_split.py b/split/kfold_split.py index 6aca4596..7f96ffc5 100644 --- a/split/kfold_split.py +++ b/split/kfold_split.py @@ -12,6 +12,8 @@ from rationai.mlkit.lightning.loggers import MLFlowLogger from sklearn.model_selection import StratifiedGroupKFold, StratifiedKFold +from preprocessing._labels import compute_label_and_tissue_prop + def derive_labels( dataset: Dataset, @@ -20,10 +22,7 @@ def derive_labels( """Derive label, tissue_prop, and slide_id arrays from the dataset.""" def compute(batch: dict[str, Any]) -> dict[str, Any]: - roi_df = pd.DataFrame({col: batch[col] for col in roi_cols}) - tp = roi_df.sum(axis=1).values - lbl = roi_df.idxmax(axis=1).str.removeprefix("roi_coverage_").values - lbl[tp == 0] = "background" + lbl, tp = compute_label_and_tissue_prop(batch, roi_cols) return {"label": lbl.tolist(), "tissue_prop": tp.tolist()} label_ds = dataset.select_columns(["slide_id", *roi_cols]).map( diff --git a/uv.lock b/uv.lock index 30b783b9..1e1ee3aa 100644 --- a/uv.lock +++ b/uv.lock @@ -2292,6 +2292,7 @@ dependencies = [ { name = "datasets" }, { name = "einops" }, { name = "hydra-core" }, + { name = "lightning" }, { name = "matplotlib" }, { name = "mlflow" }, { name = "numpy" }, @@ -2305,10 +2306,10 @@ dependencies = [ { name = "rationai-tiling" }, { name = "ratiopath" }, { name = "ray" }, - { name = "scikit-learn" }, { name = "tifffile" }, { name = "timm" }, { name = "torch" }, + { name = "torchmetrics" }, { name = "torchvision" }, { name = "tqdm" }, ] @@ -2323,17 +2324,16 @@ dev = [ [package.metadata] requires-dist = [ - { name = "datasets", specifier = ">=3.0.0" }, { name = "datasets", specifier = ">=4.0.0" }, { name = "einops", specifier = ">=0.8.0" }, { name = "hydra-core", specifier = ">=1.3.2" }, + { name = "lightning", specifier = ">=2.0.0" }, { name = "matplotlib", specifier = ">=3.10.7" }, { name = "mlflow", specifier = "<3.0.0" }, { name = "numpy", specifier = ">=2.3.5" }, { name = "omegaconf", specifier = ">=2.3.0" }, { name = "openslide-python", specifier = ">=1.4.2" }, { name = "pandas", specifier = ">=2.0.0" }, - { name = "pyarrow", specifier = ">=19.0.0" }, { name = "pyarrow", specifier = ">=19.0.1" }, { name = "rationai-masks" }, { name = "rationai-mlkit", git = "https://gitlab.ics.muni.cz/rationai/digital-pathology/libraries/mlkit.git" }, @@ -2341,10 +2341,10 @@ requires-dist = [ { name = "rationai-tiling", git = "https://gitlab.ics.muni.cz/rationai/digital-pathology/libraries/tiling.git" }, { name = "ratiopath", specifier = ">=1.2.0" }, { name = "ray", specifier = ">=2.51.1" }, - { name = "scikit-learn", specifier = ">=1.8.0" }, { name = "tifffile", specifier = ">=2025.12.20" }, { name = "timm", specifier = ">=1.0.0" }, { name = "torch", specifier = ">=2.0.0" }, + { name = "torchmetrics", specifier = ">=1.0.0" }, { name = "torchvision", specifier = ">=0.15.0" }, { name = "tqdm", specifier = ">=4.66.0" }, ]