From bef70dfd0502f33ac4224986da6a2bb19ed6cc1d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Fri, 8 May 2026 21:27:58 +0200 Subject: [PATCH 01/53] feat: add embedding dataset build pipeline Extract derive_labels logic to shared preprocessing/_labels.py, then use it in both split/kfold_split.py and the new embedding_dataset pipeline. The new pipeline joins k-fold (train) / filter_tiles (test) tile metadata with precomputed embeddings after applying tissue + per-dominant-class ROI thresholds, and emits a SlidesTilesLoader-compatible Parquet dataset as an MLflow artifact. Co-Authored-By: Claude Sonnet 4.6 --- .../preprocessing/embedding_dataset.yaml | 16 ++ configs/preprocessing/embedding_dataset.yaml | 13 ++ preprocessing/_labels.py | 24 ++ preprocessing/embedding_dataset.py | 212 ++++++++++++++++++ scripts/submit_embedding_dataset.py | 18 ++ split/kfold_split.py | 7 +- 6 files changed, 286 insertions(+), 4 deletions(-) create mode 100644 configs/experiment/preprocessing/embedding_dataset.yaml create mode 100644 configs/preprocessing/embedding_dataset.yaml create mode 100644 preprocessing/_labels.py create mode 100644 preprocessing/embedding_dataset.py create mode 100644 scripts/submit_embedding_dataset.py diff --git a/configs/experiment/preprocessing/embedding_dataset.yaml b/configs/experiment/preprocessing/embedding_dataset.yaml new file mode 100644 index 0000000..bfe24a0 --- /dev/null +++ b/configs/experiment/preprocessing/embedding_dataset.yaml @@ -0,0 +1,16 @@ +# @package _global_ + +defaults: + - /data: dataset + - _self_ + +tissue_prop_min: 0.5 +thresholds: ??? + +metadata: + run_name: Embedding dataset ${dataset.name} + description: "Join k-fold (${dataset.mlflow_artifacts.kfold_run_id}) and filter_tiles (${dataset.mlflow_artifacts.filter_tiles_run_id}) tile metadata with embeddings (${dataset.mlflow_artifacts.embedding_run_id})." + hyperparams: + kfold_run_id: ${dataset.mlflow_artifacts.kfold_run_id} + filter_tiles_run_id: ${dataset.mlflow_artifacts.filter_tiles_run_id} + embedding_run_id: ${dataset.mlflow_artifacts.embedding_run_id} diff --git a/configs/preprocessing/embedding_dataset.yaml b/configs/preprocessing/embedding_dataset.yaml new file mode 100644 index 0000000..f4af56a --- /dev/null +++ b/configs/preprocessing/embedding_dataset.yaml @@ -0,0 +1,13 @@ +# @package _global_ + +mlflow_artifact_path: embedding_dataset + +tissue_prop_min: ??? +thresholds: ??? + +metadata: + run_name: "Embedding dataset ${dataset.name}" + description: "Build embedding training dataset by joining k-fold/filter_tiles tile metadata with precomputed embeddings." + hyperparams: + tissue_prop_min: ${tissue_prop_min} + thresholds: ${thresholds} diff --git a/preprocessing/_labels.py b/preprocessing/_labels.py new file mode 100644 index 0000000..229f7dc --- /dev/null +++ b/preprocessing/_labels.py @@ -0,0 +1,24 @@ +"""Shared helpers for deriving tile labels from roi_coverage_* columns.""" + +from collections.abc import Mapping +from typing import Any + +import numpy as np +import pandas as pd + + +def compute_label_and_tissue_prop( + roi_data: Mapping[str, Any], + roi_cols: list[str], +) -> tuple[np.ndarray, np.ndarray]: + """Compute (label, tissue_prop) from roi_coverage_* columns. + + label = argmax across roi_cols (with ``roi_coverage_`` prefix stripped), + falling back to ``"background"`` whenever all coverages are zero. + tissue_prop = sum across roi_cols. + """ + roi_df = pd.DataFrame({col: roi_data[col] for col in roi_cols}) + tp = roi_df.sum(axis=1).to_numpy() + lbl = roi_df.idxmax(axis=1).str.removeprefix("roi_coverage_").to_numpy() + lbl[tp == 0] = "background" + return lbl, tp diff --git a/preprocessing/embedding_dataset.py b/preprocessing/embedding_dataset.py new file mode 100644 index 0000000..1e9028c --- /dev/null +++ b/preprocessing/embedding_dataset.py @@ -0,0 +1,212 @@ +"""Build an embedding training dataset by joining tile metadata with embeddings. + +Joins precomputed tile embeddings with k-fold metadata (train) / filter_tiles +metadata (test), applies tissue + per-class ROI thresholds before the join, and +emits a training-ready Parquet dataset (per-split ``slides.parquet`` + +``tiles.parquet``) ready for ``rationai.mlkit.data.datasets.SlidesTilesLoader``. +""" + +import shutil +import tempfile +from pathlib import Path + +import hydra +import mlflow +import mlflow.artifacts +import pandas as pd +import pyarrow.dataset as pads +from omegaconf import DictConfig, OmegaConf +from rationai.mlkit import autolog, with_cli_args +from rationai.mlkit.lightning.loggers import MLFlowLogger + +from preprocessing._labels import compute_label_and_tissue_prop + + +def apply_thresholds( + df: pd.DataFrame, + tissue_prop_min: float, + thresholds: dict[str, float], + roi_cols: list[str], +) -> tuple[pd.DataFrame, int]: + """Filter df by tissue_prop_min then by per-dominant-class roi threshold. + + Returns ``(filtered_df, after_tissue_count)`` so the caller can log both + intermediate counts. + """ + df = df[df["tissue_prop"] >= tissue_prop_min] + after_tissue = len(df) + if df.empty: + return df, after_tissue + + roi_only = df[roi_cols] + dominant = roi_only.idxmax(axis=1).str.removeprefix("roi_coverage_") + dominant_value = roi_only.max(axis=1).to_numpy() + threshold_per_row = dominant.map(thresholds).to_numpy() + keep = dominant_value >= threshold_per_row + return df[keep].copy(), after_tissue + + +def join_embeddings( + tiles_df: pd.DataFrame, + embedding_run_id: str, + embedding_split: str, +) -> tuple[pd.DataFrame, int]: + """Join filtered tile metadata with embeddings on (slide_id, x, y).""" + emb_dir = mlflow.artifacts.download_artifacts( + run_id=embedding_run_id, artifact_path=f"{embedding_split}/tiles" + ) + emb_table = pads.dataset(emb_dir, format="parquet").to_table( + columns=["slide_id", "x", "y", "embedding"] + ) + emb_df = emb_table.to_pandas() + del emb_table + + merged = tiles_df.merge(emb_df, on=["slide_id", "x", "y"], how="inner") + dropped_no_embedding = len(tiles_df) - len(merged) + return merged, dropped_no_embedding + + +def process_split( + split_name: str, + src_run_id: str, + src_artifact_path: str, + embedding_run_id: str, + tissue_prop_min: float, + thresholds: dict[str, float], + output_split_dir: Path, + derive: bool, +) -> dict[str, int]: + src_local = mlflow.artifacts.download_artifacts( + run_id=src_run_id, artifact_path=src_artifact_path + ) + df = pads.dataset(src_local, format="parquet").to_table().to_pandas() + input_count = len(df) + + roi_cols = [c for c in df.columns if c.startswith("roi_coverage_")] + if not roi_cols: + raise RuntimeError( + f"No roi_coverage_* columns in {src_artifact_path}. " + "Cannot apply class thresholds." + ) + + classes_in_data = {c.removeprefix("roi_coverage_") for c in roi_cols} + missing = classes_in_data - set(thresholds.keys()) + if missing: + raise ValueError( + f"thresholds is missing entries for roi_coverage_* classes present " + f"in data: {sorted(missing)}" + ) + + if derive: + lbl, tp = compute_label_and_tissue_prop(df, roi_cols) + df["label"] = lbl + df["tissue_prop"] = tp + + df, after_tissue_filter = apply_thresholds( + df, tissue_prop_min, thresholds, roi_cols + ) + after_class_threshold = len(df) + if after_class_threshold == 0: + raise RuntimeError( + f"All {input_count} tiles dropped by thresholds for split '{split_name}'." + ) + + drop_cols = [ + c for c in df.columns if c.startswith(("roi_coverage_", "tile_coverage_")) + ] + df = df.drop(columns=drop_cols) + + merged, dropped_no_embedding = join_embeddings(df, embedding_run_id, split_name) + if dropped_no_embedding != 0: + print( + f"WARNING: {dropped_no_embedding} tiles in split '{split_name}' have " + "no matching embedding and were dropped on join.", + flush=True, + ) + + merged = merged.sort_values("slide_id", kind="stable").reset_index(drop=True) + + output_split_dir.mkdir(parents=True, exist_ok=True) + merged.to_parquet(output_split_dir / "tiles.parquet", index=False) + + slides_local = mlflow.artifacts.download_artifacts( + run_id=embedding_run_id, artifact_path=f"{split_name}/slides.parquet" + ) + shutil.copy(slides_local, output_split_dir / "slides.parquet") + + log_label_distributions(split_name, merged) + + return { + "input_count": input_count, + "after_tissue_filter": after_tissue_filter, + "after_class_threshold": after_class_threshold, + "after_join": len(merged), + "dropped_no_embedding": dropped_no_embedding, + } + + +def log_label_distributions(split_name: str, df: pd.DataFrame) -> None: + label_dist = ( + df["label"].value_counts().rename_axis("label").reset_index(name="count") + ) + mlflow.log_table( + data=label_dist, + artifact_file=f"fold_statistics/{split_name}_label_distribution.json", + ) + + if "fold" in df.columns: + fold_dist = ( + df.groupby(["fold", "label"]).size().unstack(fill_value=0).reset_index() + ) + mlflow.log_table( + data=fold_dist, + artifact_file=f"fold_statistics/{split_name}_fold_label_distribution.json", + ) + + +@with_cli_args(["+preprocessing=embedding_dataset"]) +@hydra.main(config_path="../configs", config_name="preprocessing", version_base=None) +@autolog +def main(config: DictConfig, logger: MLFlowLogger) -> None: + artifacts = config.dataset.mlflow_artifacts + kfold_run_id = artifacts.kfold_run_id + filter_tiles_run_id = artifacts.filter_tiles_run_id + embedding_run_id = artifacts.embedding_run_id + + tissue_prop_min = float(config.tissue_prop_min) + if tissue_prop_min <= 0: + raise ValueError( + f"tissue_prop_min must be > 0 (got {tissue_prop_min}); " + "otherwise background tiles are not filtered out." + ) + raw_thresholds = OmegaConf.to_container(config.thresholds, resolve=True) + if not isinstance(raw_thresholds, dict): + raise TypeError("config.thresholds must be a mapping of class -> threshold") + thresholds = {str(k): float(v) for k, v in raw_thresholds.items()} + + splits = [ + ("train", kfold_run_id, "kfold_split/kfold_tiles.parquet", False), + ("test", filter_tiles_run_id, "filter_tiles/test_tiles.parquet", True), + ] + + with tempfile.TemporaryDirectory() as tmp_root: + tmp_root_path = Path(tmp_root) + for split_name, src_run_id, src_artifact_path, derive in splits: + stats = process_split( + split_name=split_name, + src_run_id=src_run_id, + src_artifact_path=src_artifact_path, + embedding_run_id=embedding_run_id, + tissue_prop_min=tissue_prop_min, + thresholds=thresholds, + output_split_dir=tmp_root_path / split_name, + derive=derive, + ) + for key, value in stats.items(): + mlflow.log_metric(f"{split_name}_{key}", value) + + mlflow.log_artifacts(str(tmp_root_path), config.mlflow_artifact_path) + + +if __name__ == "__main__": + main() diff --git a/scripts/submit_embedding_dataset.py b/scripts/submit_embedding_dataset.py new file mode 100644 index 0000000..bbe4063 --- /dev/null +++ b/scripts/submit_embedding_dataset.py @@ -0,0 +1,18 @@ +from kube_jobs import storage, submit_job + + +submit_job( + job_name="tissue-classification-embedding-dataset", + username=..., + cpu=8, + memory="32Gi", + gpu=None, + public=False, + script=[ + "git clone https://github.com/RationAI/tissue-classification.git workdir", + "cd workdir", + "uv sync", + "uv run python -m preprocessing.embedding_dataset +experiment=...", + ], + storage=[storage.secure.PROJECTS], +) diff --git a/split/kfold_split.py b/split/kfold_split.py index 150961e..a17c299 100644 --- a/split/kfold_split.py +++ b/split/kfold_split.py @@ -12,6 +12,8 @@ from rationai.mlkit.lightning.loggers import MLFlowLogger from sklearn.model_selection import StratifiedKFold +from preprocessing._labels import compute_label_and_tissue_prop + def derive_labels( dataset: Dataset, @@ -20,10 +22,7 @@ def derive_labels( """Derive label, tissue_prop, and slide_id arrays from the dataset.""" def compute(batch: dict[str, Any]) -> dict[str, Any]: - roi_df = pd.DataFrame({col: batch[col] for col in roi_cols}) - tp = roi_df.sum(axis=1).values - lbl = roi_df.idxmax(axis=1).str.removeprefix("roi_coverage_").values - lbl[tp == 0] = "background" + lbl, tp = compute_label_and_tissue_prop(batch, roi_cols) return {"label": lbl.tolist(), "tissue_prop": tp.tolist()} label_ds = dataset.select_columns(["slide_id", *roi_cols]).map( From 911bec2c7b40149a681ab7be97854ab32e06bdff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Fri, 8 May 2026 21:35:16 +0200 Subject: [PATCH 02/53] feat: add class tresholds and run ids --- configs/data/dataset.yaml | 2 ++ .../experiment/preprocessing/embedding_dataset.yaml | 11 +++++++++-- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/configs/data/dataset.yaml b/configs/data/dataset.yaml index e13fec8..732575b 100644 --- a/configs/data/dataset.yaml +++ b/configs/data/dataset.yaml @@ -14,6 +14,8 @@ dataset: test_split_filename: "split_mapping/test_split.csv" tiling_run_id: "fdf7550a2004474f8c7a05dc0cf1fd86" filter_tiles_run_id: "4e8f5d3c82124ea5a8f871a42d3ed9ba" + kfold_run_id: "850c81506684450b9af92296acfd045a" + embedding_run_id: "06d2d8eb088c4e04b04435940774c7aa" tissue_masks_run_id: "52bc0924f8624b259819c480c7cf213f" tissue_stats_run_id: "16ae2d003d88471b924e5f332415232a" diff --git a/configs/experiment/preprocessing/embedding_dataset.yaml b/configs/experiment/preprocessing/embedding_dataset.yaml index bfe24a0..71f3e68 100644 --- a/configs/experiment/preprocessing/embedding_dataset.yaml +++ b/configs/experiment/preprocessing/embedding_dataset.yaml @@ -4,8 +4,15 @@ defaults: - /data: dataset - _self_ -tissue_prop_min: 0.5 -thresholds: ??? +tissue_prop_min: 0.2 +thresholds: + Nerve: 0.0 + Blood: 0.0 + Connective-Tissue: 0.0 + Fat: 0.0 + Epithelium: 0.0 + Muscle: 0.0 + Other: 0.0 metadata: run_name: Embedding dataset ${dataset.name} From 1a0239537e59eddb586889085fd1fec9ad193e2e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Fri, 8 May 2026 21:43:31 +0200 Subject: [PATCH 03/53] fix: wrong run id --- configs/data/dataset.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/data/dataset.yaml b/configs/data/dataset.yaml index 732575b..0497d47 100644 --- a/configs/data/dataset.yaml +++ b/configs/data/dataset.yaml @@ -15,7 +15,7 @@ dataset: tiling_run_id: "fdf7550a2004474f8c7a05dc0cf1fd86" filter_tiles_run_id: "4e8f5d3c82124ea5a8f871a42d3ed9ba" kfold_run_id: "850c81506684450b9af92296acfd045a" - embedding_run_id: "06d2d8eb088c4e04b04435940774c7aa" + embedding_run_id: "f05076dcd5e64cb2839efe5fb20a22ae" tissue_masks_run_id: "52bc0924f8624b259819c480c7cf213f" tissue_stats_run_id: "16ae2d003d88471b924e5f332415232a" From b38465e6185acae770428314752d2c0cf26c7541 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Sat, 9 May 2026 16:34:37 +0200 Subject: [PATCH 04/53] feat: add timing --- configs/data/dataset.yaml | 2 +- preprocessing/embedding_dataset.py | 28 +++++++++++++++++++++++++--- 2 files changed, 26 insertions(+), 4 deletions(-) diff --git a/configs/data/dataset.yaml b/configs/data/dataset.yaml index 0497d47..0cf33e2 100644 --- a/configs/data/dataset.yaml +++ b/configs/data/dataset.yaml @@ -15,7 +15,7 @@ dataset: tiling_run_id: "fdf7550a2004474f8c7a05dc0cf1fd86" filter_tiles_run_id: "4e8f5d3c82124ea5a8f871a42d3ed9ba" kfold_run_id: "850c81506684450b9af92296acfd045a" - embedding_run_id: "f05076dcd5e64cb2839efe5fb20a22ae" + embedding_run_id: "5f323d5ef5a74026846ecbe8fbc007fb" tissue_masks_run_id: "52bc0924f8624b259819c480c7cf213f" tissue_stats_run_id: "16ae2d003d88471b924e5f332415232a" diff --git a/preprocessing/embedding_dataset.py b/preprocessing/embedding_dataset.py index 1e9028c..741ffdd 100644 --- a/preprocessing/embedding_dataset.py +++ b/preprocessing/embedding_dataset.py @@ -8,6 +8,7 @@ import shutil import tempfile +import time from pathlib import Path import hydra @@ -52,16 +53,28 @@ def join_embeddings( embedding_split: str, ) -> tuple[pd.DataFrame, int]: """Join filtered tile metadata with embeddings on (slide_id, x, y).""" + t0 = time.time() emb_dir = mlflow.artifacts.download_artifacts( run_id=embedding_run_id, artifact_path=f"{embedding_split}/tiles" ) - emb_table = pads.dataset(emb_dir, format="parquet").to_table( - columns=["slide_id", "x", "y", "embedding"] - ) + print(f"[timing] download embeddings: {time.time() - t0:.1f}s", flush=True) + + t0 = time.time() + emb_ds = pads.dataset(emb_dir, format="parquet") + print(f"[timing] embedding dataset has {emb_ds.count_rows()} rows", flush=True) + emb_table = emb_ds.to_table(columns=["slide_id", "x", "y", "embedding"]) + print(f"[timing] to_table: {time.time() - t0:.1f}s", flush=True) + + t0 = time.time() emb_df = emb_table.to_pandas() del emb_table + print(f"[timing] to_pandas: {time.time() - t0:.1f}s shape={emb_df.shape}", flush=True) + t0 = time.time() merged = tiles_df.merge(emb_df, on=["slide_id", "x", "y"], how="inner") + print(f"[timing] merge: {time.time() - t0:.1f}s shape={merged.shape}", flush=True) + del emb_df + dropped_no_embedding = len(tiles_df) - len(merged) return merged, dropped_no_embedding @@ -76,11 +89,14 @@ def process_split( output_split_dir: Path, derive: bool, ) -> dict[str, int]: + print(f"[{split_name}] downloading src tiles...", flush=True) + t0 = time.time() src_local = mlflow.artifacts.download_artifacts( run_id=src_run_id, artifact_path=src_artifact_path ) df = pads.dataset(src_local, format="parquet").to_table().to_pandas() input_count = len(df) + print(f"[{split_name}] src tiles loaded: {input_count} rows {time.time() - t0:.1f}s", flush=True) roi_cols = [c for c in df.columns if c.startswith("roi_coverage_")] if not roi_cols: @@ -115,6 +131,7 @@ def process_split( c for c in df.columns if c.startswith(("roi_coverage_", "tile_coverage_")) ] df = df.drop(columns=drop_cols) + print(f"[{split_name}] after thresholds: {after_class_threshold} rows, joining embeddings...", flush=True) merged, dropped_no_embedding = join_embeddings(df, embedding_run_id, split_name) if dropped_no_embedding != 0: @@ -124,11 +141,16 @@ def process_split( flush=True, ) + t0 = time.time() merged = merged.sort_values("slide_id", kind="stable").reset_index(drop=True) + print(f"[{split_name}] sort: {time.time() - t0:.1f}s", flush=True) output_split_dir.mkdir(parents=True, exist_ok=True) + t0 = time.time() merged.to_parquet(output_split_dir / "tiles.parquet", index=False) + print(f"[{split_name}] write parquet: {time.time() - t0:.1f}s", flush=True) + print(f"[{split_name}] downloading slides.parquet...", flush=True) slides_local = mlflow.artifacts.download_artifacts( run_id=embedding_run_id, artifact_path=f"{split_name}/slides.parquet" ) From bfc9578a83747a4a07eeb82a1d20ccc67368e5cc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Sat, 9 May 2026 16:50:40 +0200 Subject: [PATCH 05/53] refactor: use pyarrow to avoid to pandas conversion --- preprocessing/embedding_dataset.py | 49 ++++++++++++++++++------------ 1 file changed, 30 insertions(+), 19 deletions(-) diff --git a/preprocessing/embedding_dataset.py b/preprocessing/embedding_dataset.py index 741ffdd..771a229 100644 --- a/preprocessing/embedding_dataset.py +++ b/preprocessing/embedding_dataset.py @@ -15,7 +15,10 @@ import mlflow import mlflow.artifacts import pandas as pd +import pyarrow as pa +import pyarrow.compute as pc import pyarrow.dataset as pads +import pyarrow.parquet as pq from omegaconf import DictConfig, OmegaConf from rationai.mlkit import autolog, with_cli_args from rationai.mlkit.lightning.loggers import MLFlowLogger @@ -48,11 +51,15 @@ def apply_thresholds( def join_embeddings( - tiles_df: pd.DataFrame, + tiles_table: pa.Table, embedding_run_id: str, embedding_split: str, -) -> tuple[pd.DataFrame, int]: - """Join filtered tile metadata with embeddings on (slide_id, x, y).""" +) -> tuple[pa.Table, int]: + """Join filtered tile metadata with embeddings on (slide_id, x, y) using Arrow join. + + Stays entirely in Arrow to avoid the slow fixed-size-list to_pandas() conversion + on the embedding column. + """ t0 = time.time() emb_dir = mlflow.artifacts.download_artifacts( run_id=embedding_run_id, artifact_path=f"{embedding_split}/tiles" @@ -66,17 +73,12 @@ def join_embeddings( print(f"[timing] to_table: {time.time() - t0:.1f}s", flush=True) t0 = time.time() - emb_df = emb_table.to_pandas() + joined = tiles_table.join(emb_table, keys=["slide_id", "x", "y"], join_type="inner") del emb_table - print(f"[timing] to_pandas: {time.time() - t0:.1f}s shape={emb_df.shape}", flush=True) + print(f"[timing] arrow join: {time.time() - t0:.1f}s rows={joined.num_rows}", flush=True) - t0 = time.time() - merged = tiles_df.merge(emb_df, on=["slide_id", "x", "y"], how="inner") - print(f"[timing] merge: {time.time() - t0:.1f}s shape={merged.shape}", flush=True) - del emb_df - - dropped_no_embedding = len(tiles_df) - len(merged) - return merged, dropped_no_embedding + dropped_no_embedding = tiles_table.num_rows - joined.num_rows + return joined, dropped_no_embedding def process_split( @@ -133,7 +135,11 @@ def process_split( df = df.drop(columns=drop_cols) print(f"[{split_name}] after thresholds: {after_class_threshold} rows, joining embeddings...", flush=True) - merged, dropped_no_embedding = join_embeddings(df, embedding_run_id, split_name) + tiles_table = pa.Table.from_pandas(df, preserve_index=False) + del df + + merged_table, dropped_no_embedding = join_embeddings(tiles_table, embedding_run_id, split_name) + del tiles_table if dropped_no_embedding != 0: print( f"WARNING: {dropped_no_embedding} tiles in split '{split_name}' have " @@ -142,12 +148,13 @@ def process_split( ) t0 = time.time() - merged = merged.sort_values("slide_id", kind="stable").reset_index(drop=True) + sort_indices = pc.sort_indices(merged_table, sort_keys=[("slide_id", "ascending")]) + merged_table = merged_table.take(sort_indices) print(f"[{split_name}] sort: {time.time() - t0:.1f}s", flush=True) output_split_dir.mkdir(parents=True, exist_ok=True) t0 = time.time() - merged.to_parquet(output_split_dir / "tiles.parquet", index=False) + pq.write_table(merged_table, str(output_split_dir / "tiles.parquet")) print(f"[{split_name}] write parquet: {time.time() - t0:.1f}s", flush=True) print(f"[{split_name}] downloading slides.parquet...", flush=True) @@ -156,18 +163,22 @@ def process_split( ) shutil.copy(slides_local, output_split_dir / "slides.parquet") - log_label_distributions(split_name, merged) + log_label_distributions(split_name, merged_table) return { "input_count": input_count, "after_tissue_filter": after_tissue_filter, "after_class_threshold": after_class_threshold, - "after_join": len(merged), + "after_join": merged_table.num_rows, "dropped_no_embedding": dropped_no_embedding, } -def log_label_distributions(split_name: str, df: pd.DataFrame) -> None: +def log_label_distributions(split_name: str, table: pa.Table) -> None: + has_fold = "fold" in table.schema.names + cols = ["label", "fold"] if has_fold else ["label"] + df = table.select(cols).to_pandas() + label_dist = ( df["label"].value_counts().rename_axis("label").reset_index(name="count") ) @@ -176,7 +187,7 @@ def log_label_distributions(split_name: str, df: pd.DataFrame) -> None: artifact_file=f"fold_statistics/{split_name}_label_distribution.json", ) - if "fold" in df.columns: + if has_fold: fold_dist = ( df.groupby(["fold", "label"]).size().unstack(fill_value=0).reset_index() ) From eb213c6abc1312bd281b5bbb40a8abacf24c4d71 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Sat, 9 May 2026 20:23:21 +0200 Subject: [PATCH 06/53] fix: join on keys only --- preprocessing/embedding_dataset.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/preprocessing/embedding_dataset.py b/preprocessing/embedding_dataset.py index 771a229..061949e 100644 --- a/preprocessing/embedding_dataset.py +++ b/preprocessing/embedding_dataset.py @@ -72,9 +72,18 @@ def join_embeddings( emb_table = emb_ds.to_table(columns=["slide_id", "x", "y", "embedding"]) print(f"[timing] to_table: {time.time() - t0:.1f}s", flush=True) + # Arrow Acero join doesn't support list in non-key fields, so join on + # keys only using a row-index column, then pull embeddings via take(). t0 = time.time() - joined = tiles_table.join(emb_table, keys=["slide_id", "x", "y"], join_type="inner") - del emb_table + emb_idx = pa.array(range(emb_table.num_rows), type=pa.int32()) + emb_keys = emb_table.drop(["embedding"]).append_column("_emb_idx", emb_idx) + + joined_keys = tiles_table.join(emb_keys, keys=["slide_id", "x", "y"], join_type="inner") + embeddings = emb_table.column("embedding").take(joined_keys.column("_emb_idx")) + del emb_table, emb_keys, emb_idx + + joined = joined_keys.drop(["_emb_idx"]).append_column("embedding", embeddings) + del joined_keys, embeddings print(f"[timing] arrow join: {time.time() - t0:.1f}s rows={joined.num_rows}", flush=True) dropped_no_embedding = tiles_table.num_rows - joined.num_rows From c92d9a1a5879ae1e8d61231a5b4184983eb4d633 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Sat, 9 May 2026 20:35:35 +0200 Subject: [PATCH 07/53] fix: typing --- preprocessing/embedding_dataset.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/preprocessing/embedding_dataset.py b/preprocessing/embedding_dataset.py index 061949e..a193b85 100644 --- a/preprocessing/embedding_dataset.py +++ b/preprocessing/embedding_dataset.py @@ -74,13 +74,25 @@ def join_embeddings( # Arrow Acero join doesn't support list in non-key fields, so join on # keys only using a row-index column, then pull embeddings via take(). + # Cast embedding to large_list first: 1.1M rows * 768 doubles overflows int32 + # list offsets when chunks are concatenated by take(). t0 = time.time() + emb_col = emb_table.column("embedding") + if pa.types.is_list(emb_col.type): + emb_col = emb_col.cast(pa.large_list(emb_col.type.value_type)) + elif pa.types.is_fixed_size_list(emb_col.type): + pass # fixed_size_list has no offsets, no overflow risk + else: + emb_col = emb_col.cast(pa.large_list(pa.float64())) + emb_idx = pa.array(range(emb_table.num_rows), type=pa.int32()) emb_keys = emb_table.drop(["embedding"]).append_column("_emb_idx", emb_idx) + del emb_table, emb_idx joined_keys = tiles_table.join(emb_keys, keys=["slide_id", "x", "y"], join_type="inner") - embeddings = emb_table.column("embedding").take(joined_keys.column("_emb_idx")) - del emb_table, emb_keys, emb_idx + del emb_keys + embeddings = emb_col.take(joined_keys.column("_emb_idx")) + del emb_col joined = joined_keys.drop(["_emb_idx"]).append_column("embedding", embeddings) del joined_keys, embeddings From 01cc39450e00b0eba3a9573e08d493498cd71172 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Sat, 9 May 2026 21:24:45 +0200 Subject: [PATCH 08/53] fix: add prints --- preprocessing/embedding_dataset.py | 29 ++++++++++++++++++++--------- 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/preprocessing/embedding_dataset.py b/preprocessing/embedding_dataset.py index a193b85..4d868a8 100644 --- a/preprocessing/embedding_dataset.py +++ b/preprocessing/embedding_dataset.py @@ -74,29 +74,40 @@ def join_embeddings( # Arrow Acero join doesn't support list in non-key fields, so join on # keys only using a row-index column, then pull embeddings via take(). - # Cast embedding to large_list first: 1.1M rows * 768 doubles overflows int32 - # list offsets when chunks are concatenated by take(). - t0 = time.time() emb_col = emb_table.column("embedding") + print(f"[timing] embedding column type={emb_col.type}, num_chunks={emb_col.num_chunks}", flush=True) + + # Cast per chunk to large_list to avoid the int32 offset overflow that hits + # when take() concatenates chunks of list. Per-chunk casts touch + # the offset buffer only (each chunk individually fits int32). + t0 = time.time() if pa.types.is_list(emb_col.type): - emb_col = emb_col.cast(pa.large_list(emb_col.type.value_type)) - elif pa.types.is_fixed_size_list(emb_col.type): - pass # fixed_size_list has no offsets, no overflow risk - else: - emb_col = emb_col.cast(pa.large_list(pa.float64())) + target_type = pa.large_list(emb_col.type.value_type) + new_chunks = [c.cast(target_type) for c in emb_col.chunks] + emb_col = pa.chunked_array(new_chunks, type=target_type) + del new_chunks + print(f"[timing] cast embedding to large_list: {time.time() - t0:.1f}s", flush=True) + t0 = time.time() emb_idx = pa.array(range(emb_table.num_rows), type=pa.int32()) emb_keys = emb_table.drop(["embedding"]).append_column("_emb_idx", emb_idx) del emb_table, emb_idx + print(f"[timing] build emb_keys: {time.time() - t0:.1f}s", flush=True) + t0 = time.time() joined_keys = tiles_table.join(emb_keys, keys=["slide_id", "x", "y"], join_type="inner") del emb_keys + print(f"[timing] arrow key-join: {time.time() - t0:.1f}s rows={joined_keys.num_rows}", flush=True) + + t0 = time.time() embeddings = emb_col.take(joined_keys.column("_emb_idx")) del emb_col + print(f"[timing] take embeddings: {time.time() - t0:.1f}s", flush=True) + t0 = time.time() joined = joined_keys.drop(["_emb_idx"]).append_column("embedding", embeddings) del joined_keys, embeddings - print(f"[timing] arrow join: {time.time() - t0:.1f}s rows={joined.num_rows}", flush=True) + print(f"[timing] assemble joined table: {time.time() - t0:.1f}s rows={joined.num_rows}", flush=True) dropped_no_embedding = tiles_table.num_rows - joined.num_rows return joined, dropped_no_embedding From cad0d376e0d4eb132125ade9b7654be4c36288df Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Sat, 9 May 2026 21:35:58 +0200 Subject: [PATCH 09/53] refactor: use combine chunks --- preprocessing/embedding_dataset.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/preprocessing/embedding_dataset.py b/preprocessing/embedding_dataset.py index 4d868a8..73987dd 100644 --- a/preprocessing/embedding_dataset.py +++ b/preprocessing/embedding_dataset.py @@ -100,8 +100,16 @@ def join_embeddings( print(f"[timing] arrow key-join: {time.time() - t0:.1f}s rows={joined_keys.num_rows}", flush=True) t0 = time.time() - embeddings = emb_col.take(joined_keys.column("_emb_idx")) + emb_array = emb_col.combine_chunks() del emb_col + print(f"[timing] combine_chunks: {time.time() - t0:.1f}s", flush=True) + + t0 = time.time() + indices = joined_keys.column("_emb_idx") + if isinstance(indices, pa.ChunkedArray): + indices = indices.combine_chunks() + embeddings = emb_array.take(indices) + del emb_array print(f"[timing] take embeddings: {time.time() - t0:.1f}s", flush=True) t0 = time.time() From ae045526dd1f97096ff02a23740eab3fed44faf0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Sat, 9 May 2026 22:21:49 +0200 Subject: [PATCH 10/53] fix: lazy-cast embeddings to large_list and stay in Arrow during join Joining 1M+ rows of list embeddings was either OOMing on to_pandas() or hitting int32 list-offset overflow inside take(). The fix: - read embeddings into Arrow only and cast each chunk to large_list so take() concatenation uses int64 offsets; - run the join on keys plus a synthetic row index because Acero refuses list columns in non-key fields, then pull embeddings via take(); - combine_chunks() before take() for an O(N) single-pass copy; - write the parquet straight from Arrow, never materialising the embedding column in pandas. Also bumps the kube job memory to 64Gi to give the combined-chunks + take() peak some headroom, and trims the verbose [timing] prints down to one progress line per split. Co-Authored-By: Claude Sonnet 4.6 --- preprocessing/embedding_dataset.py | 83 +++++++++++------------------ scripts/submit_embedding_dataset.py | 2 +- 2 files changed, 31 insertions(+), 54 deletions(-) diff --git a/preprocessing/embedding_dataset.py b/preprocessing/embedding_dataset.py index 73987dd..098759c 100644 --- a/preprocessing/embedding_dataset.py +++ b/preprocessing/embedding_dataset.py @@ -8,7 +8,6 @@ import shutil import tempfile -import time from pathlib import Path import hydra @@ -55,68 +54,44 @@ def join_embeddings( embedding_run_id: str, embedding_split: str, ) -> tuple[pa.Table, int]: - """Join filtered tile metadata with embeddings on (slide_id, x, y) using Arrow join. + """Join filtered tile metadata with embeddings on (slide_id, x, y). - Stays entirely in Arrow to avoid the slow fixed-size-list to_pandas() conversion - on the embedding column. + Stays in Arrow throughout to avoid the very slow list -> pandas + conversion. Acero's join engine doesn't accept list columns in non-key + fields, so we join on keys plus a synthetic row index, then pull embeddings + via take(). The embedding column is cast per chunk to large_list to avoid + int32 offset overflow that bites take() when chunks are concatenated. """ - t0 = time.time() emb_dir = mlflow.artifacts.download_artifacts( run_id=embedding_run_id, artifact_path=f"{embedding_split}/tiles" ) - print(f"[timing] download embeddings: {time.time() - t0:.1f}s", flush=True) - - t0 = time.time() - emb_ds = pads.dataset(emb_dir, format="parquet") - print(f"[timing] embedding dataset has {emb_ds.count_rows()} rows", flush=True) - emb_table = emb_ds.to_table(columns=["slide_id", "x", "y", "embedding"]) - print(f"[timing] to_table: {time.time() - t0:.1f}s", flush=True) + emb_table = pads.dataset(emb_dir, format="parquet").to_table( + columns=["slide_id", "x", "y", "embedding"] + ) - # Arrow Acero join doesn't support list in non-key fields, so join on - # keys only using a row-index column, then pull embeddings via take(). emb_col = emb_table.column("embedding") - print(f"[timing] embedding column type={emb_col.type}, num_chunks={emb_col.num_chunks}", flush=True) - - # Cast per chunk to large_list to avoid the int32 offset overflow that hits - # when take() concatenates chunks of list. Per-chunk casts touch - # the offset buffer only (each chunk individually fits int32). - t0 = time.time() if pa.types.is_list(emb_col.type): target_type = pa.large_list(emb_col.type.value_type) - new_chunks = [c.cast(target_type) for c in emb_col.chunks] - emb_col = pa.chunked_array(new_chunks, type=target_type) - del new_chunks - print(f"[timing] cast embedding to large_list: {time.time() - t0:.1f}s", flush=True) + emb_col = pa.chunked_array( + [c.cast(target_type) for c in emb_col.chunks], type=target_type + ) - t0 = time.time() emb_idx = pa.array(range(emb_table.num_rows), type=pa.int32()) emb_keys = emb_table.drop(["embedding"]).append_column("_emb_idx", emb_idx) del emb_table, emb_idx - print(f"[timing] build emb_keys: {time.time() - t0:.1f}s", flush=True) - t0 = time.time() - joined_keys = tiles_table.join(emb_keys, keys=["slide_id", "x", "y"], join_type="inner") + joined_keys = tiles_table.join( + emb_keys, keys=["slide_id", "x", "y"], join_type="inner" + ) del emb_keys - print(f"[timing] arrow key-join: {time.time() - t0:.1f}s rows={joined_keys.num_rows}", flush=True) - - t0 = time.time() - emb_array = emb_col.combine_chunks() - del emb_col - print(f"[timing] combine_chunks: {time.time() - t0:.1f}s", flush=True) - t0 = time.time() indices = joined_keys.column("_emb_idx") if isinstance(indices, pa.ChunkedArray): indices = indices.combine_chunks() - embeddings = emb_array.take(indices) - del emb_array - print(f"[timing] take embeddings: {time.time() - t0:.1f}s", flush=True) + embeddings = emb_col.combine_chunks().take(indices) + del emb_col - t0 = time.time() joined = joined_keys.drop(["_emb_idx"]).append_column("embedding", embeddings) - del joined_keys, embeddings - print(f"[timing] assemble joined table: {time.time() - t0:.1f}s rows={joined.num_rows}", flush=True) - dropped_no_embedding = tiles_table.num_rows - joined.num_rows return joined, dropped_no_embedding @@ -131,14 +106,12 @@ def process_split( output_split_dir: Path, derive: bool, ) -> dict[str, int]: - print(f"[{split_name}] downloading src tiles...", flush=True) - t0 = time.time() + print(f"[{split_name}] downloading source tiles", flush=True) src_local = mlflow.artifacts.download_artifacts( run_id=src_run_id, artifact_path=src_artifact_path ) df = pads.dataset(src_local, format="parquet").to_table().to_pandas() input_count = len(df) - print(f"[{split_name}] src tiles loaded: {input_count} rows {time.time() - t0:.1f}s", flush=True) roi_cols = [c for c in df.columns if c.startswith("roi_coverage_")] if not roi_cols: @@ -173,12 +146,18 @@ def process_split( c for c in df.columns if c.startswith(("roi_coverage_", "tile_coverage_")) ] df = df.drop(columns=drop_cols) - print(f"[{split_name}] after thresholds: {after_class_threshold} rows, joining embeddings...", flush=True) + print( + f"[{split_name}] {input_count} -> {after_tissue_filter} (tissue) " + f"-> {after_class_threshold} (class threshold), joining embeddings", + flush=True, + ) tiles_table = pa.Table.from_pandas(df, preserve_index=False) del df - merged_table, dropped_no_embedding = join_embeddings(tiles_table, embedding_run_id, split_name) + merged_table, dropped_no_embedding = join_embeddings( + tiles_table, embedding_run_id, split_name + ) del tiles_table if dropped_no_embedding != 0: print( @@ -187,23 +166,21 @@ def process_split( flush=True, ) - t0 = time.time() - sort_indices = pc.sort_indices(merged_table, sort_keys=[("slide_id", "ascending")]) + sort_indices = pc.sort_indices( + merged_table, sort_keys=[("slide_id", "ascending")] + ) merged_table = merged_table.take(sort_indices) - print(f"[{split_name}] sort: {time.time() - t0:.1f}s", flush=True) output_split_dir.mkdir(parents=True, exist_ok=True) - t0 = time.time() pq.write_table(merged_table, str(output_split_dir / "tiles.parquet")) - print(f"[{split_name}] write parquet: {time.time() - t0:.1f}s", flush=True) - print(f"[{split_name}] downloading slides.parquet...", flush=True) slides_local = mlflow.artifacts.download_artifacts( run_id=embedding_run_id, artifact_path=f"{split_name}/slides.parquet" ) shutil.copy(slides_local, output_split_dir / "slides.parquet") log_label_distributions(split_name, merged_table) + print(f"[{split_name}] wrote {merged_table.num_rows} rows", flush=True) return { "input_count": input_count, diff --git a/scripts/submit_embedding_dataset.py b/scripts/submit_embedding_dataset.py index bbe4063..23977df 100644 --- a/scripts/submit_embedding_dataset.py +++ b/scripts/submit_embedding_dataset.py @@ -5,7 +5,7 @@ job_name="tissue-classification-embedding-dataset", username=..., cpu=8, - memory="32Gi", + memory="64Gi", gpu=None, public=False, script=[ From 82320db480999dec51cac0f3ae32f59501d83386 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Sat, 9 May 2026 22:23:03 +0200 Subject: [PATCH 11/53] fix: validate label/tissue_prop columns when derive=False Without this guard a malformed train artifact would crash deep inside apply_thresholds with a confusing KeyError. Surface a clear error that points at the expected upstream artifact instead. Co-Authored-By: Claude Sonnet 4.6 --- preprocessing/embedding_dataset.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/preprocessing/embedding_dataset.py b/preprocessing/embedding_dataset.py index 098759c..22242fc 100644 --- a/preprocessing/embedding_dataset.py +++ b/preprocessing/embedding_dataset.py @@ -132,6 +132,15 @@ def process_split( lbl, tp = compute_label_and_tissue_prop(df, roi_cols) df["label"] = lbl df["tissue_prop"] = tp + else: + required = {"label", "tissue_prop"} + missing_required = required - set(df.columns) + if missing_required: + raise RuntimeError( + f"Source split '{split_name}' (derive=False) is missing required " + f"columns {sorted(missing_required)} in {src_artifact_path}. " + "Expected the kfold_split artifact, which writes label/tissue_prop/fold." + ) df, after_tissue_filter = apply_thresholds( df, tissue_prop_min, thresholds, roi_cols From 3b0137f95bff66c83b4516c1e0fee89928cf028b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Sat, 9 May 2026 22:35:46 +0200 Subject: [PATCH 12/53] chore: remove time --- preprocessing/embedding_dataset.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/preprocessing/embedding_dataset.py b/preprocessing/embedding_dataset.py index 22242fc..5fd263b 100644 --- a/preprocessing/embedding_dataset.py +++ b/preprocessing/embedding_dataset.py @@ -175,9 +175,7 @@ def process_split( flush=True, ) - sort_indices = pc.sort_indices( - merged_table, sort_keys=[("slide_id", "ascending")] - ) + sort_indices = pc.sort_indices(merged_table, sort_keys=[("slide_id", "ascending")]) merged_table = merged_table.take(sort_indices) output_split_dir.mkdir(parents=True, exist_ok=True) From 8df47aae009fedc541dddbf4f28aa37e37838b1d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Sun, 10 May 2026 17:20:58 +0200 Subject: [PATCH 13/53] feat: add timing --- preprocessing/embedding_dataset.py | 36 +++++++++++++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/preprocessing/embedding_dataset.py b/preprocessing/embedding_dataset.py index 5fd263b..9a7eea6 100644 --- a/preprocessing/embedding_dataset.py +++ b/preprocessing/embedding_dataset.py @@ -8,6 +8,7 @@ import shutil import tempfile +import time from pathlib import Path import hydra @@ -62,34 +63,67 @@ def join_embeddings( via take(). The embedding column is cast per chunk to large_list to avoid int32 offset overflow that bites take() when chunks are concatenated. """ + t = time.time() + print("[join] downloading embeddings", flush=True) emb_dir = mlflow.artifacts.download_artifacts( run_id=embedding_run_id, artifact_path=f"{embedding_split}/tiles" ) + print(f"[join] download: {time.time() - t:.1f}s", flush=True) + + t = time.time() + print("[join] to_table", flush=True) emb_table = pads.dataset(emb_dir, format="parquet").to_table( columns=["slide_id", "x", "y", "embedding"] ) + print( + f"[join] to_table: {time.time() - t:.1f}s rows={emb_table.num_rows} " + f"chunks={emb_table.column('embedding').num_chunks}", + flush=True, + ) + t = time.time() emb_col = emb_table.column("embedding") if pa.types.is_list(emb_col.type): target_type = pa.large_list(emb_col.type.value_type) emb_col = pa.chunked_array( [c.cast(target_type) for c in emb_col.chunks], type=target_type ) + print(f"[join] cast to large_list: {time.time() - t:.1f}s", flush=True) + t = time.time() emb_idx = pa.array(range(emb_table.num_rows), type=pa.int32()) emb_keys = emb_table.drop(["embedding"]).append_column("_emb_idx", emb_idx) del emb_table, emb_idx + print(f"[join] build keys table: {time.time() - t:.1f}s", flush=True) + t = time.time() + print("[join] arrow key-join", flush=True) joined_keys = tiles_table.join( emb_keys, keys=["slide_id", "x", "y"], join_type="inner" ) del emb_keys + print( + f"[join] arrow key-join: {time.time() - t:.1f}s rows={joined_keys.num_rows}", + flush=True, + ) + t = time.time() indices = joined_keys.column("_emb_idx") if isinstance(indices, pa.ChunkedArray): indices = indices.combine_chunks() - embeddings = emb_col.combine_chunks().take(indices) + print(f"[join] combine indices: {time.time() - t:.1f}s", flush=True) + + t = time.time() + print("[join] combine_chunks(embeddings)", flush=True) + emb_contig = emb_col.combine_chunks() del emb_col + print(f"[join] combine_chunks: {time.time() - t:.1f}s", flush=True) + + t = time.time() + print("[join] take(embeddings)", flush=True) + embeddings = emb_contig.take(indices) + del emb_contig + print(f"[join] take: {time.time() - t:.1f}s", flush=True) joined = joined_keys.drop(["_emb_idx"]).append_column("embedding", embeddings) dropped_no_embedding = tiles_table.num_rows - joined.num_rows From 926753d54e072d70b121b70c13cc840809d20527 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Sun, 10 May 2026 17:48:01 +0200 Subject: [PATCH 14/53] chore: revert to the previous state --- preprocessing/embedding_dataset.py | 36 +----------------------------- 1 file changed, 1 insertion(+), 35 deletions(-) diff --git a/preprocessing/embedding_dataset.py b/preprocessing/embedding_dataset.py index 9a7eea6..5fd263b 100644 --- a/preprocessing/embedding_dataset.py +++ b/preprocessing/embedding_dataset.py @@ -8,7 +8,6 @@ import shutil import tempfile -import time from pathlib import Path import hydra @@ -63,67 +62,34 @@ def join_embeddings( via take(). The embedding column is cast per chunk to large_list to avoid int32 offset overflow that bites take() when chunks are concatenated. """ - t = time.time() - print("[join] downloading embeddings", flush=True) emb_dir = mlflow.artifacts.download_artifacts( run_id=embedding_run_id, artifact_path=f"{embedding_split}/tiles" ) - print(f"[join] download: {time.time() - t:.1f}s", flush=True) - - t = time.time() - print("[join] to_table", flush=True) emb_table = pads.dataset(emb_dir, format="parquet").to_table( columns=["slide_id", "x", "y", "embedding"] ) - print( - f"[join] to_table: {time.time() - t:.1f}s rows={emb_table.num_rows} " - f"chunks={emb_table.column('embedding').num_chunks}", - flush=True, - ) - t = time.time() emb_col = emb_table.column("embedding") if pa.types.is_list(emb_col.type): target_type = pa.large_list(emb_col.type.value_type) emb_col = pa.chunked_array( [c.cast(target_type) for c in emb_col.chunks], type=target_type ) - print(f"[join] cast to large_list: {time.time() - t:.1f}s", flush=True) - t = time.time() emb_idx = pa.array(range(emb_table.num_rows), type=pa.int32()) emb_keys = emb_table.drop(["embedding"]).append_column("_emb_idx", emb_idx) del emb_table, emb_idx - print(f"[join] build keys table: {time.time() - t:.1f}s", flush=True) - t = time.time() - print("[join] arrow key-join", flush=True) joined_keys = tiles_table.join( emb_keys, keys=["slide_id", "x", "y"], join_type="inner" ) del emb_keys - print( - f"[join] arrow key-join: {time.time() - t:.1f}s rows={joined_keys.num_rows}", - flush=True, - ) - t = time.time() indices = joined_keys.column("_emb_idx") if isinstance(indices, pa.ChunkedArray): indices = indices.combine_chunks() - print(f"[join] combine indices: {time.time() - t:.1f}s", flush=True) - - t = time.time() - print("[join] combine_chunks(embeddings)", flush=True) - emb_contig = emb_col.combine_chunks() + embeddings = emb_col.combine_chunks().take(indices) del emb_col - print(f"[join] combine_chunks: {time.time() - t:.1f}s", flush=True) - - t = time.time() - print("[join] take(embeddings)", flush=True) - embeddings = emb_contig.take(indices) - del emb_contig - print(f"[join] take: {time.time() - t:.1f}s", flush=True) joined = joined_keys.drop(["_emb_idx"]).append_column("embedding", embeddings) dropped_no_embedding = tiles_table.num_rows - joined.num_rows From b0e9ba4290f3078a2f173168f0bb76b3970f3acf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Sun, 10 May 2026 20:21:51 +0200 Subject: [PATCH 15/53] feat: add prints --- preprocessing/embedding_dataset.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/preprocessing/embedding_dataset.py b/preprocessing/embedding_dataset.py index 5fd263b..c201020 100644 --- a/preprocessing/embedding_dataset.py +++ b/preprocessing/embedding_dataset.py @@ -8,6 +8,7 @@ import shutil import tempfile +import time from pathlib import Path import hydra @@ -80,16 +81,29 @@ def join_embeddings( emb_keys = emb_table.drop(["embedding"]).append_column("_emb_idx", emb_idx) del emb_table, emb_idx + t = time.time() joined_keys = tiles_table.join( emb_keys, keys=["slide_id", "x", "y"], join_type="inner" ) del emb_keys + print( + f"[join] arrow key-join: {time.time() - t:.1f}s rows={joined_keys.num_rows}", + flush=True, + ) indices = joined_keys.column("_emb_idx") if isinstance(indices, pa.ChunkedArray): indices = indices.combine_chunks() - embeddings = emb_col.combine_chunks().take(indices) + + t = time.time() + emb_contig = emb_col.combine_chunks() del emb_col + print(f"[join] combine_chunks: {time.time() - t:.1f}s", flush=True) + + t = time.time() + embeddings = emb_contig.take(indices) + del emb_contig + print(f"[join] take: {time.time() - t:.1f}s", flush=True) joined = joined_keys.drop(["_emb_idx"]).append_column("embedding", embeddings) dropped_no_embedding = tiles_table.num_rows - joined.num_rows From 6a915de65f4c8cf62208dcf04cf02ca5d742f1fa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Mon, 11 May 2026 10:03:36 +0200 Subject: [PATCH 16/53] refactor: use discusssed thresholds --- .../experiment/preprocessing/embedding_dataset.yaml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/configs/experiment/preprocessing/embedding_dataset.yaml b/configs/experiment/preprocessing/embedding_dataset.yaml index 71f3e68..8004e2e 100644 --- a/configs/experiment/preprocessing/embedding_dataset.yaml +++ b/configs/experiment/preprocessing/embedding_dataset.yaml @@ -8,11 +8,11 @@ tissue_prop_min: 0.2 thresholds: Nerve: 0.0 Blood: 0.0 - Connective-Tissue: 0.0 - Fat: 0.0 - Epithelium: 0.0 - Muscle: 0.0 - Other: 0.0 + Connective-Tissue: 0.4 + Fat: 0.5 + Epithelium: 0.2 + Muscle: 0.4 + Other: 0.5 metadata: run_name: Embedding dataset ${dataset.name} From 0f50307daba3d5e7433b9bf7dddceda14a05f785 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Mon, 11 May 2026 10:15:05 +0200 Subject: [PATCH 17/53] refactor: use different labeling strategy --- preprocessing/embedding_dataset.py | 26 ++++++++++++++++++-------- 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/preprocessing/embedding_dataset.py b/preprocessing/embedding_dataset.py index c201020..a1e6545 100644 --- a/preprocessing/embedding_dataset.py +++ b/preprocessing/embedding_dataset.py @@ -14,6 +14,7 @@ import hydra import mlflow import mlflow.artifacts +import numpy as np import pandas as pd import pyarrow as pa import pyarrow.compute as pc @@ -32,22 +33,31 @@ def apply_thresholds( thresholds: dict[str, float], roi_cols: list[str], ) -> tuple[pd.DataFrame, int]: - """Filter df by tissue_prop_min then by per-dominant-class roi threshold. + """Filter df by tissue_prop_min, then keep tiles where ANY class meets its + threshold; among passing classes, the highest-coverage one becomes the label. Returns ``(filtered_df, after_tissue_count)`` so the caller can log both - intermediate counts. + intermediate counts. The returned df has its ``label`` column rewritten to + reflect the argmax-over-passers rule. """ df = df[df["tissue_prop"] >= tissue_prop_min] after_tissue = len(df) if df.empty: return df, after_tissue - roi_only = df[roi_cols] - dominant = roi_only.idxmax(axis=1).str.removeprefix("roi_coverage_") - dominant_value = roi_only.max(axis=1).to_numpy() - threshold_per_row = dominant.map(thresholds).to_numpy() - keep = dominant_value >= threshold_per_row - return df[keep].copy(), after_tissue + class_names = np.array([c.removeprefix("roi_coverage_") for c in roi_cols]) + thr = np.array([thresholds[c] for c in class_names], dtype=float) + roi = df[roi_cols].to_numpy() + passes = roi >= thr + keep = passes.any(axis=1) + + masked = np.where(passes, roi, -np.inf) + label_idx = masked.argmax(axis=1) + new_labels = class_names[label_idx] + + out = df[keep].copy() + out["label"] = new_labels[keep] + return out, after_tissue def join_embeddings( From 4d953dca0d0a619dd10291168c357725b96fed6b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Mon, 11 May 2026 12:41:07 +0200 Subject: [PATCH 18/53] feat: implement training pipeline --- configs/data/dataset.yaml | 1 + configs/experiment/ml/linear_classifier.yaml | 28 +++ configs/ml/data/embedding.yaml | 24 ++ configs/ml/model/linear_classifier.yaml | 18 ++ configs/ml/trainer/default.yaml | 29 +++ ml/__init__.py | 0 ml/__main__.py | 31 +++ ml/callbacks/__init__.py | 4 + ml/callbacks/parquet_prediction_writer.py | 77 +++++++ ml/data/__init__.py | 4 + ml/data/data_module.py | 64 ++++++ ml/data/datasets/__init__.py | 4 + ml/data/datasets/embedding_tiles.py | 78 +++++++ ml/meta_arch.py | 220 +++++++++++++++++++ ml/modeling/__init__.py | 0 ml/typing.py | 6 + preprocessing/embedding_dataset.py | 11 +- pyproject.toml | 2 + scripts/submit_train_linear.py | 18 ++ uv.lock | 4 + 20 files changed, 619 insertions(+), 4 deletions(-) create mode 100644 configs/experiment/ml/linear_classifier.yaml create mode 100644 configs/ml/data/embedding.yaml create mode 100644 configs/ml/model/linear_classifier.yaml create mode 100644 configs/ml/trainer/default.yaml create mode 100644 ml/__init__.py create mode 100644 ml/__main__.py create mode 100644 ml/callbacks/__init__.py create mode 100644 ml/callbacks/parquet_prediction_writer.py create mode 100644 ml/data/__init__.py create mode 100644 ml/data/data_module.py create mode 100644 ml/data/datasets/__init__.py create mode 100644 ml/data/datasets/embedding_tiles.py create mode 100644 ml/meta_arch.py create mode 100644 ml/modeling/__init__.py create mode 100644 ml/typing.py create mode 100644 scripts/submit_train_linear.py diff --git a/configs/data/dataset.yaml b/configs/data/dataset.yaml index 0cf33e2..57e1925 100644 --- a/configs/data/dataset.yaml +++ b/configs/data/dataset.yaml @@ -16,6 +16,7 @@ dataset: filter_tiles_run_id: "4e8f5d3c82124ea5a8f871a42d3ed9ba" kfold_run_id: "850c81506684450b9af92296acfd045a" embedding_run_id: "5f323d5ef5a74026846ecbe8fbc007fb" + embedding_dataset_run_id: "3ab86e376d38481dbac5bc352f7ac7c9" tissue_masks_run_id: "52bc0924f8624b259819c480c7cf213f" tissue_stats_run_id: "16ae2d003d88471b924e5f332415232a" diff --git a/configs/experiment/ml/linear_classifier.yaml b/configs/experiment/ml/linear_classifier.yaml new file mode 100644 index 0000000..6796489 --- /dev/null +++ b/configs/experiment/ml/linear_classifier.yaml @@ -0,0 +1,28 @@ +# @package _global_ + +defaults: + - /data: dataset + - /class_mapping: collapse_alterations_to_other + - /ml/trainer: default + - /ml/data: embedding + - /ml/model: linear_classifier + - _self_ + +mode: fit + +embedding_dataset_run_id: ${dataset.mlflow_artifacts.embedding_dataset_run_id} +train_tiles_uri: runs:/${embedding_dataset_run_id}/embedding_dataset/train/tiles.parquet +test_tiles_uri: runs:/${embedding_dataset_run_id}/embedding_dataset/test/tiles.parquet +val_fold: 0 + +mlflow_artifact_path: linear_classifier + +metadata: + run_name: Linear Classifier ${dataset.name} fold=${val_fold} + description: "Linear probe over frozen Virchow2 embeddings produced by embedding_dataset run ${embedding_dataset_run_id}." + hyperparams: + embedding_dataset_run_id: ${embedding_dataset_run_id} + val_fold: ${val_fold} + learning_rate: ${model.learning_rate} + weight_decay: ${model.weight_decay} + batch_size: ${data.batch_size} diff --git a/configs/ml/data/embedding.yaml b/configs/ml/data/embedding.yaml new file mode 100644 index 0000000..d80ca0c --- /dev/null +++ b/configs/ml/data/embedding.yaml @@ -0,0 +1,24 @@ +# @package _global_ + +data: + batch_size: 1024 + num_workers: 4 + + train: + _target_: ml.data.datasets.EmbeddingTilesDataset + path_or_uri: ${train_tiles_uri} + class_indices: ${class_indices} + exclude_folds: + - ${val_fold} + + val: + _target_: ml.data.datasets.EmbeddingTilesDataset + path_or_uri: ${train_tiles_uri} + class_indices: ${class_indices} + include_folds: + - ${val_fold} + + test: + _target_: ml.data.datasets.EmbeddingTilesDataset + path_or_uri: ${test_tiles_uri} + class_indices: ${class_indices} diff --git a/configs/ml/model/linear_classifier.yaml b/configs/ml/model/linear_classifier.yaml new file mode 100644 index 0000000..dfff43c --- /dev/null +++ b/configs/ml/model/linear_classifier.yaml @@ -0,0 +1,18 @@ +# @package _global_ + +model: + backbone: + _target_: torch.nn.Identity + + decode_head: + _target_: torch.nn.Linear + in_features: 2560 + out_features: ${len:${class_indices}} + + criterion: + _target_: torch.nn.CrossEntropyLoss + + class_indices: ${class_indices} + + learning_rate: 1.0e-3 + weight_decay: 0.0 diff --git a/configs/ml/trainer/default.yaml b/configs/ml/trainer/default.yaml new file mode 100644 index 0000000..cf5766d --- /dev/null +++ b/configs/ml/trainer/default.yaml @@ -0,0 +1,29 @@ +# @package _global_ + +trainer: + max_epochs: 50 + accelerator: auto + devices: auto + precision: 32 + log_every_n_steps: 50 + deterministic: false + + callbacks: + early_stopping: + _target_: lightning.pytorch.callbacks.EarlyStopping + monitor: validation/loss + mode: min + patience: 5 + model_checkpoint: + _target_: lightning.pytorch.callbacks.ModelCheckpoint + monitor: validation/loss + mode: min + save_top_k: 1 + filename: "epoch={epoch}-val_loss={validation/loss:.4f}" + auto_insert_metric_name: false + lr_monitor: + _target_: lightning.pytorch.callbacks.LearningRateMonitor + logging_interval: epoch + prediction_writer: + _target_: ml.callbacks.ParquetPredictionWriter + output_filename: predictions.parquet diff --git a/ml/__init__.py b/ml/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/ml/__main__.py b/ml/__main__.py new file mode 100644 index 0000000..61ef400 --- /dev/null +++ b/ml/__main__.py @@ -0,0 +1,31 @@ +from random import randint + +import hydra +from lightning import seed_everything +from omegaconf import DictConfig, OmegaConf +from rationai.mlkit import Trainer, autolog +from rationai.mlkit.lightning.loggers import MLFlowLogger + +from ml.data import DataModule +from ml.meta_arch import MetaArch + + +OmegaConf.register_new_resolver( + "random_seed", lambda: randint(0, 2**31), use_cache=True +) +OmegaConf.register_new_resolver("len", lambda x: len(x)) + + +@hydra.main(config_path="../configs", config_name="ml", version_base=None) +@autolog +def main(config: DictConfig, logger: MLFlowLogger) -> None: + seed_everything(config.seed, workers=True) + + data = hydra.utils.instantiate(config.data, _recursive_=False, _target_=DataModule) + model = hydra.utils.instantiate(config.model, _target_=MetaArch) + trainer = hydra.utils.instantiate(config.trainer, _target_=Trainer, logger=logger) + getattr(trainer, config.mode)(model, datamodule=data, ckpt_path=config.checkpoint) + + +if __name__ == "__main__": + main() diff --git a/ml/callbacks/__init__.py b/ml/callbacks/__init__.py new file mode 100644 index 0000000..e9c20c4 --- /dev/null +++ b/ml/callbacks/__init__.py @@ -0,0 +1,4 @@ +from ml.callbacks.parquet_prediction_writer import ParquetPredictionWriter + + +__all__ = ["ParquetPredictionWriter"] diff --git a/ml/callbacks/parquet_prediction_writer.py b/ml/callbacks/parquet_prediction_writer.py new file mode 100644 index 0000000..d3b91f7 --- /dev/null +++ b/ml/callbacks/parquet_prediction_writer.py @@ -0,0 +1,77 @@ +"""Aggregate ``predict_step`` outputs and write them as a parquet artifact.""" + +from collections.abc import Sequence +from pathlib import Path +from typing import Any + +import lightning as pl +import mlflow +import numpy as np +import pandas as pd +from lightning.pytorch.callbacks import BasePredictionWriter + + +class ParquetPredictionWriter(BasePredictionWriter): + """Collect per-tile predictions and write them as a parquet artifact. + + Aggregates ``predict_step`` outputs across the predict loop, writes one + parquet file with ``slide_id``, ``target``, ``pred``, ``prob_`` + columns and logs it to the active MLflow run. + """ + def __init__(self, output_filename: str = "predictions.parquet") -> None: + super().__init__(write_interval="epoch") + self.output_filename = output_filename + self._batches: list[dict[str, Any]] = [] + + def write_on_batch_end( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + prediction: dict[str, Any], + batch_indices: Sequence[int] | None, + batch: Any, + batch_idx: int, + dataloader_idx: int, + ) -> None: + self._batches.append(prediction) + + def write_on_epoch_end( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + predictions: Any, + batch_indices: Any, + ) -> None: + if not self._batches: + return + + slide_ids: list[str] = [] + targets: list[int] = [] + preds: list[int] = [] + probs: list[np.ndarray] = [] + for b in self._batches: + slide_ids.extend(b["slide_id"]) + targets.extend(b["target"].tolist()) + preds.extend(b["pred"].tolist()) + probs.append(b["probs"].numpy()) + prob_matrix = np.concatenate(probs, axis=0) + + class_names = getattr(pl_module, "class_names", None) + prob_columns = ( + [f"prob_{c}" for c in class_names] + if class_names is not None and len(class_names) == prob_matrix.shape[1] + else [f"prob_{i}" for i in range(prob_matrix.shape[1])] + ) + + df = pd.DataFrame({"slide_id": slide_ids, "target": targets, "pred": preds}) + df = pd.concat([df, pd.DataFrame(prob_matrix, columns=prob_columns)], axis=1) + + out_path = Path(trainer.default_root_dir) / self.output_filename + out_path.parent.mkdir(parents=True, exist_ok=True) + df.to_parquet(out_path, index=False) + + active = mlflow.active_run() + if active is not None: + mlflow.log_artifact(str(out_path), artifact_path="predictions") + + self._batches.clear() diff --git a/ml/data/__init__.py b/ml/data/__init__.py new file mode 100644 index 0000000..e7058ee --- /dev/null +++ b/ml/data/__init__.py @@ -0,0 +1,4 @@ +from ml.data.data_module import DataModule + + +__all__ = ["DataModule"] diff --git a/ml/data/data_module.py b/ml/data/data_module.py new file mode 100644 index 0000000..c39b950 --- /dev/null +++ b/ml/data/data_module.py @@ -0,0 +1,64 @@ +from collections.abc import Iterable + +from hydra.utils import instantiate +from lightning import LightningDataModule +from omegaconf import DictConfig +from torch.utils.data import DataLoader + +from ml.typing import Input + + +class DataModule(LightningDataModule): + """Generic Lightning datamodule that instantiates datasets lazily per stage. + + Mirrors the template pattern: ``**datasets`` accepts ``train``, ``val``, + ``test`` (or ``predict``) DictConfigs whose targets resolve to ``Dataset``s. + """ + + def __init__( + self, batch_size: int, num_workers: int = 0, **datasets: DictConfig + ) -> None: + super().__init__() + self.batch_size = batch_size + self.num_workers = num_workers + self.datasets = datasets + + def setup(self, stage: str) -> None: + match stage: + case "fit": + self.train = instantiate(self.datasets["train"]) + self.val = instantiate(self.datasets["val"]) + case "validate": + self.val = instantiate(self.datasets["val"]) + case "test": + self.test = instantiate(self.datasets["test"]) + case "predict": + self.predict = instantiate(self.datasets["predict"]) + + def train_dataloader(self) -> Iterable[Input]: + return DataLoader( + self.train, + batch_size=self.batch_size, + shuffle=True, + drop_last=True, + num_workers=self.num_workers, + persistent_workers=self.num_workers > 0, + ) + + def val_dataloader(self) -> Iterable[Input]: + return DataLoader( + self.val, + batch_size=self.batch_size, + num_workers=self.num_workers, + persistent_workers=self.num_workers > 0, + ) + + def test_dataloader(self) -> Iterable[Input]: + return DataLoader( + self.test, batch_size=self.batch_size, num_workers=self.num_workers + ) + + def predict_dataloader(self) -> Iterable[Input]: + return DataLoader( + self.predict, batch_size=self.batch_size, num_workers=self.num_workers + ) diff --git a/ml/data/datasets/__init__.py b/ml/data/datasets/__init__.py new file mode 100644 index 0000000..cd2f91a --- /dev/null +++ b/ml/data/datasets/__init__.py @@ -0,0 +1,4 @@ +from ml.data.datasets.embedding_tiles import EmbeddingTilesDataset + + +__all__ = ["EmbeddingTilesDataset"] diff --git a/ml/data/datasets/embedding_tiles.py b/ml/data/datasets/embedding_tiles.py new file mode 100644 index 0000000..ad4d104 --- /dev/null +++ b/ml/data/datasets/embedding_tiles.py @@ -0,0 +1,78 @@ +"""Tile-embedding dataset. + +Reads the parquet artifact produced by ``preprocessing.embedding_dataset``. +""" + + +from pathlib import Path + +import numpy as np +import pandas as pd +import torch +from mlflow.artifacts import download_artifacts +from torch.utils.data import Dataset + +from ml.typing import Sample + + +class EmbeddingTilesDataset(Dataset[Sample]): + """Returns ``(embedding, class_index, slide_id)`` triples from a tiles parquet. + + A single dataset instance corresponds to one parquet (train or test); + fold-based CV is expressed by ``include_folds`` / ``exclude_folds`` + filters applied to the train parquet via separate dataset configs. + """ + + REQUIRED_COLUMNS = ("embedding", "label", "slide_id") + + def __init__( + self, + path_or_uri: str | Path, + class_indices: dict[str, int], + include_folds: list[int] | None = None, + exclude_folds: list[int] | None = None, + ) -> None: + df = self._load_parquet(path_or_uri) + + missing = set(self.REQUIRED_COLUMNS) - set(df.columns) + if missing: + raise ValueError(f"tiles parquet missing columns: {sorted(missing)}") + + if include_folds is not None or exclude_folds is not None: + if "fold" not in df.columns: + raise RuntimeError( + "fold filter requested but 'fold' column not in parquet" + ) + if include_folds is not None: + df = df[df["fold"].isin(include_folds)] + if exclude_folds is not None: + df = df[~df["fold"].isin(exclude_folds)] + + unknown = set(df["label"].unique()) - set(class_indices.keys()) + if unknown: + raise ValueError( + f"labels in tiles not present in class_indices: {sorted(unknown)}" + ) + + self.embeddings = np.stack(df["embedding"].tolist()).astype(np.float32) + self.labels = df["label"].map(class_indices).to_numpy(dtype=np.int64) + self.slide_ids = df["slide_id"].to_numpy() + + def __len__(self) -> int: + return len(self.labels) + + def __getitem__(self, idx: int) -> Sample: + return ( + torch.from_numpy(self.embeddings[idx]), + int(self.labels[idx]), + str(self.slide_ids[idx]), + ) + + @staticmethod + def _load_parquet(path_or_uri: str | Path) -> pd.DataFrame: + s = str(path_or_uri) + if s.startswith(("mlflow-artifacts:/", "runs:/")): + local = download_artifacts(artifact_uri=s) + else: + local = s + return pd.read_parquet(local) diff --git a/ml/meta_arch.py b/ml/meta_arch.py new file mode 100644 index 0000000..602693a --- /dev/null +++ b/ml/meta_arch.py @@ -0,0 +1,220 @@ +from collections import defaultdict +from collections.abc import Iterable +from typing import Any + +import mlflow +import numpy as np +import pandas as pd +import torch +from lightning import LightningModule +from matplotlib import pyplot as plt +from matplotlib.figure import Figure +from torch import Tensor, nn +from torch.optim.optimizer import Optimizer +from torchmetrics import MetricCollection +from torchmetrics.classification import ( + MulticlassAccuracy, + MulticlassConfusionMatrix, + MulticlassF1Score, +) + +from ml.typing import Input, Outputs + + +class MetaArch(LightningModule): + """Top-level classification architecture: backbone + decode_head + criterion. + + For linear probing on precomputed embeddings, ``backbone`` is typically + ``nn.Identity`` and ``decode_head`` is a single ``nn.Linear``. + """ + + # TODO: support class_weights for CE loss when class distribution is heavily + # imbalanced. Inject either via config (list[float]) or compute from the + # train fold label distribution at setup(). + + def __init__( + self, + backbone: nn.Module, + decode_head: nn.Module, + criterion: nn.Module, + class_indices: dict[str, int], + learning_rate: float = 1e-3, + weight_decay: float = 0.0, + ) -> None: + super().__init__() + self.save_hyperparameters(ignore=["backbone", "decode_head", "criterion"]) + + self.backbone = backbone + self.decode_head = decode_head + self.criterion = criterion + + self.class_names = [ + n for n, _ in sorted(class_indices.items(), key=lambda kv: kv[1]) + ] + num_classes = len(self.class_names) + + macro_metrics = MetricCollection( + { + "acc_macro": MulticlassAccuracy( + num_classes=num_classes, average="macro" + ), + "f1_macro": MulticlassF1Score(num_classes=num_classes, average="macro"), + } + ) + per_class_metrics = MetricCollection( + { + "acc_per_class": MulticlassAccuracy( + num_classes=num_classes, average=None + ), + "f1_per_class": MulticlassF1Score( + num_classes=num_classes, average=None + ), + } + ) + self.val_metrics = macro_metrics.clone(prefix="validation/") + self.test_metrics = macro_metrics.clone(prefix="test/") + self.val_per_class = per_class_metrics.clone(prefix="validation/") + self.test_per_class = per_class_metrics.clone(prefix="test/") + self.val_confmat = MulticlassConfusionMatrix(num_classes=num_classes) + self.test_confmat = MulticlassConfusionMatrix(num_classes=num_classes) + + self._test_slide_correct: dict[str, int] = defaultdict(int) + self._test_slide_total: dict[str, int] = defaultdict(int) + + def forward(self, x: Tensor) -> Outputs: + features = self.backbone(x) + return self.decode_head(features) + + def training_step(self, batch: Input, batch_idx: int) -> Tensor: + inputs, targets, _ = batch + outputs = self(inputs) + loss = self.criterion(outputs, targets) + self.log("train/loss", loss, on_step=True, on_epoch=True, prog_bar=True) + return loss + + def validation_step(self, batch: Input, batch_idx: int) -> None: + inputs, targets, _ = batch + outputs = self(inputs) + loss = self.criterion(outputs, targets) + self.log("validation/loss", loss, on_epoch=True, prog_bar=True) + self.val_metrics.update(outputs, targets) + self.val_per_class.update(outputs, targets) + self.val_confmat.update(outputs, targets) + self.log_dict(self.val_metrics, on_epoch=True) + + def on_validation_epoch_end(self) -> None: + self._log_per_class(self.val_per_class, "validation") + self._log_confmat(self.val_confmat, "validation") + + def test_step(self, batch: Input, batch_idx: int) -> None: + inputs, targets, slide_ids = batch + outputs = self(inputs) + self.test_metrics.update(outputs, targets) + self.test_per_class.update(outputs, targets) + self.test_confmat.update(outputs, targets) + self.log_dict(self.test_metrics, on_epoch=True) + + preds = outputs.argmax(dim=1) + correct = (preds == targets).cpu().tolist() + for slide_id, ok in zip(slide_ids, correct, strict=True): + self._test_slide_correct[slide_id] += int(ok) + self._test_slide_total[slide_id] += 1 + + def on_test_epoch_end(self) -> None: + self._log_per_class(self.test_per_class, "test") + self._log_confmat(self.test_confmat, "test") + self._log_per_slide_accuracy() + self._test_slide_correct.clear() + self._test_slide_total.clear() + + def predict_step( + self, batch: Input, batch_idx: int, dataloader_idx: int = 0 + ) -> dict[str, Any]: + inputs, targets, slide_ids = batch + outputs = self(inputs) + probs = outputs.softmax(dim=1) + preds = outputs.argmax(dim=1) + return { + "slide_id": list(slide_ids), + "target": targets.cpu(), + "pred": preds.cpu(), + "probs": probs.cpu(), + } + + def configure_optimizers(self) -> Optimizer: + return torch.optim.AdamW( + self.parameters(), + lr=self.hparams["learning_rate"], + weight_decay=self.hparams["weight_decay"], + ) + + def _log_per_class(self, collection: MetricCollection, split: str) -> None: + computed = collection.compute() + for metric_name, values in computed.items(): + tag = metric_name.split("/")[-1] # e.g. "acc_per_class" + for cls_name, val in zip(self.class_names, values.tolist(), strict=True): + self.log(f"{split}/{tag}/{cls_name}", val, on_epoch=True) + collection.reset() + + def _log_confmat(self, confmat: MulticlassConfusionMatrix, split: str) -> None: + matrix = confmat.compute().cpu().numpy() + confmat.reset() + fig = _confmat_figure(matrix, self.class_names, title=f"{split} confmat") + artifact_file = f"confusion_matrix/{split}_epoch_{self.current_epoch}.png" + try: + mlflow.log_figure(fig, artifact_file=artifact_file) + finally: + plt.close(fig) + + def _log_per_slide_accuracy(self) -> None: + accs = [ + self._test_slide_correct[s] / self._test_slide_total[s] + for s in self._test_slide_total + ] + if not accs: + return + self.log("test/slide_acc_mean", float(np.mean(accs)), on_epoch=True) + self.log("test/slide_acc_median", float(np.median(accs)), on_epoch=True) + self.log("test/slide_acc_min", float(np.min(accs)), on_epoch=True) + + rows = [ + { + "slide_id": s, + "tile_accuracy": self._test_slide_correct[s] / n, + "n_tiles": n, + } + for s, n in self._test_slide_total.items() + ] + mlflow.log_table( + data=pd.DataFrame(rows), + artifact_file="per_slide/test_tile_accuracy.json", + ) + + +def _confmat_figure( + matrix: np.ndarray, class_names: Iterable[str], title: str +) -> Figure: + fig, ax = plt.subplots(figsize=(6, 5)) + im = ax.imshow(matrix, cmap="Blues") + ax.set_title(title) + ax.set_xlabel("Predicted") + ax.set_ylabel("True") + names = list(class_names) + ax.set_xticks(range(len(names))) + ax.set_yticks(range(len(names))) + ax.set_xticklabels(names, rotation=45, ha="right") + ax.set_yticklabels(names) + for i in range(matrix.shape[0]): + for j in range(matrix.shape[1]): + ax.text( + j, + i, + str(matrix[i, j]), + ha="center", + va="center", + color="white" if matrix[i, j] > matrix.max() / 2 else "black", + fontsize=8, + ) + fig.colorbar(im, ax=ax) + fig.tight_layout() + return fig diff --git a/ml/modeling/__init__.py b/ml/modeling/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/ml/typing.py b/ml/typing.py new file mode 100644 index 0000000..7060f5e --- /dev/null +++ b/ml/typing.py @@ -0,0 +1,6 @@ +import torch + + +type Sample = tuple[torch.Tensor, int, str] +type Input = tuple[torch.Tensor, torch.Tensor, list[str]] +type Outputs = torch.Tensor diff --git a/preprocessing/embedding_dataset.py b/preprocessing/embedding_dataset.py index a1e6545..7669db9 100644 --- a/preprocessing/embedding_dataset.py +++ b/preprocessing/embedding_dataset.py @@ -33,14 +33,17 @@ def apply_thresholds( thresholds: dict[str, float], roi_cols: list[str], ) -> tuple[pd.DataFrame, int]: - """Filter df by tissue_prop_min, then keep tiles where ANY class meets its - threshold; among passing classes, the highest-coverage one becomes the label. + """Filter tiles by tissue + per-class thresholds and rewrite labels. + + Filters ``df`` by ``tissue_prop_min``, then keeps tiles where ANY class + meets its threshold; among passing classes, the highest-coverage one + becomes the label. Returns ``(filtered_df, after_tissue_count)`` so the caller can log both intermediate counts. The returned df has its ``label`` column rewritten to reflect the argmax-over-passers rule. """ - df = df[df["tissue_prop"] >= tissue_prop_min] + df = df.loc[df["tissue_prop"] >= tissue_prop_min] after_tissue = len(df) if df.empty: return df, after_tissue @@ -55,7 +58,7 @@ def apply_thresholds( label_idx = masked.argmax(axis=1) new_labels = class_names[label_idx] - out = df[keep].copy() + out = df.loc[pd.Series(keep, index=df.index)].copy() out["label"] = new_labels[keep] return out, after_tissue diff --git a/pyproject.toml b/pyproject.toml index 183bdd9..bc54e93 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,8 @@ dependencies = [ "tifffile>=2025.12.20", "torch>=2.0.0", "torchvision>=0.15.0", + "lightning>=2.0.0", + "torchmetrics>=1.0.0", "timm>=1.0.0", "einops>=0.8.0", "matplotlib>=3.10.7", diff --git a/scripts/submit_train_linear.py b/scripts/submit_train_linear.py new file mode 100644 index 0000000..11fc786 --- /dev/null +++ b/scripts/submit_train_linear.py @@ -0,0 +1,18 @@ +from kube_jobs import storage, submit_job + + +submit_job( + job_name="tissue-classification-train-linear", + username=..., + cpu=8, + memory="64Gi", + gpu="A40", + public=False, + script=[ + "git clone https://github.com/RationAI/tissue-classification.git workdir", + "cd workdir", + "uv sync", + "uv run python -m ml +experiment=... val_fold=0,1,2,3,4 --mutlirun", + ], + storage=[storage.secure.PROJECTS], +) diff --git a/uv.lock b/uv.lock index 30b783b..d884e87 100644 --- a/uv.lock +++ b/uv.lock @@ -2292,6 +2292,7 @@ dependencies = [ { name = "datasets" }, { name = "einops" }, { name = "hydra-core" }, + { name = "lightning" }, { name = "matplotlib" }, { name = "mlflow" }, { name = "numpy" }, @@ -2309,6 +2310,7 @@ dependencies = [ { name = "tifffile" }, { name = "timm" }, { name = "torch" }, + { name = "torchmetrics" }, { name = "torchvision" }, { name = "tqdm" }, ] @@ -2327,6 +2329,7 @@ requires-dist = [ { name = "datasets", specifier = ">=4.0.0" }, { name = "einops", specifier = ">=0.8.0" }, { name = "hydra-core", specifier = ">=1.3.2" }, + { name = "lightning", specifier = ">=2.0.0" }, { name = "matplotlib", specifier = ">=3.10.7" }, { name = "mlflow", specifier = "<3.0.0" }, { name = "numpy", specifier = ">=2.3.5" }, @@ -2345,6 +2348,7 @@ requires-dist = [ { name = "tifffile", specifier = ">=2025.12.20" }, { name = "timm", specifier = ">=1.0.0" }, { name = "torch", specifier = ">=2.0.0" }, + { name = "torchmetrics", specifier = ">=1.0.0" }, { name = "torchvision", specifier = ">=0.15.0" }, { name = "tqdm", specifier = ">=4.66.0" }, ] From d5798bc3cb9ef1c19d167039496cccd594539252 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Mon, 11 May 2026 18:26:58 +0200 Subject: [PATCH 19/53] feat: add class weights --- configs/data/dataset.yaml | 2 +- ml/callbacks/parquet_prediction_writer.py | 1 + ml/data/datasets/embedding_tiles.py | 1 - ml/meta_arch.py | 12 ++++++++++++ scripts/submit_train_linear.py | 8 ++++---- 5 files changed, 18 insertions(+), 6 deletions(-) diff --git a/configs/data/dataset.yaml b/configs/data/dataset.yaml index 57e1925..7419d6c 100644 --- a/configs/data/dataset.yaml +++ b/configs/data/dataset.yaml @@ -16,7 +16,7 @@ dataset: filter_tiles_run_id: "4e8f5d3c82124ea5a8f871a42d3ed9ba" kfold_run_id: "850c81506684450b9af92296acfd045a" embedding_run_id: "5f323d5ef5a74026846ecbe8fbc007fb" - embedding_dataset_run_id: "3ab86e376d38481dbac5bc352f7ac7c9" + embedding_dataset_run_id: "b4a937ef6b334533807f08a191083401" tissue_masks_run_id: "52bc0924f8624b259819c480c7cf213f" tissue_stats_run_id: "16ae2d003d88471b924e5f332415232a" diff --git a/ml/callbacks/parquet_prediction_writer.py b/ml/callbacks/parquet_prediction_writer.py index d3b91f7..1fe8558 100644 --- a/ml/callbacks/parquet_prediction_writer.py +++ b/ml/callbacks/parquet_prediction_writer.py @@ -18,6 +18,7 @@ class ParquetPredictionWriter(BasePredictionWriter): parquet file with ``slide_id``, ``target``, ``pred``, ``prob_`` columns and logs it to the active MLflow run. """ + def __init__(self, output_filename: str = "predictions.parquet") -> None: super().__init__(write_interval="epoch") self.output_filename = output_filename diff --git a/ml/data/datasets/embedding_tiles.py b/ml/data/datasets/embedding_tiles.py index ad4d104..e60bc8f 100644 --- a/ml/data/datasets/embedding_tiles.py +++ b/ml/data/datasets/embedding_tiles.py @@ -3,7 +3,6 @@ Reads the parquet artifact produced by ``preprocessing.embedding_dataset``. """ - from pathlib import Path import numpy as np diff --git a/ml/meta_arch.py b/ml/meta_arch.py index 602693a..1ee2575 100644 --- a/ml/meta_arch.py +++ b/ml/meta_arch.py @@ -81,6 +81,18 @@ def __init__( self._test_slide_correct: dict[str, int] = defaultdict(int) self._test_slide_total: dict[str, int] = defaultdict(int) + def setup(self, stage: str) -> None: + if stage == "fit": + labels = self.trainer.datamodule.train.labels + num_classes = len(self.class_names) + counts = np.bincount(labels, minlength=num_classes).astype(float) + weights = len(labels) / (num_classes * counts) + self.criterion = nn.CrossEntropyLoss( + weight=torch.tensor(weights, dtype=torch.float32) + ) + for cls, w in zip(self.class_names, weights.tolist(), strict=True): + mlflow.log_metric(f"class_weight/{cls}", w) + def forward(self, x: Tensor) -> Outputs: features = self.backbone(x) return self.decode_head(features) diff --git a/scripts/submit_train_linear.py b/scripts/submit_train_linear.py index 11fc786..d19f674 100644 --- a/scripts/submit_train_linear.py +++ b/scripts/submit_train_linear.py @@ -3,16 +3,16 @@ submit_job( job_name="tissue-classification-train-linear", - username=..., + username="vcifka", cpu=8, memory="64Gi", - gpu="A40", + gpu=None, public=False, script=[ - "git clone https://github.com/RationAI/tissue-classification.git workdir", + "git clone --branch feature/ml-linear-classifier https://github.com/RationAI/tissue-classification.git workdir", "cd workdir", "uv sync", - "uv run python -m ml +experiment=... val_fold=0,1,2,3,4 --mutlirun", + "uv run python -m ml +experiment=ml/linear_classifier val_fold=0,1,2,3,4 --multirun", ], storage=[storage.secure.PROJECTS], ) From ae45cd54244702ad392e6aa15265da6bdbacd5d8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Mon, 11 May 2026 20:18:05 +0200 Subject: [PATCH 20/53] refactor: join embeddings with metadata while loading the dataset --- configs/data/dataset.yaml | 3 +- configs/experiment/ml/linear_classifier.yaml | 30 +++- configs/ml/data/embedding.yaml | 15 +- ml/data/datasets/embedding_tiles.py | 159 +++++++++++++++---- ml/meta_arch.py | 6 +- pyproject.toml | 6 +- 6 files changed, 171 insertions(+), 48 deletions(-) diff --git a/configs/data/dataset.yaml b/configs/data/dataset.yaml index 7419d6c..172e48f 100644 --- a/configs/data/dataset.yaml +++ b/configs/data/dataset.yaml @@ -15,8 +15,7 @@ dataset: tiling_run_id: "fdf7550a2004474f8c7a05dc0cf1fd86" filter_tiles_run_id: "4e8f5d3c82124ea5a8f871a42d3ed9ba" kfold_run_id: "850c81506684450b9af92296acfd045a" - embedding_run_id: "5f323d5ef5a74026846ecbe8fbc007fb" - embedding_dataset_run_id: "b4a937ef6b334533807f08a191083401" + embedding_run_id: "c325e3a5033b4077b6febb0e3e6b0bd6" tissue_masks_run_id: "52bc0924f8624b259819c480c7cf213f" tissue_stats_run_id: "16ae2d003d88471b924e5f332415232a" diff --git a/configs/experiment/ml/linear_classifier.yaml b/configs/experiment/ml/linear_classifier.yaml index 6796489..2e9a952 100644 --- a/configs/experiment/ml/linear_classifier.yaml +++ b/configs/experiment/ml/linear_classifier.yaml @@ -10,19 +10,39 @@ defaults: mode: fit -embedding_dataset_run_id: ${dataset.mlflow_artifacts.embedding_dataset_run_id} -train_tiles_uri: runs:/${embedding_dataset_run_id}/embedding_dataset/train/tiles.parquet -test_tiles_uri: runs:/${embedding_dataset_run_id}/embedding_dataset/test/tiles.parquet +embedding_run_id: ${dataset.mlflow_artifacts.embedding_run_id} +kfold_run_id: ${dataset.mlflow_artifacts.kfold_run_id} +filter_tiles_run_id: ${dataset.mlflow_artifacts.filter_tiles_run_id} + +train_embedding_uri: runs:/${embedding_run_id}/train/tiles +test_embedding_uri: runs:/${embedding_run_id}/test/tiles +train_metadata_uri: runs:/${kfold_run_id}/kfold_split/kfold_tiles.parquet +test_metadata_uri: runs:/${filter_tiles_run_id}/filter_tiles/test_tiles.parquet + val_fold: 0 +tissue_prop_min: 0.2 +thresholds: + Nerve: 0.0 + Blood: 0.0 + Connective-Tissue: 0.4 + Fat: 0.6 + Epithelium: 0.2 + Muscle: 0.5 + Other: 0.5 + mlflow_artifact_path: linear_classifier metadata: run_name: Linear Classifier ${dataset.name} fold=${val_fold} - description: "Linear probe over frozen Virchow2 embeddings produced by embedding_dataset run ${embedding_dataset_run_id}." + description: "Linear probe over frozen Virchow2 embeddings (run ${embedding_run_id}), kfold metadata ${kfold_run_id}." hyperparams: - embedding_dataset_run_id: ${embedding_dataset_run_id} + embedding_run_id: ${embedding_run_id} + kfold_run_id: ${kfold_run_id} + filter_tiles_run_id: ${filter_tiles_run_id} val_fold: ${val_fold} + tissue_prop_min: ${tissue_prop_min} + thresholds: ${thresholds} learning_rate: ${model.learning_rate} weight_decay: ${model.weight_decay} batch_size: ${data.batch_size} diff --git a/configs/ml/data/embedding.yaml b/configs/ml/data/embedding.yaml index d80ca0c..597e012 100644 --- a/configs/ml/data/embedding.yaml +++ b/configs/ml/data/embedding.yaml @@ -6,19 +6,28 @@ data: train: _target_: ml.data.datasets.EmbeddingTilesDataset - path_or_uri: ${train_tiles_uri} + embedding_uri: ${train_embedding_uri} + metadata_uri: ${train_metadata_uri} class_indices: ${class_indices} + thresholds: ${thresholds} + tissue_prop_min: ${tissue_prop_min} exclude_folds: - ${val_fold} val: _target_: ml.data.datasets.EmbeddingTilesDataset - path_or_uri: ${train_tiles_uri} + embedding_uri: ${train_embedding_uri} + metadata_uri: ${train_metadata_uri} class_indices: ${class_indices} + thresholds: ${thresholds} + tissue_prop_min: ${tissue_prop_min} include_folds: - ${val_fold} test: _target_: ml.data.datasets.EmbeddingTilesDataset - path_or_uri: ${test_tiles_uri} + embedding_uri: ${test_embedding_uri} + metadata_uri: ${test_metadata_uri} class_indices: ${class_indices} + thresholds: ${thresholds} + tissue_prop_min: ${tissue_prop_min} diff --git a/ml/data/datasets/embedding_tiles.py b/ml/data/datasets/embedding_tiles.py index e60bc8f..6cf2c1d 100644 --- a/ml/data/datasets/embedding_tiles.py +++ b/ml/data/datasets/embedding_tiles.py @@ -1,12 +1,16 @@ """Tile-embedding dataset. -Reads the parquet artifact produced by ``preprocessing.embedding_dataset``. +Joins precomputed tile embeddings with tile metadata (k-fold parquet for train, +filter_tiles parquet for test) and applies tissue + per-class thresholds at +load time to produce ``(embedding, class_index, slide_id)`` triples. """ from pathlib import Path import numpy as np import pandas as pd +import pyarrow as pa +import pyarrow.dataset as pads import torch from mlflow.artifacts import download_artifacts from torch.utils.data import Dataset @@ -15,47 +19,81 @@ class EmbeddingTilesDataset(Dataset[Sample]): - """Returns ``(embedding, class_index, slide_id)`` triples from a tiles parquet. + """Tile-level embedding dataset with on-the-fly filtering and labeling. - A single dataset instance corresponds to one parquet (train or test); - fold-based CV is expressed by ``include_folds`` / ``exclude_folds`` - filters applied to the train parquet via separate dataset configs. - """ + Inner-joins ``embedding`` parquet with ``metadata`` parquet on + ``(slide_id, x, y)``. Metadata must contain ``roi_coverage_*`` columns; + label is the dominant class whose coverage meets its threshold. Tiles + failing the tissue proportion floor, with more than one annotated class, + or whose dominant class falls below its threshold are dropped. - REQUIRED_COLUMNS = ("embedding", "label", "slide_id") + For train/val: pass the k-fold parquet as ``metadata_uri`` and use + ``include_folds`` / ``exclude_folds`` to split. For test: pass the + filter_tiles parquet (no fold column). + """ def __init__( self, - path_or_uri: str | Path, + embedding_uri: str | Path, + metadata_uri: str | Path, class_indices: dict[str, int], + thresholds: dict[str, float], + tissue_prop_min: float, include_folds: list[int] | None = None, exclude_folds: list[int] | None = None, ) -> None: - df = self._load_parquet(path_or_uri) + meta_df = self._filter_metadata( + metadata_uri, + thresholds, + tissue_prop_min, + include_folds, + exclude_folds, + ) - missing = set(self.REQUIRED_COLUMNS) - set(df.columns) - if missing: - raise ValueError(f"tiles parquet missing columns: {sorted(missing)}") + emb_dir = self._resolve_uri(embedding_uri) + emb_table = pads.dataset(emb_dir, format="parquet").to_table( + columns=["slide_id", "x", "y", "embedding"] + ) - if include_folds is not None or exclude_folds is not None: - if "fold" not in df.columns: - raise RuntimeError( - "fold filter requested but 'fold' column not in parquet" - ) - if include_folds is not None: - df = df[df["fold"].isin(include_folds)] - if exclude_folds is not None: - df = df[~df["fold"].isin(exclude_folds)] + emb_col = emb_table.column("embedding") + if pa.types.is_list(emb_col.type): + target_type = pa.large_list(emb_col.type.value_type) + emb_col = pa.chunked_array( + [c.cast(target_type) for c in emb_col.chunks], type=target_type + ) + + emb_idx = pa.array(range(emb_table.num_rows), type=pa.int64()) + emb_keys = emb_table.drop(["embedding"]).append_column("_emb_idx", emb_idx) + del emb_table - unknown = set(df["label"].unique()) - set(class_indices.keys()) + meta_table = pa.Table.from_pandas(meta_df, preserve_index=False) + del meta_df + joined_keys = meta_table.join( + emb_keys, keys=["slide_id", "x", "y"], join_type="inner" + ) + del emb_keys, meta_table + if joined_keys.num_rows == 0: + raise RuntimeError("inner join with embeddings produced empty dataset") + + indices = joined_keys.column("_emb_idx") + if isinstance(indices, pa.ChunkedArray): + indices = indices.combine_chunks() + embeddings_arrow = emb_col.combine_chunks().take(indices) + + embedding_dim = len(embeddings_arrow.values) // len(embeddings_arrow) + self.embeddings = ( + embeddings_arrow.values.to_numpy(zero_copy_only=False) + .astype(np.float32) + .reshape(len(embeddings_arrow), embedding_dim) + ) + labels = joined_keys.column("label").to_pandas() + unknown = set(labels.unique()) - set(class_indices.keys()) if unknown: raise ValueError( - f"labels in tiles not present in class_indices: {sorted(unknown)}" + f"labels in data not present in class_indices: {sorted(unknown)}" ) - - self.embeddings = np.stack(df["embedding"].tolist()).astype(np.float32) - self.labels = df["label"].map(class_indices).to_numpy(dtype=np.int64) - self.slide_ids = df["slide_id"].to_numpy() + self.labels = labels.map(class_indices).to_numpy(dtype=np.int64) + self.slide_ids = joined_keys.column("slide_id").to_pandas().to_numpy() def __len__(self) -> int: return len(self.labels) @@ -68,10 +106,67 @@ def __getitem__(self, idx: int) -> Sample: ) @staticmethod - def _load_parquet(path_or_uri: str | Path) -> pd.DataFrame: + def _filter_metadata( + metadata_uri: str | Path, + thresholds: dict[str, float], + tissue_prop_min: float, + include_folds: list[int] | None, + exclude_folds: list[int] | None, + ) -> pd.DataFrame: + local = EmbeddingTilesDataset._resolve_uri(metadata_uri) + df = pd.read_parquet(local) + + roi_cols = [c for c in df.columns if c.startswith("roi_coverage_")] + if not roi_cols: + raise ValueError( + "metadata parquet has no roi_coverage_* columns; cannot label" + ) + + classes_in_data = {c.removeprefix("roi_coverage_") for c in roi_cols} + missing_thresholds = classes_in_data - set(thresholds.keys()) + if missing_thresholds: + raise ValueError( + f"thresholds missing entries for classes present in data: " + f"{sorted(missing_thresholds)}" + ) + + tissue_prop = df[roi_cols].sum(axis=1).to_numpy() + df = df.loc[tissue_prop >= tissue_prop_min] + if df.empty: + raise RuntimeError("all tiles dropped by tissue_prop_min filter") + + nonzero_classes = (df[roi_cols].to_numpy() > 0).sum(axis=1) + df = df.loc[pd.Series(nonzero_classes <= 1, index=df.index)] + if df.empty: + raise RuntimeError("all tiles dropped by single-class filter") + + roi_only = df[roi_cols] + dominant = roi_only.idxmax(axis=1).str.removeprefix("roi_coverage_") + dominant_value = roi_only.max(axis=1).to_numpy() + threshold_per_row = dominant.map(thresholds).to_numpy() + keep = dominant_value >= threshold_per_row + df = df.loc[pd.Series(keep, index=df.index)].copy() + df["label"] = dominant.to_numpy()[keep] + if df.empty: + raise RuntimeError("all tiles dropped by per-class thresholds") + + if include_folds is not None or exclude_folds is not None: + if "fold" not in df.columns: + raise RuntimeError( + "fold filter requested but 'fold' column not in metadata" + ) + if include_folds is not None: + df = df[df["fold"].isin(include_folds)] + if exclude_folds is not None: + df = df[~df["fold"].isin(exclude_folds)] + if df.empty: + raise RuntimeError("all tiles dropped by fold filter") + + return df[["slide_id", "x", "y", "label"]] + + @staticmethod + def _resolve_uri(path_or_uri: str | Path) -> str: s = str(path_or_uri) if s.startswith(("mlflow-artifacts:/", "runs:/")): - local = download_artifacts(artifact_uri=s) - else: - local = s - return pd.read_parquet(local) + return download_artifacts(artifact_uri=s) + return s diff --git a/ml/meta_arch.py b/ml/meta_arch.py index 1ee2575..0d2cc55 100644 --- a/ml/meta_arch.py +++ b/ml/meta_arch.py @@ -1,6 +1,6 @@ from collections import defaultdict from collections.abc import Iterable -from typing import Any +from typing import Any, cast import mlflow import numpy as np @@ -83,9 +83,11 @@ def __init__( def setup(self, stage: str) -> None: if stage == "fit": - labels = self.trainer.datamodule.train.labels + datamodule = cast(Any, self.trainer).datamodule + labels = datamodule.train.labels num_classes = len(self.class_names) counts = np.bincount(labels, minlength=num_classes).astype(float) + counts = np.maximum(counts, 1.0) weights = len(labels) / (num_classes * counts) self.criterion = nn.CrossEntropyLoss( weight=torch.tensor(weights, dtype=torch.float32) diff --git a/pyproject.toml b/pyproject.toml index bc54e93..0cff386 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,8 +19,8 @@ dependencies = [ "tqdm>=4.66.0", "rationai-sdk", "ratiopath>=1.2.0", - "pyarrow>=19.0.0", - "datasets>=3.0.0", + "pyarrow>=19.0.1", + "datasets>=4.0.0", "scikit-learn>=1.8.0", "numpy>=2.3.5", "rationai-tiling>=1.1.1", @@ -32,8 +32,6 @@ dependencies = [ "timm>=1.0.0", "einops>=0.8.0", "matplotlib>=3.10.7", - "pyarrow>=19.0.1", - "datasets>=4.0.0", ] [dependency-groups] From bdce760107fd4fcba4ecdbadeff4d5f6003399ef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Mon, 11 May 2026 20:30:18 +0200 Subject: [PATCH 21/53] feat: add prints --- ml/data/datasets/embedding_tiles.py | 59 ++++++++++++++++++++++++++++- uv.lock | 2 - 2 files changed, 58 insertions(+), 3 deletions(-) diff --git a/ml/data/datasets/embedding_tiles.py b/ml/data/datasets/embedding_tiles.py index 6cf2c1d..bd86542 100644 --- a/ml/data/datasets/embedding_tiles.py +++ b/ml/data/datasets/embedding_tiles.py @@ -5,6 +5,7 @@ load time to produce ``(embedding, class_index, slide_id)`` triples. """ +import time from pathlib import Path import numpy as np @@ -42,6 +43,15 @@ def __init__( include_folds: list[int] | None = None, exclude_folds: list[int] | None = None, ) -> None: + tag = ( + f"include_folds={include_folds}" + if include_folds is not None + else f"exclude_folds={exclude_folds}" + if exclude_folds is not None + else "no_folds" + ) + print(f"[dataset] init: {tag}, metadata={metadata_uri}", flush=True) + t0 = time.time() meta_df = self._filter_metadata( metadata_uri, thresholds, @@ -49,37 +59,73 @@ def __init__( include_folds, exclude_folds, ) + print( + f"[dataset] metadata filtered: {len(meta_df)} rows " + f"({time.time() - t0:.1f}s)", + flush=True, + ) + t = time.time() emb_dir = self._resolve_uri(embedding_uri) + print( + f"[dataset] embedding artifacts resolved in {time.time() - t:.1f}s", + flush=True, + ) + + t = time.time() emb_table = pads.dataset(emb_dir, format="parquet").to_table( columns=["slide_id", "x", "y", "embedding"] ) + print( + f"[dataset] embedding parquet loaded: {emb_table.num_rows} rows " + f"({time.time() - t:.1f}s)", + flush=True, + ) + t = time.time() emb_col = emb_table.column("embedding") if pa.types.is_list(emb_col.type): target_type = pa.large_list(emb_col.type.value_type) emb_col = pa.chunked_array( [c.cast(target_type) for c in emb_col.chunks], type=target_type ) + print( + f"[dataset] embedding column cast in {time.time() - t:.1f}s", flush=True + ) emb_idx = pa.array(range(emb_table.num_rows), type=pa.int64()) emb_keys = emb_table.drop(["embedding"]).append_column("_emb_idx", emb_idx) del emb_table + t = time.time() meta_table = pa.Table.from_pandas(meta_df, preserve_index=False) del meta_df joined_keys = meta_table.join( emb_keys, keys=["slide_id", "x", "y"], join_type="inner" ) del emb_keys, meta_table + print( + f"[dataset] arrow join: {joined_keys.num_rows} rows " + f"({time.time() - t:.1f}s)", + flush=True, + ) if joined_keys.num_rows == 0: raise RuntimeError("inner join with embeddings produced empty dataset") + t = time.time() indices = joined_keys.column("_emb_idx") if isinstance(indices, pa.ChunkedArray): indices = indices.combine_chunks() - embeddings_arrow = emb_col.combine_chunks().take(indices) + emb_contig = emb_col.combine_chunks() + print( + f"[dataset] combine_chunks done in {time.time() - t:.1f}s", flush=True + ) + + t = time.time() + embeddings_arrow = emb_contig.take(indices) + print(f"[dataset] take done in {time.time() - t:.1f}s", flush=True) + t = time.time() embedding_dim = len(embeddings_arrow.values) // len(embeddings_arrow) self.embeddings = ( embeddings_arrow.values.to_numpy(zero_copy_only=False) @@ -94,6 +140,11 @@ def __init__( ) self.labels = labels.map(class_indices).to_numpy(dtype=np.int64) self.slide_ids = joined_keys.column("slide_id").to_pandas().to_numpy() + print( + f"[dataset] numpy conversion done in {time.time() - t:.1f}s, " + f"total={time.time() - t0:.1f}s", + flush=True, + ) def __len__(self) -> int: return len(self.labels) @@ -113,8 +164,14 @@ def _filter_metadata( include_folds: list[int] | None, exclude_folds: list[int] | None, ) -> pd.DataFrame: + t = time.time() local = EmbeddingTilesDataset._resolve_uri(metadata_uri) df = pd.read_parquet(local) + print( + f"[dataset] metadata parquet loaded: {len(df)} rows " + f"({time.time() - t:.1f}s)", + flush=True, + ) roi_cols = [c for c in df.columns if c.startswith("roi_coverage_")] if not roi_cols: diff --git a/uv.lock b/uv.lock index d884e87..c4ad730 100644 --- a/uv.lock +++ b/uv.lock @@ -2325,7 +2325,6 @@ dev = [ [package.metadata] requires-dist = [ - { name = "datasets", specifier = ">=3.0.0" }, { name = "datasets", specifier = ">=4.0.0" }, { name = "einops", specifier = ">=0.8.0" }, { name = "hydra-core", specifier = ">=1.3.2" }, @@ -2336,7 +2335,6 @@ requires-dist = [ { name = "omegaconf", specifier = ">=2.3.0" }, { name = "openslide-python", specifier = ">=1.4.2" }, { name = "pandas", specifier = ">=2.0.0" }, - { name = "pyarrow", specifier = ">=19.0.0" }, { name = "pyarrow", specifier = ">=19.0.1" }, { name = "rationai-masks" }, { name = "rationai-mlkit", git = "https://gitlab.ics.muni.cz/rationai/digital-pathology/libraries/mlkit.git" }, From ac633d5717e1419a950c5b0fd5ff9b0c61f0265c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Mon, 11 May 2026 20:40:09 +0200 Subject: [PATCH 22/53] fix: use chunks --- ml/data/datasets/embedding_tiles.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/ml/data/datasets/embedding_tiles.py b/ml/data/datasets/embedding_tiles.py index bd86542..73421fa 100644 --- a/ml/data/datasets/embedding_tiles.py +++ b/ml/data/datasets/embedding_tiles.py @@ -116,15 +116,17 @@ def __init__( indices = joined_keys.column("_emb_idx") if isinstance(indices, pa.ChunkedArray): indices = indices.combine_chunks() - emb_contig = emb_col.combine_chunks() + # take first on chunked array (avoids combine_chunks on full 1.1M rows) + embeddings_arrow = emb_col.take(indices) + print(f"[dataset] take done in {time.time() - t:.1f}s", flush=True) + + t = time.time() + if isinstance(embeddings_arrow, pa.ChunkedArray): + embeddings_arrow = embeddings_arrow.combine_chunks() print( f"[dataset] combine_chunks done in {time.time() - t:.1f}s", flush=True ) - t = time.time() - embeddings_arrow = emb_contig.take(indices) - print(f"[dataset] take done in {time.time() - t:.1f}s", flush=True) - t = time.time() embedding_dim = len(embeddings_arrow.values) // len(embeddings_arrow) self.embeddings = ( From 2793562c4c98fa21cf5d7477e978f8f8060a1882 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Mon, 11 May 2026 20:54:17 +0200 Subject: [PATCH 23/53] fix: use numpy chunks --- ml/data/datasets/embedding_tiles.py | 47 +++++++++++++++++++---------- 1 file changed, 31 insertions(+), 16 deletions(-) diff --git a/ml/data/datasets/embedding_tiles.py b/ml/data/datasets/embedding_tiles.py index 73421fa..601dd08 100644 --- a/ml/data/datasets/embedding_tiles.py +++ b/ml/data/datasets/embedding_tiles.py @@ -113,27 +113,42 @@ def __init__( raise RuntimeError("inner join with embeddings produced empty dataset") t = time.time() - indices = joined_keys.column("_emb_idx") - if isinstance(indices, pa.ChunkedArray): - indices = indices.combine_chunks() - # take first on chunked array (avoids combine_chunks on full 1.1M rows) - embeddings_arrow = emb_col.take(indices) - print(f"[dataset] take done in {time.time() - t:.1f}s", flush=True) + _idx_col = joined_keys.column("_emb_idx") + if isinstance(_idx_col, pa.ChunkedArray): + _idx_col = _idx_col.combine_chunks() + indices_np = _idx_col.to_numpy() - t = time.time() - if isinstance(embeddings_arrow, pa.ChunkedArray): - embeddings_arrow = embeddings_arrow.combine_chunks() + first_chunk = emb_col.chunks[0] + embedding_dim = len(first_chunk.values) // len(first_chunk) + + # sort indices for sequential per-chunk access; restore order afterwards + sort_order = np.argsort(indices_np) + sorted_indices = indices_np[sort_order] + + chunk_offsets = np.concatenate( + [[0], np.cumsum([len(c) for c in emb_col.chunks])] + ) + embeddings = np.empty((len(indices_np), embedding_dim), dtype=np.float32) + for ci, chunk in enumerate(emb_col.chunks): + lo, hi = chunk_offsets[ci], chunk_offsets[ci + 1] + mask = (sorted_indices >= lo) & (sorted_indices < hi) + if not mask.any(): + continue + local_idx = sorted_indices[mask] - lo + chunk_np = ( + chunk.values.to_numpy(zero_copy_only=False) + .reshape(len(chunk), embedding_dim) + .astype(np.float32) + ) + embeddings[sort_order[mask]] = chunk_np[local_idx] + del emb_col print( - f"[dataset] combine_chunks done in {time.time() - t:.1f}s", flush=True + f"[dataset] chunk-wise extraction done in {time.time() - t:.1f}s", + flush=True, ) t = time.time() - embedding_dim = len(embeddings_arrow.values) // len(embeddings_arrow) - self.embeddings = ( - embeddings_arrow.values.to_numpy(zero_copy_only=False) - .astype(np.float32) - .reshape(len(embeddings_arrow), embedding_dim) - ) + self.embeddings = embeddings labels = joined_keys.column("label").to_pandas() unknown = set(labels.unique()) - set(class_indices.keys()) if unknown: From e81973eadb7f6a758f7fa7ee6eb1d16bb5cc1b91 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Mon, 11 May 2026 21:17:32 +0200 Subject: [PATCH 24/53] fix: call end at the end of the main --- ml/__main__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ml/__main__.py b/ml/__main__.py index 61ef400..d531a08 100644 --- a/ml/__main__.py +++ b/ml/__main__.py @@ -1,6 +1,7 @@ from random import randint import hydra +import mlflow from lightning import seed_everything from omegaconf import DictConfig, OmegaConf from rationai.mlkit import Trainer, autolog @@ -25,6 +26,7 @@ def main(config: DictConfig, logger: MLFlowLogger) -> None: model = hydra.utils.instantiate(config.model, _target_=MetaArch) trainer = hydra.utils.instantiate(config.trainer, _target_=Trainer, logger=logger) getattr(trainer, config.mode)(model, datamodule=data, ckpt_path=config.checkpoint) + mlflow.end_run() if __name__ == "__main__": From 0071592f12c05301fa88ca3b22594b5ba4ca2938 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Mon, 11 May 2026 22:12:05 +0200 Subject: [PATCH 25/53] chore: remove prints --- ml/data/datasets/embedding_tiles.py | 54 ----------------------------- ml/meta_arch.py | 6 +--- preprocessing/embedding_dataset.py | 24 ------------- scripts/submit_train_linear.py | 6 ++-- 4 files changed, 4 insertions(+), 86 deletions(-) diff --git a/ml/data/datasets/embedding_tiles.py b/ml/data/datasets/embedding_tiles.py index 601dd08..791a734 100644 --- a/ml/data/datasets/embedding_tiles.py +++ b/ml/data/datasets/embedding_tiles.py @@ -5,7 +5,6 @@ load time to produce ``(embedding, class_index, slide_id)`` triples. """ -import time from pathlib import Path import numpy as np @@ -43,15 +42,6 @@ def __init__( include_folds: list[int] | None = None, exclude_folds: list[int] | None = None, ) -> None: - tag = ( - f"include_folds={include_folds}" - if include_folds is not None - else f"exclude_folds={exclude_folds}" - if exclude_folds is not None - else "no_folds" - ) - print(f"[dataset] init: {tag}, metadata={metadata_uri}", flush=True) - t0 = time.time() meta_df = self._filter_metadata( metadata_uri, thresholds, @@ -59,60 +49,32 @@ def __init__( include_folds, exclude_folds, ) - print( - f"[dataset] metadata filtered: {len(meta_df)} rows " - f"({time.time() - t0:.1f}s)", - flush=True, - ) - t = time.time() emb_dir = self._resolve_uri(embedding_uri) - print( - f"[dataset] embedding artifacts resolved in {time.time() - t:.1f}s", - flush=True, - ) - - t = time.time() emb_table = pads.dataset(emb_dir, format="parquet").to_table( columns=["slide_id", "x", "y", "embedding"] ) - print( - f"[dataset] embedding parquet loaded: {emb_table.num_rows} rows " - f"({time.time() - t:.1f}s)", - flush=True, - ) - t = time.time() emb_col = emb_table.column("embedding") if pa.types.is_list(emb_col.type): target_type = pa.large_list(emb_col.type.value_type) emb_col = pa.chunked_array( [c.cast(target_type) for c in emb_col.chunks], type=target_type ) - print( - f"[dataset] embedding column cast in {time.time() - t:.1f}s", flush=True - ) emb_idx = pa.array(range(emb_table.num_rows), type=pa.int64()) emb_keys = emb_table.drop(["embedding"]).append_column("_emb_idx", emb_idx) del emb_table - t = time.time() meta_table = pa.Table.from_pandas(meta_df, preserve_index=False) del meta_df joined_keys = meta_table.join( emb_keys, keys=["slide_id", "x", "y"], join_type="inner" ) del emb_keys, meta_table - print( - f"[dataset] arrow join: {joined_keys.num_rows} rows " - f"({time.time() - t:.1f}s)", - flush=True, - ) if joined_keys.num_rows == 0: raise RuntimeError("inner join with embeddings produced empty dataset") - t = time.time() _idx_col = joined_keys.column("_emb_idx") if isinstance(_idx_col, pa.ChunkedArray): _idx_col = _idx_col.combine_chunks() @@ -142,12 +104,7 @@ def __init__( ) embeddings[sort_order[mask]] = chunk_np[local_idx] del emb_col - print( - f"[dataset] chunk-wise extraction done in {time.time() - t:.1f}s", - flush=True, - ) - t = time.time() self.embeddings = embeddings labels = joined_keys.column("label").to_pandas() unknown = set(labels.unique()) - set(class_indices.keys()) @@ -157,11 +114,6 @@ def __init__( ) self.labels = labels.map(class_indices).to_numpy(dtype=np.int64) self.slide_ids = joined_keys.column("slide_id").to_pandas().to_numpy() - print( - f"[dataset] numpy conversion done in {time.time() - t:.1f}s, " - f"total={time.time() - t0:.1f}s", - flush=True, - ) def __len__(self) -> int: return len(self.labels) @@ -181,14 +133,8 @@ def _filter_metadata( include_folds: list[int] | None, exclude_folds: list[int] | None, ) -> pd.DataFrame: - t = time.time() local = EmbeddingTilesDataset._resolve_uri(metadata_uri) df = pd.read_parquet(local) - print( - f"[dataset] metadata parquet loaded: {len(df)} rows " - f"({time.time() - t:.1f}s)", - flush=True, - ) roi_cols = [c for c in df.columns if c.startswith("roi_coverage_")] if not roi_cols: diff --git a/ml/meta_arch.py b/ml/meta_arch.py index 0d2cc55..7f09cc1 100644 --- a/ml/meta_arch.py +++ b/ml/meta_arch.py @@ -28,10 +28,6 @@ class MetaArch(LightningModule): ``nn.Identity`` and ``decode_head`` is a single ``nn.Linear``. """ - # TODO: support class_weights for CE loss when class distribution is heavily - # imbalanced. Inject either via config (list[float]) or compute from the - # train fold label distribution at setup(). - def __init__( self, backbone: nn.Module, @@ -83,7 +79,7 @@ def __init__( def setup(self, stage: str) -> None: if stage == "fit": - datamodule = cast(Any, self.trainer).datamodule + datamodule = cast("Any", self.trainer).datamodule labels = datamodule.train.labels num_classes = len(self.class_names) counts = np.bincount(labels, minlength=num_classes).astype(float) diff --git a/preprocessing/embedding_dataset.py b/preprocessing/embedding_dataset.py index 7669db9..86091d4 100644 --- a/preprocessing/embedding_dataset.py +++ b/preprocessing/embedding_dataset.py @@ -8,7 +8,6 @@ import shutil import tempfile -import time from pathlib import Path import hydra @@ -94,29 +93,19 @@ def join_embeddings( emb_keys = emb_table.drop(["embedding"]).append_column("_emb_idx", emb_idx) del emb_table, emb_idx - t = time.time() joined_keys = tiles_table.join( emb_keys, keys=["slide_id", "x", "y"], join_type="inner" ) del emb_keys - print( - f"[join] arrow key-join: {time.time() - t:.1f}s rows={joined_keys.num_rows}", - flush=True, - ) indices = joined_keys.column("_emb_idx") if isinstance(indices, pa.ChunkedArray): indices = indices.combine_chunks() - t = time.time() emb_contig = emb_col.combine_chunks() del emb_col - print(f"[join] combine_chunks: {time.time() - t:.1f}s", flush=True) - - t = time.time() embeddings = emb_contig.take(indices) del emb_contig - print(f"[join] take: {time.time() - t:.1f}s", flush=True) joined = joined_keys.drop(["_emb_idx"]).append_column("embedding", embeddings) dropped_no_embedding = tiles_table.num_rows - joined.num_rows @@ -133,7 +122,6 @@ def process_split( output_split_dir: Path, derive: bool, ) -> dict[str, int]: - print(f"[{split_name}] downloading source tiles", flush=True) src_local = mlflow.artifacts.download_artifacts( run_id=src_run_id, artifact_path=src_artifact_path ) @@ -182,11 +170,6 @@ def process_split( c for c in df.columns if c.startswith(("roi_coverage_", "tile_coverage_")) ] df = df.drop(columns=drop_cols) - print( - f"[{split_name}] {input_count} -> {after_tissue_filter} (tissue) " - f"-> {after_class_threshold} (class threshold), joining embeddings", - flush=True, - ) tiles_table = pa.Table.from_pandas(df, preserve_index=False) del df @@ -195,12 +178,6 @@ def process_split( tiles_table, embedding_run_id, split_name ) del tiles_table - if dropped_no_embedding != 0: - print( - f"WARNING: {dropped_no_embedding} tiles in split '{split_name}' have " - "no matching embedding and were dropped on join.", - flush=True, - ) sort_indices = pc.sort_indices(merged_table, sort_keys=[("slide_id", "ascending")]) merged_table = merged_table.take(sort_indices) @@ -214,7 +191,6 @@ def process_split( shutil.copy(slides_local, output_split_dir / "slides.parquet") log_label_distributions(split_name, merged_table) - print(f"[{split_name}] wrote {merged_table.num_rows} rows", flush=True) return { "input_count": input_count, diff --git a/scripts/submit_train_linear.py b/scripts/submit_train_linear.py index d19f674..93cf686 100644 --- a/scripts/submit_train_linear.py +++ b/scripts/submit_train_linear.py @@ -3,16 +3,16 @@ submit_job( job_name="tissue-classification-train-linear", - username="vcifka", + username=..., cpu=8, memory="64Gi", gpu=None, public=False, script=[ - "git clone --branch feature/ml-linear-classifier https://github.com/RationAI/tissue-classification.git workdir", + "git clone https://github.com/RationAI/tissue-classification.git workdir", "cd workdir", "uv sync", - "uv run python -m ml +experiment=ml/linear_classifier val_fold=0,1,2,3,4 --multirun", + "uv run python -m ml +experiment=... val_fold=0,1,2,3,4 --multirun", ], storage=[storage.secure.PROJECTS], ) From c0a7499de7fa602e72a2eb30d1840f519d97e350 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Mon, 11 May 2026 22:18:46 +0200 Subject: [PATCH 26/53] chore: remove debug prints, stale TODO, and unused preprocessing pipeline - Drop all print/logging/timing instrumentation from embedding_dataset.py and ml/data/datasets/embedding_tiles.py - Remove stale TODO comment in meta_arch.py (class_weights already implemented via setup() computing balanced weights from train fold label distribution) - Delete preprocessing/embedding_dataset.py and related configs/scripts (embedding dataset build pipeline not needed for this branch) - Add PR.md with title and description Co-Authored-By: Claude Sonnet 4.6 --- PR.md | 34 +++ .../preprocessing/embedding_dataset.yaml | 23 -- configs/preprocessing/embedding_dataset.yaml | 13 - preprocessing/embedding_dataset.py | 272 ------------------ scripts/submit_embedding_dataset.py | 18 -- 5 files changed, 34 insertions(+), 326 deletions(-) create mode 100644 PR.md delete mode 100644 configs/experiment/preprocessing/embedding_dataset.yaml delete mode 100644 configs/preprocessing/embedding_dataset.yaml delete mode 100644 preprocessing/embedding_dataset.py delete mode 100644 scripts/submit_embedding_dataset.py diff --git a/PR.md b/PR.md new file mode 100644 index 0000000..7c4eb76 --- /dev/null +++ b/PR.md @@ -0,0 +1,34 @@ +# feat: linear classifier training pipeline on precomputed embeddings + +## Summary + +Adds an end-to-end ML training pipeline for linear probing on precomputed tile +embeddings. Introduces the embedding dataset preprocessing step, a PyTorch +Lightning training module, and all supporting configs and submission scripts. + +## Changes + +### Preprocessing +- `preprocessing/_labels.py` — shared label/tissue-prop derivation logic. + +### ML training +- `ml/meta_arch.py` — `MetaArch` Lightning module: backbone + decode head + + CrossEntropyLoss with balanced class weights computed from the train fold. + Logs per-class metrics, confusion matrices, and per-slide accuracy. +- `ml/data/datasets/embedding_tiles.py` — `EmbeddingTilesDataset`: loads the + embedding parquet, inner-joins with metadata, and serves `(embedding, label, + slide_id)` triples. Stays in Arrow for the join to avoid large-list → pandas + conversion overhead. +- `ml/data/data_module.py` — Lightning `DataModule` wrapping train/val/test splits. +- `ml/callbacks/parquet_prediction_writer.py` — writes model predictions to Parquet. +- `configs/experiment/ml/linear_classifier.yaml` — full experiment config. +- `configs/ml/` — model, data, and trainer sub-configs. +- `scripts/submit_train_linear.py` — MLflow submission script. + +## Test plan + +- [ ] Run `submit_train_linear.py`; verify training converges and MLflow logs + loss, macro F1, per-class metrics, and confusion matrix figures. +- [ ] Check class weights are logged under `class_weight/` in MLflow. +- [ ] Confirm `parquet_prediction_writer` produces a valid predictions Parquet + on the test split. diff --git a/configs/experiment/preprocessing/embedding_dataset.yaml b/configs/experiment/preprocessing/embedding_dataset.yaml deleted file mode 100644 index 8004e2e..0000000 --- a/configs/experiment/preprocessing/embedding_dataset.yaml +++ /dev/null @@ -1,23 +0,0 @@ -# @package _global_ - -defaults: - - /data: dataset - - _self_ - -tissue_prop_min: 0.2 -thresholds: - Nerve: 0.0 - Blood: 0.0 - Connective-Tissue: 0.4 - Fat: 0.5 - Epithelium: 0.2 - Muscle: 0.4 - Other: 0.5 - -metadata: - run_name: Embedding dataset ${dataset.name} - description: "Join k-fold (${dataset.mlflow_artifacts.kfold_run_id}) and filter_tiles (${dataset.mlflow_artifacts.filter_tiles_run_id}) tile metadata with embeddings (${dataset.mlflow_artifacts.embedding_run_id})." - hyperparams: - kfold_run_id: ${dataset.mlflow_artifacts.kfold_run_id} - filter_tiles_run_id: ${dataset.mlflow_artifacts.filter_tiles_run_id} - embedding_run_id: ${dataset.mlflow_artifacts.embedding_run_id} diff --git a/configs/preprocessing/embedding_dataset.yaml b/configs/preprocessing/embedding_dataset.yaml deleted file mode 100644 index f4af56a..0000000 --- a/configs/preprocessing/embedding_dataset.yaml +++ /dev/null @@ -1,13 +0,0 @@ -# @package _global_ - -mlflow_artifact_path: embedding_dataset - -tissue_prop_min: ??? -thresholds: ??? - -metadata: - run_name: "Embedding dataset ${dataset.name}" - description: "Build embedding training dataset by joining k-fold/filter_tiles tile metadata with precomputed embeddings." - hyperparams: - tissue_prop_min: ${tissue_prop_min} - thresholds: ${thresholds} diff --git a/preprocessing/embedding_dataset.py b/preprocessing/embedding_dataset.py deleted file mode 100644 index 86091d4..0000000 --- a/preprocessing/embedding_dataset.py +++ /dev/null @@ -1,272 +0,0 @@ -"""Build an embedding training dataset by joining tile metadata with embeddings. - -Joins precomputed tile embeddings with k-fold metadata (train) / filter_tiles -metadata (test), applies tissue + per-class ROI thresholds before the join, and -emits a training-ready Parquet dataset (per-split ``slides.parquet`` + -``tiles.parquet``) ready for ``rationai.mlkit.data.datasets.SlidesTilesLoader``. -""" - -import shutil -import tempfile -from pathlib import Path - -import hydra -import mlflow -import mlflow.artifacts -import numpy as np -import pandas as pd -import pyarrow as pa -import pyarrow.compute as pc -import pyarrow.dataset as pads -import pyarrow.parquet as pq -from omegaconf import DictConfig, OmegaConf -from rationai.mlkit import autolog, with_cli_args -from rationai.mlkit.lightning.loggers import MLFlowLogger - -from preprocessing._labels import compute_label_and_tissue_prop - - -def apply_thresholds( - df: pd.DataFrame, - tissue_prop_min: float, - thresholds: dict[str, float], - roi_cols: list[str], -) -> tuple[pd.DataFrame, int]: - """Filter tiles by tissue + per-class thresholds and rewrite labels. - - Filters ``df`` by ``tissue_prop_min``, then keeps tiles where ANY class - meets its threshold; among passing classes, the highest-coverage one - becomes the label. - - Returns ``(filtered_df, after_tissue_count)`` so the caller can log both - intermediate counts. The returned df has its ``label`` column rewritten to - reflect the argmax-over-passers rule. - """ - df = df.loc[df["tissue_prop"] >= tissue_prop_min] - after_tissue = len(df) - if df.empty: - return df, after_tissue - - class_names = np.array([c.removeprefix("roi_coverage_") for c in roi_cols]) - thr = np.array([thresholds[c] for c in class_names], dtype=float) - roi = df[roi_cols].to_numpy() - passes = roi >= thr - keep = passes.any(axis=1) - - masked = np.where(passes, roi, -np.inf) - label_idx = masked.argmax(axis=1) - new_labels = class_names[label_idx] - - out = df.loc[pd.Series(keep, index=df.index)].copy() - out["label"] = new_labels[keep] - return out, after_tissue - - -def join_embeddings( - tiles_table: pa.Table, - embedding_run_id: str, - embedding_split: str, -) -> tuple[pa.Table, int]: - """Join filtered tile metadata with embeddings on (slide_id, x, y). - - Stays in Arrow throughout to avoid the very slow list -> pandas - conversion. Acero's join engine doesn't accept list columns in non-key - fields, so we join on keys plus a synthetic row index, then pull embeddings - via take(). The embedding column is cast per chunk to large_list to avoid - int32 offset overflow that bites take() when chunks are concatenated. - """ - emb_dir = mlflow.artifacts.download_artifacts( - run_id=embedding_run_id, artifact_path=f"{embedding_split}/tiles" - ) - emb_table = pads.dataset(emb_dir, format="parquet").to_table( - columns=["slide_id", "x", "y", "embedding"] - ) - - emb_col = emb_table.column("embedding") - if pa.types.is_list(emb_col.type): - target_type = pa.large_list(emb_col.type.value_type) - emb_col = pa.chunked_array( - [c.cast(target_type) for c in emb_col.chunks], type=target_type - ) - - emb_idx = pa.array(range(emb_table.num_rows), type=pa.int32()) - emb_keys = emb_table.drop(["embedding"]).append_column("_emb_idx", emb_idx) - del emb_table, emb_idx - - joined_keys = tiles_table.join( - emb_keys, keys=["slide_id", "x", "y"], join_type="inner" - ) - del emb_keys - - indices = joined_keys.column("_emb_idx") - if isinstance(indices, pa.ChunkedArray): - indices = indices.combine_chunks() - - emb_contig = emb_col.combine_chunks() - del emb_col - embeddings = emb_contig.take(indices) - del emb_contig - - joined = joined_keys.drop(["_emb_idx"]).append_column("embedding", embeddings) - dropped_no_embedding = tiles_table.num_rows - joined.num_rows - return joined, dropped_no_embedding - - -def process_split( - split_name: str, - src_run_id: str, - src_artifact_path: str, - embedding_run_id: str, - tissue_prop_min: float, - thresholds: dict[str, float], - output_split_dir: Path, - derive: bool, -) -> dict[str, int]: - src_local = mlflow.artifacts.download_artifacts( - run_id=src_run_id, artifact_path=src_artifact_path - ) - df = pads.dataset(src_local, format="parquet").to_table().to_pandas() - input_count = len(df) - - roi_cols = [c for c in df.columns if c.startswith("roi_coverage_")] - if not roi_cols: - raise RuntimeError( - f"No roi_coverage_* columns in {src_artifact_path}. " - "Cannot apply class thresholds." - ) - - classes_in_data = {c.removeprefix("roi_coverage_") for c in roi_cols} - missing = classes_in_data - set(thresholds.keys()) - if missing: - raise ValueError( - f"thresholds is missing entries for roi_coverage_* classes present " - f"in data: {sorted(missing)}" - ) - - if derive: - lbl, tp = compute_label_and_tissue_prop(df, roi_cols) - df["label"] = lbl - df["tissue_prop"] = tp - else: - required = {"label", "tissue_prop"} - missing_required = required - set(df.columns) - if missing_required: - raise RuntimeError( - f"Source split '{split_name}' (derive=False) is missing required " - f"columns {sorted(missing_required)} in {src_artifact_path}. " - "Expected the kfold_split artifact, which writes label/tissue_prop/fold." - ) - - df, after_tissue_filter = apply_thresholds( - df, tissue_prop_min, thresholds, roi_cols - ) - after_class_threshold = len(df) - if after_class_threshold == 0: - raise RuntimeError( - f"All {input_count} tiles dropped by thresholds for split '{split_name}'." - ) - - drop_cols = [ - c for c in df.columns if c.startswith(("roi_coverage_", "tile_coverage_")) - ] - df = df.drop(columns=drop_cols) - - tiles_table = pa.Table.from_pandas(df, preserve_index=False) - del df - - merged_table, dropped_no_embedding = join_embeddings( - tiles_table, embedding_run_id, split_name - ) - del tiles_table - - sort_indices = pc.sort_indices(merged_table, sort_keys=[("slide_id", "ascending")]) - merged_table = merged_table.take(sort_indices) - - output_split_dir.mkdir(parents=True, exist_ok=True) - pq.write_table(merged_table, str(output_split_dir / "tiles.parquet")) - - slides_local = mlflow.artifacts.download_artifacts( - run_id=embedding_run_id, artifact_path=f"{split_name}/slides.parquet" - ) - shutil.copy(slides_local, output_split_dir / "slides.parquet") - - log_label_distributions(split_name, merged_table) - - return { - "input_count": input_count, - "after_tissue_filter": after_tissue_filter, - "after_class_threshold": after_class_threshold, - "after_join": merged_table.num_rows, - "dropped_no_embedding": dropped_no_embedding, - } - - -def log_label_distributions(split_name: str, table: pa.Table) -> None: - has_fold = "fold" in table.schema.names - cols = ["label", "fold"] if has_fold else ["label"] - df = table.select(cols).to_pandas() - - label_dist = ( - df["label"].value_counts().rename_axis("label").reset_index(name="count") - ) - mlflow.log_table( - data=label_dist, - artifact_file=f"fold_statistics/{split_name}_label_distribution.json", - ) - - if has_fold: - fold_dist = ( - df.groupby(["fold", "label"]).size().unstack(fill_value=0).reset_index() - ) - mlflow.log_table( - data=fold_dist, - artifact_file=f"fold_statistics/{split_name}_fold_label_distribution.json", - ) - - -@with_cli_args(["+preprocessing=embedding_dataset"]) -@hydra.main(config_path="../configs", config_name="preprocessing", version_base=None) -@autolog -def main(config: DictConfig, logger: MLFlowLogger) -> None: - artifacts = config.dataset.mlflow_artifacts - kfold_run_id = artifacts.kfold_run_id - filter_tiles_run_id = artifacts.filter_tiles_run_id - embedding_run_id = artifacts.embedding_run_id - - tissue_prop_min = float(config.tissue_prop_min) - if tissue_prop_min <= 0: - raise ValueError( - f"tissue_prop_min must be > 0 (got {tissue_prop_min}); " - "otherwise background tiles are not filtered out." - ) - raw_thresholds = OmegaConf.to_container(config.thresholds, resolve=True) - if not isinstance(raw_thresholds, dict): - raise TypeError("config.thresholds must be a mapping of class -> threshold") - thresholds = {str(k): float(v) for k, v in raw_thresholds.items()} - - splits = [ - ("train", kfold_run_id, "kfold_split/kfold_tiles.parquet", False), - ("test", filter_tiles_run_id, "filter_tiles/test_tiles.parquet", True), - ] - - with tempfile.TemporaryDirectory() as tmp_root: - tmp_root_path = Path(tmp_root) - for split_name, src_run_id, src_artifact_path, derive in splits: - stats = process_split( - split_name=split_name, - src_run_id=src_run_id, - src_artifact_path=src_artifact_path, - embedding_run_id=embedding_run_id, - tissue_prop_min=tissue_prop_min, - thresholds=thresholds, - output_split_dir=tmp_root_path / split_name, - derive=derive, - ) - for key, value in stats.items(): - mlflow.log_metric(f"{split_name}_{key}", value) - - mlflow.log_artifacts(str(tmp_root_path), config.mlflow_artifact_path) - - -if __name__ == "__main__": - main() diff --git a/scripts/submit_embedding_dataset.py b/scripts/submit_embedding_dataset.py deleted file mode 100644 index 23977df..0000000 --- a/scripts/submit_embedding_dataset.py +++ /dev/null @@ -1,18 +0,0 @@ -from kube_jobs import storage, submit_job - - -submit_job( - job_name="tissue-classification-embedding-dataset", - username=..., - cpu=8, - memory="64Gi", - gpu=None, - public=False, - script=[ - "git clone https://github.com/RationAI/tissue-classification.git workdir", - "cd workdir", - "uv sync", - "uv run python -m preprocessing.embedding_dataset +experiment=...", - ], - storage=[storage.secure.PROJECTS], -) From fe918d152a75498982c4953e03b9ce81ce919de0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Mon, 11 May 2026 22:25:56 +0200 Subject: [PATCH 27/53] chore: remove markdown file --- PR.md | 34 ---------------------------------- 1 file changed, 34 deletions(-) delete mode 100644 PR.md diff --git a/PR.md b/PR.md deleted file mode 100644 index 7c4eb76..0000000 --- a/PR.md +++ /dev/null @@ -1,34 +0,0 @@ -# feat: linear classifier training pipeline on precomputed embeddings - -## Summary - -Adds an end-to-end ML training pipeline for linear probing on precomputed tile -embeddings. Introduces the embedding dataset preprocessing step, a PyTorch -Lightning training module, and all supporting configs and submission scripts. - -## Changes - -### Preprocessing -- `preprocessing/_labels.py` — shared label/tissue-prop derivation logic. - -### ML training -- `ml/meta_arch.py` — `MetaArch` Lightning module: backbone + decode head + - CrossEntropyLoss with balanced class weights computed from the train fold. - Logs per-class metrics, confusion matrices, and per-slide accuracy. -- `ml/data/datasets/embedding_tiles.py` — `EmbeddingTilesDataset`: loads the - embedding parquet, inner-joins with metadata, and serves `(embedding, label, - slide_id)` triples. Stays in Arrow for the join to avoid large-list → pandas - conversion overhead. -- `ml/data/data_module.py` — Lightning `DataModule` wrapping train/val/test splits. -- `ml/callbacks/parquet_prediction_writer.py` — writes model predictions to Parquet. -- `configs/experiment/ml/linear_classifier.yaml` — full experiment config. -- `configs/ml/` — model, data, and trainer sub-configs. -- `scripts/submit_train_linear.py` — MLflow submission script. - -## Test plan - -- [ ] Run `submit_train_linear.py`; verify training converges and MLflow logs - loss, macro F1, per-class metrics, and confusion matrix figures. -- [ ] Check class weights are logged under `class_weight/` in MLflow. -- [ ] Confirm `parquet_prediction_writer` produces a valid predictions Parquet - on the test split. From 6b7d1e8ec9bc2f45e5685ea03204416b018120d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Tue, 12 May 2026 08:51:51 +0200 Subject: [PATCH 28/53] fix: edge cases --- ml/__main__.py | 3 +++ ml/callbacks/parquet_prediction_writer.py | 33 ++++++++--------------- ml/data/data_module.py | 5 +++- 3 files changed, 18 insertions(+), 23 deletions(-) diff --git a/ml/__main__.py b/ml/__main__.py index d531a08..318c37c 100644 --- a/ml/__main__.py +++ b/ml/__main__.py @@ -25,6 +25,9 @@ def main(config: DictConfig, logger: MLFlowLogger) -> None: data = hydra.utils.instantiate(config.data, _recursive_=False, _target_=DataModule) model = hydra.utils.instantiate(config.model, _target_=MetaArch) trainer = hydra.utils.instantiate(config.trainer, _target_=Trainer, logger=logger) + allowed_modes = ["fit", "test", "validate", "predict"] + if config.mode not in allowed_modes: + raise ValueError(f"Invalid mode {config.mode!r}. Allowed: {allowed_modes}") getattr(trainer, config.mode)(model, datamodule=data, ckpt_path=config.checkpoint) mlflow.end_run() diff --git a/ml/callbacks/parquet_prediction_writer.py b/ml/callbacks/parquet_prediction_writer.py index 1fe8558..38d86fd 100644 --- a/ml/callbacks/parquet_prediction_writer.py +++ b/ml/callbacks/parquet_prediction_writer.py @@ -1,6 +1,5 @@ """Aggregate ``predict_step`` outputs and write them as a parquet artifact.""" -from collections.abc import Sequence from pathlib import Path from typing import Any @@ -22,19 +21,6 @@ class ParquetPredictionWriter(BasePredictionWriter): def __init__(self, output_filename: str = "predictions.parquet") -> None: super().__init__(write_interval="epoch") self.output_filename = output_filename - self._batches: list[dict[str, Any]] = [] - - def write_on_batch_end( - self, - trainer: pl.Trainer, - pl_module: pl.LightningModule, - prediction: dict[str, Any], - batch_indices: Sequence[int] | None, - batch: Any, - batch_idx: int, - dataloader_idx: int, - ) -> None: - self._batches.append(prediction) def write_on_epoch_end( self, @@ -43,18 +29,23 @@ def write_on_epoch_end( predictions: Any, batch_indices: Any, ) -> None: - if not self._batches: + if trainer.global_rank != 0: return slide_ids: list[str] = [] targets: list[int] = [] preds: list[int] = [] probs: list[np.ndarray] = [] - for b in self._batches: - slide_ids.extend(b["slide_id"]) - targets.extend(b["target"].tolist()) - preds.extend(b["pred"].tolist()) - probs.append(b["probs"].numpy()) + for dataloader_preds in predictions: + for b in dataloader_preds: + slide_ids.extend(b["slide_id"]) + targets.extend(b["target"].tolist()) + preds.extend(b["pred"].tolist()) + probs.append(b["probs"].numpy()) + + if not slide_ids: + return + prob_matrix = np.concatenate(probs, axis=0) class_names = getattr(pl_module, "class_names", None) @@ -74,5 +65,3 @@ def write_on_epoch_end( active = mlflow.active_run() if active is not None: mlflow.log_artifact(str(out_path), artifact_path="predictions") - - self._batches.clear() diff --git a/ml/data/data_module.py b/ml/data/data_module.py index c39b950..7302be5 100644 --- a/ml/data/data_module.py +++ b/ml/data/data_module.py @@ -33,7 +33,10 @@ def setup(self, stage: str) -> None: case "test": self.test = instantiate(self.datasets["test"]) case "predict": - self.predict = instantiate(self.datasets["predict"]) + dataset_cfg = self.datasets.get("predict") or self.datasets.get("test") + if dataset_cfg is None: + raise KeyError("Neither 'predict' nor 'test' dataset configured") + self.predict = instantiate(dataset_cfg) def train_dataloader(self) -> Iterable[Input]: return DataLoader( From 4ff988ef071f8ccb9fbacd4be701209db3d10c31 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Tue, 12 May 2026 09:55:59 +0200 Subject: [PATCH 29/53] feat: normalize the confusion matrix rows per class recall --- ml/meta_arch.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/ml/meta_arch.py b/ml/meta_arch.py index 7f09cc1..9ee227a 100644 --- a/ml/meta_arch.py +++ b/ml/meta_arch.py @@ -204,8 +204,11 @@ def _log_per_slide_accuracy(self) -> None: def _confmat_figure( matrix: np.ndarray, class_names: Iterable[str], title: str ) -> Figure: + row_sums = matrix.sum(axis=1, keepdims=True) + normalized = np.divide(matrix, row_sums, where=row_sums > 0, out=np.zeros_like(matrix, dtype=float)) + fig, ax = plt.subplots(figsize=(6, 5)) - im = ax.imshow(matrix, cmap="Blues") + im = ax.imshow(normalized, cmap="Blues", vmin=0, vmax=1) ax.set_title(title) ax.set_xlabel("Predicted") ax.set_ylabel("True") @@ -222,7 +225,7 @@ def _confmat_figure( str(matrix[i, j]), ha="center", va="center", - color="white" if matrix[i, j] > matrix.max() / 2 else "black", + color="white" if normalized[i, j] > 0.5 else "black", fontsize=8, ) fig.colorbar(im, ax=ax) From 32375b27f5753cc17a018492f74b617c3578614f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Tue, 12 May 2026 09:56:46 +0200 Subject: [PATCH 30/53] fix: format --- ml/meta_arch.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ml/meta_arch.py b/ml/meta_arch.py index 9ee227a..04de757 100644 --- a/ml/meta_arch.py +++ b/ml/meta_arch.py @@ -205,7 +205,9 @@ def _confmat_figure( matrix: np.ndarray, class_names: Iterable[str], title: str ) -> Figure: row_sums = matrix.sum(axis=1, keepdims=True) - normalized = np.divide(matrix, row_sums, where=row_sums > 0, out=np.zeros_like(matrix, dtype=float)) + normalized = np.divide( + matrix, row_sums, where=row_sums > 0, out=np.zeros_like(matrix, dtype=float) + ) fig, ax = plt.subplots(figsize=(6, 5)) im = ax.imshow(normalized, cmap="Blues", vmin=0, vmax=1) From af9538a09acde88a87627400e4511b743e658122 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Tue, 12 May 2026 15:13:42 +0200 Subject: [PATCH 31/53] feat: use stratified k fold run --- configs/data/dataset.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/data/dataset.yaml b/configs/data/dataset.yaml index 172e48f..7392686 100644 --- a/configs/data/dataset.yaml +++ b/configs/data/dataset.yaml @@ -14,7 +14,7 @@ dataset: test_split_filename: "split_mapping/test_split.csv" tiling_run_id: "fdf7550a2004474f8c7a05dc0cf1fd86" filter_tiles_run_id: "4e8f5d3c82124ea5a8f871a42d3ed9ba" - kfold_run_id: "850c81506684450b9af92296acfd045a" + kfold_run_id: "814611e8987d4d569b255b7a4749bc90" embedding_run_id: "c325e3a5033b4077b6febb0e3e6b0bd6" tissue_masks_run_id: "52bc0924f8624b259819c480c7cf213f" tissue_stats_run_id: "16ae2d003d88471b924e5f332415232a" From bc0819a1434324a96a4c4b21328ba47823f70746 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Tue, 12 May 2026 18:22:58 +0200 Subject: [PATCH 32/53] fix: remove criterion --- ml/meta_arch.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ml/meta_arch.py b/ml/meta_arch.py index 04de757..07d1431 100644 --- a/ml/meta_arch.py +++ b/ml/meta_arch.py @@ -22,27 +22,27 @@ class MetaArch(LightningModule): - """Top-level classification architecture: backbone + decode_head + criterion. + """Top-level classification architecture: backbone + decode_head. For linear probing on precomputed embeddings, ``backbone`` is typically ``nn.Identity`` and ``decode_head`` is a single ``nn.Linear``. + Criterion is class-weighted CrossEntropyLoss, computed from training labels in setup(). """ def __init__( self, backbone: nn.Module, decode_head: nn.Module, - criterion: nn.Module, class_indices: dict[str, int], learning_rate: float = 1e-3, weight_decay: float = 0.0, ) -> None: super().__init__() - self.save_hyperparameters(ignore=["backbone", "decode_head", "criterion"]) + self.save_hyperparameters(ignore=["backbone", "decode_head"]) self.backbone = backbone self.decode_head = decode_head - self.criterion = criterion + self.criterion: nn.Module self.class_names = [ n for n, _ in sorted(class_indices.items(), key=lambda kv: kv[1]) From b8e85e0af5c6d7fb576b815c74befcf1687b71a8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Tue, 12 May 2026 18:27:33 +0200 Subject: [PATCH 33/53] fix: remove criterion from configs --- configs/ml/model/linear_classifier.yaml | 3 --- 1 file changed, 3 deletions(-) diff --git a/configs/ml/model/linear_classifier.yaml b/configs/ml/model/linear_classifier.yaml index dfff43c..86c0ed2 100644 --- a/configs/ml/model/linear_classifier.yaml +++ b/configs/ml/model/linear_classifier.yaml @@ -9,9 +9,6 @@ model: in_features: 2560 out_features: ${len:${class_indices}} - criterion: - _target_: torch.nn.CrossEntropyLoss - class_indices: ${class_indices} learning_rate: 1.0e-3 From 3cc670de01f98c9774f56b2c8bcf944ac76fee7a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Wed, 13 May 2026 22:26:12 +0200 Subject: [PATCH 34/53] feat: add option to use different kfold strategies --- configs/data/dataset.yaml | 5 +++-- .../ml/linear_classifier_stratified_group_kfold.yaml | 8 ++++++++ .../experiment/ml/linear_classifier_stratified_kfold.yaml | 8 ++++++++ configs/{experiment => }/ml/linear_classifier.yaml | 8 +++++--- 4 files changed, 24 insertions(+), 5 deletions(-) create mode 100644 configs/experiment/ml/linear_classifier_stratified_group_kfold.yaml create mode 100644 configs/experiment/ml/linear_classifier_stratified_kfold.yaml rename configs/{experiment => }/ml/linear_classifier.yaml (81%) diff --git a/configs/data/dataset.yaml b/configs/data/dataset.yaml index 7392686..0ab4e0d 100644 --- a/configs/data/dataset.yaml +++ b/configs/data/dataset.yaml @@ -14,7 +14,8 @@ dataset: test_split_filename: "split_mapping/test_split.csv" tiling_run_id: "fdf7550a2004474f8c7a05dc0cf1fd86" filter_tiles_run_id: "4e8f5d3c82124ea5a8f871a42d3ed9ba" - kfold_run_id: "814611e8987d4d569b255b7a4749bc90" + stratified_kfold_run_id: "814611e8987d4d569b255b7a4749bc90" + stratified_group_kfold_run_id: "382b41d2fa894514908e8067949c4326" embedding_run_id: "c325e3a5033b4077b6febb0e3e6b0bd6" tissue_masks_run_id: "52bc0924f8624b259819c480c7cf213f" tissue_stats_run_id: "16ae2d003d88471b924e5f332415232a" @@ -58,4 +59,4 @@ dataset: "44 Klarzelliges Nierenzellkarzinom_1": "kidney" "50 Muzinöses Zystadenom_1": "breast" "85 Mammakarzinom NST": "breast" - "28 Zöliakie": "small intestine" \ No newline at end of file + "28 Zöliakie": "small intestine" diff --git a/configs/experiment/ml/linear_classifier_stratified_group_kfold.yaml b/configs/experiment/ml/linear_classifier_stratified_group_kfold.yaml new file mode 100644 index 0000000..471f5a3 --- /dev/null +++ b/configs/experiment/ml/linear_classifier_stratified_group_kfold.yaml @@ -0,0 +1,8 @@ +# @package _global_ + +defaults: + - /ml/linear_classifier + - _self_ + +kfold_strategy: stratified_group +kfold_run_id: ${dataset.mlflow_artifacts.stratified_group_kfold_run_id} diff --git a/configs/experiment/ml/linear_classifier_stratified_kfold.yaml b/configs/experiment/ml/linear_classifier_stratified_kfold.yaml new file mode 100644 index 0000000..c01fbbf --- /dev/null +++ b/configs/experiment/ml/linear_classifier_stratified_kfold.yaml @@ -0,0 +1,8 @@ +# @package _global_ + +defaults: + - /ml/linear_classifier + - _self_ + +kfold_strategy: stratified +kfold_run_id: ${dataset.mlflow_artifacts.stratified_kfold_run_id} diff --git a/configs/experiment/ml/linear_classifier.yaml b/configs/ml/linear_classifier.yaml similarity index 81% rename from configs/experiment/ml/linear_classifier.yaml rename to configs/ml/linear_classifier.yaml index 2e9a952..5606f0a 100644 --- a/configs/experiment/ml/linear_classifier.yaml +++ b/configs/ml/linear_classifier.yaml @@ -11,7 +11,8 @@ defaults: mode: fit embedding_run_id: ${dataset.mlflow_artifacts.embedding_run_id} -kfold_run_id: ${dataset.mlflow_artifacts.kfold_run_id} +kfold_strategy: stratified +kfold_run_id: ${dataset.mlflow_artifacts.stratified_kfold_run_id} filter_tiles_run_id: ${dataset.mlflow_artifacts.filter_tiles_run_id} train_embedding_uri: runs:/${embedding_run_id}/train/tiles @@ -34,10 +35,11 @@ thresholds: mlflow_artifact_path: linear_classifier metadata: - run_name: Linear Classifier ${dataset.name} fold=${val_fold} - description: "Linear probe over frozen Virchow2 embeddings (run ${embedding_run_id}), kfold metadata ${kfold_run_id}." + run_name: Linear Classifier ${dataset.name} ${kfold_strategy} fold=${val_fold} + description: "Linear probe over frozen Virchow2 embeddings (run ${embedding_run_id}), ${kfold_strategy} kfold metadata ${kfold_run_id}." hyperparams: embedding_run_id: ${embedding_run_id} + kfold_strategy: ${kfold_strategy} kfold_run_id: ${kfold_run_id} filter_tiles_run_id: ${filter_tiles_run_id} val_fold: ${val_fold} From 27ceea344ca0d9ce230864ef1d83301e533273ae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Thu, 14 May 2026 00:05:41 +0200 Subject: [PATCH 35/53] fix: lower LR and patience --- configs/ml/model/linear_classifier.yaml | 2 +- configs/ml/trainer/default.yaml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/configs/ml/model/linear_classifier.yaml b/configs/ml/model/linear_classifier.yaml index 86c0ed2..fc9bfed 100644 --- a/configs/ml/model/linear_classifier.yaml +++ b/configs/ml/model/linear_classifier.yaml @@ -11,5 +11,5 @@ model: class_indices: ${class_indices} - learning_rate: 1.0e-3 + learning_rate: 1.0e-4 weight_decay: 0.0 diff --git a/configs/ml/trainer/default.yaml b/configs/ml/trainer/default.yaml index cf5766d..a465a5c 100644 --- a/configs/ml/trainer/default.yaml +++ b/configs/ml/trainer/default.yaml @@ -13,7 +13,7 @@ trainer: _target_: lightning.pytorch.callbacks.EarlyStopping monitor: validation/loss mode: min - patience: 5 + patience: 2 model_checkpoint: _target_: lightning.pytorch.callbacks.ModelCheckpoint monitor: validation/loss From efde82ab3d76965011dc1a2335b8a6e84b9ca132 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Thu, 14 May 2026 13:26:57 +0200 Subject: [PATCH 36/53] fix: use f1 macro as a monitor --- configs/ml/trainer/default.yaml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/configs/ml/trainer/default.yaml b/configs/ml/trainer/default.yaml index a465a5c..63e3c15 100644 --- a/configs/ml/trainer/default.yaml +++ b/configs/ml/trainer/default.yaml @@ -11,15 +11,15 @@ trainer: callbacks: early_stopping: _target_: lightning.pytorch.callbacks.EarlyStopping - monitor: validation/loss - mode: min + monitor: validation/f1_macro + mode: max patience: 2 model_checkpoint: _target_: lightning.pytorch.callbacks.ModelCheckpoint - monitor: validation/loss - mode: min + monitor: validation/f1_macro + mode: max save_top_k: 1 - filename: "epoch={epoch}-val_loss={validation/loss:.4f}" + filename: "epoch={epoch}-val_f1={validation/f1_macro:.4f}" auto_insert_metric_name: false lr_monitor: _target_: lightning.pytorch.callbacks.LearningRateMonitor From c8102de253e7c1c19baf530881f7415b8a80b1dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Thu, 14 May 2026 15:22:12 +0200 Subject: [PATCH 37/53] fix: rever back to validation loss --- configs/ml/trainer/default.yaml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/configs/ml/trainer/default.yaml b/configs/ml/trainer/default.yaml index 63e3c15..a465a5c 100644 --- a/configs/ml/trainer/default.yaml +++ b/configs/ml/trainer/default.yaml @@ -11,15 +11,15 @@ trainer: callbacks: early_stopping: _target_: lightning.pytorch.callbacks.EarlyStopping - monitor: validation/f1_macro - mode: max + monitor: validation/loss + mode: min patience: 2 model_checkpoint: _target_: lightning.pytorch.callbacks.ModelCheckpoint - monitor: validation/f1_macro - mode: max + monitor: validation/loss + mode: min save_top_k: 1 - filename: "epoch={epoch}-val_f1={validation/f1_macro:.4f}" + filename: "epoch={epoch}-val_loss={validation/loss:.4f}" auto_insert_metric_name: false lr_monitor: _target_: lightning.pytorch.callbacks.LearningRateMonitor From c5bab90c0dadc9ddd74fdf55101fbefeaae966ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Thu, 14 May 2026 17:25:36 +0200 Subject: [PATCH 38/53] fix: add weight decay 1e-3 to linear classifier Train loss ~0.02 vs val loss ~0.32 indicated severe overfit on the linear probe. AdamW weight_decay was 0; bump to 1e-3 to regularize the head. --- configs/ml/model/linear_classifier.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/ml/model/linear_classifier.yaml b/configs/ml/model/linear_classifier.yaml index fc9bfed..d555804 100644 --- a/configs/ml/model/linear_classifier.yaml +++ b/configs/ml/model/linear_classifier.yaml @@ -12,4 +12,4 @@ model: class_indices: ${class_indices} learning_rate: 1.0e-4 - weight_decay: 0.0 + weight_decay: 1.0e-3 From 475b67c7a05af8bd5ab071182cf9cafb81c43498 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Thu, 14 May 2026 19:05:00 +0200 Subject: [PATCH 39/53] Revert "fix: add weight decay 1e-3 to linear classifier" This reverts commit c5bab90c0dadc9ddd74fdf55101fbefeaae966ee. --- configs/ml/model/linear_classifier.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/ml/model/linear_classifier.yaml b/configs/ml/model/linear_classifier.yaml index d555804..fc9bfed 100644 --- a/configs/ml/model/linear_classifier.yaml +++ b/configs/ml/model/linear_classifier.yaml @@ -12,4 +12,4 @@ model: class_indices: ${class_indices} learning_rate: 1.0e-4 - weight_decay: 1.0e-3 + weight_decay: 0.0 From 43663a9326e5d54dbf72335ee1d0d423b53bcda3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Thu, 14 May 2026 20:28:59 +0200 Subject: [PATCH 40/53] feat: add logistic regression --- ...tic_regression_stratified_group_kfold.yaml | 8 + ..._logistic_regression_stratified_kfold.yaml | 8 + configs/ml/lbfgs_logistic_regression.yaml | 64 +++++++ ml/__main__.py | 7 + ml/sklearn_linear.py | 176 ++++++++++++++++++ 5 files changed, 263 insertions(+) create mode 100644 configs/experiment/ml/lbfgs_logistic_regression_stratified_group_kfold.yaml create mode 100644 configs/experiment/ml/lbfgs_logistic_regression_stratified_kfold.yaml create mode 100644 configs/ml/lbfgs_logistic_regression.yaml create mode 100644 ml/sklearn_linear.py diff --git a/configs/experiment/ml/lbfgs_logistic_regression_stratified_group_kfold.yaml b/configs/experiment/ml/lbfgs_logistic_regression_stratified_group_kfold.yaml new file mode 100644 index 0000000..49c954e --- /dev/null +++ b/configs/experiment/ml/lbfgs_logistic_regression_stratified_group_kfold.yaml @@ -0,0 +1,8 @@ +# @package _global_ + +defaults: + - /ml/lbfgs_logistic_regression + - _self_ + +kfold_strategy: stratified_group +kfold_run_id: ${dataset.mlflow_artifacts.stratified_group_kfold_run_id} diff --git a/configs/experiment/ml/lbfgs_logistic_regression_stratified_kfold.yaml b/configs/experiment/ml/lbfgs_logistic_regression_stratified_kfold.yaml new file mode 100644 index 0000000..3d15556 --- /dev/null +++ b/configs/experiment/ml/lbfgs_logistic_regression_stratified_kfold.yaml @@ -0,0 +1,8 @@ +# @package _global_ + +defaults: + - /ml/lbfgs_logistic_regression + - _self_ + +kfold_strategy: stratified +kfold_run_id: ${dataset.mlflow_artifacts.stratified_kfold_run_id} diff --git a/configs/ml/lbfgs_logistic_regression.yaml b/configs/ml/lbfgs_logistic_regression.yaml new file mode 100644 index 0000000..c6ef999 --- /dev/null +++ b/configs/ml/lbfgs_logistic_regression.yaml @@ -0,0 +1,64 @@ +# @package _global_ + +defaults: + - /data: dataset + - /class_mapping: collapse_alterations_to_other + - /ml/data: embedding + - _self_ + +mode: fit +runner: sklearn_linear + +embedding_run_id: ${dataset.mlflow_artifacts.embedding_run_id} +kfold_strategy: stratified +kfold_run_id: ${dataset.mlflow_artifacts.stratified_kfold_run_id} +filter_tiles_run_id: ${dataset.mlflow_artifacts.filter_tiles_run_id} + +train_embedding_uri: runs:/${embedding_run_id}/train/tiles +test_embedding_uri: runs:/${embedding_run_id}/test/tiles +train_metadata_uri: runs:/${kfold_run_id}/kfold_split/kfold_tiles.parquet +test_metadata_uri: runs:/${filter_tiles_run_id}/filter_tiles/test_tiles.parquet + +val_fold: 0 + +tissue_prop_min: 0.2 +thresholds: + Nerve: 0.0 + Blood: 0.0 + Connective-Tissue: 0.4 + Fat: 0.6 + Epithelium: 0.2 + Muscle: 0.5 + Other: 0.5 + +model: + solver: lbfgs + penalty: l2 + C: 1.0 + class_weight: balanced + max_iter: 1000 + tol: 1.0e-4 + n_jobs: null + verbose: 0 + standardize: true + +mlflow_artifact_path: lbfgs_logistic_regression + +metadata: + run_name: LBFGS Logistic Regression ${dataset.name} ${kfold_strategy} fold=${val_fold} + description: "LBFGS multinomial logistic regression over frozen Virchow2 embeddings (run ${embedding_run_id}), ${kfold_strategy} kfold metadata ${kfold_run_id}." + hyperparams: + embedding_run_id: ${embedding_run_id} + kfold_strategy: ${kfold_strategy} + kfold_run_id: ${kfold_run_id} + filter_tiles_run_id: ${filter_tiles_run_id} + val_fold: ${val_fold} + tissue_prop_min: ${tissue_prop_min} + thresholds: ${thresholds} + solver: ${model.solver} + penalty: ${model.penalty} + C: ${model.C} + class_weight: ${model.class_weight} + max_iter: ${model.max_iter} + tol: ${model.tol} + standardize: ${model.standardize} diff --git a/ml/__main__.py b/ml/__main__.py index 318c37c..620c882 100644 --- a/ml/__main__.py +++ b/ml/__main__.py @@ -20,6 +20,13 @@ @hydra.main(config_path="../configs", config_name="ml", version_base=None) @autolog def main(config: DictConfig, logger: MLFlowLogger) -> None: + if config.get("runner") == "sklearn_linear": + from ml.sklearn_linear import run as run_sklearn_linear + + run_sklearn_linear(config) + mlflow.end_run() + return + seed_everything(config.seed, workers=True) data = hydra.utils.instantiate(config.data, _recursive_=False, _target_=DataModule) diff --git a/ml/sklearn_linear.py b/ml/sklearn_linear.py new file mode 100644 index 0000000..f8fb536 --- /dev/null +++ b/ml/sklearn_linear.py @@ -0,0 +1,176 @@ +from pathlib import Path +from random import randint +from typing import Any + +import hydra +import joblib +import mlflow +import numpy as np +import pandas as pd +from matplotlib import pyplot as plt +from omegaconf import DictConfig, OmegaConf +from rationai.mlkit import autolog +from rationai.mlkit.lightning.loggers import MLFlowLogger +from sklearn.linear_model import LogisticRegression +from sklearn.metrics import confusion_matrix, f1_score, recall_score +from sklearn.pipeline import Pipeline +from sklearn.preprocessing import StandardScaler + +from ml.data import DataModule +from ml.meta_arch import _confmat_figure + + +if not OmegaConf.has_resolver("random_seed"): + OmegaConf.register_new_resolver( + "random_seed", lambda: randint(0, 2**31), use_cache=True + ) +if not OmegaConf.has_resolver("len"): + OmegaConf.register_new_resolver("len", lambda x: len(x)) + + +@hydra.main(config_path="../configs", config_name="ml", version_base=None) +@autolog +def main(config: DictConfig, logger: MLFlowLogger) -> None: + run(config) + mlflow.end_run() + + +def run(config: DictConfig) -> None: + if config.mode != "fit": + raise ValueError("sklearn_linear currently supports only mode='fit'") + + np.random.seed(config.seed) + data = hydra.utils.instantiate(config.data, _recursive_=False, _target_=DataModule) + data.setup("fit") + + x_train = data.train.embeddings + y_train = data.train.labels + x_val = data.val.embeddings + y_val = data.val.labels + + class_names = [ + n for n, _ in sorted(config.class_indices.items(), key=lambda kv: kv[1]) + ] + model = _build_model(config) + model.fit(x_train, y_train) + + _log_split_metrics( + model, x_val, y_val, data.val.slide_ids, class_names, "validation" + ) + _log_model(model) + + mlflow.log_params( + { + "model_type": "sklearn_logistic_regression", + "solver": config.model.solver, + "penalty": config.model.penalty, + "C": config.model.C, + "max_iter": config.model.max_iter, + "tol": config.model.tol, + "class_weight": config.model.class_weight, + "standardize": config.model.standardize, + "train_tiles": len(y_train), + "validation_tiles": len(y_val), + } + ) + + +def _build_model(config: DictConfig) -> Pipeline: + steps: list[tuple[str, Any]] = [] + if config.model.standardize: + steps.append(("scaler", StandardScaler())) + steps.append( + ( + "classifier", + LogisticRegression( + C=config.model.C, + class_weight=config.model.class_weight, + max_iter=config.model.max_iter, + n_jobs=config.model.n_jobs, + penalty=config.model.penalty, + random_state=config.seed, + solver=config.model.solver, + tol=config.model.tol, + verbose=config.model.verbose, + ), + ) + ) + return Pipeline(steps) + + +def _log_split_metrics( + model: Pipeline, + inputs: np.ndarray, + targets: np.ndarray, + slide_ids: np.ndarray, + class_names: list[str], + split: str, +) -> None: + labels = np.arange(len(class_names)) + preds = model.predict(inputs) + probs = _predict_proba_for_all_classes(model, inputs, labels) + + mlflow.log_metric( + f"{split}/acc_macro", + recall_score(targets, preds, labels=labels, average="macro", zero_division=0), + ) + mlflow.log_metric( + f"{split}/f1_macro", + f1_score(targets, preds, average="macro", zero_division=0), + ) + + per_class_acc = recall_score( + targets, preds, labels=labels, average=None, zero_division=0 + ) + per_class_f1 = f1_score( + targets, preds, labels=labels, average=None, zero_division=0 + ) + for cls_name, acc, f1 in zip( + class_names, per_class_acc.tolist(), per_class_f1.tolist(), strict=True + ): + mlflow.log_metric(f"{split}/acc_per_class/{cls_name}", acc) + mlflow.log_metric(f"{split}/f1_per_class/{cls_name}", f1) + + matrix = confusion_matrix(targets, preds, labels=labels) + fig = _confmat_figure(matrix, class_names, title=f"{split} confmat") + try: + mlflow.log_figure(fig, artifact_file=f"confusion_matrix/{split}.png") + finally: + plt.close(fig) + + prob_columns = [f"prob_{c}" for c in class_names] + predictions = pd.DataFrame( + { + "slide_id": slide_ids, + "target": targets, + "pred": preds, + } + ) + predictions = pd.concat( + [predictions, pd.DataFrame(probs, columns=prob_columns)], axis=1 + ) + out_path = Path(f"{split}_predictions.parquet") + predictions.to_parquet(out_path, index=False) + mlflow.log_artifact(str(out_path), artifact_path="predictions") + + +def _predict_proba_for_all_classes( + model: Pipeline, inputs: np.ndarray, labels: np.ndarray +) -> np.ndarray: + raw_probs = model.predict_proba(inputs) + probs = np.zeros((len(inputs), len(labels)), dtype=raw_probs.dtype) + for source_idx, class_idx in enumerate(model.classes_): + matching = np.flatnonzero(labels == class_idx) + if len(matching) == 1: + probs[:, matching[0]] = raw_probs[:, source_idx] + return probs + + +def _log_model(model: Pipeline) -> None: + out_path = Path("model.joblib") + joblib.dump(model, out_path) + mlflow.log_artifact(str(out_path), artifact_path="model") + + +if __name__ == "__main__": + main() From a2fe451b122ac4a856557f3f31691121b05f9da8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Thu, 14 May 2026 20:53:54 +0200 Subject: [PATCH 41/53] feat: polish and add two distinct submission scripts --- configs/ml/lbfgs_logistic_regression.yaml | 2 +- ml/sklearn_linear.py | 80 ++++++++++++++----- ...linear.py => submit_train_linear_probe.py} | 9 ++- scripts/submit_train_logistic_regression.py | 24 ++++++ 4 files changed, 94 insertions(+), 21 deletions(-) rename scripts/{submit_train_linear.py => submit_train_linear_probe.py} (61%) create mode 100644 scripts/submit_train_logistic_regression.py diff --git a/configs/ml/lbfgs_logistic_regression.yaml b/configs/ml/lbfgs_logistic_regression.yaml index c6ef999..43e1bdd 100644 --- a/configs/ml/lbfgs_logistic_regression.yaml +++ b/configs/ml/lbfgs_logistic_regression.yaml @@ -45,7 +45,7 @@ model: mlflow_artifact_path: lbfgs_logistic_regression metadata: - run_name: LBFGS Logistic Regression ${dataset.name} ${kfold_strategy} fold=${val_fold} + run_name: LBFGS Logistic Regression ${dataset.name} ${kfold_strategy} fold=${val_fold} C=${model.C} std=${model.standardize} description: "LBFGS multinomial logistic regression over frozen Virchow2 embeddings (run ${embedding_run_id}), ${kfold_strategy} kfold metadata ${kfold_run_id}." hyperparams: embedding_run_id: ${embedding_run_id} diff --git a/ml/sklearn_linear.py b/ml/sklearn_linear.py index f8fb536..08ea754 100644 --- a/ml/sklearn_linear.py +++ b/ml/sklearn_linear.py @@ -3,8 +3,8 @@ from typing import Any import hydra -import joblib import mlflow +import mlflow.sklearn import numpy as np import pandas as pd from matplotlib import pyplot as plt @@ -42,6 +42,8 @@ def run(config: DictConfig) -> None: np.random.seed(config.seed) data = hydra.utils.instantiate(config.data, _recursive_=False, _target_=DataModule) data.setup("fit") + if "test" in data.datasets: + data.setup("test") x_train = data.train.embeddings y_train = data.train.labels @@ -53,26 +55,42 @@ def run(config: DictConfig) -> None: ] model = _build_model(config) model.fit(x_train, y_train) + _log_convergence(model, config) _log_split_metrics( model, x_val, y_val, data.val.slide_ids, class_names, "validation" ) + if hasattr(data, "test"): + _log_split_metrics( + model, + data.test.embeddings, + data.test.labels, + data.test.slide_ids, + class_names, + "test", + ) _log_model(model) - mlflow.log_params( - { - "model_type": "sklearn_logistic_regression", - "solver": config.model.solver, - "penalty": config.model.penalty, - "C": config.model.C, - "max_iter": config.model.max_iter, - "tol": config.model.tol, - "class_weight": config.model.class_weight, - "standardize": config.model.standardize, - "train_tiles": len(y_train), - "validation_tiles": len(y_val), - } - ) + params = { + "model_type": "sklearn_logistic_regression", + "solver": config.model.solver, + "penalty": config.model.penalty, + "C": config.model.C, + "max_iter": config.model.max_iter, + "tol": config.model.tol, + "class_weight": config.model.class_weight, + "standardize": config.model.standardize, + "val_fold": config.val_fold, + "kfold_strategy": config.kfold_strategy, + "embedding_run_id": config.embedding_run_id, + "kfold_run_id": config.kfold_run_id, + "filter_tiles_run_id": config.filter_tiles_run_id, + "train_tiles": len(y_train), + "validation_tiles": len(y_val), + } + if hasattr(data, "test"): + params["test_tiles"] = len(data.test.labels) + mlflow.log_params(params) def _build_model(config: DictConfig) -> Pipeline: @@ -152,6 +170,27 @@ def _log_split_metrics( out_path = Path(f"{split}_predictions.parquet") predictions.to_parquet(out_path, index=False) mlflow.log_artifact(str(out_path), artifact_path="predictions") + _log_per_slide_accuracy(predictions, split) + + +def _log_per_slide_accuracy(predictions: pd.DataFrame, split: str) -> None: + rows = [] + for slide_id, slide_df in predictions.groupby("slide_id"): + rows.append( + { + "slide_id": slide_id, + "tile_accuracy": float((slide_df["pred"] == slide_df["target"]).mean()), + "n_tiles": len(slide_df), + } + ) + if not rows: + return + + per_slide = pd.DataFrame(rows) + mlflow.log_metric(f"{split}/slide_acc_mean", per_slide["tile_accuracy"].mean()) + mlflow.log_metric(f"{split}/slide_acc_median", per_slide["tile_accuracy"].median()) + mlflow.log_metric(f"{split}/slide_acc_min", per_slide["tile_accuracy"].min()) + mlflow.log_table(per_slide, artifact_file=f"per_slide/{split}_tile_accuracy.json") def _predict_proba_for_all_classes( @@ -166,10 +205,15 @@ def _predict_proba_for_all_classes( return probs +def _log_convergence(model: Pipeline, config: DictConfig) -> None: + classifier = model.named_steps["classifier"] + n_iter = int(classifier.n_iter_.max()) + mlflow.log_metric("n_iter", n_iter) + mlflow.log_param("converged", n_iter < config.model.max_iter) + + def _log_model(model: Pipeline) -> None: - out_path = Path("model.joblib") - joblib.dump(model, out_path) - mlflow.log_artifact(str(out_path), artifact_path="model") + mlflow.sklearn.log_model(model, artifact_path="model") if __name__ == "__main__": diff --git a/scripts/submit_train_linear.py b/scripts/submit_train_linear_probe.py similarity index 61% rename from scripts/submit_train_linear.py rename to scripts/submit_train_linear_probe.py index 93cf686..1c7b609 100644 --- a/scripts/submit_train_linear.py +++ b/scripts/submit_train_linear_probe.py @@ -2,7 +2,7 @@ submit_job( - job_name="tissue-classification-train-linear", + job_name="tissue-classification-train-linear-probe", username=..., cpu=8, memory="64Gi", @@ -12,7 +12,12 @@ "git clone https://github.com/RationAI/tissue-classification.git workdir", "cd workdir", "uv sync", - "uv run python -m ml +experiment=... val_fold=0,1,2,3,4 --multirun", + ( + "uv run python -m ml " + "+experiment=ml/..." + "val_fold=0,1,2,3,4 " + "--multirun" + ), ], storage=[storage.secure.PROJECTS], ) diff --git a/scripts/submit_train_logistic_regression.py b/scripts/submit_train_logistic_regression.py new file mode 100644 index 0000000..858ec8c --- /dev/null +++ b/scripts/submit_train_logistic_regression.py @@ -0,0 +1,24 @@ +from kube_jobs import storage, submit_job + + +submit_job( + job_name="tissue-classification-train-logistic-regression", + username=..., + cpu=8, + memory="64Gi", + gpu=None, + public=False, + script=[ + "git clone https://github.com/RationAI/tissue-classification.git workdir", + "cd workdir", + "uv sync", + ( + "uv run python -m ml " + "+experiment=ml/..." + "val_fold=0,1,2,3,4 " + "model.C=0.001,0.01,0.1,1,10,100 " + "--multirun" + ), + ], + storage=[storage.secure.PROJECTS], +) From 31ecf6d3de399581e8bbf28a7752e38515b07433 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Thu, 14 May 2026 21:03:31 +0200 Subject: [PATCH 42/53] fix: submission scripts --- scripts/submit_train_linear_probe.py | 9 ++------- scripts/submit_train_logistic_regression.py | 8 +------- 2 files changed, 3 insertions(+), 14 deletions(-) diff --git a/scripts/submit_train_linear_probe.py b/scripts/submit_train_linear_probe.py index 1c7b609..93cf686 100644 --- a/scripts/submit_train_linear_probe.py +++ b/scripts/submit_train_linear_probe.py @@ -2,7 +2,7 @@ submit_job( - job_name="tissue-classification-train-linear-probe", + job_name="tissue-classification-train-linear", username=..., cpu=8, memory="64Gi", @@ -12,12 +12,7 @@ "git clone https://github.com/RationAI/tissue-classification.git workdir", "cd workdir", "uv sync", - ( - "uv run python -m ml " - "+experiment=ml/..." - "val_fold=0,1,2,3,4 " - "--multirun" - ), + "uv run python -m ml +experiment=... val_fold=0,1,2,3,4 --multirun", ], storage=[storage.secure.PROJECTS], ) diff --git a/scripts/submit_train_logistic_regression.py b/scripts/submit_train_logistic_regression.py index 858ec8c..b043252 100644 --- a/scripts/submit_train_logistic_regression.py +++ b/scripts/submit_train_logistic_regression.py @@ -12,13 +12,7 @@ "git clone https://github.com/RationAI/tissue-classification.git workdir", "cd workdir", "uv sync", - ( - "uv run python -m ml " - "+experiment=ml/..." - "val_fold=0,1,2,3,4 " - "model.C=0.001,0.01,0.1,1,10,100 " - "--multirun" - ), + "uv run python -m ml +experiment=ml/... val_fold=0,1,2,3,4 model.C=0.001,0.01,0.1,1,10,100 --multirun", ], storage=[storage.secure.PROJECTS], ) From ff8d0bf45c2ff2aa1f0a6c0ef4ff024bd5e14500 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Thu, 14 May 2026 21:33:59 +0200 Subject: [PATCH 43/53] feat: implement knn --- .../ml/knn_stratified_group_kfold.yaml | 8 ++ .../experiment/ml/knn_stratified_kfold.yaml | 8 ++ configs/ml/knn.yaml | 62 ++++++++++ ml/__main__.py | 6 + ml/sklearn_knn.py | 111 ++++++++++++++++++ scripts/submit_train_knn.py | 18 +++ 6 files changed, 213 insertions(+) create mode 100644 configs/experiment/ml/knn_stratified_group_kfold.yaml create mode 100644 configs/experiment/ml/knn_stratified_kfold.yaml create mode 100644 configs/ml/knn.yaml create mode 100644 ml/sklearn_knn.py create mode 100644 scripts/submit_train_knn.py diff --git a/configs/experiment/ml/knn_stratified_group_kfold.yaml b/configs/experiment/ml/knn_stratified_group_kfold.yaml new file mode 100644 index 0000000..3c9cafe --- /dev/null +++ b/configs/experiment/ml/knn_stratified_group_kfold.yaml @@ -0,0 +1,8 @@ +# @package _global_ + +defaults: + - /ml/knn + - _self_ + +kfold_strategy: stratified_group +kfold_run_id: ${dataset.mlflow_artifacts.stratified_group_kfold_run_id} diff --git a/configs/experiment/ml/knn_stratified_kfold.yaml b/configs/experiment/ml/knn_stratified_kfold.yaml new file mode 100644 index 0000000..87876c2 --- /dev/null +++ b/configs/experiment/ml/knn_stratified_kfold.yaml @@ -0,0 +1,8 @@ +# @package _global_ + +defaults: + - /ml/knn + - _self_ + +kfold_strategy: stratified +kfold_run_id: ${dataset.mlflow_artifacts.stratified_kfold_run_id} diff --git a/configs/ml/knn.yaml b/configs/ml/knn.yaml new file mode 100644 index 0000000..b2a6561 --- /dev/null +++ b/configs/ml/knn.yaml @@ -0,0 +1,62 @@ +# @package _global_ + +defaults: + - /data: dataset + - /class_mapping: collapse_alterations_to_other + - /ml/data: embedding + - _self_ + +mode: fit +runner: sklearn_knn + +embedding_run_id: ${dataset.mlflow_artifacts.embedding_run_id} +kfold_strategy: stratified +kfold_run_id: ${dataset.mlflow_artifacts.stratified_kfold_run_id} +filter_tiles_run_id: ${dataset.mlflow_artifacts.filter_tiles_run_id} + +train_embedding_uri: runs:/${embedding_run_id}/train/tiles +test_embedding_uri: runs:/${embedding_run_id}/test/tiles +train_metadata_uri: runs:/${kfold_run_id}/kfold_split/kfold_tiles.parquet +test_metadata_uri: runs:/${filter_tiles_run_id}/filter_tiles/test_tiles.parquet + +val_fold: 0 + +tissue_prop_min: 0.2 +thresholds: + Nerve: 0.0 + Blood: 0.0 + Connective-Tissue: 0.4 + Fat: 0.6 + Epithelium: 0.2 + Muscle: 0.5 + Other: 0.5 + +model: + n_neighbors: 25 + weights: distance + metric: cosine + algorithm: brute + n_jobs: -1 + standardize: false + log_model: false + +mlflow_artifact_path: knn + +metadata: + run_name: kNN ${dataset.name} k=${model.n_neighbors} ${model.weights} ${model.metric} ${kfold_strategy} fold=${val_fold} + description: "kNN classifier over frozen Virchow2 embeddings (run ${embedding_run_id}), ${kfold_strategy} kfold metadata ${kfold_run_id}." + hyperparams: + embedding_run_id: ${embedding_run_id} + kfold_strategy: ${kfold_strategy} + kfold_run_id: ${kfold_run_id} + filter_tiles_run_id: ${filter_tiles_run_id} + val_fold: ${val_fold} + tissue_prop_min: ${tissue_prop_min} + thresholds: ${thresholds} + n_neighbors: ${model.n_neighbors} + weights: ${model.weights} + metric: ${model.metric} + algorithm: ${model.algorithm} + n_jobs: ${model.n_jobs} + standardize: ${model.standardize} + log_model: ${model.log_model} diff --git a/ml/__main__.py b/ml/__main__.py index 620c882..3c4cad0 100644 --- a/ml/__main__.py +++ b/ml/__main__.py @@ -26,6 +26,12 @@ def main(config: DictConfig, logger: MLFlowLogger) -> None: run_sklearn_linear(config) mlflow.end_run() return + if config.get("runner") == "sklearn_knn": + from ml.sklearn_knn import run as run_sklearn_knn + + run_sklearn_knn(config) + mlflow.end_run() + return seed_everything(config.seed, workers=True) diff --git a/ml/sklearn_knn.py b/ml/sklearn_knn.py new file mode 100644 index 0000000..e955dd8 --- /dev/null +++ b/ml/sklearn_knn.py @@ -0,0 +1,111 @@ +from random import randint +from typing import Any + +import hydra +import mlflow +import numpy as np +from omegaconf import DictConfig, OmegaConf +from rationai.mlkit import autolog +from rationai.mlkit.lightning.loggers import MLFlowLogger +from sklearn.neighbors import KNeighborsClassifier +from sklearn.pipeline import Pipeline +from sklearn.preprocessing import StandardScaler + +from ml.data import DataModule +from ml.sklearn_linear import _log_model, _log_split_metrics + + +if not OmegaConf.has_resolver("random_seed"): + OmegaConf.register_new_resolver( + "random_seed", lambda: randint(0, 2**31), use_cache=True + ) +if not OmegaConf.has_resolver("len"): + OmegaConf.register_new_resolver("len", lambda x: len(x)) + + +@hydra.main(config_path="../configs", config_name="ml", version_base=None) +@autolog +def main(config: DictConfig, logger: MLFlowLogger) -> None: + run(config) + mlflow.end_run() + + +def run(config: DictConfig) -> None: + if config.mode != "fit": + raise ValueError("sklearn_knn currently supports only mode='fit'") + + np.random.seed(config.seed) + data = hydra.utils.instantiate(config.data, _recursive_=False, _target_=DataModule) + data.setup("fit") + if "test" in data.datasets: + data.setup("test") + + x_train = data.train.embeddings + y_train = data.train.labels + x_val = data.val.embeddings + y_val = data.val.labels + + class_names = [ + n for n, _ in sorted(config.class_indices.items(), key=lambda kv: kv[1]) + ] + model = _build_model(config) + model.fit(x_train, y_train) + + _log_split_metrics( + model, x_val, y_val, data.val.slide_ids, class_names, "validation" + ) + if "test" in data.datasets: + _log_split_metrics( + model, + data.test.embeddings, + data.test.labels, + data.test.slide_ids, + class_names, + "test", + ) + if config.model.get("log_model", False): + _log_model(model) + + params = { + "model_type": "sklearn_knn", + "n_neighbors": config.model.n_neighbors, + "weights": config.model.weights, + "metric": config.model.metric, + "algorithm": config.model.algorithm, + "n_jobs": config.model.n_jobs, + "standardize": config.model.standardize, + "log_model": config.model.get("log_model", False), + "val_fold": config.val_fold, + "kfold_strategy": config.kfold_strategy, + "embedding_run_id": config.embedding_run_id, + "kfold_run_id": config.kfold_run_id, + "filter_tiles_run_id": config.filter_tiles_run_id, + "train_tiles": len(y_train), + "validation_tiles": len(y_val), + } + if "test" in data.datasets: + params["test_tiles"] = len(data.test.labels) + mlflow.log_params(params) + + +def _build_model(config: DictConfig) -> Pipeline: + steps: list[tuple[str, Any]] = [] + if config.model.standardize: + steps.append(("scaler", StandardScaler())) + steps.append( + ( + "classifier", + KNeighborsClassifier( + n_neighbors=config.model.n_neighbors, + weights=config.model.weights, + metric=config.model.metric, + algorithm=config.model.algorithm, + n_jobs=config.model.n_jobs, + ), + ) + ) + return Pipeline(steps) + + +if __name__ == "__main__": + main() diff --git a/scripts/submit_train_knn.py b/scripts/submit_train_knn.py new file mode 100644 index 0000000..5959cd2 --- /dev/null +++ b/scripts/submit_train_knn.py @@ -0,0 +1,18 @@ +from kube_jobs import storage, submit_job + + +submit_job( + job_name="tissue-classification-train-knn", + username="vcifka", + cpu=8, + memory="64Gi", + gpu=None, + public=False, + script=[ + "git clone --branch feature/ml-linear-classifier https://github.com/RationAI/tissue-classification.git workdir", + "cd workdir", + "uv sync", + "uv run python -m ml +experiment=ml/knn_stratified_group_kfold val_fold=0,1,2,3,4 model.n_neighbors=1,3,5,11,25,51,101 model.weights=uniform,distance model.metric=cosine,euclidean --multirun", + ], + storage=[storage.secure.PROJECTS], +) From 1f87154bd87ebb2da57ef3811096126e32b08c7c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Thu, 14 May 2026 22:14:45 +0200 Subject: [PATCH 44/53] refactor: focus on convergence --- configs/ml/trainer/default.yaml | 5 +++-- ml/meta_arch.py | 11 +++++++++++ 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/configs/ml/trainer/default.yaml b/configs/ml/trainer/default.yaml index a465a5c..d3b2171 100644 --- a/configs/ml/trainer/default.yaml +++ b/configs/ml/trainer/default.yaml @@ -1,7 +1,7 @@ # @package _global_ trainer: - max_epochs: 50 + max_epochs: 500 accelerator: auto devices: auto precision: 32 @@ -13,7 +13,8 @@ trainer: _target_: lightning.pytorch.callbacks.EarlyStopping monitor: validation/loss mode: min - patience: 2 + patience: 1 + min_delta: 1.0e-4 model_checkpoint: _target_: lightning.pytorch.callbacks.ModelCheckpoint monitor: validation/loss diff --git a/ml/meta_arch.py b/ml/meta_arch.py index 07d1431..e7c4c52 100644 --- a/ml/meta_arch.py +++ b/ml/meta_arch.py @@ -7,6 +7,7 @@ import pandas as pd import torch from lightning import LightningModule +from lightning.pytorch.utilities import grad_norm from matplotlib import pyplot as plt from matplotlib.figure import Figure from torch import Tensor, nn @@ -102,6 +103,16 @@ def training_step(self, batch: Input, batch_idx: int) -> Tensor: self.log("train/loss", loss, on_step=True, on_epoch=True, prog_bar=True) return loss + def on_before_optimizer_step(self, optimizer: Optimizer) -> None: + norms = grad_norm(self, norm_type=2) + self.log( + "train/grad_norm", + norms["grad_2.0_norm_total"], + on_step=False, + on_epoch=True, + prog_bar=True, + ) + def validation_step(self, batch: Input, batch_idx: int) -> None: inputs, targets, _ = batch outputs = self(inputs) From 7039307d3cd995a1405fbf835f3d19168e5157ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Thu, 14 May 2026 22:28:57 +0200 Subject: [PATCH 45/53] Remove kNN sklearn baseline --- .../ml/knn_stratified_group_kfold.yaml | 8 -- .../experiment/ml/knn_stratified_kfold.yaml | 8 -- configs/ml/knn.yaml | 62 ---------- ml/__main__.py | 7 -- ml/sklearn_knn.py | 111 ------------------ scripts/submit_train_knn.py | 18 --- 6 files changed, 214 deletions(-) delete mode 100644 configs/experiment/ml/knn_stratified_group_kfold.yaml delete mode 100644 configs/experiment/ml/knn_stratified_kfold.yaml delete mode 100644 configs/ml/knn.yaml delete mode 100644 ml/sklearn_knn.py delete mode 100644 scripts/submit_train_knn.py diff --git a/configs/experiment/ml/knn_stratified_group_kfold.yaml b/configs/experiment/ml/knn_stratified_group_kfold.yaml deleted file mode 100644 index 3c9cafe..0000000 --- a/configs/experiment/ml/knn_stratified_group_kfold.yaml +++ /dev/null @@ -1,8 +0,0 @@ -# @package _global_ - -defaults: - - /ml/knn - - _self_ - -kfold_strategy: stratified_group -kfold_run_id: ${dataset.mlflow_artifacts.stratified_group_kfold_run_id} diff --git a/configs/experiment/ml/knn_stratified_kfold.yaml b/configs/experiment/ml/knn_stratified_kfold.yaml deleted file mode 100644 index 87876c2..0000000 --- a/configs/experiment/ml/knn_stratified_kfold.yaml +++ /dev/null @@ -1,8 +0,0 @@ -# @package _global_ - -defaults: - - /ml/knn - - _self_ - -kfold_strategy: stratified -kfold_run_id: ${dataset.mlflow_artifacts.stratified_kfold_run_id} diff --git a/configs/ml/knn.yaml b/configs/ml/knn.yaml deleted file mode 100644 index b2a6561..0000000 --- a/configs/ml/knn.yaml +++ /dev/null @@ -1,62 +0,0 @@ -# @package _global_ - -defaults: - - /data: dataset - - /class_mapping: collapse_alterations_to_other - - /ml/data: embedding - - _self_ - -mode: fit -runner: sklearn_knn - -embedding_run_id: ${dataset.mlflow_artifacts.embedding_run_id} -kfold_strategy: stratified -kfold_run_id: ${dataset.mlflow_artifacts.stratified_kfold_run_id} -filter_tiles_run_id: ${dataset.mlflow_artifacts.filter_tiles_run_id} - -train_embedding_uri: runs:/${embedding_run_id}/train/tiles -test_embedding_uri: runs:/${embedding_run_id}/test/tiles -train_metadata_uri: runs:/${kfold_run_id}/kfold_split/kfold_tiles.parquet -test_metadata_uri: runs:/${filter_tiles_run_id}/filter_tiles/test_tiles.parquet - -val_fold: 0 - -tissue_prop_min: 0.2 -thresholds: - Nerve: 0.0 - Blood: 0.0 - Connective-Tissue: 0.4 - Fat: 0.6 - Epithelium: 0.2 - Muscle: 0.5 - Other: 0.5 - -model: - n_neighbors: 25 - weights: distance - metric: cosine - algorithm: brute - n_jobs: -1 - standardize: false - log_model: false - -mlflow_artifact_path: knn - -metadata: - run_name: kNN ${dataset.name} k=${model.n_neighbors} ${model.weights} ${model.metric} ${kfold_strategy} fold=${val_fold} - description: "kNN classifier over frozen Virchow2 embeddings (run ${embedding_run_id}), ${kfold_strategy} kfold metadata ${kfold_run_id}." - hyperparams: - embedding_run_id: ${embedding_run_id} - kfold_strategy: ${kfold_strategy} - kfold_run_id: ${kfold_run_id} - filter_tiles_run_id: ${filter_tiles_run_id} - val_fold: ${val_fold} - tissue_prop_min: ${tissue_prop_min} - thresholds: ${thresholds} - n_neighbors: ${model.n_neighbors} - weights: ${model.weights} - metric: ${model.metric} - algorithm: ${model.algorithm} - n_jobs: ${model.n_jobs} - standardize: ${model.standardize} - log_model: ${model.log_model} diff --git a/ml/__main__.py b/ml/__main__.py index 3c4cad0..ae5a82e 100644 --- a/ml/__main__.py +++ b/ml/__main__.py @@ -26,13 +26,6 @@ def main(config: DictConfig, logger: MLFlowLogger) -> None: run_sklearn_linear(config) mlflow.end_run() return - if config.get("runner") == "sklearn_knn": - from ml.sklearn_knn import run as run_sklearn_knn - - run_sklearn_knn(config) - mlflow.end_run() - return - seed_everything(config.seed, workers=True) data = hydra.utils.instantiate(config.data, _recursive_=False, _target_=DataModule) diff --git a/ml/sklearn_knn.py b/ml/sklearn_knn.py deleted file mode 100644 index e955dd8..0000000 --- a/ml/sklearn_knn.py +++ /dev/null @@ -1,111 +0,0 @@ -from random import randint -from typing import Any - -import hydra -import mlflow -import numpy as np -from omegaconf import DictConfig, OmegaConf -from rationai.mlkit import autolog -from rationai.mlkit.lightning.loggers import MLFlowLogger -from sklearn.neighbors import KNeighborsClassifier -from sklearn.pipeline import Pipeline -from sklearn.preprocessing import StandardScaler - -from ml.data import DataModule -from ml.sklearn_linear import _log_model, _log_split_metrics - - -if not OmegaConf.has_resolver("random_seed"): - OmegaConf.register_new_resolver( - "random_seed", lambda: randint(0, 2**31), use_cache=True - ) -if not OmegaConf.has_resolver("len"): - OmegaConf.register_new_resolver("len", lambda x: len(x)) - - -@hydra.main(config_path="../configs", config_name="ml", version_base=None) -@autolog -def main(config: DictConfig, logger: MLFlowLogger) -> None: - run(config) - mlflow.end_run() - - -def run(config: DictConfig) -> None: - if config.mode != "fit": - raise ValueError("sklearn_knn currently supports only mode='fit'") - - np.random.seed(config.seed) - data = hydra.utils.instantiate(config.data, _recursive_=False, _target_=DataModule) - data.setup("fit") - if "test" in data.datasets: - data.setup("test") - - x_train = data.train.embeddings - y_train = data.train.labels - x_val = data.val.embeddings - y_val = data.val.labels - - class_names = [ - n for n, _ in sorted(config.class_indices.items(), key=lambda kv: kv[1]) - ] - model = _build_model(config) - model.fit(x_train, y_train) - - _log_split_metrics( - model, x_val, y_val, data.val.slide_ids, class_names, "validation" - ) - if "test" in data.datasets: - _log_split_metrics( - model, - data.test.embeddings, - data.test.labels, - data.test.slide_ids, - class_names, - "test", - ) - if config.model.get("log_model", False): - _log_model(model) - - params = { - "model_type": "sklearn_knn", - "n_neighbors": config.model.n_neighbors, - "weights": config.model.weights, - "metric": config.model.metric, - "algorithm": config.model.algorithm, - "n_jobs": config.model.n_jobs, - "standardize": config.model.standardize, - "log_model": config.model.get("log_model", False), - "val_fold": config.val_fold, - "kfold_strategy": config.kfold_strategy, - "embedding_run_id": config.embedding_run_id, - "kfold_run_id": config.kfold_run_id, - "filter_tiles_run_id": config.filter_tiles_run_id, - "train_tiles": len(y_train), - "validation_tiles": len(y_val), - } - if "test" in data.datasets: - params["test_tiles"] = len(data.test.labels) - mlflow.log_params(params) - - -def _build_model(config: DictConfig) -> Pipeline: - steps: list[tuple[str, Any]] = [] - if config.model.standardize: - steps.append(("scaler", StandardScaler())) - steps.append( - ( - "classifier", - KNeighborsClassifier( - n_neighbors=config.model.n_neighbors, - weights=config.model.weights, - metric=config.model.metric, - algorithm=config.model.algorithm, - n_jobs=config.model.n_jobs, - ), - ) - ) - return Pipeline(steps) - - -if __name__ == "__main__": - main() diff --git a/scripts/submit_train_knn.py b/scripts/submit_train_knn.py deleted file mode 100644 index 5959cd2..0000000 --- a/scripts/submit_train_knn.py +++ /dev/null @@ -1,18 +0,0 @@ -from kube_jobs import storage, submit_job - - -submit_job( - job_name="tissue-classification-train-knn", - username="vcifka", - cpu=8, - memory="64Gi", - gpu=None, - public=False, - script=[ - "git clone --branch feature/ml-linear-classifier https://github.com/RationAI/tissue-classification.git workdir", - "cd workdir", - "uv sync", - "uv run python -m ml +experiment=ml/knn_stratified_group_kfold val_fold=0,1,2,3,4 model.n_neighbors=1,3,5,11,25,51,101 model.weights=uniform,distance model.metric=cosine,euclidean --multirun", - ], - storage=[storage.secure.PROJECTS], -) From 729eccd54f6c1b38744eddf27a35cf6fe4463457 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Thu, 14 May 2026 22:44:32 +0200 Subject: [PATCH 46/53] fix: change monitor to focus on train losss --- configs/ml/trainer/default.yaml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/configs/ml/trainer/default.yaml b/configs/ml/trainer/default.yaml index d3b2171..8615025 100644 --- a/configs/ml/trainer/default.yaml +++ b/configs/ml/trainer/default.yaml @@ -11,16 +11,16 @@ trainer: callbacks: early_stopping: _target_: lightning.pytorch.callbacks.EarlyStopping - monitor: validation/loss + monitor: train/loss_epoch mode: min patience: 1 min_delta: 1.0e-4 model_checkpoint: _target_: lightning.pytorch.callbacks.ModelCheckpoint - monitor: validation/loss + monitor: train/loss_epoch mode: min save_top_k: 1 - filename: "epoch={epoch}-val_loss={validation/loss:.4f}" + filename: "epoch={epoch}-train_loss={train/loss_epoch:.4f}" auto_insert_metric_name: false lr_monitor: _target_: lightning.pytorch.callbacks.LearningRateMonitor From d3ed2ed642159335aa49fa9d7770d72893120126 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Fri, 15 May 2026 00:20:08 +0200 Subject: [PATCH 47/53] feat: add run name --- configs/ml/linear_classifier.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/ml/linear_classifier.yaml b/configs/ml/linear_classifier.yaml index 5606f0a..821d786 100644 --- a/configs/ml/linear_classifier.yaml +++ b/configs/ml/linear_classifier.yaml @@ -35,7 +35,7 @@ thresholds: mlflow_artifact_path: linear_classifier metadata: - run_name: Linear Classifier ${dataset.name} ${kfold_strategy} fold=${val_fold} + run_name: Linear Classifier ${dataset.name} ${kfold_strategy} fold=${val_fold} wd=${model.weight_decay} description: "Linear probe over frozen Virchow2 embeddings (run ${embedding_run_id}), ${kfold_strategy} kfold metadata ${kfold_run_id}." hyperparams: embedding_run_id: ${embedding_run_id} From e9fd559e18ab6f4eb0034d8a4ada68505f6fb3c2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Fri, 15 May 2026 10:46:48 +0200 Subject: [PATCH 48/53] chore: remove logistic regression --- ...tic_regression_stratified_group_kfold.yaml | 8 - ..._logistic_regression_stratified_kfold.yaml | 8 - configs/ml/lbfgs_logistic_regression.yaml | 64 ----- ml/__main__.py | 6 - ml/sklearn_linear.py | 220 ------------------ pyproject.toml | 1 - scripts/submit_train_linear_probe.py | 6 +- scripts/submit_train_logistic_regression.py | 18 -- uv.lock | 2 - 9 files changed, 3 insertions(+), 330 deletions(-) delete mode 100644 configs/experiment/ml/lbfgs_logistic_regression_stratified_group_kfold.yaml delete mode 100644 configs/experiment/ml/lbfgs_logistic_regression_stratified_kfold.yaml delete mode 100644 configs/ml/lbfgs_logistic_regression.yaml delete mode 100644 ml/sklearn_linear.py delete mode 100644 scripts/submit_train_logistic_regression.py diff --git a/configs/experiment/ml/lbfgs_logistic_regression_stratified_group_kfold.yaml b/configs/experiment/ml/lbfgs_logistic_regression_stratified_group_kfold.yaml deleted file mode 100644 index 49c954e..0000000 --- a/configs/experiment/ml/lbfgs_logistic_regression_stratified_group_kfold.yaml +++ /dev/null @@ -1,8 +0,0 @@ -# @package _global_ - -defaults: - - /ml/lbfgs_logistic_regression - - _self_ - -kfold_strategy: stratified_group -kfold_run_id: ${dataset.mlflow_artifacts.stratified_group_kfold_run_id} diff --git a/configs/experiment/ml/lbfgs_logistic_regression_stratified_kfold.yaml b/configs/experiment/ml/lbfgs_logistic_regression_stratified_kfold.yaml deleted file mode 100644 index 3d15556..0000000 --- a/configs/experiment/ml/lbfgs_logistic_regression_stratified_kfold.yaml +++ /dev/null @@ -1,8 +0,0 @@ -# @package _global_ - -defaults: - - /ml/lbfgs_logistic_regression - - _self_ - -kfold_strategy: stratified -kfold_run_id: ${dataset.mlflow_artifacts.stratified_kfold_run_id} diff --git a/configs/ml/lbfgs_logistic_regression.yaml b/configs/ml/lbfgs_logistic_regression.yaml deleted file mode 100644 index 43e1bdd..0000000 --- a/configs/ml/lbfgs_logistic_regression.yaml +++ /dev/null @@ -1,64 +0,0 @@ -# @package _global_ - -defaults: - - /data: dataset - - /class_mapping: collapse_alterations_to_other - - /ml/data: embedding - - _self_ - -mode: fit -runner: sklearn_linear - -embedding_run_id: ${dataset.mlflow_artifacts.embedding_run_id} -kfold_strategy: stratified -kfold_run_id: ${dataset.mlflow_artifacts.stratified_kfold_run_id} -filter_tiles_run_id: ${dataset.mlflow_artifacts.filter_tiles_run_id} - -train_embedding_uri: runs:/${embedding_run_id}/train/tiles -test_embedding_uri: runs:/${embedding_run_id}/test/tiles -train_metadata_uri: runs:/${kfold_run_id}/kfold_split/kfold_tiles.parquet -test_metadata_uri: runs:/${filter_tiles_run_id}/filter_tiles/test_tiles.parquet - -val_fold: 0 - -tissue_prop_min: 0.2 -thresholds: - Nerve: 0.0 - Blood: 0.0 - Connective-Tissue: 0.4 - Fat: 0.6 - Epithelium: 0.2 - Muscle: 0.5 - Other: 0.5 - -model: - solver: lbfgs - penalty: l2 - C: 1.0 - class_weight: balanced - max_iter: 1000 - tol: 1.0e-4 - n_jobs: null - verbose: 0 - standardize: true - -mlflow_artifact_path: lbfgs_logistic_regression - -metadata: - run_name: LBFGS Logistic Regression ${dataset.name} ${kfold_strategy} fold=${val_fold} C=${model.C} std=${model.standardize} - description: "LBFGS multinomial logistic regression over frozen Virchow2 embeddings (run ${embedding_run_id}), ${kfold_strategy} kfold metadata ${kfold_run_id}." - hyperparams: - embedding_run_id: ${embedding_run_id} - kfold_strategy: ${kfold_strategy} - kfold_run_id: ${kfold_run_id} - filter_tiles_run_id: ${filter_tiles_run_id} - val_fold: ${val_fold} - tissue_prop_min: ${tissue_prop_min} - thresholds: ${thresholds} - solver: ${model.solver} - penalty: ${model.penalty} - C: ${model.C} - class_weight: ${model.class_weight} - max_iter: ${model.max_iter} - tol: ${model.tol} - standardize: ${model.standardize} diff --git a/ml/__main__.py b/ml/__main__.py index ae5a82e..318c37c 100644 --- a/ml/__main__.py +++ b/ml/__main__.py @@ -20,12 +20,6 @@ @hydra.main(config_path="../configs", config_name="ml", version_base=None) @autolog def main(config: DictConfig, logger: MLFlowLogger) -> None: - if config.get("runner") == "sklearn_linear": - from ml.sklearn_linear import run as run_sklearn_linear - - run_sklearn_linear(config) - mlflow.end_run() - return seed_everything(config.seed, workers=True) data = hydra.utils.instantiate(config.data, _recursive_=False, _target_=DataModule) diff --git a/ml/sklearn_linear.py b/ml/sklearn_linear.py deleted file mode 100644 index 08ea754..0000000 --- a/ml/sklearn_linear.py +++ /dev/null @@ -1,220 +0,0 @@ -from pathlib import Path -from random import randint -from typing import Any - -import hydra -import mlflow -import mlflow.sklearn -import numpy as np -import pandas as pd -from matplotlib import pyplot as plt -from omegaconf import DictConfig, OmegaConf -from rationai.mlkit import autolog -from rationai.mlkit.lightning.loggers import MLFlowLogger -from sklearn.linear_model import LogisticRegression -from sklearn.metrics import confusion_matrix, f1_score, recall_score -from sklearn.pipeline import Pipeline -from sklearn.preprocessing import StandardScaler - -from ml.data import DataModule -from ml.meta_arch import _confmat_figure - - -if not OmegaConf.has_resolver("random_seed"): - OmegaConf.register_new_resolver( - "random_seed", lambda: randint(0, 2**31), use_cache=True - ) -if not OmegaConf.has_resolver("len"): - OmegaConf.register_new_resolver("len", lambda x: len(x)) - - -@hydra.main(config_path="../configs", config_name="ml", version_base=None) -@autolog -def main(config: DictConfig, logger: MLFlowLogger) -> None: - run(config) - mlflow.end_run() - - -def run(config: DictConfig) -> None: - if config.mode != "fit": - raise ValueError("sklearn_linear currently supports only mode='fit'") - - np.random.seed(config.seed) - data = hydra.utils.instantiate(config.data, _recursive_=False, _target_=DataModule) - data.setup("fit") - if "test" in data.datasets: - data.setup("test") - - x_train = data.train.embeddings - y_train = data.train.labels - x_val = data.val.embeddings - y_val = data.val.labels - - class_names = [ - n for n, _ in sorted(config.class_indices.items(), key=lambda kv: kv[1]) - ] - model = _build_model(config) - model.fit(x_train, y_train) - _log_convergence(model, config) - - _log_split_metrics( - model, x_val, y_val, data.val.slide_ids, class_names, "validation" - ) - if hasattr(data, "test"): - _log_split_metrics( - model, - data.test.embeddings, - data.test.labels, - data.test.slide_ids, - class_names, - "test", - ) - _log_model(model) - - params = { - "model_type": "sklearn_logistic_regression", - "solver": config.model.solver, - "penalty": config.model.penalty, - "C": config.model.C, - "max_iter": config.model.max_iter, - "tol": config.model.tol, - "class_weight": config.model.class_weight, - "standardize": config.model.standardize, - "val_fold": config.val_fold, - "kfold_strategy": config.kfold_strategy, - "embedding_run_id": config.embedding_run_id, - "kfold_run_id": config.kfold_run_id, - "filter_tiles_run_id": config.filter_tiles_run_id, - "train_tiles": len(y_train), - "validation_tiles": len(y_val), - } - if hasattr(data, "test"): - params["test_tiles"] = len(data.test.labels) - mlflow.log_params(params) - - -def _build_model(config: DictConfig) -> Pipeline: - steps: list[tuple[str, Any]] = [] - if config.model.standardize: - steps.append(("scaler", StandardScaler())) - steps.append( - ( - "classifier", - LogisticRegression( - C=config.model.C, - class_weight=config.model.class_weight, - max_iter=config.model.max_iter, - n_jobs=config.model.n_jobs, - penalty=config.model.penalty, - random_state=config.seed, - solver=config.model.solver, - tol=config.model.tol, - verbose=config.model.verbose, - ), - ) - ) - return Pipeline(steps) - - -def _log_split_metrics( - model: Pipeline, - inputs: np.ndarray, - targets: np.ndarray, - slide_ids: np.ndarray, - class_names: list[str], - split: str, -) -> None: - labels = np.arange(len(class_names)) - preds = model.predict(inputs) - probs = _predict_proba_for_all_classes(model, inputs, labels) - - mlflow.log_metric( - f"{split}/acc_macro", - recall_score(targets, preds, labels=labels, average="macro", zero_division=0), - ) - mlflow.log_metric( - f"{split}/f1_macro", - f1_score(targets, preds, average="macro", zero_division=0), - ) - - per_class_acc = recall_score( - targets, preds, labels=labels, average=None, zero_division=0 - ) - per_class_f1 = f1_score( - targets, preds, labels=labels, average=None, zero_division=0 - ) - for cls_name, acc, f1 in zip( - class_names, per_class_acc.tolist(), per_class_f1.tolist(), strict=True - ): - mlflow.log_metric(f"{split}/acc_per_class/{cls_name}", acc) - mlflow.log_metric(f"{split}/f1_per_class/{cls_name}", f1) - - matrix = confusion_matrix(targets, preds, labels=labels) - fig = _confmat_figure(matrix, class_names, title=f"{split} confmat") - try: - mlflow.log_figure(fig, artifact_file=f"confusion_matrix/{split}.png") - finally: - plt.close(fig) - - prob_columns = [f"prob_{c}" for c in class_names] - predictions = pd.DataFrame( - { - "slide_id": slide_ids, - "target": targets, - "pred": preds, - } - ) - predictions = pd.concat( - [predictions, pd.DataFrame(probs, columns=prob_columns)], axis=1 - ) - out_path = Path(f"{split}_predictions.parquet") - predictions.to_parquet(out_path, index=False) - mlflow.log_artifact(str(out_path), artifact_path="predictions") - _log_per_slide_accuracy(predictions, split) - - -def _log_per_slide_accuracy(predictions: pd.DataFrame, split: str) -> None: - rows = [] - for slide_id, slide_df in predictions.groupby("slide_id"): - rows.append( - { - "slide_id": slide_id, - "tile_accuracy": float((slide_df["pred"] == slide_df["target"]).mean()), - "n_tiles": len(slide_df), - } - ) - if not rows: - return - - per_slide = pd.DataFrame(rows) - mlflow.log_metric(f"{split}/slide_acc_mean", per_slide["tile_accuracy"].mean()) - mlflow.log_metric(f"{split}/slide_acc_median", per_slide["tile_accuracy"].median()) - mlflow.log_metric(f"{split}/slide_acc_min", per_slide["tile_accuracy"].min()) - mlflow.log_table(per_slide, artifact_file=f"per_slide/{split}_tile_accuracy.json") - - -def _predict_proba_for_all_classes( - model: Pipeline, inputs: np.ndarray, labels: np.ndarray -) -> np.ndarray: - raw_probs = model.predict_proba(inputs) - probs = np.zeros((len(inputs), len(labels)), dtype=raw_probs.dtype) - for source_idx, class_idx in enumerate(model.classes_): - matching = np.flatnonzero(labels == class_idx) - if len(matching) == 1: - probs[:, matching[0]] = raw_probs[:, source_idx] - return probs - - -def _log_convergence(model: Pipeline, config: DictConfig) -> None: - classifier = model.named_steps["classifier"] - n_iter = int(classifier.n_iter_.max()) - mlflow.log_metric("n_iter", n_iter) - mlflow.log_param("converged", n_iter < config.model.max_iter) - - -def _log_model(model: Pipeline) -> None: - mlflow.sklearn.log_model(model, artifact_path="model") - - -if __name__ == "__main__": - main() diff --git a/pyproject.toml b/pyproject.toml index 0cff386..450c492 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,6 @@ dependencies = [ "ratiopath>=1.2.0", "pyarrow>=19.0.1", "datasets>=4.0.0", - "scikit-learn>=1.8.0", "numpy>=2.3.5", "rationai-tiling>=1.1.1", "tifffile>=2025.12.20", diff --git a/scripts/submit_train_linear_probe.py b/scripts/submit_train_linear_probe.py index 93cf686..3f7ecb1 100644 --- a/scripts/submit_train_linear_probe.py +++ b/scripts/submit_train_linear_probe.py @@ -3,16 +3,16 @@ submit_job( job_name="tissue-classification-train-linear", - username=..., + username="vcifka", cpu=8, memory="64Gi", gpu=None, public=False, script=[ - "git clone https://github.com/RationAI/tissue-classification.git workdir", + "git clone --branch feature/ml-linear-classifier https://github.com/RationAI/tissue-classification.git workdir", "cd workdir", "uv sync", - "uv run python -m ml +experiment=... val_fold=0,1,2,3,4 --multirun", + "uv run python -m ml +experiment=ml/linear_classifier_stratified_group_kfold val_fold=0,1,2,3,4 model.weight_decay=0,1e-5,1e-4,1e-3,1e-2 --multirun", ], storage=[storage.secure.PROJECTS], ) diff --git a/scripts/submit_train_logistic_regression.py b/scripts/submit_train_logistic_regression.py deleted file mode 100644 index b043252..0000000 --- a/scripts/submit_train_logistic_regression.py +++ /dev/null @@ -1,18 +0,0 @@ -from kube_jobs import storage, submit_job - - -submit_job( - job_name="tissue-classification-train-logistic-regression", - username=..., - cpu=8, - memory="64Gi", - gpu=None, - public=False, - script=[ - "git clone https://github.com/RationAI/tissue-classification.git workdir", - "cd workdir", - "uv sync", - "uv run python -m ml +experiment=ml/... val_fold=0,1,2,3,4 model.C=0.001,0.01,0.1,1,10,100 --multirun", - ], - storage=[storage.secure.PROJECTS], -) diff --git a/uv.lock b/uv.lock index c4ad730..1e1ee3a 100644 --- a/uv.lock +++ b/uv.lock @@ -2306,7 +2306,6 @@ dependencies = [ { name = "rationai-tiling" }, { name = "ratiopath" }, { name = "ray" }, - { name = "scikit-learn" }, { name = "tifffile" }, { name = "timm" }, { name = "torch" }, @@ -2342,7 +2341,6 @@ requires-dist = [ { name = "rationai-tiling", git = "https://gitlab.ics.muni.cz/rationai/digital-pathology/libraries/tiling.git" }, { name = "ratiopath", specifier = ">=1.2.0" }, { name = "ray", specifier = ">=2.51.1" }, - { name = "scikit-learn", specifier = ">=1.8.0" }, { name = "tifffile", specifier = ">=2025.12.20" }, { name = "timm", specifier = ">=1.0.0" }, { name = "torch", specifier = ">=2.0.0" }, From 6dadbd75aa3cda770c14f545266c746226236177 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Fri, 15 May 2026 11:16:29 +0200 Subject: [PATCH 49/53] feat: implement lbfgs --- ...assifier_lbfgs_stratified_group_kfold.yaml | 26 ++++ ...ear_classifier_lbfgs_stratified_kfold.yaml | 26 ++++ configs/ml/data/embedding.yaml | 2 + configs/ml/linear_classifier.yaml | 6 +- configs/ml/model/linear_classifier.yaml | 10 ++ ml/data/data_module.py | 13 +- ml/meta_arch.py | 144 ++++++++++++++++++ 7 files changed, 223 insertions(+), 4 deletions(-) create mode 100644 configs/experiment/ml/linear_classifier_lbfgs_stratified_group_kfold.yaml create mode 100644 configs/experiment/ml/linear_classifier_lbfgs_stratified_kfold.yaml diff --git a/configs/experiment/ml/linear_classifier_lbfgs_stratified_group_kfold.yaml b/configs/experiment/ml/linear_classifier_lbfgs_stratified_group_kfold.yaml new file mode 100644 index 0000000..8595986 --- /dev/null +++ b/configs/experiment/ml/linear_classifier_lbfgs_stratified_group_kfold.yaml @@ -0,0 +1,26 @@ +# @package _global_ + +defaults: + - /experiment/ml/linear_classifier_stratified_group_kfold + - _self_ + +trainer: + max_epochs: 10 + +data: + batch_size: 1000000000 + train_shuffle: false + train_drop_last: false + +model: + optimizer: lbfgs + learning_rate: 1.0 + lbfgs: + max_iter: 100 + max_eval: null + tolerance_grad: 1.0e-7 + tolerance_change: 1.0e-9 + history_size: 100 + line_search_fn: strong_wolfe + accumulate_batches: 1 + accumulate_on_cpu: false diff --git a/configs/experiment/ml/linear_classifier_lbfgs_stratified_kfold.yaml b/configs/experiment/ml/linear_classifier_lbfgs_stratified_kfold.yaml new file mode 100644 index 0000000..bd3c10b --- /dev/null +++ b/configs/experiment/ml/linear_classifier_lbfgs_stratified_kfold.yaml @@ -0,0 +1,26 @@ +# @package _global_ + +defaults: + - /experiment/ml/linear_classifier_stratified_kfold + - _self_ + +trainer: + max_epochs: 10 + +data: + batch_size: 1000000000 + train_shuffle: false + train_drop_last: false + +model: + optimizer: lbfgs + learning_rate: 1.0 + lbfgs: + max_iter: 100 + max_eval: null + tolerance_grad: 1.0e-7 + tolerance_change: 1.0e-9 + history_size: 100 + line_search_fn: strong_wolfe + accumulate_batches: 1 + accumulate_on_cpu: false diff --git a/configs/ml/data/embedding.yaml b/configs/ml/data/embedding.yaml index 597e012..40ff4b7 100644 --- a/configs/ml/data/embedding.yaml +++ b/configs/ml/data/embedding.yaml @@ -3,6 +3,8 @@ data: batch_size: 1024 num_workers: 4 + train_shuffle: true + train_drop_last: true train: _target_: ml.data.datasets.EmbeddingTilesDataset diff --git a/configs/ml/linear_classifier.yaml b/configs/ml/linear_classifier.yaml index 821d786..d339337 100644 --- a/configs/ml/linear_classifier.yaml +++ b/configs/ml/linear_classifier.yaml @@ -35,7 +35,7 @@ thresholds: mlflow_artifact_path: linear_classifier metadata: - run_name: Linear Classifier ${dataset.name} ${kfold_strategy} fold=${val_fold} wd=${model.weight_decay} + run_name: Linear Classifier ${dataset.name} ${kfold_strategy} fold=${val_fold} opt=${model.optimizer} wd=${model.weight_decay} description: "Linear probe over frozen Virchow2 embeddings (run ${embedding_run_id}), ${kfold_strategy} kfold metadata ${kfold_run_id}." hyperparams: embedding_run_id: ${embedding_run_id} @@ -45,6 +45,10 @@ metadata: val_fold: ${val_fold} tissue_prop_min: ${tissue_prop_min} thresholds: ${thresholds} + optimizer: ${model.optimizer} learning_rate: ${model.learning_rate} weight_decay: ${model.weight_decay} + lbfgs: ${model.lbfgs} batch_size: ${data.batch_size} + train_shuffle: ${data.train_shuffle} + train_drop_last: ${data.train_drop_last} diff --git a/configs/ml/model/linear_classifier.yaml b/configs/ml/model/linear_classifier.yaml index fc9bfed..4b4d9e8 100644 --- a/configs/ml/model/linear_classifier.yaml +++ b/configs/ml/model/linear_classifier.yaml @@ -11,5 +11,15 @@ model: class_indices: ${class_indices} + optimizer: adamw learning_rate: 1.0e-4 weight_decay: 0.0 + lbfgs: + max_iter: 100 + max_eval: null + tolerance_grad: 1.0e-7 + tolerance_change: 1.0e-9 + history_size: 100 + line_search_fn: strong_wolfe + accumulate_batches: 1 + accumulate_on_cpu: false diff --git a/ml/data/data_module.py b/ml/data/data_module.py index 7302be5..bfac118 100644 --- a/ml/data/data_module.py +++ b/ml/data/data_module.py @@ -16,11 +16,18 @@ class DataModule(LightningDataModule): """ def __init__( - self, batch_size: int, num_workers: int = 0, **datasets: DictConfig + self, + batch_size: int, + num_workers: int = 0, + train_shuffle: bool = True, + train_drop_last: bool = True, + **datasets: DictConfig, ) -> None: super().__init__() self.batch_size = batch_size self.num_workers = num_workers + self.train_shuffle = train_shuffle + self.train_drop_last = train_drop_last self.datasets = datasets def setup(self, stage: str) -> None: @@ -42,8 +49,8 @@ def train_dataloader(self) -> Iterable[Input]: return DataLoader( self.train, batch_size=self.batch_size, - shuffle=True, - drop_last=True, + shuffle=self.train_shuffle, + drop_last=self.train_drop_last, num_workers=self.num_workers, persistent_workers=self.num_workers > 0, ) diff --git a/ml/meta_arch.py b/ml/meta_arch.py index e7c4c52..73cd3f5 100644 --- a/ml/meta_arch.py +++ b/ml/meta_arch.py @@ -37,13 +37,21 @@ def __init__( class_indices: dict[str, int], learning_rate: float = 1e-3, weight_decay: float = 0.0, + optimizer: str = "adamw", + lbfgs: dict[str, Any] | None = None, ) -> None: super().__init__() self.save_hyperparameters(ignore=["backbone", "decode_head"]) + if optimizer not in {"adamw", "lbfgs"}: + raise ValueError(f"Unsupported optimizer {optimizer!r}") + if optimizer == "lbfgs": + self.automatic_optimization = False + self.backbone = backbone self.decode_head = decode_head self.criterion: nn.Module + self._lbfgs_batches: list[tuple[Tensor, Tensor]] = [] self.class_names = [ n for n, _ in sorted(class_indices.items(), key=lambda kv: kv[1]) @@ -82,6 +90,8 @@ def setup(self, stage: str) -> None: if stage == "fit": datamodule = cast("Any", self.trainer).datamodule labels = datamodule.train.labels + if self.hparams["optimizer"] == "lbfgs": + self._validate_lbfgs_full_batch(datamodule, len(labels)) num_classes = len(self.class_names) counts = np.bincount(labels, minlength=num_classes).astype(float) counts = np.maximum(counts, 1.0) @@ -97,6 +107,9 @@ def forward(self, x: Tensor) -> Outputs: return self.decode_head(features) def training_step(self, batch: Input, batch_idx: int) -> Tensor: + if self.hparams["optimizer"] == "lbfgs": + return self._lbfgs_training_step(batch, batch_idx) + inputs, targets, _ = batch outputs = self(inputs) loss = self.criterion(outputs, targets) @@ -104,6 +117,8 @@ def training_step(self, batch: Input, batch_idx: int) -> Tensor: return loss def on_before_optimizer_step(self, optimizer: Optimizer) -> None: + if self.hparams["optimizer"] == "lbfgs": + return norms = grad_norm(self, norm_type=2) self.log( "train/grad_norm", @@ -163,12 +178,141 @@ def predict_step( } def configure_optimizers(self) -> Optimizer: + if self.hparams["optimizer"] == "lbfgs": + lbfgs = self.hparams.get("lbfgs") or {} + return torch.optim.LBFGS( + self.parameters(), + lr=self.hparams["learning_rate"], + max_iter=lbfgs.get("max_iter", 100), + max_eval=lbfgs.get("max_eval"), + tolerance_grad=lbfgs.get("tolerance_grad", 1.0e-7), + tolerance_change=lbfgs.get("tolerance_change", 1.0e-9), + history_size=lbfgs.get("history_size", 100), + line_search_fn=lbfgs.get("line_search_fn", "strong_wolfe"), + ) + return torch.optim.AdamW( self.parameters(), lr=self.hparams["learning_rate"], weight_decay=self.hparams["weight_decay"], ) + def _lbfgs_training_step(self, batch: Input, batch_idx: int) -> Tensor: + inputs, targets, _ = batch + self._lbfgs_batches.append(self._prepare_lbfgs_batch(inputs, targets)) + lbfgs = self.hparams.get("lbfgs") or {} + accumulation_steps = int(lbfgs.get("accumulate_batches", 1)) + is_last_batch = batch_idx + 1 == self.trainer.num_training_batches + should_step = len(self._lbfgs_batches) >= accumulation_steps or is_last_batch + if not should_step: + with torch.no_grad(): + return self.criterion(self(inputs), targets) + + optimizer = self.optimizers() + total_samples = sum(targets.numel() for _, targets in self._lbfgs_batches) + + def closure() -> Tensor: + optimizer.zero_grad() + loss, _, _ = self._lbfgs_buffered_loss(total_samples) + if not torch.isfinite(loss): + raise FloatingPointError(f"non-finite LBFGS loss: {loss.item()}") + self.manual_backward(loss) + return loss + + step_loss = optimizer.step(closure=closure) + if not isinstance(step_loss, Tensor): + step_loss = torch.as_tensor(step_loss, device=self.device) + + optimizer.zero_grad() + loss, ce_loss, l2_loss = self._lbfgs_buffered_loss(total_samples) + if not torch.isfinite(loss): + raise FloatingPointError(f"non-finite LBFGS post-step loss: {loss.item()}") + self.manual_backward(loss) + grad_norm = self._total_grad_norm() + if grad_norm is not None and not torch.isfinite(grad_norm): + raise FloatingPointError( + f"non-finite LBFGS gradient norm: {grad_norm.item()}" + ) + optimizer.zero_grad() + self._lbfgs_batches.clear() + + self.log("train/loss", loss, on_step=True, on_epoch=True, prog_bar=True) + self.log("train/ce_loss", ce_loss, on_step=True, on_epoch=True) + self.log("train/l2_loss", l2_loss, on_step=True, on_epoch=True) + if grad_norm is not None: + self.log( + "train/grad_norm", + grad_norm, + on_step=True, + on_epoch=True, + prog_bar=True, + ) + self.log("train/lbfgs_step_loss", step_loss.detach(), on_step=True) + return loss + + def _prepare_lbfgs_batch( + self, inputs: Tensor, targets: Tensor + ) -> tuple[Tensor, Tensor]: + lbfgs = self.hparams.get("lbfgs") or {} + if lbfgs.get("accumulate_on_cpu", False): + return inputs.detach().cpu(), targets.detach().cpu() + return inputs, targets + + def _validate_lbfgs_full_batch(self, datamodule: Any, train_size: int) -> None: + lbfgs = self.hparams.get("lbfgs") or {} + batch_size = int(datamodule.batch_size) + accumulation_steps = int(lbfgs.get("accumulate_batches", 1)) + effective_batch_size = batch_size * accumulation_steps + + if datamodule.train_shuffle: + raise ValueError("LBFGS requires data.train_shuffle=false.") + if datamodule.train_drop_last: + raise ValueError("LBFGS requires data.train_drop_last=false.") + if effective_batch_size < train_size: + raise ValueError( + "LBFGS requires a deterministic full-batch objective. Set " + "data.batch_size >= len(train) or set " + "model.lbfgs.accumulate_batches >= ceil(len(train) / " + "data.batch_size). Current effective batch size is " + f"{effective_batch_size} for {train_size} training samples." + ) + + def _lbfgs_buffered_loss(self, total_samples: int) -> tuple[Tensor, Tensor, Tensor]: + ce_loss = torch.zeros((), device=self.device) + for micro_inputs, micro_targets in self._lbfgs_batches: + micro_inputs = micro_inputs.to(self.device) + micro_targets = micro_targets.to(self.device) + outputs = self(micro_inputs) + weight = micro_targets.numel() / total_samples + ce_loss = ce_loss + self.criterion(outputs, micro_targets) * weight + + l2_loss = self._l2_loss() + return ce_loss + l2_loss, ce_loss, l2_loss + + def _objective_loss(self, outputs: Tensor, targets: Tensor) -> Tensor: + return self.criterion(outputs, targets) + self._l2_loss() + + def _l2_loss(self) -> Tensor: + weight_decay = self.hparams["weight_decay"] + if weight_decay == 0: + return torch.zeros((), device=self.device) + + penalty = torch.zeros((), device=self.device) + for name, param in self.named_parameters(): + if param.requires_grad and name.endswith("weight"): + penalty = penalty + param.square().sum() + return 0.5 * weight_decay * penalty + + def _total_grad_norm(self) -> Tensor | None: + grads = [ + param.grad.detach().norm(2) + for param in self.parameters() + if param.grad is not None + ] + if not grads: + return None + return torch.linalg.vector_norm(torch.stack(grads), ord=2) + def _log_per_class(self, collection: MetricCollection, split: str) -> None: computed = collection.compute() for metric_name, values in computed.items(): From d5d3edd1655aa3a32652e4383175a4d9b5259b4b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Fri, 15 May 2026 11:32:41 +0200 Subject: [PATCH 50/53] fix: run id --- configs/data/dataset.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/data/dataset.yaml b/configs/data/dataset.yaml index 0ab4e0d..09f8f4a 100644 --- a/configs/data/dataset.yaml +++ b/configs/data/dataset.yaml @@ -14,7 +14,7 @@ dataset: test_split_filename: "split_mapping/test_split.csv" tiling_run_id: "fdf7550a2004474f8c7a05dc0cf1fd86" filter_tiles_run_id: "4e8f5d3c82124ea5a8f871a42d3ed9ba" - stratified_kfold_run_id: "814611e8987d4d569b255b7a4749bc90" + stratified_kfold_run_id: "c7eafdffa32743aa9eb6dd2bf3a185b5" stratified_group_kfold_run_id: "382b41d2fa894514908e8067949c4326" embedding_run_id: "c325e3a5033b4077b6febb0e3e6b0bd6" tissue_masks_run_id: "52bc0924f8624b259819c480c7cf213f" From 216369965c1efc53a985174e240055047e0c267f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Fri, 15 May 2026 11:37:43 +0200 Subject: [PATCH 51/53] fix: cache the tiles and embeddings so they do not need to be downloaded twice --- ml/data/datasets/embedding_tiles.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/ml/data/datasets/embedding_tiles.py b/ml/data/datasets/embedding_tiles.py index 791a734..5160ba1 100644 --- a/ml/data/datasets/embedding_tiles.py +++ b/ml/data/datasets/embedding_tiles.py @@ -5,6 +5,7 @@ load time to produce ``(embedding, class_index, slide_id)`` triples. """ +from functools import cache from pathlib import Path import numpy as np @@ -186,7 +187,11 @@ def _filter_metadata( @staticmethod def _resolve_uri(path_or_uri: str | Path) -> str: - s = str(path_or_uri) - if s.startswith(("mlflow-artifacts:/", "runs:/")): - return download_artifacts(artifact_uri=s) - return s + return EmbeddingTilesDataset._resolve_uri_cached(str(path_or_uri)) + + @staticmethod + @cache + def _resolve_uri_cached(uri: str) -> str: + if uri.startswith(("mlflow-artifacts:/", "runs:/")): + return download_artifacts(artifact_uri=uri) + return uri From 92868070c23d24c537ffae7ab965ce4f84f60b44 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Fri, 15 May 2026 12:43:21 +0200 Subject: [PATCH 52/53] fix: limit num of workers --- .../ml/linear_classifier_lbfgs_stratified_group_kfold.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/configs/experiment/ml/linear_classifier_lbfgs_stratified_group_kfold.yaml b/configs/experiment/ml/linear_classifier_lbfgs_stratified_group_kfold.yaml index 8595986..4d92561 100644 --- a/configs/experiment/ml/linear_classifier_lbfgs_stratified_group_kfold.yaml +++ b/configs/experiment/ml/linear_classifier_lbfgs_stratified_group_kfold.yaml @@ -11,6 +11,7 @@ data: batch_size: 1000000000 train_shuffle: false train_drop_last: false + num_workers: 0 model: optimizer: lbfgs From bb8a043d5dc174768a7aa06ad9ae2de30d6f1bf8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Fri, 15 May 2026 17:28:14 +0200 Subject: [PATCH 53/53] fix: support checkpoint test and prediction export --- ml/callbacks/parquet_prediction_writer.py | 17 +++++++++++------ ml/meta_arch.py | 4 ++-- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/ml/callbacks/parquet_prediction_writer.py b/ml/callbacks/parquet_prediction_writer.py index 38d86fd..a6f676b 100644 --- a/ml/callbacks/parquet_prediction_writer.py +++ b/ml/callbacks/parquet_prediction_writer.py @@ -32,16 +32,21 @@ def write_on_epoch_end( if trainer.global_rank != 0: return + batches = ( + predictions + if not predictions or isinstance(predictions[0], dict) + else [b for dataloader_preds in predictions for b in dataloader_preds] + ) + slide_ids: list[str] = [] targets: list[int] = [] preds: list[int] = [] probs: list[np.ndarray] = [] - for dataloader_preds in predictions: - for b in dataloader_preds: - slide_ids.extend(b["slide_id"]) - targets.extend(b["target"].tolist()) - preds.extend(b["pred"].tolist()) - probs.append(b["probs"].numpy()) + for b in batches: + slide_ids.extend(b["slide_id"]) + targets.extend(b["target"].tolist()) + preds.extend(b["pred"].tolist()) + probs.append(b["probs"].numpy()) if not slide_ids: return diff --git a/ml/meta_arch.py b/ml/meta_arch.py index 73cd3f5..ab6882e 100644 --- a/ml/meta_arch.py +++ b/ml/meta_arch.py @@ -50,13 +50,13 @@ def __init__( self.backbone = backbone self.decode_head = decode_head - self.criterion: nn.Module self._lbfgs_batches: list[tuple[Tensor, Tensor]] = [] self.class_names = [ n for n, _ in sorted(class_indices.items(), key=lambda kv: kv[1]) ] num_classes = len(self.class_names) + self.criterion = nn.CrossEntropyLoss(weight=torch.ones(num_classes)) macro_metrics = MetricCollection( { @@ -208,7 +208,7 @@ def _lbfgs_training_step(self, batch: Input, batch_idx: int) -> Tensor: with torch.no_grad(): return self.criterion(self(inputs), targets) - optimizer = self.optimizers() + optimizer = cast("Any", self.optimizers()) total_samples = sum(targets.numel() for _, targets in self._lbfgs_batches) def closure() -> Tensor: