-
Notifications
You must be signed in to change notification settings - Fork 0
# feat: linear classifier training pipeline on precomputed embeddings #11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
55 commits
Select commit
Hold shift + click to select a range
bef70df
feat: add embedding dataset build pipeline
vojtech-cifka 911bec2
feat: add class tresholds and run ids
vojtech-cifka 1a02395
fix: wrong run id
vojtech-cifka 08d7ba5
Merge remote-tracking branch 'origin/master' into feature/embedding-d…
vojtech-cifka b38465e
feat: add timing
vojtech-cifka bfc9578
refactor: use pyarrow to avoid to pandas conversion
vojtech-cifka eb213c6
fix: join on keys only
vojtech-cifka c92d9a1
fix: typing
vojtech-cifka 01cc394
fix: add prints
vojtech-cifka cad0d37
refactor: use combine chunks
vojtech-cifka ae04552
fix: lazy-cast embeddings to large_list and stay in Arrow during join
vojtech-cifka 82320db
fix: validate label/tissue_prop columns when derive=False
vojtech-cifka 3b0137f
chore: remove time
vojtech-cifka 8df47aa
feat: add timing
vojtech-cifka 926753d
chore: revert to the previous state
vojtech-cifka b0e9ba4
feat: add prints
vojtech-cifka 6a915de
refactor: use discusssed thresholds
vojtech-cifka 0f50307
refactor: use different labeling strategy
vojtech-cifka 4d953dc
feat: implement training pipeline
vojtech-cifka d5798bc
feat: add class weights
vojtech-cifka ae45cd5
refactor: join embeddings with metadata while loading the dataset
vojtech-cifka bdce760
feat: add prints
vojtech-cifka ac633d5
fix: use chunks
vojtech-cifka 2793562
fix: use numpy chunks
vojtech-cifka e81973e
fix: call end at the end of the main
vojtech-cifka 0071592
chore: remove prints
vojtech-cifka c0a7499
chore: remove debug prints, stale TODO, and unused preprocessing pipe…
vojtech-cifka fe918d1
chore: remove markdown file
vojtech-cifka 6b7d1e8
fix: edge cases
vojtech-cifka 4ff988e
feat: normalize the confusion matrix rows per class recall
vojtech-cifka 32375b2
fix: format
vojtech-cifka af9538a
feat: use stratified k fold run
vojtech-cifka bc0819a
fix: remove criterion
vojtech-cifka b8e85e0
fix: remove criterion from configs
vojtech-cifka ff4d307
Merge branch 'master' into feature/ml-linear-classifier
vojtech-cifka 3cc670d
feat: add option to use different kfold strategies
vojtech-cifka 27ceea3
fix: lower LR and patience
vojtech-cifka efde82a
fix: use f1 macro as a monitor
vojtech-cifka c8102de
fix: rever back to validation loss
vojtech-cifka c5bab90
fix: add weight decay 1e-3 to linear classifier
vojtech-cifka 475b67c
Revert "fix: add weight decay 1e-3 to linear classifier"
vojtech-cifka 43663a9
feat: add logistic regression
vojtech-cifka a2fe451
feat: polish and add two distinct submission scripts
vojtech-cifka 31ecf6d
fix: submission scripts
vojtech-cifka ff8d0bf
feat: implement knn
vojtech-cifka 1f87154
refactor: focus on convergence
vojtech-cifka 7039307
Remove kNN sklearn baseline
vojtech-cifka 729eccd
fix: change monitor to focus on train losss
vojtech-cifka d3ed2ed
feat: add run name
vojtech-cifka e9fd559
chore: remove logistic regression
vojtech-cifka 6dadbd7
feat: implement lbfgs
vojtech-cifka d5d3edd
fix: run id
vojtech-cifka 2163699
fix: cache the tiles and embeddings so they do not need to be downloa…
vojtech-cifka 9286807
fix: limit num of workers
vojtech-cifka bb8a043
fix: support checkpoint test and prediction export
vojtech-cifka File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
27 changes: 27 additions & 0 deletions
27
configs/experiment/ml/linear_classifier_lbfgs_stratified_group_kfold.yaml
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,27 @@ | ||
| # @package _global_ | ||
|
|
||
| defaults: | ||
| - /experiment/ml/linear_classifier_stratified_group_kfold | ||
| - _self_ | ||
|
|
||
| trainer: | ||
| max_epochs: 10 | ||
|
|
||
| data: | ||
| batch_size: 1000000000 | ||
| train_shuffle: false | ||
| train_drop_last: false | ||
| num_workers: 0 | ||
|
|
||
| model: | ||
| optimizer: lbfgs | ||
| learning_rate: 1.0 | ||
| lbfgs: | ||
| max_iter: 100 | ||
| max_eval: null | ||
| tolerance_grad: 1.0e-7 | ||
| tolerance_change: 1.0e-9 | ||
| history_size: 100 | ||
| line_search_fn: strong_wolfe | ||
| accumulate_batches: 1 | ||
| accumulate_on_cpu: false |
26 changes: 26 additions & 0 deletions
26
configs/experiment/ml/linear_classifier_lbfgs_stratified_kfold.yaml
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,26 @@ | ||
| # @package _global_ | ||
|
|
||
| defaults: | ||
| - /experiment/ml/linear_classifier_stratified_kfold | ||
| - _self_ | ||
|
|
||
| trainer: | ||
| max_epochs: 10 | ||
|
|
||
| data: | ||
| batch_size: 1000000000 | ||
| train_shuffle: false | ||
| train_drop_last: false | ||
|
|
||
| model: | ||
| optimizer: lbfgs | ||
| learning_rate: 1.0 | ||
| lbfgs: | ||
| max_iter: 100 | ||
| max_eval: null | ||
| tolerance_grad: 1.0e-7 | ||
| tolerance_change: 1.0e-9 | ||
| history_size: 100 | ||
| line_search_fn: strong_wolfe | ||
| accumulate_batches: 1 | ||
| accumulate_on_cpu: false |
8 changes: 8 additions & 0 deletions
8
configs/experiment/ml/linear_classifier_stratified_group_kfold.yaml
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,8 @@ | ||
| # @package _global_ | ||
|
|
||
| defaults: | ||
| - /ml/linear_classifier | ||
| - _self_ | ||
|
|
||
| kfold_strategy: stratified_group | ||
| kfold_run_id: ${dataset.mlflow_artifacts.stratified_group_kfold_run_id} |
8 changes: 8 additions & 0 deletions
8
configs/experiment/ml/linear_classifier_stratified_kfold.yaml
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,8 @@ | ||
| # @package _global_ | ||
|
|
||
| defaults: | ||
| - /ml/linear_classifier | ||
| - _self_ | ||
|
|
||
| kfold_strategy: stratified | ||
| kfold_run_id: ${dataset.mlflow_artifacts.stratified_kfold_run_id} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,35 @@ | ||
| # @package _global_ | ||
|
|
||
| data: | ||
| batch_size: 1024 | ||
| num_workers: 4 | ||
| train_shuffle: true | ||
| train_drop_last: true | ||
|
|
||
| train: | ||
| _target_: ml.data.datasets.EmbeddingTilesDataset | ||
| embedding_uri: ${train_embedding_uri} | ||
| metadata_uri: ${train_metadata_uri} | ||
| class_indices: ${class_indices} | ||
| thresholds: ${thresholds} | ||
| tissue_prop_min: ${tissue_prop_min} | ||
| exclude_folds: | ||
| - ${val_fold} | ||
|
|
||
| val: | ||
| _target_: ml.data.datasets.EmbeddingTilesDataset | ||
| embedding_uri: ${train_embedding_uri} | ||
| metadata_uri: ${train_metadata_uri} | ||
| class_indices: ${class_indices} | ||
| thresholds: ${thresholds} | ||
| tissue_prop_min: ${tissue_prop_min} | ||
| include_folds: | ||
| - ${val_fold} | ||
|
|
||
| test: | ||
| _target_: ml.data.datasets.EmbeddingTilesDataset | ||
| embedding_uri: ${test_embedding_uri} | ||
| metadata_uri: ${test_metadata_uri} | ||
| class_indices: ${class_indices} | ||
| thresholds: ${thresholds} | ||
| tissue_prop_min: ${tissue_prop_min} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,54 @@ | ||
| # @package _global_ | ||
|
|
||
| defaults: | ||
| - /data: dataset | ||
| - /class_mapping: collapse_alterations_to_other | ||
| - /ml/trainer: default | ||
| - /ml/data: embedding | ||
| - /ml/model: linear_classifier | ||
| - _self_ | ||
|
|
||
| mode: fit | ||
|
|
||
| embedding_run_id: ${dataset.mlflow_artifacts.embedding_run_id} | ||
| kfold_strategy: stratified | ||
| kfold_run_id: ${dataset.mlflow_artifacts.stratified_kfold_run_id} | ||
| filter_tiles_run_id: ${dataset.mlflow_artifacts.filter_tiles_run_id} | ||
|
|
||
| train_embedding_uri: runs:/${embedding_run_id}/train/tiles | ||
| test_embedding_uri: runs:/${embedding_run_id}/test/tiles | ||
| train_metadata_uri: runs:/${kfold_run_id}/kfold_split/kfold_tiles.parquet | ||
| test_metadata_uri: runs:/${filter_tiles_run_id}/filter_tiles/test_tiles.parquet | ||
|
|
||
| val_fold: 0 | ||
|
|
||
| tissue_prop_min: 0.2 | ||
| thresholds: | ||
| Nerve: 0.0 | ||
| Blood: 0.0 | ||
| Connective-Tissue: 0.4 | ||
| Fat: 0.6 | ||
| Epithelium: 0.2 | ||
| Muscle: 0.5 | ||
| Other: 0.5 | ||
|
|
||
| mlflow_artifact_path: linear_classifier | ||
|
|
||
| metadata: | ||
| run_name: Linear Classifier ${dataset.name} ${kfold_strategy} fold=${val_fold} opt=${model.optimizer} wd=${model.weight_decay} | ||
| description: "Linear probe over frozen Virchow2 embeddings (run ${embedding_run_id}), ${kfold_strategy} kfold metadata ${kfold_run_id}." | ||
| hyperparams: | ||
| embedding_run_id: ${embedding_run_id} | ||
| kfold_strategy: ${kfold_strategy} | ||
| kfold_run_id: ${kfold_run_id} | ||
| filter_tiles_run_id: ${filter_tiles_run_id} | ||
| val_fold: ${val_fold} | ||
| tissue_prop_min: ${tissue_prop_min} | ||
| thresholds: ${thresholds} | ||
| optimizer: ${model.optimizer} | ||
| learning_rate: ${model.learning_rate} | ||
| weight_decay: ${model.weight_decay} | ||
| lbfgs: ${model.lbfgs} | ||
| batch_size: ${data.batch_size} | ||
| train_shuffle: ${data.train_shuffle} | ||
| train_drop_last: ${data.train_drop_last} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,25 @@ | ||
| # @package _global_ | ||
|
|
||
| model: | ||
| backbone: | ||
| _target_: torch.nn.Identity | ||
|
|
||
| decode_head: | ||
| _target_: torch.nn.Linear | ||
| in_features: 2560 | ||
| out_features: ${len:${class_indices}} | ||
|
|
||
| class_indices: ${class_indices} | ||
|
|
||
| optimizer: adamw | ||
| learning_rate: 1.0e-4 | ||
| weight_decay: 0.0 | ||
| lbfgs: | ||
| max_iter: 100 | ||
| max_eval: null | ||
| tolerance_grad: 1.0e-7 | ||
| tolerance_change: 1.0e-9 | ||
| history_size: 100 | ||
| line_search_fn: strong_wolfe | ||
| accumulate_batches: 1 | ||
| accumulate_on_cpu: false |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,30 @@ | ||
| # @package _global_ | ||
|
|
||
| trainer: | ||
| max_epochs: 500 | ||
| accelerator: auto | ||
| devices: auto | ||
| precision: 32 | ||
| log_every_n_steps: 50 | ||
| deterministic: false | ||
|
|
||
| callbacks: | ||
| early_stopping: | ||
| _target_: lightning.pytorch.callbacks.EarlyStopping | ||
| monitor: train/loss_epoch | ||
| mode: min | ||
| patience: 1 | ||
| min_delta: 1.0e-4 | ||
| model_checkpoint: | ||
| _target_: lightning.pytorch.callbacks.ModelCheckpoint | ||
| monitor: train/loss_epoch | ||
| mode: min | ||
| save_top_k: 1 | ||
| filename: "epoch={epoch}-train_loss={train/loss_epoch:.4f}" | ||
| auto_insert_metric_name: false | ||
| lr_monitor: | ||
| _target_: lightning.pytorch.callbacks.LearningRateMonitor | ||
| logging_interval: epoch | ||
| prediction_writer: | ||
| _target_: ml.callbacks.ParquetPredictionWriter | ||
| output_filename: predictions.parquet |
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,36 @@ | ||
| from random import randint | ||
|
|
||
| import hydra | ||
| import mlflow | ||
| from lightning import seed_everything | ||
| from omegaconf import DictConfig, OmegaConf | ||
| from rationai.mlkit import Trainer, autolog | ||
| from rationai.mlkit.lightning.loggers import MLFlowLogger | ||
|
|
||
| from ml.data import DataModule | ||
| from ml.meta_arch import MetaArch | ||
|
|
||
|
|
||
| OmegaConf.register_new_resolver( | ||
| "random_seed", lambda: randint(0, 2**31), use_cache=True | ||
| ) | ||
| OmegaConf.register_new_resolver("len", lambda x: len(x)) | ||
|
|
||
|
|
||
| @hydra.main(config_path="../configs", config_name="ml", version_base=None) | ||
| @autolog | ||
| def main(config: DictConfig, logger: MLFlowLogger) -> None: | ||
| seed_everything(config.seed, workers=True) | ||
|
|
||
| data = hydra.utils.instantiate(config.data, _recursive_=False, _target_=DataModule) | ||
| model = hydra.utils.instantiate(config.model, _target_=MetaArch) | ||
| trainer = hydra.utils.instantiate(config.trainer, _target_=Trainer, logger=logger) | ||
| allowed_modes = ["fit", "test", "validate", "predict"] | ||
| if config.mode not in allowed_modes: | ||
| raise ValueError(f"Invalid mode {config.mode!r}. Allowed: {allowed_modes}") | ||
| getattr(trainer, config.mode)(model, datamodule=data, ckpt_path=config.checkpoint) | ||
| mlflow.end_run() | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,4 @@ | ||
| from ml.callbacks.parquet_prediction_writer import ParquetPredictionWriter | ||
|
|
||
|
|
||
| __all__ = ["ParquetPredictionWriter"] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,72 @@ | ||
| """Aggregate ``predict_step`` outputs and write them as a parquet artifact.""" | ||
|
|
||
| from pathlib import Path | ||
| from typing import Any | ||
|
|
||
| import lightning as pl | ||
| import mlflow | ||
| import numpy as np | ||
| import pandas as pd | ||
| from lightning.pytorch.callbacks import BasePredictionWriter | ||
|
|
||
|
|
||
| class ParquetPredictionWriter(BasePredictionWriter): | ||
| """Collect per-tile predictions and write them as a parquet artifact. | ||
|
|
||
| Aggregates ``predict_step`` outputs across the predict loop, writes one | ||
| parquet file with ``slide_id``, ``target``, ``pred``, ``prob_<class>`` | ||
| columns and logs it to the active MLflow run. | ||
| """ | ||
|
|
||
| def __init__(self, output_filename: str = "predictions.parquet") -> None: | ||
| super().__init__(write_interval="epoch") | ||
| self.output_filename = output_filename | ||
|
|
||
| def write_on_epoch_end( | ||
| self, | ||
| trainer: pl.Trainer, | ||
| pl_module: pl.LightningModule, | ||
| predictions: Any, | ||
| batch_indices: Any, | ||
| ) -> None: | ||
| if trainer.global_rank != 0: | ||
| return | ||
|
|
||
| batches = ( | ||
| predictions | ||
| if not predictions or isinstance(predictions[0], dict) | ||
| else [b for dataloader_preds in predictions for b in dataloader_preds] | ||
| ) | ||
|
|
||
| slide_ids: list[str] = [] | ||
| targets: list[int] = [] | ||
| preds: list[int] = [] | ||
| probs: list[np.ndarray] = [] | ||
| for b in batches: | ||
| slide_ids.extend(b["slide_id"]) | ||
| targets.extend(b["target"].tolist()) | ||
| preds.extend(b["pred"].tolist()) | ||
| probs.append(b["probs"].numpy()) | ||
|
|
||
| if not slide_ids: | ||
| return | ||
|
|
||
| prob_matrix = np.concatenate(probs, axis=0) | ||
|
|
||
| class_names = getattr(pl_module, "class_names", None) | ||
| prob_columns = ( | ||
| [f"prob_{c}" for c in class_names] | ||
| if class_names is not None and len(class_names) == prob_matrix.shape[1] | ||
| else [f"prob_{i}" for i in range(prob_matrix.shape[1])] | ||
| ) | ||
|
|
||
| df = pd.DataFrame({"slide_id": slide_ids, "target": targets, "pred": preds}) | ||
| df = pd.concat([df, pd.DataFrame(prob_matrix, columns=prob_columns)], axis=1) | ||
|
|
||
| out_path = Path(trainer.default_root_dir) / self.output_filename | ||
| out_path.parent.mkdir(parents=True, exist_ok=True) | ||
| df.to_parquet(out_path, index=False) | ||
|
|
||
| active = mlflow.active_run() | ||
| if active is not None: | ||
| mlflow.log_artifact(str(out_path), artifact_path="predictions") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,4 @@ | ||
| from ml.data.data_module import DataModule | ||
|
|
||
|
|
||
| __all__ = ["DataModule"] |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.