Skip to content

Commit 5f615b5

Browse files
committed
add clamping for numerical stability
1 parent d76477f commit 5f615b5

File tree

3 files changed

+6
-5
lines changed

3 files changed

+6
-5
lines changed

examples/train_sentiment.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ def train(
6767
optim.zero_grad()
6868
out = model.forward(X_train)
6969
prob = (out * y_train) + (out - 1.0) * (y_train - 1.0)
70+
prob = prob + 1e-10 # for numerical stability
7071
loss = -(prob.log() / y_train.shape[0]).sum().view(1)
7172
loss.backward()
7273

@@ -81,9 +82,9 @@ def train(
8182
total = 0
8283
model.eval()
8384
pbar = tqdm(val_loader, total=len(val_loader), desc=f"Val epoch {epoch}/{max_epochs}")
84-
for X_val, y_val in pbar:
85+
for X_val, y_val in pbar:
8586
out = model.forward(X_val)
86-
preds = (out > 0.5).astype(y_val.dtype)
87+
preds = (out > 0.5)
8788
correct += (preds == y_val).sum().item()
8889
total += y_val.shape[0]
8990
pbar.set_postfix(acc=correct / total * 100)

minitorch/datasets/uci_sentiment.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,10 +81,10 @@ def encode(self):
8181
max_length = 0
8282
for sent in self.sentences:
8383
max_length = max(max_length, len(sent))
84-
84+
8585
unks = set()
86-
unk_emb = [0.1 * (random.random() - 0.5) for i in range(max_length)]
87-
86+
unk_emb = [0.1 * (random.random() - 0.5) for i in range(self.emb_lookup.d_emb)]
87+
8888
self.samples = encode_sentences(self.sentences, max_length, self.emb_lookup, unk_emb, unks)
8989

9090
def __len__(self):

sentiment_model.npz

474 KB
Binary file not shown.

0 commit comments

Comments
 (0)