Skip to content

Commit 1beedf5

Browse files
committed
Added support for gpu
Signed-off-by: Devansh Agarwal <[email protected]>
1 parent 7fea20f commit 1beedf5

File tree

4 files changed

+415
-182
lines changed

4 files changed

+415
-182
lines changed
Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,23 @@
1+
"""GLASS - Unsupervised anomaly detection via Gradient Ascent for Industrial Anomaly detection and localization.
2+
3+
This module implements the GLASS model for unsupervised anomaly detection and localization. GLASS synthesizes both
4+
global and local anomalies using Gaussian noise guided by gradient ascent to enhance weak defect detection in
5+
industrial settings.
6+
7+
The model consists of:
8+
- A feature extractor and feature adaptor to obtain robust normal representations
9+
- A Global Anomaly Synthesis (GAS) module that perturbs features using Gaussian noise and gradient ascent with
10+
truncated projection
11+
- A Local Anomaly Synthesis (LAS) module that overlays augmented textures onto images using Perlin noise masks
12+
- A shared discriminator trained with features from normal, global, and local synthetic samples
13+
14+
Paper: `A Unified Anomaly Synthesis Strategy with Gradient Ascent for Industrial Anomaly Detection and Localization
15+
<https://arxiv.org/pdf/2407.09359>`
16+
"""
17+
118
# Copyright (C) 2025 Intel Corporation
219
# SPDX-License-Identifier: Apache-2.0
320

4-
from .lightning_model import Glass as Glass
21+
from .lightning_model import Glass
22+
23+
__all__ = ["Glass"]

src/anomalib/models/image/glass/lightning_model.py

Lines changed: 74 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
"""GLASS - Unsupervised anomaly detection via Gradient Ascent for Industrial Anomaly detection and localization.
22
3-
This module implements the GLASS model for unsupervised anomaly detection and localization. GLASS synthesizes both global and local anomalies using Gaussian noise guided by gradient ascent to enhance weak defect detection in industrial settings.
3+
This module implements the GLASS model for unsupervised anomaly detection and localization. GLASS synthesizes both
4+
global and local anomalies using Gaussian noise guided by gradient ascent to enhance weak defect detection in
5+
industrial settings.
46
57
The model consists of:
68
- A feature extractor and feature adaptor to obtain robust normal representations
7-
- A Global Anomaly Synthesis (GAS) module that perturbs features using Gaussian noise and gradient ascent with truncated projection
9+
- A Global Anomaly Synthesis (GAS) module that perturbs features using Gaussian noise and gradient ascent with
10+
truncated projection
811
- A Local Anomaly Synthesis (LAS) module that overlays augmented textures onto images using Perlin noise masks
912
- A shared discriminator trained with features from normal, global, and local synthetic samples
1013
@@ -21,7 +24,7 @@
2124
import torch
2225
from lightning.pytorch.utilities.types import STEP_OUTPUT
2326
from torch import nn, optim
24-
from torch.nn import functional as F
27+
from torch.nn import functional as f
2528
from torchvision.transforms.v2 import CenterCrop, Compose, Normalize, Resize
2629

2730
from anomalib import LearningType
@@ -40,17 +43,21 @@
4043
class Glass(AnomalibModule):
4144
"""PyTorch Lightning Implementation of the GLASS Model.
4245
43-
The model uses a pre-trained feature extractor to extract features and a feature adaptor to mitigate latent domain bias.
46+
The model uses a pre-trained feature extractor to extract features and a feature adaptor to mitigate latent domain
47+
bias.
4448
Global anomaly features are synthesized from adapted normal features using gradient ascent.
45-
Local anomaly images are synthesized using texture overlay datasets like dtd which are then processed by feature extractor and feature adaptor.
49+
Local anomaly images are synthesized using texture overlay datasets like dtd which are then processed by feature
50+
extractor and feature adaptor.
4651
All three different features are passed to the discriminator trained using loss functions.
4752
4853
Args:
49-
input_shape (tuple[int, int]): Input image dimensions as a tuple of (height, width). Required for shaping the input pipeline.
50-
anomaly_source_path (str): Path to the dataset or source directory containing normal images and anomaly textures.
54+
input_shape (tuple[int, int]): Input image dimensions as a tuple of (height, width). Required for shaping the
55+
input pipeline.
56+
anomaly_source_path (str): Path to the dataset or source directory containing normal images and anomaly textures
5157
backbone (str, optional): Name of the CNN backbone used for feature extraction.
5258
Defaults to `"resnet18"`.
53-
pretrain_embed_dim (int, optional): Dimensionality of features extracted by the pre-trained backbone before adaptation.
59+
pretrain_embed_dim (int, optional): Dimensionality of features extracted by the pre-trained backbone before
60+
adaptation.
5461
Defaults to `1024`.
5562
target_embed_dim (int, optional): Dimensionality of the target adapted features after projection.
5663
Defaults to `1024`.
@@ -62,31 +69,37 @@ class Glass(AnomalibModule):
6269
Defaults to `True`.
6370
layers (list[str], optional): List of backbone layers to extract features from.
6471
Defaults to `["layer1", "layer2", "layer3"]`.
65-
pre_proj (int, optional): Number of projection layers used in the feature adaptor (e.g., MLP before discriminator).
72+
pre_proj (int, optional): Number of projection layers used in the feature adaptor (e.g., MLP before
73+
discriminator).
6674
Defaults to `1`.
6775
dsc_layers (int, optional): Number of layers in the discriminator network.
6876
Defaults to `2`.
6977
dsc_hidden (int, optional): Number of hidden units in each discriminator layer.
7078
Defaults to `1024`.
71-
dsc_margin (float, optional): Margin used for contrastive or binary classification loss in discriminator training.
79+
dsc_margin (float, optional): Margin used for contrastive or binary classification loss in discriminator
80+
training.
7281
Defaults to `0.5`.
7382
pre_processor (PreProcessor | bool, optional): reprocessing module or flag to enable default preprocessing.
7483
Set to `True` to apply default normalization and resizing.
7584
Defaults to `True`.
76-
post_processor (PostProcessor | bool, optional): Postprocessing module or flag to enable default output smoothing or thresholding.
85+
post_processor (PostProcessor | bool, optional): Postprocessing module or flag to enable default output
86+
smoothing or thresholding.
7787
Defaults to `True`.
7888
evaluator (Evaluator | bool, optional): Evaluation module for calculating metrics such as AUROC and PRO.
7989
Defaults to `True`.
80-
visualizer (Visualizer | bool, optional): Visualization module to generate heatmaps, segmentation overlays, and anomaly scores.
90+
visualizer (Visualizer | bool, optional): Visualization module to generate heatmaps, segmentation overlays, and
91+
anomaly scores.
8192
Defaults to `True`.
82-
mining (int, optional): Number of iterations or difficulty level for Online Hard Example Mining (OHEM) during training.
93+
mining (int, optional): Number of iterations or difficulty level for Online Hard Example Mining (OHEM) during
94+
training.
8395
Defaults to `1`.
8496
noise (float, optional): Standard deviation of Gaussian noise used in feature-level anomaly synthesis.
8597
Defaults to `0.015`.
8698
radius (float, optional): Radius parameter used for truncated projection in the anomaly synthesis strategy.
8799
Determines the range for valid synthetic anomalies in the hypersphere or manifold.
88100
Defaults to `0.75`.
89-
p (float, optional): Probability used in random selection logic, such as anomaly mask generation or augmentation choice.
101+
p (float, optional): Probability used in random selection logic, such as anomaly mask generation or augmentation
102+
choice.
90103
Defaults to `0.5`.
91104
lr (float, optional): Learning rate for training the feature adaptor and discriminator networks.
92105
Defaults to `0.0001`.
@@ -106,7 +119,7 @@ def __init__(
106119
patchsize: int = 3,
107120
patchstride: int = 1,
108121
pre_trained: bool = True,
109-
layers: list[str] = ["layer1", "layer2", "layer3"],
122+
layers: list[str] | None = None,
110123
pre_proj: int = 1,
111124
dsc_layers: int = 2,
112125
dsc_hidden: int = 1024,
@@ -122,14 +135,17 @@ def __init__(
122135
lr: float = 0.0001,
123136
step: int = 20,
124137
svd: int = 0,
125-
):
138+
) -> None:
126139
super().__init__(
127140
pre_processor=pre_processor,
128141
post_processor=post_processor,
129142
evaluator=evaluator,
130143
visualizer=visualizer,
131144
)
132145

146+
if layers is None:
147+
layers = ["layer1", "layer2", "layer3"]
148+
133149
self.augmentor = PerlinAnomalyGenerator(anomaly_source_path)
134150

135151
self.model = GlassModel(
@@ -157,6 +173,8 @@ def __init__(
157173
self.step = step
158174
self.svd = svd
159175

176+
self.dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
177+
160178
self.focal_loss = FocalLoss()
161179

162180
if pre_proj > 0:
@@ -170,7 +188,7 @@ def __init__(
170188

171189
if not pre_trained:
172190
self.backbone_opt = optim.AdamW(
173-
self.model.foward_modules["feature_aggregator"].backbone.parameters(),
191+
self.mosdel.forward_modules["feature_aggregator"].backbone.parameters(),
174192
self.lr,
175193
)
176194
else:
@@ -182,6 +200,30 @@ def configure_pre_processor(
182200
image_size: tuple[int, int] | None = None,
183201
center_crop_size: tuple[int, int] | None = None,
184202
) -> PreProcessor:
203+
"""Configure the default pre-processor for GLASS.
204+
205+
If valid center_crop_size is provided, the pre-processor will
206+
also perform center cropping, according to the paper.
207+
208+
Args:
209+
image_size (tuple[int, int] | None, optional): Target size for
210+
resizing. Defaults to ``(256, 256)``.
211+
center_crop_size (tuple[int, int] | None, optional): Size for center
212+
cropping. Defaults to ``None``.
213+
214+
Returns:
215+
PreProcessor: Configured pre-processor instance.
216+
217+
Raises:
218+
ValueError: If at least one dimension of ``center_crop_size`` is larger
219+
than correspondent ``image_size`` dimension.
220+
221+
Example:
222+
>>> pre_processor = Glass.configure_pre_processor(
223+
... image_size=(256, 256)
224+
... )
225+
>>> transformed_image = pre_processor(image)
226+
"""
185227
image_size = image_size or (256, 256)
186228

187229
if center_crop_size is not None:
@@ -201,10 +243,13 @@ def configure_pre_processor(
201243

202244
return PreProcessor(transform=transform)
203245

204-
def configure_optimizers(self) -> list[optim.Optimizer]:
205-
dsc_opt = optim.AdamW(self.model.discriminator.parameters(), lr=self.lr * 2)
246+
def configure_optimizers(self) -> optim.Optimizer:
247+
"""Configure optimizer for the discriminator.
206248
207-
return dsc_opt
249+
Returns:
250+
Optimizer: AdamW Optimizer for the discriminator.
251+
"""
252+
return optim.AdamW(self.model.discriminator.parameters(), lr=self.lr * 2)
208253

209254
def training_step(
210255
self,
@@ -220,6 +265,7 @@ def training_step(
220265
Returns:
221266
STEP_OUTPUT: Dictionary containing loss values and metrics
222267
"""
268+
del batch_idx
223269
dsc_opt = self.optimizers()
224270

225271
self.model.forward_modules.eval()
@@ -235,21 +281,22 @@ def training_step(
235281

236282
img = batch.image
237283
aug, mask_s = self.augmentor(img)
238-
batch_size = img.shape[0]
284+
if img is not None:
285+
batch_size = img.shape[0]
239286

240287
true_feats, fake_feats = self.model(img, aug)
241288

242289
h_ratio = mask_s.shape[2] // int(math.sqrt(fake_feats.shape[0] // batch_size))
243290
w_ratio = mask_s.shape[3] // int(math.sqrt(fake_feats.shape[0] // batch_size))
244291

245-
mask_s_resized = F.interpolate(
292+
mask_s_resized = f.interpolate(
246293
mask_s.float(),
247294
size=(mask_s.shape[2] // h_ratio, mask_s.shape[3] // w_ratio),
248295
mode="nearest",
249296
)
250297
mask_s_gt = mask_s_resized.reshape(-1, 1)
251298

252-
noise = torch.normal(0, self.noise, true_feats.shape)
299+
noise = torch.normal(0, self.noise, true_feats.shape).to(self.dev)
253300
gaus_feats = true_feats + noise
254301

255302
center = self.c.repeat(img.shape[0], 1, 1)
@@ -260,7 +307,7 @@ def training_step(
260307
)
261308
c_t_points = torch.concat([center[mask_s_gt[:, 0] == 0], center], dim=0)
262309
dist_t = torch.norm(true_points - c_t_points, dim=1)
263-
r_t = torch.tensor([torch.quantile(dist_t, q=self.radius)]).to(self.device)
310+
r_t = torch.tensor([torch.quantile(dist_t, q=self.radius)]).to(self.dev)
264311

265312
for step in range(self.step + 1):
266313
scores = self.model.discriminator(torch.cat([true_feats, gaus_feats]))
@@ -272,10 +319,6 @@ def training_step(
272319

273320
if step == self.step:
274321
break
275-
if self.mining == 0:
276-
dist_g = torch.norm(gaus_feats - center, dim=1)
277-
r_g = torch.tensor([torch.quantile(dist_g, q=self.radius)])
278-
break
279322

280323
grad = torch.autograd.grad(gaus_loss, [gaus_feats])[0]
281324
grad_norm = torch.norm(grad, dim=1)
@@ -326,7 +369,7 @@ def training_step(
326369
self.log("true_loss", true_loss, prog_bar=True)
327370
self.log("gaus_loss", gaus_loss, prog_bar=True)
328371
self.log("bce_loss", bce_loss, prog_bar=True)
329-
self.log("focal_losss", focal_loss, prog_bar=True)
372+
self.log("focal_loss", focal_loss, prog_bar=True)
330373
self.log("loss", loss, prog_bar=True)
331374

332375
def on_train_start(self) -> None:
@@ -340,9 +383,9 @@ def on_train_start(self) -> None:
340383
with torch.no_grad():
341384
for i, batch in enumerate(dataloader):
342385
if i == 0:
343-
self.c = self.model.calculate_mean(batch.image)
386+
self.c = self.model.calculate_mean(batch.image.to(self.dev))
344387
else:
345-
self.c += self.model.calculate_mean(batch.image)
388+
self.c += self.model.calculate_mean(batch.image.to(self.dev))
346389

347390
self.c /= len(dataloader)
348391

0 commit comments

Comments
 (0)