Skip to content

Commit 1baa0b7

Browse files
committed
GPU bug fixed
Signed-off-by: Devansh Agarwal <[email protected]>
1 parent 838bc50 commit 1baa0b7

File tree

2 files changed

+6
-9
lines changed

2 files changed

+6
-9
lines changed

src/anomalib/models/image/glass/lightning_model.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -170,8 +170,6 @@ def __init__(
170170
self.c = torch.tensor([1])
171171
self.lr = lr
172172

173-
self.dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
174-
175173
if pre_proj > 0:
176174
self.proj_opt = optim.AdamW(
177175
self.model.pre_projection.parameters(),
@@ -275,7 +273,7 @@ def training_step(
275273
self.backbone_opt.zero_grad()
276274

277275
img = batch.image
278-
true_loss, gaus_loss, bce_loss, focal_loss, loss = self.model(img, self.c)
276+
true_loss, gaus_loss, bce_loss, focal_loss, loss = self.model(img, self.c, self.device)
279277
loss.backward()
280278

281279
if self.proj_opt is not None:
@@ -301,9 +299,9 @@ def on_train_start(self) -> None:
301299
with torch.no_grad():
302300
for i, batch in enumerate(dataloader):
303301
if i == 0:
304-
self.c = self.model.calculate_mean(batch.image.to(self.dev))
302+
self.c = self.model.calculate_mean(batch.image.to(self.device))
305303
else:
306-
self.c += self.model.calculate_mean(batch.image.to(self.dev))
304+
self.c += self.model.calculate_mean(batch.image.to(self.device))
307305

308306
self.c /= len(dataloader)
309307

src/anomalib/models/image/glass/torch_model.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -348,8 +348,6 @@ def __init__(
348348

349349
self.augmentor = PerlinAnomalyGenerator(anomaly_source_path)
350350

351-
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
352-
353351
self.focal_loss = FocalLoss()
354352

355353
self.forward_modules = torch.nn.ModuleDict({})
@@ -550,6 +548,7 @@ def forward(
550548
self,
551549
img: torch.Tensor,
552550
c: torch.Tensor | None = None,
551+
device: torch.device | None = None,
553552
) -> tuple[torch.Tensor, torch.Tensor]:
554553
"""Forward pass to compute patch-wise feature embeddings for original and augmented images.
555554
@@ -573,7 +572,7 @@ def forward(
573572
)
574573
mask_s_gt = mask_s_resized.reshape(-1, 1)
575574

576-
noise = torch.normal(0, self.noise, true_feats.shape).to(self.device)
575+
noise = torch.normal(0, self.noise, true_feats.shape).to(device)
577576
gaus_feats = true_feats + noise
578577

579578
center = c.repeat(img.shape[0], 1, 1)
@@ -584,7 +583,7 @@ def forward(
584583
)
585584
c_t_points = torch.concat([center[mask_s_gt[:, 0] == 0], center], dim=0)
586585
dist_t = torch.norm(true_points - c_t_points, dim=1)
587-
r_t = torch.tensor([torch.quantile(dist_t, q=self.radius)]).to(self.device)
586+
r_t = torch.tensor([torch.quantile(dist_t, q=self.radius)]).to(device)
588587

589588
for step in range(self.step + 1):
590589
scores = self.discriminator(torch.cat([true_feats, gaus_feats]))

0 commit comments

Comments
 (0)