|
96 | 96 | get_groupwise_affine_qparams, |
97 | 97 | groupwise_affine_quantize_tensor, |
98 | 98 | ) |
| 99 | +from torchao.testing.utils import skip_if_xpu |
99 | 100 | from torchao.utils import ( |
100 | 101 | _is_fbgemm_gpu_genai_available, |
101 | 102 | get_current_accelerator_device, |
@@ -695,10 +696,7 @@ def test_qat_4w_quantizer_gradients(self): |
695 | 696 | self._test_qat_quantized_gradients(quantizer) |
696 | 697 |
|
697 | 698 | @unittest.skipIf(_DEVICE is None, "skipping when GPU is not available") |
698 | | - @unittest.skipIf( |
699 | | - _DEVICE is torch.device("xpu"), |
700 | | - "skipped due to https://github.com/intel/torch-xpu-ops/issues/1770", |
701 | | - ) |
| 699 | + @skip_if_xpu("skipped due to https://github.com/intel/torch-xpu-ops/issues/1770") |
702 | 700 | def test_qat_4w_quantizer(self): |
703 | 701 | from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer |
704 | 702 | from torchao.quantization.qat import Int4WeightOnlyQATQuantizer |
@@ -2015,6 +2013,7 @@ def test_quantize_api_int8_intx(self, weight_dtype, weight_granularity, dtype): |
2015 | 2013 | ) |
2016 | 2014 |
|
2017 | 2015 | @unittest.skipIf(_DEVICE is None, "skipping when GPU is not available") |
| 2016 | + @skip_if_xpu("XPU enablement in progress") |
2018 | 2017 | @parametrize( |
2019 | 2018 | "weight_dtype, granularity, dtype, module_type", |
2020 | 2019 | [ |
|
0 commit comments