11import argparse
22import numba
33import numpy as np
4+ import matplotlib .pyplot as plt
45import sys
56import warnings
67warnings .filterwarnings ("ignore" )
1415if numba .cuda .is_available ():
1516 GPUBackend = minitorch .TensorBackend (minitorch .CudaOps )
1617
17- C = 10
18-
1918H , W = 28 , 28
19+ C = 10
2020
2121
22- def mnist_transform (image ):
23- """Normalize MNIST image from uint8 [0, 255] to float [0, 1]"""
22+ def preprocess (image ):
2423 return image .astype (np .float64 ) / 255.0
2524
2625
27- class Network (minitorch .Module ):
28- """
29- Implement a CNN for MNist classification based on LeNet.
30- This model should implement the following procedure:
31- 1. Apply a convolution with 4 output channels and a 3x3 kernel followed by a ReLU (save to self.mid)
32- 2. Apply a convolution with 8 output channels and a 3x3 kernel followed by a ReLU (save to self.out)
33- 3. Apply 2D pooling (either Avg or Max) with 4x4 kernel.
34- 4. Flatten channels, height, and width. (Should be size BATCHx392)
35- 5. Apply a Linear to size 64 followed by a ReLU and Dropout with rate 25%
36- 6. Apply a Linear to size C (number of classes).
37- 7. Apply a logsoftmax over the class dimension.
38- """
26+ class Network (minitorch .Module ): # LeNet-5
3927
4028 def __init__ (self , backend = FastTensorBackend ):
4129 super ().__init__ ()
30+ self .conv1 = minitorch .Conv2d (in_channels = 1 , out_channels = 6 , kernel = (5 , 5 ), stride = 1 , backend = backend )
31+ self .conv2 = minitorch .Conv2d (in_channels = 6 , out_channels = 16 , kernel = (5 , 5 ), stride = 1 , backend = backend )
4232
43- # For vis
44- self .mid = None
45- self .out = None
46-
47- self .conv1 = minitorch .Conv2d (1 , 4 , 3 , 3 , backend = backend )
48- self .conv2 = minitorch .Conv2d (4 , 8 , 3 , 3 , backend = backend )
49- self .linear1 = minitorch .Linear (392 , 64 , backend = backend )
50- self .linear2 = minitorch .Linear (64 , C , backend = backend )
33+ self .fc1 = minitorch .Linear (16 * 4 * 4 , 120 , backend = backend )
34+ self .fc2 = minitorch .Linear (120 , 84 , backend = backend )
35+ self .fc3 = minitorch .Linear (84 , C , backend = backend )
5136
5237 def forward (self , x ):
5338 batch_size = x .shape [0 ]
5439 x = self .conv1 (x ).relu ()
55- self . mid = x
40+ x = minitorch . avgpool2d ( x , kernel = ( 2 , 2 ), stride = ( 2 , 2 ))
5641 x = self .conv2 (x ).relu ()
57- self .out = x
58- x = minitorch .avgpool2d (x , (4 , 4 ))
59- x = x .view (batch_size , 392 )
60- x = self .linear1 (x ).relu ()
61- x = minitorch .dropout (x , 0.25 , not self .training )
62- x = self .linear2 (x )
42+ x = minitorch .avgpool2d (x , kernel = (2 , 2 ), stride = (2 , 2 ))
43+ x = x .view (batch_size , 16 * 4 * 4 )
44+ x = self .fc1 (x ).relu ()
45+ x = minitorch .dropout (x , 0.2 , not self .training )
46+ x = self .fc2 (x ).relu ()
47+ x = minitorch .dropout (x , 0.2 , not self .training )
48+ x = self .fc3 (x )
6349 x = minitorch .logsoftmax (x , dim = 1 )
6450 return x
6551
6652
67- def default_log_fn (epoch , total_loss , correct , total ):
53+ def default_log_fn (epoch , total_loss , correct , total , loss_list ):
6854 print (
6955 f"Epoch { epoch } | loss { total_loss / total :.2f} | valid acc { correct / total :.2f} "
7056 )
@@ -74,11 +60,11 @@ def train(
7460 model ,
7561 train_loader ,
7662 val_loader ,
77- learning_rate ,
63+ learning_rate = 1e-2 ,
7864 max_epochs = 50 ,
7965 log_fn = default_log_fn ,
8066):
81- optim = minitorch .SGD (model .parameters (), learning_rate )
67+ optim = minitorch .RMSProp (model .parameters (), learning_rate )
8268 for epoch in range (1 , max_epochs + 1 ):
8369 total_loss = 0.0
8470 model .train ()
@@ -133,14 +119,14 @@ def train(
133119 batch_size = args .batch_size ,
134120 shuffle = True ,
135121 backend = backend ,
136- transform = mnist_transform
122+ transform = preprocess
137123 )
138124 val_loader = DataLoader (
139125 mnist_val ,
140126 batch_size = args .batch_size ,
141127 shuffle = False ,
142128 backend = backend ,
143- transform = mnist_transform
129+ transform = preprocess
144130 )
145131
146132 model = Network (backend = backend )
0 commit comments