diff --git a/DeepLense_Classification_Transformers_Archil_Srivastava/train.py b/DeepLense_Classification_Transformers_Archil_Srivastava/train.py index a5a6303..3961565 100644 --- a/DeepLense_Classification_Transformers_Archil_Srivastava/train.py +++ b/DeepLense_Classification_Transformers_Archil_Srivastava/train.py @@ -44,8 +44,7 @@ def train_step(model, images, labels, optimizer, scheduler, criterion, device="c Loss value from the forward pass """ # Send to device - images, labels = images.to(device, dtype=torch.float), labels.type( - torch.LongTensor + images, labels = images.to(device, dtype=torch.float), labels.long().to(device) ).to(device) model.train() # Set train mode optimizer.zero_grad() # Reset gradients @@ -54,7 +53,7 @@ def train_step(model, images, labels, optimizer, scheduler, criterion, device="c loss.backward() # Backward pass optimizer.step() # Optimize weights step if scheduler is not None: - scheduler.step(loss) # Modify learning rate if scheduler is set + scheduler.step() # Modify learning rate if scheduler is set return loss