diff --git a/onnxruntime/core/providers/webgpu/generator/range.cc b/onnxruntime/core/providers/webgpu/generator/range.cc index 99c5a1c1b5566..bf68e534d7c7f 100644 --- a/onnxruntime/core/providers/webgpu/generator/range.cc +++ b/onnxruntime/core/providers/webgpu/generator/range.cc @@ -24,7 +24,7 @@ Status Range::ComputeInternal(ComputeContext& context) const { } uint32_t output_size = onnxruntime::narrow(n); - RangeProgram program{}; + RangeProgram program{output_tensor->GetElementType()}; #if defined(__GNUC__) #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wstrict-aliasing" @@ -48,9 +48,19 @@ Status Range::ComputeInternal(ComputeContext& context) const { Status RangeProgram::GenerateShaderCode(ShaderHelper& sh) const { const auto& output = sh.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); - sh.MainFunctionBody() << sh.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size") - << " let value = bitcast(uniforms.start) + output_value_t(global_idx) * bitcast(uniforms.delta);\n" - << output.SetByOffset("global_idx", "value"); + sh.MainFunctionBody() << sh.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size"); + + // For int64, we need to cast to i32 first, then assign to output (which handles vec2 conversion) + // For int32 and float, we can use output_value_t directly + if (data_type_ == ONNX_NAMESPACE::TensorProto_DataType_INT64) { + // int64 case: bitcast to i32, compute with i32, then assign (automatic conversion to vec2) + sh.MainFunctionBody() << " let value = bitcast(uniforms.start) + i32(global_idx) * bitcast(uniforms.delta);\n" + << output.SetByOffset("global_idx", "value"); + } else { + // float or int32 case: use output_value_t + sh.MainFunctionBody() << " let value = bitcast(uniforms.start) + output_value_t(global_idx) * bitcast(uniforms.delta);\n" + << output.SetByOffset("global_idx", "value"); + } return Status(); } @@ -71,6 +81,7 @@ Status RangeProgram::GenerateShaderCode(ShaderHelper& sh) const { WEBGPU_RANGE_KERNEL(float) WEBGPU_RANGE_KERNEL(int32_t) +WEBGPU_RANGE_KERNEL(int64_t) } // namespace webgpu } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/generator/range.h b/onnxruntime/core/providers/webgpu/generator/range.h index 2f5812bb460ad..7f6b17b69c115 100644 --- a/onnxruntime/core/providers/webgpu/generator/range.h +++ b/onnxruntime/core/providers/webgpu/generator/range.h @@ -19,12 +19,16 @@ class Range : public WebGpuKernel { class RangeProgram : public Program { public: RangeProgram() : Program{"Range"} {} + RangeProgram(int32_t data_type) : Program{"Range"}, data_type_(data_type) {} Status GenerateShaderCode(ShaderHelper& sh) const override; WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"output_size", ProgramUniformVariableDataType::Uint32}, {"start", ProgramUniformVariableDataType::Uint32}, {"delta", ProgramUniformVariableDataType::Uint32}); + + private: + int32_t data_type_{0}; }; } // namespace webgpu diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index 6b764d51bcf75..c9752dcd4e5d4 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -382,6 +382,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCD class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, float, Range); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, int32_t, Range); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, int64_t, Range); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 12, Einsum); @@ -723,6 +724,7 @@ std::unique_ptr RegisterKernels(bool enable_graph_capture = fals BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo,