Skip to content

Commit 1f40224

Browse files
committed
- Update requirements
- Update training code
1 parent d35a86a commit 1f40224

File tree

3 files changed

+34
-12
lines changed

3 files changed

+34
-12
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,4 +133,5 @@ pyodide
133133

134134
# IDE and editor specific ignores
135135
.vscode/
136-
.idea/
136+
.idea/
137+
log/

examples/run_mnist_multiclass.py

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
1+
"""Train a LeNet-5 CNN on MNIST dataset"""
2+
13
import argparse
24
import numba
35
import numpy as np
4-
import matplotlib.pyplot as plt
6+
import os
7+
import shutil
58
import sys
69
import warnings
710
warnings.filterwarnings("ignore")
11+
from tensorboardX import SummaryWriter
812
from tqdm import tqdm
913

1014
import minitorch
@@ -50,7 +54,7 @@ def forward(self, x):
5054
return x
5155

5256

53-
def default_log_fn(epoch, total_loss, correct, total, loss_list):
57+
def default_log_fn(epoch, total_loss, correct, total):
5458
print(
5559
f"Epoch {epoch} | loss {total_loss / total:.2f} | valid acc {correct / total:.2f}"
5660
)
@@ -60,16 +64,18 @@ def train(
6064
model,
6165
train_loader,
6266
val_loader,
67+
logger=None,
6368
learning_rate=1e-2,
6469
max_epochs=50,
6570
log_fn=default_log_fn,
66-
):
71+
):
6772
optim = minitorch.RMSProp(model.parameters(), learning_rate)
73+
best_val_acc = float('-inf')
6874
for epoch in range(1, max_epochs + 1):
6975
total_loss = 0.0
7076
model.train()
71-
pbar = tqdm(train_loader, total=len(train_loader), desc=f"Train epoch {epoch}/{max_epochs}")
72-
for X_train, y_train in pbar:
77+
pbar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Train epoch {epoch}/{max_epochs}")
78+
for i, (X_train, y_train) in pbar:
7379
optim.zero_grad()
7480
out = model.forward(X_train.view(X_train.shape[0], 1, H, W))
7581
loss = minitorch.nll_loss(out, y_train)
@@ -78,6 +84,9 @@ def train(
7884
total_loss += loss.item()
7985
optim.step()
8086
pbar.set_postfix(loss=loss.item())
87+
88+
if logger:
89+
logger.add_scalar('Loss/train', loss.item(), (epoch - 1) * len(train_loader) + (i + 1))
8190

8291
correct = 0
8392
total = 0
@@ -89,7 +98,13 @@ def train(
8998
correct += (y_hat == y_val).sum().item()
9099
total += y_val.shape[0]
91100
pbar.set_postfix(acc=correct / total * 100)
92-
101+
102+
if best_val_acc < correct / total:
103+
best_val_acc = correct / total
104+
model.save_weights("mnist_model.npz")
105+
print("Model saved to mnist_model.npz")
106+
107+
logger.add_scalar('Accuracy/val', correct / total * 100, epoch)
93108
log_fn(epoch, total_loss, correct, total)
94109

95110

@@ -100,6 +115,7 @@ def train(
100115
parser.add_argument("--epochs", type=int, default=1, help="Number of epochs to train for")
101116
parser.add_argument("--lr", type=float, default=0.01, help="Learning rate")
102117
parser.add_argument("--data_dir", type=str, default="/home/minh/datasets/", help="Directory containing MNIST dataset")
118+
parser.add_argument("--log_dir", type=str, default=None, help="Directory to log training parameters")
103119
args = parser.parse_args()
104120

105121
if args.backend == "gpu" and numba.cuda.is_available():
@@ -130,9 +146,13 @@ def train(
130146
)
131147

132148
model = Network(backend=backend)
149+
150+
logger = None
151+
if args.log_dir:
152+
if os.path.exists(args.log_dir):
153+
shutil.rmtree(args.log_dir)
154+
os.makedirs(args.log_dir)
155+
logger = SummaryWriter(args.log_dir)
133156

134157
print("Starting training...")
135-
train(model, train_loader, val_loader, args.lr, max_epochs=args.epochs)
136-
137-
model.save_weights("mnist_model.npz")
138-
print("Model saved to mnist_model.npz")
158+
train(model, train_loader, val_loader, logger, args.lr, max_epochs=args.epochs)

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,6 @@ pre-commit==2.20.0
77
pytest==7.1.2
88
pytest-env
99
pytest-runner==5.2
10-
matplotlib==3.10.6
10+
tensorboardX==2.6.4
11+
tensorboard==2.20.0
1112
typing_extensions

0 commit comments

Comments
 (0)