Skip to content

Commit daf9ffd

Browse files
stick with loop for now
1 parent 825340e commit daf9ffd

File tree

1 file changed

+6
-37
lines changed

1 file changed

+6
-37
lines changed

torchao/csrc/cuda/mx_kernels/mx_block_rearrange_2d_K_groups.cu

Lines changed: 6 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,12 @@
22
#include <cuda_fp8.h>
33
#include <cstdint>
44
#include <cstdio>
5-
65
#define BLOCK_ROWS 128
76
#define BLOCK_COLS 4
8-
97
// Helper function to compute ceil division
108
__device__ __forceinline__ int ceil_div(int a, int b) {
119
return (a + b - 1) / b;
1210
}
13-
1411
// Helper function to compute the start index of a group after padding
1512
__device__ __forceinline__ int compute_output_group_start_col(
1613
int group_id,
@@ -19,7 +16,6 @@ __device__ __forceinline__ int compute_output_group_start_col(
1916
int padding_size
2017
) {
2118
int start_idx = 0;
22-
2319
// Compute prefix sum of padded group sizes
2420
for (int i = 0; i < group_id; i++) {
2521
int prev_offset = (i > 0) ? input_group_end_offsets[i - 1] : 0;
@@ -28,10 +24,8 @@ __device__ __forceinline__ int compute_output_group_start_col(
2824
int padded_size = ceil_div(group_size, padding_size) * padding_size;
2925
start_idx += padded_size;
3026
}
31-
3227
return start_idx;
3328
}
34-
3529
// Compute destination index for swizzled block layout
3630
// For a 128x4 block: r_div_32 = row / 32, r_mod_32 = row % 32
3731
// Swizzle: dest = r_mod_32 * 16 + r_div_32 * 4 + col
@@ -40,7 +34,6 @@ __device__ __forceinline__ int compute_swizzled_index(int row, int col) {
4034
int r_mod_32 = row % 32;
4135
return r_mod_32 * 16 + r_div_32 * 4 + col;
4236
}
43-
4437
__global__ void mx_block_rearrange_2d_K_groups_naive_kernel(
4538
const uint8_t* __restrict__ scales_ptr,
4639
int scales_stride_dim0,
@@ -55,83 +48,65 @@ __global__ void mx_block_rearrange_2d_K_groups_naive_kernel(
5548
const int group_id = blockIdx.x;
5649
const int block_row_id = blockIdx.y;
5750
const int tid = threadIdx.x; // 128 threads, each handles one row
58-
5951
// Shared memory for one 128x4 block
6052
__shared__ __align__(16) uint8_t smem_block[BLOCK_ROWS * BLOCK_COLS];
61-
6253
// Get start/end cols of this input group
6354
int input_group_start_col = (group_id > 0) ? input_group_end_offsets[group_id - 1] : 0;
6455
int input_group_end_col = input_group_end_offsets[group_id];
6556
int num_cols_in_group = input_group_end_col - input_group_start_col;
66-
6757
// Get output group start column
6858
int output_group_start_col = compute_output_group_start_col(
6959
group_id,
7060
input_group_end_offsets,
7161
num_groups,
7262
4); // scaling factor column padding size
73-
7463
// Compute base offset for this group in output
7564
int out_group_base_offset = output_group_start_col * padded_rows;
76-
7765
// Compute stride per row of blocks in this group
7866
int num_col_blocks_in_group = ceil_div(num_cols_in_group, BLOCK_COLS);
7967
int stride_per_row_of_blocks_in_group = num_col_blocks_in_group * output_stride_per_block;
80-
8168
// Each thread handles one row
8269
int input_row = block_row_id * BLOCK_ROWS + tid;
83-
8470
// Loop through column blocks in this group
8571
int curr_input_start_col = input_group_start_col;
8672
int curr_out_col_block = 0;
87-
8873
while (curr_input_start_col < input_group_end_col) {
8974
// Calculate how many columns to load for this block
9075
int cols_remaining = input_group_end_col - curr_input_start_col;
9176
int cols_to_load = min(BLOCK_COLS, cols_remaining);
92-
9377
// Load data for this row using vectorized loads when possible
9478
uint32_t row_data = 0;
95-
9679
if (input_row < scale_rows && curr_input_start_col < input_group_end_col) {
9780
int input_offset = input_row * scales_stride_dim0 + curr_input_start_col;
9881
const uint8_t* input_ptr = scales_ptr + input_offset;
99-
10082
// Check alignment and available columns within this group
10183
uintptr_t ptr_addr = reinterpret_cast<uintptr_t>(input_ptr);
102-
10384
if (cols_to_load >= 4 && ptr_addr % 4 == 0 && curr_input_start_col + 4 <= input_group_end_col) {
10485
// 4-byte aligned and have 4 columns within group: use uint32_t load
105-
row_data = *reinterpret_cast<const uint32_t*>(input_ptr);
86+
row_data = __ldg(reinterpret_cast<const uint32_t*>(input_ptr));
10687
} else {
10788
// Byte-by-byte loads for unaligned or partial blocks
10889
uint8_t* row_bytes = reinterpret_cast<uint8_t*>(&row_data);
10990
for (int i = 0; i < cols_to_load && (curr_input_start_col + i) < input_group_end_col; i++) {
110-
row_bytes[i] = input_ptr[i];
91+
row_bytes[i] = __ldg(input_ptr + i);
11192
}
11293
}
11394
}
114-
11595
// Write to swizzled positions in shared memory
11696
uint8_t* row_bytes = reinterpret_cast<uint8_t*>(&row_data);
117-
11897
#pragma unroll
11998
for (int col = 0; col < BLOCK_COLS; col++) {
12099
int swizzled_idx = compute_swizzled_index(tid, col);
121100
smem_block[swizzled_idx] = row_bytes[col];
122101
}
123-
124102
__syncthreads();
125-
126103
// Write from shared memory to global memory
127104
// Calculate the output offset for this specific block
128105
int offset_in_group = block_row_id * stride_per_row_of_blocks_in_group +
129106
curr_out_col_block * output_stride_per_block;
130107
int final_offset = out_group_base_offset + offset_in_group;
131-
132108
// Each thread writes 4 bytes (one row of the 128x4 block)
133109
uint8_t* output_ptr = output_scales_ptr + final_offset + tid * BLOCK_COLS;
134-
135110
// Check output alignment for vectorized write
136111
uintptr_t out_ptr_addr = reinterpret_cast<uintptr_t>(output_ptr);
137112
if (out_ptr_addr % 4 == 0) {
@@ -146,18 +121,17 @@ __global__ void mx_block_rearrange_2d_K_groups_naive_kernel(
146121
output_ptr[i] = smem_ptr[i];
147122
}
148123
}
149-
150-
__syncthreads();
151-
152124
// Advance to next column block
153125
curr_input_start_col += BLOCK_COLS;
154126
curr_out_col_block += 1;
127+
// Only sync if there's another iteration
128+
if (curr_input_start_col < input_group_end_col) {
129+
__syncthreads();
130+
}
155131
}
156132
}
157-
158133
// Host function to launch the kernel
159134
namespace mxfp8 {
160-
161135
void launch_mx_block_rearrange_2d_K_groups(
162136
const uint8_t* scales_ptr,
163137
int scales_stride_dim0,
@@ -170,14 +144,11 @@ void launch_mx_block_rearrange_2d_K_groups(
170144
cudaStream_t stream
171145
) {
172146
int num_row_blocks = (scale_rows + BLOCK_ROWS - 1) / BLOCK_ROWS;
173-
174147
// Grid parallelizes over (num_groups, num_row_blocks)
175148
// Each thread block loops through column blocks within its group
176149
dim3 grid(num_groups, num_row_blocks);
177150
dim3 block(128); // 128 threads, each handling one row
178-
179151
int output_stride_per_block = BLOCK_ROWS * BLOCK_COLS;
180-
181152
mx_block_rearrange_2d_K_groups_naive_kernel<<<grid, block, 0, stream>>>(
182153
scales_ptr,
183154
scales_stride_dim0,
@@ -189,11 +160,9 @@ void launch_mx_block_rearrange_2d_K_groups(
189160
output_stride_per_block,
190161
num_groups
191162
);
192-
193163
cudaError_t err = cudaGetLastError();
194164
if (err != cudaSuccess) {
195165
printf("CUDA Error: %s\n", cudaGetErrorString(err));
196166
}
197167
}
198-
199168
} // namespace mxfp8

0 commit comments

Comments
 (0)