Skip to content

Commit 2786714

Browse files
committed
Address the comments from copilot
1 parent 8b7017a commit 2786714

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

onnxruntime/core/providers/webgpu/math/gemm_utils.cc

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -280,9 +280,10 @@ Status MakeMatMulPackedVec4Source(ShaderHelper& shader,
280280
// `ProgramManager::NormalizeDispatchGroupSize()`. In the shader we should always use
281281
// `logical_workgroup_id` and `logical_global_id` instead of `workgroup_id` and `global_id`.
282282
shader.MainFunctionBody()
283-
<< " let logical_workgroup_id_z = workgroup_idx / (uniforms.logical_dispatch_x * uniforms.logical_dispatch_y);\n"
284-
<< " let logical_workgroup_id_y = (workgroup_idx % (uniforms.logical_dispatch_x * uniforms.logical_dispatch_y)) / uniforms.logical_dispatch_x;\n"
285-
<< " let logical_workgroup_id_x = (workgroup_idx % (uniforms.logical_dispatch_x * uniforms.logical_dispatch_y)) % uniforms.logical_dispatch_x;\n"
283+
<< " let logical_workgroups_xy = uniforms.logical_dispatch_x * uniforms.logical_dispatch_y;\n"
284+
<< " let logical_workgroup_id_z = workgroup_idx / logical_workgroups_xy;\n"
285+
<< " let logical_workgroup_id_y = (workgroup_idx % logical_workgroups_xy) / uniforms.logical_dispatch_x;\n"
286+
<< " let logical_workgroup_id_x = (workgroup_idx % logical_workgroups_xy) % uniforms.logical_dispatch_x;\n"
286287
<< " let logical_workgroup_id = vec3u(logical_workgroup_id_x, logical_workgroup_id_y, logical_workgroup_id_z);\n"
287288
<< " let logical_global_id = logical_workgroup_id * workgroup_size + local_id;\n";
288289

@@ -516,9 +517,10 @@ Status MakeMatMulPackedSource(ShaderHelper& shader,
516517
// `ProgramManager::NormalizeDispatchGroupSize()`. In the shader we should always use
517518
// `logical_workgroup_id` and `logical_global_id` instead of `workgroup_id` and `global_id`.
518519
shader.MainFunctionBody()
519-
<< " let logical_workgroup_id_z = workgroup_idx / (uniforms.logical_dispatch_x * uniforms.logical_dispatch_y);\n"
520-
<< " let logical_workgroup_id_y = (workgroup_idx % (uniforms.logical_dispatch_x * uniforms.logical_dispatch_y)) / uniforms.logical_dispatch_x;\n"
521-
<< " let logical_workgroup_id_x = (workgroup_idx % (uniforms.logical_dispatch_x * uniforms.logical_dispatch_y)) % uniforms.logical_dispatch_x;\n"
520+
<< " let logical_workgroups_xy = uniforms.logical_dispatch_x * uniforms.logical_dispatch_y;\n"
521+
<< " let logical_workgroup_id_z = workgroup_idx / logical_workgroups_xy;\n"
522+
<< " let logical_workgroup_id_y = (workgroup_idx % logical_workgroups_xy) / uniforms.logical_dispatch_x;\n"
523+
<< " let logical_workgroup_id_x = (workgroup_idx % logical_workgroups_xy) % uniforms.logical_dispatch_x;\n"
522524
<< " let logical_workgroup_id = vec3u(logical_workgroup_id_x, logical_workgroup_id_y, logical_workgroup_id_z);\n"
523525
<< " let logical_global_id = logical_workgroup_id * workgroup_size + local_id;\n";
524526

0 commit comments

Comments
 (0)