diff --git a/benchmarks/microbenchmarks/test/benchmark_config.yml b/benchmarks/microbenchmarks/test/benchmark_config.yml index 227cb90948..4394d0208b 100644 --- a/benchmarks/microbenchmarks/test/benchmark_config.yml +++ b/benchmarks/microbenchmarks/test/benchmark_config.yml @@ -50,3 +50,31 @@ model_params: # device: "cpu" # model_type: "linear" # enable_profiler: true # Enable profiling for this model + + - name: "bf16_rms_norm_linear_activation" + matrix_shapes: + - name: "custom" + shapes: [ + [2048, 4096, 1024], + ] + high_precision_dtype: "torch.bfloat16" + use_torch_compile: true + torch_compile_mode: "max-autotune" + device: "cuda" + model_type: "rms_norm_linear_activation" + enable_profiler: true + enable_memory_profile: true + + - name: "bf16_transformer_block" + matrix_shapes: + - name: "custom" + shapes: [ + [2048, 4096, 1024], # For transformer_block, k is the hidden dimension + ] + high_precision_dtype: "torch.bfloat16" + use_torch_compile: true + torch_compile_mode: "max-autotune" + device: "cuda" + model_type: "transformer_block" + enable_profiler: true + enable_memory_profile: true diff --git a/benchmarks/microbenchmarks/test/test_utils.py b/benchmarks/microbenchmarks/test/test_utils.py index 14f226bd7e..46f6a74685 100644 --- a/benchmarks/microbenchmarks/test/test_utils.py +++ b/benchmarks/microbenchmarks/test/test_utils.py @@ -17,8 +17,11 @@ Float8DynamicActivationFloat8SemiSparseWeightConfig, Int4WeightOnlyConfig, LNLinearSigmoid, + RMSNorm, + RMSNormLinearActivation, SemiSparseWeightConfig, ToyLinearModel, + TransformerBlock, clean_caches, create_model_and_input, generate_results_csv, @@ -162,6 +165,61 @@ def test_ln_linear_sigmoid(self): torch.all((out >= 0) & (out <= 1)) ) # Check sigmoid output range + def test_rms_norm(self): + # Test RMSNorm + rms_norm = RMSNorm(dim=64) + x = torch.randn(16, 64) + out = rms_norm(x) + self.assertEqual(out.shape, (16, 64)) + + # Test with different eps + rms_norm = RMSNorm(dim=64, eps=1e-5) + out = rms_norm(x) + self.assertEqual(out.shape, (16, 64)) + + def test_rms_norm_linear_activation(self): + # Test with default GELU activation + model = RMSNormLinearActivation(fc_dim1=64, fc_dim2=32, dtype=torch.float32) + x = torch.randn(16, 64) + out = model(x) + self.assertEqual(out.shape, (16, 32)) + self.assertEqual(out.dtype, torch.float32) + + # Test with ReLU activation + model = RMSNormLinearActivation(fc_dim1=64, fc_dim2=32, dtype=torch.float32, activation="relu") + out = model(x) + self.assertEqual(out.shape, (16, 32)) + self.assertTrue(torch.all(out >= 0)) # Check ReLU output range + + # Test with SiLU activation + model = RMSNormLinearActivation(fc_dim1=64, fc_dim2=32, dtype=torch.float32, activation="silu") + out = model(x) + self.assertEqual(out.shape, (16, 32)) + + # Test with invalid activation + with self.assertRaises(ValueError): + RMSNormLinearActivation(fc_dim1=64, fc_dim2=32, dtype=torch.float32, activation="invalid") + + def test_transformer_block(self): + # Test with default parameters + model = TransformerBlock(hidden_dim=64, num_heads=8, mlp_ratio=4, dtype=torch.float32) + x = torch.randn(16, 16, 64) # [batch_size, seq_len, hidden_dim] + out = model(x) + self.assertEqual(out.shape, (16, 16, 64)) + self.assertEqual(out.dtype, torch.float32) + + # Test with different parameters + model = TransformerBlock(hidden_dim=128, num_heads=4, mlp_ratio=2, dtype=torch.float32) + x = torch.randn(8, 32, 128) + out = model(x) + self.assertEqual(out.shape, (8, 32, 128)) + + # Test with different head dimensions + model = TransformerBlock(hidden_dim=96, num_heads=6, mlp_ratio=3, dtype=torch.float32) + x = torch.randn(4, 8, 96) + out = model(x) + self.assertEqual(out.shape, (4, 8, 96)) + def test_create_model_and_input(self): m, k, n = 16, 64, 32 model, input_data = create_model_and_input( @@ -186,6 +244,63 @@ def test_create_model_and_input(self): self.assertIsInstance(model, LNLinearSigmoid) self.assertEqual(input_data.shape, (m, k)) + # Test RMSNormLinearActivation + model, input_data = create_model_and_input( + model_type="rms_norm_linear_activation", + m=m, + k=k, + n=n, + high_precision_dtype=torch.float32, + device="cpu", + ) + self.assertIsInstance(model, RMSNormLinearActivation) + self.assertEqual(input_data.shape, (m, k)) + + # Test TransformerBlock + model, input_data = create_model_and_input( + model_type="transformer_block", + m=m, + k=k, + n=n, # n is not used for transformer_block + high_precision_dtype=torch.float32, + device="cpu", + ) + self.assertIsInstance(model, TransformerBlock) + self.assertEqual(input_data.shape, (m, 16, k)) # [batch_size, seq_len, hidden_dim] + + def test_quantization_on_models(self): + # Test quantization on RMSNormLinearActivation + model = RMSNormLinearActivation(fc_dim1=64, fc_dim2=32, dtype=torch.float32) + x = torch.randn(16, 64) + + # Test with Int8WeightOnlyConfig + config = string_to_config(quantization="int8wo", sparsity=None) + if config is not None: + # Skip quantization test if torchao.quantization.quantize is not available + try: + from torchao.quantization import quantize + quantized_model = quantize(model, config) + out = quantized_model(x) + self.assertEqual(out.shape, (16, 32)) + except ImportError: + print("Skipping quantization test: torchao.quantization.quantize not available") + + # Test quantization on TransformerBlock + model = TransformerBlock(hidden_dim=64, num_heads=8, mlp_ratio=4, dtype=torch.float32) + x = torch.randn(16, 16, 64) + + # Test with Int8WeightOnlyConfig + config = string_to_config(quantization="int8wo", sparsity=None) + if config is not None: + # Skip quantization test if torchao.quantization.quantize is not available + try: + from torchao.quantization import quantize + quantized_model = quantize(model, config) + out = quantized_model(x) + self.assertEqual(out.shape, (16, 16, 64)) + except ImportError: + print("Skipping quantization test: torchao.quantization.quantize not available") + def test_generate_results_csv(self): results = [ BenchmarkResult( diff --git a/benchmarks/microbenchmarks/utils.py b/benchmarks/microbenchmarks/utils.py index 677f66ac75..9e978f70fa 100644 --- a/benchmarks/microbenchmarks/utils.py +++ b/benchmarks/microbenchmarks/utils.py @@ -383,6 +383,108 @@ def forward(self, x): return x +class RMSNorm(torch.nn.Module): + def __init__(self, dim, eps=1e-6, dtype=torch.bfloat16): + super().__init__() + self.eps = eps + self.weight = torch.nn.Parameter(torch.ones(dim, dtype=dtype)) + + def forward(self, x): + norm = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + return x * norm * self.weight + + +class RMSNormLinearActivation(torch.nn.Module): + def __init__(self, fc_dim1, fc_dim2, dtype=torch.bfloat16, activation="gelu"): + super().__init__() + self.rms_norm = RMSNorm(fc_dim1, dtype=dtype) + self.fc = torch.nn.Linear(fc_dim1, fc_dim2, bias=False).to(dtype) + + if activation == "gelu": + self.activation = torch.nn.GELU() + elif activation == "relu": + self.activation = torch.nn.ReLU() + elif activation == "silu": + self.activation = torch.nn.SiLU() + else: + raise ValueError(f"Unsupported activation: {activation}") + + def forward(self, x): + x = self.rms_norm(x) + x = self.fc(x) + x = self.activation(x) + return x + + +class TransformerBlock(torch.nn.Module): + def __init__(self, hidden_dim, num_heads=8, mlp_ratio=4, dtype=torch.bfloat16): + super().__init__() + self.hidden_dim = hidden_dim + self.num_heads = num_heads + self.head_dim = hidden_dim // num_heads + + # Self-attention + self.qkv = torch.nn.Linear(hidden_dim, 3 * hidden_dim, bias=False).to(dtype) + self.proj = torch.nn.Linear(hidden_dim, hidden_dim, bias=False).to(dtype) + + # MLP + self.mlp_ratio = mlp_ratio + self.mlp_hidden_dim = int(hidden_dim * mlp_ratio) + self.mlp_fc1 = torch.nn.Linear(hidden_dim, self.mlp_hidden_dim, bias=False).to(dtype) + self.mlp_fc2 = torch.nn.Linear(self.mlp_hidden_dim, hidden_dim, bias=False).to(dtype) + + # Layer norms + self.norm1 = RMSNorm(hidden_dim, dtype=dtype) + self.norm2 = RMSNorm(hidden_dim, dtype=dtype) + + # Activation + self.activation = torch.nn.GELU() + + def forward(self, x): + batch_size, seq_len, _ = x.shape + + # Self-attention + residual = x + x = self.norm1(x) + + # Reshape qkv projection for better memory layout + qkv = self.qkv(x) # [batch_size, seq_len, 3 * hidden_dim] + qkv = qkv.reshape(batch_size, seq_len, 3, self.num_heads, self.head_dim) + qkv = qkv.permute(2, 0, 3, 1, 4) # [3, batch_size, num_heads, seq_len, head_dim] + q, k, v = qkv # Each has shape [batch_size, num_heads, seq_len, head_dim] + + # Scaled dot-product attention with proper reshaping + # Reshape for better memory layout and avoid broadcasting issues + q = q.reshape(batch_size * self.num_heads, seq_len, self.head_dim) + k = k.reshape(batch_size * self.num_heads, seq_len, self.head_dim) + v = v.reshape(batch_size * self.num_heads, seq_len, self.head_dim) + + # Compute attention scores + attn = (q @ k.transpose(-2, -1)) * (1.0 / (self.head_dim ** 0.5)) + attn = torch.softmax(attn, dim=-1) + + # Apply attention to values + x = attn @ v # [batch_size * num_heads, seq_len, head_dim] + + # Reshape back to original dimensions + x = x.reshape(batch_size, self.num_heads, seq_len, self.head_dim) + x = x.transpose(1, 2).reshape(batch_size, seq_len, self.hidden_dim) + + # Project back to hidden dimension + x = self.proj(x) + x = residual + x + + # MLP + residual = x + x = self.norm2(x) + x = self.mlp_fc1(x) + x = self.activation(x) + x = self.mlp_fc2(x) + x = residual + x + + return x + + def string_to_config( quantization: Optional[str], sparsity: Optional[str], **kwargs ) -> AOBaseConfig: @@ -576,6 +678,14 @@ def create_model_and_input( elif model_type == "ln_linear_sigmoid": model = LNLinearSigmoid(k, n, high_precision_dtype).to(device) input_data = torch.randn(m, k, device=device, dtype=high_precision_dtype) + elif model_type == "rms_norm_linear_activation": + model = RMSNormLinearActivation(k, n, high_precision_dtype).to(device) + input_data = torch.randn(m, k, device=device, dtype=high_precision_dtype) + elif model_type == "transformer_block": + # For transformer block, k is the hidden dimension + model = TransformerBlock(k, num_heads=8, mlp_ratio=4, dtype=high_precision_dtype).to(device) + # Input shape for transformer is [batch_size, seq_len, hidden_dim] + input_data = torch.randn(m, 16, k, device=device, dtype=high_precision_dtype) else: raise ValueError(f"Unknown model type: {model_type}") return model, input_data