1515# Copyright (C) 2025 Intel Corporation
1616# SPDX-License-Identifier: Apache-2.0
1717
18+ import math
1819from typing import Any
1920
2021import torch
2122from lightning .pytorch .utilities .types import STEP_OUTPUT
2223from torch import nn , optim
24+ from torch .nn import functional as F
25+ from torchvision .transforms .v2 import CenterCrop , Compose , Normalize , Resize
2326
2427from anomalib import LearningType
2528from anomalib .data import Batch
29+ from anomalib .data .utils .generators .perlin import PerlinAnomalyGenerator
2630from anomalib .metrics import Evaluator
2731from anomalib .models .components import AnomalibModule
2832from anomalib .post_processing import PostProcessor
3236from .loss import FocalLoss
3337from .torch_model import GlassModel
3438
35- from anomalib .data .utils .generators .perlin import PerlinAnomalyGenerator
36-
3739
3840class Glass (AnomalibModule ):
39- """PyTorch Lightning Implementation of the GLASS Model
41+ """PyTorch Lightning Implementation of the GLASS Model.
4042
4143 The model uses a pre-trained feature extractor to extract features and a feature adaptor to mitigate latent domain bias.
4244 Global anomaly features are synthesized from adapted normal features using gradient ascent.
@@ -88,7 +90,10 @@ class Glass(AnomalibModule):
8890 Defaults to `0.5`.
8991 lr (float, optional): Learning rate for training the feature adaptor and discriminator networks.
9092 Defaults to `0.0001`.
91- step (int, optional): Number of gradient ascent steps or
93+ step (int, optional): Number of gradient ascent steps for anomaly synthesis.
94+ Defaults to `20`.
95+ svd (int, optional): Flag to enable SVD-based feature projection.
96+ Defaults to `0`.
9297 """
9398
9499 def __init__ (
@@ -116,6 +121,7 @@ def __init__(
116121 p : float = 0.5 ,
117122 lr : float = 0.0001 ,
118123 step : int = 20 ,
124+ svd : int = 0 ,
119125 ):
120126 super ().__init__ (
121127 pre_processor = pre_processor ,
@@ -149,12 +155,15 @@ def __init__(
149155 self .distribution = 0
150156 self .lr = lr
151157 self .step = step
158+ self .svd = svd
152159
153160 self .focal_loss = FocalLoss ()
154161
155162 if pre_proj > 0 :
156163 self .proj_opt = optim .AdamW (
157- self .model .pre_projection .parameters (), self .lr , weight_decay = 1e-5
164+ self .model .pre_projection .parameters (),
165+ self .lr ,
166+ weight_decay = 1e-5 ,
158167 )
159168 else :
160169 self .proj_opt = None
@@ -167,6 +176,31 @@ def __init__(
167176 else :
168177 self .backbone_opt = None
169178
179+ @classmethod
180+ def configure_pre_processor (
181+ cls ,
182+ image_size : tuple [int , int ] | None = None ,
183+ center_crop_size : tuple [int , int ] | None = None ,
184+ ) -> PreProcessor :
185+ image_size = image_size or (256 , 256 )
186+
187+ if center_crop_size is not None :
188+ if center_crop_size [0 ] > image_size [0 ] or center_crop_size [1 ] > image_size [1 ]:
189+ msg = f"Center crop size { center_crop_size } cannot be larger than image size { image_size } ."
190+ raise ValueError (msg )
191+ transform = Compose ([
192+ Resize (image_size , antialias = True ),
193+ CenterCrop (center_crop_size ),
194+ Normalize (mean = [0.485 , 0.456 , 0.406 ], std = [0.229 , 0.224 , 0.225 ]),
195+ ])
196+ else :
197+ transform = Compose ([
198+ Resize (image_size , antialias = True ),
199+ Normalize (mean = [0.485 , 0.456 , 0.406 ], std = [0.229 , 0.224 , 0.225 ]),
200+ ])
201+
202+ return PreProcessor (transform = transform )
203+
170204 def configure_optimizers (self ) -> list [optim .Optimizer ]:
171205 dsc_opt = optim .AdamW (self .model .discriminator .parameters (), lr = self .lr * 2 )
172206
@@ -177,6 +211,15 @@ def training_step(
177211 batch : Batch ,
178212 batch_idx : int ,
179213 ) -> STEP_OUTPUT :
214+ """Training step for GLASS model.
215+
216+ Args:
217+ batch (Batch): Input batch containing images and metadata
218+ batch_idx (int): Index of the current batch
219+
220+ Returns:
221+ STEP_OUTPUT: Dictionary containing loss values and metrics
222+ """
180223 dsc_opt = self .optimizers ()
181224
182225 self .model .forward_modules .eval ()
@@ -192,17 +235,28 @@ def training_step(
192235
193236 img = batch .image
194237 aug , mask_s = self .augmentor (img )
238+ batch_size = img .shape [0 ]
195239
196240 true_feats , fake_feats = self .model (img , aug )
197241
198- mask_s_gt = mask_s .reshape (- 1 , 1 )
242+ h_ratio = mask_s .shape [2 ] // int (math .sqrt (fake_feats .shape [0 ] // batch_size ))
243+ w_ratio = mask_s .shape [3 ] // int (math .sqrt (fake_feats .shape [0 ] // batch_size ))
244+
245+ mask_s_resized = F .interpolate (
246+ mask_s .float (),
247+ size = (mask_s .shape [2 ] // h_ratio , mask_s .shape [3 ] // w_ratio ),
248+ mode = "nearest" ,
249+ )
250+ mask_s_gt = mask_s_resized .reshape (- 1 , 1 )
251+
199252 noise = torch .normal (0 , self .noise , true_feats .shape )
200253 gaus_feats = true_feats + noise
201254
202255 center = self .c .repeat (img .shape [0 ], 1 , 1 )
203256 center = center .reshape (- 1 , center .shape [- 1 ])
204257 true_points = torch .concat (
205- [fake_feats [mask_s_gt [:, 0 ] == 0 ], true_feats ], dim = 0
258+ [fake_feats [mask_s_gt [:, 0 ] == 0 ], true_feats ],
259+ dim = 0 ,
206260 )
207261 c_t_points = torch .concat ([center [mask_s_gt [:, 0 ] == 0 ], center ], dim = 0 )
208262 dist_t = torch .norm (true_points - c_t_points , dim = 1 )
@@ -235,7 +289,6 @@ def training_step(
235289 true_points = true_feats [mask_s_gt [:, 0 ] == 1 ]
236290 c_f_points = center [mask_s_gt [:, 0 ] == 1 ]
237291 dist_f = torch .norm (fake_points - c_f_points , dim = 1 )
238- r_f = torch .tensor ([torch .quantile (dist_f , q = self .radius )]).to (self .device )
239292 proj_feats = c_f_points if self .svd == 1 else true_points
240293 r = r_t if self .svd == 1 else 1
241294
@@ -270,7 +323,18 @@ def training_step(
270323 self .backbone_opt .step ()
271324 dsc_opt .step ()
272325
326+ self .log ("true_loss" , true_loss , prog_bar = True )
327+ self .log ("gaus_loss" , gaus_loss , prog_bar = True )
328+ self .log ("bce_loss" , bce_loss , prog_bar = True )
329+ self .log ("focal_losss" , focal_loss , prog_bar = True )
330+ self .log ("loss" , loss , prog_bar = True )
331+
273332 def on_train_start (self ) -> None :
333+ """Initialize model by computing mean feature representation across training dataset.
334+
335+ This method is called at the start of training and computes a mean feature vector
336+ that serves as a reference point for the normal class distribution.
337+ """
274338 dataloader = self .trainer .train_dataloader
275339
276340 with torch .no_grad ():
@@ -293,6 +357,9 @@ def learning_type(self) -> LearningType:
293357
294358 @property
295359 def trainer_arguments (self ) -> dict [str , Any ]:
296- """Return GLASS trainer arguments."""
360+ """Return GLASS trainer arguments.
361+
362+ Returns:
363+ dict[str, Any]: Dictionary containing trainer configuration
364+ """
297365 return {"gradient_clip_val" : 0 , "num_sanity_val_steps" : 0 }
298- # TODO
0 commit comments