Skip to content

Commit 4a2eabf

Browse files
committed
Dynamically select which backend to train
1 parent 3f3e5a3 commit 4a2eabf

File tree

9 files changed

+401
-168
lines changed

9 files changed

+401
-168
lines changed

minitorch/__init__.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,15 @@
1010

1111
from .autodiff import * # noqa: F401,F403
1212
from .backends.cuda_ops import * # noqa: F401,F403
13-
from .datasets import dummy_datasets # noqa: F401,F403
13+
from .datasets import dummy_datasets, mnist # noqa: F401,F403
1414
from .backends.fast_conv import * # noqa: F401,F403
15-
1615
from .backends.fast_ops import * # noqa: F401,F403
17-
from .nn.module import * # noqa: F401,F403
1816

17+
from .nn.module import * # noqa: F401,F403
1918
from .nn.nn import * # noqa: F401,F403
2019
from .nn.optim import * # noqa: F401,F403
20+
from .nn.layers import *
21+
from .nn.loss import nll_loss, bce_loss, cross_entropy_loss, mse_loss
22+
from .dataloader import DataLoader
2123

22-
version = "0.4"
24+
version = "1.0"

minitorch/dataloader.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import numpy as np
2+
from .tensor.functions import tensor
3+
4+
5+
class DataLoader:
6+
def __init__(self, dataset, backend, batch_size=1, shuffle=False):
7+
self.dataset = dataset
8+
self.batch_size = batch_size
9+
self.shuffle = shuffle
10+
self.backend = backend
11+
12+
def __len__(self):
13+
return int(np.ceil(len(self.dataset) / self.batch_size))
14+
15+
def __iter__(self):
16+
indices = np.arange(len(self.dataset))
17+
if self.shuffle:
18+
np.random.shuffle(indices)
19+
20+
for i in range(0, len(indices), self.batch_size):
21+
batch_indices = indices[i:i + self.batch_size]
22+
batch_data = [self.dataset[j] for j in batch_indices]
23+
24+
inputs, labels = zip(*batch_data)
25+
26+
inputs_tensor = tensor(list(inputs), backend=self.backend)
27+
labels_tensor = tensor(list(labels), backend=self.backend)
28+
29+
yield inputs_tensor, labels_tensor

minitorch/datasets/mnist.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
import os
2+
import gzip
3+
import shutil
4+
import urllib.request
5+
import numpy as np
6+
7+
8+
class MNISTDataset:
9+
@staticmethod
10+
def load_mnist_img(path):
11+
try:
12+
with open(path, "rb") as fi:
13+
_ = int.from_bytes(fi.read(4), "big") # magic number
14+
n_images = int.from_bytes(fi.read(4), "big")
15+
h = int.from_bytes(fi.read(4), "big")
16+
w = int.from_bytes(fi.read(4), "big")
17+
buffer = fi.read()
18+
images = np.frombuffer(buffer, dtype=np.uint8).reshape(n_images, h, w)
19+
except Exception as e:
20+
print(f"Could not read MNIST image file at {path}")
21+
print(e)
22+
exit(1)
23+
return images
24+
25+
@staticmethod
26+
def load_mnist_lbl(path):
27+
try:
28+
with open(path, "rb") as fi:
29+
_ = int.from_bytes(fi.read(4), "big")
30+
n_labels = int.from_bytes(fi.read(4), "big")
31+
buffer = fi.read()
32+
labels = np.frombuffer(buffer, dtype=np.uint8)
33+
except Exception as e:
34+
print(f"Could not read MNIST label file at {path}")
35+
print(e)
36+
exit(1)
37+
return labels
38+
39+
@staticmethod
40+
def _download_and_extract(root):
41+
"""
42+
Downloads and extracts the MNIST dataset files if they don't exist.
43+
"""
44+
mnist_path = os.path.join(root, "MNIST")
45+
os.makedirs(mnist_path, exist_ok=True)
46+
47+
urls = [
48+
"https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz",
49+
"https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz",
50+
"https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz",
51+
"https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz",
52+
]
53+
54+
for url in urls:
55+
filename = url.split("/")[-1]
56+
gz_path = os.path.join(mnist_path, filename)
57+
uncompressed_path = os.path.join(mnist_path, filename[:-3])
58+
59+
if not os.path.exists(uncompressed_path):
60+
print(f"Downloading {url}")
61+
urllib.request.urlretrieve(url, gz_path)
62+
63+
print(f"Extracting {gz_path}")
64+
with gzip.open(gz_path, 'rb') as f_in:
65+
with open(uncompressed_path, 'wb') as f_out:
66+
shutil.copyfileobj(f_in, f_out)
67+
os.remove(gz_path)
68+
69+
'''
70+
dataset_dir
71+
├── MNIST
72+
├── train-images.idx3-ubyte (train images file)
73+
├── train-labels.idx1-ubyte
74+
├── t10k-images.idx3-ubyte (val images file)
75+
├── t10k-labels.idx1-ubyte
76+
'''
77+
78+
def __init__(self, root, download=True, train=True):
79+
if download and not os.path.exists(os.path.join(root, "MNIST")):
80+
self._download_and_extract(root)
81+
82+
if train:
83+
img_dir = os.path.join(root, "MNIST", "train-images-idx3-ubyte")
84+
lbl_dir = os.path.join(root, "MNIST", "train-labels-idx1-ubyte")
85+
else:
86+
img_dir = os.path.join(root, "MNIST", "t10k-images-idx3-ubyte")
87+
lbl_dir = os.path.join(root, "MNIST", "t10k-labels-idx1-ubyte")
88+
89+
images = self.load_mnist_img(img_dir)
90+
labels = self.load_mnist_lbl(lbl_dir)
91+
92+
self.data = [(image, label) for image, label in zip(images, labels)]
93+
94+
def __len__(self):
95+
return len(self.data)
96+
97+
def __getitem__(self, index):
98+
return self.data[index]

minitorch/nn/init.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import math
2-
from ..tensor import tensor
32

43

54
def kaiming_uniform(tensor, fan_in, **kwargs):

minitorch/nn/layers.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55

66

77
class Linear(Module):
8-
def __init__(self, in_size, out_size, initializer=init.kaiming_uniform):
8+
def __init__(self, in_size, out_size, backend, initializer=init.kaiming_uniform):
99
super().__init__()
10-
self.weights = Parameter(rand((in_size, out_size)))
10+
self.weights = Parameter(rand((in_size, out_size), backend=backend))
1111
initializer(self.weights.value, in_size)
12-
self.bias = Parameter(zeros((out_size,)))
12+
self.bias = Parameter(zeros((out_size,), backend=backend))
1313
self.out_size = out_size
1414

1515
def forward(self, x):
@@ -20,25 +20,25 @@ def forward(self, x):
2020

2121

2222
class Conv1d(Module):
23-
def __init__(self, in_channels, out_channels, kernel_width, initializer=init.kaiming_uniform):
23+
def __init__(self, in_channels, out_channels, kernel_width, backend, initializer=init.kaiming_uniform):
2424
super().__init__()
25-
self.weights = Parameter(rand((out_channels, in_channels, kernel_width)))
25+
self.weights = Parameter(rand((out_channels, in_channels, kernel_width), backend=backend))
2626
fan_in = in_channels * kernel_width
2727
initializer(self.weights.value, fan_in)
28-
self.bias = Parameter(zeros((1, out_channels, 1)))
28+
self.bias = Parameter(zeros((1, out_channels, 1), backend=backend))
2929

3030
def forward(self, input):
3131
out = fast_conv.conv1d(input, self.weights.value) + self.bias.value
3232
return out
3333

3434

3535
class Conv2d(Module):
36-
def __init__(self, in_channels, out_channels, kh, kw, initializer=init.kaiming_uniform):
36+
def __init__(self, in_channels, out_channels, kh, kw, backend, initializer=init.kaiming_uniform):
3737
super().__init__()
38-
self.weights = Parameter(rand((out_channels, in_channels, kh, kw)))
38+
self.weights = Parameter(rand((out_channels, in_channels, kh, kw), backend=backend))
3939
fan_in = in_channels * kh * kw
4040
initializer(self.weights.value, fan_in)
41-
self.bias = Parameter(zeros((out_channels, 1, 1)))
41+
self.bias = Parameter(zeros((out_channels, 1, 1), backend=backend))
4242

4343
def forward(self, input):
4444
out = fast_conv.conv2d(input, self.weights.value) + self.bias.value

minitorch/nn/loss.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
from ..tensor import tensor
2+
from .nn import logsoftmax
3+
4+
5+
def mse_loss(y_pred, y_true):
6+
"""
7+
Mean Squared Error Loss.
8+
9+
Args:
10+
y_pred (Tensor): Predicted values, shape (batch_size, 1).
11+
y_true (Tensor): True values, shape (batch_size, 1).
12+
13+
Returns:
14+
Tensor: The mean squared error loss.
15+
"""
16+
diff = y_pred - y_true
17+
return (diff * diff).sum() / y_pred.shape[0]
18+
19+
20+
def nll_loss(y_pred_log_probs, y_true):
21+
"""
22+
Negative Log-Likelihood Loss.
23+
24+
Args:
25+
y_pred_log_probs (Tensor): Log-probabilities of predictions, shape (batch_size, num_classes).
26+
y_true (Tensor): True class indices, shape (batch_size,).
27+
28+
Returns:
29+
Tensor: The negative log-likelihood loss.
30+
"""
31+
batch_size, num_classes = y_pred_log_probs.shape
32+
33+
# Create one-hot encoded tensor for y_true
34+
y_one_hot = y_pred_log_probs.zeros(y_pred_log_probs.shape)
35+
y_one_hot.requires_grad_(False)
36+
37+
for i in range(batch_size):
38+
y_one_hot[i, int(y_true[i].item())] = 1
39+
40+
loss = -(y_pred_log_probs * y_one_hot).sum()
41+
return loss / batch_size
42+
43+
44+
def cross_entropy_loss(y_pred_logits, y_true):
45+
"""
46+
Cross-Entropy Loss.
47+
48+
Args:
49+
y_pred_logits (Tensor): Raw logits from the model, shape (batch_size, num_classes).
50+
y_true (Tensor): True class indices, shape (batch_size,).
51+
52+
Returns:
53+
Tensor: The cross-entropy loss.
54+
"""
55+
log_probs = logsoftmax(y_pred_logits, dim=1)
56+
return nll_loss(log_probs, y_true)
57+
58+
59+
def bce_loss(y_pred, y_true):
60+
"""
61+
Binary Cross-Entropy Loss.
62+
63+
Args:
64+
y_pred (Tensor): Predicted probabilities, shape (batch_size, 1).
65+
y_true (Tensor): True labels (0 or 1), shape (batch_size, 1).
66+
67+
Returns:
68+
Tensor: The binary cross-entropy loss.
69+
"""
70+
loss = -(y_true * y_pred.log() + (1 - y_true) * (1 - y_pred).log())
71+
return loss.sum() / y_pred.shape[0]

minitorch/nn/module.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from __future__ import annotations
22

33
from typing import Any, Dict, Optional, Sequence, Tuple
4+
import minitorch
5+
import numpy as np
46

57

68
class Module:
@@ -121,6 +123,14 @@ def _addindent(s_: str, numSpaces: int) -> str:
121123
main_str += ")"
122124
return main_str
123125

126+
def save_weights(self, path: str) -> None:
127+
weights = {name: p.value.to_numpy() for name, p in self.named_parameters()}
128+
np.savez(path, **weights)
129+
130+
def load_weights(self, path: str) -> None:
131+
weights = np.load(path)
132+
for name, p in self.named_parameters():
133+
p.update(minitorch.tensor(weights[name].tolist()))
124134

125135
class Parameter:
126136
"""

minitorch/nn/optim.py

Lines changed: 71 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from typing import Sequence
2+
import math
23

34
from .module import Parameter
45
from ..scalar.scalar import Scalar
@@ -10,9 +11,11 @@ def __init__(self, parameters: Sequence[Parameter]):
1011

1112

1213
class SGD(Optimizer):
13-
def __init__(self, parameters: Sequence[Parameter], lr: float = 1.0):
14+
def __init__(self, parameters: Sequence[Parameter], lr: float = 1.0, momentum: float = 0.0):
1415
super().__init__(parameters)
1516
self.lr = lr
17+
self.momentum = momentum
18+
self.velocities = {}
1619

1720
def zero_grad(self) -> None:
1821
for p in self.parameters:
@@ -29,9 +32,70 @@ def step(self) -> None:
2932
for p in self.parameters:
3033
if p.value is None:
3134
continue
32-
if hasattr(p.value, "derivative"):
33-
if p.value.derivative is not None:
34-
p.update(Scalar(p.value.data - self.lr * p.value.derivative))
35-
elif hasattr(p.value, "grad"):
36-
if p.value.grad is not None:
37-
p.update(p.value - self.lr * p.value.grad)
35+
36+
is_scalar = hasattr(p.value, "derivative")
37+
38+
grad = p.value.derivative if is_scalar and p.value.derivative is not None else p.value.grad
39+
40+
if grad is None:
41+
continue
42+
43+
if self.momentum == 0.0:
44+
# Standard SGD
45+
update_val = self.lr * grad
46+
else:
47+
# SGD with momentum
48+
if p not in self.velocities:
49+
self.velocities[p] = 0.0 if is_scalar else grad * 0.0
50+
51+
v = self.velocities[p]
52+
v_new = self.momentum * v + grad
53+
self.velocities[p] = v_new
54+
update_val = self.lr * v_new
55+
56+
if is_scalar:
57+
p.update(Scalar(p.value.data - update_val))
58+
else:
59+
p.update(p.value - update_val)
60+
61+
62+
class RMSProp(Optimizer):
63+
def __init__(self, parameters: Sequence[Parameter], lr: float = 1e-2, decay_rate: float = 0.9, eps: float = 1e-8):
64+
super().__init__(parameters)
65+
self.lr = lr
66+
self.decay_rate = decay_rate
67+
self.eps = eps
68+
self.s_vals = {}
69+
70+
def step(self) -> None:
71+
for p in self.parameters:
72+
if p.value is None:
73+
continue
74+
75+
is_scalar = hasattr(p.value, "derivative")
76+
77+
grad = p.value.derivative if is_scalar and p.value.derivative is not None else p.value.grad
78+
79+
if grad is None:
80+
continue
81+
82+
if p not in self.s_vals:
83+
if is_scalar:
84+
self.s_vals[p] = 0.0
85+
else:
86+
self.s_vals[p] = grad * 0.0
87+
88+
s = self.s_vals[p]
89+
90+
s_new = self.decay_rate * s + (1 - self.decay_rate) * (grad * grad)
91+
self.s_vals[p] = s_new
92+
93+
if is_scalar:
94+
update_val = self.lr * grad / (math.sqrt(s_new) + self.eps)
95+
else:
96+
update_val = self.lr * grad / (s_new.sqrt() + self.eps)
97+
98+
if is_scalar:
99+
p.update(Scalar(p.value.data - update_val))
100+
else:
101+
p.update(p.value - update_val)

0 commit comments

Comments
 (0)