|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD 3-Clause license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +import itertools |
| 8 | +import os |
| 9 | +from dataclasses import dataclass |
| 10 | +from typing import List |
| 11 | + |
| 12 | +import torch |
| 13 | +from tabulate import tabulate |
| 14 | +from torch.utils.cpp_extension import load |
| 15 | +from tqdm import tqdm |
| 16 | + |
| 17 | +from benchmarks.utils import benchmark_cuda_function_in_microseconds |
| 18 | +from torchao.prototype.moe_training.kernels.mxfp8 import ( |
| 19 | + triton_mx_block_rearrange_2d_K_groups, |
| 20 | +) |
| 21 | +from torchao.prototype.moe_training.kernels.mxfp8.quant import ( |
| 22 | + triton_mx_block_rearrange_2d_K_groups_naive, |
| 23 | +) |
| 24 | +from torchao.prototype.moe_training.utils import generate_jagged_offs |
| 25 | + |
| 26 | +# Build CUDA kernel directly using torch.utils.cpp_extension.load |
| 27 | +mxfp8_cuda = None |
| 28 | +try: |
| 29 | + # Get the kernel source directory |
| 30 | + KERNEL_DIR = os.path.join( |
| 31 | + os.path.dirname(os.path.abspath(__file__)), |
| 32 | + "..", |
| 33 | + "..", |
| 34 | + "..", |
| 35 | + "..", |
| 36 | + "torchao", |
| 37 | + "csrc", |
| 38 | + "cuda", |
| 39 | + "mx_kernels", |
| 40 | + ) |
| 41 | + KERNEL_DIR = os.path.normpath(KERNEL_DIR) |
| 42 | + |
| 43 | + print("Compiling CUDA kernel...") |
| 44 | + mxfp8_cuda = load( |
| 45 | + name="mxfp8_cuda", |
| 46 | + sources=[ |
| 47 | + os.path.join(KERNEL_DIR, "mxfp8_extension.cpp"), |
| 48 | + os.path.join(KERNEL_DIR, "mxfp8_cuda.cu"), |
| 49 | + os.path.join(KERNEL_DIR, "mx_block_rearrange_2d_K_groups.cu"), |
| 50 | + ], |
| 51 | + extra_cuda_cflags=[ |
| 52 | + "-O3", |
| 53 | + "--use_fast_math", |
| 54 | + "-std=c++17", |
| 55 | + "-gencode=arch=compute_90,code=sm_90", |
| 56 | + ], |
| 57 | + extra_cflags=["-O3", "-std=c++17"], |
| 58 | + verbose=True, |
| 59 | + ) |
| 60 | + print("✓ CUDA kernel compilation successful!") |
| 61 | +except (ImportError, RuntimeError) as e: |
| 62 | + print(f"⚠ CUDA kernel not available: {e}") |
| 63 | + print("The benchmark will only run 'naive' and 'parallel' Triton versions.\n") |
| 64 | + |
| 65 | +device = torch.device("cuda") |
| 66 | + |
| 67 | +# Needed since changing args to function causes recompiles |
| 68 | +torch._dynamo.config.cache_size_limit = 1000 |
| 69 | + |
| 70 | + |
| 71 | +@dataclass(frozen=True) |
| 72 | +class ExperimentConfig: |
| 73 | + input_shape: tuple[int] |
| 74 | + num_groups: int |
| 75 | + version: str # "naive" or "parallel" |
| 76 | + |
| 77 | + |
| 78 | +@dataclass(frozen=True) |
| 79 | +class ExperimentResult: |
| 80 | + time_us: float |
| 81 | + mem_bw_gbps: float |
| 82 | + |
| 83 | + |
| 84 | +@dataclass(frozen=True) |
| 85 | +class Experiment: |
| 86 | + config: ExperimentConfig |
| 87 | + result: ExperimentResult |
| 88 | + |
| 89 | + |
| 90 | +def get_configs() -> List[ExperimentConfig]: |
| 91 | + # Llama4 and DSV3 671b shapes. Input activations are scaled along the total_M dim, which contains all the token groups. |
| 92 | + block_size = 32 |
| 93 | + input_shapes = [ |
| 94 | + (5120, 16384 // block_size), |
| 95 | + (5120, 131072 // block_size), |
| 96 | + (8192, 16384 // block_size), |
| 97 | + (8192, 131072 // block_size), |
| 98 | + (7168, 16384 // block_size), |
| 99 | + (7168, 131072 // block_size), |
| 100 | + (2048, 16384 // block_size), |
| 101 | + (2048, 131072 // block_size), |
| 102 | + ] |
| 103 | + num_groups = [8] |
| 104 | + versions = ["naive", "parallel", "cuda"] |
| 105 | + |
| 106 | + configs = [] |
| 107 | + for shape, groups, version in itertools.product( |
| 108 | + input_shapes, |
| 109 | + num_groups, |
| 110 | + versions, |
| 111 | + ): |
| 112 | + configs.append( |
| 113 | + ExperimentConfig( |
| 114 | + input_shape=shape, |
| 115 | + num_groups=groups, |
| 116 | + version=version, |
| 117 | + ) |
| 118 | + ) |
| 119 | + return configs |
| 120 | + |
| 121 | + |
| 122 | +def run_experiment(config: ExperimentConfig) -> ExperimentResult: |
| 123 | + input_shape, num_groups, version = ( |
| 124 | + config.input_shape, |
| 125 | + config.num_groups, |
| 126 | + config.version, |
| 127 | + ) |
| 128 | + input_tensor = torch.randint( |
| 129 | + low=0, |
| 130 | + high=256, |
| 131 | + size=input_shape, |
| 132 | + dtype=torch.uint8, |
| 133 | + device=device, |
| 134 | + ) |
| 135 | + |
| 136 | + M, Kg = input_shape |
| 137 | + block_size = 32 |
| 138 | + input_group_offsets = generate_jagged_offs(num_groups, Kg, multiple_of=block_size) |
| 139 | + |
| 140 | + # Select which kernel to benchmark based on version |
| 141 | + if version == "naive": |
| 142 | + kernel_fn = triton_mx_block_rearrange_2d_K_groups_naive |
| 143 | + elif version == "parallel": |
| 144 | + kernel_fn = triton_mx_block_rearrange_2d_K_groups |
| 145 | + elif version == "cuda": |
| 146 | + kernel_fn = mxfp8_cuda.mx_block_rearrange_2d_K_groups |
| 147 | + else: |
| 148 | + raise ValueError(f"Unknown version: {version}") |
| 149 | + |
| 150 | + # Run kernel to get output shape |
| 151 | + out_scales = kernel_fn( |
| 152 | + input_tensor, |
| 153 | + input_group_offsets, |
| 154 | + ) |
| 155 | + |
| 156 | + # Benchmark the kernel |
| 157 | + assert input_tensor.is_contiguous() |
| 158 | + time_us = benchmark_cuda_function_in_microseconds( |
| 159 | + kernel_fn, |
| 160 | + input_tensor, |
| 161 | + input_group_offsets, |
| 162 | + ) |
| 163 | + |
| 164 | + # Calculate memory bandwidth |
| 165 | + bytes_per_input_el = torch.finfo(torch.float8_e8m0fnu).bits / 8 |
| 166 | + bytes_per_output_el = torch.finfo(torch.float8_e4m3fn).bits / 8 |
| 167 | + |
| 168 | + read_bytes = input_tensor.numel() * bytes_per_input_el |
| 169 | + write_bytes = out_scales.numel() * bytes_per_output_el |
| 170 | + |
| 171 | + mem_bw_gbps = ((read_bytes + write_bytes) / 1e9) / (time_us / 1e6) |
| 172 | + |
| 173 | + return ExperimentResult( |
| 174 | + time_us=time_us, |
| 175 | + mem_bw_gbps=mem_bw_gbps, |
| 176 | + ) |
| 177 | + |
| 178 | + |
| 179 | +def print_results(experiments: List[Experiment]): |
| 180 | + # Group experiments by input shape |
| 181 | + shapes_dict = {} |
| 182 | + for exp in experiments: |
| 183 | + shape_key = exp.config.input_shape |
| 184 | + if shape_key not in shapes_dict: |
| 185 | + shapes_dict[shape_key] = {} |
| 186 | + shapes_dict[shape_key][exp.config.version] = exp.result |
| 187 | + |
| 188 | + headers = [ |
| 189 | + "kernel_version", |
| 190 | + "input_shape", |
| 191 | + "time_us", |
| 192 | + "mem_bw_gbps", |
| 193 | + "fastest_version", |
| 194 | + ] |
| 195 | + |
| 196 | + rows = [] |
| 197 | + for shape, versions in shapes_dict.items(): |
| 198 | + # Find fastest version for this shape |
| 199 | + fastest_version = min(versions.items(), key=lambda x: x[1].time_us)[0] |
| 200 | + |
| 201 | + # Add rows for each version |
| 202 | + for version, result in versions.items(): |
| 203 | + rows.append( |
| 204 | + [ |
| 205 | + version, |
| 206 | + f"({shape[0]}, {shape[1]})", |
| 207 | + f"{result.time_us:.2f}", |
| 208 | + round(result.mem_bw_gbps, 3), |
| 209 | + fastest_version, |
| 210 | + ] |
| 211 | + ) |
| 212 | + |
| 213 | + print(tabulate(rows, headers=headers)) |
| 214 | + |
| 215 | + |
| 216 | +def main(): |
| 217 | + torch.random.manual_seed(123) |
| 218 | + configs = get_configs() |
| 219 | + results = [] |
| 220 | + for config in tqdm(configs): |
| 221 | + result = run_experiment(config) |
| 222 | + results.append(Experiment(config=config, result=result)) |
| 223 | + |
| 224 | + # Use Tabulate to print results |
| 225 | + print_results(results) |
| 226 | + |
| 227 | + |
| 228 | +if __name__ == "__main__": |
| 229 | + main() |
0 commit comments