diff --git a/utils.py b/utils.py index bdca336c..07d3a572 100644 --- a/utils.py +++ b/utils.py @@ -491,7 +491,7 @@ def __init__(self, data_source, num_epochs, start_itr=0, batch_size=128): self.data_source = data_source self.num_samples = len(self.data_source) self.num_epochs = num_epochs - self.start_itr = start_itr + self.start_itr = start_itr%int((self.num_samples*self.num_epochs)/batch_size) self.batch_size = batch_size if not isinstance(self.num_samples, int) or self.num_samples <= 0: @@ -514,6 +514,7 @@ def __iter__(self): # return iter(.tolist()) output = torch.cat(out).tolist() print('Length dataset output is %d' % len(output)) + self.start_itr = 0 return iter(output) def __len__(self):