Skip to content

Commit 838bc50

Browse files
committed
Refactored code from lightning model to torch model
Signed-off-by: Devansh Agarwal <[email protected]>
1 parent 1beedf5 commit 838bc50

File tree

2 files changed

+148
-107
lines changed

2 files changed

+148
-107
lines changed

src/anomalib/models/image/glass/lightning_model.py

Lines changed: 9 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,11 @@
1818
# Copyright (C) 2025 Intel Corporation
1919
# SPDX-License-Identifier: Apache-2.0
2020

21-
import math
2221
from typing import Any
2322

2423
import torch
2524
from lightning.pytorch.utilities.types import STEP_OUTPUT
26-
from torch import nn, optim
27-
from torch.nn import functional as f
25+
from torch import optim
2826
from torchvision.transforms.v2 import CenterCrop, Compose, Normalize, Resize
2927

3028
from anomalib import LearningType
@@ -36,7 +34,6 @@
3634
from anomalib.pre_processing import PreProcessor
3735
from anomalib.visualization import Visualizer
3836

39-
from .loss import FocalLoss
4037
from .torch_model import GlassModel
4138

4239

@@ -150,6 +147,7 @@ def __init__(
150147

151148
self.model = GlassModel(
152149
input_shape=input_shape,
150+
anomaly_source_path=anomaly_source_path,
153151
pretrain_embed_dim=pretrain_embed_dim,
154152
target_embed_dim=target_embed_dim,
155153
backbone=backbone,
@@ -161,22 +159,19 @@ def __init__(
161159
dsc_layers=dsc_layers,
162160
dsc_hidden=dsc_hidden,
163161
dsc_margin=dsc_margin,
162+
step=step,
163+
svd=svd,
164+
mining=mining,
165+
noise=noise,
166+
radius=radius,
167+
p=p,
164168
)
165169

166170
self.c = torch.tensor([1])
167-
self.p = p
168-
self.radius = radius
169-
self.mining = mining
170-
self.noise = noise
171-
self.distribution = 0
172171
self.lr = lr
173-
self.step = step
174-
self.svd = svd
175172

176173
self.dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
177174

178-
self.focal_loss = FocalLoss()
179-
180175
if pre_proj > 0:
181176
self.proj_opt = optim.AdamW(
182177
self.model.pre_projection.parameters(),
@@ -280,84 +275,7 @@ def training_step(
280275
self.backbone_opt.zero_grad()
281276

282277
img = batch.image
283-
aug, mask_s = self.augmentor(img)
284-
if img is not None:
285-
batch_size = img.shape[0]
286-
287-
true_feats, fake_feats = self.model(img, aug)
288-
289-
h_ratio = mask_s.shape[2] // int(math.sqrt(fake_feats.shape[0] // batch_size))
290-
w_ratio = mask_s.shape[3] // int(math.sqrt(fake_feats.shape[0] // batch_size))
291-
292-
mask_s_resized = f.interpolate(
293-
mask_s.float(),
294-
size=(mask_s.shape[2] // h_ratio, mask_s.shape[3] // w_ratio),
295-
mode="nearest",
296-
)
297-
mask_s_gt = mask_s_resized.reshape(-1, 1)
298-
299-
noise = torch.normal(0, self.noise, true_feats.shape).to(self.dev)
300-
gaus_feats = true_feats + noise
301-
302-
center = self.c.repeat(img.shape[0], 1, 1)
303-
center = center.reshape(-1, center.shape[-1])
304-
true_points = torch.concat(
305-
[fake_feats[mask_s_gt[:, 0] == 0], true_feats],
306-
dim=0,
307-
)
308-
c_t_points = torch.concat([center[mask_s_gt[:, 0] == 0], center], dim=0)
309-
dist_t = torch.norm(true_points - c_t_points, dim=1)
310-
r_t = torch.tensor([torch.quantile(dist_t, q=self.radius)]).to(self.dev)
311-
312-
for step in range(self.step + 1):
313-
scores = self.model.discriminator(torch.cat([true_feats, gaus_feats]))
314-
true_scores = scores[: len(true_feats)]
315-
gaus_scores = scores[len(true_feats) :]
316-
true_loss = nn.BCELoss()(true_scores, torch.zeros_like(true_scores))
317-
gaus_loss = nn.BCELoss()(gaus_scores, torch.ones_like(gaus_scores))
318-
bce_loss = true_loss + gaus_loss
319-
320-
if step == self.step:
321-
break
322-
323-
grad = torch.autograd.grad(gaus_loss, [gaus_feats])[0]
324-
grad_norm = torch.norm(grad, dim=1)
325-
grad_norm = grad_norm.view(-1, 1)
326-
grad_normalized = grad / (grad_norm + 1e-10)
327-
328-
with torch.no_grad():
329-
gaus_feats.add_(0.001 * grad_normalized)
330-
331-
fake_points = fake_feats[mask_s_gt[:, 0] == 1]
332-
true_points = true_feats[mask_s_gt[:, 0] == 1]
333-
c_f_points = center[mask_s_gt[:, 0] == 1]
334-
dist_f = torch.norm(fake_points - c_f_points, dim=1)
335-
proj_feats = c_f_points if self.svd == 1 else true_points
336-
r = r_t if self.svd == 1 else 1
337-
338-
if self.svd == 1:
339-
h = fake_points - proj_feats
340-
h_norm = dist_f if self.svd == 1 else torch.norm(h, dim=1)
341-
alpha = torch.clamp(h_norm, 2 * r, 4 * r)
342-
proj = (alpha / (h_norm + 1e-10)).view(-1, 1)
343-
h = proj * h
344-
fake_points = proj_feats + h
345-
fake_feats[mask_s_gt[:, 0] == 1] = fake_points
346-
347-
fake_scores = self.model.discriminator(fake_feats)
348-
349-
if self.p > 0:
350-
fake_dist = (fake_scores - mask_s_gt) ** 2
351-
d_hard = torch.quantile(fake_dist, q=self.p)
352-
fake_scores_ = fake_scores[fake_dist >= d_hard].unsqueeze(1)
353-
mask_ = mask_s_gt[fake_dist >= d_hard].unsqueeze(1)
354-
else:
355-
fake_scores_ = fake_scores
356-
mask_ = mask_s_gt
357-
output = torch.cat([1 - fake_scores_, fake_scores_], dim=1)
358-
focal_loss = self.focal_loss(output, mask_)
359-
360-
loss = bce_loss + focal_loss
278+
true_loss, gaus_loss, bce_loss, focal_loss, loss = self.model(img, self.c)
361279
loss.backward()
362280

363281
if self.proj_opt is not None:

src/anomalib/models/image/glass/torch_model.py

Lines changed: 139 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,12 @@
2424
import torch.nn.functional as f
2525
from torch import nn
2626

27+
from anomalib.data.utils.generators.perlin import PerlinAnomalyGenerator
2728
from anomalib.models.components import TimmFeatureExtractor
2829
from anomalib.models.components.feature_extractors import dryrun_find_featuremap_dims
2930

31+
from .loss import FocalLoss
32+
3033

3134
def init_weight(m: nn.Module) -> None:
3235
"""Initializes network weights using Xavier normal initialization.
@@ -313,6 +316,7 @@ class GlassModel(nn.Module):
313316
def __init__(
314317
self,
315318
input_shape: tuple[int, int], # (H, W)
319+
anomaly_source_path: str,
316320
pretrain_embed_dim: int = 1024,
317321
target_embed_dim: int = 1024,
318322
backbone: str = "resnet18",
@@ -324,6 +328,13 @@ def __init__(
324328
dsc_layers: int = 2,
325329
dsc_hidden: int = 1024,
326330
dsc_margin: float = 0.5,
331+
mining: int = 1,
332+
noise: float = 0.015,
333+
radius: float = 0.75,
334+
p: float = 0.5,
335+
lr: float = 0.0001,
336+
step: int = 20,
337+
svd: int = 0,
327338
) -> None:
328339
super().__init__()
329340

@@ -335,6 +346,12 @@ def __init__(
335346
self.input_shape = input_shape
336347
self.pre_trained = pre_trained
337348

349+
self.augmentor = PerlinAnomalyGenerator(anomaly_source_path)
350+
351+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
352+
353+
self.focal_loss = FocalLoss()
354+
338355
self.forward_modules = torch.nn.ModuleDict({})
339356
feature_aggregator = TimmFeatureExtractor(
340357
backbone=self.backbone,
@@ -367,6 +384,15 @@ def __init__(
367384
hidden=self.dsc_hidden,
368385
)
369386

387+
self.p = p
388+
self.radius = radius
389+
self.mining = mining
390+
self.noise = noise
391+
self.distribution = 0
392+
self.lr = lr
393+
self.step = step
394+
self.svd = svd
395+
370396
self.patch_maker = PatchMaker(patchsize, stride=patchstride)
371397

372398
def calculate_mean(self, images: torch.Tensor) -> torch.Tensor:
@@ -400,6 +426,41 @@ def calculate_mean(self, images: torch.Tensor) -> torch.Tensor:
400426

401427
return torch.mean(outputs, dim=0)
402428

429+
def calculate_features(self,
430+
img: torch.Tensor,
431+
aug: torch.Tensor,
432+
evaluation: bool = False,
433+
) -> tuple[torch.Tensor, torch.Tensor]:
434+
"""Calculate and return feature embeddings for the input and augmented images.
435+
436+
Depending on whether a pre-projection module is used, this method optionally applies it to the
437+
438+
Args:
439+
img (torch.Tensor): The original input image tensor.
440+
aug (torch.Tensor): The augmented image tensor.
441+
evaluation (bool, optional): Whether the model is in evaluation mode. Defaults to False.
442+
443+
Returns:
444+
tuple[torch.Tensor, torch.Tensor]: A tuple containing the feature embeddings for the original
445+
image (`true_feats`) and the augmented image (`fake_feats`).
446+
"""
447+
if self.pre_proj > 0:
448+
fake_feats = self.pre_projection(
449+
self.generate_embeddings(aug, evaluation=evaluation)[0],
450+
)
451+
fake_feats = fake_feats[0] if len(fake_feats) == 2 else fake_feats
452+
true_feats = self.pre_projection(
453+
self.generate_embeddings(img, evaluation=evaluation)[0],
454+
)
455+
true_feats = true_feats[0] if len(true_feats) == 2 else true_feats
456+
else:
457+
fake_feats = self.generate_embeddings(aug, evaluation=evaluation)[0]
458+
fake_feats.requires_grad = True
459+
true_feats = self.generate_embeddings(img, evaluation=evaluation)[0]
460+
true_feats.requires_grad = True
461+
462+
return true_feats, fake_feats
463+
403464
def generate_embeddings(
404465
self,
405466
images: torch.Tensor,
@@ -488,28 +549,90 @@ def generate_embeddings(
488549
def forward(
489550
self,
490551
img: torch.Tensor,
491-
aug: torch.Tensor,
492-
evaluation: bool = False,
552+
c: torch.Tensor | None = None,
493553
) -> tuple[torch.Tensor, torch.Tensor]:
494554
"""Forward pass to compute patch-wise feature embeddings for original and augmented images.
495555
496556
Depending on whether a pre-projection module is used, this method optionally applies it to the
497557
embeddings generated for both `img` and `aug`. If not, the embeddings are directly obtained and
498558
`requires_grad` is enabled for them, likely for gradient-based optimization or anomaly generation.
499559
"""
500-
if self.pre_proj > 0:
501-
fake_feats = self.pre_projection(
502-
self.generate_embeddings(aug, evaluation=evaluation)[0],
503-
)
504-
fake_feats = fake_feats[0] if len(fake_feats) == 2 else fake_feats
505-
true_feats = self.pre_projection(
506-
self.generate_embeddings(img, evaluation=evaluation)[0],
507-
)
508-
true_feats = true_feats[0] if len(true_feats) == 2 else true_feats
560+
aug, mask_s = self.augmentor(img)
561+
if img is not None:
562+
batch_size = img.shape[0]
563+
564+
true_feats, fake_feats = self.calculate_features(img, aug)
565+
566+
h_ratio = mask_s.shape[2] // int(math.sqrt(fake_feats.shape[0] // batch_size))
567+
w_ratio = mask_s.shape[3] // int(math.sqrt(fake_feats.shape[0] // batch_size))
568+
569+
mask_s_resized = f.interpolate(
570+
mask_s.float(),
571+
size=(mask_s.shape[2] // h_ratio, mask_s.shape[3] // w_ratio),
572+
mode="nearest",
573+
)
574+
mask_s_gt = mask_s_resized.reshape(-1, 1)
575+
576+
noise = torch.normal(0, self.noise, true_feats.shape).to(self.device)
577+
gaus_feats = true_feats + noise
578+
579+
center = c.repeat(img.shape[0], 1, 1)
580+
center = center.reshape(-1, center.shape[-1])
581+
true_points = torch.concat(
582+
[fake_feats[mask_s_gt[:, 0] == 0], true_feats],
583+
dim=0,
584+
)
585+
c_t_points = torch.concat([center[mask_s_gt[:, 0] == 0], center], dim=0)
586+
dist_t = torch.norm(true_points - c_t_points, dim=1)
587+
r_t = torch.tensor([torch.quantile(dist_t, q=self.radius)]).to(self.device)
588+
589+
for step in range(self.step + 1):
590+
scores = self.discriminator(torch.cat([true_feats, gaus_feats]))
591+
true_scores = scores[: len(true_feats)]
592+
gaus_scores = scores[len(true_feats) :]
593+
true_loss = nn.BCELoss()(true_scores, torch.zeros_like(true_scores))
594+
gaus_loss = nn.BCELoss()(gaus_scores, torch.ones_like(gaus_scores))
595+
bce_loss = true_loss + gaus_loss
596+
597+
if step == self.step:
598+
break
599+
600+
grad = torch.autograd.grad(gaus_loss, [gaus_feats])[0]
601+
grad_norm = torch.norm(grad, dim=1)
602+
grad_norm = grad_norm.view(-1, 1)
603+
grad_normalized = grad / (grad_norm + 1e-10)
604+
605+
with torch.no_grad():
606+
gaus_feats.add_(0.001 * grad_normalized)
607+
608+
fake_points = fake_feats[mask_s_gt[:, 0] == 1]
609+
true_points = true_feats[mask_s_gt[:, 0] == 1]
610+
c_f_points = center[mask_s_gt[:, 0] == 1]
611+
dist_f = torch.norm(fake_points - c_f_points, dim=1)
612+
proj_feats = c_f_points if self.svd == 1 else true_points
613+
r = r_t if self.svd == 1 else 1
614+
615+
if self.svd == 1:
616+
h = fake_points - proj_feats
617+
h_norm = dist_f if self.svd == 1 else torch.norm(h, dim=1)
618+
alpha = torch.clamp(h_norm, 2 * r, 4 * r)
619+
proj = (alpha / (h_norm + 1e-10)).view(-1, 1)
620+
h = proj * h
621+
fake_points = proj_feats + h
622+
fake_feats[mask_s_gt[:, 0] == 1] = fake_points
623+
624+
fake_scores = self.discriminator(fake_feats)
625+
626+
if self.p > 0:
627+
fake_dist = (fake_scores - mask_s_gt) ** 2
628+
d_hard = torch.quantile(fake_dist, q=self.p)
629+
fake_scores_ = fake_scores[fake_dist >= d_hard].unsqueeze(1)
630+
mask_ = mask_s_gt[fake_dist >= d_hard].unsqueeze(1)
509631
else:
510-
fake_feats = self.generate_embeddings(aug, evaluation=evaluation)[0]
511-
fake_feats.requires_grad = True
512-
true_feats = self.generate_embeddings(img, evaluation=evaluation)[0]
513-
true_feats.requires_grad = True
632+
fake_scores_ = fake_scores
633+
mask_ = mask_s_gt
634+
output = torch.cat([1 - fake_scores_, fake_scores_], dim=1)
635+
focal_loss = self.focal_loss(output, mask_)
514636

515-
return true_feats, fake_feats
637+
loss = bce_loss + focal_loss
638+
return true_loss, gaus_loss, bce_loss, focal_loss, loss

0 commit comments

Comments
 (0)