Skip to content

Commit 3791188

Browse files
add 512 row kernel
1 parent 1ff0461 commit 3791188

File tree

4 files changed

+593
-23
lines changed

4 files changed

+593
-23
lines changed

benchmarks/prototype/moe_training/mxfp8/bench_triton_mx_block_rearrange_2d_K_groups.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,13 @@ def get_configs() -> List[ExperimentConfig]:
9898
(2048, 131072 // block_size),
9999
]
100100
num_groups = [8]
101-
versions = ["triton", "cuda_rowmajor", "cuda_colmajor"]
101+
versions = [
102+
"triton",
103+
"cuda_rowmajor",
104+
"cuda_colmajor",
105+
"cuda_colmajor_vec",
106+
"cuda_colmajor_vec_16B",
107+
]
102108

103109
configs = []
104110
for shape, groups, version in itertools.product(
@@ -152,6 +158,18 @@ def run_experiment(config: ExperimentConfig) -> ExperimentResult:
152158
# Column-major kernel expects column-major input
153159
# Column-major: same shape (rows, cols) but stride(0)=1, stride(1)=rows
154160
kernel_input = input_tensor.T.contiguous().T
161+
elif version == "cuda_colmajor_vec":
162+
if mxfp8_cuda is None:
163+
raise RuntimeError("CUDA kernel not available")
164+
kernel_fn = mxfp8_cuda.mx_block_rearrange_2d_K_groups_colmajor_vectorized
165+
# Vectorized column-major kernel also expects column-major input
166+
kernel_input = input_tensor.T.contiguous().T
167+
elif version == "cuda_colmajor_vec_16B":
168+
if mxfp8_cuda is None:
169+
raise RuntimeError("CUDA kernel not available")
170+
kernel_fn = mxfp8_cuda.mx_block_rearrange_2d_K_groups_colmajor_vectorized_16B
171+
# 16B vectorized column-major kernel also expects column-major input
172+
kernel_input = input_tensor.T.contiguous().T
155173
else:
156174
raise ValueError(f"Unknown version: {version}")
157175

@@ -217,7 +235,7 @@ def print_results(experiments: List[Experiment]):
217235
for version, result in versions.items():
218236
# Calculate speedup vs triton
219237
speedup_str = ""
220-
if version != "triton" and triton_time_us > 0:
238+
if version != "triton":
221239
speedup = triton_time_us / result.time_us
222240
speedup_str = f"{speedup:.2f}x"
223241

0 commit comments

Comments
 (0)