22#include < cuda_fp8.h>
33#include < cstdint>
44#include < cstdio>
5-
65#define BLOCK_ROWS 128
76#define BLOCK_COLS 4
8-
97// Helper function to compute ceil division
108__device__ __forceinline__ int ceil_div (int a, int b) {
119 return (a + b - 1 ) / b;
1210}
13-
1411// Helper function to compute the start index of a group after padding
1512__device__ __forceinline__ int compute_output_group_start_col (
1613 int group_id,
@@ -19,7 +16,6 @@ __device__ __forceinline__ int compute_output_group_start_col(
1916 int padding_size
2017) {
2118 int start_idx = 0 ;
22-
2319 // Compute prefix sum of padded group sizes
2420 for (int i = 0 ; i < group_id; i++) {
2521 int prev_offset = (i > 0 ) ? input_group_end_offsets[i - 1 ] : 0 ;
@@ -28,10 +24,8 @@ __device__ __forceinline__ int compute_output_group_start_col(
2824 int padded_size = ceil_div (group_size, padding_size) * padding_size;
2925 start_idx += padded_size;
3026 }
31-
3227 return start_idx;
3328}
34-
3529// Compute destination index for swizzled block layout
3630// For a 128x4 block: r_div_32 = row / 32, r_mod_32 = row % 32
3731// Swizzle: dest = r_mod_32 * 16 + r_div_32 * 4 + col
@@ -40,7 +34,6 @@ __device__ __forceinline__ int compute_swizzled_index(int row, int col) {
4034 int r_mod_32 = row % 32 ;
4135 return r_mod_32 * 16 + r_div_32 * 4 + col;
4236}
43-
4437__global__ void mx_block_rearrange_2d_K_groups_naive_kernel (
4538 const uint8_t * __restrict__ scales_ptr,
4639 int scales_stride_dim0,
@@ -55,83 +48,65 @@ __global__ void mx_block_rearrange_2d_K_groups_naive_kernel(
5548 const int group_id = blockIdx .x ;
5649 const int block_row_id = blockIdx .y ;
5750 const int tid = threadIdx .x ; // 128 threads, each handles one row
58-
5951 // Shared memory for one 128x4 block
6052 __shared__ __align__ (16 ) uint8_t smem_block[BLOCK_ROWS * BLOCK_COLS];
61-
6253 // Get start/end cols of this input group
6354 int input_group_start_col = (group_id > 0 ) ? input_group_end_offsets[group_id - 1 ] : 0 ;
6455 int input_group_end_col = input_group_end_offsets[group_id];
6556 int num_cols_in_group = input_group_end_col - input_group_start_col;
66-
6757 // Get output group start column
6858 int output_group_start_col = compute_output_group_start_col (
6959 group_id,
7060 input_group_end_offsets,
7161 num_groups,
7262 4 ); // scaling factor column padding size
73-
7463 // Compute base offset for this group in output
7564 int out_group_base_offset = output_group_start_col * padded_rows;
76-
7765 // Compute stride per row of blocks in this group
7866 int num_col_blocks_in_group = ceil_div (num_cols_in_group, BLOCK_COLS);
7967 int stride_per_row_of_blocks_in_group = num_col_blocks_in_group * output_stride_per_block;
80-
8168 // Each thread handles one row
8269 int input_row = block_row_id * BLOCK_ROWS + tid;
83-
8470 // Loop through column blocks in this group
8571 int curr_input_start_col = input_group_start_col;
8672 int curr_out_col_block = 0 ;
87-
8873 while (curr_input_start_col < input_group_end_col) {
8974 // Calculate how many columns to load for this block
9075 int cols_remaining = input_group_end_col - curr_input_start_col;
9176 int cols_to_load = min (BLOCK_COLS, cols_remaining);
92-
9377 // Load data for this row using vectorized loads when possible
9478 uint32_t row_data = 0 ;
95-
9679 if (input_row < scale_rows && curr_input_start_col < input_group_end_col) {
9780 int input_offset = input_row * scales_stride_dim0 + curr_input_start_col;
9881 const uint8_t * input_ptr = scales_ptr + input_offset;
99-
10082 // Check alignment and available columns within this group
10183 uintptr_t ptr_addr = reinterpret_cast <uintptr_t >(input_ptr);
102-
10384 if (cols_to_load >= 4 && ptr_addr % 4 == 0 && curr_input_start_col + 4 <= input_group_end_col) {
10485 // 4-byte aligned and have 4 columns within group: use uint32_t load
105- row_data = * reinterpret_cast <const uint32_t *>(input_ptr);
86+ row_data = __ldg ( reinterpret_cast <const uint32_t *>(input_ptr) );
10687 } else {
10788 // Byte-by-byte loads for unaligned or partial blocks
10889 uint8_t * row_bytes = reinterpret_cast <uint8_t *>(&row_data);
10990 for (int i = 0 ; i < cols_to_load && (curr_input_start_col + i) < input_group_end_col; i++) {
110- row_bytes[i] = input_ptr[i] ;
91+ row_bytes[i] = __ldg ( input_ptr + i) ;
11192 }
11293 }
11394 }
114-
11595 // Write to swizzled positions in shared memory
11696 uint8_t * row_bytes = reinterpret_cast <uint8_t *>(&row_data);
117-
11897 #pragma unroll
11998 for (int col = 0 ; col < BLOCK_COLS; col++) {
12099 int swizzled_idx = compute_swizzled_index (tid, col);
121100 smem_block[swizzled_idx] = row_bytes[col];
122101 }
123-
124102 __syncthreads ();
125-
126103 // Write from shared memory to global memory
127104 // Calculate the output offset for this specific block
128105 int offset_in_group = block_row_id * stride_per_row_of_blocks_in_group +
129106 curr_out_col_block * output_stride_per_block;
130107 int final_offset = out_group_base_offset + offset_in_group;
131-
132108 // Each thread writes 4 bytes (one row of the 128x4 block)
133109 uint8_t * output_ptr = output_scales_ptr + final_offset + tid * BLOCK_COLS;
134-
135110 // Check output alignment for vectorized write
136111 uintptr_t out_ptr_addr = reinterpret_cast <uintptr_t >(output_ptr);
137112 if (out_ptr_addr % 4 == 0 ) {
@@ -146,18 +121,17 @@ __global__ void mx_block_rearrange_2d_K_groups_naive_kernel(
146121 output_ptr[i] = smem_ptr[i];
147122 }
148123 }
149-
150- __syncthreads ();
151-
152124 // Advance to next column block
153125 curr_input_start_col += BLOCK_COLS;
154126 curr_out_col_block += 1 ;
127+ // Only sync if there's another iteration
128+ if (curr_input_start_col < input_group_end_col) {
129+ __syncthreads ();
130+ }
155131 }
156132}
157-
158133// Host function to launch the kernel
159134namespace mxfp8 {
160-
161135void launch_mx_block_rearrange_2d_K_groups (
162136 const uint8_t * scales_ptr,
163137 int scales_stride_dim0,
@@ -170,14 +144,11 @@ void launch_mx_block_rearrange_2d_K_groups(
170144 cudaStream_t stream
171145) {
172146 int num_row_blocks = (scale_rows + BLOCK_ROWS - 1 ) / BLOCK_ROWS;
173-
174147 // Grid parallelizes over (num_groups, num_row_blocks)
175148 // Each thread block loops through column blocks within its group
176149 dim3 grid (num_groups, num_row_blocks);
177150 dim3 block (128 ); // 128 threads, each handling one row
178-
179151 int output_stride_per_block = BLOCK_ROWS * BLOCK_COLS;
180-
181152 mx_block_rearrange_2d_K_groups_naive_kernel<<<grid, block, 0 , stream>>> (
182153 scales_ptr,
183154 scales_stride_dim0,
@@ -189,11 +160,9 @@ void launch_mx_block_rearrange_2d_K_groups(
189160 output_stride_per_block,
190161 num_groups
191162 );
192-
193163 cudaError_t err = cudaGetLastError ();
194164 if (err != cudaSuccess) {
195165 printf (" CUDA Error: %s\n " , cudaGetErrorString (err));
196166 }
197167}
198-
199168} // namespace mxfp8
0 commit comments