11"""GLASS - Unsupervised anomaly detection via Gradient Ascent for Industrial Anomaly detection and localization.
22
3- This module implements the GLASS model for unsupervised anomaly detection and localization. GLASS synthesizes both global and local anomalies using Gaussian noise guided by gradient ascent to enhance weak defect detection in industrial settings.
3+ This module implements the GLASS model for unsupervised anomaly detection and localization. GLASS synthesizes both
4+ global and local anomalies using Gaussian noise guided by gradient ascent to enhance weak defect detection in
5+ industrial settings.
46
57The model consists of:
68 - A feature extractor and feature adaptor to obtain robust normal representations
7- - A Global Anomaly Synthesis (GAS) module that perturbs features using Gaussian noise and gradient ascent with truncated projection
9+ - A Global Anomaly Synthesis (GAS) module that perturbs features using Gaussian noise and gradient ascent with
10+ truncated projection
811 - A Local Anomaly Synthesis (LAS) module that overlays augmented textures onto images using Perlin noise masks
912 - A shared discriminator trained with features from normal, global, and local synthetic samples
1013
2124import torch
2225from lightning .pytorch .utilities .types import STEP_OUTPUT
2326from torch import nn , optim
24- from torch .nn import functional as F
27+ from torch .nn import functional as f
2528from torchvision .transforms .v2 import CenterCrop , Compose , Normalize , Resize
2629
2730from anomalib import LearningType
4043class Glass (AnomalibModule ):
4144 """PyTorch Lightning Implementation of the GLASS Model.
4245
43- The model uses a pre-trained feature extractor to extract features and a feature adaptor to mitigate latent domain bias.
46+ The model uses a pre-trained feature extractor to extract features and a feature adaptor to mitigate latent domain
47+ bias.
4448 Global anomaly features are synthesized from adapted normal features using gradient ascent.
45- Local anomaly images are synthesized using texture overlay datasets like dtd which are then processed by feature extractor and feature adaptor.
49+ Local anomaly images are synthesized using texture overlay datasets like dtd which are then processed by feature
50+ extractor and feature adaptor.
4651 All three different features are passed to the discriminator trained using loss functions.
4752
4853 Args:
49- input_shape (tuple[int, int]): Input image dimensions as a tuple of (height, width). Required for shaping the input pipeline.
50- anomaly_source_path (str): Path to the dataset or source directory containing normal images and anomaly textures.
54+ input_shape (tuple[int, int]): Input image dimensions as a tuple of (height, width). Required for shaping the
55+ input pipeline.
56+ anomaly_source_path (str): Path to the dataset or source directory containing normal images and anomaly textures
5157 backbone (str, optional): Name of the CNN backbone used for feature extraction.
5258 Defaults to `"resnet18"`.
53- pretrain_embed_dim (int, optional): Dimensionality of features extracted by the pre-trained backbone before adaptation.
59+ pretrain_embed_dim (int, optional): Dimensionality of features extracted by the pre-trained backbone before
60+ adaptation.
5461 Defaults to `1024`.
5562 target_embed_dim (int, optional): Dimensionality of the target adapted features after projection.
5663 Defaults to `1024`.
@@ -62,31 +69,37 @@ class Glass(AnomalibModule):
6269 Defaults to `True`.
6370 layers (list[str], optional): List of backbone layers to extract features from.
6471 Defaults to `["layer1", "layer2", "layer3"]`.
65- pre_proj (int, optional): Number of projection layers used in the feature adaptor (e.g., MLP before discriminator).
72+ pre_proj (int, optional): Number of projection layers used in the feature adaptor (e.g., MLP before
73+ discriminator).
6674 Defaults to `1`.
6775 dsc_layers (int, optional): Number of layers in the discriminator network.
6876 Defaults to `2`.
6977 dsc_hidden (int, optional): Number of hidden units in each discriminator layer.
7078 Defaults to `1024`.
71- dsc_margin (float, optional): Margin used for contrastive or binary classification loss in discriminator training.
79+ dsc_margin (float, optional): Margin used for contrastive or binary classification loss in discriminator
80+ training.
7281 Defaults to `0.5`.
7382 pre_processor (PreProcessor | bool, optional): reprocessing module or flag to enable default preprocessing.
7483 Set to `True` to apply default normalization and resizing.
7584 Defaults to `True`.
76- post_processor (PostProcessor | bool, optional): Postprocessing module or flag to enable default output smoothing or thresholding.
85+ post_processor (PostProcessor | bool, optional): Postprocessing module or flag to enable default output
86+ smoothing or thresholding.
7787 Defaults to `True`.
7888 evaluator (Evaluator | bool, optional): Evaluation module for calculating metrics such as AUROC and PRO.
7989 Defaults to `True`.
80- visualizer (Visualizer | bool, optional): Visualization module to generate heatmaps, segmentation overlays, and anomaly scores.
90+ visualizer (Visualizer | bool, optional): Visualization module to generate heatmaps, segmentation overlays, and
91+ anomaly scores.
8192 Defaults to `True`.
82- mining (int, optional): Number of iterations or difficulty level for Online Hard Example Mining (OHEM) during training.
93+ mining (int, optional): Number of iterations or difficulty level for Online Hard Example Mining (OHEM) during
94+ training.
8395 Defaults to `1`.
8496 noise (float, optional): Standard deviation of Gaussian noise used in feature-level anomaly synthesis.
8597 Defaults to `0.015`.
8698 radius (float, optional): Radius parameter used for truncated projection in the anomaly synthesis strategy.
8799 Determines the range for valid synthetic anomalies in the hypersphere or manifold.
88100 Defaults to `0.75`.
89- p (float, optional): Probability used in random selection logic, such as anomaly mask generation or augmentation choice.
101+ p (float, optional): Probability used in random selection logic, such as anomaly mask generation or augmentation
102+ choice.
90103 Defaults to `0.5`.
91104 lr (float, optional): Learning rate for training the feature adaptor and discriminator networks.
92105 Defaults to `0.0001`.
@@ -106,7 +119,7 @@ def __init__(
106119 patchsize : int = 3 ,
107120 patchstride : int = 1 ,
108121 pre_trained : bool = True ,
109- layers : list [str ] = [ "layer1" , "layer2" , "layer3" ] ,
122+ layers : list [str ] | None = None ,
110123 pre_proj : int = 1 ,
111124 dsc_layers : int = 2 ,
112125 dsc_hidden : int = 1024 ,
@@ -122,14 +135,17 @@ def __init__(
122135 lr : float = 0.0001 ,
123136 step : int = 20 ,
124137 svd : int = 0 ,
125- ):
138+ ) -> None :
126139 super ().__init__ (
127140 pre_processor = pre_processor ,
128141 post_processor = post_processor ,
129142 evaluator = evaluator ,
130143 visualizer = visualizer ,
131144 )
132145
146+ if layers is None :
147+ layers = ["layer1" , "layer2" , "layer3" ]
148+
133149 self .augmentor = PerlinAnomalyGenerator (anomaly_source_path )
134150
135151 self .model = GlassModel (
@@ -157,6 +173,8 @@ def __init__(
157173 self .step = step
158174 self .svd = svd
159175
176+ self .dev = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
177+
160178 self .focal_loss = FocalLoss ()
161179
162180 if pre_proj > 0 :
@@ -170,7 +188,7 @@ def __init__(
170188
171189 if not pre_trained :
172190 self .backbone_opt = optim .AdamW (
173- self .model . foward_modules ["feature_aggregator" ].backbone .parameters (),
191+ self .mosdel . forward_modules ["feature_aggregator" ].backbone .parameters (),
174192 self .lr ,
175193 )
176194 else :
@@ -182,6 +200,30 @@ def configure_pre_processor(
182200 image_size : tuple [int , int ] | None = None ,
183201 center_crop_size : tuple [int , int ] | None = None ,
184202 ) -> PreProcessor :
203+ """Configure the default pre-processor for GLASS.
204+
205+ If valid center_crop_size is provided, the pre-processor will
206+ also perform center cropping, according to the paper.
207+
208+ Args:
209+ image_size (tuple[int, int] | None, optional): Target size for
210+ resizing. Defaults to ``(256, 256)``.
211+ center_crop_size (tuple[int, int] | None, optional): Size for center
212+ cropping. Defaults to ``None``.
213+
214+ Returns:
215+ PreProcessor: Configured pre-processor instance.
216+
217+ Raises:
218+ ValueError: If at least one dimension of ``center_crop_size`` is larger
219+ than correspondent ``image_size`` dimension.
220+
221+ Example:
222+ >>> pre_processor = Glass.configure_pre_processor(
223+ ... image_size=(256, 256)
224+ ... )
225+ >>> transformed_image = pre_processor(image)
226+ """
185227 image_size = image_size or (256 , 256 )
186228
187229 if center_crop_size is not None :
@@ -201,10 +243,13 @@ def configure_pre_processor(
201243
202244 return PreProcessor (transform = transform )
203245
204- def configure_optimizers (self ) -> list [ optim .Optimizer ] :
205- dsc_opt = optim . AdamW ( self . model . discriminator . parameters (), lr = self . lr * 2 )
246+ def configure_optimizers (self ) -> optim .Optimizer :
247+ """Configure optimizer for the discriminator.
206248
207- return dsc_opt
249+ Returns:
250+ Optimizer: AdamW Optimizer for the discriminator.
251+ """
252+ return optim .AdamW (self .model .discriminator .parameters (), lr = self .lr * 2 )
208253
209254 def training_step (
210255 self ,
@@ -220,6 +265,7 @@ def training_step(
220265 Returns:
221266 STEP_OUTPUT: Dictionary containing loss values and metrics
222267 """
268+ del batch_idx
223269 dsc_opt = self .optimizers ()
224270
225271 self .model .forward_modules .eval ()
@@ -235,21 +281,22 @@ def training_step(
235281
236282 img = batch .image
237283 aug , mask_s = self .augmentor (img )
238- batch_size = img .shape [0 ]
284+ if img is not None :
285+ batch_size = img .shape [0 ]
239286
240287 true_feats , fake_feats = self .model (img , aug )
241288
242289 h_ratio = mask_s .shape [2 ] // int (math .sqrt (fake_feats .shape [0 ] // batch_size ))
243290 w_ratio = mask_s .shape [3 ] // int (math .sqrt (fake_feats .shape [0 ] // batch_size ))
244291
245- mask_s_resized = F .interpolate (
292+ mask_s_resized = f .interpolate (
246293 mask_s .float (),
247294 size = (mask_s .shape [2 ] // h_ratio , mask_s .shape [3 ] // w_ratio ),
248295 mode = "nearest" ,
249296 )
250297 mask_s_gt = mask_s_resized .reshape (- 1 , 1 )
251298
252- noise = torch .normal (0 , self .noise , true_feats .shape )
299+ noise = torch .normal (0 , self .noise , true_feats .shape ). to ( self . dev )
253300 gaus_feats = true_feats + noise
254301
255302 center = self .c .repeat (img .shape [0 ], 1 , 1 )
@@ -260,7 +307,7 @@ def training_step(
260307 )
261308 c_t_points = torch .concat ([center [mask_s_gt [:, 0 ] == 0 ], center ], dim = 0 )
262309 dist_t = torch .norm (true_points - c_t_points , dim = 1 )
263- r_t = torch .tensor ([torch .quantile (dist_t , q = self .radius )]).to (self .device )
310+ r_t = torch .tensor ([torch .quantile (dist_t , q = self .radius )]).to (self .dev )
264311
265312 for step in range (self .step + 1 ):
266313 scores = self .model .discriminator (torch .cat ([true_feats , gaus_feats ]))
@@ -272,10 +319,6 @@ def training_step(
272319
273320 if step == self .step :
274321 break
275- if self .mining == 0 :
276- dist_g = torch .norm (gaus_feats - center , dim = 1 )
277- r_g = torch .tensor ([torch .quantile (dist_g , q = self .radius )])
278- break
279322
280323 grad = torch .autograd .grad (gaus_loss , [gaus_feats ])[0 ]
281324 grad_norm = torch .norm (grad , dim = 1 )
@@ -326,7 +369,7 @@ def training_step(
326369 self .log ("true_loss" , true_loss , prog_bar = True )
327370 self .log ("gaus_loss" , gaus_loss , prog_bar = True )
328371 self .log ("bce_loss" , bce_loss , prog_bar = True )
329- self .log ("focal_losss " , focal_loss , prog_bar = True )
372+ self .log ("focal_loss " , focal_loss , prog_bar = True )
330373 self .log ("loss" , loss , prog_bar = True )
331374
332375 def on_train_start (self ) -> None :
@@ -340,9 +383,9 @@ def on_train_start(self) -> None:
340383 with torch .no_grad ():
341384 for i , batch in enumerate (dataloader ):
342385 if i == 0 :
343- self .c = self .model .calculate_mean (batch .image )
386+ self .c = self .model .calculate_mean (batch .image . to ( self . dev ) )
344387 else :
345- self .c += self .model .calculate_mean (batch .image )
388+ self .c += self .model .calculate_mean (batch .image . to ( self . dev ) )
346389
347390 self .c /= len (dataloader )
348391
0 commit comments