Skip to content

Commit 345a00c

Browse files
committed
Updates
1 parent 4b7ea5d commit 345a00c

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

torchao/testing/model_architectures.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
import torch
1010
import torch.nn as nn
11-
from torch.nn import RMSNorm
1211

1312

1413
class ToyLinearModel(torch.nn.Module):
@@ -73,8 +72,8 @@ def __init__(self, hidden_dim, num_heads=8, mlp_ratio=4, dtype=torch.bfloat16):
7372
)
7473

7574
# Layer norms
76-
self.norm1 = RMSNorm(hidden_dim, dtype=dtype)
77-
self.norm2 = RMSNorm(hidden_dim, dtype=dtype)
75+
self.norm1 = nn.RMSNorm(hidden_dim, dtype=dtype)
76+
self.norm2 = nn.RMSNorm(hidden_dim, dtype=dtype)
7877

7978
# Activation
8079
self.activation = torch.nn.GELU()

0 commit comments

Comments
 (0)