Skip to content

Commit 825340e

Browse files
[mxfp8 moe training] parallelize along col blocks in scale blocked format kernel for groups along K
stack-info: PR: #3416, branch: danielvegamyhre/stack/85
1 parent a6dbf45 commit 825340e

File tree

10 files changed

+2927
-12
lines changed

10 files changed

+2927
-12
lines changed
Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,229 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import itertools
8+
import os
9+
from dataclasses import dataclass
10+
from typing import List
11+
12+
import torch
13+
from tabulate import tabulate
14+
from torch.utils.cpp_extension import load
15+
from tqdm import tqdm
16+
17+
from benchmarks.utils import benchmark_cuda_function_in_microseconds
18+
from torchao.prototype.moe_training.kernels.mxfp8 import (
19+
triton_mx_block_rearrange_2d_K_groups,
20+
)
21+
from torchao.prototype.moe_training.kernels.mxfp8.quant import (
22+
triton_mx_block_rearrange_2d_K_groups_naive,
23+
)
24+
from torchao.prototype.moe_training.utils import generate_jagged_offs
25+
26+
# Build CUDA kernel directly using torch.utils.cpp_extension.load
27+
mxfp8_cuda = None
28+
try:
29+
# Get the kernel source directory
30+
KERNEL_DIR = os.path.join(
31+
os.path.dirname(os.path.abspath(__file__)),
32+
"..",
33+
"..",
34+
"..",
35+
"..",
36+
"torchao",
37+
"csrc",
38+
"cuda",
39+
"mx_kernels",
40+
)
41+
KERNEL_DIR = os.path.normpath(KERNEL_DIR)
42+
43+
print("Compiling CUDA kernel...")
44+
mxfp8_cuda = load(
45+
name="mxfp8_cuda",
46+
sources=[
47+
os.path.join(KERNEL_DIR, "mxfp8_extension.cpp"),
48+
os.path.join(KERNEL_DIR, "mxfp8_cuda.cu"),
49+
os.path.join(KERNEL_DIR, "mx_block_rearrange_2d_K_groups.cu"),
50+
],
51+
extra_cuda_cflags=[
52+
"-O3",
53+
"--use_fast_math",
54+
"-std=c++17",
55+
"-gencode=arch=compute_90,code=sm_90",
56+
],
57+
extra_cflags=["-O3", "-std=c++17"],
58+
verbose=True,
59+
)
60+
print("✓ CUDA kernel compilation successful!")
61+
except (ImportError, RuntimeError) as e:
62+
print(f"⚠ CUDA kernel not available: {e}")
63+
print("The benchmark will only run 'naive' and 'parallel' Triton versions.\n")
64+
65+
device = torch.device("cuda")
66+
67+
# Needed since changing args to function causes recompiles
68+
torch._dynamo.config.cache_size_limit = 1000
69+
70+
71+
@dataclass(frozen=True)
72+
class ExperimentConfig:
73+
input_shape: tuple[int]
74+
num_groups: int
75+
version: str # "naive" or "parallel"
76+
77+
78+
@dataclass(frozen=True)
79+
class ExperimentResult:
80+
time_us: float
81+
mem_bw_gbps: float
82+
83+
84+
@dataclass(frozen=True)
85+
class Experiment:
86+
config: ExperimentConfig
87+
result: ExperimentResult
88+
89+
90+
def get_configs() -> List[ExperimentConfig]:
91+
# Llama4 and DSV3 671b shapes. Input activations are scaled along the total_M dim, which contains all the token groups.
92+
block_size = 32
93+
input_shapes = [
94+
(5120, 16384 // block_size),
95+
(5120, 131072 // block_size),
96+
(8192, 16384 // block_size),
97+
(8192, 131072 // block_size),
98+
(7168, 16384 // block_size),
99+
(7168, 131072 // block_size),
100+
(2048, 16384 // block_size),
101+
(2048, 131072 // block_size),
102+
]
103+
num_groups = [8]
104+
versions = ["naive", "parallel", "cuda"]
105+
106+
configs = []
107+
for shape, groups, version in itertools.product(
108+
input_shapes,
109+
num_groups,
110+
versions,
111+
):
112+
configs.append(
113+
ExperimentConfig(
114+
input_shape=shape,
115+
num_groups=groups,
116+
version=version,
117+
)
118+
)
119+
return configs
120+
121+
122+
def run_experiment(config: ExperimentConfig) -> ExperimentResult:
123+
input_shape, num_groups, version = (
124+
config.input_shape,
125+
config.num_groups,
126+
config.version,
127+
)
128+
input_tensor = torch.randint(
129+
low=0,
130+
high=256,
131+
size=input_shape,
132+
dtype=torch.uint8,
133+
device=device,
134+
)
135+
136+
M, Kg = input_shape
137+
block_size = 32
138+
input_group_offsets = generate_jagged_offs(num_groups, Kg, multiple_of=block_size)
139+
140+
# Select which kernel to benchmark based on version
141+
if version == "naive":
142+
kernel_fn = triton_mx_block_rearrange_2d_K_groups_naive
143+
elif version == "parallel":
144+
kernel_fn = triton_mx_block_rearrange_2d_K_groups
145+
elif version == "cuda":
146+
kernel_fn = mxfp8_cuda.mx_block_rearrange_2d_K_groups
147+
else:
148+
raise ValueError(f"Unknown version: {version}")
149+
150+
# Run kernel to get output shape
151+
out_scales = kernel_fn(
152+
input_tensor,
153+
input_group_offsets,
154+
)
155+
156+
# Benchmark the kernel
157+
assert input_tensor.is_contiguous()
158+
time_us = benchmark_cuda_function_in_microseconds(
159+
kernel_fn,
160+
input_tensor,
161+
input_group_offsets,
162+
)
163+
164+
# Calculate memory bandwidth
165+
bytes_per_input_el = torch.finfo(torch.float8_e8m0fnu).bits / 8
166+
bytes_per_output_el = torch.finfo(torch.float8_e4m3fn).bits / 8
167+
168+
read_bytes = input_tensor.numel() * bytes_per_input_el
169+
write_bytes = out_scales.numel() * bytes_per_output_el
170+
171+
mem_bw_gbps = ((read_bytes + write_bytes) / 1e9) / (time_us / 1e6)
172+
173+
return ExperimentResult(
174+
time_us=time_us,
175+
mem_bw_gbps=mem_bw_gbps,
176+
)
177+
178+
179+
def print_results(experiments: List[Experiment]):
180+
# Group experiments by input shape
181+
shapes_dict = {}
182+
for exp in experiments:
183+
shape_key = exp.config.input_shape
184+
if shape_key not in shapes_dict:
185+
shapes_dict[shape_key] = {}
186+
shapes_dict[shape_key][exp.config.version] = exp.result
187+
188+
headers = [
189+
"kernel_version",
190+
"input_shape",
191+
"time_us",
192+
"mem_bw_gbps",
193+
"fastest_version",
194+
]
195+
196+
rows = []
197+
for shape, versions in shapes_dict.items():
198+
# Find fastest version for this shape
199+
fastest_version = min(versions.items(), key=lambda x: x[1].time_us)[0]
200+
201+
# Add rows for each version
202+
for version, result in versions.items():
203+
rows.append(
204+
[
205+
version,
206+
f"({shape[0]}, {shape[1]})",
207+
f"{result.time_us:.2f}",
208+
round(result.mem_bw_gbps, 3),
209+
fastest_version,
210+
]
211+
)
212+
213+
print(tabulate(rows, headers=headers))
214+
215+
216+
def main():
217+
torch.random.manual_seed(123)
218+
configs = get_configs()
219+
results = []
220+
for config in tqdm(configs):
221+
result = run_experiment(config)
222+
results.append(Experiment(config=config, result=result))
223+
224+
# Use Tabulate to print results
225+
print_results(results)
226+
227+
228+
if __name__ == "__main__":
229+
main()

0 commit comments

Comments
 (0)