Skip to content

Commit efc937c

Browse files
committed
Address reviewer's comment
1 parent 2786714 commit efc937c

File tree

1 file changed

+18
-26
lines changed

1 file changed

+18
-26
lines changed

onnxruntime/core/providers/webgpu/math/gemm_utils.cc

Lines changed: 18 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,20 @@ void HandleMatMulWithSplitK(
117117
}
118118
}
119119

120+
// Compute `logical_workgroup_id` and `logical_global_id` because the dispatch workgroup size in
121+
// `ProgramBase.SetDispatchGroupSize()` may be normalized in
122+
// `ProgramManager::NormalizeDispatchGroupSize()`. In the shader we should always use
123+
// `logical_workgroup_id` and `logical_global_id` instead of `workgroup_id` and `global_id`.
124+
void InitializeLogicalWorkgroupIDAndGlobalID(ShaderHelper& shader) {
125+
shader.MainFunctionBody()
126+
<< " let logical_workgroup_id_z = workgroup_idx / (uniforms.logical_dispatch_x * uniforms.logical_dispatch_y);\n"
127+
<< " let logical_workgroup_id_y = (workgroup_idx % (uniforms.logical_dispatch_x * uniforms.logical_dispatch_y)) / uniforms.logical_dispatch_x;\n"
128+
<< " let logical_workgroup_id_x = (workgroup_idx % (uniforms.logical_dispatch_x * uniforms.logical_dispatch_y)) % uniforms.logical_dispatch_x;\n"
129+
<< " let logical_workgroup_id = vec3u(logical_workgroup_id_x, logical_workgroup_id_y, logical_workgroup_id_z);\n"
130+
<< " const workgroupSize = vec3u(workgroup_size_x, workgroup_size_y, workgroup_size_z);\n"
131+
<< " let logical_global_id = logical_workgroup_id * workgroupSize + local_id;\n";
132+
}
133+
120134
} // namespace
121135

122136
void MatMulReadFnSource(ShaderHelper& shader,
@@ -272,20 +286,9 @@ Status MakeMatMulPackedVec4Source(ShaderHelper& shader,
272286
<< "const rowPerThread = " << elements_per_thread_y << ";\n"
273287
<< "const colPerThread = " << elements_per_thread_x << ";\n"
274288
<< "const innerElementSize = " << inner_elements_size << ";\n"
275-
<< "const tileInner = " << tile_inner << ";\n"
276-
<< "const workgroup_size = vec3u(workgroup_size_x, workgroup_size_y, workgroup_size_z);\n";
289+
<< "const tileInner = " << tile_inner << ";\n";
277290

278-
// Compute `logical_workgroup_id` and `logical_global_id` because the dispatch workgroup size in
279-
// `ProgramBase.SetDispatchGroupSize()` may be normalized in
280-
// `ProgramManager::NormalizeDispatchGroupSize()`. In the shader we should always use
281-
// `logical_workgroup_id` and `logical_global_id` instead of `workgroup_id` and `global_id`.
282-
shader.MainFunctionBody()
283-
<< " let logical_workgroups_xy = uniforms.logical_dispatch_x * uniforms.logical_dispatch_y;\n"
284-
<< " let logical_workgroup_id_z = workgroup_idx / logical_workgroups_xy;\n"
285-
<< " let logical_workgroup_id_y = (workgroup_idx % logical_workgroups_xy) / uniforms.logical_dispatch_x;\n"
286-
<< " let logical_workgroup_id_x = (workgroup_idx % logical_workgroups_xy) % uniforms.logical_dispatch_x;\n"
287-
<< " let logical_workgroup_id = vec3u(logical_workgroup_id_x, logical_workgroup_id_y, logical_workgroup_id_z);\n"
288-
<< " let logical_global_id = logical_workgroup_id * workgroup_size + local_id;\n";
291+
InitializeLogicalWorkgroupIDAndGlobalID(shader);
289292

290293
shader.MainFunctionBody()
291294
<< " let localRow = i32(local_id.y);\n"
@@ -509,20 +512,9 @@ Status MakeMatMulPackedSource(ShaderHelper& shader,
509512
<< "var<workgroup> mm_Bsub: array<array<" << data_type << ", " << tile_b_outer << ">, " << tile_inner << ">;\n"
510513
<< "const rowPerThread = " << elements_per_thread_y << ";\n"
511514
<< "const colPerThread = " << elements_per_thread_x << ";\n"
512-
<< "const tileInner = " << tile_inner << ";\n"
513-
<< "const workgroup_size = vec3u(workgroup_size_x, workgroup_size_y, workgroup_size_z);\n";
515+
<< "const tileInner = " << tile_inner << ";\n";
514516

515-
// Compute `logical_workgroup_id` and `logical_global_id` because the dispatch workgroup size in
516-
// `ProgramBase.SetDispatchGroupSize()` may be normalized in
517-
// `ProgramManager::NormalizeDispatchGroupSize()`. In the shader we should always use
518-
// `logical_workgroup_id` and `logical_global_id` instead of `workgroup_id` and `global_id`.
519-
shader.MainFunctionBody()
520-
<< " let logical_workgroups_xy = uniforms.logical_dispatch_x * uniforms.logical_dispatch_y;\n"
521-
<< " let logical_workgroup_id_z = workgroup_idx / logical_workgroups_xy;\n"
522-
<< " let logical_workgroup_id_y = (workgroup_idx % logical_workgroups_xy) / uniforms.logical_dispatch_x;\n"
523-
<< " let logical_workgroup_id_x = (workgroup_idx % logical_workgroups_xy) % uniforms.logical_dispatch_x;\n"
524-
<< " let logical_workgroup_id = vec3u(logical_workgroup_id_x, logical_workgroup_id_y, logical_workgroup_id_z);\n"
525-
<< " let logical_global_id = logical_workgroup_id * workgroup_size + local_id;\n";
517+
InitializeLogicalWorkgroupIDAndGlobalID(shader);
526518

527519
shader.MainFunctionBody() << " let batch = i32(logical_global_id.z);\n"
528520
<< (nullptr != batch_dims ? " let batchIndices = " + batch_dims->OffsetToIndices("u32(batch)") + ";\n" : "")

0 commit comments

Comments
 (0)