diff --git a/intel_extension_for_deepspeed/xpu_accelerator.py b/intel_extension_for_deepspeed/xpu_accelerator.py index 329b666..08112d2 100644 --- a/intel_extension_for_deepspeed/xpu_accelerator.py +++ b/intel_extension_for_deepspeed/xpu_accelerator.py @@ -81,6 +81,10 @@ def initial_seed(self, seed): def default_generator(self, device_index): return torch.xpu.default_generators[device_index] + + #WA for xpu Generator in torch api + def xpu_generator(self, device_index): + return torch.xpu.Generator(device=f'xpu:{device_index}') # Streams/Events @property