Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
bef70df
feat: add embedding dataset build pipeline
vojtech-cifka May 8, 2026
911bec2
feat: add class tresholds and run ids
vojtech-cifka May 8, 2026
1a02395
fix: wrong run id
vojtech-cifka May 8, 2026
08d7ba5
Merge remote-tracking branch 'origin/master' into feature/embedding-d…
vojtech-cifka May 9, 2026
b38465e
feat: add timing
vojtech-cifka May 9, 2026
bfc9578
refactor: use pyarrow to avoid to pandas conversion
vojtech-cifka May 9, 2026
eb213c6
fix: join on keys only
vojtech-cifka May 9, 2026
c92d9a1
fix: typing
vojtech-cifka May 9, 2026
01cc394
fix: add prints
vojtech-cifka May 9, 2026
cad0d37
refactor: use combine chunks
vojtech-cifka May 9, 2026
ae04552
fix: lazy-cast embeddings to large_list and stay in Arrow during join
vojtech-cifka May 9, 2026
82320db
fix: validate label/tissue_prop columns when derive=False
vojtech-cifka May 9, 2026
3b0137f
chore: remove time
vojtech-cifka May 9, 2026
8df47aa
feat: add timing
vojtech-cifka May 10, 2026
926753d
chore: revert to the previous state
vojtech-cifka May 10, 2026
b0e9ba4
feat: add prints
vojtech-cifka May 10, 2026
6a915de
refactor: use discusssed thresholds
vojtech-cifka May 11, 2026
0f50307
refactor: use different labeling strategy
vojtech-cifka May 11, 2026
4d953dc
feat: implement training pipeline
vojtech-cifka May 11, 2026
d5798bc
feat: add class weights
vojtech-cifka May 11, 2026
ae45cd5
refactor: join embeddings with metadata while loading the dataset
vojtech-cifka May 11, 2026
bdce760
feat: add prints
vojtech-cifka May 11, 2026
ac633d5
fix: use chunks
vojtech-cifka May 11, 2026
2793562
fix: use numpy chunks
vojtech-cifka May 11, 2026
e81973e
fix: call end at the end of the main
vojtech-cifka May 11, 2026
0071592
chore: remove prints
vojtech-cifka May 11, 2026
c0a7499
chore: remove debug prints, stale TODO, and unused preprocessing pipe…
vojtech-cifka May 11, 2026
fe918d1
chore: remove markdown file
vojtech-cifka May 11, 2026
6b7d1e8
fix: edge cases
vojtech-cifka May 12, 2026
4ff988e
feat: normalize the confusion matrix rows per class recall
vojtech-cifka May 12, 2026
32375b2
fix: format
vojtech-cifka May 12, 2026
af9538a
feat: use stratified k fold run
vojtech-cifka May 12, 2026
bc0819a
fix: remove criterion
vojtech-cifka May 12, 2026
b8e85e0
fix: remove criterion from configs
vojtech-cifka May 12, 2026
ff4d307
Merge branch 'master' into feature/ml-linear-classifier
vojtech-cifka May 13, 2026
3cc670d
feat: add option to use different kfold strategies
vojtech-cifka May 13, 2026
27ceea3
fix: lower LR and patience
vojtech-cifka May 13, 2026
efde82a
fix: use f1 macro as a monitor
vojtech-cifka May 14, 2026
c8102de
fix: rever back to validation loss
vojtech-cifka May 14, 2026
c5bab90
fix: add weight decay 1e-3 to linear classifier
vojtech-cifka May 14, 2026
475b67c
Revert "fix: add weight decay 1e-3 to linear classifier"
vojtech-cifka May 14, 2026
43663a9
feat: add logistic regression
vojtech-cifka May 14, 2026
a2fe451
feat: polish and add two distinct submission scripts
vojtech-cifka May 14, 2026
31ecf6d
fix: submission scripts
vojtech-cifka May 14, 2026
ff8d0bf
feat: implement knn
vojtech-cifka May 14, 2026
1f87154
refactor: focus on convergence
vojtech-cifka May 14, 2026
7039307
Remove kNN sklearn baseline
vojtech-cifka May 14, 2026
729eccd
fix: change monitor to focus on train losss
vojtech-cifka May 14, 2026
d3ed2ed
feat: add run name
vojtech-cifka May 14, 2026
e9fd559
chore: remove logistic regression
vojtech-cifka May 15, 2026
6dadbd7
feat: implement lbfgs
vojtech-cifka May 15, 2026
d5d3edd
fix: run id
vojtech-cifka May 15, 2026
2163699
fix: cache the tiles and embeddings so they do not need to be downloa…
vojtech-cifka May 15, 2026
9286807
fix: limit num of workers
vojtech-cifka May 15, 2026
bb8a043
fix: support checkpoint test and prediction export
vojtech-cifka May 15, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion configs/data/dataset.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ dataset:
test_split_filename: "split_mapping/test_split.csv"
tiling_run_id: "fdf7550a2004474f8c7a05dc0cf1fd86"
filter_tiles_run_id: "4e8f5d3c82124ea5a8f871a42d3ed9ba"
stratified_kfold_run_id: "c7eafdffa32743aa9eb6dd2bf3a185b5"
stratified_group_kfold_run_id: "382b41d2fa894514908e8067949c4326"
embedding_run_id: "c325e3a5033b4077b6febb0e3e6b0bd6"
tissue_masks_run_id: "52bc0924f8624b259819c480c7cf213f"
tissue_stats_run_id: "16ae2d003d88471b924e5f332415232a"

Expand Down Expand Up @@ -56,4 +59,4 @@ dataset:
"44 Klarzelliges Nierenzellkarzinom_1": "kidney"
"50 Muzinöses Zystadenom_1": "breast"
"85 Mammakarzinom NST": "breast"
"28 Zöliakie": "small intestine"
"28 Zöliakie": "small intestine"
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# @package _global_

defaults:
- /experiment/ml/linear_classifier_stratified_group_kfold
- _self_

trainer:
max_epochs: 10

data:
batch_size: 1000000000
train_shuffle: false
train_drop_last: false
num_workers: 0

model:
optimizer: lbfgs
learning_rate: 1.0
lbfgs:
max_iter: 100
max_eval: null
tolerance_grad: 1.0e-7
tolerance_change: 1.0e-9
history_size: 100
line_search_fn: strong_wolfe
accumulate_batches: 1
accumulate_on_cpu: false
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# @package _global_

defaults:
- /experiment/ml/linear_classifier_stratified_kfold
- _self_

trainer:
max_epochs: 10

data:
batch_size: 1000000000
train_shuffle: false
train_drop_last: false

model:
optimizer: lbfgs
learning_rate: 1.0
lbfgs:
max_iter: 100
max_eval: null
tolerance_grad: 1.0e-7
tolerance_change: 1.0e-9
history_size: 100
line_search_fn: strong_wolfe
accumulate_batches: 1
accumulate_on_cpu: false
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# @package _global_

defaults:
- /ml/linear_classifier
- _self_

kfold_strategy: stratified_group
kfold_run_id: ${dataset.mlflow_artifacts.stratified_group_kfold_run_id}
8 changes: 8 additions & 0 deletions configs/experiment/ml/linear_classifier_stratified_kfold.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# @package _global_

defaults:
- /ml/linear_classifier
- _self_

kfold_strategy: stratified
kfold_run_id: ${dataset.mlflow_artifacts.stratified_kfold_run_id}
35 changes: 35 additions & 0 deletions configs/ml/data/embedding.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# @package _global_

data:
batch_size: 1024
num_workers: 4
train_shuffle: true
train_drop_last: true

train:
_target_: ml.data.datasets.EmbeddingTilesDataset
embedding_uri: ${train_embedding_uri}
metadata_uri: ${train_metadata_uri}
class_indices: ${class_indices}
thresholds: ${thresholds}
tissue_prop_min: ${tissue_prop_min}
exclude_folds:
- ${val_fold}

val:
_target_: ml.data.datasets.EmbeddingTilesDataset
embedding_uri: ${train_embedding_uri}
metadata_uri: ${train_metadata_uri}
class_indices: ${class_indices}
thresholds: ${thresholds}
tissue_prop_min: ${tissue_prop_min}
include_folds:
- ${val_fold}

test:
_target_: ml.data.datasets.EmbeddingTilesDataset
embedding_uri: ${test_embedding_uri}
metadata_uri: ${test_metadata_uri}
class_indices: ${class_indices}
thresholds: ${thresholds}
tissue_prop_min: ${tissue_prop_min}
54 changes: 54 additions & 0 deletions configs/ml/linear_classifier.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# @package _global_

defaults:
- /data: dataset
- /class_mapping: collapse_alterations_to_other
- /ml/trainer: default
- /ml/data: embedding
- /ml/model: linear_classifier
- _self_

mode: fit

embedding_run_id: ${dataset.mlflow_artifacts.embedding_run_id}
kfold_strategy: stratified
kfold_run_id: ${dataset.mlflow_artifacts.stratified_kfold_run_id}
filter_tiles_run_id: ${dataset.mlflow_artifacts.filter_tiles_run_id}

train_embedding_uri: runs:/${embedding_run_id}/train/tiles
test_embedding_uri: runs:/${embedding_run_id}/test/tiles
train_metadata_uri: runs:/${kfold_run_id}/kfold_split/kfold_tiles.parquet
test_metadata_uri: runs:/${filter_tiles_run_id}/filter_tiles/test_tiles.parquet

val_fold: 0

tissue_prop_min: 0.2
thresholds:
Nerve: 0.0
Blood: 0.0
Connective-Tissue: 0.4
Fat: 0.6
Epithelium: 0.2
Muscle: 0.5
Other: 0.5

mlflow_artifact_path: linear_classifier

metadata:
run_name: Linear Classifier ${dataset.name} ${kfold_strategy} fold=${val_fold} opt=${model.optimizer} wd=${model.weight_decay}
description: "Linear probe over frozen Virchow2 embeddings (run ${embedding_run_id}), ${kfold_strategy} kfold metadata ${kfold_run_id}."
hyperparams:
embedding_run_id: ${embedding_run_id}
kfold_strategy: ${kfold_strategy}
kfold_run_id: ${kfold_run_id}
filter_tiles_run_id: ${filter_tiles_run_id}
val_fold: ${val_fold}
tissue_prop_min: ${tissue_prop_min}
thresholds: ${thresholds}
optimizer: ${model.optimizer}
learning_rate: ${model.learning_rate}
weight_decay: ${model.weight_decay}
lbfgs: ${model.lbfgs}
batch_size: ${data.batch_size}
train_shuffle: ${data.train_shuffle}
train_drop_last: ${data.train_drop_last}
25 changes: 25 additions & 0 deletions configs/ml/model/linear_classifier.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# @package _global_

model:
backbone:
_target_: torch.nn.Identity

decode_head:
_target_: torch.nn.Linear
in_features: 2560
out_features: ${len:${class_indices}}

class_indices: ${class_indices}

optimizer: adamw
learning_rate: 1.0e-4
weight_decay: 0.0
lbfgs:
max_iter: 100
max_eval: null
tolerance_grad: 1.0e-7
tolerance_change: 1.0e-9
history_size: 100
line_search_fn: strong_wolfe
accumulate_batches: 1
accumulate_on_cpu: false
30 changes: 30 additions & 0 deletions configs/ml/trainer/default.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# @package _global_

trainer:
max_epochs: 500
accelerator: auto
devices: auto
precision: 32
log_every_n_steps: 50
deterministic: false

callbacks:
early_stopping:
_target_: lightning.pytorch.callbacks.EarlyStopping
monitor: train/loss_epoch
mode: min
patience: 1
min_delta: 1.0e-4
model_checkpoint:
_target_: lightning.pytorch.callbacks.ModelCheckpoint
monitor: train/loss_epoch
mode: min
save_top_k: 1
filename: "epoch={epoch}-train_loss={train/loss_epoch:.4f}"
auto_insert_metric_name: false
lr_monitor:
_target_: lightning.pytorch.callbacks.LearningRateMonitor
logging_interval: epoch
prediction_writer:
_target_: ml.callbacks.ParquetPredictionWriter
output_filename: predictions.parquet
Empty file added ml/__init__.py
Empty file.
36 changes: 36 additions & 0 deletions ml/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from random import randint

import hydra
import mlflow
from lightning import seed_everything
from omegaconf import DictConfig, OmegaConf
from rationai.mlkit import Trainer, autolog
from rationai.mlkit.lightning.loggers import MLFlowLogger

from ml.data import DataModule
from ml.meta_arch import MetaArch


OmegaConf.register_new_resolver(
"random_seed", lambda: randint(0, 2**31), use_cache=True
)
OmegaConf.register_new_resolver("len", lambda x: len(x))


@hydra.main(config_path="../configs", config_name="ml", version_base=None)
@autolog
def main(config: DictConfig, logger: MLFlowLogger) -> None:
seed_everything(config.seed, workers=True)

data = hydra.utils.instantiate(config.data, _recursive_=False, _target_=DataModule)
model = hydra.utils.instantiate(config.model, _target_=MetaArch)
trainer = hydra.utils.instantiate(config.trainer, _target_=Trainer, logger=logger)
allowed_modes = ["fit", "test", "validate", "predict"]
if config.mode not in allowed_modes:
raise ValueError(f"Invalid mode {config.mode!r}. Allowed: {allowed_modes}")
getattr(trainer, config.mode)(model, datamodule=data, ckpt_path=config.checkpoint)
Comment thread
vojtech-cifka marked this conversation as resolved.
mlflow.end_run()


if __name__ == "__main__":
main()
4 changes: 4 additions & 0 deletions ml/callbacks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from ml.callbacks.parquet_prediction_writer import ParquetPredictionWriter


__all__ = ["ParquetPredictionWriter"]
72 changes: 72 additions & 0 deletions ml/callbacks/parquet_prediction_writer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""Aggregate ``predict_step`` outputs and write them as a parquet artifact."""

from pathlib import Path
from typing import Any

import lightning as pl
import mlflow
import numpy as np
import pandas as pd
from lightning.pytorch.callbacks import BasePredictionWriter


class ParquetPredictionWriter(BasePredictionWriter):
"""Collect per-tile predictions and write them as a parquet artifact.

Aggregates ``predict_step`` outputs across the predict loop, writes one
parquet file with ``slide_id``, ``target``, ``pred``, ``prob_<class>``
columns and logs it to the active MLflow run.
"""

def __init__(self, output_filename: str = "predictions.parquet") -> None:
super().__init__(write_interval="epoch")
self.output_filename = output_filename

def write_on_epoch_end(
self,
trainer: pl.Trainer,
pl_module: pl.LightningModule,
predictions: Any,
batch_indices: Any,
) -> None:
if trainer.global_rank != 0:
return

batches = (
predictions
if not predictions or isinstance(predictions[0], dict)
else [b for dataloader_preds in predictions for b in dataloader_preds]
)

slide_ids: list[str] = []
targets: list[int] = []
preds: list[int] = []
probs: list[np.ndarray] = []
for b in batches:
slide_ids.extend(b["slide_id"])
targets.extend(b["target"].tolist())
preds.extend(b["pred"].tolist())
probs.append(b["probs"].numpy())

if not slide_ids:
return

prob_matrix = np.concatenate(probs, axis=0)

class_names = getattr(pl_module, "class_names", None)
prob_columns = (
[f"prob_{c}" for c in class_names]
if class_names is not None and len(class_names) == prob_matrix.shape[1]
else [f"prob_{i}" for i in range(prob_matrix.shape[1])]
)

df = pd.DataFrame({"slide_id": slide_ids, "target": targets, "pred": preds})
df = pd.concat([df, pd.DataFrame(prob_matrix, columns=prob_columns)], axis=1)

out_path = Path(trainer.default_root_dir) / self.output_filename
out_path.parent.mkdir(parents=True, exist_ok=True)
df.to_parquet(out_path, index=False)

active = mlflow.active_run()
if active is not None:
mlflow.log_artifact(str(out_path), artifact_path="predictions")
4 changes: 4 additions & 0 deletions ml/data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from ml.data.data_module import DataModule


__all__ = ["DataModule"]
Loading