Skip to content

Commit 83450b4

Browse files
committed
[webgpu] Remove global_id and workgroup_id in gemm_utils.cc
This patch replaces `global_id` and `workgroup_id` with `logical_global_id` and `logical_workgroup_id` which are computed from `workgroup_idx` and the dispatch workgroup sizes set in `ProgramBase::SetDispatchGroupSize()` because the dispatch workgroup sizes may be normalized in `ProgramManager::NormalizeDispatchGroupSize()`.
1 parent 7845ea8 commit 83450b4

File tree

7 files changed

+66
-29
lines changed

7 files changed

+66
-29
lines changed

onnxruntime/core/providers/webgpu/math/gemm_packed.cc

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -93,18 +93,21 @@ Status ApplyGemmPacked(const Tensor* a,
9393
}
9494

9595
const uint32_t TILE_SIZE = 32;
96-
const uint32_t num_tile_n = (N + TILE_SIZE - 1) / TILE_SIZE;
97-
const uint32_t num_tile_m = (M + TILE_SIZE - 1) / TILE_SIZE;
96+
const uint32_t dispatch_x = (N + TILE_SIZE - 1) / TILE_SIZE;
97+
const uint32_t dispatch_y = (M + TILE_SIZE - 1) / TILE_SIZE;
9898

9999
program.CacheHint(alpha, transA, transB, c_is_scalar)
100100
.AddOutputs({{y, ProgramTensorMetadataDependency::TypeAndRank, output_components}})
101-
.SetDispatchGroupSize(num_tile_n, num_tile_m, 1)
101+
.SetDispatchGroupSize(dispatch_x, dispatch_y, 1)
102102
.SetWorkgroupSize(GemmProgram::MATMUL_PACKED_WORKGROUP_SIZE_X, GemmProgram::MATMUL_PACKED_WORKGROUP_SIZE_Y, GemmProgram::MATMUL_PACKED_WORKGROUP_SIZE_Z)
103103
.AddUniformVariables({{alpha},
104104
{beta},
105-
{M}, /* dim_a_outer */
106-
{N}, /* dim_b_outer */
107-
{K}} /*dim_inner */
105+
{M}, /* dim_a_outer */
106+
{N}, /* dim_b_outer */
107+
{K}, /*dim_inner */
108+
{dispatch_x}, /* logical_dispatch_x */
109+
{dispatch_y}, /* logical_dispatch_y */
110+
{1}} /* logical_dispatch_z */
108111
);
109112

110113
return context.RunProgram(program);

onnxruntime/core/providers/webgpu/math/gemm_packed.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,10 @@ class GemmProgram final : public Program<GemmProgram> {
3232
{"beta", ProgramUniformVariableDataType::Float32},
3333
{"dim_a_outer", ProgramUniformVariableDataType::Uint32},
3434
{"dim_b_outer", ProgramUniformVariableDataType::Uint32},
35-
{"dim_inner", ProgramUniformVariableDataType::Uint32});
35+
{"dim_inner", ProgramUniformVariableDataType::Uint32},
36+
{"logical_dispatch_x", ProgramUniformVariableDataType::Uint32},
37+
{"logical_dispatch_y", ProgramUniformVariableDataType::Uint32},
38+
{"logical_dispatch_z", ProgramUniformVariableDataType::Uint32});
3639

3740
constexpr static uint32_t MATMUL_PACKED_WORKGROUP_SIZE_X = 8;
3841
constexpr static uint32_t MATMUL_PACKED_WORKGROUP_SIZE_Y = 8;

onnxruntime/core/providers/webgpu/math/gemm_utils.cc

Lines changed: 40 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -272,22 +272,34 @@ Status MakeMatMulPackedVec4Source(ShaderHelper& shader,
272272
<< "const rowPerThread = " << elements_per_thread_y << ";\n"
273273
<< "const colPerThread = " << elements_per_thread_x << ";\n"
274274
<< "const innerElementSize = " << inner_elements_size << ";\n"
275-
<< "const tileInner = " << tile_inner << ";\n";
275+
<< "const tileInner = " << tile_inner << ";\n"
276+
<< "const workgroupSize = vec3u(workgroup_size_x, workgroup_size_y, workgroup_size_z);\n";
277+
278+
// Compute `logical_workgroup_id` and `logical_global_id` because the dispatch workgroup size in
279+
// `ProgramBase.SetDispatchGroupSize()` may be normalized in
280+
// `ProgramManager::NormalizeDispatchGroupSize()`. In the shader we should always use
281+
// `logical_workgroup_id` and `logical_global_id` instead of `workgroup_id` and `global_id`.
282+
shader.MainFunctionBody()
283+
<< " let logical_workgroup_id_z = workgroup_idx / (uniforms.logical_dispatch_x * uniforms.logical_dispatch_y);\n"
284+
<< " let logical_workgroup_id_y = (workgroup_idx % (uniforms.logical_dispatch_x * uniforms.logical_dispatch_y)) / uniforms.logical_dispatch_x;\n"
285+
<< " let logical_workgroup_id_x = (workgroup_idx % (uniforms.logical_dispatch_x * uniforms.logical_dispatch_y)) % uniforms.logical_dispatch_x;\n"
286+
<< " let logical_workgroup_id = vec3u(logical_workgroup_id_x, logical_workgroup_id_y, logical_workgroup_id_z);\n"
287+
<< " let logical_global_id = logical_workgroup_id * workgroupSize + local_id;\n";
276288

277289
shader.MainFunctionBody()
278290
<< " let localRow = i32(local_id.y);\n"
279291
<< " let tileRow = localRow * rowPerThread;\n"
280292
<< " let tileCol = i32(local_id.x);\n"
281-
<< " let globalRow = i32(global_id.y) * rowPerThread;\n"
282-
<< " let globalCol = i32(global_id.x);\n"
283-
<< " let globalRowStart = i32(workgroup_id.y) * " << tile_a_outer << ";\n"
284-
<< " let globalColStart = i32(workgroup_id.x) * " << tile_b_outer << ";\n"
293+
<< " let globalRow = i32(logical_global_id.y) * rowPerThread;\n"
294+
<< " let globalCol = i32(logical_global_id.x);\n"
295+
<< " let globalRowStart = i32(logical_workgroup_id.y) * " << tile_a_outer << ";\n"
296+
<< " let globalColStart = i32(logical_workgroup_id.x) * " << tile_b_outer << ";\n"
285297
<< " var acc: array<vec4<" << data_type << ">, rowPerThread>;\n";
286298

287299
if (split_k) {
288300
// With Split-K, the original "workgroup" (with dispatch_z == 1 in API side) is split into
289301
// multiple ones, and in the current workgroup we only compute `kSplitK` elements starting from
290-
// `kSplitK * i32(global_id.z)`.
302+
// `kSplitK * i32(logical_global_id.z)`.
291303
//
292304
// For example: considering computing Y = (X * W + B) in one workgroup.
293305
// Let kSplitK = 2, B = [d1, d2]
@@ -305,23 +317,23 @@ Status MakeMatMulPackedVec4Source(ShaderHelper& shader,
305317
// Workgroup1: compute (A1 * A2) Workgroup2: compute (B1 * B2)
306318
// Workgroup3: compute (C1 * C2)
307319
// In each workgroup:
308-
// - `num_tiles` is computed with `kSplitK`, and `kStart` is computed with `global_id.z`
320+
// - `num_tiles` is computed with `kSplitK`, and `kStart` is computed with `logical_global_id.z`
309321
// - When the computation in each workgroup is completed, add the result to Y with several
310322
// atomic built-in functions in `HandleMatMulWithSplitK()`.
311323
shader.MainFunctionBody()
312324
<< "const kSplitK = " << split_dim_inner << ";\n"
313325
<< " let num_tiles = (kSplitK - 1) / tileInner + 1;\n"
314-
<< " var kStart = kSplitK * i32(global_id.z);\n"
326+
<< " var kStart = kSplitK * i32(logical_global_id.z);\n"
315327

316-
// When Split-K is used, `batch` should always be 0 and `global_id.z` is used to indicate
328+
// When Split-K is used, `batch` should always be 0 and `logical_global_id.z` is used to indicate
317329
// the index of split-k instead of batch.
318330
<< " let batch = 0;\n"
319331
<< " let batchIndices = 0u;\n";
320332
} else {
321333
shader.MainFunctionBody()
322334
<< " let num_tiles = (uniforms.dim_inner - 1) / tileInner + 1;\n"
323335
<< " var kStart = 0;\n"
324-
<< " let batch = i32(global_id.z);\n"
336+
<< " let batch = i32(logical_global_id.z);\n"
325337
<< (nullptr != batch_dims ? " let batchIndices = " + batch_dims->OffsetToIndices("u32(batch)") + ";\n" : "");
326338
}
327339

@@ -496,9 +508,21 @@ Status MakeMatMulPackedSource(ShaderHelper& shader,
496508
<< "var<workgroup> mm_Bsub: array<array<" << data_type << ", " << tile_b_outer << ">, " << tile_inner << ">;\n"
497509
<< "const rowPerThread = " << elements_per_thread_y << ";\n"
498510
<< "const colPerThread = " << elements_per_thread_x << ";\n"
499-
<< "const tileInner = " << tile_inner << ";\n";
511+
<< "const tileInner = " << tile_inner << ";\n"
512+
<< "const workgroupSize = vec3u(workgroup_size_x, workgroup_size_y, workgroup_size_z);\n";
513+
514+
// Compute `logical_workgroup_id` and `logical_global_id` because the dispatch workgroup size in
515+
// `ProgramBase.SetDispatchGroupSize()` may be normalized in
516+
// `ProgramManager::NormalizeDispatchGroupSize()`. In the shader we should always use
517+
// `logical_workgroup_id` and `logical_global_id` instead of `workgroup_id` and `global_id`.
518+
shader.MainFunctionBody()
519+
<< " let logical_workgroup_id_z = workgroup_idx / (uniforms.logical_dispatch_x * uniforms.logical_dispatch_y);\n"
520+
<< " let logical_workgroup_id_y = (workgroup_idx % (uniforms.logical_dispatch_x * uniforms.logical_dispatch_y)) / uniforms.logical_dispatch_x;\n"
521+
<< " let logical_workgroup_id_x = (workgroup_idx % (uniforms.logical_dispatch_x * uniforms.logical_dispatch_y)) % uniforms.logical_dispatch_x;\n"
522+
<< " let logical_workgroup_id = vec3u(logical_workgroup_id_x, logical_workgroup_id_y, logical_workgroup_id_z);\n"
523+
<< " let logical_global_id = logical_workgroup_id * workgroupSize + local_id;\n";
500524

501-
shader.MainFunctionBody() << " let batch = i32(global_id.z);\n"
525+
shader.MainFunctionBody() << " let batch = i32(logical_global_id.z);\n"
502526
<< (nullptr != batch_dims ? " let batchIndices = " + batch_dims->OffsetToIndices("u32(batch)") + ";\n" : "")
503527
<< " let num_tiles = (uniforms.dim_inner - 1) / tileInner + 1;\n"
504528
<< " var kStart = 0;\n"
@@ -507,10 +531,10 @@ Status MakeMatMulPackedSource(ShaderHelper& shader,
507531
shader.MainFunctionBody()
508532
<< "let tileRow = i32(local_id.y) * rowPerThread;\n"
509533
<< "let tileCol = i32(local_id.x) * colPerThread;\n"
510-
<< "let globalRow = i32(global_id.y) * rowPerThread;\n"
511-
<< "let globalCol = i32(global_id.x) * colPerThread;\n"
512-
<< "let globalRowStart = i32(workgroup_id.y) * " << tile_a_outer << ";\n"
513-
<< "let globalColStart = i32(workgroup_id.x) * " << tile_b_outer << ";\n"
534+
<< "let globalRow = i32(logical_global_id.y) * rowPerThread;\n"
535+
<< "let globalCol = i32(logical_global_id.x) * colPerThread;\n"
536+
<< "let globalRowStart = i32(logical_workgroup_id.y) * " << tile_a_outer << ";\n"
537+
<< "let globalColStart = i32(logical_workgroup_id.x) * " << tile_b_outer << ";\n"
514538
<< "let tileRowA = i32(local_id.y) * " << row_per_thread_a << ";\n"
515539
<< "let tileColA = i32(local_id.x) * " << col_per_thread_a << ";\n"
516540
<< "let tileRowB = i32(local_id.y) * " << row_per_thread_b << ";\n";

onnxruntime/core/providers/webgpu/math/matmul.cc

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -256,8 +256,6 @@ Status ComputeMatMul(ComputeContext* context,
256256

257257
// With Split-K, `dim_inner` will be split into multiple parts and `dispatch_z` will be the
258258
// number of splits along `dim_inner`.
259-
// TODO: avoid using `global_id.xxx` or `workgroup_id.xxx` in `MatMulProgram` when we normalize
260-
// the dispatch size with `ProgramManager::NormalizeDispatchGroupSize()` for `MatMulProgram`.
261259
split_dim_inner = split_k_config.GetSplitDimInner();
262260
dispatch_z = (dim_inner + split_dim_inner - 1) / split_dim_inner;
263261

@@ -271,7 +269,7 @@ Status ComputeMatMul(ComputeContext* context,
271269
.CacheHint(activation.ToString(), absl::StrJoin(elements_per_thread, "-"), std::to_string(is_vec4), components, is_channels_last, split_dim_inner)
272270
.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, a_shape_temp, components},
273271
{b, ProgramTensorMetadataDependency::TypeAndRank, b_shape_temp, components}})
274-
.AddUniformVariables({{dim_a_outer}, {dim_b_outer}, {dim_inner}})
272+
.AddUniformVariables({{dim_a_outer}, {dim_b_outer}, {dim_inner}, {dispatch_x}, {dispatch_y}, {dispatch_z}})
275273
.AddIndices(outer_dims)
276274
.SetDispatchGroupSize(dispatch_x, dispatch_y, dispatch_z)
277275
.SetWorkgroupSize(MatMul::MATMUL_PACKED_WORKGROUP_SIZE_X, MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Y, MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Z)

onnxruntime/core/providers/webgpu/math/matmul_packed.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,10 @@ class MatMulProgram final : public Program<MatMulProgram> {
2424
Status GenerateShaderCode(ShaderHelper& sh) const override;
2525
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"dim_a_outer", ProgramUniformVariableDataType::Uint32},
2626
{"dim_b_outer", ProgramUniformVariableDataType::Uint32},
27-
{"dim_inner", ProgramUniformVariableDataType::Uint32});
27+
{"dim_inner", ProgramUniformVariableDataType::Uint32},
28+
{"logical_dispatch_x", ProgramUniformVariableDataType::Uint32},
29+
{"logical_dispatch_y", ProgramUniformVariableDataType::Uint32},
30+
{"logical_dispatch_z", ProgramUniformVariableDataType::Uint32});
2831

2932
bool NeedSplitK() const;
3033

onnxruntime/core/providers/webgpu/nn/conv2d_mm.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,10 @@ Conv2dMMProgram CreateConv2dMMProgram(const Activation& activation, const std::v
226226
{static_cast<uint32_t>(dim_inner)},
227227
{pads},
228228
{strides},
229-
{dilations}});
229+
{dilations},
230+
{dispatch[0]},
231+
{dispatch[1]},
232+
{dispatch[2]}});
230233

231234
return program;
232235
}

onnxruntime/core/providers/webgpu/nn/conv2d_mm.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,10 @@ class Conv2dMMProgram final : public Program<Conv2dMMProgram> {
3838
{"dim_inner", ProgramUniformVariableDataType::Uint32},
3939
{"pads", ProgramUniformVariableDataType::Uint32},
4040
{"strides", ProgramUniformVariableDataType::Uint32},
41-
{"dilations", ProgramUniformVariableDataType::Uint32});
41+
{"dilations", ProgramUniformVariableDataType::Uint32},
42+
{"logical_dispatch_x", ProgramUniformVariableDataType::Uint32},
43+
{"logical_dispatch_y", ProgramUniformVariableDataType::Uint32},
44+
{"logical_dispatch_z", ProgramUniformVariableDataType::Uint32});
4245

4346
private:
4447
const Activation& activation_;

0 commit comments

Comments
 (0)