diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 170cb2cd34..d3eb70a2ad 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -631,16 +631,13 @@ def get_attention_backend( if use_flash_attention_2 and ( head_dim_qk > 256 or head_dim_qk % 8 != 0 - or ( - head_dim_qk > 192 - and device_compute_capability not in ((8, 0), (9, 0), (10, 0), (12, 0)) - ) + or (head_dim_qk > 192 and device_compute_capability < (8, 0)) ): if FlashAttentionUtils.is_installed: logger.debug( "Disabling FlashAttention 2 due to unsupported head_dim_qk and head_dim_v. " "Supported: head_dim_qk = head_dim_v, head_dim_qk %%8 = 0, " - "head_dim_qk <= 256 (>192 requires sm80/90/100+). " + "head_dim_qk <= 256 (>192 requires sm80+). " "Found: head_dim_qk = %s, head_dim_v = %s, on sm%s.", head_dim_qk, head_dim_v,