From 83450b47084fd87bf5fa112b83f5063e751173b7 Mon Sep 17 00:00:00 2001 From: Jiawei Shao Date: Wed, 26 Nov 2025 14:51:47 +0800 Subject: [PATCH 1/4] [webgpu] Remove `global_id` and `workgroup_id` in gemm_utils.cc This patch replaces `global_id` and `workgroup_id` with `logical_global_id` and `logical_workgroup_id` which are computed from `workgroup_idx` and the dispatch workgroup sizes set in `ProgramBase::SetDispatchGroupSize()` because the dispatch workgroup sizes may be normalized in `ProgramManager::NormalizeDispatchGroupSize()`. --- .../core/providers/webgpu/math/gemm_packed.cc | 15 +++-- .../core/providers/webgpu/math/gemm_packed.h | 5 +- .../core/providers/webgpu/math/gemm_utils.cc | 56 +++++++++++++------ .../core/providers/webgpu/math/matmul.cc | 4 +- .../providers/webgpu/math/matmul_packed.h | 5 +- .../core/providers/webgpu/nn/conv2d_mm.cc | 5 +- .../core/providers/webgpu/nn/conv2d_mm.h | 5 +- 7 files changed, 66 insertions(+), 29 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/math/gemm_packed.cc b/onnxruntime/core/providers/webgpu/math/gemm_packed.cc index 6aefa90a59285..1d261a175850e 100644 --- a/onnxruntime/core/providers/webgpu/math/gemm_packed.cc +++ b/onnxruntime/core/providers/webgpu/math/gemm_packed.cc @@ -93,18 +93,21 @@ Status ApplyGemmPacked(const Tensor* a, } const uint32_t TILE_SIZE = 32; - const uint32_t num_tile_n = (N + TILE_SIZE - 1) / TILE_SIZE; - const uint32_t num_tile_m = (M + TILE_SIZE - 1) / TILE_SIZE; + const uint32_t dispatch_x = (N + TILE_SIZE - 1) / TILE_SIZE; + const uint32_t dispatch_y = (M + TILE_SIZE - 1) / TILE_SIZE; program.CacheHint(alpha, transA, transB, c_is_scalar) .AddOutputs({{y, ProgramTensorMetadataDependency::TypeAndRank, output_components}}) - .SetDispatchGroupSize(num_tile_n, num_tile_m, 1) + .SetDispatchGroupSize(dispatch_x, dispatch_y, 1) .SetWorkgroupSize(GemmProgram::MATMUL_PACKED_WORKGROUP_SIZE_X, GemmProgram::MATMUL_PACKED_WORKGROUP_SIZE_Y, GemmProgram::MATMUL_PACKED_WORKGROUP_SIZE_Z) .AddUniformVariables({{alpha}, {beta}, - {M}, /* dim_a_outer */ - {N}, /* dim_b_outer */ - {K}} /*dim_inner */ + {M}, /* dim_a_outer */ + {N}, /* dim_b_outer */ + {K}, /*dim_inner */ + {dispatch_x}, /* logical_dispatch_x */ + {dispatch_y}, /* logical_dispatch_y */ + {1}} /* logical_dispatch_z */ ); return context.RunProgram(program); diff --git a/onnxruntime/core/providers/webgpu/math/gemm_packed.h b/onnxruntime/core/providers/webgpu/math/gemm_packed.h index dce5164693aa8..cb89ccefba313 100644 --- a/onnxruntime/core/providers/webgpu/math/gemm_packed.h +++ b/onnxruntime/core/providers/webgpu/math/gemm_packed.h @@ -32,7 +32,10 @@ class GemmProgram final : public Program { {"beta", ProgramUniformVariableDataType::Float32}, {"dim_a_outer", ProgramUniformVariableDataType::Uint32}, {"dim_b_outer", ProgramUniformVariableDataType::Uint32}, - {"dim_inner", ProgramUniformVariableDataType::Uint32}); + {"dim_inner", ProgramUniformVariableDataType::Uint32}, + {"logical_dispatch_x", ProgramUniformVariableDataType::Uint32}, + {"logical_dispatch_y", ProgramUniformVariableDataType::Uint32}, + {"logical_dispatch_z", ProgramUniformVariableDataType::Uint32}); constexpr static uint32_t MATMUL_PACKED_WORKGROUP_SIZE_X = 8; constexpr static uint32_t MATMUL_PACKED_WORKGROUP_SIZE_Y = 8; diff --git a/onnxruntime/core/providers/webgpu/math/gemm_utils.cc b/onnxruntime/core/providers/webgpu/math/gemm_utils.cc index 7cbc7f6a4a821..75e2247d26d24 100644 --- a/onnxruntime/core/providers/webgpu/math/gemm_utils.cc +++ b/onnxruntime/core/providers/webgpu/math/gemm_utils.cc @@ -272,22 +272,34 @@ Status MakeMatMulPackedVec4Source(ShaderHelper& shader, << "const rowPerThread = " << elements_per_thread_y << ";\n" << "const colPerThread = " << elements_per_thread_x << ";\n" << "const innerElementSize = " << inner_elements_size << ";\n" - << "const tileInner = " << tile_inner << ";\n"; + << "const tileInner = " << tile_inner << ";\n" + << "const workgroupSize = vec3u(workgroup_size_x, workgroup_size_y, workgroup_size_z);\n"; + + // Compute `logical_workgroup_id` and `logical_global_id` because the dispatch workgroup size in + // `ProgramBase.SetDispatchGroupSize()` may be normalized in + // `ProgramManager::NormalizeDispatchGroupSize()`. In the shader we should always use + // `logical_workgroup_id` and `logical_global_id` instead of `workgroup_id` and `global_id`. + shader.MainFunctionBody() + << " let logical_workgroup_id_z = workgroup_idx / (uniforms.logical_dispatch_x * uniforms.logical_dispatch_y);\n" + << " let logical_workgroup_id_y = (workgroup_idx % (uniforms.logical_dispatch_x * uniforms.logical_dispatch_y)) / uniforms.logical_dispatch_x;\n" + << " let logical_workgroup_id_x = (workgroup_idx % (uniforms.logical_dispatch_x * uniforms.logical_dispatch_y)) % uniforms.logical_dispatch_x;\n" + << " let logical_workgroup_id = vec3u(logical_workgroup_id_x, logical_workgroup_id_y, logical_workgroup_id_z);\n" + << " let logical_global_id = logical_workgroup_id * workgroupSize + local_id;\n"; shader.MainFunctionBody() << " let localRow = i32(local_id.y);\n" << " let tileRow = localRow * rowPerThread;\n" << " let tileCol = i32(local_id.x);\n" - << " let globalRow = i32(global_id.y) * rowPerThread;\n" - << " let globalCol = i32(global_id.x);\n" - << " let globalRowStart = i32(workgroup_id.y) * " << tile_a_outer << ";\n" - << " let globalColStart = i32(workgroup_id.x) * " << tile_b_outer << ";\n" + << " let globalRow = i32(logical_global_id.y) * rowPerThread;\n" + << " let globalCol = i32(logical_global_id.x);\n" + << " let globalRowStart = i32(logical_workgroup_id.y) * " << tile_a_outer << ";\n" + << " let globalColStart = i32(logical_workgroup_id.x) * " << tile_b_outer << ";\n" << " var acc: array, rowPerThread>;\n"; if (split_k) { // With Split-K, the original "workgroup" (with dispatch_z == 1 in API side) is split into // multiple ones, and in the current workgroup we only compute `kSplitK` elements starting from - // `kSplitK * i32(global_id.z)`. + // `kSplitK * i32(logical_global_id.z)`. // // For example: considering computing Y = (X * W + B) in one workgroup. // Let kSplitK = 2, B = [d1, d2] @@ -305,15 +317,15 @@ Status MakeMatMulPackedVec4Source(ShaderHelper& shader, // Workgroup1: compute (A1 * A2) Workgroup2: compute (B1 * B2) // Workgroup3: compute (C1 * C2) // In each workgroup: - // - `num_tiles` is computed with `kSplitK`, and `kStart` is computed with `global_id.z` + // - `num_tiles` is computed with `kSplitK`, and `kStart` is computed with `logical_global_id.z` // - When the computation in each workgroup is completed, add the result to Y with several // atomic built-in functions in `HandleMatMulWithSplitK()`. shader.MainFunctionBody() << "const kSplitK = " << split_dim_inner << ";\n" << " let num_tiles = (kSplitK - 1) / tileInner + 1;\n" - << " var kStart = kSplitK * i32(global_id.z);\n" + << " var kStart = kSplitK * i32(logical_global_id.z);\n" - // When Split-K is used, `batch` should always be 0 and `global_id.z` is used to indicate + // When Split-K is used, `batch` should always be 0 and `logical_global_id.z` is used to indicate // the index of split-k instead of batch. << " let batch = 0;\n" << " let batchIndices = 0u;\n"; @@ -321,7 +333,7 @@ Status MakeMatMulPackedVec4Source(ShaderHelper& shader, shader.MainFunctionBody() << " let num_tiles = (uniforms.dim_inner - 1) / tileInner + 1;\n" << " var kStart = 0;\n" - << " let batch = i32(global_id.z);\n" + << " let batch = i32(logical_global_id.z);\n" << (nullptr != batch_dims ? " let batchIndices = " + batch_dims->OffsetToIndices("u32(batch)") + ";\n" : ""); } @@ -496,9 +508,21 @@ Status MakeMatMulPackedSource(ShaderHelper& shader, << "var mm_Bsub: array, " << tile_inner << ">;\n" << "const rowPerThread = " << elements_per_thread_y << ";\n" << "const colPerThread = " << elements_per_thread_x << ";\n" - << "const tileInner = " << tile_inner << ";\n"; + << "const tileInner = " << tile_inner << ";\n" + << "const workgroupSize = vec3u(workgroup_size_x, workgroup_size_y, workgroup_size_z);\n"; + + // Compute `logical_workgroup_id` and `logical_global_id` because the dispatch workgroup size in + // `ProgramBase.SetDispatchGroupSize()` may be normalized in + // `ProgramManager::NormalizeDispatchGroupSize()`. In the shader we should always use + // `logical_workgroup_id` and `logical_global_id` instead of `workgroup_id` and `global_id`. + shader.MainFunctionBody() + << " let logical_workgroup_id_z = workgroup_idx / (uniforms.logical_dispatch_x * uniforms.logical_dispatch_y);\n" + << " let logical_workgroup_id_y = (workgroup_idx % (uniforms.logical_dispatch_x * uniforms.logical_dispatch_y)) / uniforms.logical_dispatch_x;\n" + << " let logical_workgroup_id_x = (workgroup_idx % (uniforms.logical_dispatch_x * uniforms.logical_dispatch_y)) % uniforms.logical_dispatch_x;\n" + << " let logical_workgroup_id = vec3u(logical_workgroup_id_x, logical_workgroup_id_y, logical_workgroup_id_z);\n" + << " let logical_global_id = logical_workgroup_id * workgroupSize + local_id;\n"; - shader.MainFunctionBody() << " let batch = i32(global_id.z);\n" + shader.MainFunctionBody() << " let batch = i32(logical_global_id.z);\n" << (nullptr != batch_dims ? " let batchIndices = " + batch_dims->OffsetToIndices("u32(batch)") + ";\n" : "") << " let num_tiles = (uniforms.dim_inner - 1) / tileInner + 1;\n" << " var kStart = 0;\n" @@ -507,10 +531,10 @@ Status MakeMatMulPackedSource(ShaderHelper& shader, shader.MainFunctionBody() << "let tileRow = i32(local_id.y) * rowPerThread;\n" << "let tileCol = i32(local_id.x) * colPerThread;\n" - << "let globalRow = i32(global_id.y) * rowPerThread;\n" - << "let globalCol = i32(global_id.x) * colPerThread;\n" - << "let globalRowStart = i32(workgroup_id.y) * " << tile_a_outer << ";\n" - << "let globalColStart = i32(workgroup_id.x) * " << tile_b_outer << ";\n" + << "let globalRow = i32(logical_global_id.y) * rowPerThread;\n" + << "let globalCol = i32(logical_global_id.x) * colPerThread;\n" + << "let globalRowStart = i32(logical_workgroup_id.y) * " << tile_a_outer << ";\n" + << "let globalColStart = i32(logical_workgroup_id.x) * " << tile_b_outer << ";\n" << "let tileRowA = i32(local_id.y) * " << row_per_thread_a << ";\n" << "let tileColA = i32(local_id.x) * " << col_per_thread_a << ";\n" << "let tileRowB = i32(local_id.y) * " << row_per_thread_b << ";\n"; diff --git a/onnxruntime/core/providers/webgpu/math/matmul.cc b/onnxruntime/core/providers/webgpu/math/matmul.cc index 55c2c5773cc1f..72dd235eb820a 100644 --- a/onnxruntime/core/providers/webgpu/math/matmul.cc +++ b/onnxruntime/core/providers/webgpu/math/matmul.cc @@ -256,8 +256,6 @@ Status ComputeMatMul(ComputeContext* context, // With Split-K, `dim_inner` will be split into multiple parts and `dispatch_z` will be the // number of splits along `dim_inner`. - // TODO: avoid using `global_id.xxx` or `workgroup_id.xxx` in `MatMulProgram` when we normalize - // the dispatch size with `ProgramManager::NormalizeDispatchGroupSize()` for `MatMulProgram`. split_dim_inner = split_k_config.GetSplitDimInner(); dispatch_z = (dim_inner + split_dim_inner - 1) / split_dim_inner; @@ -271,7 +269,7 @@ Status ComputeMatMul(ComputeContext* context, .CacheHint(activation.ToString(), absl::StrJoin(elements_per_thread, "-"), std::to_string(is_vec4), components, is_channels_last, split_dim_inner) .AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, a_shape_temp, components}, {b, ProgramTensorMetadataDependency::TypeAndRank, b_shape_temp, components}}) - .AddUniformVariables({{dim_a_outer}, {dim_b_outer}, {dim_inner}}) + .AddUniformVariables({{dim_a_outer}, {dim_b_outer}, {dim_inner}, {dispatch_x}, {dispatch_y}, {dispatch_z}}) .AddIndices(outer_dims) .SetDispatchGroupSize(dispatch_x, dispatch_y, dispatch_z) .SetWorkgroupSize(MatMul::MATMUL_PACKED_WORKGROUP_SIZE_X, MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Y, MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Z) diff --git a/onnxruntime/core/providers/webgpu/math/matmul_packed.h b/onnxruntime/core/providers/webgpu/math/matmul_packed.h index 143ba61c99e13..dbd193bc38f58 100644 --- a/onnxruntime/core/providers/webgpu/math/matmul_packed.h +++ b/onnxruntime/core/providers/webgpu/math/matmul_packed.h @@ -24,7 +24,10 @@ class MatMulProgram final : public Program { Status GenerateShaderCode(ShaderHelper& sh) const override; WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"dim_a_outer", ProgramUniformVariableDataType::Uint32}, {"dim_b_outer", ProgramUniformVariableDataType::Uint32}, - {"dim_inner", ProgramUniformVariableDataType::Uint32}); + {"dim_inner", ProgramUniformVariableDataType::Uint32}, + {"logical_dispatch_x", ProgramUniformVariableDataType::Uint32}, + {"logical_dispatch_y", ProgramUniformVariableDataType::Uint32}, + {"logical_dispatch_z", ProgramUniformVariableDataType::Uint32}); bool NeedSplitK() const; diff --git a/onnxruntime/core/providers/webgpu/nn/conv2d_mm.cc b/onnxruntime/core/providers/webgpu/nn/conv2d_mm.cc index 2d5424c52a3f2..c66f2cbd582d9 100644 --- a/onnxruntime/core/providers/webgpu/nn/conv2d_mm.cc +++ b/onnxruntime/core/providers/webgpu/nn/conv2d_mm.cc @@ -226,7 +226,10 @@ Conv2dMMProgram CreateConv2dMMProgram(const Activation& activation, const std::v {static_cast(dim_inner)}, {pads}, {strides}, - {dilations}}); + {dilations}, + {dispatch[0]}, + {dispatch[1]}, + {dispatch[2]}}); return program; } diff --git a/onnxruntime/core/providers/webgpu/nn/conv2d_mm.h b/onnxruntime/core/providers/webgpu/nn/conv2d_mm.h index d7cc08aae26f3..e161bffb0c503 100644 --- a/onnxruntime/core/providers/webgpu/nn/conv2d_mm.h +++ b/onnxruntime/core/providers/webgpu/nn/conv2d_mm.h @@ -38,7 +38,10 @@ class Conv2dMMProgram final : public Program { {"dim_inner", ProgramUniformVariableDataType::Uint32}, {"pads", ProgramUniformVariableDataType::Uint32}, {"strides", ProgramUniformVariableDataType::Uint32}, - {"dilations", ProgramUniformVariableDataType::Uint32}); + {"dilations", ProgramUniformVariableDataType::Uint32}, + {"logical_dispatch_x", ProgramUniformVariableDataType::Uint32}, + {"logical_dispatch_y", ProgramUniformVariableDataType::Uint32}, + {"logical_dispatch_z", ProgramUniformVariableDataType::Uint32}); private: const Activation& activation_; From 8b7017ad7151d14d0b261c9dee80d1ebe3a544d3 Mon Sep 17 00:00:00 2001 From: Jiawei Shao Date: Wed, 26 Nov 2025 21:25:24 +0800 Subject: [PATCH 2/4] Address the comments from copilot and fix errors on the bots --- onnxruntime/core/providers/webgpu/math/gemm_packed.cc | 4 ++-- onnxruntime/core/providers/webgpu/math/gemm_utils.cc | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/math/gemm_packed.cc b/onnxruntime/core/providers/webgpu/math/gemm_packed.cc index 1d261a175850e..c26b58a7af1f4 100644 --- a/onnxruntime/core/providers/webgpu/math/gemm_packed.cc +++ b/onnxruntime/core/providers/webgpu/math/gemm_packed.cc @@ -98,7 +98,7 @@ Status ApplyGemmPacked(const Tensor* a, program.CacheHint(alpha, transA, transB, c_is_scalar) .AddOutputs({{y, ProgramTensorMetadataDependency::TypeAndRank, output_components}}) - .SetDispatchGroupSize(dispatch_x, dispatch_y, 1) + .SetDispatchGroupSize(dispatch_x, dispatch_y, 1u) .SetWorkgroupSize(GemmProgram::MATMUL_PACKED_WORKGROUP_SIZE_X, GemmProgram::MATMUL_PACKED_WORKGROUP_SIZE_Y, GemmProgram::MATMUL_PACKED_WORKGROUP_SIZE_Z) .AddUniformVariables({{alpha}, {beta}, @@ -107,7 +107,7 @@ Status ApplyGemmPacked(const Tensor* a, {K}, /*dim_inner */ {dispatch_x}, /* logical_dispatch_x */ {dispatch_y}, /* logical_dispatch_y */ - {1}} /* logical_dispatch_z */ + {1u}} /* logical_dispatch_z */ ); return context.RunProgram(program); diff --git a/onnxruntime/core/providers/webgpu/math/gemm_utils.cc b/onnxruntime/core/providers/webgpu/math/gemm_utils.cc index 75e2247d26d24..97731ea2bf101 100644 --- a/onnxruntime/core/providers/webgpu/math/gemm_utils.cc +++ b/onnxruntime/core/providers/webgpu/math/gemm_utils.cc @@ -273,7 +273,7 @@ Status MakeMatMulPackedVec4Source(ShaderHelper& shader, << "const colPerThread = " << elements_per_thread_x << ";\n" << "const innerElementSize = " << inner_elements_size << ";\n" << "const tileInner = " << tile_inner << ";\n" - << "const workgroupSize = vec3u(workgroup_size_x, workgroup_size_y, workgroup_size_z);\n"; + << "const workgroup_size = vec3u(workgroup_size_x, workgroup_size_y, workgroup_size_z);\n"; // Compute `logical_workgroup_id` and `logical_global_id` because the dispatch workgroup size in // `ProgramBase.SetDispatchGroupSize()` may be normalized in @@ -284,7 +284,7 @@ Status MakeMatMulPackedVec4Source(ShaderHelper& shader, << " let logical_workgroup_id_y = (workgroup_idx % (uniforms.logical_dispatch_x * uniforms.logical_dispatch_y)) / uniforms.logical_dispatch_x;\n" << " let logical_workgroup_id_x = (workgroup_idx % (uniforms.logical_dispatch_x * uniforms.logical_dispatch_y)) % uniforms.logical_dispatch_x;\n" << " let logical_workgroup_id = vec3u(logical_workgroup_id_x, logical_workgroup_id_y, logical_workgroup_id_z);\n" - << " let logical_global_id = logical_workgroup_id * workgroupSize + local_id;\n"; + << " let logical_global_id = logical_workgroup_id * workgroup_size + local_id;\n"; shader.MainFunctionBody() << " let localRow = i32(local_id.y);\n" @@ -509,7 +509,7 @@ Status MakeMatMulPackedSource(ShaderHelper& shader, << "const rowPerThread = " << elements_per_thread_y << ";\n" << "const colPerThread = " << elements_per_thread_x << ";\n" << "const tileInner = " << tile_inner << ";\n" - << "const workgroupSize = vec3u(workgroup_size_x, workgroup_size_y, workgroup_size_z);\n"; + << "const workgroup_size = vec3u(workgroup_size_x, workgroup_size_y, workgroup_size_z);\n"; // Compute `logical_workgroup_id` and `logical_global_id` because the dispatch workgroup size in // `ProgramBase.SetDispatchGroupSize()` may be normalized in @@ -520,7 +520,7 @@ Status MakeMatMulPackedSource(ShaderHelper& shader, << " let logical_workgroup_id_y = (workgroup_idx % (uniforms.logical_dispatch_x * uniforms.logical_dispatch_y)) / uniforms.logical_dispatch_x;\n" << " let logical_workgroup_id_x = (workgroup_idx % (uniforms.logical_dispatch_x * uniforms.logical_dispatch_y)) % uniforms.logical_dispatch_x;\n" << " let logical_workgroup_id = vec3u(logical_workgroup_id_x, logical_workgroup_id_y, logical_workgroup_id_z);\n" - << " let logical_global_id = logical_workgroup_id * workgroupSize + local_id;\n"; + << " let logical_global_id = logical_workgroup_id * workgroup_size + local_id;\n"; shader.MainFunctionBody() << " let batch = i32(logical_global_id.z);\n" << (nullptr != batch_dims ? " let batchIndices = " + batch_dims->OffsetToIndices("u32(batch)") + ";\n" : "") From 2786714acd2a36e3ed1de66c9f9fe3534f819059 Mon Sep 17 00:00:00 2001 From: Jiawei Shao Date: Wed, 26 Nov 2025 21:33:24 +0800 Subject: [PATCH 3/4] Address the comments from copilot --- .../core/providers/webgpu/math/gemm_utils.cc | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/math/gemm_utils.cc b/onnxruntime/core/providers/webgpu/math/gemm_utils.cc index 97731ea2bf101..b19d488534be6 100644 --- a/onnxruntime/core/providers/webgpu/math/gemm_utils.cc +++ b/onnxruntime/core/providers/webgpu/math/gemm_utils.cc @@ -280,9 +280,10 @@ Status MakeMatMulPackedVec4Source(ShaderHelper& shader, // `ProgramManager::NormalizeDispatchGroupSize()`. In the shader we should always use // `logical_workgroup_id` and `logical_global_id` instead of `workgroup_id` and `global_id`. shader.MainFunctionBody() - << " let logical_workgroup_id_z = workgroup_idx / (uniforms.logical_dispatch_x * uniforms.logical_dispatch_y);\n" - << " let logical_workgroup_id_y = (workgroup_idx % (uniforms.logical_dispatch_x * uniforms.logical_dispatch_y)) / uniforms.logical_dispatch_x;\n" - << " let logical_workgroup_id_x = (workgroup_idx % (uniforms.logical_dispatch_x * uniforms.logical_dispatch_y)) % uniforms.logical_dispatch_x;\n" + << " let logical_workgroups_xy = uniforms.logical_dispatch_x * uniforms.logical_dispatch_y;\n" + << " let logical_workgroup_id_z = workgroup_idx / logical_workgroups_xy;\n" + << " let logical_workgroup_id_y = (workgroup_idx % logical_workgroups_xy) / uniforms.logical_dispatch_x;\n" + << " let logical_workgroup_id_x = (workgroup_idx % logical_workgroups_xy) % uniforms.logical_dispatch_x;\n" << " let logical_workgroup_id = vec3u(logical_workgroup_id_x, logical_workgroup_id_y, logical_workgroup_id_z);\n" << " let logical_global_id = logical_workgroup_id * workgroup_size + local_id;\n"; @@ -516,9 +517,10 @@ Status MakeMatMulPackedSource(ShaderHelper& shader, // `ProgramManager::NormalizeDispatchGroupSize()`. In the shader we should always use // `logical_workgroup_id` and `logical_global_id` instead of `workgroup_id` and `global_id`. shader.MainFunctionBody() - << " let logical_workgroup_id_z = workgroup_idx / (uniforms.logical_dispatch_x * uniforms.logical_dispatch_y);\n" - << " let logical_workgroup_id_y = (workgroup_idx % (uniforms.logical_dispatch_x * uniforms.logical_dispatch_y)) / uniforms.logical_dispatch_x;\n" - << " let logical_workgroup_id_x = (workgroup_idx % (uniforms.logical_dispatch_x * uniforms.logical_dispatch_y)) % uniforms.logical_dispatch_x;\n" + << " let logical_workgroups_xy = uniforms.logical_dispatch_x * uniforms.logical_dispatch_y;\n" + << " let logical_workgroup_id_z = workgroup_idx / logical_workgroups_xy;\n" + << " let logical_workgroup_id_y = (workgroup_idx % logical_workgroups_xy) / uniforms.logical_dispatch_x;\n" + << " let logical_workgroup_id_x = (workgroup_idx % logical_workgroups_xy) % uniforms.logical_dispatch_x;\n" << " let logical_workgroup_id = vec3u(logical_workgroup_id_x, logical_workgroup_id_y, logical_workgroup_id_z);\n" << " let logical_global_id = logical_workgroup_id * workgroup_size + local_id;\n"; From efc937cf960294a0ff4a5eeebdf8f19a49c0d474 Mon Sep 17 00:00:00 2001 From: Jiawei Shao Date: Fri, 28 Nov 2025 09:42:13 +0800 Subject: [PATCH 4/4] Address reviewer's comment --- .../core/providers/webgpu/math/gemm_utils.cc | 44 ++++++++----------- 1 file changed, 18 insertions(+), 26 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/math/gemm_utils.cc b/onnxruntime/core/providers/webgpu/math/gemm_utils.cc index b19d488534be6..89718149cea88 100644 --- a/onnxruntime/core/providers/webgpu/math/gemm_utils.cc +++ b/onnxruntime/core/providers/webgpu/math/gemm_utils.cc @@ -117,6 +117,20 @@ void HandleMatMulWithSplitK( } } +// Compute `logical_workgroup_id` and `logical_global_id` because the dispatch workgroup size in +// `ProgramBase.SetDispatchGroupSize()` may be normalized in +// `ProgramManager::NormalizeDispatchGroupSize()`. In the shader we should always use +// `logical_workgroup_id` and `logical_global_id` instead of `workgroup_id` and `global_id`. +void InitializeLogicalWorkgroupIDAndGlobalID(ShaderHelper& shader) { + shader.MainFunctionBody() + << " let logical_workgroup_id_z = workgroup_idx / (uniforms.logical_dispatch_x * uniforms.logical_dispatch_y);\n" + << " let logical_workgroup_id_y = (workgroup_idx % (uniforms.logical_dispatch_x * uniforms.logical_dispatch_y)) / uniforms.logical_dispatch_x;\n" + << " let logical_workgroup_id_x = (workgroup_idx % (uniforms.logical_dispatch_x * uniforms.logical_dispatch_y)) % uniforms.logical_dispatch_x;\n" + << " let logical_workgroup_id = vec3u(logical_workgroup_id_x, logical_workgroup_id_y, logical_workgroup_id_z);\n" + << " const workgroupSize = vec3u(workgroup_size_x, workgroup_size_y, workgroup_size_z);\n" + << " let logical_global_id = logical_workgroup_id * workgroupSize + local_id;\n"; +} + } // namespace void MatMulReadFnSource(ShaderHelper& shader, @@ -272,20 +286,9 @@ Status MakeMatMulPackedVec4Source(ShaderHelper& shader, << "const rowPerThread = " << elements_per_thread_y << ";\n" << "const colPerThread = " << elements_per_thread_x << ";\n" << "const innerElementSize = " << inner_elements_size << ";\n" - << "const tileInner = " << tile_inner << ";\n" - << "const workgroup_size = vec3u(workgroup_size_x, workgroup_size_y, workgroup_size_z);\n"; + << "const tileInner = " << tile_inner << ";\n"; - // Compute `logical_workgroup_id` and `logical_global_id` because the dispatch workgroup size in - // `ProgramBase.SetDispatchGroupSize()` may be normalized in - // `ProgramManager::NormalizeDispatchGroupSize()`. In the shader we should always use - // `logical_workgroup_id` and `logical_global_id` instead of `workgroup_id` and `global_id`. - shader.MainFunctionBody() - << " let logical_workgroups_xy = uniforms.logical_dispatch_x * uniforms.logical_dispatch_y;\n" - << " let logical_workgroup_id_z = workgroup_idx / logical_workgroups_xy;\n" - << " let logical_workgroup_id_y = (workgroup_idx % logical_workgroups_xy) / uniforms.logical_dispatch_x;\n" - << " let logical_workgroup_id_x = (workgroup_idx % logical_workgroups_xy) % uniforms.logical_dispatch_x;\n" - << " let logical_workgroup_id = vec3u(logical_workgroup_id_x, logical_workgroup_id_y, logical_workgroup_id_z);\n" - << " let logical_global_id = logical_workgroup_id * workgroup_size + local_id;\n"; + InitializeLogicalWorkgroupIDAndGlobalID(shader); shader.MainFunctionBody() << " let localRow = i32(local_id.y);\n" @@ -509,20 +512,9 @@ Status MakeMatMulPackedSource(ShaderHelper& shader, << "var mm_Bsub: array, " << tile_inner << ">;\n" << "const rowPerThread = " << elements_per_thread_y << ";\n" << "const colPerThread = " << elements_per_thread_x << ";\n" - << "const tileInner = " << tile_inner << ";\n" - << "const workgroup_size = vec3u(workgroup_size_x, workgroup_size_y, workgroup_size_z);\n"; + << "const tileInner = " << tile_inner << ";\n"; - // Compute `logical_workgroup_id` and `logical_global_id` because the dispatch workgroup size in - // `ProgramBase.SetDispatchGroupSize()` may be normalized in - // `ProgramManager::NormalizeDispatchGroupSize()`. In the shader we should always use - // `logical_workgroup_id` and `logical_global_id` instead of `workgroup_id` and `global_id`. - shader.MainFunctionBody() - << " let logical_workgroups_xy = uniforms.logical_dispatch_x * uniforms.logical_dispatch_y;\n" - << " let logical_workgroup_id_z = workgroup_idx / logical_workgroups_xy;\n" - << " let logical_workgroup_id_y = (workgroup_idx % logical_workgroups_xy) / uniforms.logical_dispatch_x;\n" - << " let logical_workgroup_id_x = (workgroup_idx % logical_workgroups_xy) % uniforms.logical_dispatch_x;\n" - << " let logical_workgroup_id = vec3u(logical_workgroup_id_x, logical_workgroup_id_y, logical_workgroup_id_z);\n" - << " let logical_global_id = logical_workgroup_id * workgroup_size + local_id;\n"; + InitializeLogicalWorkgroupIDAndGlobalID(shader); shader.MainFunctionBody() << " let batch = i32(logical_global_id.z);\n" << (nullptr != batch_dims ? " let batchIndices = " + batch_dims->OffsetToIndices("u32(batch)") + ";\n" : "")