feat: callbacks#7
Conversation
📝 WalkthroughWalkthroughAdds numerous Hydra configs and new PyTorch Lightning callbacks for prediction, metric logging, mask/attention TIFF generation, and threshold curve plotting; wires MLflow artifact/metric logging and integrates OpenSlide/polyline rasterization for per-slide outputs. Changes
Sequence Diagram(s)sequenceDiagram
participant Trainer
participant Callback
participant PLModule
participant Filesystem
participant MLflow
participant OpenSlide
Trainer->>PLModule: run predict batch
PLModule->>Callback: on_predict_batch_end(outputs, batch)
alt write predictions/parquet
Callback->>Filesystem: create temp dir & write <slide>.parquet
end
alt generate masks
Callback->>OpenSlide: open slide file & read level dims
Callback->>Filesystem: write BigTIFF mask (tiled)
end
alt log metrics/figures
Callback->>MLflow: log metrics/artifacts (if active run)
else no active run
Callback-->>Trainer: skip MLflow uploads
end
Trainer->>Callback: on_predict_epoch_end()
Callback->>MLflow: upload remaining artifacts & cleanup temp dir
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the prediction capabilities of the machine learning framework by introducing a suite of specialized callbacks. These callbacks automate the process of capturing and logging model predictions, generating visual masks based on attention or prediction scores, and computing comprehensive evaluation metrics. By integrating these functionalities with MLflow, the PR aims to provide a more robust and streamlined workflow for analyzing and tracking model inference results, particularly for complex tasks involving Whole Slide Imaging (WSI) with WSL and MIL approaches. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a set of PyTorch Lightning callbacks for processing model predictions, including saving labels, generating masks, and calculating metrics. The overall structure is logical, but there are key areas for improvement. My review focuses on enhancing configuration robustness by removing hardcoded values and improving the maintainability and performance of the new callbacks. Specifically, I suggest refactoring duplicated code and optimizing interactions with MLflow to make the prediction pipeline more efficient and easier to maintain.
There was a problem hiding this comment.
Actionable comments posted: 9
🧹 Nitpick comments (8)
nuclei_graph/callbacks/predictions/metrics_slide.py (2)
22-31: Consider callingsuper().__init__()in the Callback subclass.While
lightning.Callbackcurrently has a trivial__init__, explicitly callingsuper().__init__()is a small defensive measure against future base-class changes and keeps the subclass well-behaved for cooperative multiple inheritance.♻️ Suggested change
def __init__(self, threshold: float) -> None: + super().__init__() self.slide_nuclei_metrics = NestedMetricCollection(🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@nuclei_graph/callbacks/predictions/metrics_slide.py` around lines 22 - 31, The Callback subclass's __init__ should call the base constructor; in the __init__ method where slide_nuclei_metrics is created (the __init__ method that sets slide_nuclei_metrics = NestedMetricCollection(...)), add a call to super().__init__() (typically as the first statement) to ensure cooperative initialization with lightning.Callback and support future base-class changes.
59-67: Logging path inconsistent with sibling callback.
WSLSlidePredictionMetricsCallbacklogs viatrainer.logger.log_table(...)and assertsMLFlowLogger, whereasWSLDatasetPredictionMetricsCallback/MILDatasetPredictionMetricsCallbackinmetrics_dataset.pylog directly with the globalmlflowmodule (bypassingtrainer.logger). Pick one approach consistently across the callback family so that run-context (run_id, nested runs, fluent/non-fluent setup) is handled the same way and metrics don't end up attached to different runs in edge cases.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@nuclei_graph/callbacks/predictions/metrics_slide.py` around lines 59 - 67, The on_predict_epoch_end implementation in WSLSlidePredictionMetricsCallback currently asserts trainer.logger is an MLFlowLogger and calls trainer.logger.log_table(...), which is inconsistent with WSLDatasetPredictionMetricsCallback and MILDatasetPredictionMetricsCallback that use the global mlflow module; change WSLSlidePredictionMetricsCallback.on_predict_epoch_end to use the same global mlflow logging approach as the dataset callbacks (log the result of self.slide_nuclei_metrics.compute() via the global mlflow API and remove the MLFlowLogger assert), ensuring the artifact/metric is attached to the same run/context as the other callbacks and then call self.slide_nuclei_metrics.reset().nuclei_graph/callbacks/predictions/metrics_dataset.py (2)
39-50: Redundant prefix handling.
MetricCollectionalready appliesprefix="prediction/", so keys returned bycompute()are alreadyprediction/<name>. The loop then strips the prefix viakey.split("/")[-1]and re-prepends"prediction/"— net-no-op. Simplify:♻️ Suggested change
- for key, value in computed_metrics.items(): - metric_name = key.split("/")[-1] - mlflow.log_metric(f"prediction/{metric_name}", float(value)) + for key, value in computed_metrics.items(): + mlflow.log_metric(key, float(value))The same simplification applies to the
elsebranch inMILDatasetPredictionMetricsCallback.on_predict_epoch_end(Line 136).🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@nuclei_graph/callbacks/predictions/metrics_dataset.py` around lines 39 - 50, The loop in on_predict_epoch_end is redundantly removing and re-adding the "prediction/" prefix: dataset_nuclei_metrics.compute() already returns keys like "prediction/<name>", so in the on_predict_epoch_end method of this callback you should stop splitting the key and re-prepending the prefix; instead pass the returned key directly into mlflow.log_metric (i.e., use the compute() keys as-is). Make the same simplification in MILDatasetPredictionMetricsCallback.on_predict_epoch_end (the else branch mentioned) so neither callback strips and re-adds the "prediction/" prefix.
1-4: Directmlflowusage bypasses the Lightning logger.Both callbacks log via the global
mlflowmodule, which relies on the ambient active run rather than the run managed bytrainer.logger(anMLFlowLogger). This can attach metrics/figures to the wrong run in nested-run or multi-logger scenarios, and is inconsistent withWSLSlidePredictionMetricsCallbackinmetrics_slide.py, which routes throughtrainer.logger. Consider unifying ontrainer.logger(uselog_metricsandexperiment.log_figure(run_id=..., ...)) for consistency and correctness.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@nuclei_graph/callbacks/predictions/metrics_dataset.py` around lines 1 - 4, The callbacks in metrics_dataset.py currently call the global mlflow module directly; change them to use the Trainer's logger (trainer.logger) like WSLSlidePredictionMetricsCallback does in metrics_slide.py: remove direct mlflow imports/usages, call trainer.logger.log_metrics(...) to record metric dicts, and use trainer.logger.experiment.log_figure(run_id, name, figure) or trainer.logger.experiment.log_* APIs to attach figures (obtain run_id from trainer.logger.experiment.active_run_id or the logger's run_id property as appropriate); update any functions/methods in the callbacks that reference mlflow to instead accept a Trainer instance and route logging through trainer.logger and trainer.logger.experiment.configs/experiment/modeling/inference/nuclei_level/prostate_cancer_mmci_tl.yaml (1)
12-14: Commented-out callback entries — intentional?The
labels,nuclei_masks, andthresholdscallback defaults are commented out here but active in the siblingradboud.yamlinference config. If this is a temporary disable, consider a brief comment explaining why (e.g., MMCI lacks labels for these artifacts) so the divergence is self-documenting; otherwise remove the dead lines.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@configs/experiment/modeling/inference/nuclei_level/prostate_cancer_mmci_tl.yaml` around lines 12 - 14, The three callback entries (trainer.callbacks.labels, trainer.callbacks.nuclei_masks, trainer.callbacks.thresholds) are commented out here while kept active in the sibling config; either remove these dead lines or add a short explanatory comment above them explaining why they're disabled (e.g., "MMCI lacks labels/nuclei masks, thresholds disabled") so the divergence is documented; update the block containing the commented lines "- /callbacks/predictions/labels@trainer.callbacks.labels", "- /callbacks/predictions@trainer.callbacks.nuclei_masks", and "- /callbacks/thresholds/plot_curves@trainer.callbacks.thresholds" accordingly.nuclei_graph/callbacks/predictions/labels.py (1)
74-79: Resetslide_predsinon_predict_startfor robustness.
self.slide_predsis only cleared at the end ofon_predict_epoch_end. If a previous predict run raises before reaching the epoch end, a subsequenttrainer.predict()on the same callback instance will accumulate stale slides. Reinitializing inon_predict_start(in addition to the current end-of-epoch reset) makes the callback reuse-safe.♻️ Proposed fix
+ def on_predict_start(self, trainer: Trainer, pl_module: LightningModule) -> None: + super().on_predict_start(trainer, pl_module) + self.slide_preds = {"slide_id": [], "is_carcinoma": [], "prediction": []} + def on_predict_batch_end(🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@nuclei_graph/callbacks/predictions/labels.py` around lines 74 - 79, The callback currently initializes self.slide_preds only in __init__ and clears it in on_predict_epoch_end, which can leave stale data if a predict run fails; add a reset of self.slide_preds = {"slide_id": [], "is_carcinoma": [], "prediction": []} at the start of on_predict_start to ensure each trainer.predict() run starts with a fresh buffer (keep the existing clear in on_predict_epoch_end as well).nuclei_graph/callbacks/predictions/nuclei_masks.py (2)
100-101:int(pred * 255)truncates; consider rounding.
int()truncates toward zero, so e.g.pred=0.999 → 254rather than255. Usinground(ornp.uint8(round(...))) gives a slightly more faithful visual mapping of probabilities/attention to grayscale. Same note applies to line 163 inMILAttentionMasksCallback. Nit.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@nuclei_graph/callbacks/predictions/nuclei_masks.py` around lines 100 - 101, The code currently computes pixel_val = int(pred * 255) which truncates toward zero; change this to use rounding so high probabilities map correctly (e.g., use round(pred * 255) or np.uint8(round(pred * 255))) before passing to canvas.polygon; update the same pattern in the MILAttentionMasksCallback occurrence as well (search for pixel_val/int(pred * 255) and replace with a rounded conversion to uint8) to ensure faithful grayscale mapping.
56-113: Significant duplication betweenWSLPredictionMasksCallbackandMILAttentionMasksCallback.The slide-open/scale computation, mask canvas setup, parquet polygon loading, polygon rasterization loop, and
write_big_tiffcall are essentially identical in both subclasses — only the per-nucleus intensity source (sigmoid(logits) vs. normalized attention) differs. Consider pulling the common rasterization pipeline into a helper onBaseMasksCallback(e.g.,_render_and_write(metadata, per_nucleus_values)), leaving subclasses to compute just the values array. This will also make future additions (new overlay kinds) much cheaper.Also applies to: 115-175
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@nuclei_graph/callbacks/predictions/nuclei_masks.py` around lines 56 - 113, WSLPredictionMasksCallback and MILAttentionMasksCallback duplicate the entire rasterization pipeline; extract the shared logic (OpenSlide slide open & scale computation, mask canvas creation, parquet polygon loading, polygon -> pixel scaling and drawing, and write_big_tiff invocation) into a new BaseMasksCallback helper (suggested name _render_and_write(metadata, per_nucleus_values)). Update WSLPredictionMasksCallback.on_predict_batch_end to only compute its per-nucleus values (predicted_labels = sigmoid(logits_ordered).cpu().numpy().flatten()) and pass metadata and that array to _render_and_write; do the same in MILAttentionMasksCallback (compute normalized attention values then call _render_and_write). The helper should reuse existing symbols/fields: self.level, mask_tile_width, mask_tile_height, self._get_output_path, slide_resolution, OpenSlide, PILImage, rearrange, write_big_tiff, and preserve the current mask sizing and scaling behavior.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@configs/callbacks/predictions/metrics_dataset/wsl.yaml`:
- Line 4: The inline comment on the threshold key contains an inconsistent MIL
value; update or remove the inline note so it no longer shows "MIL: 0.957" —
locate the threshold entry (the YAML key "threshold") and either delete the
trailing comment "MIL: 0.957, WSL: 0.388" or replace it with a single accurate
note (e.g., just "WSL: 0.388") that matches the MIL thresholds used in this PR.
In `@configs/callbacks/predictions/metrics_slide/wsl.yaml`:
- Line 4: The inline comment on the threshold line in the file (the `threshold:
0.388 # MIL: 0.957, WSL: 0.388` entry) is stale/misleading; update or remove the
inline comparison so it doesn't conflict with the current MIL threshold used
elsewhere (see the `threshold` entry in `metrics_dataset/mil.yaml` which is
`0.540`)—either delete the trailing `# MIL: ...` fragment or replace it with the
correct, current MIL value to keep `threshold` comments consistent.
In
`@configs/experiment/modeling/inference/crop_level/prostate_cancer_mmci_tl.yaml`:
- Around line 10-11: The Hydra config duplicates the target key
trainer.callbacks.metrics_dataset so the second entry (wsl) overwrites the first
(mil); change the entries to use distinct keys or separate list entries so both
callbacks register (e.g., use unique keys referencing the same target such as
trainer.callbacks.metrics_dataset_mil and trainer.callbacks.metrics_dataset_wsl
or convert to an explicit sequence under trainer.callbacks that includes both
targets) and ensure the identifiers reference the correct callback
implementations for MIL and WSL respectively.
In `@configs/experiment/modeling/inference/crop_level/radboud.yaml`:
- Line 12: The inline comment is incorrect: the callback entry
'/callbacks/predictions/metrics_slide@trainer.callbacks.metrics_slide: wsl'
refers to nuclei-level WSL metrics, not "graph-level preds"; update the comment
to accurately state "nuclei-level preds" (or remove the misleading "graph-level
preds") next to the metrics_slide / wsl entry so the comment matches the actual
behavior of the wsl metric.
- Around line 10-11: The two callback registrations use the identical Hydra
target key trainer.callbacks.metrics_dataset so the second (wsl) overwrites the
first (mil); rename one of the keys so both are unique (for example change the
left-side key for the MIL entry to something like
/callbacks/predictions/metrics_dataset_mil@trainer.callbacks.metrics_dataset or
similarly suffix the key) and keep the other as
/callbacks/predictions/metrics_dataset@trainer.callbacks.metrics_dataset for WSL
so both callbacks (mil and wsl) are registered and active.
- Line 15: The callback reference for attention is using an unnecessary '+'
prefix; locate the mapping key that reads
callbacks/attention@+trainer.callbacks.attn_masks and change it to
callbacks/attention@trainer.callbacks.attn_masks so it matches the other
callback entries and the pre-defined trainer.callbacks dict (i.e., remove the
'+' from the @+trainer.callbacks.attn_masks reference).
In `@nuclei_graph/callbacks/predictions/labels.py`:
- Around line 126-135: The CSV is being logged to the run root instead of the
predictions artifact path; inside the with tempfile.TemporaryDirectory() block
where misclassif_df is written, either move the CSV into the same temp directory
used by the base uploader (e.g., write to self.tmp_dir or the temp dir the base
class expects so _save_parquet and the base log_artifacts call pick it up) or
call mlflow.log_artifact with artifact_path=self.mlflow_artifact_path (replace
the current mlflow.log_artifact call) so misclassifications.csv is uploaded
alongside the prediction parquet files.
In `@nuclei_graph/callbacks/thresholds/plot_curves.py`:
- Around line 68-92: The _perform_roc function currently looks for exact tpr ==
1 and falls back to thresholds[0] (sklearn's sentinel) when no tpr reaches 1;
change this to (1) use a tolerant comparison like np.isclose(tpr, 1.0) to find
indices and (2) when no index is found do not pick thresholds[0] — instead set
tpr_threshold to np.nan (or None) to signal "not available" so you don't log the
sklearn sentinel; update the variables idx, tpr_idx, and tpr_threshold in
_perform_roc accordingly and ensure callers that expect a numeric threshold
handle np.nan/None instead of the sentinel.
- Around line 94-99: In _perform_pr, f1 is computed from precision and recall
which are length n_thresholds+1 while thresholds is length n_thresholds, causing
a possible IndexError when picking best_threshold; fix by computing f1 only over
the entries that correspond to thresholds (e.g., use precision[:-1] and
recall[:-1] so f1.shape == thresholds.shape) and then select best_threshold =
thresholds[best_idx]; also guard against the case of empty thresholds (handle or
return a sensible default when thresholds.size == 0) to avoid indexing errors.
---
Nitpick comments:
In
`@configs/experiment/modeling/inference/nuclei_level/prostate_cancer_mmci_tl.yaml`:
- Around line 12-14: The three callback entries (trainer.callbacks.labels,
trainer.callbacks.nuclei_masks, trainer.callbacks.thresholds) are commented out
here while kept active in the sibling config; either remove these dead lines or
add a short explanatory comment above them explaining why they're disabled
(e.g., "MMCI lacks labels/nuclei masks, thresholds disabled") so the divergence
is documented; update the block containing the commented lines "-
/callbacks/predictions/labels@trainer.callbacks.labels", "-
/callbacks/predictions@trainer.callbacks.nuclei_masks", and "-
/callbacks/thresholds/plot_curves@trainer.callbacks.thresholds" accordingly.
In `@nuclei_graph/callbacks/predictions/labels.py`:
- Around line 74-79: The callback currently initializes self.slide_preds only in
__init__ and clears it in on_predict_epoch_end, which can leave stale data if a
predict run fails; add a reset of self.slide_preds = {"slide_id": [],
"is_carcinoma": [], "prediction": []} at the start of on_predict_start to ensure
each trainer.predict() run starts with a fresh buffer (keep the existing clear
in on_predict_epoch_end as well).
In `@nuclei_graph/callbacks/predictions/metrics_dataset.py`:
- Around line 39-50: The loop in on_predict_epoch_end is redundantly removing
and re-adding the "prediction/" prefix: dataset_nuclei_metrics.compute() already
returns keys like "prediction/<name>", so in the on_predict_epoch_end method of
this callback you should stop splitting the key and re-prepending the prefix;
instead pass the returned key directly into mlflow.log_metric (i.e., use the
compute() keys as-is). Make the same simplification in
MILDatasetPredictionMetricsCallback.on_predict_epoch_end (the else branch
mentioned) so neither callback strips and re-adds the "prediction/" prefix.
- Around line 1-4: The callbacks in metrics_dataset.py currently call the global
mlflow module directly; change them to use the Trainer's logger (trainer.logger)
like WSLSlidePredictionMetricsCallback does in metrics_slide.py: remove direct
mlflow imports/usages, call trainer.logger.log_metrics(...) to record metric
dicts, and use trainer.logger.experiment.log_figure(run_id, name, figure) or
trainer.logger.experiment.log_* APIs to attach figures (obtain run_id from
trainer.logger.experiment.active_run_id or the logger's run_id property as
appropriate); update any functions/methods in the callbacks that reference
mlflow to instead accept a Trainer instance and route logging through
trainer.logger and trainer.logger.experiment.
In `@nuclei_graph/callbacks/predictions/metrics_slide.py`:
- Around line 22-31: The Callback subclass's __init__ should call the base
constructor; in the __init__ method where slide_nuclei_metrics is created (the
__init__ method that sets slide_nuclei_metrics = NestedMetricCollection(...)),
add a call to super().__init__() (typically as the first statement) to ensure
cooperative initialization with lightning.Callback and support future base-class
changes.
- Around line 59-67: The on_predict_epoch_end implementation in
WSLSlidePredictionMetricsCallback currently asserts trainer.logger is an
MLFlowLogger and calls trainer.logger.log_table(...), which is inconsistent with
WSLDatasetPredictionMetricsCallback and MILDatasetPredictionMetricsCallback that
use the global mlflow module; change
WSLSlidePredictionMetricsCallback.on_predict_epoch_end to use the same global
mlflow logging approach as the dataset callbacks (log the result of
self.slide_nuclei_metrics.compute() via the global mlflow API and remove the
MLFlowLogger assert), ensuring the artifact/metric is attached to the same
run/context as the other callbacks and then call
self.slide_nuclei_metrics.reset().
In `@nuclei_graph/callbacks/predictions/nuclei_masks.py`:
- Around line 100-101: The code currently computes pixel_val = int(pred * 255)
which truncates toward zero; change this to use rounding so high probabilities
map correctly (e.g., use round(pred * 255) or np.uint8(round(pred * 255)))
before passing to canvas.polygon; update the same pattern in the
MILAttentionMasksCallback occurrence as well (search for pixel_val/int(pred *
255) and replace with a rounded conversion to uint8) to ensure faithful
grayscale mapping.
- Around line 56-113: WSLPredictionMasksCallback and MILAttentionMasksCallback
duplicate the entire rasterization pipeline; extract the shared logic (OpenSlide
slide open & scale computation, mask canvas creation, parquet polygon loading,
polygon -> pixel scaling and drawing, and write_big_tiff invocation) into a new
BaseMasksCallback helper (suggested name _render_and_write(metadata,
per_nucleus_values)). Update WSLPredictionMasksCallback.on_predict_batch_end to
only compute its per-nucleus values (predicted_labels =
sigmoid(logits_ordered).cpu().numpy().flatten()) and pass metadata and that
array to _render_and_write; do the same in MILAttentionMasksCallback (compute
normalized attention values then call _render_and_write). The helper should
reuse existing symbols/fields: self.level, mask_tile_width, mask_tile_height,
self._get_output_path, slide_resolution, OpenSlide, PILImage, rearrange,
write_big_tiff, and preserve the current mask sizing and scaling behavior.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: b0ffc4db-acaf-44c9-a21f-50137756319f
📒 Files selected for processing (28)
configs/callbacks/attention/nuclei_masks.yamlconfigs/callbacks/predictions/labels/mil.yamlconfigs/callbacks/predictions/labels/wsl.yamlconfigs/callbacks/predictions/metrics_dataset/mil.yamlconfigs/callbacks/predictions/metrics_dataset/wsl.yamlconfigs/callbacks/predictions/metrics_slide/wsl.yamlconfigs/callbacks/predictions/nuclei_masks.yamlconfigs/callbacks/thresholds/plot_curves/mil.yamlconfigs/callbacks/thresholds/plot_curves/wsl.yamlconfigs/experiment/modeling/inference/base.yamlconfigs/experiment/modeling/inference/crop_level/prostate_cancer_mmci_tl.yamlconfigs/experiment/modeling/inference/crop_level/radboud.yamlconfigs/experiment/modeling/inference/nuclei_level/prostate_cancer_mmci_tl.yamlconfigs/experiment/modeling/inference/nuclei_level/radboud.yamlconfigs/experiment/modeling/validation/base.yamlconfigs/experiment/modeling/validation/crop_level/prostate_cancer_mmci_tl.yamlconfigs/experiment/modeling/validation/crop_level/radboud.yamlconfigs/experiment/modeling/validation/nuclei_level/prostate_cancer_mmci_tl.yamlconfigs/experiment/modeling/validation/nuclei_level/radboud.yamlconfigs/ml.yamlnuclei_graph/callbacks/__init__.pynuclei_graph/callbacks/predictions/__init__.pynuclei_graph/callbacks/predictions/labels.pynuclei_graph/callbacks/predictions/metrics_dataset.pynuclei_graph/callbacks/predictions/metrics_slide.pynuclei_graph/callbacks/predictions/nuclei_masks.pynuclei_graph/callbacks/thresholds/__init__.pynuclei_graph/callbacks/thresholds/plot_curves.py
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (1)
nuclei_graph/callbacks/thresholds/plot_curves.py (1)
68-74:⚠️ Potential issue | 🟡 Minor
tpr_thresholdfallback still emits sklearn's sentinelnp.inf.Previous review flagged this and it hasn't been addressed. Since sklearn ≥ 1.3,
roc_curvesetsthresholds[0] = np.inf, so whenlen(idx) == 0(e.g., no positives reachTPR == 1),tpr_threshold = thresholds[0] = infgets forwarded tomlflow.log_metric(..., float(roc_t))on line 30 — MLflow rejects non-finite metric values, which will raise at log time.Additionally,
np.where(tpr == 1)uses exact float equality; prefernp.isclose(tpr, 1.0)for robustness.🛡️ Proposed fix
- idx = np.where(tpr == 1)[0] - tpr_idx = idx[np.argmin(fpr[idx])] if len(idx) > 0 else 0 - tpr_threshold = thresholds[tpr_idx] + idx = np.where(np.isclose(tpr, 1.0))[0] + if len(idx) > 0: + tpr_idx = idx[np.argmin(fpr[idx])] + tpr_threshold = float(thresholds[tpr_idx]) + else: + tpr_idx = int(np.argmax(tpr)) # best-effort point for plotting + tpr_threshold = float("nan") # signal "not available"And in
_log_and_clear_curves, skip logging / handleNaNaccordingly:- mlflow.log_metric(f"thresholds/{level_name}_tpr_threshold", float(roc_t)) + if np.isfinite(roc_t): + mlflow.log_metric(f"thresholds/{level_name}_tpr_threshold", float(roc_t))🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@nuclei_graph/callbacks/thresholds/plot_curves.py` around lines 68 - 74, The code in _perform_roc uses exact equality to find TPR==1 and falls back to thresholds[0] which may be np.inf; change the selection to use np.isclose(tpr, 1.0) and if no matching index exists set tpr_threshold = np.nan (or otherwise ensure it is finite) rather than using thresholds[0]; additionally update _log_and_clear_curves to skip or convert non-finite roc_t values before calling mlflow.log_metric (only call mlflow.log_metric when np.isfinite(roc_t) and cast to float), so MLflow never receives inf/NaN.
🧹 Nitpick comments (1)
nuclei_graph/callbacks/thresholds/plot_curves.py (1)
40-66: Optional: iterate withzipand add type hints.Minor idiomatic cleanup — no behavioral change.
♻️ Proposed refactor
- for i in range(len(to_pinpoint)): - x, y = to_pinpoint[i] - ax.scatter(x, y, color=point_colors[i], label=point_labels[i], zorder=5) + for (x, y), label, color in zip(to_pinpoint, point_labels, point_colors): + ax.scatter(x, y, color=color, label=label, zorder=5)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@nuclei_graph/callbacks/thresholds/plot_curves.py` around lines 40 - 66, The _plot_curve method currently iterates by index over to_pinpoint and accesses parallel lists; change the loop to iterate with zip (e.g., for (x, y), color, label in zip(to_pinpoint, point_colors, point_labels)) to make it idiomatic and safer, and add lightweight type hints to the signature (e.g., xs: Sequence[float], ys: Sequence[float], to_pinpoint: Sequence[Tuple[float, float]], point_labels: Sequence[str], point_colors: Sequence[str], xlabel: str, ylabel: str, title: str, loc: Any) to improve readability and static checking while preserving behavior.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@nuclei_graph/callbacks/thresholds/plot_curves.py`:
- Around line 127-149: Add an early guard in
MILCurvesCallback.on_validation_batch_end to return immediately if outputs is
None (like WSLCurvesCallback does) before accessing outputs["graph"] or
outputs["nuclei"]; keep the existing trainer.sanity_checking check and ensure
you check outputs is None right after it so subsequent lines (graph_outputs =
outputs["graph"].view(-1), nuclei_outputs = outputs["nuclei"][...]) cannot NPE.
---
Duplicate comments:
In `@nuclei_graph/callbacks/thresholds/plot_curves.py`:
- Around line 68-74: The code in _perform_roc uses exact equality to find TPR==1
and falls back to thresholds[0] which may be np.inf; change the selection to use
np.isclose(tpr, 1.0) and if no matching index exists set tpr_threshold = np.nan
(or otherwise ensure it is finite) rather than using thresholds[0]; additionally
update _log_and_clear_curves to skip or convert non-finite roc_t values before
calling mlflow.log_metric (only call mlflow.log_metric when np.isfinite(roc_t)
and cast to float), so MLflow never receives inf/NaN.
---
Nitpick comments:
In `@nuclei_graph/callbacks/thresholds/plot_curves.py`:
- Around line 40-66: The _plot_curve method currently iterates by index over
to_pinpoint and accesses parallel lists; change the loop to iterate with zip
(e.g., for (x, y), color, label in zip(to_pinpoint, point_colors, point_labels))
to make it idiomatic and safer, and add lightweight type hints to the signature
(e.g., xs: Sequence[float], ys: Sequence[float], to_pinpoint:
Sequence[Tuple[float, float]], point_labels: Sequence[str], point_colors:
Sequence[str], xlabel: str, ylabel: str, title: str, loc: Any) to improve
readability and static checking while preserving behavior.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: e3494ca0-71da-4001-baa5-ac68e2ec8d57
📒 Files selected for processing (1)
nuclei_graph/callbacks/thresholds/plot_curves.py
Summary by CodeRabbit