5757import torch
5858from matplotlib .figure import Figure
5959from torchmetrics import Metric
60- from torchmetrics .functional .classification import binary_roc
6160from torchmetrics .utilities .compute import auc
6261from torchmetrics .utilities .data import dim_zero_cat
6362
6463from anomalib .metrics .pro import connected_components_cpu , connected_components_gpu
64+ from anomalib .utils import deprecate
6565
6666from .base import AnomalibMetric
67- from .binning import thresholds_between_0_and_1 , thresholds_between_min_and_max
6867from .utils import plot_metric_curve
6968
7069
70+ @deprecate (
71+ args = {"num_thresholds" : None },
72+ since = "2.1.0" ,
73+ remove = "3.0.0" ,
74+ reason = "New AUPRO computation does not require number of thresholds" ,
75+ )
7176class _AUPRO (Metric ):
7277 """Area under per region overlap (AUPRO) Metric.
7378
@@ -82,9 +87,8 @@ class _AUPRO(Metric):
8287 Defaults to ``None``.
8388 fpr_limit (float): Limit for the false positive rate.
8489 Defaults to ``0.3``.
85- num_thresholds (int | None): Number of thresholds to use for computing
86- the ROC curve. When ``None``, uses thresholds from torchmetrics.
87- Defaults to ``None``.
90+ num_thresholds (int | None): Present for backward compatibility with the
91+ old implementation, but ignored in this fast version.
8892
8993 Examples:
9094 >>> import torch
@@ -109,14 +113,6 @@ class _AUPRO(Metric):
109113 full_state_update : bool = False
110114 preds : list [torch .Tensor ]
111115 target : list [torch .Tensor ]
112- # When not None, the computation is performed in constant-memory by computing
113- # the roc curve for fixed thresholds buckets/thresholds.
114- # Warning: The thresholds are evenly distributed between the min and max
115- # predictions if all predictions are inside [0, 1]. Otherwise, the thresholds
116- # are evenly distributed between 0 and 1.
117- # This warning can be removed when
118- # https://github.com/Lightning-AI/torchmetrics/issues/1526 is fixed
119- # and the roc curve is computed with deactivated formatting
120116 num_thresholds : int | None
121117
122118 def __init__ (
@@ -136,6 +132,8 @@ def __init__(
136132 self .add_state ("preds" , default = [], dist_reduce_fx = "cat" )
137133 self .add_state ("target" , default = [], dist_reduce_fx = "cat" )
138134 self .register_buffer ("fpr_limit" , torch .tensor (fpr_limit ))
135+
136+ # Kept for API compatibility; ignored by the fast implementation.
139137 self .num_thresholds = num_thresholds
140138
141139 def update (self , preds : torch .Tensor , target : torch .Tensor ) -> None :
@@ -158,7 +156,7 @@ def perform_cca(self) -> torch.Tensor:
158156 Returns:
159157 Tensor: Components labeled from 0 to N.
160158 """
161- target = dim_zero_cat (self .target )
159+ target = dim_zero_cat (self .target ) # (B, ..., H, W)
162160
163161 # check and prepare target for labeling via kornia
164162 if target .min () < 0 or target .max () > 1 :
@@ -173,130 +171,166 @@ def perform_cca(self) -> torch.Tensor:
173171 target = target .type (torch .float ) # kornia expects FloatTensor
174172 return connected_components_gpu (target ) if target .is_cuda else connected_components_cpu (target )
175173
174+ @staticmethod
175+ def _make_global_region_labels (cca : torch .Tensor ) -> torch .Tensor :
176+ """Offset connected component labels across batch to make them unique.
177+
178+ Args:
179+ cca (torch.Tensor): (B, H, W) integer labels, starting at 0 for each image.
180+
181+ Returns:
182+ torch.Tensor: (B, H, W) labels where:
183+ - 0 is still background
184+ - positive labels are unique across the batch
185+ """
186+ # We iterate over batch dimension only; typically small.
187+ batch_size = cca .size (0 )
188+ cca_off = cca .clone ()
189+ current_offset = 0
190+
191+ for b in range (batch_size ):
192+ img_labels = cca_off [b ]
193+ unique = img_labels .unique ()
194+ unique_fg = unique [unique != 0 ]
195+ num_regions = int (unique_fg .numel ())
196+ if num_regions == 0 :
197+ continue
198+
199+ # shift all foreground labels in this image by current_offset
200+ fg_mask = img_labels > 0
201+ img_labels [fg_mask ] = img_labels [fg_mask ] + current_offset
202+
203+ cca_off [b ] = img_labels
204+ current_offset += num_regions
205+
206+ return cca_off
207+
208+ @deprecate (
209+ args = {"target" : None },
210+ since = "2.1.0" ,
211+ remove = "3.0.0" ,
212+ reason = "Compute PRO computes overlap with connected components not target." ,
213+ )
176214 def compute_pro (
177215 self ,
178216 cca : torch .Tensor ,
179- target : torch .Tensor ,
180217 preds : torch .Tensor ,
218+ target : torch .Tensor | None = None ,
181219 ) -> tuple [torch .Tensor , torch .Tensor ]:
182- """Compute the pro/fpr value-pairs until the fpr specified by fpr_limit .
220+ """Compute the PRO curve (FPR vs. averaged per-region TPR/overlap) .
183221
184- It leverages the fact that the overlap corresponds to the tpr, and thus
185- computes the overall PRO curve by aggregating per-region tpr/fpr values
186- produced by ROC-construction.
222+ This implementation is inspired by the MvTec implementation found at
223+ https://www.mvtec.com/company/research/datasets/mvtec-ad
187224
188225 Args:
189- cca (torch.Tensor): Connected components tensor
190- target (torch.Tensor): Ground truth tensor
191- preds (torch.Tensor): Model predictions tensor
226+ cca (torch.Tensor):
227+ Connected-component labels of shape (B, H, W). Must contain integers
228+ ≥ 0, where 0 denotes background and >0 denote region IDs.
229+ preds (torch.Tensor):
230+ Prediction scores of shape (B, H, W). Higher values indicate more
231+ anomalous.
232+ target (torch.Tensor | None):
233+ Unused; accepted only for API compatibility.
192234
193235 Returns:
194- tuple[torch.Tensor, torch.Tensor]: Tuple containing final fpr and tpr
195- values.
236+ tuple[torch.Tensor, torch.Tensor]:
237+ (fpr, pro), both 1-D tensors sorted by increasing FPR and clipped
238+ at ``self.fpr_limit``.
196239 """
197- if self .num_thresholds is not None :
198- # binary_roc is applying a sigmoid on the predictions before computing
199- # the roc curve when some predictions are out of [0, 1], the binning
200- # between min and max predictions cannot be applied in that case.
201- # This can be removed when
202- # https://github.com/Lightning-AI/torchmetrics/issues/1526 is fixed
203- # and the roc curve is computed with deactivated formatting.
204-
205- if torch .all ((preds >= 0 ) * (preds <= 1 )):
206- thresholds = thresholds_between_min_and_max (preds , self .num_thresholds , self .device )
207- else :
208- thresholds = thresholds_between_0_and_1 (self .num_thresholds , self .device )
240+ del target
241+ device = preds .device
242+
243+ # flatten (already on correct device)
244+ labels = cca .reshape (- 1 ).long ()
245+ preds_flat = preds .reshape (- 1 ).float ()
246+
247+ # background = FPR contribution
248+ background = labels == 0
249+ fp_change = background .float ()
250+ num_bg = fp_change .sum ()
251+
252+ if num_bg == 0 :
253+ f = float (self .fpr_limit )
254+ return (
255+ torch .tensor ([0.0 , f ], device = device ),
256+ torch .tensor ([0.0 , 0.0 ], device = device ),
257+ )
209258
210- else :
211- thresholds = None
212-
213- # compute the global fpr-size
214- fpr : torch .Tensor = binary_roc (
215- preds = preds ,
216- target = target ,
217- thresholds = thresholds ,
218- )[0 ] # only need fpr
219- output_size = torch .where (fpr <= self .fpr_limit )[0 ].size (0 )
220-
221- # compute the PRO curve by aggregating per-region tpr/fpr curves/values.
222- tpr = torch .zeros (output_size , device = preds .device , dtype = torch .float )
223- fpr = torch .zeros (output_size , device = preds .device , dtype = torch .float )
224- new_idx = torch .arange (0 , output_size , device = preds .device , dtype = torch .float )
225-
226- # Loop over the labels, computing per-region tpr/fpr curves, and
227- # aggregating them. Note that, since the groundtruth is different for
228- # every all to `roc`, we also get different/unique tpr/fpr curves
229- # (i.e. len(_fpr_idx) is different for every call).
230- # We therefore need to resample per-region curves to a fixed sampling
231- # ratio (defined above).
232- labels = cca .unique ()[1 :] # 0 is background
233- background = cca == 0
234- fpr_ : torch .Tensor
235- tpr_ : torch .Tensor
236- for label in labels :
237- interp : bool = False
238- new_idx [- 1 ] = output_size - 1
239- mask = cca == label
240- # Need to calculate label-wise roc on union of background & mask, as
241- # otherwise we wrongly consider other label in labels as FPs.
242- # We also don't need to return the thresholds
243- fpr_ , tpr_ = binary_roc (
244- preds = preds [background | mask ],
245- target = mask [background | mask ],
246- thresholds = thresholds ,
247- )[:- 1 ]
248-
249- # catch edge-case where ROC only has fpr vals > self.fpr_limit
250- if fpr_ [fpr_ <= self .fpr_limit ].max () == 0 :
251- fpr_limit_ = fpr_ [fpr_ > self .fpr_limit ].min ()
252- else :
253- fpr_limit_ = self .fpr_limit
254-
255- fpr_idx_ = torch .where (fpr_ <= fpr_limit_ )[0 ]
256- # if computed roc curve is not specified sufficiently close to
257- # self.fpr_limit, we include the closest higher tpr/fpr pair and
258- # linearly interpolate the tpr/fpr point at self.fpr_limit
259- if not torch .allclose (fpr_ [fpr_idx_ ].max (), self .fpr_limit ):
260- tmp_idx_ = torch .searchsorted (fpr_ , self .fpr_limit )
261- fpr_idx_ = torch .cat ([fpr_idx_ , tmp_idx_ .unsqueeze_ (0 )])
262- slope_ = 1 - ((fpr_ [tmp_idx_ ] - self .fpr_limit ) / (fpr_ [tmp_idx_ ] - fpr_ [tmp_idx_ - 1 ]))
263- interp = True
264-
265- fpr_ = fpr_ [fpr_idx_ ]
266- tpr_ = tpr_ [fpr_idx_ ]
267-
268- fpr_idx_ = fpr_idx_ .float ()
269- fpr_idx_ /= fpr_idx_ .max ()
270- fpr_idx_ *= new_idx .max ()
271-
272- if interp :
273- # last point will be sampled at self.fpr_limit
274- new_idx [- 1 ] = fpr_idx_ [- 2 ] + ((fpr_idx_ [- 1 ] - fpr_idx_ [- 2 ]) * slope_ )
275-
276- tpr_ = self .interp1d (fpr_idx_ , tpr_ , new_idx )
277- fpr_ = self .interp1d (fpr_idx_ , fpr_ , new_idx )
278- tpr += tpr_
279- fpr += fpr_
280-
281- # Actually perform the averaging
282- tpr /= labels .size (0 )
283- fpr /= labels .size (0 )
284- return fpr , tpr
259+ max_label = int (labels .max ())
260+ if max_label == 0 :
261+ f = float (self .fpr_limit )
262+ return (
263+ torch .tensor ([0.0 , f ], device = device ),
264+ torch .tensor ([0.0 , 0.0 ], device = device ),
265+ )
285266
286- def _compute (self ) -> tuple [torch .Tensor , torch .Tensor ]:
287- """Compute the PRO curve.
267+ region_sizes = torch .bincount (labels , minlength = max_label + 1 ).float ()
268+ num_regions = (region_sizes [1 :] > 0 ).sum ()
269+
270+ if num_regions == 0 :
271+ f = float (self .fpr_limit )
272+ return (
273+ torch .tensor ([0.0 , f ], device = device ),
274+ torch .tensor ([0.0 , 0.0 ], device = device ),
275+ )
288276
289- Perform the Connected Component Analysis first then compute the PRO curve.
277+ fg_mask = labels > 0
290278
291- Returns:
292- tuple[torch.Tensor, torch.Tensor]: Tuple containing final fpr and tpr
293- values.
294- """
295- cca = self .perform_cca ().flatten ()
296- target = dim_zero_cat (self .target ).flatten ()
297- preds = dim_zero_cat (self .preds ).flatten ()
279+ pro_change = torch .zeros_like (preds_flat )
280+ pro_change [fg_mask ] = 1.0 / region_sizes [labels [fg_mask ]]
281+
282+ # global sort
283+ idx = torch .argsort (preds_flat , descending = True )
284+ fp_sorted = fp_change [idx ]
285+ pro_sorted = pro_change [idx ]
286+ preds_sorted = preds_flat [idx ]
287+
288+ # cumulative sums
289+ fpr = torch .cumsum (fp_sorted , 0 ) / num_bg
290+ pro = torch .cumsum (pro_sorted , 0 ) / num_regions
291+ fpr .clamp_ (max = 1.0 )
292+ pro .clamp_ (max = 1.0 )
293+
294+ # remove duplicate thresholds
295+ keep = torch .ones_like (preds_sorted , dtype = torch .bool )
296+ keep [:- 1 ] = preds_sorted [:- 1 ] != preds_sorted [1 :]
297+ fpr = fpr [keep ]
298+ pro = pro [keep ]
299+
300+ # prepend zero
301+ fpr = torch .cat ([torch .tensor ([0.0 ], device = device ), fpr ])
302+ pro = torch .cat ([torch .tensor ([0.0 ], device = device ), pro ])
303+
304+ # FPR limit clipping
305+ f_lim = float (self .fpr_limit )
306+ mask = fpr <= f_lim
298307
299- return self .compute_pro (cca = cca , target = target , preds = preds )
308+ if mask .any ():
309+ i = mask .nonzero (as_tuple = True )[0 ][- 1 ].item ()
310+
311+ if fpr [i ] < f_lim and i + 1 < fpr .numel ():
312+ f1 , f2 = fpr [i ], fpr [i + 1 ]
313+ p1 , p2 = pro [i ], pro [i + 1 ]
314+ p_lim = p1 + (p2 - p1 ) * (f_lim - f1 ) / (f2 - f1 )
315+
316+ fpr = torch .cat ([fpr [: i + 1 ], torch .tensor ([f_lim ], device = device )])
317+ pro = torch .cat ([pro [: i + 1 ], torch .tensor ([p_lim ], device = device )])
318+ else :
319+ fpr = fpr [: i + 1 ]
320+ pro = pro [: i + 1 ]
321+ else :
322+ fpr = torch .tensor ([0.0 , f_lim ], device = device )
323+ pro = torch .tensor ([0.0 , 0.0 ], device = device )
324+
325+ return fpr , pro
326+
327+ def _compute (self ) -> tuple [torch .Tensor , torch .Tensor ]:
328+ """Compute the PRO curve (FPR vs PRO) for all stored predictions."""
329+ cca = self .perform_cca () # (B, H, W)
330+ preds = dim_zero_cat (self .preds ) # (B, 1, H, W) or (B, H, W)
331+ if preds .dim () > 3 and preds .size (1 ) == 1 :
332+ preds = preds .squeeze (1 )
333+ return self .compute_pro (cca = cca , preds = preds )
300334
301335 def compute (self ) -> torch .Tensor :
302336 """First compute PRO curve, then compute and scale area under the curve.
@@ -319,7 +353,7 @@ def generate_figure(self) -> tuple[Figure, str]:
319353 fpr , tpr = self ._compute ()
320354 aupro = self .compute ()
321355
322- xlim = (0.0 , float (self .fpr_limit .detach ().cpu ().numpy ()))
356+ xlim = (0.0 , float (self .fpr_limit .detach ().cpu ().item ()))
323357 ylim = (0.0 , 1.0 )
324358 xlabel = "Global FPR"
325359 ylabel = "Averaged Per-Region TPR"
0 commit comments