From dbb1ad039cfa48f5594737c91fb7937dfb68d742 Mon Sep 17 00:00:00 2001 From: Jiawei Shao Date: Fri, 28 Nov 2025 21:01:31 +0800 Subject: [PATCH] [webgpu] Get data from bias with `GetByOffset()` in `gemm_utils.cc` --- onnxruntime/core/providers/webgpu/math/gemm_utils.cc | 5 +++-- onnxruntime/core/providers/webgpu/math/matmul_packed.cc | 7 ------- 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/math/gemm_utils.cc b/onnxruntime/core/providers/webgpu/math/gemm_utils.cc index 7cbc7f6a4a821..7390f7cda35f8 100644 --- a/onnxruntime/core/providers/webgpu/math/gemm_utils.cc +++ b/onnxruntime/core/providers/webgpu/math/gemm_utils.cc @@ -34,7 +34,7 @@ void HandleMaybeHaveBiasForGEMM(ShaderHelper& shader, } else if (c_is_scalar) { shader.AdditionalImplementation() << "output_value_t(C[0]);\n"; } else { - shader.AdditionalImplementation() << "output_value_t(C[row]);\n"; + shader.AdditionalImplementation() << "output_value_t(" << C.GetByOffset("row") << ");\n"; } } shader.AdditionalImplementation() << output.SetByIndices("coords", "value") << "\n"; @@ -47,7 +47,8 @@ void HandleMaybeBiasForMatMul(ShaderHelper& shader, bool is_channels_last) { shader.AdditionalImplementation() << " let coords = vec3(u32(batch), u32(row), u32(colIn));\n"; if (has_bias) { - shader.AdditionalImplementation() << " value = value + output_value_t(" << (is_channels_last ? "bias[colIn]" : "bias[row]") << ");\n"; + const ShaderVariableHelper& bias = shader.AddInput("bias", ShaderUsage::UseUniform); + shader.AdditionalImplementation() << " value = value + output_value_t(" << (is_channels_last ? bias.GetByOffset("colIn") : bias.GetByOffset("row")) << ");\n"; } shader.AdditionalImplementation() << " " << activation_snippet << "\n" << output.SetByIndices("coords", "value") << "\n"; diff --git a/onnxruntime/core/providers/webgpu/math/matmul_packed.cc b/onnxruntime/core/providers/webgpu/math/matmul_packed.cc index 4daabe8246aa7..c75b5c0efc1e9 100644 --- a/onnxruntime/core/providers/webgpu/math/matmul_packed.cc +++ b/onnxruntime/core/providers/webgpu/math/matmul_packed.cc @@ -26,9 +26,6 @@ Status MatMulProgram::GenerateShaderCode(ShaderHelper& shader) const { const auto& batch_dims = shader.AddIndices("batch_dims", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); - if (has_bias_) { - shader.AddInput("bias", ShaderUsage::UseUniform); - } std::string apply_activation = GetActivationSnippet(activation_, "output_value_t", "output_element_t"); ProgramVariableDataType output_var_type = this->Outputs()[0].var_type; // declare the read and write functions @@ -54,10 +51,6 @@ bool MatMulProgram::NeedSplitK() const { Status MatMulFillBiasOrZeroBeforeSplitKProgram::GenerateShaderCode(ShaderHelper& shader) const { const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); - if (has_bias_) { - shader.AddInput("bias", ShaderUsage::UseUniform); - } - // Handle bias with `MatMulWriteFnSource()`. // Here `use_split_k` is false because we just initialize `output` with bias. // `use_split_k` is true only when we do the actual MatMul with Split-K.