diff --git a/benchmarks/microbenchmarks/README.md b/benchmarks/microbenchmarks/README.md index d65b295645..a95dc53755 100644 --- a/benchmarks/microbenchmarks/README.md +++ b/benchmarks/microbenchmarks/README.md @@ -63,15 +63,7 @@ Currently, quantization string is in same format as the one being passed in llam ### Model Types - `linear`: Simple linear layer -- `ln_linear_`: LayerNorm + Linear + Activation, where activation can be: - - `ln_linear_sigmoid`: LayerNorm + Linear + Sigmoid - - `ln_linear_relu`: LayerNorm + Linear + ReLU - - `ln_linear_leakyrelu`: LayerNorm + Linear + LeakyReLU - - `ln_linear_relu6`: LayerNorm + Linear + ReLU6 - - `ln_linear_gelu`: LayerNorm + Linear + GELU - - `ln_linear_silu`: LayerNorm + Linear + SiLU - - `ln_linear_hardswish`: LayerNorm + Linear + Hardswish -- `transformer_block`: Transformer block with self-attention and MLP +- `ln_linear_sigmoid`: LayerNorm + Linear + Sigmoid ### Device Options - `cuda`: NVIDIA GPU @@ -79,58 +71,6 @@ Currently, quantization string is in same format as the one being passed in llam - `mps`: Apple Silicon GPU - `cpu`: CPU fallback -### Shape Generation Options -- `custom`: Manually specify shapes as a list of [m, k, n] dimensions - ```yaml - matrix_shapes: - - name: "custom" - shapes: [ - [1024, 1024, 1024], # [m, k, n] - [2048, 4096, 1024] - ] - ``` - -- `llama`: Use LLaMa 2 70B single-node weight shapes (assumes fused attn.wqkv and ffn.w13) - - Generates shapes for: "attn.wqkv", "attn.w0", "ffn.w13", "ffn.w2" - ```yaml - matrix_shapes: - - name: "llama" - ``` - -- `pow2`: Generate shapes with dimensions that are powers of 2 - - Parameters: - - `min_power`: Minimum power of 2 (default: 10, which is 1024) - - `max_power`: Maximum power of 2 (default: 14, which is 16,384) - ```yaml - matrix_shapes: - - name: "pow2" - min_power: 10 # 2^10 = 1024 - max_power: 12 # 2^12 = 4096 - ``` - -- `pow2_extended`: Generate shapes with dimensions that are powers of 2 and powers of 2 + half - - Parameters: - - `min_power`: Minimum power of 2 (default: 10, which is 1024) - - `max_power`: Maximum power of 2 (default: 14, which is 16,384) - ```yaml - matrix_shapes: - - name: "pow2_extended" - min_power: 10 # Generates: 1024, 1536, 2048, 3072, etc. - max_power: 11 - ``` - -- `sweep`: Generate a sweep of shapes with different powers of 2 for M, K, N dimensions - - Parameters: - - `min_power`: Minimum power of 2 (default: 8, which is 256) - - `max_power`: Maximum power of 2 (default: 15, which is 32,768) - - Note: This generates all combinations of M, K, N dimensions, which can be a large number of shapes - ```yaml - matrix_shapes: - - name: "sweep" - min_power: 8 # 2^8 = 256 - max_power: 9 # 2^9 = 512 - ``` - ## Output Results are saved to a CSV file in the specified output directory diff --git a/benchmarks/microbenchmarks/benchmark_inference.py b/benchmarks/microbenchmarks/benchmark_inference.py index b7a8e8d7c4..3af0ceb57b 100644 --- a/benchmarks/microbenchmarks/benchmark_inference.py +++ b/benchmarks/microbenchmarks/benchmark_inference.py @@ -22,14 +22,12 @@ BenchmarkConfig, BenchmarkResult, clean_caches, + create_model_and_input, model_inference_time_in_ms, string_to_config, ) from torchao.quantization import quantize_ from torchao.sparsity.sparse_api import sparsify_ -from torchao.testing.model_architectures import ( - create_model_and_input_data, -) def run(config: BenchmarkConfig) -> BenchmarkResult: @@ -40,7 +38,7 @@ def run(config: BenchmarkConfig) -> BenchmarkResult: # Create output directory if it doesn't exist Path(config.output_dir).mkdir(parents=True, exist_ok=True) - base_model, input_data = create_model_and_input_data( + base_model, input_data = create_model_and_input( config.model_type, config.m, config.k, diff --git a/benchmarks/microbenchmarks/benchmark_runner.py b/benchmarks/microbenchmarks/benchmark_runner.py index fbd7f08388..e38fc93819 100644 --- a/benchmarks/microbenchmarks/benchmark_runner.py +++ b/benchmarks/microbenchmarks/benchmark_runner.py @@ -48,50 +48,9 @@ def get_shapes_for_config( name = shape_config["name"] if name == "custom": shapes.extend([(name, shape) for shape in shape_config["shapes"]]) - elif name == "llama": - # LLaMa 2 70B single-node weight shapes - # assumes fused attn.wqkv and ffn.w13 - bsz, seq_len = 4, 4096 - M = bsz * seq_len - llama_shapes = { - "attn.wqkv": (M, 8192, 1280), - "attn.w0": (M, 1024, 8192), - "ffn.w13": (M, 8192, 7168), - "ffn.w2": (M, 3584, 8192), - } - shapes.extend([(f"{name}_{k}", v) for k, v in llama_shapes.items()]) - elif name == "pow2": - # Generate shapes with dimensions that are powers of 2 - min_power_of_2 = shape_config.get("min_power", 10) # 1024 - max_power_of_2 = shape_config.get("max_power", 14) # 16,384 - for idx, power_of_2 in enumerate(range(min_power_of_2, max_power_of_2 + 1)): - val = 2**power_of_2 - shapes.append((f"{name}_{idx}", [val, val, val])) - elif name == "pow2_extended": - # Generate shapes with dimensions that are powers of 2 and powers of 2 + half - min_power_of_2 = shape_config.get("min_power", 10) # 1024 - max_power_of_2 = shape_config.get("max_power", 14) # 16,384 - for idx, power_of_2 in enumerate(range(min_power_of_2, max_power_of_2 + 1)): - val1 = 2**power_of_2 - val2 = 2**power_of_2 + 2 ** (power_of_2 - 1) - shapes.append((f"{name}_{idx * 2}", [val1, val1, val1])) - shapes.append((f"{name}_{idx * 2 + 1}", [val2, val2, val2])) - elif name == "sweep": - # Generate a sweep of shapes with different powers of 2 for M, K, N - min_p2 = shape_config.get("min_power", 8) # 256 - max_p2 = shape_config.get("max_power", 15) # 32,768 - counter = 0 - for M_p2 in range(min_p2, max_p2 + 1): - M = 2**M_p2 - for K_p2 in range(min_p2, max_p2 + 1): - K = 2**K_p2 - for N_p2 in range(min_p2, max_p2 + 1): - N = 2**N_p2 - shapes.append((f"{name}_{counter}", [M, K, N])) - counter += 1 else: raise NotImplementedError( - f"Shape config {name} not supported. Supported options: custom, llama, pow2, pow2_extended, sweep." + f"Shape config {name} not supported. Currently only supports custom shapes." ) return shapes diff --git a/benchmarks/microbenchmarks/test/benchmark_config.yml b/benchmarks/microbenchmarks/test/benchmark_config.yml index 2fc0433c36..5ea3f5d642 100644 --- a/benchmarks/microbenchmarks/test/benchmark_config.yml +++ b/benchmarks/microbenchmarks/test/benchmark_config.yml @@ -26,48 +26,3 @@ model_params: device: "cuda" model_type: "linear" enable_profiler: true # Enable profiling for this model - - - name: "ln_linear_sigmoid_cuda" - matrix_shapes: - - name: "custom" - shapes: [ - [2048, 4096, 1024], - ] - high_precision_dtype: "torch.bfloat16" - use_torch_compile: true - torch_compile_mode: "max-autotune" - device: "cuda" - model_type: "ln_linear_sigmoid" - enable_profiler: true - - - name: "bf16_transformer_block" - matrix_shapes: - - name: "custom" - shapes: [ - [2048, 4096, 1024], # For transformer_block, k is the hidden dimension - ] - high_precision_dtype: "torch.bfloat16" - use_torch_compile: true - torch_compile_mode: "max-autotune" - device: "cuda" - model_type: "transformer_block" # TODO: Add a custom model (Figure out how to do this, maybe pass a .py file with model definition) - enable_profiler: true - - - name: "large_bf16_ln_linear" - matrix_shapes: - - name: "llama" # Example of using LLaMa shapes - - name: "pow2" # Example of using power of 2 shapes - min_power: 10 # 1024 - max_power: 12 # 4096 - - name: "pow2_extended" # Example of using extended power of 2 shapes - min_power: 10 # 1024 - max_power: 11 # 2048 - - name: "sweep" # Example of using sweep shapes (commented out as it generates many shapes) - min_power: 8 # 256 - max_power: 9 # 512 - high_precision_dtype: "torch.bfloat16" - use_torch_compile: true - torch_compile_mode: "max-autotune" - device: "cuda" - model_type: "linear" - enable_profiler: true # Enable profiling for this model diff --git a/benchmarks/microbenchmarks/test/test_benchmark_profiler.py b/benchmarks/microbenchmarks/test/test_benchmark_profiler.py index e3971b5986..0e398b4899 100644 --- a/benchmarks/microbenchmarks/test/test_benchmark_profiler.py +++ b/benchmarks/microbenchmarks/test/test_benchmark_profiler.py @@ -15,8 +15,8 @@ ) from benchmarks.microbenchmarks.utils import ( BenchmarkConfig, + ToyLinearModel, ) -from torchao.testing.model_architectures import ToyLinearModel class TestBenchmarkProfiler(unittest.TestCase): diff --git a/benchmarks/microbenchmarks/test/test_benchmark_runner.py b/benchmarks/microbenchmarks/test/test_benchmark_runner.py index 7f93213a22..a8683a1de8 100644 --- a/benchmarks/microbenchmarks/test/test_benchmark_runner.py +++ b/benchmarks/microbenchmarks/test/test_benchmark_runner.py @@ -57,72 +57,12 @@ def tearDown(self): shutil.rmtree(self.temp_dir) def test_get_shapes_for_config(self): - # Test custom shapes shapes = get_shapes_for_config( self.test_config["model_params"][0]["matrix_shapes"] ) self.assertEqual(len(shapes), 1) self.assertEqual(shapes[0], ("custom", [1024, 1024, 1024])) - # Test llama shapes - llama_shapes = get_shapes_for_config([{"name": "llama"}]) - self.assertEqual(len(llama_shapes), 4) # 4 LLaMa shapes - self.assertTrue( - any(name.startswith("llama_attn.wqkv") for name, _ in llama_shapes) - ) - self.assertTrue( - any(name.startswith("llama_attn.w0") for name, _ in llama_shapes) - ) - self.assertTrue( - any(name.startswith("llama_ffn.w13") for name, _ in llama_shapes) - ) - self.assertTrue( - any(name.startswith("llama_ffn.w2") for name, _ in llama_shapes) - ) - - # Test pow2 shapes - pow2_shapes = get_shapes_for_config( - [{"name": "pow2", "min_power": 10, "max_power": 12}] - ) - self.assertEqual(len(pow2_shapes), 3) # 3 powers of 2 (10, 11, 12) - self.assertEqual(pow2_shapes[0], ("pow2_0", [1024, 1024, 1024])) # 2^10 - self.assertEqual(pow2_shapes[1], ("pow2_1", [2048, 2048, 2048])) # 2^11 - self.assertEqual(pow2_shapes[2], ("pow2_2", [4096, 4096, 4096])) # 2^12 - - # Test pow2_extended shapes - pow2_extended_shapes = get_shapes_for_config( - [{"name": "pow2_extended", "min_power": 10, "max_power": 11}] - ) - self.assertEqual( - len(pow2_extended_shapes), 4 - ) # 2 powers of 2, each with 2 variants - self.assertEqual( - pow2_extended_shapes[0], ("pow2_extended_0", [1024, 1024, 1024]) - ) # 2^10 - self.assertEqual( - pow2_extended_shapes[1], ("pow2_extended_1", [1536, 1536, 1536]) - ) # 2^10 + 2^9 - self.assertEqual( - pow2_extended_shapes[2], ("pow2_extended_2", [2048, 2048, 2048]) - ) # 2^11 - self.assertEqual( - pow2_extended_shapes[3], ("pow2_extended_3", [3072, 3072, 3072]) - ) # 2^11 + 2^10 - - # Test sweep shapes (limited to a small range for testing) - sweep_shapes = get_shapes_for_config( - [{"name": "sweep", "min_power": 8, "max_power": 9}] - ) - # For min_power=8, max_power=9, we should have 8 shapes (2^3 = 8 combinations) - self.assertEqual(len(sweep_shapes), 8) - # Check that all shapes have the expected format - for name, shape in sweep_shapes: - self.assertTrue(name.startswith("sweep_")) - self.assertEqual(len(shape), 3) # [M, K, N] - # Check that all dimensions are powers of 2 between 2^8 and 2^9 - for dim in shape: - self.assertTrue(dim in [256, 512]) # 2^8, 2^9 - def test_get_param_combinations(self): model_param = self.test_config["model_params"][0] shapes, params = get_param_combinations(model_param) diff --git a/benchmarks/microbenchmarks/test/test_utils.py b/benchmarks/microbenchmarks/test/test_utils.py index bb721e9e03..14f226bd7e 100644 --- a/benchmarks/microbenchmarks/test/test_utils.py +++ b/benchmarks/microbenchmarks/test/test_utils.py @@ -16,17 +16,15 @@ BlockSparseWeightConfig, Float8DynamicActivationFloat8SemiSparseWeightConfig, Int4WeightOnlyConfig, + LNLinearSigmoid, SemiSparseWeightConfig, + ToyLinearModel, clean_caches, + create_model_and_input, generate_results_csv, get_default_device, string_to_config, ) -from torchao.testing.model_architectures import ( - LNLinearActivationModel, - ToyLinearModel, - create_model_and_input_data, -) class TestUtils(unittest.TestCase): @@ -155,7 +153,7 @@ def test_toy_linear_model(self): self.assertEqual(out.dtype, torch.float32) def test_ln_linear_sigmoid(self): - model = LNLinearActivationModel(fc_dim1=64, fc_dim2=32, dtype=torch.float32) + model = LNLinearSigmoid(fc_dim1=64, fc_dim2=32, dtype=torch.float32) x = torch.randn(16, 64) out = model(x) self.assertEqual(out.shape, (16, 32)) @@ -164,9 +162,9 @@ def test_ln_linear_sigmoid(self): torch.all((out >= 0) & (out <= 1)) ) # Check sigmoid output range - def test_create_model_and_input_data(self): + def test_create_model_and_input(self): m, k, n = 16, 64, 32 - model, input_data = create_model_and_input_data( + model, input_data = create_model_and_input( model_type="linear", m=m, k=k, @@ -177,7 +175,7 @@ def test_create_model_and_input_data(self): self.assertIsInstance(model, ToyLinearModel) self.assertEqual(input_data.shape, (m, k)) - model, input_data = create_model_and_input_data( + model, input_data = create_model_and_input( model_type="ln_linear_sigmoid", m=m, k=k, @@ -185,7 +183,7 @@ def test_create_model_and_input_data(self): high_precision_dtype=torch.float32, device="cpu", ) - self.assertIsInstance(model, LNLinearActivationModel) + self.assertIsInstance(model, LNLinearSigmoid) self.assertEqual(input_data.shape, (m, k)) def test_generate_results_csv(self): diff --git a/benchmarks/microbenchmarks/utils.py b/benchmarks/microbenchmarks/utils.py index 3907abfa89..2fef1317fc 100644 --- a/benchmarks/microbenchmarks/utils.py +++ b/benchmarks/microbenchmarks/utils.py @@ -137,6 +137,30 @@ def to_dict(self) -> Dict[str, Any]: return result_dict +class ToyLinearModel(torch.nn.Module): + def __init__(self, k=64, n=32, dtype=torch.bfloat16): + super().__init__() + self.linear1 = torch.nn.Linear(k, n, bias=False).to(dtype) + + def forward(self, x): + x = self.linear1(x) + return x + + +class LNLinearSigmoid(torch.nn.Module): + def __init__(self, fc_dim1, fc_dim2, dtype=torch.bfloat16): + super().__init__() + self.ln = torch.nn.LayerNorm(fc_dim1, elementwise_affine=False) + self.fc = torch.nn.Linear(fc_dim1, fc_dim2, bias=False).to(dtype) + self.sigmoid = torch.nn.Sigmoid() + + def forward(self, x): + x = self.ln(x) + x = self.fc(x) + x = self.sigmoid(x) + return x + + def string_to_config( quantization: Optional[str], sparsity: Optional[str], **kwargs ) -> AOBaseConfig: @@ -313,6 +337,34 @@ def model_inference_time_in_ms(model, input_data): return res * 1e6 +def create_model_and_input( + model_type: str, + m: int, + k: int, + n: int, + high_precision_dtype: torch.dtype = torch.bfloat16, + device: str = get_default_device(), +): + """Create a model and input data for benchmarking. + + Args: + model_type (str): type of the model to be created + batch_size (int): batch size of the input data + device (str): device to run the model on + high_precision_dtype (torch.dtype): data type of the model + m, k, n (int): dimensions of the model and input data + """ + if model_type == "linear": + model = ToyLinearModel(k, n, high_precision_dtype).to(device) + input_data = torch.randn(m, k, device=device, dtype=high_precision_dtype) + elif model_type == "ln_linear_sigmoid": + model = LNLinearSigmoid(k, n, high_precision_dtype).to(device) + input_data = torch.randn(m, k, device=device, dtype=high_precision_dtype) + else: + raise ValueError(f"Unknown model type: {model_type}") + return model, input_data + + def clean_caches(): import gc diff --git a/test/test_model_architecture.py b/test/test_model_architecture.py deleted file mode 100644 index 973939a56a..0000000000 --- a/test/test_model_architecture.py +++ /dev/null @@ -1,55 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD 3-Clause license found in the -# LICENSE file in the root directory of this source tree. - -import unittest - -import torch -from parameterized import parameterized - -from torchao.testing.model_architectures import create_model_and_input_data -from torchao.utils import get_available_devices - - -class TestModels(unittest.TestCase): - @parameterized.expand([(device,) for device in get_available_devices()]) - def test_toy_linear_model(self, device): - # Skip if device is not available - if device == "cuda" and not torch.cuda.is_available(): - self.skipTest("CUDA not available") - - model, input_data = create_model_and_input_data( - "linear", 10, 64, 32, device=device - ) - output = model(input_data) - self.assertEqual(output.shape, (10, 32)) - - @parameterized.expand([(device,) for device in get_available_devices()]) - def test_ln_linear_activation_model(self, device): - # Skip if device is not available - if device == "cuda" and not torch.cuda.is_available(): - self.skipTest("CUDA not available") - - model, input_data = create_model_and_input_data( - "ln_linear_sigmoid", 10, 64, 32, device=device - ) - output = model(input_data) - self.assertEqual(output.shape, (10, 32)) - - @parameterized.expand([(device,) for device in get_available_devices()]) - def test_transformer_block(self, device): - # Skip if device is not available - if device == "cuda" and not torch.cuda.is_available(): - self.skipTest("CUDA not available") - - model, input_data = create_model_and_input_data( - "transformer_block", 10, 64, 32, device=device - ) - output = model(input_data) - self.assertEqual(output.shape, (10, 16, 64)) - - -if __name__ == "__main__": - unittest.main() diff --git a/torchao/quantization/qat/embedding.py b/torchao/quantization/qat/embedding.py index 97b5920c7a..02772f05f0 100644 --- a/torchao/quantization/qat/embedding.py +++ b/torchao/quantization/qat/embedding.py @@ -196,7 +196,7 @@ def convert( """ self._convert_helper(model) return model - + @staticmethod def quantize_weights( weight: torch.Tensor, @@ -207,11 +207,12 @@ def quantize_weights( Helper function to quantize weights """ (qmin, qmax) = _get_qmin_qmax(bit_width) - (s, zp) = get_group_qparams_symmetric(weight, bit_width, group_size) + (s, zp) = get_group_qparams_symmetric( + weight, bit_width, group_size + ) from torchao._executorch_ops import ( _quantized_decomposed_quantize_per_channel_group_wrapper, ) - q_weight = _quantized_decomposed_quantize_per_channel_group_wrapper( weight, s, @@ -223,6 +224,7 @@ def quantize_weights( ) return (q_weight, s, zp) + def _convert_helper(self, module: torch.nn.Module): """ Helper function to recursively swap `Int4WeightOnlyQATEmbedding` @@ -253,9 +255,7 @@ def _convert_helper(self, module: torch.nn.Module): ) setattr(module, name, quantized_embedding) - q_weight, s, zp = self.quantize_weights( - child.weight, self.bit_width, group_size - ) + q_weight, s, zp = self.quantize_weights(child.weight, self.bit_width, group_size) # Load weights and qparams into quantized embedding quantized_embedding.weight = q_weight quantized_embedding.scale = s.to(scale_precision) diff --git a/torchao/quantization/qat/linear.py b/torchao/quantization/qat/linear.py index d384eff2d6..ab5417fb16 100644 --- a/torchao/quantization/qat/linear.py +++ b/torchao/quantization/qat/linear.py @@ -197,7 +197,7 @@ def convert( ) -> torch.nn.Module: self._convert_qat_linear_8da4w(model) return model - + @staticmethod def quantize_weights( weight: torch.Tensor, @@ -209,7 +209,9 @@ def quantize_weights( # Load weights and qparams into quantized linear n_bit = 4 (qmin, qmax) = _get_qmin_qmax(n_bit) - (s, zp) = get_group_qparams_symmetric(weight, n_bit, group_size) + (s, zp) = get_group_qparams_symmetric( + weight, n_bit, group_size + ) from torchao._executorch_ops import ( _quantized_decomposed_quantize_per_channel_group_wrapper, ) @@ -225,6 +227,7 @@ def quantize_weights( ) return (q_weight, s, zp) + def _convert_qat_linear_8da4w(self, module: torch.nn.Module): """ Replace all `Int8DynActInt4WeightQATLinear` with `Int8DynActInt4WeightLinear`. @@ -242,9 +245,7 @@ def _convert_qat_linear_8da4w(self, module: torch.nn.Module): ) setattr(module, name, quantized_linear) - q_weight, scales, zeros = self.quantize_weights( - child.weight, config.group_size - ) + q_weight, scales, zeros = self.quantize_weights(child.weight, config.group_size) quantized_linear.weight = q_weight quantized_linear.scales = scales quantized_linear.zeros = zeros diff --git a/torchao/testing/model_architectures.py b/torchao/testing/model_architectures.py deleted file mode 100644 index f59a1271b1..0000000000 --- a/torchao/testing/model_architectures.py +++ /dev/null @@ -1,179 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD 3-Clause license found in the -# LICENSE file in the root directory of this source tree. - -import re - -import torch -import torch.nn as nn - - -# TODO: Refactor torchao and tests to use these models -class ToyLinearModel(torch.nn.Module): - def __init__(self, k=64, n=32, dtype=torch.bfloat16): - super().__init__() - self.linear1 = torch.nn.Linear(k, n, bias=False).to(dtype) - - def forward(self, x): - x = self.linear1(x) - return x - - -class LNLinearActivationModel(nn.Module): - def __init__(self, fc_dim1, fc_dim2, dtype=torch.bfloat16, activation="sigmoid"): - super().__init__() - - activation = activation.lower() - activation_map = { - "relu": nn.ReLU(), - "sigmoid": nn.Sigmoid(), - "leakyrelu": nn.LeakyReLU(), - "relu6": nn.ReLU6(), - "gelu": nn.GELU(), - "silu": nn.SiLU(), - "hardswish": nn.Hardswish(), - } - - if activation not in activation_map: - raise ValueError(f"Unsupported activation: {activation}") - - self.ln = nn.LayerNorm(fc_dim1, elementwise_affine=False) - self.fc = nn.Linear(fc_dim1, fc_dim2, bias=False).to(dtype=dtype) - self.activation = activation_map[activation] - - def forward(self, x): - x = self.ln(x) - x = self.fc(x) - return self.activation(x) - - -class RMSNorm(nn.Module): - def __init__(self, dim: int, eps: float = 1e-5): - super().__init__() - self.eps = eps - self.weight = nn.Parameter(torch.ones(dim)) - - def _norm(self, x): - return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - output = self._norm(x.float()).type_as(x) - return output * self.weight - - -class TransformerBlock(torch.nn.Module): - def __init__(self, hidden_dim, num_heads=8, mlp_ratio=4, dtype=torch.bfloat16): - super().__init__() - self.hidden_dim = hidden_dim - self.num_heads = num_heads - self.head_dim = hidden_dim // num_heads - - # Self-attention - self.qkv = torch.nn.Linear(hidden_dim, 3 * hidden_dim, bias=False).to(dtype) - self.proj = torch.nn.Linear(hidden_dim, hidden_dim, bias=False).to(dtype) - - # MLP - self.mlp_ratio = mlp_ratio - self.mlp_hidden_dim = int(hidden_dim * mlp_ratio) - self.mlp_fc1 = torch.nn.Linear(hidden_dim, self.mlp_hidden_dim, bias=False).to( - dtype - ) - self.mlp_fc2 = torch.nn.Linear(self.mlp_hidden_dim, hidden_dim, bias=False).to( - dtype - ) - - # Layer norms - self.norm1 = RMSNorm(hidden_dim).to(dtype) - self.norm2 = RMSNorm(hidden_dim).to(dtype) - - # Activation - self.activation = torch.nn.GELU() - - def forward(self, x): - batch_size, seq_len, _ = x.shape - - # Self-attention - residual = x - x = self.norm1(x) - - # Reshape qkv projection for better memory layout - qkv = self.qkv(x) # [batch_size, seq_len, 3 * hidden_dim] - qkv = qkv.reshape(batch_size, seq_len, 3, self.num_heads, self.head_dim) - qkv = qkv.permute( - 2, 0, 3, 1, 4 - ) # [3, batch_size, num_heads, seq_len, head_dim] - q, k, v = qkv # Each has shape [batch_size, num_heads, seq_len, head_dim] - - # Scaled dot-product attention with proper reshaping - # Reshape for better memory layout and avoid broadcasting issues - q = q.reshape(batch_size * self.num_heads, seq_len, self.head_dim) - k = k.reshape(batch_size * self.num_heads, seq_len, self.head_dim) - v = v.reshape(batch_size * self.num_heads, seq_len, self.head_dim) - - # Compute attention scores - attn = (q @ k.transpose(-2, -1)) * (1.0 / (self.head_dim**0.5)) - attn = torch.softmax(attn, dim=-1) - - # Apply attention to values - x = attn @ v # [batch_size * num_heads, seq_len, head_dim] - - # Reshape back to original dimensions - x = x.reshape(batch_size, self.num_heads, seq_len, self.head_dim) - x = x.transpose(1, 2).reshape(batch_size, seq_len, self.hidden_dim) - - # Project back to hidden dimension - x = self.proj(x) - x = residual + x - - # MLP - residual = x - x = self.norm2(x) - x = self.mlp_fc1(x) - x = self.activation(x) - x = self.mlp_fc2(x) - x = residual + x - - return x - - -def create_model_and_input_data( - model_type: str, - m: int, - k: int, - n: int, - high_precision_dtype: torch.dtype = torch.bfloat16, - device: str = "cuda", - activation: str = "relu", -): - """Create a model and input data for benchmarking. - - Args: - model_type (str): type of the model to be created - batch_size (int): batch size of the input data - device (str): device to run the model on - high_precision_dtype (torch.dtype): data type of the model - m, k, n (int): dimensions of the model and input data - """ - if model_type == "linear": - model = ToyLinearModel(k, n, high_precision_dtype).to(device) - input_data = torch.randn(m, k, device=device, dtype=high_precision_dtype) - elif "ln_linear" in model_type: - # Extract activation type from model_type string - match = re.search(r"ln_linear_?(\w+)?", model_type) - activation = match.group(1) if match and match.group(1) else "relu" - model = LNLinearActivationModel( - k, n, high_precision_dtype, activation=activation - ).to(device) - input_data = torch.randn(m, k, device=device, dtype=high_precision_dtype) - elif model_type == "transformer_block": - # For transformer block, k is the hidden dimension - model = TransformerBlock( - k, num_heads=8, mlp_ratio=4, dtype=high_precision_dtype - ).to(device) - # Input shape for transformer is [batch_size, seq_len, hidden_dim] - input_data = torch.randn(m, 16, k, device=device, dtype=high_precision_dtype) - else: - raise ValueError(f"Unknown model type: {model_type}") - return model, input_data