-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathdata.py
More file actions
87 lines (67 loc) · 2.51 KB
/
data.py
File metadata and controls
87 lines (67 loc) · 2.51 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import tqdm
import json
import torch
from PIL import Image
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
def get_name_label_pairs(json_name):
with open(json_name, 'r') as f:
jsc = json.load(f)
name_label_pair = list()
for k, v in jsc.items():
label = int(k)
images = v
for image in images:
abs_image_name = '/opt/tiger/debug_server/Phase2/data/train_p2/%s' % image
name_label_pair.append((abs_image_name, label))
return name_label_pair
class MyDataset(Dataset):
def __init__(self, names, transform):
self.names = names
self.transform = transform
def __len__(self):
return len(self.names)
def __getitem__(self, index):
img_name, label = self.names[index]
img = Image.open(img_name).convert('RGB')
return self.transform(img), int(label)
def data_pipeline(train_json, val_json, transform, batch_size):
# only center crop for validation image
val_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor()
])
train_pairs = get_name_label_pairs(train_json)
val_pairs = get_name_label_pairs(val_json)
train_set = MyDataset(train_pairs, transform)
val_set = MyDataset(val_pairs, val_transform)
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=24)
val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=True, num_workers=12)
return train_loader, val_loader
def get_train_dataset(train_json, transform):
train_pairs = get_name_label_pairs(train_json)
train_set = MyDataset(train_pairs, transform)
return train_set
def get_val_dataset(val_json):
val_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor()
])
val_pairs = get_name_label_pairs(val_json)
val_set = MyDataset(val_pairs, val_transform)
return val_set
if __name__ == '__main__':
train_json = './data/train.json'
val_json = './data/val.json'
# data augmentation : just for testing
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
])
batch_size = 64
train_loader, val_loader = data_pipeline(train_json, val_json, transform, batch_size)
for x, y in tqdm.tqdm(train_loader):
print(x.shape, y.shape, torch.min(x), torch.max(x))