Skip to content

Commit a99255a

Browse files
committed
updated test with rebase changes
1 parent 565e813 commit a99255a

File tree

2 files changed

+15
-20
lines changed

2 files changed

+15
-20
lines changed

torchao/prototype/tests/test_mxfp8_allgather.py renamed to test/prototype/mx_formats/test_mxfp8_allgather.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,4 @@
1-
import pytest
21
import torch
3-
4-
if not torch.cuda.is_available() or torch.cuda.get_device_capability() != (10, 0):
5-
pytest.skip("Test requires CUDA build on SM100", allow_module_level=True)
6-
72
import torch.distributed as dist
83
from torch.testing._internal.common_distributed import (
94
MultiProcessTestCase,
@@ -23,7 +18,7 @@ def setUp(self) -> None:
2318

2419
@property
2520
def world_size(self) -> int:
26-
return 4
21+
return 2
2722

2823
@property
2924
def device(self) -> torch.device:
@@ -64,9 +59,9 @@ def test_allgather(self):
6459
elem_dtype=torch.float8_e5m2,
6560
block_size=32,
6661
orig_dtype=torch.float32,
67-
gemm_kernel_choice=None,
68-
pack_fp6=None,
62+
kernel_preference=None,
6963
act_quant_kwargs=None,
64+
is_swizzled_scales=None,
7065
)
7166

7267
world_size = self.world_size
@@ -82,9 +77,9 @@ def test_allgather(self):
8277
elem_dtype=torch.float8_e5m2,
8378
block_size=32,
8479
orig_dtype=torch.float32,
85-
gemm_kernel_choice=None,
86-
pack_fp6=None,
80+
kernel_preference=None,
8781
act_quant_kwargs=None,
82+
is_swizzled_scales=None,
8883
)
8984

9085
# Perform all_gather
@@ -111,12 +106,12 @@ def test_allgather(self):
111106

112107
# Verify scale matches golden exactly
113108
if not torch.equal(
114-
gathered_mx._scale_e8m0.view(torch.uint8),
109+
gathered_mx.scale.view(torch.uint8),
115110
golden_scale.view(torch.uint8),
116111
):
117112
assert False, "scale mismatch"
118113

119-
assert gathered_mx._block_size == 32
114+
assert gathered_mx.block_size == 32
120115

121116
finally:
122117
dist.destroy_process_group()

torchao/prototype/mx_formats/mx_tensor.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -869,7 +869,7 @@ def mx_all_gather(func, types, args, kwargs):
869869
)
870870

871871
gathered_scale = torch.ops._c10d_functional.all_gather_into_tensor.default(
872-
mx_tensor._scale_e8m0.view(
872+
mx_tensor.scale.view(
873873
torch.uint8
874874
), # The scale factors, Need to cast to uint8 as float8_e8m0fnu is not support for all gather.
875875
group_tag,
@@ -884,11 +884,11 @@ def mx_all_gather(func, types, args, kwargs):
884884
gathered_qdata,
885885
gathered_scale,
886886
mx_tensor._elem_dtype,
887-
mx_tensor._block_size,
887+
mx_tensor.block_size,
888888
mx_tensor._orig_dtype,
889-
mx_tensor._gemm_kernel_choice,
890-
mx_tensor._pack_fp6,
889+
mx_tensor.kernel_preference,
891890
mx_tensor.act_quant_kwargs,
891+
mx_tensor._is_swizzled_scales,
892892
)
893893

894894

@@ -908,16 +908,16 @@ def mx_wait_tensor(func, types, args, kwargs):
908908
)
909909

910910
waited_scale = torch.ops._c10d_functional.wait_tensor.default(
911-
mx_tensor._scale_e8m0, *args[1:], **kwargs
911+
mx_tensor.scale, *args[1:], **kwargs
912912
)
913913

914914
return MXTensor(
915915
waited_qdata,
916916
waited_scale,
917917
mx_tensor._elem_dtype,
918-
mx_tensor._block_size,
918+
mx_tensor.block_size,
919919
mx_tensor._orig_dtype,
920-
mx_tensor._gemm_kernel_choice,
921-
mx_tensor._pack_fp6,
920+
mx_tensor.kernel_preference,
922921
mx_tensor.act_quant_kwargs,
922+
mx_tensor._is_swizzled_scales,
923923
)

0 commit comments

Comments
 (0)