Skip to content

Commit 4b7ea5d

Browse files
committed
Add support for different models and different shapes
1 parent 04f39ef commit 4b7ea5d

File tree

10 files changed

+436
-112
lines changed

10 files changed

+436
-112
lines changed

benchmarks/microbenchmarks/README.md

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,14 +63,74 @@ Currently, quantization string is in same format as the one being passed in llam
6363

6464
### Model Types
6565
- `linear`: Simple linear layer
66-
- `ln_linear_sigmoid`: LayerNorm + Linear + Sigmoid
66+
- `ln_linear_<activation>`: LayerNorm + Linear + Activation, where activation can be:
67+
- `ln_linear_sigmoid`: LayerNorm + Linear + Sigmoid
68+
- `ln_linear_relu`: LayerNorm + Linear + ReLU
69+
- `ln_linear_leakyrelu`: LayerNorm + Linear + LeakyReLU
70+
- `ln_linear_relu6`: LayerNorm + Linear + ReLU6
71+
- `ln_linear_gelu`: LayerNorm + Linear + GELU
72+
- `ln_linear_silu`: LayerNorm + Linear + SiLU
73+
- `ln_linear_hardswish`: LayerNorm + Linear + Hardswish
74+
- `transformer_block`: Transformer block with self-attention and MLP
6775

6876
### Device Options
6977
- `cuda`: NVIDIA GPU
7078
- `xpu`: Intel GPU
7179
- `mps`: Apple Silicon GPU
7280
- `cpu`: CPU fallback
7381

82+
### Shape Generation Options
83+
- `custom`: Manually specify shapes as a list of [m, k, n] dimensions
84+
```yaml
85+
matrix_shapes:
86+
- name: "custom"
87+
shapes: [
88+
[1024, 1024, 1024], # [m, k, n]
89+
[2048, 4096, 1024]
90+
]
91+
```
92+
93+
- `llama`: Use LLaMa 2 70B single-node weight shapes (assumes fused attn.wqkv and ffn.w13)
94+
- Generates shapes for: "attn.wqkv", "attn.w0", "ffn.w13", "ffn.w2"
95+
```yaml
96+
matrix_shapes:
97+
- name: "llama"
98+
```
99+
100+
- `pow2`: Generate shapes with dimensions that are powers of 2
101+
- Parameters:
102+
- `min_power`: Minimum power of 2 (default: 10, which is 1024)
103+
- `max_power`: Maximum power of 2 (default: 14, which is 16,384)
104+
```yaml
105+
matrix_shapes:
106+
- name: "pow2"
107+
min_power: 10 # 2^10 = 1024
108+
max_power: 12 # 2^12 = 4096
109+
```
110+
111+
- `pow2_extended`: Generate shapes with dimensions that are powers of 2 and powers of 2 + half
112+
- Parameters:
113+
- `min_power`: Minimum power of 2 (default: 10, which is 1024)
114+
- `max_power`: Maximum power of 2 (default: 14, which is 16,384)
115+
```yaml
116+
matrix_shapes:
117+
- name: "pow2_extended"
118+
min_power: 10 # Generates: 1024, 1536, 2048, 3072, etc.
119+
max_power: 11
120+
```
121+
122+
- `sweep`: Generate a sweep of shapes with different powers of 2 for M, K, N dimensions
123+
- Parameters:
124+
- `min_power`: Minimum power of 2 (default: 8, which is 256)
125+
- `max_power`: Maximum power of 2 (default: 15, which is 32,768)
126+
- Note: This generates all combinations of M, K, N dimensions, which can be a large number of shapes
127+
```yaml
128+
matrix_shapes:
129+
- name: "sweep"
130+
min_power: 8 # 2^8 = 256
131+
max_power: 9 # 2^9 = 512
132+
```
133+
74134
## Output
75135

76136
Results are saved to a CSV file in the specified output directory

benchmarks/microbenchmarks/benchmark_inference.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,15 @@
1919
BenchmarkConfig,
2020
BenchmarkResult,
2121
clean_caches,
22-
create_model_and_input,
2322
generate_model_profile,
2423
model_inference_time_in_ms,
2524
string_to_config,
2625
)
2726
from torchao.quantization import quantize_
2827
from torchao.sparsity.sparse_api import sparsify_
28+
from torchao.testing.model_architectures import (
29+
create_model_and_input_data,
30+
)
2931

3032

3133
def run(config: BenchmarkConfig) -> BenchmarkResult:
@@ -36,7 +38,7 @@ def run(config: BenchmarkConfig) -> BenchmarkResult:
3638
# Create output directory if it doesn't exist
3739
Path(config.output_dir).mkdir(parents=True, exist_ok=True)
3840

39-
base_model, input_data = create_model_and_input(
41+
base_model, input_data = create_model_and_input_data(
4042
config.model_type,
4143
config.m,
4244
config.k,
@@ -94,16 +96,12 @@ def run(config: BenchmarkConfig) -> BenchmarkResult:
9496
if config.enable_profiler:
9597
print("Running profiler...")
9698
try:
97-
result.profiler_json_path, result.perfetto_url = generate_model_profile(
99+
result.profiler_json_path = generate_model_profile(
98100
m_copy, input_data, config.profiler_file_name
99101
)
100-
except Exception as e:
101-
print(f"Error running profiler: {e}")
102-
102+
except Exception:
103+
print(f"Error running profiler for {config.name}")
103104
return result
104-
except Exception as e:
105-
print(f"Error in benchmark run: {e}")
106-
import traceback
107-
108-
print(traceback.format_exc())
109-
return None
105+
except Exception:
106+
print(f"Error in benchmark run: {config.name}")
107+
return

benchmarks/microbenchmarks/benchmark_runner.py

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,50 @@ def get_shapes_for_config(
4848
name = shape_config["name"]
4949
if name == "custom":
5050
shapes.extend([(name, shape) for shape in shape_config["shapes"]])
51+
elif name == "llama":
52+
# LLaMa 2 70B single-node weight shapes
53+
# assumes fused attn.wqkv and ffn.w13
54+
bsz, seq_len = 4, 4096
55+
M = bsz * seq_len
56+
llama_shapes = {
57+
"attn.wqkv": (M, 8192, 1280),
58+
"attn.w0": (M, 1024, 8192),
59+
"ffn.w13": (M, 8192, 7168),
60+
"ffn.w2": (M, 3584, 8192),
61+
}
62+
shapes.extend([(f"{name}_{k}", v) for k, v in llama_shapes.items()])
63+
elif name == "pow2":
64+
# Generate shapes with dimensions that are powers of 2
65+
min_power_of_2 = shape_config.get("min_power", 10) # 1024
66+
max_power_of_2 = shape_config.get("max_power", 14) # 16,384
67+
for idx, power_of_2 in enumerate(range(min_power_of_2, max_power_of_2 + 1)):
68+
val = 2**power_of_2
69+
shapes.append((f"{name}_{idx}", [val, val, val]))
70+
elif name == "pow2_extended":
71+
# Generate shapes with dimensions that are powers of 2 and powers of 2 + half
72+
min_power_of_2 = shape_config.get("min_power", 10) # 1024
73+
max_power_of_2 = shape_config.get("max_power", 14) # 16,384
74+
for idx, power_of_2 in enumerate(range(min_power_of_2, max_power_of_2 + 1)):
75+
val1 = 2**power_of_2
76+
val2 = 2**power_of_2 + 2 ** (power_of_2 - 1)
77+
shapes.append((f"{name}_{idx*2}", [val1, val1, val1]))
78+
shapes.append((f"{name}_{idx*2+1}", [val2, val2, val2]))
79+
elif name == "sweep":
80+
# Generate a sweep of shapes with different powers of 2 for M, K, N
81+
min_p2 = shape_config.get("min_power", 8) # 256
82+
max_p2 = shape_config.get("max_power", 15) # 32,768
83+
counter = 0
84+
for M_p2 in range(min_p2, max_p2 + 1):
85+
M = 2**M_p2
86+
for K_p2 in range(min_p2, max_p2 + 1):
87+
K = 2**K_p2
88+
for N_p2 in range(min_p2, max_p2 + 1):
89+
N = 2**N_p2
90+
shapes.append((f"{name}_{counter}", [M, K, N]))
91+
counter += 1
5192
else:
5293
raise NotImplementedError(
53-
f"Shape config {name} not supported. Currently only supports custom shapes."
94+
f"Shape config {name} not supported. Supported options: custom, llama, pow2, pow2_extended, sweep."
5495
)
5596
return shapes
5697

@@ -167,10 +208,7 @@ def run_inference_benchmarks_from_config(configs: List[BenchmarkConfig]) -> None
167208
if result is not None: # Only add successful results
168209
results.append(result)
169210
except Exception as e:
170-
import traceback
171-
172211
print(f"Error running benchmark {config.name} with error: {e}")
173-
print(traceback.format_exc())
174212
continue
175213

176214
# Add results to csv if there are any

benchmarks/microbenchmarks/test/benchmark_config.yml

Lines changed: 52 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2,51 +2,71 @@
22
benchmark_mode: "inference"
33
quantization_config_recipe_names:
44
# Will run a baseline inference for model by default, without quantization for comparison
5-
# - "int4wo-32"
6-
# - "marlin"
75
- "int8wo"
8-
# sparsity_config_recipe_names:
6+
- "int8dq"
7+
- "float8dq"
8+
sparsity_config_recipe_names:
99
# Will run a baseline inference for model by default, without sparsity for comparison
10-
# - "semi-sparse"
11-
# - "block"
10+
- "semi-sparse"
11+
- "block"
1212
output_dir: "benchmarks/microbenchmarks/results"
1313
model_params:
14-
# - name: "small_bf16_linear"
15-
# matrix_shapes:
16-
# - name: "custom"
17-
# shapes: [
18-
# [1024, 1024, 1024], # [m, k, n]
19-
# ]
20-
# high_precision_dtype: "torch.bfloat16"
21-
# use_torch_compile: true
22-
# torch_compile_mode: "max-autotune"
23-
# device: "cuda"
24-
# model_type: "linear"
25-
# enable_profiler: true # Enable profiling for this model
26-
27-
- name: "large_bf16_ln_linear"
14+
- name: "small_bf16_linear"
2815
matrix_shapes:
2916
- name: "custom"
3017
shapes: [
18+
[1024, 1024, 1024], # [m, k, n]
3119
[2048, 4096, 1024],
32-
# [4096, 4096, 1024]
20+
[4096, 4096, 1024]
3321
]
3422
high_precision_dtype: "torch.bfloat16"
3523
use_torch_compile: true
3624
torch_compile_mode: "max-autotune"
3725
device: "cuda"
3826
model_type: "linear"
3927
enable_profiler: true # Enable profiling for this model
40-
enable_memory_profile: true # Enable memory profiling for this model
4128

42-
# - name: "cpu_fp32_linear"
43-
# matrix_shapes:
44-
# - name: "custom"
45-
# shapes: [
46-
# [4096, 4096, 1024]
47-
# ]
48-
# high_precision_dtype: "torch.float32"
49-
# use_torch_compile: false
50-
# device: "cpu"
51-
# model_type: "linear"
52-
# enable_profiler: true # Enable profiling for this model
29+
- name: "ln_linear_sigmoid_cuda"
30+
matrix_shapes:
31+
- name: "custom"
32+
shapes: [
33+
[2048, 4096, 1024],
34+
]
35+
high_precision_dtype: "torch.bfloat16"
36+
use_torch_compile: true
37+
torch_compile_mode: "max-autotune"
38+
device: "cuda"
39+
model_type: "ln_linear_sigmoid"
40+
enable_profiler: true
41+
42+
- name: "bf16_transformer_block"
43+
matrix_shapes:
44+
- name: "custom"
45+
shapes: [
46+
[2048, 4096, 1024], # For transformer_block, k is the hidden dimension
47+
]
48+
high_precision_dtype: "torch.bfloat16"
49+
use_torch_compile: true
50+
torch_compile_mode: "max-autotune"
51+
device: "cuda"
52+
model_type: "transformer_block" # TODO: Add a custom model (Figure out how to do this, maybe pass a .py file with model definition)
53+
enable_profiler: true
54+
55+
- name: "large_bf16_ln_linear"
56+
matrix_shapes:
57+
- name: "llama" # Example of using LLaMa shapes
58+
- name: "pow2" # Example of using power of 2 shapes
59+
min_power: 10 # 1024
60+
max_power: 12 # 4096
61+
- name: "pow2_extended" # Example of using extended power of 2 shapes
62+
min_power: 10 # 1024
63+
max_power: 11 # 2048
64+
- name: "sweep" # Example of using sweep shapes (commented out as it generates many shapes)
65+
min_power: 8 # 256
66+
max_power: 9 # 512
67+
high_precision_dtype: "torch.bfloat16"
68+
use_torch_compile: true
69+
torch_compile_mode: "max-autotune"
70+
device: "cuda"
71+
model_type: "linear"
72+
enable_profiler: true # Enable profiling for this model

benchmarks/microbenchmarks/test/test_benchmark_profiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212

1313
from benchmarks.microbenchmarks.utils import (
1414
BenchmarkConfig,
15-
ToyLinearModel,
1615
generate_model_profile,
1716
)
17+
from torchao.testing.model_architectures import ToyLinearModel
1818

1919

2020
class TestBenchmarkProfiler(unittest.TestCase):

benchmarks/microbenchmarks/test/test_benchmark_runner.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,12 +57,72 @@ def tearDown(self):
5757
shutil.rmtree(self.temp_dir)
5858

5959
def test_get_shapes_for_config(self):
60+
# Test custom shapes
6061
shapes = get_shapes_for_config(
6162
self.test_config["model_params"][0]["matrix_shapes"]
6263
)
6364
self.assertEqual(len(shapes), 1)
6465
self.assertEqual(shapes[0], ("custom", [1024, 1024, 1024]))
6566

67+
# Test llama shapes
68+
llama_shapes = get_shapes_for_config([{"name": "llama"}])
69+
self.assertEqual(len(llama_shapes), 4) # 4 LLaMa shapes
70+
self.assertTrue(
71+
any(name.startswith("llama_attn.wqkv") for name, _ in llama_shapes)
72+
)
73+
self.assertTrue(
74+
any(name.startswith("llama_attn.w0") for name, _ in llama_shapes)
75+
)
76+
self.assertTrue(
77+
any(name.startswith("llama_ffn.w13") for name, _ in llama_shapes)
78+
)
79+
self.assertTrue(
80+
any(name.startswith("llama_ffn.w2") for name, _ in llama_shapes)
81+
)
82+
83+
# Test pow2 shapes
84+
pow2_shapes = get_shapes_for_config(
85+
[{"name": "pow2", "min_power": 10, "max_power": 12}]
86+
)
87+
self.assertEqual(len(pow2_shapes), 3) # 3 powers of 2 (10, 11, 12)
88+
self.assertEqual(pow2_shapes[0], ("pow2_0", [1024, 1024, 1024])) # 2^10
89+
self.assertEqual(pow2_shapes[1], ("pow2_1", [2048, 2048, 2048])) # 2^11
90+
self.assertEqual(pow2_shapes[2], ("pow2_2", [4096, 4096, 4096])) # 2^12
91+
92+
# Test pow2_extended shapes
93+
pow2_extended_shapes = get_shapes_for_config(
94+
[{"name": "pow2_extended", "min_power": 10, "max_power": 11}]
95+
)
96+
self.assertEqual(
97+
len(pow2_extended_shapes), 4
98+
) # 2 powers of 2, each with 2 variants
99+
self.assertEqual(
100+
pow2_extended_shapes[0], ("pow2_extended_0", [1024, 1024, 1024])
101+
) # 2^10
102+
self.assertEqual(
103+
pow2_extended_shapes[1], ("pow2_extended_1", [1536, 1536, 1536])
104+
) # 2^10 + 2^9
105+
self.assertEqual(
106+
pow2_extended_shapes[2], ("pow2_extended_2", [2048, 2048, 2048])
107+
) # 2^11
108+
self.assertEqual(
109+
pow2_extended_shapes[3], ("pow2_extended_3", [3072, 3072, 3072])
110+
) # 2^11 + 2^10
111+
112+
# Test sweep shapes (limited to a small range for testing)
113+
sweep_shapes = get_shapes_for_config(
114+
[{"name": "sweep", "min_power": 8, "max_power": 9}]
115+
)
116+
# For min_power=8, max_power=9, we should have 8 shapes (2^3 = 8 combinations)
117+
self.assertEqual(len(sweep_shapes), 8)
118+
# Check that all shapes have the expected format
119+
for name, shape in sweep_shapes:
120+
self.assertTrue(name.startswith("sweep_"))
121+
self.assertEqual(len(shape), 3) # [M, K, N]
122+
# Check that all dimensions are powers of 2 between 2^8 and 2^9
123+
for dim in shape:
124+
self.assertTrue(dim in [256, 512]) # 2^8, 2^9
125+
66126
def test_get_param_combinations(self):
67127
model_param = self.test_config["model_params"][0]
68128
shapes, params = get_param_combinations(model_param)

0 commit comments

Comments
 (0)