Skip to content

Commit 328b7bf

Browse files
committed
Update
[ghstack-poisoned]
1 parent dd9f50d commit 328b7bf

File tree

3 files changed

+0
-253
lines changed

3 files changed

+0
-253
lines changed

benchmarks/microbenchmarks/test/benchmark_config.yml

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -50,31 +50,3 @@ model_params:
5050
# device: "cpu"
5151
# model_type: "linear"
5252
# enable_profiler: true # Enable profiling for this model
53-
54-
- name: "bf16_rms_norm_linear_activation"
55-
matrix_shapes:
56-
- name: "custom"
57-
shapes: [
58-
[2048, 4096, 1024],
59-
]
60-
high_precision_dtype: "torch.bfloat16"
61-
use_torch_compile: true
62-
torch_compile_mode: "max-autotune"
63-
device: "cuda"
64-
model_type: "rms_norm_linear_activation"
65-
enable_profiler: true
66-
enable_memory_profile: true
67-
68-
- name: "bf16_transformer_block"
69-
matrix_shapes:
70-
- name: "custom"
71-
shapes: [
72-
[2048, 4096, 1024], # For transformer_block, k is the hidden dimension
73-
]
74-
high_precision_dtype: "torch.bfloat16"
75-
use_torch_compile: true
76-
torch_compile_mode: "max-autotune"
77-
device: "cuda"
78-
model_type: "transformer_block"
79-
enable_profiler: true
80-
enable_memory_profile: true

benchmarks/microbenchmarks/test/test_utils.py

Lines changed: 0 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,8 @@
1717
Float8DynamicActivationFloat8SemiSparseWeightConfig,
1818
Int4WeightOnlyConfig,
1919
LNLinearSigmoid,
20-
RMSNorm,
21-
RMSNormLinearActivation,
2220
SemiSparseWeightConfig,
2321
ToyLinearModel,
24-
TransformerBlock,
2522
clean_caches,
2623
create_model_and_input,
2724
generate_results_csv,
@@ -165,61 +162,6 @@ def test_ln_linear_sigmoid(self):
165162
torch.all((out >= 0) & (out <= 1))
166163
) # Check sigmoid output range
167164

168-
def test_rms_norm(self):
169-
# Test RMSNorm
170-
rms_norm = RMSNorm(dim=64)
171-
x = torch.randn(16, 64)
172-
out = rms_norm(x)
173-
self.assertEqual(out.shape, (16, 64))
174-
175-
# Test with different eps
176-
rms_norm = RMSNorm(dim=64, eps=1e-5)
177-
out = rms_norm(x)
178-
self.assertEqual(out.shape, (16, 64))
179-
180-
def test_rms_norm_linear_activation(self):
181-
# Test with default GELU activation
182-
model = RMSNormLinearActivation(fc_dim1=64, fc_dim2=32, dtype=torch.float32)
183-
x = torch.randn(16, 64)
184-
out = model(x)
185-
self.assertEqual(out.shape, (16, 32))
186-
self.assertEqual(out.dtype, torch.float32)
187-
188-
# Test with ReLU activation
189-
model = RMSNormLinearActivation(fc_dim1=64, fc_dim2=32, dtype=torch.float32, activation="relu")
190-
out = model(x)
191-
self.assertEqual(out.shape, (16, 32))
192-
self.assertTrue(torch.all(out >= 0)) # Check ReLU output range
193-
194-
# Test with SiLU activation
195-
model = RMSNormLinearActivation(fc_dim1=64, fc_dim2=32, dtype=torch.float32, activation="silu")
196-
out = model(x)
197-
self.assertEqual(out.shape, (16, 32))
198-
199-
# Test with invalid activation
200-
with self.assertRaises(ValueError):
201-
RMSNormLinearActivation(fc_dim1=64, fc_dim2=32, dtype=torch.float32, activation="invalid")
202-
203-
def test_transformer_block(self):
204-
# Test with default parameters
205-
model = TransformerBlock(hidden_dim=64, num_heads=8, mlp_ratio=4, dtype=torch.float32)
206-
x = torch.randn(16, 16, 64) # [batch_size, seq_len, hidden_dim]
207-
out = model(x)
208-
self.assertEqual(out.shape, (16, 16, 64))
209-
self.assertEqual(out.dtype, torch.float32)
210-
211-
# Test with different parameters
212-
model = TransformerBlock(hidden_dim=128, num_heads=4, mlp_ratio=2, dtype=torch.float32)
213-
x = torch.randn(8, 32, 128)
214-
out = model(x)
215-
self.assertEqual(out.shape, (8, 32, 128))
216-
217-
# Test with different head dimensions
218-
model = TransformerBlock(hidden_dim=96, num_heads=6, mlp_ratio=3, dtype=torch.float32)
219-
x = torch.randn(4, 8, 96)
220-
out = model(x)
221-
self.assertEqual(out.shape, (4, 8, 96))
222-
223165
def test_create_model_and_input(self):
224166
m, k, n = 16, 64, 32
225167
model, input_data = create_model_and_input(
@@ -244,63 +186,6 @@ def test_create_model_and_input(self):
244186
self.assertIsInstance(model, LNLinearSigmoid)
245187
self.assertEqual(input_data.shape, (m, k))
246188

247-
# Test RMSNormLinearActivation
248-
model, input_data = create_model_and_input(
249-
model_type="rms_norm_linear_activation",
250-
m=m,
251-
k=k,
252-
n=n,
253-
high_precision_dtype=torch.float32,
254-
device="cpu",
255-
)
256-
self.assertIsInstance(model, RMSNormLinearActivation)
257-
self.assertEqual(input_data.shape, (m, k))
258-
259-
# Test TransformerBlock
260-
model, input_data = create_model_and_input(
261-
model_type="transformer_block",
262-
m=m,
263-
k=k,
264-
n=n, # n is not used for transformer_block
265-
high_precision_dtype=torch.float32,
266-
device="cpu",
267-
)
268-
self.assertIsInstance(model, TransformerBlock)
269-
self.assertEqual(input_data.shape, (m, 16, k)) # [batch_size, seq_len, hidden_dim]
270-
271-
def test_quantization_on_models(self):
272-
# Test quantization on RMSNormLinearActivation
273-
model = RMSNormLinearActivation(fc_dim1=64, fc_dim2=32, dtype=torch.float32)
274-
x = torch.randn(16, 64)
275-
276-
# Test with Int8WeightOnlyConfig
277-
config = string_to_config(quantization="int8wo", sparsity=None)
278-
if config is not None:
279-
# Skip quantization test if torchao.quantization.quantize is not available
280-
try:
281-
from torchao.quantization import quantize
282-
quantized_model = quantize(model, config)
283-
out = quantized_model(x)
284-
self.assertEqual(out.shape, (16, 32))
285-
except ImportError:
286-
print("Skipping quantization test: torchao.quantization.quantize not available")
287-
288-
# Test quantization on TransformerBlock
289-
model = TransformerBlock(hidden_dim=64, num_heads=8, mlp_ratio=4, dtype=torch.float32)
290-
x = torch.randn(16, 16, 64)
291-
292-
# Test with Int8WeightOnlyConfig
293-
config = string_to_config(quantization="int8wo", sparsity=None)
294-
if config is not None:
295-
# Skip quantization test if torchao.quantization.quantize is not available
296-
try:
297-
from torchao.quantization import quantize
298-
quantized_model = quantize(model, config)
299-
out = quantized_model(x)
300-
self.assertEqual(out.shape, (16, 16, 64))
301-
except ImportError:
302-
print("Skipping quantization test: torchao.quantization.quantize not available")
303-
304189
def test_generate_results_csv(self):
305190
results = [
306191
BenchmarkResult(

benchmarks/microbenchmarks/utils.py

Lines changed: 0 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -383,108 +383,6 @@ def forward(self, x):
383383
return x
384384

385385

386-
class RMSNorm(torch.nn.Module):
387-
def __init__(self, dim, eps=1e-6, dtype=torch.bfloat16):
388-
super().__init__()
389-
self.eps = eps
390-
self.weight = torch.nn.Parameter(torch.ones(dim, dtype=dtype))
391-
392-
def forward(self, x):
393-
norm = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
394-
return x * norm * self.weight
395-
396-
397-
class RMSNormLinearActivation(torch.nn.Module):
398-
def __init__(self, fc_dim1, fc_dim2, dtype=torch.bfloat16, activation="gelu"):
399-
super().__init__()
400-
self.rms_norm = RMSNorm(fc_dim1, dtype=dtype)
401-
self.fc = torch.nn.Linear(fc_dim1, fc_dim2, bias=False).to(dtype)
402-
403-
if activation == "gelu":
404-
self.activation = torch.nn.GELU()
405-
elif activation == "relu":
406-
self.activation = torch.nn.ReLU()
407-
elif activation == "silu":
408-
self.activation = torch.nn.SiLU()
409-
else:
410-
raise ValueError(f"Unsupported activation: {activation}")
411-
412-
def forward(self, x):
413-
x = self.rms_norm(x)
414-
x = self.fc(x)
415-
x = self.activation(x)
416-
return x
417-
418-
419-
class TransformerBlock(torch.nn.Module):
420-
def __init__(self, hidden_dim, num_heads=8, mlp_ratio=4, dtype=torch.bfloat16):
421-
super().__init__()
422-
self.hidden_dim = hidden_dim
423-
self.num_heads = num_heads
424-
self.head_dim = hidden_dim // num_heads
425-
426-
# Self-attention
427-
self.qkv = torch.nn.Linear(hidden_dim, 3 * hidden_dim, bias=False).to(dtype)
428-
self.proj = torch.nn.Linear(hidden_dim, hidden_dim, bias=False).to(dtype)
429-
430-
# MLP
431-
self.mlp_ratio = mlp_ratio
432-
self.mlp_hidden_dim = int(hidden_dim * mlp_ratio)
433-
self.mlp_fc1 = torch.nn.Linear(hidden_dim, self.mlp_hidden_dim, bias=False).to(dtype)
434-
self.mlp_fc2 = torch.nn.Linear(self.mlp_hidden_dim, hidden_dim, bias=False).to(dtype)
435-
436-
# Layer norms
437-
self.norm1 = RMSNorm(hidden_dim, dtype=dtype)
438-
self.norm2 = RMSNorm(hidden_dim, dtype=dtype)
439-
440-
# Activation
441-
self.activation = torch.nn.GELU()
442-
443-
def forward(self, x):
444-
batch_size, seq_len, _ = x.shape
445-
446-
# Self-attention
447-
residual = x
448-
x = self.norm1(x)
449-
450-
# Reshape qkv projection for better memory layout
451-
qkv = self.qkv(x) # [batch_size, seq_len, 3 * hidden_dim]
452-
qkv = qkv.reshape(batch_size, seq_len, 3, self.num_heads, self.head_dim)
453-
qkv = qkv.permute(2, 0, 3, 1, 4) # [3, batch_size, num_heads, seq_len, head_dim]
454-
q, k, v = qkv # Each has shape [batch_size, num_heads, seq_len, head_dim]
455-
456-
# Scaled dot-product attention with proper reshaping
457-
# Reshape for better memory layout and avoid broadcasting issues
458-
q = q.reshape(batch_size * self.num_heads, seq_len, self.head_dim)
459-
k = k.reshape(batch_size * self.num_heads, seq_len, self.head_dim)
460-
v = v.reshape(batch_size * self.num_heads, seq_len, self.head_dim)
461-
462-
# Compute attention scores
463-
attn = (q @ k.transpose(-2, -1)) * (1.0 / (self.head_dim ** 0.5))
464-
attn = torch.softmax(attn, dim=-1)
465-
466-
# Apply attention to values
467-
x = attn @ v # [batch_size * num_heads, seq_len, head_dim]
468-
469-
# Reshape back to original dimensions
470-
x = x.reshape(batch_size, self.num_heads, seq_len, self.head_dim)
471-
x = x.transpose(1, 2).reshape(batch_size, seq_len, self.hidden_dim)
472-
473-
# Project back to hidden dimension
474-
x = self.proj(x)
475-
x = residual + x
476-
477-
# MLP
478-
residual = x
479-
x = self.norm2(x)
480-
x = self.mlp_fc1(x)
481-
x = self.activation(x)
482-
x = self.mlp_fc2(x)
483-
x = residual + x
484-
485-
return x
486-
487-
488386
def string_to_config(
489387
quantization: Optional[str], sparsity: Optional[str], **kwargs
490388
) -> AOBaseConfig:
@@ -678,14 +576,6 @@ def create_model_and_input(
678576
elif model_type == "ln_linear_sigmoid":
679577
model = LNLinearSigmoid(k, n, high_precision_dtype).to(device)
680578
input_data = torch.randn(m, k, device=device, dtype=high_precision_dtype)
681-
elif model_type == "rms_norm_linear_activation":
682-
model = RMSNormLinearActivation(k, n, high_precision_dtype).to(device)
683-
input_data = torch.randn(m, k, device=device, dtype=high_precision_dtype)
684-
elif model_type == "transformer_block":
685-
# For transformer block, k is the hidden dimension
686-
model = TransformerBlock(k, num_heads=8, mlp_ratio=4, dtype=high_precision_dtype).to(device)
687-
# Input shape for transformer is [batch_size, seq_len, hidden_dim]
688-
input_data = torch.randn(m, 16, k, device=device, dtype=high_precision_dtype)
689579
else:
690580
raise ValueError(f"Unknown model type: {model_type}")
691581
return model, input_data

0 commit comments

Comments
 (0)