Skip to content

Commit a8d2159

Browse files
Add config selection for row-wise scaled FP8 sparse CUTLASS-based kernel (#1940)
1 parent 83eb490 commit a8d2159

File tree

1 file changed

+26
-12
lines changed

1 file changed

+26
-12
lines changed

torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass.cuh

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -79,14 +79,12 @@ void rowwise_scaled_linear_sparse_kernel_cutlass_sm9x(
7979

8080
using ProblemShape = cute::Shape<int, int, int, int>;
8181

82-
// If KernelTmaWarpSpecializedPingpong used for kernel schedule, the
83-
// performance is really bad; on the other side, using
84-
// KernelTmaWarpSpecializedPingpongFP8FastAccum doesn't seem to
85-
// affect the precision much - thus, sticking with it.
82+
// If FP8FastAccum not used for kernel schedule, the performance is
83+
// really bad; on the other side, using it doesn't seem to affect
84+
// the precision much - thus, sticking with it.
8685
using KernelSchedule =
8786
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
8887
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized;
89-
9088
constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
9189
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
9290
using AScale =
@@ -256,13 +254,29 @@ static void select_config(
256254
std::is_same<DtypeWq, cutlass::float_e4m3_t>::value) ||
257255
(std::is_same<DtypeXq, cutlass::float_e5m2_t>::value &&
258256
std::is_same<DtypeWq, cutlass::float_e5m2_t>::value)) {
259-
// TODO: add proper tuning here.
260-
using TileShape = cute::Shape<cute::_128, cute::_128, cute::_128>;
261-
using ClusterShape = cute::Shape<cute::_1, cute::_2, cute::_1>;
262-
rowwise_scaled_linear_sparse_kernel_cutlass_sm9x<
263-
DtypeXq, DtypeWq, Types..., TileShape, ClusterShape>(
264-
Xq, X_scale, Wq, W_meta, W_scale, bias, Y);
265-
return;
257+
const auto m = Y.size(0);
258+
if (m <= 64) {
259+
using TileShape = cute::Shape<cute::_64, cute::_32, cute::_256>;
260+
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
261+
rowwise_scaled_linear_sparse_kernel_cutlass_sm9x<
262+
DtypeXq, DtypeWq, Types..., TileShape, ClusterShape>(
263+
Xq, X_scale, Wq, W_meta, W_scale, bias, Y);
264+
return;
265+
} else if (m <= 128) {
266+
using TileShape = cute::Shape<cute::_64, cute::_128, cute::_256>;
267+
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
268+
rowwise_scaled_linear_sparse_kernel_cutlass_sm9x<
269+
DtypeXq, DtypeWq, Types..., TileShape, ClusterShape>(
270+
Xq, X_scale, Wq, W_meta, W_scale, bias, Y);
271+
return;
272+
} else {
273+
using TileShape = cute::Shape<cute::_128, cute::_128, cute::_256>;
274+
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
275+
rowwise_scaled_linear_sparse_kernel_cutlass_sm9x<
276+
DtypeXq, DtypeWq, Types..., TileShape, ClusterShape>(
277+
Xq, X_scale, Wq, W_meta, W_scale, bias, Y);
278+
return;
279+
}
266280
}
267281
}
268282

0 commit comments

Comments
 (0)