Skip to content

Commit a5d83e3

Browse files
cuda parallel idea working now
1 parent daf9ffd commit a5d83e3

File tree

4 files changed

+352
-78
lines changed

4 files changed

+352
-78
lines changed

benchmarks/prototype/moe_training/mxfp8/bench_triton_mx_block_rearrange_2d_K_groups.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
"-O3",
5353
"--use_fast_math",
5454
"-std=c++17",
55-
"-gencode=arch=compute_90,code=sm_90",
55+
"-gencode=arch=compute_100,code=sm_100",
5656
],
5757
extra_cflags=["-O3", "-std=c++17"],
5858
verbose=True,
@@ -101,7 +101,7 @@ def get_configs() -> List[ExperimentConfig]:
101101
(2048, 131072 // block_size),
102102
]
103103
num_groups = [8]
104-
versions = ["naive", "parallel", "cuda"]
104+
versions = ["triton_naive", "triton_parallel", "cuda_parallel", "cuda_naive"]
105105

106106
configs = []
107107
for shape, groups, version in itertools.product(
@@ -138,12 +138,18 @@ def run_experiment(config: ExperimentConfig) -> ExperimentResult:
138138
input_group_offsets = generate_jagged_offs(num_groups, Kg, multiple_of=block_size)
139139

140140
# Select which kernel to benchmark based on version
141-
if version == "naive":
141+
if version == "triton_naive":
142142
kernel_fn = triton_mx_block_rearrange_2d_K_groups_naive
143-
elif version == "parallel":
143+
elif version == "triton_parallel":
144144
kernel_fn = triton_mx_block_rearrange_2d_K_groups
145-
elif version == "cuda":
145+
elif version == "cuda_parallel":
146+
if mxfp8_cuda is None:
147+
raise RuntimeError("CUDA kernel not available")
146148
kernel_fn = mxfp8_cuda.mx_block_rearrange_2d_K_groups
149+
elif version == "cuda_naive":
150+
if mxfp8_cuda is None:
151+
raise RuntimeError("CUDA kernel not available")
152+
kernel_fn = mxfp8_cuda.mx_block_rearrange_2d_K_groups_naive
147153
else:
148154
raise ValueError(f"Unknown version: {version}")
149155

@@ -191,22 +197,35 @@ def print_results(experiments: List[Experiment]):
191197
"time_us",
192198
"mem_bw_gbps",
193199
"fastest_version",
200+
"speedup_vs_triton_naive",
194201
]
195202

196203
rows = []
197204
for shape, versions in shapes_dict.items():
198205
# Find fastest version for this shape
199206
fastest_version = min(versions.items(), key=lambda x: x[1].time_us)[0]
200207

208+
# Get naive baseline time for speedup calculation
209+
naive_time_us = (
210+
versions.get("triton_naive").time_us if "triton_naive" in versions else None
211+
)
212+
201213
# Add rows for each version
202214
for version, result in versions.items():
215+
# Calculate speedup vs naive
216+
speedup_str = ""
217+
if naive_time_us and naive_time_us > 0:
218+
speedup = naive_time_us / result.time_us
219+
speedup_str = f"{speedup:.2f}x"
220+
203221
rows.append(
204222
[
205223
version,
206224
f"({shape[0]}, {shape[1]})",
207225
f"{result.time_us:.2f}",
208226
round(result.mem_bw_gbps, 3),
209227
fastest_version,
228+
speedup_str,
210229
]
211230
)
212231

0 commit comments

Comments
 (0)