While reviewing the training pipeline, I noticed two small issues that may affect training stability:
- The scheduler is stepped using
scheduler.step(loss), but schedulers such as CosineAnnealingWarmRestarts do not require the loss value and expect a simple scheduler.step() call.
- Labels are converted using
labels.type(torch.LongTensor).to(device), which can create CPU tensors before moving to the target device. Using labels.long().to(device) is safer and avoids unnecessary tensor transfers.
I can submit a small fix addressing both points if this sounds reasonable.