2424import torch .nn .functional as f
2525from torch import nn
2626
27+ from anomalib .data .utils .generators .perlin import PerlinAnomalyGenerator
2728from anomalib .models .components import TimmFeatureExtractor
2829from anomalib .models .components .feature_extractors import dryrun_find_featuremap_dims
2930
31+ from .loss import FocalLoss
32+
3033
3134def init_weight (m : nn .Module ) -> None :
3235 """Initializes network weights using Xavier normal initialization.
@@ -313,6 +316,7 @@ class GlassModel(nn.Module):
313316 def __init__ (
314317 self ,
315318 input_shape : tuple [int , int ], # (H, W)
319+ anomaly_source_path : str ,
316320 pretrain_embed_dim : int = 1024 ,
317321 target_embed_dim : int = 1024 ,
318322 backbone : str = "resnet18" ,
@@ -324,6 +328,13 @@ def __init__(
324328 dsc_layers : int = 2 ,
325329 dsc_hidden : int = 1024 ,
326330 dsc_margin : float = 0.5 ,
331+ mining : int = 1 ,
332+ noise : float = 0.015 ,
333+ radius : float = 0.75 ,
334+ p : float = 0.5 ,
335+ lr : float = 0.0001 ,
336+ step : int = 20 ,
337+ svd : int = 0 ,
327338 ) -> None :
328339 super ().__init__ ()
329340
@@ -335,6 +346,12 @@ def __init__(
335346 self .input_shape = input_shape
336347 self .pre_trained = pre_trained
337348
349+ self .augmentor = PerlinAnomalyGenerator (anomaly_source_path )
350+
351+ self .device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
352+
353+ self .focal_loss = FocalLoss ()
354+
338355 self .forward_modules = torch .nn .ModuleDict ({})
339356 feature_aggregator = TimmFeatureExtractor (
340357 backbone = self .backbone ,
@@ -367,6 +384,15 @@ def __init__(
367384 hidden = self .dsc_hidden ,
368385 )
369386
387+ self .p = p
388+ self .radius = radius
389+ self .mining = mining
390+ self .noise = noise
391+ self .distribution = 0
392+ self .lr = lr
393+ self .step = step
394+ self .svd = svd
395+
370396 self .patch_maker = PatchMaker (patchsize , stride = patchstride )
371397
372398 def calculate_mean (self , images : torch .Tensor ) -> torch .Tensor :
@@ -400,6 +426,41 @@ def calculate_mean(self, images: torch.Tensor) -> torch.Tensor:
400426
401427 return torch .mean (outputs , dim = 0 )
402428
429+ def calculate_features (self ,
430+ img : torch .Tensor ,
431+ aug : torch .Tensor ,
432+ evaluation : bool = False ,
433+ ) -> tuple [torch .Tensor , torch .Tensor ]:
434+ """Calculate and return feature embeddings for the input and augmented images.
435+
436+ Depending on whether a pre-projection module is used, this method optionally applies it to the
437+
438+ Args:
439+ img (torch.Tensor): The original input image tensor.
440+ aug (torch.Tensor): The augmented image tensor.
441+ evaluation (bool, optional): Whether the model is in evaluation mode. Defaults to False.
442+
443+ Returns:
444+ tuple[torch.Tensor, torch.Tensor]: A tuple containing the feature embeddings for the original
445+ image (`true_feats`) and the augmented image (`fake_feats`).
446+ """
447+ if self .pre_proj > 0 :
448+ fake_feats = self .pre_projection (
449+ self .generate_embeddings (aug , evaluation = evaluation )[0 ],
450+ )
451+ fake_feats = fake_feats [0 ] if len (fake_feats ) == 2 else fake_feats
452+ true_feats = self .pre_projection (
453+ self .generate_embeddings (img , evaluation = evaluation )[0 ],
454+ )
455+ true_feats = true_feats [0 ] if len (true_feats ) == 2 else true_feats
456+ else :
457+ fake_feats = self .generate_embeddings (aug , evaluation = evaluation )[0 ]
458+ fake_feats .requires_grad = True
459+ true_feats = self .generate_embeddings (img , evaluation = evaluation )[0 ]
460+ true_feats .requires_grad = True
461+
462+ return true_feats , fake_feats
463+
403464 def generate_embeddings (
404465 self ,
405466 images : torch .Tensor ,
@@ -488,28 +549,90 @@ def generate_embeddings(
488549 def forward (
489550 self ,
490551 img : torch .Tensor ,
491- aug : torch .Tensor ,
492- evaluation : bool = False ,
552+ c : torch .Tensor | None = None ,
493553 ) -> tuple [torch .Tensor , torch .Tensor ]:
494554 """Forward pass to compute patch-wise feature embeddings for original and augmented images.
495555
496556 Depending on whether a pre-projection module is used, this method optionally applies it to the
497557 embeddings generated for both `img` and `aug`. If not, the embeddings are directly obtained and
498558 `requires_grad` is enabled for them, likely for gradient-based optimization or anomaly generation.
499559 """
500- if self .pre_proj > 0 :
501- fake_feats = self .pre_projection (
502- self .generate_embeddings (aug , evaluation = evaluation )[0 ],
503- )
504- fake_feats = fake_feats [0 ] if len (fake_feats ) == 2 else fake_feats
505- true_feats = self .pre_projection (
506- self .generate_embeddings (img , evaluation = evaluation )[0 ],
507- )
508- true_feats = true_feats [0 ] if len (true_feats ) == 2 else true_feats
560+ aug , mask_s = self .augmentor (img )
561+ if img is not None :
562+ batch_size = img .shape [0 ]
563+
564+ true_feats , fake_feats = self .calculate_features (img , aug )
565+
566+ h_ratio = mask_s .shape [2 ] // int (math .sqrt (fake_feats .shape [0 ] // batch_size ))
567+ w_ratio = mask_s .shape [3 ] // int (math .sqrt (fake_feats .shape [0 ] // batch_size ))
568+
569+ mask_s_resized = f .interpolate (
570+ mask_s .float (),
571+ size = (mask_s .shape [2 ] // h_ratio , mask_s .shape [3 ] // w_ratio ),
572+ mode = "nearest" ,
573+ )
574+ mask_s_gt = mask_s_resized .reshape (- 1 , 1 )
575+
576+ noise = torch .normal (0 , self .noise , true_feats .shape ).to (self .device )
577+ gaus_feats = true_feats + noise
578+
579+ center = c .repeat (img .shape [0 ], 1 , 1 )
580+ center = center .reshape (- 1 , center .shape [- 1 ])
581+ true_points = torch .concat (
582+ [fake_feats [mask_s_gt [:, 0 ] == 0 ], true_feats ],
583+ dim = 0 ,
584+ )
585+ c_t_points = torch .concat ([center [mask_s_gt [:, 0 ] == 0 ], center ], dim = 0 )
586+ dist_t = torch .norm (true_points - c_t_points , dim = 1 )
587+ r_t = torch .tensor ([torch .quantile (dist_t , q = self .radius )]).to (self .device )
588+
589+ for step in range (self .step + 1 ):
590+ scores = self .discriminator (torch .cat ([true_feats , gaus_feats ]))
591+ true_scores = scores [: len (true_feats )]
592+ gaus_scores = scores [len (true_feats ) :]
593+ true_loss = nn .BCELoss ()(true_scores , torch .zeros_like (true_scores ))
594+ gaus_loss = nn .BCELoss ()(gaus_scores , torch .ones_like (gaus_scores ))
595+ bce_loss = true_loss + gaus_loss
596+
597+ if step == self .step :
598+ break
599+
600+ grad = torch .autograd .grad (gaus_loss , [gaus_feats ])[0 ]
601+ grad_norm = torch .norm (grad , dim = 1 )
602+ grad_norm = grad_norm .view (- 1 , 1 )
603+ grad_normalized = grad / (grad_norm + 1e-10 )
604+
605+ with torch .no_grad ():
606+ gaus_feats .add_ (0.001 * grad_normalized )
607+
608+ fake_points = fake_feats [mask_s_gt [:, 0 ] == 1 ]
609+ true_points = true_feats [mask_s_gt [:, 0 ] == 1 ]
610+ c_f_points = center [mask_s_gt [:, 0 ] == 1 ]
611+ dist_f = torch .norm (fake_points - c_f_points , dim = 1 )
612+ proj_feats = c_f_points if self .svd == 1 else true_points
613+ r = r_t if self .svd == 1 else 1
614+
615+ if self .svd == 1 :
616+ h = fake_points - proj_feats
617+ h_norm = dist_f if self .svd == 1 else torch .norm (h , dim = 1 )
618+ alpha = torch .clamp (h_norm , 2 * r , 4 * r )
619+ proj = (alpha / (h_norm + 1e-10 )).view (- 1 , 1 )
620+ h = proj * h
621+ fake_points = proj_feats + h
622+ fake_feats [mask_s_gt [:, 0 ] == 1 ] = fake_points
623+
624+ fake_scores = self .discriminator (fake_feats )
625+
626+ if self .p > 0 :
627+ fake_dist = (fake_scores - mask_s_gt ) ** 2
628+ d_hard = torch .quantile (fake_dist , q = self .p )
629+ fake_scores_ = fake_scores [fake_dist >= d_hard ].unsqueeze (1 )
630+ mask_ = mask_s_gt [fake_dist >= d_hard ].unsqueeze (1 )
509631 else :
510- fake_feats = self . generate_embeddings ( aug , evaluation = evaluation )[ 0 ]
511- fake_feats . requires_grad = True
512- true_feats = self . generate_embeddings ( img , evaluation = evaluation )[ 0 ]
513- true_feats . requires_grad = True
632+ fake_scores_ = fake_scores
633+ mask_ = mask_s_gt
634+ output = torch . cat ([ 1 - fake_scores_ , fake_scores_ ], dim = 1 )
635+ focal_loss = self . focal_loss ( output , mask_ )
514636
515- return true_feats , fake_feats
637+ loss = bce_loss + focal_loss
638+ return true_loss , gaus_loss , bce_loss , focal_loss , loss
0 commit comments