Skip to content

Commit 83d4ca2

Browse files
committed
add back torch.compile
1 parent e464d85 commit 83d4ca2

File tree

1 file changed

+6
-3
lines changed
  • src/fairseq2/models/transformer/_sdpa

1 file changed

+6
-3
lines changed

src/fairseq2/models/transformer/_sdpa/_flex.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,13 @@
88

99
from typing import Callable, TypeAlias, final
1010

11+
import torch
1112
from torch import Tensor
1213
from torch.nn.attention.flex_attention import flex_attention
1314
from typing_extensions import override
1415

1516
from fairseq2.models.transformer._block_mask import BlockMaskCache
17+
from fairseq2.logging import log
1618
from fairseq2.nn import BatchLayout
1719

1820
# isort: split
@@ -25,9 +27,10 @@
2527

2628
MaskFunction: TypeAlias = Callable[[Tensor, Tensor, Tensor, Tensor], Tensor]
2729

28-
# TODO: Hitting some torch.compile issues with this enabled for different builds.
29-
# Commenting out for now until we can investigate.
30-
# flex_attention = torch.compile(flex_attention, dynamic=False)
30+
# NOTE: Flex attention only has performance benefits when torch.compiled, but this is
31+
# not possible on certain platforms (e.g., CPU).
32+
if torch.cuda.is_available():
33+
flex_attention = torch.compile(flex_attention, dynamic=False)
3134

3235

3336
@final

0 commit comments

Comments
 (0)