Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions onnxruntime/core/providers/webgpu/generator/range.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ Status Range<T>::ComputeInternal(ComputeContext& context) const {
}

uint32_t output_size = onnxruntime::narrow<uint32_t>(n);
RangeProgram program{};
RangeProgram program{output_tensor->GetElementType()};
#if defined(__GNUC__)
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
Expand All @@ -48,9 +48,19 @@ Status Range<T>::ComputeInternal(ComputeContext& context) const {
Status RangeProgram::GenerateShaderCode(ShaderHelper& sh) const {
const auto& output = sh.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);

sh.MainFunctionBody() << sh.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size")
<< " let value = bitcast<output_value_t>(uniforms.start) + output_value_t(global_idx) * bitcast<output_value_t>(uniforms.delta);\n"
<< output.SetByOffset("global_idx", "value");
sh.MainFunctionBody() << sh.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size");

// For int64, we need to cast to i32 first, then assign to output (which handles vec2<u32> conversion)
// For int32 and float, we can use output_value_t directly
if (data_type_ == ONNX_NAMESPACE::TensorProto_DataType_INT64) {
// int64 case: bitcast to i32, compute with i32, then assign (automatic conversion to vec2<u32>)
sh.MainFunctionBody() << " let value = bitcast<i32>(uniforms.start) + i32(global_idx) * bitcast<i32>(uniforms.delta);\n"
<< output.SetByOffset("global_idx", "value");
} else {
// float or int32 case: use output_value_t
sh.MainFunctionBody() << " let value = bitcast<output_value_t>(uniforms.start) + output_value_t(global_idx) * bitcast<output_value_t>(uniforms.delta);\n"
<< output.SetByOffset("global_idx", "value");
}

return Status();
}
Expand All @@ -71,6 +81,7 @@ Status RangeProgram::GenerateShaderCode(ShaderHelper& sh) const {

WEBGPU_RANGE_KERNEL(float)
WEBGPU_RANGE_KERNEL(int32_t)
WEBGPU_RANGE_KERNEL(int64_t)

} // namespace webgpu
} // namespace onnxruntime
4 changes: 4 additions & 0 deletions onnxruntime/core/providers/webgpu/generator/range.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,16 @@
class RangeProgram : public Program<RangeProgram> {
public:
RangeProgram() : Program{"Range"} {}
RangeProgram(int32_t data_type) : Program{"Range"}, data_type_(data_type) {}

Check warning on line 22 in onnxruntime/core/providers/webgpu/generator/range.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Single-parameter constructors should be marked explicit. [runtime/explicit] [4] Raw Output: onnxruntime/core/providers/webgpu/generator/range.h:22: Single-parameter constructors should be marked explicit. [runtime/explicit] [4]

Status GenerateShaderCode(ShaderHelper& sh) const override;

WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"output_size", ProgramUniformVariableDataType::Uint32},
{"start", ProgramUniformVariableDataType::Uint32},
{"delta", ProgramUniformVariableDataType::Uint32});

private:
int32_t data_type_{0};
};

} // namespace webgpu
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCD

class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, float, Range);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, int32_t, Range);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, int64_t, Range);

class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 12, Einsum);

Expand Down Expand Up @@ -723,6 +724,7 @@ std::unique_ptr<KernelRegistry> RegisterKernels(bool enable_graph_capture = fals

BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, float, Range)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, int32_t, Range)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, int64_t, Range)>,

BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 12, Einsum)>,

Expand Down
Loading