Skip to content

Commit 46d1859

Browse files
committed
Fix lint error
1 parent 8edd692 commit 46d1859

File tree

1 file changed

+6
-21
lines changed

1 file changed

+6
-21
lines changed

test/xpu/test_decomp.py

Lines changed: 6 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from torch.testing._internal.common_device_type import (
2020
instantiate_device_type_tests,
2121
onlyCPU,
22-
onlyCUDA,
2322
onlyNativeDeviceTypes,
2423
ops,
2524
)
@@ -45,19 +44,9 @@
4544
from torch.utils._python_dispatch import TorchDispatchMode
4645
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
4746

48-
4947
device_type = (
5048
acc.type if (acc := torch.accelerator.current_accelerator(True)) else "cpu"
5149
)
52-
# # XPU 支持导入
53-
# try:
54-
# from xpu_test_utils import XPUPatchForImport
55-
# except Exception as e:
56-
# from .xpu_test_utils import XPUPatchForImport
57-
58-
# with XPUPatchForImport(False):
59-
# from test_decomp import DecompOneOffTests, TestDecomp
60-
6150

6251
aten = torch.ops.aten
6352

@@ -194,7 +183,7 @@ def _getDefaultRtolAndAtol(dtype0, dtype1):
194183
return rtol, atol
195184

196185

197-
# 修改后的 op_assert_ref 函数,包含 XPU 特定的容差配置
186+
198187
def op_assert_ref(test_case, op, test_dtype, i, orig, decomp, ref, args, kwargs):
199188
assert orig.dtype == decomp.dtype, f"{i} Operation: {op}"
200189
if orig.numel() == 0 or decomp.numel() == 0:
@@ -240,7 +229,7 @@ def op_assert_ref(test_case, op, test_dtype, i, orig, decomp, ref, args, kwargs)
240229
(torch.bfloat16, torch.ops.aten.mv.default): 1e-5,
241230
(torch.float16, torch.ops.aten.log_sigmoid_backward.default): 2e-5,
242231
(torch.float16, torch.ops.aten._softmax_backward_data.default): 3e-7,
243-
# XPU specific
232+
# XPU specific
244233
(
245234
torch.float16,
246235
torch.ops.aten._batch_norm_with_update.default,
@@ -781,7 +770,6 @@ def forward(self, x_1, start_1):
781770
def test_masked_fill(self, device):
782771
from torch.fx.experimental.proxy_tensor import make_fx
783772

784-
# 修改设备检查以包含 XPU
785773
if torch.device(device).type not in [
786774
"xpu",
787775
"cuda",
@@ -1068,7 +1056,6 @@ def run_without_python_dispatcher(mode):
10681056
)
10691057

10701058

1071-
# 修改测试实例化以支持 XPU
10721059
instantiate_device_type_tests(TestDecomp, globals(), only_for="xpu", allow_xpu=True)
10731060

10741061

@@ -1101,7 +1088,6 @@ def test_contiguous_log_softmax(self, device):
11011088
res = torch._decomp.decompositions._log_softmax(x, -1, False)
11021089
self.assertEqual(ref.stride(), res.stride())
11031090

1104-
# @onlyCUDA
11051091
def test_exponential_non_inf(self, device):
11061092
inp = torch.empty((4, 400, 256), device=device)
11071093

@@ -1114,7 +1100,6 @@ def test_exponential_non_inf(self, device):
11141100

11151101
@unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
11161102
@skipIfCrossRef
1117-
# @onlyCUDA
11181103
def test_amp_batch_norm_backward(self):
11191104
device = device_type
11201105
grad_out = torch.randn((1, 2, 16, 16), dtype=torch.float16, device=device)
@@ -1265,7 +1250,6 @@ def f(x, w, b):
12651250
for o_ref, o in zip(out_ref, out):
12661251
self.assertEqual(o_ref.dtype, o.dtype)
12671252

1268-
# @onlyCUDA
12691253
@unittest.skipIf(not SM70OrLater, "triton")
12701254
def test_rms_norm_decomp_accelerator(self, device):
12711255
@torch.compile
@@ -1293,8 +1277,9 @@ def forward_pass_fn():
12931277
)
12941278

12951279

1296-
# 修改测试实例化以支持 XPU
1297-
instantiate_device_type_tests(DecompOneOffTests, globals(), only_for="xpu", allow_xpu=True)
1280+
instantiate_device_type_tests(
1281+
DecompOneOffTests, globals(), only_for="xpu", allow_xpu=True
1282+
)
12981283

12991284

13001285
class HasDecompTest(TestCase):
@@ -1376,4 +1361,4 @@ def test_aten_core_operators(self):
13761361

13771362

13781363
if __name__ == "__main__":
1379-
run_tests()
1364+
run_tests()

0 commit comments

Comments
 (0)