Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
129 commits
Select commit Hold shift + click to select a range
24668c3
feat: create ml pipeline for linear probe
vojtech-cifka May 1, 2026
f340038
refactor(ml): switch DataModule to HF datasets with fold-based split
vojtech-cifka May 4, 2026
c3ef38a
feat(ml): wire up linear probe training with k-fold CV on cached embe…
vojtech-cifka May 7, 2026
c644f22
fix(configs): use override for class_mapping in experiment yaml
vojtech-cifka May 7, 2026
564b0b1
fix(scripts): drop duplicate +ml= from linear-probe submit command
vojtech-cifka May 7, 2026
3a77adc
fix(ml): register random_seed/len resolvers and unflatten class_mappi…
vojtech-cifka May 7, 2026
11b19f0
fix(ml): accept already-canonical labels in datamodule label map
vojtech-cifka May 8, 2026
c6bfe8e
feat(ml): class-weighted CE, raise class_coverage_min to 0.5
vojtech-cifka May 8, 2026
894c27b
fix: sort only tiles parquet
vojtech-cifka May 8, 2026
fc824ad
fix: log join types of tile keys
vojtech-cifka May 8, 2026
11931d1
fix: remove embeddings from the join
vojtech-cifka May 8, 2026
fb6b320
fix: remove label column
vojtech-cifka May 8, 2026
7434ae9
fix: prevent overflow
vojtech-cifka May 8, 2026
1b18daa
Merge remote-tracking branch 'origin/master' into feature/linear-probe
vojtech-cifka May 8, 2026
bef70df
feat: add embedding dataset build pipeline
vojtech-cifka May 8, 2026
911bec2
feat: add class tresholds and run ids
vojtech-cifka May 8, 2026
1a02395
fix: wrong run id
vojtech-cifka May 8, 2026
08d7ba5
Merge remote-tracking branch 'origin/master' into feature/embedding-d…
vojtech-cifka May 9, 2026
b38465e
feat: add timing
vojtech-cifka May 9, 2026
bfc9578
refactor: use pyarrow to avoid to pandas conversion
vojtech-cifka May 9, 2026
eb213c6
fix: join on keys only
vojtech-cifka May 9, 2026
c92d9a1
fix: typing
vojtech-cifka May 9, 2026
01cc394
fix: add prints
vojtech-cifka May 9, 2026
cad0d37
refactor: use combine chunks
vojtech-cifka May 9, 2026
ae04552
fix: lazy-cast embeddings to large_list and stay in Arrow during join
vojtech-cifka May 9, 2026
82320db
fix: validate label/tissue_prop columns when derive=False
vojtech-cifka May 9, 2026
3b0137f
chore: remove time
vojtech-cifka May 9, 2026
8df47aa
feat: add timing
vojtech-cifka May 10, 2026
926753d
chore: revert to the previous state
vojtech-cifka May 10, 2026
b0e9ba4
feat: add prints
vojtech-cifka May 10, 2026
6a915de
refactor: use discusssed thresholds
vojtech-cifka May 11, 2026
0f50307
refactor: use different labeling strategy
vojtech-cifka May 11, 2026
4d953dc
feat: implement training pipeline
vojtech-cifka May 11, 2026
d5798bc
feat: add class weights
vojtech-cifka May 11, 2026
ae45cd5
refactor: join embeddings with metadata while loading the dataset
vojtech-cifka May 11, 2026
bdce760
feat: add prints
vojtech-cifka May 11, 2026
ac633d5
fix: use chunks
vojtech-cifka May 11, 2026
2793562
fix: use numpy chunks
vojtech-cifka May 11, 2026
e81973e
fix: call end at the end of the main
vojtech-cifka May 11, 2026
0071592
chore: remove prints
vojtech-cifka May 11, 2026
c0a7499
chore: remove debug prints, stale TODO, and unused preprocessing pipe…
vojtech-cifka May 11, 2026
fe918d1
chore: remove markdown file
vojtech-cifka May 11, 2026
6b7d1e8
fix: edge cases
vojtech-cifka May 12, 2026
4ff988e
feat: normalize the confusion matrix rows per class recall
vojtech-cifka May 12, 2026
32375b2
fix: format
vojtech-cifka May 12, 2026
af9538a
feat: use stratified k fold run
vojtech-cifka May 12, 2026
bc0819a
fix: remove criterion
vojtech-cifka May 12, 2026
b8e85e0
fix: remove criterion from configs
vojtech-cifka May 12, 2026
c387189
feat: implement test pipeline
vojtech-cifka May 13, 2026
1216504
fix: Hydra unreached
vojtech-cifka May 13, 2026
7ec86ef
fix: set weights only to false
vojtech-cifka May 13, 2026
c9b566e
fix: criterion weight
vojtech-cifka May 13, 2026
ff4d307
Merge branch 'master' into feature/ml-linear-classifier
vojtech-cifka May 13, 2026
3cc670d
feat: add option to use different kfold strategies
vojtech-cifka May 13, 2026
ad0a4e7
feat: add training without validation
vojtech-cifka May 13, 2026
811e21c
feat: implement final test run
vojtech-cifka May 13, 2026
27ceea3
fix: lower LR and patience
vojtech-cifka May 13, 2026
efde82a
fix: use f1 macro as a monitor
vojtech-cifka May 14, 2026
c8102de
fix: rever back to validation loss
vojtech-cifka May 14, 2026
c5bab90
fix: add weight decay 1e-3 to linear classifier
vojtech-cifka May 14, 2026
475b67c
Revert "fix: add weight decay 1e-3 to linear classifier"
vojtech-cifka May 14, 2026
43663a9
feat: add logistic regression
vojtech-cifka May 14, 2026
a2fe451
feat: polish and add two distinct submission scripts
vojtech-cifka May 14, 2026
31ecf6d
fix: submission scripts
vojtech-cifka May 14, 2026
ff8d0bf
feat: implement knn
vojtech-cifka May 14, 2026
1f87154
refactor: focus on convergence
vojtech-cifka May 14, 2026
7039307
Remove kNN sklearn baseline
vojtech-cifka May 14, 2026
729eccd
fix: change monitor to focus on train losss
vojtech-cifka May 14, 2026
d3ed2ed
feat: add run name
vojtech-cifka May 14, 2026
e9fd559
chore: remove logistic regression
vojtech-cifka May 15, 2026
6dadbd7
feat: implement lbfgs
vojtech-cifka May 15, 2026
d5d3edd
fix: run id
vojtech-cifka May 15, 2026
2163699
fix: cache the tiles and embeddings so they do not need to be downloa…
vojtech-cifka May 15, 2026
9286807
fix: limit num of workers
vojtech-cifka May 15, 2026
bb8a043
fix: support checkpoint test and prediction export
vojtech-cifka May 15, 2026
c284d8d
Merge remote-tracking branch 'origin/feature/linear-probe' into featu…
vojtech-cifka May 15, 2026
efddcd6
Revert "Merge remote-tracking branch 'origin/feature/linear-probe' in…
vojtech-cifka May 15, 2026
420534e
Merge remote-tracking branch 'origin/feature/ml-linear-classifier' in…
vojtech-cifka May 15, 2026
14909e2
feat: add functionality to submit final train for both adamw and lbfgs
vojtech-cifka May 15, 2026
8167363
feat: implement prediction maps
vojtech-cifka May 16, 2026
4e45ce1
fix: change the adamw checkpoint dir name to last
vojtech-cifka May 16, 2026
8f9ce70
fix: lower the batch so the compute does not hang
vojtech-cifka May 16, 2026
99c2d0d
fix: put num workers to 0
vojtech-cifka May 16, 2026
01486bd
feat: add prints
vojtech-cifka May 16, 2026
64963ac
Merge branch 'master' into feature/ml-test-mode
vojtech-cifka May 16, 2026
85270fd
feat: add diagnostic prints
vojtech-cifka May 16, 2026
5db671c
fix: use numpy buffer
vojtech-cifka May 16, 2026
3aea3c2
refactor: use HeatmapAssembler
vojtech-cifka May 16, 2026
756642a
chore: clean config structure
vojtech-cifka May 16, 2026
2771e78
fix: prediction maps class indices
vojtech-cifka May 16, 2026
918b691
fix: format and mypy
vojtech-cifka May 16, 2026
4032df3
feat: add posibility to predict the whole slide with tissue area
vojtech-cifka May 17, 2026
ca50a7c
feat: add embeddings for whole slide
vojtech-cifka May 17, 2026
6489cd0
refactor: compute grayscale mask per each class
vojtech-cifka May 17, 2026
afddfc1
feat: add the provgigapath train and test runs
vojtech-cifka May 17, 2026
b618807
feat: set final weight decay for train
vojtech-cifka May 17, 2026
0ec0da8
feat: turn on prediction maps for the test runs over the annotated re…
vojtech-cifka May 18, 2026
9d8729a
refactor: do not generate error masks
vojtech-cifka May 18, 2026
ac0ce16
chore: config cleanup
vojtech-cifka May 18, 2026
099e277
feat: add prints to the prediction maps writer
vojtech-cifka May 18, 2026
4909324
feat: add embeddings run id for the whole tissue tiles run
vojtech-cifka May 18, 2026
ee9d2da
feat: add prediction maps in configs
vojtech-cifka May 18, 2026
8b3a82d
chore: deduplicate, apply safety nets
vojtech-cifka May 18, 2026
27c7596
Merge branch 'feature/ml-test-mode' into feature/provgigapath-metrics…
vojtech-cifka May 18, 2026
e16426e
fix: pytorch checkpoint loading
vojtech-cifka May 18, 2026
fd3fdd6
chore: remove redundancy, rename variables
vojtech-cifka May 18, 2026
c401015
chore: remove username and branch
vojtech-cifka May 18, 2026
847c3cc
refactor: rename configs
vojtech-cifka May 18, 2026
b138a42
Merge branch 'feature/ml-test-mode' into feature/provgigapath-metrics…
vojtech-cifka May 18, 2026
51f36fb
chore: remove pgp test prediction maps
vojtech-cifka May 18, 2026
2ba0562
fix: keep criterion.weight in state_dict for strict checkpoint load
vojtech-cifka May 18, 2026
e370417
fix: criterion weight
vojtech-cifka May 18, 2026
632a8f6
fix: keep space in MUG prediction masks names
vojtech-cifka May 18, 2026
3cd0243
fix: log test accuracy as jsons
vojtech-cifka May 18, 2026
76e4194
chore: remove username from the submission script
vojtech-cifka May 18, 2026
597e348
fix: force the entering of the write phase of the prediction maps
vojtech-cifka May 18, 2026
e4a4cc5
fix: surface why prediction-map write phase skips
vojtech-cifka May 18, 2026
3829ebd
fix: remove username
vojtech-cifka May 18, 2026
0b2d38e
feat: generate embeddings up to a budget
vojtech-cifka May 18, 2026
ff7c06e
Merge branch 'feature/ml-test-mode' into feature/provgigapath-metrics…
vojtech-cifka May 18, 2026
4a528d7
Merge branch 'master' into feature/provgigapath-metrics-test
vojtech-cifka May 19, 2026
b3d803a
chore: rename ml experiments for clarity
vojtech-cifka May 19, 2026
1703c01
feat: add original slide name in the per slide statistics
vojtech-cifka May 19, 2026
3c72c4f
Merge remote-tracking branch 'origin/master' into feature/provgigapat…
vojtech-cifka May 19, 2026
67d3ef4
refactor: simplify the preprocessing name scripts
vojtech-cifka May 19, 2026
cd4b19a
fix: commnets, generate safer filenames
vojtech-cifka May 19, 2026
c9b4c67
fix: remove erorr masks
vojtech-cifka May 19, 2026
df7888f
fix: remove rendundant column selection
vojtech-cifka May 19, 2026
191892c
fix: format
vojtech-cifka May 19, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions configs/experiment/ml/final_linear_provgigapath_adamw.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# @package _global_

defaults:
- /experiment/ml/final_linear_virchow2_adamw
- _self_

embedding_model_name: ProvGigaPath
embedding_dim: 1536
embedding_run_id: 410c8672471348ceb4c58817f70fa097
kfold_strategy: stratified_group
kfold_run_id: ${dataset.mlflow_artifacts.stratified_group_kfold_run_id}
mlflow_artifact_path: linear_classifier_final_provgigapath

# Set after Stage 1 from ProvGigaPath's own AdamW sweep selected by
# validation/f1_macro.
model:
weight_decay: 1.0e-4

metadata:
run_name: Final Linear Classifier AdamW ProvGigaPath ${dataset.name}
description: "Final AdamW linear probe over frozen ProvGigaPath embeddings, trained on all training folds with the ProvGigaPath-selected weight decay."
21 changes: 21 additions & 0 deletions configs/experiment/ml/final_linear_provgigapath_lbfgs.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# @package _global_

defaults:
- /experiment/ml/final_linear_virchow2_lbfgs
- _self_

embedding_model_name: ProvGigaPath
embedding_dim: 1536
embedding_run_id: 410c8672471348ceb4c58817f70fa097
kfold_strategy: stratified_group
kfold_run_id: ${dataset.mlflow_artifacts.stratified_group_kfold_run_id}
mlflow_artifact_path: linear_classifier_final_provgigapath

# Set after Stage 1 from ProvGigaPath's own LBFGS sweep selected by
# validation/f1_macro.
model:
weight_decay: 1.0e-4

metadata:
run_name: Final Linear Classifier LBFGS ProvGigaPath ${dataset.name}
description: "Final LBFGS linear probe over frozen ProvGigaPath embeddings, exact full-batch solve with the ProvGigaPath-selected weight decay."

This file was deleted.

16 changes: 16 additions & 0 deletions configs/experiment/ml/test_linear_provgigapath_adamw.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# @package _global_

defaults:
- /experiment/ml/final_linear_provgigapath_adamw
- _self_

# Held-out test for the final ProvGigaPath AdamW checkpoint. Uses the same
# filtered labeled test split, thresholds, metrics, and checkpoint convention as
# the Virchow2 test config.
mode: test
final_train_run_id: fe172ccd8c1140269f7f3d1fdbd351ea
checkpoint: mlflow-artifacts:/104/${final_train_run_id}/artifacts/checkpoints/last/checkpoint.ckpt
checkpoint_weights_only: false

data:
num_workers: 0
18 changes: 18 additions & 0 deletions configs/experiment/ml/test_linear_provgigapath_lbfgs.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# @package _global_

defaults:
- /experiment/ml/final_linear_provgigapath_lbfgs
- override /ml/trainer: early_stopping
- _self_

# Held-out test for the final ProvGigaPath LBFGS checkpoint. Uses the same
# filtered labeled test split, thresholds, metrics, and checkpoint convention as
# the Virchow2 test config.
mode: test
final_train_run_id: 067b08dcbdb54d9187fbd4dd8d5599a1
checkpoint: mlflow-artifacts:/104/${final_train_run_id}/artifacts/checkpoints/last/checkpoint.ckpt
checkpoint_weights_only: false

data:
train_batch_size: 1024
num_workers: 0
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# @package _global_

defaults:
- /experiment/ml/linear_classifier_final_adamw
- /experiment/ml/final_linear_virchow2_adamw
- _self_

# Test the AdamW final checkpoint on the held-out test split. Same model
Expand All @@ -20,13 +20,3 @@ checkpoint_weights_only: false
# before the first test batch. final_embedding_tiles defaults to 4; override here.
data:
num_workers: 0

trainer:
callbacks:
tiff_prediction_maps:
_target_: ml.callbacks.TiffPredictionMapWriter
slides_uri: runs:/${dataset.mlflow_artifacts.tiling_run_id}/test_split/slides.parquet
artifact_path: prediction_maps_tiff
draw_region: central_stride
slide_selection: all
max_slides: null
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# @package _global_

defaults:
- /experiment/ml/linear_classifier_final_lbfgs
- /experiment/ml/final_linear_virchow2_lbfgs
- override /ml/trainer: early_stopping
- _self_

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# @package _global_

defaults:
- /experiment/ml/train_linear_virchow2_adamw_group_kfold
- _self_

embedding_model_name: ProvGigaPath
embedding_dim: 1536
embedding_run_id: 410c8672471348ceb4c58817f70fa097
mlflow_artifact_path: linear_classifier_provgigapath

metadata:
run_name: Linear Classifier ProvGigaPath ${dataset.name} ${kfold_strategy} fold=${val_fold} opt=${model.optimizer} wd=${model.weight_decay}
description: "Linear probe over frozen ProvGigaPath embeddings (run ${embedding_run_id}), ${kfold_strategy} kfold metadata ${kfold_run_id}."
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# @package _global_

defaults:
- /experiment/ml/linear_classifier_stratified_kfold
- /experiment/ml/train_linear_provgigapath_adamw_group_kfold
- _self_

trainer:
Expand All @@ -11,6 +11,7 @@ data:
train_batch_size: 1000000000
train_shuffle: false
train_drop_last: false
num_workers: 0

model:
optimizer: lbfgs
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# @package _global_

defaults:
- /experiment/ml/linear_classifier_stratified_group_kfold
- /experiment/ml/train_linear_virchow2_adamw_group_kfold
- _self_

trainer:
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# @package _global_

defaults:
- /experiment/preprocessing/embeddings_virchow2_0_5mpp
- _self_

# Embeddings for a deterministic sampled subset of test slides whose tiles
# intersect the tissue mask. The sample is capped by slide_sample_max_tiles and
# selected with slide_sample_seed for doctor-review prediction maps.
splits:
- test
tile_source_run_id: ${dataset.mlflow_artifacts.tissue_stats_run_id}
tile_source_artifact_template: "tissue_stats/{split}_tiles.parquet"
tile_filter_column: tile_tissue_coverage
slide_sample_max_tiles: 2000000
slide_sample_seed: 0
Comment thread
vojtech-cifka marked this conversation as resolved.

metadata:
run_name: "Embeddings: ${model} tissue tiles"
description: "Tile embeddings using ${model} over a sampled held-out test slide subset with tile_tissue_coverage > 0, capped by slide_sample_max_tiles=${slide_sample_max_tiles} and selected with slide_sample_seed=${slide_sample_seed}."
1 change: 1 addition & 0 deletions configs/ml/data/final_embedding_tiles.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@ data:
class_indices: ${class_indices}
thresholds: ${thresholds}
tissue_prop_min: ${tissue_prop_min}
slide_metadata_uri: ${test_slide_metadata_uri}
2 changes: 1 addition & 1 deletion configs/ml/model/linear_classifier.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ model:

decode_head:
_target_: torch.nn.Linear
in_features: 2560
in_features: ${embedding_dim}
out_features: ${len:${class_indices}}

class_indices: ${class_indices}
Expand Down
7 changes: 6 additions & 1 deletion configs/ml/task/final_linear_classifier.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ defaults:

mode: fit

embedding_model_name: Virchow2
embedding_dim: 2560
embedding_run_id: ${dataset.mlflow_artifacts.embedding_run_id}
kfold_strategy: stratified
kfold_run_id: ${dataset.mlflow_artifacts.stratified_kfold_run_id}
Expand All @@ -19,6 +21,7 @@ train_embedding_uri: runs:/${embedding_run_id}/train/tiles
test_embedding_uri: runs:/${embedding_run_id}/test/tiles
train_metadata_uri: runs:/${kfold_run_id}/kfold_split/kfold_tiles.parquet
test_metadata_uri: runs:/${filter_tiles_run_id}/filter_tiles/test_tiles.parquet
test_slide_metadata_uri: runs:/${embedding_run_id}/test/slides.parquet

tissue_prop_min: 0.2
thresholds:
Expand All @@ -34,8 +37,10 @@ mlflow_artifact_path: linear_classifier_final

metadata:
run_name: Final Linear Classifier ${dataset.name}
description: "Final linear probe over frozen Virchow2 embeddings trained on all training folds for ${trainer.max_epochs} epochs."
description: "Final linear probe over frozen ${embedding_model_name} embeddings trained on all training folds for ${trainer.max_epochs} epochs."
hyperparams:
embedding_model_name: ${embedding_model_name}
embedding_dim: ${embedding_dim}
embedding_run_id: ${embedding_run_id}
kfold_strategy: ${kfold_strategy}
kfold_run_id: ${kfold_run_id}
Expand Down
6 changes: 5 additions & 1 deletion configs/ml/task/kfold_linear_classifier.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ defaults:

mode: fit

embedding_model_name: Virchow2
embedding_dim: 2560
embedding_run_id: ${dataset.mlflow_artifacts.embedding_run_id}
kfold_strategy: stratified
kfold_run_id: ${dataset.mlflow_artifacts.stratified_kfold_run_id}
Expand All @@ -36,8 +38,10 @@ mlflow_artifact_path: linear_classifier

metadata:
run_name: Linear Classifier ${dataset.name} ${kfold_strategy} fold=${val_fold} opt=${model.optimizer} wd=${model.weight_decay}
description: "Linear probe over frozen Virchow2 embeddings (run ${embedding_run_id}), ${kfold_strategy} kfold metadata ${kfold_run_id}."
description: "Linear probe over frozen ${embedding_model_name} embeddings (run ${embedding_run_id}), ${kfold_strategy} kfold metadata ${kfold_run_id}."
hyperparams:
embedding_model_name: ${embedding_model_name}
embedding_dim: ${embedding_dim}
embedding_run_id: ${embedding_run_id}
kfold_strategy: ${kfold_strategy}
kfold_run_id: ${kfold_run_id}
Expand Down
1 change: 1 addition & 0 deletions configs/ml/trainer/early_stopping.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ trainer:
_target_: lightning.pytorch.callbacks.ModelCheckpoint
monitor: train/loss_epoch
mode: min
save_last: true
save_top_k: 1
filename: "epoch={epoch}-train_loss={train/loss_epoch:.4f}"
auto_insert_metric_name: false
Expand Down
4 changes: 4 additions & 0 deletions configs/preprocessing/embeddings.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ splits:
tile_source_run_id: ${dataset.mlflow_artifacts.filter_tiles_run_id}
tile_source_artifact_template: "filter_tiles/{split}_tiles.parquet"
tile_filter_column: null
slide_sample_max_tiles: null
slide_sample_seed: 0

metadata:
run_name: "Embeddings: ${model}"
Expand All @@ -23,3 +25,5 @@ metadata:
tile_source_run_id: ${tile_source_run_id}
tile_source_artifact_template: ${tile_source_artifact_template}
tile_filter_column: ${tile_filter_column}
slide_sample_max_tiles: ${slide_sample_max_tiles}
slide_sample_seed: ${slide_sample_seed}
6 changes: 5 additions & 1 deletion ml/callbacks/tiff_prediction_map_writer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Write tile predictions as WSI-aligned BigTIFF masks."""

from collections.abc import Mapping
from hashlib import blake2b
from pathlib import Path
from re import sub
from tempfile import TemporaryDirectory
Expand Down Expand Up @@ -517,7 +518,10 @@ def _safe_filename(value: str) -> str:


def _slide_prediction_filename(path: str | Path) -> str:
return Path(str(path)).with_suffix(".tiff").name
path_str = str(path)
stem = Path(path_str).stem
digest = blake2b(path_str.encode("utf-8"), digest_size=4).hexdigest()
return _safe_filename(f"{stem}-{digest}.tiff")


def _spread_lut(n_classes: int) -> np.ndarray:
Expand Down
16 changes: 14 additions & 2 deletions ml/data/datasets/embedding_tiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def __init__(
embedding_uri: str | Path,
meta_df: pd.DataFrame,
diag: Callable[[str], None],
slide_metadata_uri: str | Path | None = None,
) -> None:
diag(f"metadata filtered: {len(meta_df)} rows; reading embeddings")
joined_keys, embeddings = _load_embeddings_and_join(
Expand All @@ -37,6 +38,9 @@ def __init__(
self.slide_ids = joined_keys.column("slide_id").to_pandas().to_numpy()
self.xs = joined_keys.column("x").to_pandas().to_numpy(dtype=np.int64)
self.ys = joined_keys.column("y").to_pandas().to_numpy(dtype=np.int64)
self.slide_names_by_id = (
_load_slide_names(slide_metadata_uri) if slide_metadata_uri else {}
)
diag(f"dataset ready: {len(self.labels)} samples, dim={embeddings.shape[1]}")

def __len__(self) -> int:
Expand Down Expand Up @@ -78,6 +82,7 @@ def __init__(
tissue_prop_min: float,
include_folds: list[int] | None = None,
exclude_folds: list[int] | None = None,
slide_metadata_uri: str | Path | None = None,
) -> None:
self.class_indices = class_indices
diag = _make_diag(type(self).__name__)
Expand All @@ -89,7 +94,7 @@ def __init__(
include_folds,
exclude_folds,
)
super().__init__(embedding_uri, meta_df, diag)
super().__init__(embedding_uri, meta_df, diag, slide_metadata_uri)

def _labels_from_joined_keys(self, joined_keys: pa.Table) -> np.ndarray:
labels = joined_keys.column("label").to_pandas()
Expand Down Expand Up @@ -233,12 +238,13 @@ def __init__(
tissue_column: str = "tile_tissue_coverage",
tissue_min: float = 0.0,
label_value: int = -1,
slide_metadata_uri: str | Path | None = None,
) -> None:
self.label_value = label_value
diag = _make_diag(type(self).__name__)
diag("filtering metadata")
meta_df = self._filter_metadata(metadata_uri, tissue_column, tissue_min)
super().__init__(embedding_uri, meta_df, diag)
super().__init__(embedding_uri, meta_df, diag, slide_metadata_uri)

def _labels_from_joined_keys(self, joined_keys: pa.Table) -> np.ndarray:
return np.full(joined_keys.num_rows, self.label_value, dtype=np.int64)
Expand Down Expand Up @@ -268,6 +274,12 @@ def _resolve_uri(path_or_uri: str | Path) -> str:
return _resolve_uri_cached(str(path_or_uri))


def _load_slide_names(slide_metadata_uri: str | Path) -> dict[str, str]:
local = _resolve_uri(slide_metadata_uri)
df = pd.read_parquet(local, columns=["id", "path"])
return {str(row.id): Path(str(row.path)).name for row in df.itertuples(index=False)}


def _make_diag(dataset_name: str) -> Callable[[str], None]:
t0 = perf_counter()

Expand Down
Loading
Loading