Skip to content

feat: callbacks#7

Open
xrusnack wants to merge 22 commits into
masterfrom
feature/callbacks
Open

feat: callbacks#7
xrusnack wants to merge 22 commits into
masterfrom
feature/callbacks

Conversation

@xrusnack
Copy link
Copy Markdown
Member

@xrusnack xrusnack commented Mar 24, 2026

Summary by CodeRabbit

  • New Features
    • Inference and validation recipes for multiple datasets and model variants
    • Automated per-slide and dataset prediction metrics (including confusion matrix) logged to MLflow
    • Per-slide mask and attention-map generation exported as tiled TIFF artifacts
    • Threshold analysis with ROC/PR curves and suggested operating points logged and visualized
    • Enhanced ML/trainer defaults (checkpointing, epochs, data/worker settings)

@xrusnack xrusnack requested review from matejpekar and vejtek March 24, 2026 10:42
@xrusnack xrusnack self-assigned this Mar 24, 2026
@xrusnack xrusnack requested a review from a team March 24, 2026 10:42
@coderabbitai
Copy link
Copy Markdown

coderabbitai Bot commented Mar 24, 2026

📝 Walkthrough

Walkthrough

Adds numerous Hydra configs and new PyTorch Lightning callbacks for prediction, metric logging, mask/attention TIFF generation, and threshold curve plotting; wires MLflow artifact/metric logging and integrates OpenSlide/polyline rasterization for per-slide outputs.

Changes

Cohort / File(s) Summary
Hydra callback configs
configs/callbacks/attention/nuclei_masks.yaml, configs/callbacks/predictions/labels/mil.yaml, configs/callbacks/predictions/labels/wsl.yaml, configs/callbacks/predictions/metrics_dataset/mil.yaml, configs/callbacks/predictions/metrics_dataset/wsl.yaml, configs/callbacks/predictions/metrics_slide/wsl.yaml, configs/callbacks/predictions/nuclei_masks.yaml, configs/callbacks/thresholds/plot_curves/mil.yaml, configs/callbacks/thresholds/plot_curves/wsl.yaml
New Hydra YAMLs registering the new callbacks (MIL/WSL variants) with thresholds, artifact paths, package scope, and mask tiling params.
Experiment configs
configs/experiment/modeling/inference/base.yaml, configs/experiment/modeling/inference/crop_level/*, configs/experiment/modeling/inference/nuclei_level/*, configs/experiment/modeling/validation/base.yaml, configs/experiment/modeling/validation/crop_level/*, configs/experiment/modeling/validation/nuclei_level/*
New inference/validation experiment recipes for Radboud and MMCI datasets (crop- and nuclei-level), composing model/dataset/supervision defaults and wiring checkpoints/MLflow URIs and callbacks.
Top-level ML config
configs/ml.yaml
Expanded ML config: dataset/default experiment, explicit trainer settings (checkpointing, max_epochs, log frequency), data and model hyperparameter structures, and MLflow URI placeholders.
Package initializers
nuclei_graph/callbacks/__init__.py, nuclei_graph/callbacks/predictions/__init__.py, nuclei_graph/callbacks/thresholds/__init__.py
New init modules exporting the public callback symbols for package-level imports.
Prediction label callbacks
nuclei_graph/callbacks/predictions/labels.py
Adds BasePredictionsCallback, WSLPredictionsCallback, MILPredictionsCallback: writes per-slide Parquet predictions, accumulates slide preds, logs misclassifications to MLflow (MIL), and uploads artifacts at epoch end.
Prediction metrics (dataset/slide)
nuclei_graph/callbacks/predictions/metrics_dataset.py, nuclei_graph/callbacks/predictions/metrics_slide.py
Adds WSL/MIL dataset-level metric callbacks and a slide-level WSL metrics callback: accumulates metrics, moves to device, logs scalars and confusion matrix figure to MLflow, and writes per-slide metric tables.
Mask & attention TIFF generation
nuclei_graph/callbacks/predictions/nuclei_masks.py
Adds BaseMasksCallback, WSLPredictionMasksCallback, MILAttentionMasksCallback: open whole-slide images, rasterize polygon masks from parquet, map predictions/attention to 0–255 intensities, write BigTIFFs with tiling, and upload artifacts to MLflow.
Threshold curve plotting
nuclei_graph/callbacks/thresholds/plot_curves.py
Adds BaseCurvesCallback, MILCurvesCallback, WSLCurvesCallback: accumulate preds/targets, compute ROC/PR, identify thresholds (Youden’s J, F1, TPR=1 point), plot and log figures and threshold scalars to MLflow.

Sequence Diagram(s)

sequenceDiagram
    participant Trainer
    participant Callback
    participant PLModule
    participant Filesystem
    participant MLflow
    participant OpenSlide

    Trainer->>PLModule: run predict batch
    PLModule->>Callback: on_predict_batch_end(outputs, batch)
    alt write predictions/parquet
        Callback->>Filesystem: create temp dir & write <slide>.parquet
    end
    alt generate masks
        Callback->>OpenSlide: open slide file & read level dims
        Callback->>Filesystem: write BigTIFF mask (tiled)
    end
    alt log metrics/figures
        Callback->>MLflow: log metrics/artifacts (if active run)
    else no active run
        Callback-->>Trainer: skip MLflow uploads
    end
    Trainer->>Callback: on_predict_epoch_end()
    Callback->>MLflow: upload remaining artifacts & cleanup temp dir
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Suggested reviewers

  • vejtek
  • matejpekar

Poem

🐰 I hopped through configs, callbacks, and MLflow streams,
I stitched masks from polygons and stitched them into dreams.
Predictions saved in parquet rows, attention painted bright,
A tiny rabbit cheers at night — metrics logged, and runs take flight! 🎩🥕

🚥 Pre-merge checks | ✅ 1 | ❌ 2

❌ Failed checks (1 warning, 1 inconclusive)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 2.63% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Title check ❓ Inconclusive The PR title 'feat: callbacks' is vague and generic. While it mentions callbacks, it does not specify what type of callbacks or their primary purpose, making it unclear for developers scanning history. Consider using a more descriptive title such as 'feat: add prediction and threshold callbacks for inference' or 'feat: implement MIL/WSL prediction callbacks' to better convey the scope and purpose.
✅ Passed checks (1 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch feature/callbacks

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Copy Markdown

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the prediction capabilities of the machine learning framework by introducing a suite of specialized callbacks. These callbacks automate the process of capturing and logging model predictions, generating visual masks based on attention or prediction scores, and computing comprehensive evaluation metrics. By integrating these functionalities with MLflow, the PR aims to provide a more robust and streamlined workflow for analyzing and tracking model inference results, particularly for complex tasks involving Whole Slide Imaging (WSI) with WSL and MIL approaches.

Highlights

  • Prediction Callbacks Introduced: New callbacks have been implemented to handle various aspects of model prediction, including logging labels, generating masks, and computing metrics.
  • MLflow Integration for Predictions: All new prediction callbacks are designed to seamlessly integrate with MLflow, enabling automatic logging of prediction data, masks, and performance metrics as artifacts.
  • Support for WSL and MIL Models: Dedicated callbacks and configurations have been added to support both Weakly Supervised Learning (WSL) and Multiple Instance Learning (MIL) models during the prediction phase.
  • New Inference Configurations: New YAML configuration files (base.yaml, crop_level.yaml, nuclei_level.yaml) were added to define and streamline inference setups for different model types.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a set of PyTorch Lightning callbacks for processing model predictions, including saving labels, generating masks, and calculating metrics. The overall structure is logical, but there are key areas for improvement. My review focuses on enhancing configuration robustness by removing hardcoded values and improving the maintainability and performance of the new callbacks. Specifically, I suggest refactoring duplicated code and optimizing interactions with MLflow to make the prediction pipeline more efficient and easier to maintain.

Comment thread nuclei_graph/callbacks/prediction_labels.py Outdated
Comment thread nuclei_graph/callbacks/predictions/nuclei_masks.py Outdated
Comment thread configs/experiment/modeling/inference/nuclei_level.yaml Outdated
Copy link
Copy Markdown

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 9

🧹 Nitpick comments (8)
nuclei_graph/callbacks/predictions/metrics_slide.py (2)

22-31: Consider calling super().__init__() in the Callback subclass.

While lightning.Callback currently has a trivial __init__, explicitly calling super().__init__() is a small defensive measure against future base-class changes and keeps the subclass well-behaved for cooperative multiple inheritance.

♻️ Suggested change
     def __init__(self, threshold: float) -> None:
+        super().__init__()
         self.slide_nuclei_metrics = NestedMetricCollection(
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@nuclei_graph/callbacks/predictions/metrics_slide.py` around lines 22 - 31,
The Callback subclass's __init__ should call the base constructor; in the
__init__ method where slide_nuclei_metrics is created (the __init__ method that
sets slide_nuclei_metrics = NestedMetricCollection(...)), add a call to
super().__init__() (typically as the first statement) to ensure cooperative
initialization with lightning.Callback and support future base-class changes.

59-67: Logging path inconsistent with sibling callback.

WSLSlidePredictionMetricsCallback logs via trainer.logger.log_table(...) and asserts MLFlowLogger, whereas WSLDatasetPredictionMetricsCallback / MILDatasetPredictionMetricsCallback in metrics_dataset.py log directly with the global mlflow module (bypassing trainer.logger). Pick one approach consistently across the callback family so that run-context (run_id, nested runs, fluent/non-fluent setup) is handled the same way and metrics don't end up attached to different runs in edge cases.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@nuclei_graph/callbacks/predictions/metrics_slide.py` around lines 59 - 67,
The on_predict_epoch_end implementation in WSLSlidePredictionMetricsCallback
currently asserts trainer.logger is an MLFlowLogger and calls
trainer.logger.log_table(...), which is inconsistent with
WSLDatasetPredictionMetricsCallback and MILDatasetPredictionMetricsCallback that
use the global mlflow module; change
WSLSlidePredictionMetricsCallback.on_predict_epoch_end to use the same global
mlflow logging approach as the dataset callbacks (log the result of
self.slide_nuclei_metrics.compute() via the global mlflow API and remove the
MLFlowLogger assert), ensuring the artifact/metric is attached to the same
run/context as the other callbacks and then call
self.slide_nuclei_metrics.reset().
nuclei_graph/callbacks/predictions/metrics_dataset.py (2)

39-50: Redundant prefix handling.

MetricCollection already applies prefix="prediction/", so keys returned by compute() are already prediction/<name>. The loop then strips the prefix via key.split("/")[-1] and re-prepends "prediction/" — net-no-op. Simplify:

♻️ Suggested change
-        for key, value in computed_metrics.items():
-            metric_name = key.split("/")[-1]
-            mlflow.log_metric(f"prediction/{metric_name}", float(value))
+        for key, value in computed_metrics.items():
+            mlflow.log_metric(key, float(value))

The same simplification applies to the else branch in MILDatasetPredictionMetricsCallback.on_predict_epoch_end (Line 136).

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@nuclei_graph/callbacks/predictions/metrics_dataset.py` around lines 39 - 50,
The loop in on_predict_epoch_end is redundantly removing and re-adding the
"prediction/" prefix: dataset_nuclei_metrics.compute() already returns keys like
"prediction/<name>", so in the on_predict_epoch_end method of this callback you
should stop splitting the key and re-prepending the prefix; instead pass the
returned key directly into mlflow.log_metric (i.e., use the compute() keys
as-is). Make the same simplification in
MILDatasetPredictionMetricsCallback.on_predict_epoch_end (the else branch
mentioned) so neither callback strips and re-adds the "prediction/" prefix.

1-4: Direct mlflow usage bypasses the Lightning logger.

Both callbacks log via the global mlflow module, which relies on the ambient active run rather than the run managed by trainer.logger (an MLFlowLogger). This can attach metrics/figures to the wrong run in nested-run or multi-logger scenarios, and is inconsistent with WSLSlidePredictionMetricsCallback in metrics_slide.py, which routes through trainer.logger. Consider unifying on trainer.logger (use log_metrics and experiment.log_figure(run_id=..., ...)) for consistency and correctness.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@nuclei_graph/callbacks/predictions/metrics_dataset.py` around lines 1 - 4,
The callbacks in metrics_dataset.py currently call the global mlflow module
directly; change them to use the Trainer's logger (trainer.logger) like
WSLSlidePredictionMetricsCallback does in metrics_slide.py: remove direct mlflow
imports/usages, call trainer.logger.log_metrics(...) to record metric dicts, and
use trainer.logger.experiment.log_figure(run_id, name, figure) or
trainer.logger.experiment.log_* APIs to attach figures (obtain run_id from
trainer.logger.experiment.active_run_id or the logger's run_id property as
appropriate); update any functions/methods in the callbacks that reference
mlflow to instead accept a Trainer instance and route logging through
trainer.logger and trainer.logger.experiment.
configs/experiment/modeling/inference/nuclei_level/prostate_cancer_mmci_tl.yaml (1)

12-14: Commented-out callback entries — intentional?

The labels, nuclei_masks, and thresholds callback defaults are commented out here but active in the sibling radboud.yaml inference config. If this is a temporary disable, consider a brief comment explaining why (e.g., MMCI lacks labels for these artifacts) so the divergence is self-documenting; otherwise remove the dead lines.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@configs/experiment/modeling/inference/nuclei_level/prostate_cancer_mmci_tl.yaml`
around lines 12 - 14, The three callback entries (trainer.callbacks.labels,
trainer.callbacks.nuclei_masks, trainer.callbacks.thresholds) are commented out
here while kept active in the sibling config; either remove these dead lines or
add a short explanatory comment above them explaining why they're disabled
(e.g., "MMCI lacks labels/nuclei masks, thresholds disabled") so the divergence
is documented; update the block containing the commented lines "-
/callbacks/predictions/labels@trainer.callbacks.labels", "-
/callbacks/predictions@trainer.callbacks.nuclei_masks", and "-
/callbacks/thresholds/plot_curves@trainer.callbacks.thresholds" accordingly.
nuclei_graph/callbacks/predictions/labels.py (1)

74-79: Reset slide_preds in on_predict_start for robustness.

self.slide_preds is only cleared at the end of on_predict_epoch_end. If a previous predict run raises before reaching the epoch end, a subsequent trainer.predict() on the same callback instance will accumulate stale slides. Reinitializing in on_predict_start (in addition to the current end-of-epoch reset) makes the callback reuse-safe.

♻️ Proposed fix
+    def on_predict_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
+        super().on_predict_start(trainer, pl_module)
+        self.slide_preds = {"slide_id": [], "is_carcinoma": [], "prediction": []}
+
     def on_predict_batch_end(
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@nuclei_graph/callbacks/predictions/labels.py` around lines 74 - 79, The
callback currently initializes self.slide_preds only in __init__ and clears it
in on_predict_epoch_end, which can leave stale data if a predict run fails; add
a reset of self.slide_preds = {"slide_id": [], "is_carcinoma": [], "prediction":
[]} at the start of on_predict_start to ensure each trainer.predict() run starts
with a fresh buffer (keep the existing clear in on_predict_epoch_end as well).
nuclei_graph/callbacks/predictions/nuclei_masks.py (2)

100-101: int(pred * 255) truncates; consider rounding.

int() truncates toward zero, so e.g. pred=0.999 → 254 rather than 255. Using round (or np.uint8(round(...))) gives a slightly more faithful visual mapping of probabilities/attention to grayscale. Same note applies to line 163 in MILAttentionMasksCallback. Nit.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@nuclei_graph/callbacks/predictions/nuclei_masks.py` around lines 100 - 101,
The code currently computes pixel_val = int(pred * 255) which truncates toward
zero; change this to use rounding so high probabilities map correctly (e.g., use
round(pred * 255) or np.uint8(round(pred * 255))) before passing to
canvas.polygon; update the same pattern in the MILAttentionMasksCallback
occurrence as well (search for pixel_val/int(pred * 255) and replace with a
rounded conversion to uint8) to ensure faithful grayscale mapping.

56-113: Significant duplication between WSLPredictionMasksCallback and MILAttentionMasksCallback.

The slide-open/scale computation, mask canvas setup, parquet polygon loading, polygon rasterization loop, and write_big_tiff call are essentially identical in both subclasses — only the per-nucleus intensity source (sigmoid(logits) vs. normalized attention) differs. Consider pulling the common rasterization pipeline into a helper on BaseMasksCallback (e.g., _render_and_write(metadata, per_nucleus_values)), leaving subclasses to compute just the values array. This will also make future additions (new overlay kinds) much cheaper.

Also applies to: 115-175

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@nuclei_graph/callbacks/predictions/nuclei_masks.py` around lines 56 - 113,
WSLPredictionMasksCallback and MILAttentionMasksCallback duplicate the entire
rasterization pipeline; extract the shared logic (OpenSlide slide open & scale
computation, mask canvas creation, parquet polygon loading, polygon -> pixel
scaling and drawing, and write_big_tiff invocation) into a new BaseMasksCallback
helper (suggested name _render_and_write(metadata, per_nucleus_values)). Update
WSLPredictionMasksCallback.on_predict_batch_end to only compute its per-nucleus
values (predicted_labels = sigmoid(logits_ordered).cpu().numpy().flatten()) and
pass metadata and that array to _render_and_write; do the same in
MILAttentionMasksCallback (compute normalized attention values then call
_render_and_write). The helper should reuse existing symbols/fields: self.level,
mask_tile_width, mask_tile_height, self._get_output_path, slide_resolution,
OpenSlide, PILImage, rearrange, write_big_tiff, and preserve the current mask
sizing and scaling behavior.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@configs/callbacks/predictions/metrics_dataset/wsl.yaml`:
- Line 4: The inline comment on the threshold key contains an inconsistent MIL
value; update or remove the inline note so it no longer shows "MIL: 0.957" —
locate the threshold entry (the YAML key "threshold") and either delete the
trailing comment "MIL: 0.957, WSL: 0.388" or replace it with a single accurate
note (e.g., just "WSL: 0.388") that matches the MIL thresholds used in this PR.

In `@configs/callbacks/predictions/metrics_slide/wsl.yaml`:
- Line 4: The inline comment on the threshold line in the file (the `threshold:
0.388 # MIL: 0.957, WSL: 0.388` entry) is stale/misleading; update or remove the
inline comparison so it doesn't conflict with the current MIL threshold used
elsewhere (see the `threshold` entry in `metrics_dataset/mil.yaml` which is
`0.540`)—either delete the trailing `# MIL: ...` fragment or replace it with the
correct, current MIL value to keep `threshold` comments consistent.

In
`@configs/experiment/modeling/inference/crop_level/prostate_cancer_mmci_tl.yaml`:
- Around line 10-11: The Hydra config duplicates the target key
trainer.callbacks.metrics_dataset so the second entry (wsl) overwrites the first
(mil); change the entries to use distinct keys or separate list entries so both
callbacks register (e.g., use unique keys referencing the same target such as
trainer.callbacks.metrics_dataset_mil and trainer.callbacks.metrics_dataset_wsl
or convert to an explicit sequence under trainer.callbacks that includes both
targets) and ensure the identifiers reference the correct callback
implementations for MIL and WSL respectively.

In `@configs/experiment/modeling/inference/crop_level/radboud.yaml`:
- Line 12: The inline comment is incorrect: the callback entry
'/callbacks/predictions/metrics_slide@trainer.callbacks.metrics_slide: wsl'
refers to nuclei-level WSL metrics, not "graph-level preds"; update the comment
to accurately state "nuclei-level preds" (or remove the misleading "graph-level
preds") next to the metrics_slide / wsl entry so the comment matches the actual
behavior of the wsl metric.
- Around line 10-11: The two callback registrations use the identical Hydra
target key trainer.callbacks.metrics_dataset so the second (wsl) overwrites the
first (mil); rename one of the keys so both are unique (for example change the
left-side key for the MIL entry to something like
/callbacks/predictions/metrics_dataset_mil@trainer.callbacks.metrics_dataset or
similarly suffix the key) and keep the other as
/callbacks/predictions/metrics_dataset@trainer.callbacks.metrics_dataset for WSL
so both callbacks (mil and wsl) are registered and active.
- Line 15: The callback reference for attention is using an unnecessary '+'
prefix; locate the mapping key that reads
callbacks/attention@+trainer.callbacks.attn_masks and change it to
callbacks/attention@trainer.callbacks.attn_masks so it matches the other
callback entries and the pre-defined trainer.callbacks dict (i.e., remove the
'+' from the @+trainer.callbacks.attn_masks reference).

In `@nuclei_graph/callbacks/predictions/labels.py`:
- Around line 126-135: The CSV is being logged to the run root instead of the
predictions artifact path; inside the with tempfile.TemporaryDirectory() block
where misclassif_df is written, either move the CSV into the same temp directory
used by the base uploader (e.g., write to self.tmp_dir or the temp dir the base
class expects so _save_parquet and the base log_artifacts call pick it up) or
call mlflow.log_artifact with artifact_path=self.mlflow_artifact_path (replace
the current mlflow.log_artifact call) so misclassifications.csv is uploaded
alongside the prediction parquet files.

In `@nuclei_graph/callbacks/thresholds/plot_curves.py`:
- Around line 68-92: The _perform_roc function currently looks for exact tpr ==
1 and falls back to thresholds[0] (sklearn's sentinel) when no tpr reaches 1;
change this to (1) use a tolerant comparison like np.isclose(tpr, 1.0) to find
indices and (2) when no index is found do not pick thresholds[0] — instead set
tpr_threshold to np.nan (or None) to signal "not available" so you don't log the
sklearn sentinel; update the variables idx, tpr_idx, and tpr_threshold in
_perform_roc accordingly and ensure callers that expect a numeric threshold
handle np.nan/None instead of the sentinel.
- Around line 94-99: In _perform_pr, f1 is computed from precision and recall
which are length n_thresholds+1 while thresholds is length n_thresholds, causing
a possible IndexError when picking best_threshold; fix by computing f1 only over
the entries that correspond to thresholds (e.g., use precision[:-1] and
recall[:-1] so f1.shape == thresholds.shape) and then select best_threshold =
thresholds[best_idx]; also guard against the case of empty thresholds (handle or
return a sensible default when thresholds.size == 0) to avoid indexing errors.

---

Nitpick comments:
In
`@configs/experiment/modeling/inference/nuclei_level/prostate_cancer_mmci_tl.yaml`:
- Around line 12-14: The three callback entries (trainer.callbacks.labels,
trainer.callbacks.nuclei_masks, trainer.callbacks.thresholds) are commented out
here while kept active in the sibling config; either remove these dead lines or
add a short explanatory comment above them explaining why they're disabled
(e.g., "MMCI lacks labels/nuclei masks, thresholds disabled") so the divergence
is documented; update the block containing the commented lines "-
/callbacks/predictions/labels@trainer.callbacks.labels", "-
/callbacks/predictions@trainer.callbacks.nuclei_masks", and "-
/callbacks/thresholds/plot_curves@trainer.callbacks.thresholds" accordingly.

In `@nuclei_graph/callbacks/predictions/labels.py`:
- Around line 74-79: The callback currently initializes self.slide_preds only in
__init__ and clears it in on_predict_epoch_end, which can leave stale data if a
predict run fails; add a reset of self.slide_preds = {"slide_id": [],
"is_carcinoma": [], "prediction": []} at the start of on_predict_start to ensure
each trainer.predict() run starts with a fresh buffer (keep the existing clear
in on_predict_epoch_end as well).

In `@nuclei_graph/callbacks/predictions/metrics_dataset.py`:
- Around line 39-50: The loop in on_predict_epoch_end is redundantly removing
and re-adding the "prediction/" prefix: dataset_nuclei_metrics.compute() already
returns keys like "prediction/<name>", so in the on_predict_epoch_end method of
this callback you should stop splitting the key and re-prepending the prefix;
instead pass the returned key directly into mlflow.log_metric (i.e., use the
compute() keys as-is). Make the same simplification in
MILDatasetPredictionMetricsCallback.on_predict_epoch_end (the else branch
mentioned) so neither callback strips and re-adds the "prediction/" prefix.
- Around line 1-4: The callbacks in metrics_dataset.py currently call the global
mlflow module directly; change them to use the Trainer's logger (trainer.logger)
like WSLSlidePredictionMetricsCallback does in metrics_slide.py: remove direct
mlflow imports/usages, call trainer.logger.log_metrics(...) to record metric
dicts, and use trainer.logger.experiment.log_figure(run_id, name, figure) or
trainer.logger.experiment.log_* APIs to attach figures (obtain run_id from
trainer.logger.experiment.active_run_id or the logger's run_id property as
appropriate); update any functions/methods in the callbacks that reference
mlflow to instead accept a Trainer instance and route logging through
trainer.logger and trainer.logger.experiment.

In `@nuclei_graph/callbacks/predictions/metrics_slide.py`:
- Around line 22-31: The Callback subclass's __init__ should call the base
constructor; in the __init__ method where slide_nuclei_metrics is created (the
__init__ method that sets slide_nuclei_metrics = NestedMetricCollection(...)),
add a call to super().__init__() (typically as the first statement) to ensure
cooperative initialization with lightning.Callback and support future base-class
changes.
- Around line 59-67: The on_predict_epoch_end implementation in
WSLSlidePredictionMetricsCallback currently asserts trainer.logger is an
MLFlowLogger and calls trainer.logger.log_table(...), which is inconsistent with
WSLDatasetPredictionMetricsCallback and MILDatasetPredictionMetricsCallback that
use the global mlflow module; change
WSLSlidePredictionMetricsCallback.on_predict_epoch_end to use the same global
mlflow logging approach as the dataset callbacks (log the result of
self.slide_nuclei_metrics.compute() via the global mlflow API and remove the
MLFlowLogger assert), ensuring the artifact/metric is attached to the same
run/context as the other callbacks and then call
self.slide_nuclei_metrics.reset().

In `@nuclei_graph/callbacks/predictions/nuclei_masks.py`:
- Around line 100-101: The code currently computes pixel_val = int(pred * 255)
which truncates toward zero; change this to use rounding so high probabilities
map correctly (e.g., use round(pred * 255) or np.uint8(round(pred * 255)))
before passing to canvas.polygon; update the same pattern in the
MILAttentionMasksCallback occurrence as well (search for pixel_val/int(pred *
255) and replace with a rounded conversion to uint8) to ensure faithful
grayscale mapping.
- Around line 56-113: WSLPredictionMasksCallback and MILAttentionMasksCallback
duplicate the entire rasterization pipeline; extract the shared logic (OpenSlide
slide open & scale computation, mask canvas creation, parquet polygon loading,
polygon -> pixel scaling and drawing, and write_big_tiff invocation) into a new
BaseMasksCallback helper (suggested name _render_and_write(metadata,
per_nucleus_values)). Update WSLPredictionMasksCallback.on_predict_batch_end to
only compute its per-nucleus values (predicted_labels =
sigmoid(logits_ordered).cpu().numpy().flatten()) and pass metadata and that
array to _render_and_write; do the same in MILAttentionMasksCallback (compute
normalized attention values then call _render_and_write). The helper should
reuse existing symbols/fields: self.level, mask_tile_width, mask_tile_height,
self._get_output_path, slide_resolution, OpenSlide, PILImage, rearrange,
write_big_tiff, and preserve the current mask sizing and scaling behavior.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: b0ffc4db-acaf-44c9-a21f-50137756319f

📥 Commits

Reviewing files that changed from the base of the PR and between fcce74f and 18ed904.

📒 Files selected for processing (28)
  • configs/callbacks/attention/nuclei_masks.yaml
  • configs/callbacks/predictions/labels/mil.yaml
  • configs/callbacks/predictions/labels/wsl.yaml
  • configs/callbacks/predictions/metrics_dataset/mil.yaml
  • configs/callbacks/predictions/metrics_dataset/wsl.yaml
  • configs/callbacks/predictions/metrics_slide/wsl.yaml
  • configs/callbacks/predictions/nuclei_masks.yaml
  • configs/callbacks/thresholds/plot_curves/mil.yaml
  • configs/callbacks/thresholds/plot_curves/wsl.yaml
  • configs/experiment/modeling/inference/base.yaml
  • configs/experiment/modeling/inference/crop_level/prostate_cancer_mmci_tl.yaml
  • configs/experiment/modeling/inference/crop_level/radboud.yaml
  • configs/experiment/modeling/inference/nuclei_level/prostate_cancer_mmci_tl.yaml
  • configs/experiment/modeling/inference/nuclei_level/radboud.yaml
  • configs/experiment/modeling/validation/base.yaml
  • configs/experiment/modeling/validation/crop_level/prostate_cancer_mmci_tl.yaml
  • configs/experiment/modeling/validation/crop_level/radboud.yaml
  • configs/experiment/modeling/validation/nuclei_level/prostate_cancer_mmci_tl.yaml
  • configs/experiment/modeling/validation/nuclei_level/radboud.yaml
  • configs/ml.yaml
  • nuclei_graph/callbacks/__init__.py
  • nuclei_graph/callbacks/predictions/__init__.py
  • nuclei_graph/callbacks/predictions/labels.py
  • nuclei_graph/callbacks/predictions/metrics_dataset.py
  • nuclei_graph/callbacks/predictions/metrics_slide.py
  • nuclei_graph/callbacks/predictions/nuclei_masks.py
  • nuclei_graph/callbacks/thresholds/__init__.py
  • nuclei_graph/callbacks/thresholds/plot_curves.py

Comment thread configs/callbacks/predictions/metrics_dataset/wsl.yaml Outdated
Comment thread configs/callbacks/predictions/metrics_slide/wsl.yaml Outdated
Comment thread configs/experiment/modeling/inference/crop_level/prostate_cancer_mmci_tl.yaml Outdated
Comment thread configs/experiment/modeling/inference/crop_level/radboud.yaml Outdated
Comment thread configs/experiment/modeling/inference/crop_level/radboud.yaml Outdated
Comment thread configs/experiment/modeling/inference/crop_level/radboud.yaml Outdated
Comment thread nuclei_graph/callbacks/predictions/labels.py Outdated
Comment thread nuclei_graph/callbacks/thresholds/plot_curves.py Outdated
Comment thread nuclei_graph/callbacks/thresholds/plot_curves.py Outdated
Copy link
Copy Markdown

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
nuclei_graph/callbacks/thresholds/plot_curves.py (1)

68-74: ⚠️ Potential issue | 🟡 Minor

tpr_threshold fallback still emits sklearn's sentinel np.inf.

Previous review flagged this and it hasn't been addressed. Since sklearn ≥ 1.3, roc_curve sets thresholds[0] = np.inf, so when len(idx) == 0 (e.g., no positives reach TPR == 1), tpr_threshold = thresholds[0] = inf gets forwarded to mlflow.log_metric(..., float(roc_t)) on line 30 — MLflow rejects non-finite metric values, which will raise at log time.

Additionally, np.where(tpr == 1) uses exact float equality; prefer np.isclose(tpr, 1.0) for robustness.

🛡️ Proposed fix
-        idx = np.where(tpr == 1)[0]
-        tpr_idx = idx[np.argmin(fpr[idx])] if len(idx) > 0 else 0
-        tpr_threshold = thresholds[tpr_idx]
+        idx = np.where(np.isclose(tpr, 1.0))[0]
+        if len(idx) > 0:
+            tpr_idx = idx[np.argmin(fpr[idx])]
+            tpr_threshold = float(thresholds[tpr_idx])
+        else:
+            tpr_idx = int(np.argmax(tpr))  # best-effort point for plotting
+            tpr_threshold = float("nan")   # signal "not available"

And in _log_and_clear_curves, skip logging / handle NaN accordingly:

-        mlflow.log_metric(f"thresholds/{level_name}_tpr_threshold", float(roc_t))
+        if np.isfinite(roc_t):
+            mlflow.log_metric(f"thresholds/{level_name}_tpr_threshold", float(roc_t))
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@nuclei_graph/callbacks/thresholds/plot_curves.py` around lines 68 - 74, The
code in _perform_roc uses exact equality to find TPR==1 and falls back to
thresholds[0] which may be np.inf; change the selection to use np.isclose(tpr,
1.0) and if no matching index exists set tpr_threshold = np.nan (or otherwise
ensure it is finite) rather than using thresholds[0]; additionally update
_log_and_clear_curves to skip or convert non-finite roc_t values before calling
mlflow.log_metric (only call mlflow.log_metric when np.isfinite(roc_t) and cast
to float), so MLflow never receives inf/NaN.
🧹 Nitpick comments (1)
nuclei_graph/callbacks/thresholds/plot_curves.py (1)

40-66: Optional: iterate with zip and add type hints.

Minor idiomatic cleanup — no behavioral change.

♻️ Proposed refactor
-        for i in range(len(to_pinpoint)):
-            x, y = to_pinpoint[i]
-            ax.scatter(x, y, color=point_colors[i], label=point_labels[i], zorder=5)
+        for (x, y), label, color in zip(to_pinpoint, point_labels, point_colors):
+            ax.scatter(x, y, color=color, label=label, zorder=5)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@nuclei_graph/callbacks/thresholds/plot_curves.py` around lines 40 - 66, The
_plot_curve method currently iterates by index over to_pinpoint and accesses
parallel lists; change the loop to iterate with zip (e.g., for (x, y), color,
label in zip(to_pinpoint, point_colors, point_labels)) to make it idiomatic and
safer, and add lightweight type hints to the signature (e.g., xs:
Sequence[float], ys: Sequence[float], to_pinpoint: Sequence[Tuple[float,
float]], point_labels: Sequence[str], point_colors: Sequence[str], xlabel: str,
ylabel: str, title: str, loc: Any) to improve readability and static checking
while preserving behavior.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@nuclei_graph/callbacks/thresholds/plot_curves.py`:
- Around line 127-149: Add an early guard in
MILCurvesCallback.on_validation_batch_end to return immediately if outputs is
None (like WSLCurvesCallback does) before accessing outputs["graph"] or
outputs["nuclei"]; keep the existing trainer.sanity_checking check and ensure
you check outputs is None right after it so subsequent lines (graph_outputs =
outputs["graph"].view(-1), nuclei_outputs = outputs["nuclei"][...]) cannot NPE.

---

Duplicate comments:
In `@nuclei_graph/callbacks/thresholds/plot_curves.py`:
- Around line 68-74: The code in _perform_roc uses exact equality to find TPR==1
and falls back to thresholds[0] which may be np.inf; change the selection to use
np.isclose(tpr, 1.0) and if no matching index exists set tpr_threshold = np.nan
(or otherwise ensure it is finite) rather than using thresholds[0]; additionally
update _log_and_clear_curves to skip or convert non-finite roc_t values before
calling mlflow.log_metric (only call mlflow.log_metric when np.isfinite(roc_t)
and cast to float), so MLflow never receives inf/NaN.

---

Nitpick comments:
In `@nuclei_graph/callbacks/thresholds/plot_curves.py`:
- Around line 40-66: The _plot_curve method currently iterates by index over
to_pinpoint and accesses parallel lists; change the loop to iterate with zip
(e.g., for (x, y), color, label in zip(to_pinpoint, point_colors, point_labels))
to make it idiomatic and safer, and add lightweight type hints to the signature
(e.g., xs: Sequence[float], ys: Sequence[float], to_pinpoint:
Sequence[Tuple[float, float]], point_labels: Sequence[str], point_colors:
Sequence[str], xlabel: str, ylabel: str, title: str, loc: Any) to improve
readability and static checking while preserving behavior.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: e3494ca0-71da-4001-baa5-ac68e2ec8d57

📥 Commits

Reviewing files that changed from the base of the PR and between 18ed904 and 4809b6d.

📒 Files selected for processing (1)
  • nuclei_graph/callbacks/thresholds/plot_curves.py

Comment thread nuclei_graph/callbacks/plot_curves.py
@xrusnack xrusnack changed the title feat: prediction callbacks feat: callbacks Apr 20, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant