Skip to content

Commit d35a86a

Browse files
committed
Implement stride convolutional and pooling layer
1 parent 88e18ae commit d35a86a

34 files changed

+761
-3010
lines changed

project/run_mnist_multiclass.py renamed to examples/run_mnist_multiclass.py

Lines changed: 22 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import argparse
22
import numba
33
import numpy as np
4+
import matplotlib.pyplot as plt
45
import sys
56
import warnings
67
warnings.filterwarnings("ignore")
@@ -14,57 +15,42 @@
1415
if numba.cuda.is_available():
1516
GPUBackend = minitorch.TensorBackend(minitorch.CudaOps)
1617

17-
C = 10
18-
1918
H, W = 28, 28
19+
C = 10
2020

2121

22-
def mnist_transform(image):
23-
"""Normalize MNIST image from uint8 [0, 255] to float [0, 1]"""
22+
def preprocess(image):
2423
return image.astype(np.float64) / 255.0
2524

2625

27-
class Network(minitorch.Module):
28-
"""
29-
Implement a CNN for MNist classification based on LeNet.
30-
This model should implement the following procedure:
31-
1. Apply a convolution with 4 output channels and a 3x3 kernel followed by a ReLU (save to self.mid)
32-
2. Apply a convolution with 8 output channels and a 3x3 kernel followed by a ReLU (save to self.out)
33-
3. Apply 2D pooling (either Avg or Max) with 4x4 kernel.
34-
4. Flatten channels, height, and width. (Should be size BATCHx392)
35-
5. Apply a Linear to size 64 followed by a ReLU and Dropout with rate 25%
36-
6. Apply a Linear to size C (number of classes).
37-
7. Apply a logsoftmax over the class dimension.
38-
"""
26+
class Network(minitorch.Module): # LeNet-5
3927

4028
def __init__(self, backend=FastTensorBackend):
4129
super().__init__()
30+
self.conv1 = minitorch.Conv2d(in_channels=1, out_channels=6, kernel=(5, 5), stride=1, backend=backend)
31+
self.conv2 = minitorch.Conv2d(in_channels=6, out_channels=16, kernel=(5, 5), stride=1, backend=backend)
4232

43-
# For vis
44-
self.mid = None
45-
self.out = None
46-
47-
self.conv1 = minitorch.Conv2d(1, 4, 3, 3, backend=backend)
48-
self.conv2 = minitorch.Conv2d(4, 8, 3, 3, backend=backend)
49-
self.linear1 = minitorch.Linear(392, 64, backend=backend)
50-
self.linear2 = minitorch.Linear(64, C, backend=backend)
33+
self.fc1 = minitorch.Linear(16 * 4 * 4, 120, backend=backend)
34+
self.fc2 = minitorch.Linear(120, 84, backend=backend)
35+
self.fc3 = minitorch.Linear(84, C, backend=backend)
5136

5237
def forward(self, x):
5338
batch_size = x.shape[0]
5439
x = self.conv1(x).relu()
55-
self.mid = x
40+
x = minitorch.avgpool2d(x, kernel=(2, 2), stride=(2, 2))
5641
x = self.conv2(x).relu()
57-
self.out = x
58-
x = minitorch.avgpool2d(x, (4, 4))
59-
x = x.view(batch_size, 392)
60-
x = self.linear1(x).relu()
61-
x = minitorch.dropout(x, 0.25, not self.training)
62-
x = self.linear2(x)
42+
x = minitorch.avgpool2d(x, kernel=(2, 2), stride=(2, 2))
43+
x = x.view(batch_size, 16 * 4 * 4)
44+
x = self.fc1(x).relu()
45+
x = minitorch.dropout(x, 0.2, not self.training)
46+
x = self.fc2(x).relu()
47+
x = minitorch.dropout(x, 0.2, not self.training)
48+
x = self.fc3(x)
6349
x = minitorch.logsoftmax(x, dim=1)
6450
return x
6551

6652

67-
def default_log_fn(epoch, total_loss, correct, total):
53+
def default_log_fn(epoch, total_loss, correct, total, loss_list):
6854
print(
6955
f"Epoch {epoch} | loss {total_loss / total:.2f} | valid acc {correct / total:.2f}"
7056
)
@@ -74,11 +60,11 @@ def train(
7460
model,
7561
train_loader,
7662
val_loader,
77-
learning_rate,
63+
learning_rate=1e-2,
7864
max_epochs=50,
7965
log_fn=default_log_fn,
8066
):
81-
optim = minitorch.SGD(model.parameters(), learning_rate)
67+
optim = minitorch.RMSProp(model.parameters(), learning_rate)
8268
for epoch in range(1, max_epochs + 1):
8369
total_loss = 0.0
8470
model.train()
@@ -133,14 +119,14 @@ def train(
133119
batch_size=args.batch_size,
134120
shuffle=True,
135121
backend=backend,
136-
transform=mnist_transform
122+
transform=preprocess
137123
)
138124
val_loader = DataLoader(
139125
mnist_val,
140126
batch_size=args.batch_size,
141127
shuffle=False,
142128
backend=backend,
143-
transform=mnist_transform
129+
transform=preprocess
144130
)
145131

146132
model = Network(backend=backend)

0 commit comments

Comments
 (0)