@@ -79,14 +79,12 @@ void rowwise_scaled_linear_sparse_kernel_cutlass_sm9x(
7979
8080 using ProblemShape = cute::Shape<int , int , int , int >;
8181
82- // If KernelTmaWarpSpecializedPingpong used for kernel schedule, the
83- // performance is really bad; on the other side, using
84- // KernelTmaWarpSpecializedPingpongFP8FastAccum doesn't seem to
85- // affect the precision much - thus, sticking with it.
82+ // If FP8FastAccum not used for kernel schedule, the performance is
83+ // really bad; on the other side, using it doesn't seem to affect
84+ // the precision much - thus, sticking with it.
8685 using KernelSchedule =
8786 cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
8887 using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized;
89-
9088 constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
9189 using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
9290 using AScale =
@@ -256,13 +254,29 @@ static void select_config(
256254 std::is_same<DtypeWq, cutlass::float_e4m3_t >::value) ||
257255 (std::is_same<DtypeXq, cutlass::float_e5m2_t >::value &&
258256 std::is_same<DtypeWq, cutlass::float_e5m2_t >::value)) {
259- // TODO: add proper tuning here.
260- using TileShape = cute::Shape<cute::_128, cute::_128, cute::_128>;
261- using ClusterShape = cute::Shape<cute::_1, cute::_2, cute::_1>;
262- rowwise_scaled_linear_sparse_kernel_cutlass_sm9x<
263- DtypeXq, DtypeWq, Types..., TileShape, ClusterShape>(
264- Xq, X_scale, Wq, W_meta, W_scale, bias, Y);
265- return ;
257+ const auto m = Y.size (0 );
258+ if (m <= 64 ) {
259+ using TileShape = cute::Shape<cute::_64, cute::_32, cute::_256>;
260+ using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
261+ rowwise_scaled_linear_sparse_kernel_cutlass_sm9x<
262+ DtypeXq, DtypeWq, Types..., TileShape, ClusterShape>(
263+ Xq, X_scale, Wq, W_meta, W_scale, bias, Y);
264+ return ;
265+ } else if (m <= 128 ) {
266+ using TileShape = cute::Shape<cute::_64, cute::_128, cute::_256>;
267+ using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
268+ rowwise_scaled_linear_sparse_kernel_cutlass_sm9x<
269+ DtypeXq, DtypeWq, Types..., TileShape, ClusterShape>(
270+ Xq, X_scale, Wq, W_meta, W_scale, bias, Y);
271+ return ;
272+ } else {
273+ using TileShape = cute::Shape<cute::_128, cute::_128, cute::_256>;
274+ using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
275+ rowwise_scaled_linear_sparse_kernel_cutlass_sm9x<
276+ DtypeXq, DtypeWq, Types..., TileShape, ClusterShape>(
277+ Xq, X_scale, Wq, W_meta, W_scale, bias, Y);
278+ return ;
279+ }
266280 }
267281 }
268282
0 commit comments