-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathpytorch_dataset_class.py
More file actions
84 lines (72 loc) · 3.01 KB
/
pytorch_dataset_class.py
File metadata and controls
84 lines (72 loc) · 3.01 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import torch
from torch.utils.data import Dataset
import os, json, h5py
class fetch_data(Dataset):
def __init__(self, dataset_path, split, transform=None):
super().__init__()
assert split in ["train", "val", "test"]
assert os.path.exists(dataset_path)
self.n_captions_per_image = 5
self.split = split
image_dataset_name = f"image_dataset_{split}.hdf5"
self.imgs = h5py.File(os.path.join(dataset_path, image_dataset_name), "r")
self.imgs = self.imgs["default"]
caption_dataset_name = f"tokenized_captions_{split}.json"
caplens_dataset_name = f"caption_lengths_{split}.json"
with open(os.path.join(dataset_path, caption_dataset_name), "r") as f:
self.captions = json.load(f)
with open(os.path.join(dataset_path, caplens_dataset_name), "r") as f:
self.caplens = json.load(f)
self.transform = transform
word2idx_map_name = "word_to_index_map.json"
with open(os.path.join(dataset_path, word2idx_map_name), "r") as f:
self.word2idx = json.load(f)
def __getitem__(self, i):
img = self.imgs[i // self.n_captions_per_image]
img = torch.from_numpy(img)
if self.transform != None:
img = self.transform(img)
caption = self.captions[i]
caplen = self.caplens[i]
def get_one_hot_encoding(idx, dim):
res = [0] * dim
res[idx] = 1
return res
dim = len(self.word2idx)
caption = torch.FloatTensor(
[
get_one_hot_encoding(
self.word2idx.get(word, self.word2idx["<unk>"]), dim
)
for word in caption
]
)
if self.split == "train":
return img, caption, caplen#, self.captions[i]
else:
all_captions_for_image = []
start = (i // self.n_captions_per_image) * self.n_captions_per_image
end = (
(i // self.n_captions_per_image) * self.n_captions_per_image
) + self.n_captions_per_image
for i in range(start, end):
# all_captions_for_image.append(torch.FloatTensor([get_one_hot_encoding(self.word2idx.get(word,self.word2idx["<unk>"]),dim) for word in self.captions[i]]))
all_captions_for_image.append(
[
get_one_hot_encoding(
self.word2idx.get(word, self.word2idx["<unk>"]), dim
)
for word in self.captions[i]
]
)
return img, caption, caplen, torch.FloatTensor(all_captions_for_image)#, self.captions[i]
def __len__(self):
return len(self.captions)
if __name__=="__main__":
dataset = fetch_data("dataset_0.01", "train", transform=None)
img, caption, caplen, raw_caption = dataset[25]
img = img.permute(1,2,0)
import matplotlib.pyplot as plt
plt.imshow(img)
plt.show()
print(raw_caption)