Skip to content

Commit 7fea20f

Browse files
committed
Matched code to the original implementation
Signed-off-by: Devansh Agarwal <[email protected]>
1 parent f9d3207 commit 7fea20f

File tree

2 files changed

+155
-30
lines changed

2 files changed

+155
-30
lines changed

src/anomalib/models/image/glass/lightning_model.py

Lines changed: 77 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,18 @@
1515
# Copyright (C) 2025 Intel Corporation
1616
# SPDX-License-Identifier: Apache-2.0
1717

18+
import math
1819
from typing import Any
1920

2021
import torch
2122
from lightning.pytorch.utilities.types import STEP_OUTPUT
2223
from torch import nn, optim
24+
from torch.nn import functional as F
25+
from torchvision.transforms.v2 import CenterCrop, Compose, Normalize, Resize
2326

2427
from anomalib import LearningType
2528
from anomalib.data import Batch
29+
from anomalib.data.utils.generators.perlin import PerlinAnomalyGenerator
2630
from anomalib.metrics import Evaluator
2731
from anomalib.models.components import AnomalibModule
2832
from anomalib.post_processing import PostProcessor
@@ -32,11 +36,9 @@
3236
from .loss import FocalLoss
3337
from .torch_model import GlassModel
3438

35-
from anomalib.data.utils.generators.perlin import PerlinAnomalyGenerator
36-
3739

3840
class Glass(AnomalibModule):
39-
"""PyTorch Lightning Implementation of the GLASS Model
41+
"""PyTorch Lightning Implementation of the GLASS Model.
4042
4143
The model uses a pre-trained feature extractor to extract features and a feature adaptor to mitigate latent domain bias.
4244
Global anomaly features are synthesized from adapted normal features using gradient ascent.
@@ -88,7 +90,10 @@ class Glass(AnomalibModule):
8890
Defaults to `0.5`.
8991
lr (float, optional): Learning rate for training the feature adaptor and discriminator networks.
9092
Defaults to `0.0001`.
91-
step (int, optional): Number of gradient ascent steps or
93+
step (int, optional): Number of gradient ascent steps for anomaly synthesis.
94+
Defaults to `20`.
95+
svd (int, optional): Flag to enable SVD-based feature projection.
96+
Defaults to `0`.
9297
"""
9398

9499
def __init__(
@@ -116,6 +121,7 @@ def __init__(
116121
p: float = 0.5,
117122
lr: float = 0.0001,
118123
step: int = 20,
124+
svd: int = 0,
119125
):
120126
super().__init__(
121127
pre_processor=pre_processor,
@@ -149,12 +155,15 @@ def __init__(
149155
self.distribution = 0
150156
self.lr = lr
151157
self.step = step
158+
self.svd = svd
152159

153160
self.focal_loss = FocalLoss()
154161

155162
if pre_proj > 0:
156163
self.proj_opt = optim.AdamW(
157-
self.model.pre_projection.parameters(), self.lr, weight_decay=1e-5
164+
self.model.pre_projection.parameters(),
165+
self.lr,
166+
weight_decay=1e-5,
158167
)
159168
else:
160169
self.proj_opt = None
@@ -167,6 +176,31 @@ def __init__(
167176
else:
168177
self.backbone_opt = None
169178

179+
@classmethod
180+
def configure_pre_processor(
181+
cls,
182+
image_size: tuple[int, int] | None = None,
183+
center_crop_size: tuple[int, int] | None = None,
184+
) -> PreProcessor:
185+
image_size = image_size or (256, 256)
186+
187+
if center_crop_size is not None:
188+
if center_crop_size[0] > image_size[0] or center_crop_size[1] > image_size[1]:
189+
msg = f"Center crop size {center_crop_size} cannot be larger than image size {image_size}."
190+
raise ValueError(msg)
191+
transform = Compose([
192+
Resize(image_size, antialias=True),
193+
CenterCrop(center_crop_size),
194+
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
195+
])
196+
else:
197+
transform = Compose([
198+
Resize(image_size, antialias=True),
199+
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
200+
])
201+
202+
return PreProcessor(transform=transform)
203+
170204
def configure_optimizers(self) -> list[optim.Optimizer]:
171205
dsc_opt = optim.AdamW(self.model.discriminator.parameters(), lr=self.lr * 2)
172206

@@ -177,6 +211,15 @@ def training_step(
177211
batch: Batch,
178212
batch_idx: int,
179213
) -> STEP_OUTPUT:
214+
"""Training step for GLASS model.
215+
216+
Args:
217+
batch (Batch): Input batch containing images and metadata
218+
batch_idx (int): Index of the current batch
219+
220+
Returns:
221+
STEP_OUTPUT: Dictionary containing loss values and metrics
222+
"""
180223
dsc_opt = self.optimizers()
181224

182225
self.model.forward_modules.eval()
@@ -192,17 +235,28 @@ def training_step(
192235

193236
img = batch.image
194237
aug, mask_s = self.augmentor(img)
238+
batch_size = img.shape[0]
195239

196240
true_feats, fake_feats = self.model(img, aug)
197241

198-
mask_s_gt = mask_s.reshape(-1, 1)
242+
h_ratio = mask_s.shape[2] // int(math.sqrt(fake_feats.shape[0] // batch_size))
243+
w_ratio = mask_s.shape[3] // int(math.sqrt(fake_feats.shape[0] // batch_size))
244+
245+
mask_s_resized = F.interpolate(
246+
mask_s.float(),
247+
size=(mask_s.shape[2] // h_ratio, mask_s.shape[3] // w_ratio),
248+
mode="nearest",
249+
)
250+
mask_s_gt = mask_s_resized.reshape(-1, 1)
251+
199252
noise = torch.normal(0, self.noise, true_feats.shape)
200253
gaus_feats = true_feats + noise
201254

202255
center = self.c.repeat(img.shape[0], 1, 1)
203256
center = center.reshape(-1, center.shape[-1])
204257
true_points = torch.concat(
205-
[fake_feats[mask_s_gt[:, 0] == 0], true_feats], dim=0
258+
[fake_feats[mask_s_gt[:, 0] == 0], true_feats],
259+
dim=0,
206260
)
207261
c_t_points = torch.concat([center[mask_s_gt[:, 0] == 0], center], dim=0)
208262
dist_t = torch.norm(true_points - c_t_points, dim=1)
@@ -235,7 +289,6 @@ def training_step(
235289
true_points = true_feats[mask_s_gt[:, 0] == 1]
236290
c_f_points = center[mask_s_gt[:, 0] == 1]
237291
dist_f = torch.norm(fake_points - c_f_points, dim=1)
238-
r_f = torch.tensor([torch.quantile(dist_f, q=self.radius)]).to(self.device)
239292
proj_feats = c_f_points if self.svd == 1 else true_points
240293
r = r_t if self.svd == 1 else 1
241294

@@ -270,7 +323,18 @@ def training_step(
270323
self.backbone_opt.step()
271324
dsc_opt.step()
272325

326+
self.log("true_loss", true_loss, prog_bar=True)
327+
self.log("gaus_loss", gaus_loss, prog_bar=True)
328+
self.log("bce_loss", bce_loss, prog_bar=True)
329+
self.log("focal_losss", focal_loss, prog_bar=True)
330+
self.log("loss", loss, prog_bar=True)
331+
273332
def on_train_start(self) -> None:
333+
"""Initialize model by computing mean feature representation across training dataset.
334+
335+
This method is called at the start of training and computes a mean feature vector
336+
that serves as a reference point for the normal class distribution.
337+
"""
274338
dataloader = self.trainer.train_dataloader
275339

276340
with torch.no_grad():
@@ -293,6 +357,9 @@ def learning_type(self) -> LearningType:
293357

294358
@property
295359
def trainer_arguments(self) -> dict[str, Any]:
296-
"""Return GLASS trainer arguments."""
360+
"""Return GLASS trainer arguments.
361+
362+
Returns:
363+
dict[str, Any]: Dictionary containing trainer configuration
364+
"""
297365
return {"gradient_clip_val": 0, "num_sanity_val_steps": 0}
298-
# TODO

0 commit comments

Comments
 (0)