Skip to content

Commit d72d8a2

Browse files
committed
FusedAttention: Add cudnn 9.20 path for SM arch >100
Signed-off-by: zmelumian972 <zmelumian@gmail.com> Signed-off-by: zmelumian <zmelumian@lightricks.com>
1 parent 0b4a561 commit d72d8a2

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

transformer_engine/common/fused_attn/fused_attn.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,10 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
339339
attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK))) ||
340340
// 9.11: d_qk = 192, d_v = 128 + Blackwell + bprop + non-paged
341341
(head_dim_qk == 192 && head_dim_v == 128 && is_training && sm_arch_ >= 100 &&
342-
cudnn_runtime_version >= 91100)) &&
342+
cudnn_runtime_version >= 91100) ||
343+
// 9.20: any head_dim + Blackwell + fprop/bprop + non_paged + any sq
344+
(sm_arch_ >= 100 && cudnn_runtime_version >= 92000 &&
345+
layout_group != NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD)) &&
343346
// 9.11+ bug: 128 < d_qk <= 256, 128 < d_v <= 256 + Hopper + bprop + MLA
344347
// Conditional to temporarily use blanket cudnn_runtime_version >= 9.11 until fixed
345348
(!((cudnn_runtime_version >= 91100) && is_training && sm_arch_ == 90 &&

0 commit comments

Comments
 (0)