Skip to content

Commit 248a403

Browse files
committed
added checks to run only on CUDA with compatibility >=9
1 parent 49dd2ce commit 248a403

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

test/prototype/mx_formats/test_mxfp8_allgather.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import pytest
12
import torch
23
import torch.distributed as dist
34
from torch.testing._internal.common_distributed import (
@@ -9,6 +10,11 @@
910

1011
from torchao.prototype.mx_formats.mx_tensor import MXTensor
1112

13+
if not torch.cuda.is_available() or torch.cuda.get_device_capability() < (9, 0):
14+
pytest.skip(
15+
"Test Requires CUDA and compute capability >= 9.0", allow_module_level=True
16+
)
17+
1218

1319
@instantiate_parametrized_tests
1420
class MXFP8OnDeviceAllGatherTest(MultiProcessTestCase):

0 commit comments

Comments
 (0)