Skip to content

Commit 3874ed2

Browse files
refactor(metric): Speedup AUPRO (#3115)
update aupro for speedup, remove binning test case (we do not bin anymore Signed-off-by: Alfie Roddan <[email protected]>
1 parent 2bccb2a commit 3874ed2

File tree

2 files changed

+157
-157
lines changed

2 files changed

+157
-157
lines changed

src/anomalib/metrics/aupro.py

Lines changed: 157 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -57,17 +57,22 @@
5757
import torch
5858
from matplotlib.figure import Figure
5959
from torchmetrics import Metric
60-
from torchmetrics.functional.classification import binary_roc
6160
from torchmetrics.utilities.compute import auc
6261
from torchmetrics.utilities.data import dim_zero_cat
6362

6463
from anomalib.metrics.pro import connected_components_cpu, connected_components_gpu
64+
from anomalib.utils import deprecate
6565

6666
from .base import AnomalibMetric
67-
from .binning import thresholds_between_0_and_1, thresholds_between_min_and_max
6867
from .utils import plot_metric_curve
6968

7069

70+
@deprecate(
71+
args={"num_thresholds": None},
72+
since="2.1.0",
73+
remove="3.0.0",
74+
reason="New AUPRO computation does not require number of thresholds",
75+
)
7176
class _AUPRO(Metric):
7277
"""Area under per region overlap (AUPRO) Metric.
7378
@@ -82,9 +87,8 @@ class _AUPRO(Metric):
8287
Defaults to ``None``.
8388
fpr_limit (float): Limit for the false positive rate.
8489
Defaults to ``0.3``.
85-
num_thresholds (int | None): Number of thresholds to use for computing
86-
the ROC curve. When ``None``, uses thresholds from torchmetrics.
87-
Defaults to ``None``.
90+
num_thresholds (int | None): Present for backward compatibility with the
91+
old implementation, but ignored in this fast version.
8892
8993
Examples:
9094
>>> import torch
@@ -109,14 +113,6 @@ class _AUPRO(Metric):
109113
full_state_update: bool = False
110114
preds: list[torch.Tensor]
111115
target: list[torch.Tensor]
112-
# When not None, the computation is performed in constant-memory by computing
113-
# the roc curve for fixed thresholds buckets/thresholds.
114-
# Warning: The thresholds are evenly distributed between the min and max
115-
# predictions if all predictions are inside [0, 1]. Otherwise, the thresholds
116-
# are evenly distributed between 0 and 1.
117-
# This warning can be removed when
118-
# https://github.com/Lightning-AI/torchmetrics/issues/1526 is fixed
119-
# and the roc curve is computed with deactivated formatting
120116
num_thresholds: int | None
121117

122118
def __init__(
@@ -136,6 +132,8 @@ def __init__(
136132
self.add_state("preds", default=[], dist_reduce_fx="cat")
137133
self.add_state("target", default=[], dist_reduce_fx="cat")
138134
self.register_buffer("fpr_limit", torch.tensor(fpr_limit))
135+
136+
# Kept for API compatibility; ignored by the fast implementation.
139137
self.num_thresholds = num_thresholds
140138

141139
def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
@@ -158,7 +156,7 @@ def perform_cca(self) -> torch.Tensor:
158156
Returns:
159157
Tensor: Components labeled from 0 to N.
160158
"""
161-
target = dim_zero_cat(self.target)
159+
target = dim_zero_cat(self.target) # (B, ..., H, W)
162160

163161
# check and prepare target for labeling via kornia
164162
if target.min() < 0 or target.max() > 1:
@@ -173,130 +171,166 @@ def perform_cca(self) -> torch.Tensor:
173171
target = target.type(torch.float) # kornia expects FloatTensor
174172
return connected_components_gpu(target) if target.is_cuda else connected_components_cpu(target)
175173

174+
@staticmethod
175+
def _make_global_region_labels(cca: torch.Tensor) -> torch.Tensor:
176+
"""Offset connected component labels across batch to make them unique.
177+
178+
Args:
179+
cca (torch.Tensor): (B, H, W) integer labels, starting at 0 for each image.
180+
181+
Returns:
182+
torch.Tensor: (B, H, W) labels where:
183+
- 0 is still background
184+
- positive labels are unique across the batch
185+
"""
186+
# We iterate over batch dimension only; typically small.
187+
batch_size = cca.size(0)
188+
cca_off = cca.clone()
189+
current_offset = 0
190+
191+
for b in range(batch_size):
192+
img_labels = cca_off[b]
193+
unique = img_labels.unique()
194+
unique_fg = unique[unique != 0]
195+
num_regions = int(unique_fg.numel())
196+
if num_regions == 0:
197+
continue
198+
199+
# shift all foreground labels in this image by current_offset
200+
fg_mask = img_labels > 0
201+
img_labels[fg_mask] = img_labels[fg_mask] + current_offset
202+
203+
cca_off[b] = img_labels
204+
current_offset += num_regions
205+
206+
return cca_off
207+
208+
@deprecate(
209+
args={"target": None},
210+
since="2.1.0",
211+
remove="3.0.0",
212+
reason="Compute PRO computes overlap with connected components not target.",
213+
)
176214
def compute_pro(
177215
self,
178216
cca: torch.Tensor,
179-
target: torch.Tensor,
180217
preds: torch.Tensor,
218+
target: torch.Tensor | None = None,
181219
) -> tuple[torch.Tensor, torch.Tensor]:
182-
"""Compute the pro/fpr value-pairs until the fpr specified by fpr_limit.
220+
"""Compute the PRO curve (FPR vs. averaged per-region TPR/overlap).
183221
184-
It leverages the fact that the overlap corresponds to the tpr, and thus
185-
computes the overall PRO curve by aggregating per-region tpr/fpr values
186-
produced by ROC-construction.
222+
This implementation is inspired by the MvTec implementation found at
223+
https://www.mvtec.com/company/research/datasets/mvtec-ad
187224
188225
Args:
189-
cca (torch.Tensor): Connected components tensor
190-
target (torch.Tensor): Ground truth tensor
191-
preds (torch.Tensor): Model predictions tensor
226+
cca (torch.Tensor):
227+
Connected-component labels of shape (B, H, W). Must contain integers
228+
≥ 0, where 0 denotes background and >0 denote region IDs.
229+
preds (torch.Tensor):
230+
Prediction scores of shape (B, H, W). Higher values indicate more
231+
anomalous.
232+
target (torch.Tensor | None):
233+
Unused; accepted only for API compatibility.
192234
193235
Returns:
194-
tuple[torch.Tensor, torch.Tensor]: Tuple containing final fpr and tpr
195-
values.
236+
tuple[torch.Tensor, torch.Tensor]:
237+
(fpr, pro), both 1-D tensors sorted by increasing FPR and clipped
238+
at ``self.fpr_limit``.
196239
"""
197-
if self.num_thresholds is not None:
198-
# binary_roc is applying a sigmoid on the predictions before computing
199-
# the roc curve when some predictions are out of [0, 1], the binning
200-
# between min and max predictions cannot be applied in that case.
201-
# This can be removed when
202-
# https://github.com/Lightning-AI/torchmetrics/issues/1526 is fixed
203-
# and the roc curve is computed with deactivated formatting.
204-
205-
if torch.all((preds >= 0) * (preds <= 1)):
206-
thresholds = thresholds_between_min_and_max(preds, self.num_thresholds, self.device)
207-
else:
208-
thresholds = thresholds_between_0_and_1(self.num_thresholds, self.device)
240+
del target
241+
device = preds.device
242+
243+
# flatten (already on correct device)
244+
labels = cca.reshape(-1).long()
245+
preds_flat = preds.reshape(-1).float()
246+
247+
# background = FPR contribution
248+
background = labels == 0
249+
fp_change = background.float()
250+
num_bg = fp_change.sum()
251+
252+
if num_bg == 0:
253+
f = float(self.fpr_limit)
254+
return (
255+
torch.tensor([0.0, f], device=device),
256+
torch.tensor([0.0, 0.0], device=device),
257+
)
209258

210-
else:
211-
thresholds = None
212-
213-
# compute the global fpr-size
214-
fpr: torch.Tensor = binary_roc(
215-
preds=preds,
216-
target=target,
217-
thresholds=thresholds,
218-
)[0] # only need fpr
219-
output_size = torch.where(fpr <= self.fpr_limit)[0].size(0)
220-
221-
# compute the PRO curve by aggregating per-region tpr/fpr curves/values.
222-
tpr = torch.zeros(output_size, device=preds.device, dtype=torch.float)
223-
fpr = torch.zeros(output_size, device=preds.device, dtype=torch.float)
224-
new_idx = torch.arange(0, output_size, device=preds.device, dtype=torch.float)
225-
226-
# Loop over the labels, computing per-region tpr/fpr curves, and
227-
# aggregating them. Note that, since the groundtruth is different for
228-
# every all to `roc`, we also get different/unique tpr/fpr curves
229-
# (i.e. len(_fpr_idx) is different for every call).
230-
# We therefore need to resample per-region curves to a fixed sampling
231-
# ratio (defined above).
232-
labels = cca.unique()[1:] # 0 is background
233-
background = cca == 0
234-
fpr_: torch.Tensor
235-
tpr_: torch.Tensor
236-
for label in labels:
237-
interp: bool = False
238-
new_idx[-1] = output_size - 1
239-
mask = cca == label
240-
# Need to calculate label-wise roc on union of background & mask, as
241-
# otherwise we wrongly consider other label in labels as FPs.
242-
# We also don't need to return the thresholds
243-
fpr_, tpr_ = binary_roc(
244-
preds=preds[background | mask],
245-
target=mask[background | mask],
246-
thresholds=thresholds,
247-
)[:-1]
248-
249-
# catch edge-case where ROC only has fpr vals > self.fpr_limit
250-
if fpr_[fpr_ <= self.fpr_limit].max() == 0:
251-
fpr_limit_ = fpr_[fpr_ > self.fpr_limit].min()
252-
else:
253-
fpr_limit_ = self.fpr_limit
254-
255-
fpr_idx_ = torch.where(fpr_ <= fpr_limit_)[0]
256-
# if computed roc curve is not specified sufficiently close to
257-
# self.fpr_limit, we include the closest higher tpr/fpr pair and
258-
# linearly interpolate the tpr/fpr point at self.fpr_limit
259-
if not torch.allclose(fpr_[fpr_idx_].max(), self.fpr_limit):
260-
tmp_idx_ = torch.searchsorted(fpr_, self.fpr_limit)
261-
fpr_idx_ = torch.cat([fpr_idx_, tmp_idx_.unsqueeze_(0)])
262-
slope_ = 1 - ((fpr_[tmp_idx_] - self.fpr_limit) / (fpr_[tmp_idx_] - fpr_[tmp_idx_ - 1]))
263-
interp = True
264-
265-
fpr_ = fpr_[fpr_idx_]
266-
tpr_ = tpr_[fpr_idx_]
267-
268-
fpr_idx_ = fpr_idx_.float()
269-
fpr_idx_ /= fpr_idx_.max()
270-
fpr_idx_ *= new_idx.max()
271-
272-
if interp:
273-
# last point will be sampled at self.fpr_limit
274-
new_idx[-1] = fpr_idx_[-2] + ((fpr_idx_[-1] - fpr_idx_[-2]) * slope_)
275-
276-
tpr_ = self.interp1d(fpr_idx_, tpr_, new_idx)
277-
fpr_ = self.interp1d(fpr_idx_, fpr_, new_idx)
278-
tpr += tpr_
279-
fpr += fpr_
280-
281-
# Actually perform the averaging
282-
tpr /= labels.size(0)
283-
fpr /= labels.size(0)
284-
return fpr, tpr
259+
max_label = int(labels.max())
260+
if max_label == 0:
261+
f = float(self.fpr_limit)
262+
return (
263+
torch.tensor([0.0, f], device=device),
264+
torch.tensor([0.0, 0.0], device=device),
265+
)
285266

286-
def _compute(self) -> tuple[torch.Tensor, torch.Tensor]:
287-
"""Compute the PRO curve.
267+
region_sizes = torch.bincount(labels, minlength=max_label + 1).float()
268+
num_regions = (region_sizes[1:] > 0).sum()
269+
270+
if num_regions == 0:
271+
f = float(self.fpr_limit)
272+
return (
273+
torch.tensor([0.0, f], device=device),
274+
torch.tensor([0.0, 0.0], device=device),
275+
)
288276

289-
Perform the Connected Component Analysis first then compute the PRO curve.
277+
fg_mask = labels > 0
290278

291-
Returns:
292-
tuple[torch.Tensor, torch.Tensor]: Tuple containing final fpr and tpr
293-
values.
294-
"""
295-
cca = self.perform_cca().flatten()
296-
target = dim_zero_cat(self.target).flatten()
297-
preds = dim_zero_cat(self.preds).flatten()
279+
pro_change = torch.zeros_like(preds_flat)
280+
pro_change[fg_mask] = 1.0 / region_sizes[labels[fg_mask]]
281+
282+
# global sort
283+
idx = torch.argsort(preds_flat, descending=True)
284+
fp_sorted = fp_change[idx]
285+
pro_sorted = pro_change[idx]
286+
preds_sorted = preds_flat[idx]
287+
288+
# cumulative sums
289+
fpr = torch.cumsum(fp_sorted, 0) / num_bg
290+
pro = torch.cumsum(pro_sorted, 0) / num_regions
291+
fpr.clamp_(max=1.0)
292+
pro.clamp_(max=1.0)
293+
294+
# remove duplicate thresholds
295+
keep = torch.ones_like(preds_sorted, dtype=torch.bool)
296+
keep[:-1] = preds_sorted[:-1] != preds_sorted[1:]
297+
fpr = fpr[keep]
298+
pro = pro[keep]
299+
300+
# prepend zero
301+
fpr = torch.cat([torch.tensor([0.0], device=device), fpr])
302+
pro = torch.cat([torch.tensor([0.0], device=device), pro])
303+
304+
# FPR limit clipping
305+
f_lim = float(self.fpr_limit)
306+
mask = fpr <= f_lim
298307

299-
return self.compute_pro(cca=cca, target=target, preds=preds)
308+
if mask.any():
309+
i = mask.nonzero(as_tuple=True)[0][-1].item()
310+
311+
if fpr[i] < f_lim and i + 1 < fpr.numel():
312+
f1, f2 = fpr[i], fpr[i + 1]
313+
p1, p2 = pro[i], pro[i + 1]
314+
p_lim = p1 + (p2 - p1) * (f_lim - f1) / (f2 - f1)
315+
316+
fpr = torch.cat([fpr[: i + 1], torch.tensor([f_lim], device=device)])
317+
pro = torch.cat([pro[: i + 1], torch.tensor([p_lim], device=device)])
318+
else:
319+
fpr = fpr[: i + 1]
320+
pro = pro[: i + 1]
321+
else:
322+
fpr = torch.tensor([0.0, f_lim], device=device)
323+
pro = torch.tensor([0.0, 0.0], device=device)
324+
325+
return fpr, pro
326+
327+
def _compute(self) -> tuple[torch.Tensor, torch.Tensor]:
328+
"""Compute the PRO curve (FPR vs PRO) for all stored predictions."""
329+
cca = self.perform_cca() # (B, H, W)
330+
preds = dim_zero_cat(self.preds) # (B, 1, H, W) or (B, H, W)
331+
if preds.dim() > 3 and preds.size(1) == 1:
332+
preds = preds.squeeze(1)
333+
return self.compute_pro(cca=cca, preds=preds)
300334

301335
def compute(self) -> torch.Tensor:
302336
"""First compute PRO curve, then compute and scale area under the curve.
@@ -319,7 +353,7 @@ def generate_figure(self) -> tuple[Figure, str]:
319353
fpr, tpr = self._compute()
320354
aupro = self.compute()
321355

322-
xlim = (0.0, float(self.fpr_limit.detach().cpu().numpy()))
356+
xlim = (0.0, float(self.fpr_limit.detach().cpu().item()))
323357
ylim = (0.0, 1.0)
324358
xlabel = "Global FPR"
325359
ylabel = "Averaged Per-Region TPR"

0 commit comments

Comments
 (0)