1+ """Train a LeNet-5 CNN on MNIST dataset"""
2+
13import argparse
24import numba
35import numpy as np
4- import matplotlib .pyplot as plt
6+ import os
7+ import shutil
58import sys
69import warnings
710warnings .filterwarnings ("ignore" )
11+ from tensorboardX import SummaryWriter
812from tqdm import tqdm
913
1014import minitorch
@@ -50,7 +54,7 @@ def forward(self, x):
5054 return x
5155
5256
53- def default_log_fn (epoch , total_loss , correct , total , loss_list ):
57+ def default_log_fn (epoch , total_loss , correct , total ):
5458 print (
5559 f"Epoch { epoch } | loss { total_loss / total :.2f} | valid acc { correct / total :.2f} "
5660 )
@@ -60,16 +64,18 @@ def train(
6064 model ,
6165 train_loader ,
6266 val_loader ,
67+ logger = None ,
6368 learning_rate = 1e-2 ,
6469 max_epochs = 50 ,
6570 log_fn = default_log_fn ,
66- ):
71+ ):
6772 optim = minitorch .RMSProp (model .parameters (), learning_rate )
73+ best_val_acc = float ('-inf' )
6874 for epoch in range (1 , max_epochs + 1 ):
6975 total_loss = 0.0
7076 model .train ()
71- pbar = tqdm (train_loader , total = len (train_loader ), desc = f"Train epoch { epoch } /{ max_epochs } " )
72- for X_train , y_train in pbar :
77+ pbar = tqdm (enumerate ( train_loader ) , total = len (train_loader ), desc = f"Train epoch { epoch } /{ max_epochs } " )
78+ for i , ( X_train , y_train ) in pbar :
7379 optim .zero_grad ()
7480 out = model .forward (X_train .view (X_train .shape [0 ], 1 , H , W ))
7581 loss = minitorch .nll_loss (out , y_train )
@@ -78,6 +84,9 @@ def train(
7884 total_loss += loss .item ()
7985 optim .step ()
8086 pbar .set_postfix (loss = loss .item ())
87+
88+ if logger :
89+ logger .add_scalar ('Loss/train' , loss .item (), (epoch - 1 ) * len (train_loader ) + (i + 1 ))
8190
8291 correct = 0
8392 total = 0
@@ -89,7 +98,13 @@ def train(
8998 correct += (y_hat == y_val ).sum ().item ()
9099 total += y_val .shape [0 ]
91100 pbar .set_postfix (acc = correct / total * 100 )
92-
101+
102+ if best_val_acc < correct / total :
103+ best_val_acc = correct / total
104+ model .save_weights ("mnist_model.npz" )
105+ print ("Model saved to mnist_model.npz" )
106+
107+ logger .add_scalar ('Accuracy/val' , correct / total * 100 , epoch )
93108 log_fn (epoch , total_loss , correct , total )
94109
95110
@@ -100,6 +115,7 @@ def train(
100115 parser .add_argument ("--epochs" , type = int , default = 1 , help = "Number of epochs to train for" )
101116 parser .add_argument ("--lr" , type = float , default = 0.01 , help = "Learning rate" )
102117 parser .add_argument ("--data_dir" , type = str , default = "/home/minh/datasets/" , help = "Directory containing MNIST dataset" )
118+ parser .add_argument ("--log_dir" , type = str , default = None , help = "Directory to log training parameters" )
103119 args = parser .parse_args ()
104120
105121 if args .backend == "gpu" and numba .cuda .is_available ():
@@ -130,9 +146,13 @@ def train(
130146 )
131147
132148 model = Network (backend = backend )
149+
150+ logger = None
151+ if args .log_dir :
152+ if os .path .exists (args .log_dir ):
153+ shutil .rmtree (args .log_dir )
154+ os .makedirs (args .log_dir )
155+ logger = SummaryWriter (args .log_dir )
133156
134157 print ("Starting training..." )
135- train (model , train_loader , val_loader , args .lr , max_epochs = args .epochs )
136-
137- model .save_weights ("mnist_model.npz" )
138- print ("Model saved to mnist_model.npz" )
158+ train (model , train_loader , val_loader , logger , args .lr , max_epochs = args .epochs )
0 commit comments