@@ -98,7 +98,13 @@ def get_configs() -> List[ExperimentConfig]:
9898 (2048 , 131072 // block_size ),
9999 ]
100100 num_groups = [8 ]
101- versions = ["triton" , "cuda_rowmajor" , "cuda_colmajor" ]
101+ versions = [
102+ "triton" ,
103+ "cuda_rowmajor" ,
104+ "cuda_colmajor" ,
105+ "cuda_colmajor_vec" ,
106+ "cuda_colmajor_vec_16B" ,
107+ ]
102108
103109 configs = []
104110 for shape , groups , version in itertools .product (
@@ -152,6 +158,18 @@ def run_experiment(config: ExperimentConfig) -> ExperimentResult:
152158 # Column-major kernel expects column-major input
153159 # Column-major: same shape (rows, cols) but stride(0)=1, stride(1)=rows
154160 kernel_input = input_tensor .T .contiguous ().T
161+ elif version == "cuda_colmajor_vec" :
162+ if mxfp8_cuda is None :
163+ raise RuntimeError ("CUDA kernel not available" )
164+ kernel_fn = mxfp8_cuda .mx_block_rearrange_2d_K_groups_colmajor_vectorized
165+ # Vectorized column-major kernel also expects column-major input
166+ kernel_input = input_tensor .T .contiguous ().T
167+ elif version == "cuda_colmajor_vec_16B" :
168+ if mxfp8_cuda is None :
169+ raise RuntimeError ("CUDA kernel not available" )
170+ kernel_fn = mxfp8_cuda .mx_block_rearrange_2d_K_groups_colmajor_vectorized_16B
171+ # 16B vectorized column-major kernel also expects column-major input
172+ kernel_input = input_tensor .T .contiguous ().T
155173 else :
156174 raise ValueError (f"Unknown version: { version } " )
157175
@@ -217,7 +235,7 @@ def print_results(experiments: List[Experiment]):
217235 for version , result in versions .items ():
218236 # Calculate speedup vs triton
219237 speedup_str = ""
220- if version != "triton" and triton_time_us > 0 :
238+ if version != "triton" :
221239 speedup = triton_time_us / result .time_us
222240 speedup_str = f"{ speedup :.2f} x"
223241
0 commit comments