1919from torch .testing ._internal .common_device_type import (
2020 instantiate_device_type_tests ,
2121 onlyCPU ,
22- onlyCUDA ,
2322 onlyNativeDeviceTypes ,
2423 ops ,
2524)
4544from torch .utils ._python_dispatch import TorchDispatchMode
4645from torch .utils ._pytree import tree_flatten , tree_map , tree_unflatten
4746
48-
4947device_type = (
5048 acc .type if (acc := torch .accelerator .current_accelerator (True )) else "cpu"
5149)
52- # # XPU 支持导入
53- # try:
54- # from xpu_test_utils import XPUPatchForImport
55- # except Exception as e:
56- # from .xpu_test_utils import XPUPatchForImport
57-
58- # with XPUPatchForImport(False):
59- # from test_decomp import DecompOneOffTests, TestDecomp
60-
6150
6251aten = torch .ops .aten
6352
@@ -194,7 +183,7 @@ def _getDefaultRtolAndAtol(dtype0, dtype1):
194183 return rtol , atol
195184
196185
197- # 修改后的 op_assert_ref 函数,包含 XPU 特定的容差配置
186+
198187def op_assert_ref (test_case , op , test_dtype , i , orig , decomp , ref , args , kwargs ):
199188 assert orig .dtype == decomp .dtype , f"{ i } Operation: { op } "
200189 if orig .numel () == 0 or decomp .numel () == 0 :
@@ -240,7 +229,7 @@ def op_assert_ref(test_case, op, test_dtype, i, orig, decomp, ref, args, kwargs)
240229 (torch .bfloat16 , torch .ops .aten .mv .default ): 1e-5 ,
241230 (torch .float16 , torch .ops .aten .log_sigmoid_backward .default ): 2e-5 ,
242231 (torch .float16 , torch .ops .aten ._softmax_backward_data .default ): 3e-7 ,
243- # XPU specific
232+ # XPU specific
244233 (
245234 torch .float16 ,
246235 torch .ops .aten ._batch_norm_with_update .default ,
@@ -781,7 +770,6 @@ def forward(self, x_1, start_1):
781770 def test_masked_fill (self , device ):
782771 from torch .fx .experimental .proxy_tensor import make_fx
783772
784- # 修改设备检查以包含 XPU
785773 if torch .device (device ).type not in [
786774 "xpu" ,
787775 "cuda" ,
@@ -1068,7 +1056,6 @@ def run_without_python_dispatcher(mode):
10681056 )
10691057
10701058
1071- # 修改测试实例化以支持 XPU
10721059instantiate_device_type_tests (TestDecomp , globals (), only_for = "xpu" , allow_xpu = True )
10731060
10741061
@@ -1101,7 +1088,6 @@ def test_contiguous_log_softmax(self, device):
11011088 res = torch ._decomp .decompositions ._log_softmax (x , - 1 , False )
11021089 self .assertEqual (ref .stride (), res .stride ())
11031090
1104- # @onlyCUDA
11051091 def test_exponential_non_inf (self , device ):
11061092 inp = torch .empty ((4 , 400 , 256 ), device = device )
11071093
@@ -1114,7 +1100,6 @@ def test_exponential_non_inf(self, device):
11141100
11151101 @unittest .skipIf (TEST_WITH_ASAN , "Skipped under ASAN" )
11161102 @skipIfCrossRef
1117- # @onlyCUDA
11181103 def test_amp_batch_norm_backward (self ):
11191104 device = device_type
11201105 grad_out = torch .randn ((1 , 2 , 16 , 16 ), dtype = torch .float16 , device = device )
@@ -1265,7 +1250,6 @@ def f(x, w, b):
12651250 for o_ref , o in zip (out_ref , out ):
12661251 self .assertEqual (o_ref .dtype , o .dtype )
12671252
1268- # @onlyCUDA
12691253 @unittest .skipIf (not SM70OrLater , "triton" )
12701254 def test_rms_norm_decomp_accelerator (self , device ):
12711255 @torch .compile
@@ -1293,8 +1277,9 @@ def forward_pass_fn():
12931277 )
12941278
12951279
1296- # 修改测试实例化以支持 XPU
1297- instantiate_device_type_tests (DecompOneOffTests , globals (), only_for = "xpu" , allow_xpu = True )
1280+ instantiate_device_type_tests (
1281+ DecompOneOffTests , globals (), only_for = "xpu" , allow_xpu = True
1282+ )
12981283
12991284
13001285class HasDecompTest (TestCase ):
@@ -1376,4 +1361,4 @@ def test_aten_core_operators(self):
13761361
13771362
13781363if __name__ == "__main__" :
1379- run_tests ()
1364+ run_tests ()
0 commit comments