Skip to content

Commit 471f22d

Browse files
committed
train sentiment RNN
1 parent 5f615b5 commit 471f22d

File tree

4 files changed

+141
-4
lines changed

4 files changed

+141
-4
lines changed
Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
"""Train a LeNet-5 CNN on MNIST dataset"""
2-
31
import argparse
42
import numba
53
import numpy as np
Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
"""Train a 1D CNN on sentiment classification data"""
2-
31
import argparse
42
import embeddings
53
import numba

examples/train_sentiment_rnn.py

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
import argparse
2+
import embeddings
3+
import numba
4+
import numpy as np
5+
import os
6+
import shutil
7+
import sys
8+
import warnings
9+
warnings.filterwarnings("ignore")
10+
from tensorboardX import SummaryWriter
11+
from tqdm import tqdm
12+
from sklearn.model_selection import train_test_split
13+
14+
import minitorch
15+
from minitorch.datasets import uci_sentiment
16+
from minitorch.dataloader import DataLoader
17+
18+
FastTensorBackend = minitorch.TensorBackend(minitorch.FastOps)
19+
if numba.cuda.is_available():
20+
GPUBackend = minitorch.TensorBackend(minitorch.CudaOps)
21+
22+
23+
class RNNSentiment(minitorch.Module):
24+
def __init__(self, embedding_size=50, hidden_size=100, backend=FastTensorBackend):
25+
super().__init__()
26+
self.hidden_size = hidden_size
27+
self.rnn = minitorch.RNN(embedding_size, hidden_size, backend=backend)
28+
self.fc = minitorch.Linear(hidden_size, 1, backend=backend)
29+
30+
def forward(self, embeddings):
31+
rnn_out, hidden = self.rnn(embeddings)
32+
x = hidden.view(hidden.shape[0], self.hidden_size)
33+
x = self.fc(x)
34+
x = minitorch.dropout(x, 0.2, not self.training)
35+
return x.sigmoid().view(x.shape[0])
36+
37+
38+
def default_log_fn(epoch, total_loss, correct, total):
39+
print(
40+
f"Epoch {epoch} | loss {total_loss / total:.2f} | valid acc {correct / total:.2f}"
41+
)
42+
43+
44+
def train(
45+
model,
46+
train_loader,
47+
val_loader,
48+
logger=None,
49+
learning_rate=1e-2,
50+
max_epochs=50,
51+
log_fn=default_log_fn,
52+
):
53+
optim = minitorch.RMSProp(model.parameters(), learning_rate)
54+
best_val_acc = float('-inf')
55+
for epoch in range(1, max_epochs + 1):
56+
total_loss = 0.0
57+
model.train()
58+
pbar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Train epoch {epoch}/{max_epochs}")
59+
for i, (X_train, y_train) in pbar:
60+
optim.zero_grad()
61+
out = model.forward(X_train)
62+
prob = (out * y_train) + (out - 1.0) * (y_train - 1.0)
63+
prob = prob + 1e-10 # for numerical stability
64+
loss = -(prob.log() / y_train.shape[0]).sum().view(1)
65+
loss.backward()
66+
67+
total_loss += loss.item()
68+
optim.step()
69+
pbar.set_postfix(loss=loss.item())
70+
71+
if logger:
72+
logger.add_scalar('Loss/train', loss.item(), (epoch - 1) * len(train_loader) + (i + 1))
73+
74+
correct = 0
75+
total = 0
76+
model.eval()
77+
pbar = tqdm(val_loader, total=len(val_loader), desc=f"Val epoch {epoch}/{max_epochs}")
78+
for X_val, y_val in pbar:
79+
out = model.forward(X_val)
80+
preds = (out > 0.5)
81+
correct += (preds == y_val).sum().item()
82+
total += y_val.shape[0]
83+
pbar.set_postfix(acc=correct / total * 100)
84+
85+
if best_val_acc < correct / total:
86+
best_val_acc = correct / total
87+
model.save_weights("sentiment_rnn_model.npz")
88+
print(f"Model saved to sentiment_rnn_model.npz (val acc: {best_val_acc:.4f})")
89+
90+
if logger:
91+
logger.add_scalar('Accuracy/val', correct / total * 100, epoch)
92+
log_fn(epoch, total_loss, correct, total)
93+
94+
95+
if __name__ == "__main__":
96+
parser = argparse.ArgumentParser()
97+
parser.add_argument("--backend", default="cpu", help="backend mode")
98+
parser.add_argument("--batch_size", type=int, default=16, help="Batch size for training")
99+
parser.add_argument("--epochs", type=int, default=10, help="Number of epochs to train for")
100+
parser.add_argument("--lr", type=float, default=0.01, help="Learning rate")
101+
parser.add_argument("--hidden_size", type=int, default=100, help="RNN hidden size")
102+
parser.add_argument("--data_dir", type=str, default="/home/minh/datasets/", help="Directory containing sentiment dataset")
103+
parser.add_argument("--log_dir", type=str, default=None, help="Directory to log training parameters")
104+
args = parser.parse_args()
105+
106+
if args.backend == "gpu" and numba.cuda.is_available():
107+
backend = GPUBackend
108+
print("Using CUDA backend")
109+
else:
110+
if args.backend == "gpu":
111+
print("CUDA backend not available, using CPU instead.", file=sys.stderr)
112+
backend = FastTensorBackend
113+
print("Using CPU backend")
114+
115+
emb_lookup = embeddings.GloveEmbedding("wikipedia_gigaword", d_emb=50, show_progress=True)
116+
ds = uci_sentiment.UCISentimentDataset(root=args.data_dir, emb_lookup=emb_lookup)
117+
sentiment_train, sentiment_val = train_test_split(ds, test_size=0.2, random_state=42)
118+
train_loader = DataLoader(
119+
sentiment_train,
120+
batch_size=args.batch_size,
121+
shuffle=True,
122+
backend=backend
123+
)
124+
val_loader = DataLoader(
125+
sentiment_val,
126+
batch_size=args.batch_size,
127+
shuffle=False,
128+
backend=backend
129+
)
130+
131+
model = RNNSentiment(embedding_size=50, hidden_size=args.hidden_size, backend=backend)
132+
133+
logger = None
134+
if args.log_dir:
135+
if os.path.exists(args.log_dir):
136+
shutil.rmtree(args.log_dir)
137+
os.makedirs(args.log_dir)
138+
logger = SummaryWriter(args.log_dir)
139+
140+
print("Starting training...")
141+
train(model, train_loader, val_loader, logger, args.lr, max_epochs=args.epochs)

sentiment_model.npz

-474 KB
Binary file not shown.

0 commit comments

Comments
 (0)