File tree Expand file tree Collapse file tree 1 file changed +4
-1
lines changed
transformer_engine/common/fused_attn Expand file tree Collapse file tree 1 file changed +4
-1
lines changed Original file line number Diff line number Diff line change @@ -339,7 +339,10 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
339339 attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK))) ||
340340 // 9.11: d_qk = 192, d_v = 128 + Blackwell + bprop + non-paged
341341 (head_dim_qk == 192 && head_dim_v == 128 && is_training && sm_arch_ >= 100 &&
342- cudnn_runtime_version >= 91100 )) &&
342+ cudnn_runtime_version >= 91100 ) ||
343+ // 9.20: any head_dim + Blackwell + fprop/bprop + non_paged + any sq
344+ (sm_arch_ >= 100 && cudnn_runtime_version >= 92000 &&
345+ layout_group != NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD)) &&
343346 // 9.11+ bug: 128 < d_qk <= 256, 128 < d_v <= 256 + Hopper + bprop + MLA
344347 // Conditional to temporarily use blanket cudnn_runtime_version >= 9.11 until fixed
345348 (!((cudnn_runtime_version >= 91100 ) && is_training && sm_arch_ == 90 &&
You can’t perform that action at this time.
0 commit comments