|
52 | 52 | "-O3", |
53 | 53 | "--use_fast_math", |
54 | 54 | "-std=c++17", |
55 | | - "-gencode=arch=compute_90,code=sm_90", |
| 55 | + "-gencode=arch=compute_100,code=sm_100", |
56 | 56 | ], |
57 | 57 | extra_cflags=["-O3", "-std=c++17"], |
58 | 58 | verbose=True, |
@@ -101,7 +101,7 @@ def get_configs() -> List[ExperimentConfig]: |
101 | 101 | (2048, 131072 // block_size), |
102 | 102 | ] |
103 | 103 | num_groups = [8] |
104 | | - versions = ["naive", "parallel", "cuda"] |
| 104 | + versions = ["triton_naive", "triton_parallel", "cuda_parallel", "cuda_naive"] |
105 | 105 |
|
106 | 106 | configs = [] |
107 | 107 | for shape, groups, version in itertools.product( |
@@ -138,12 +138,18 @@ def run_experiment(config: ExperimentConfig) -> ExperimentResult: |
138 | 138 | input_group_offsets = generate_jagged_offs(num_groups, Kg, multiple_of=block_size) |
139 | 139 |
|
140 | 140 | # Select which kernel to benchmark based on version |
141 | | - if version == "naive": |
| 141 | + if version == "triton_naive": |
142 | 142 | kernel_fn = triton_mx_block_rearrange_2d_K_groups_naive |
143 | | - elif version == "parallel": |
| 143 | + elif version == "triton_parallel": |
144 | 144 | kernel_fn = triton_mx_block_rearrange_2d_K_groups |
145 | | - elif version == "cuda": |
| 145 | + elif version == "cuda_parallel": |
| 146 | + if mxfp8_cuda is None: |
| 147 | + raise RuntimeError("CUDA kernel not available") |
146 | 148 | kernel_fn = mxfp8_cuda.mx_block_rearrange_2d_K_groups |
| 149 | + elif version == "cuda_naive": |
| 150 | + if mxfp8_cuda is None: |
| 151 | + raise RuntimeError("CUDA kernel not available") |
| 152 | + kernel_fn = mxfp8_cuda.mx_block_rearrange_2d_K_groups_naive |
147 | 153 | else: |
148 | 154 | raise ValueError(f"Unknown version: {version}") |
149 | 155 |
|
@@ -191,22 +197,35 @@ def print_results(experiments: List[Experiment]): |
191 | 197 | "time_us", |
192 | 198 | "mem_bw_gbps", |
193 | 199 | "fastest_version", |
| 200 | + "speedup_vs_triton_naive", |
194 | 201 | ] |
195 | 202 |
|
196 | 203 | rows = [] |
197 | 204 | for shape, versions in shapes_dict.items(): |
198 | 205 | # Find fastest version for this shape |
199 | 206 | fastest_version = min(versions.items(), key=lambda x: x[1].time_us)[0] |
200 | 207 |
|
| 208 | + # Get naive baseline time for speedup calculation |
| 209 | + naive_time_us = ( |
| 210 | + versions.get("triton_naive").time_us if "triton_naive" in versions else None |
| 211 | + ) |
| 212 | + |
201 | 213 | # Add rows for each version |
202 | 214 | for version, result in versions.items(): |
| 215 | + # Calculate speedup vs naive |
| 216 | + speedup_str = "" |
| 217 | + if naive_time_us and naive_time_us > 0: |
| 218 | + speedup = naive_time_us / result.time_us |
| 219 | + speedup_str = f"{speedup:.2f}x" |
| 220 | + |
203 | 221 | rows.append( |
204 | 222 | [ |
205 | 223 | version, |
206 | 224 | f"({shape[0]}, {shape[1]})", |
207 | 225 | f"{result.time_us:.2f}", |
208 | 226 | round(result.mem_bw_gbps, 3), |
209 | 227 | fastest_version, |
| 228 | + speedup_str, |
210 | 229 | ] |
211 | 230 | ) |
212 | 231 |
|
|
0 commit comments