-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathforward.py
More file actions
30 lines (25 loc) · 857 Bytes
/
forward.py
File metadata and controls
30 lines (25 loc) · 857 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import os
import sys
import torch
from torch.utils.data import DataLoader
from scheduler import forward_diffusion_sample
from dataset import load_transformed_dataset, show_tensor_image
import matplotlib.pyplot as plt
dir_path = os.path.dirname(os.path.realpath(__file__))
main_path = os.path.dirname(dir_path)
sys.path.append(main_path)
BATCH_SIZE = 128
data = load_transformed_dataset()
dataloader = DataLoader(data, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
# Forward Diffusion
T = 300
image = next(iter(dataloader))[0]
plt.figure(figsize=(15,15))
plt.axis('off')
num_images = 10
stepsize = int(T/num_images)
for idx in range(0, T, stepsize):
t = torch.Tensor([idx]).type(torch.int64)
plt.subplot(1, num_images+1, int(idx/stepsize) + 1)
img, noise = forward_diffusion_sample(image, t)
show_tensor_image(img, 'forward_pass')