@@ -280,9 +280,10 @@ Status MakeMatMulPackedVec4Source(ShaderHelper& shader,
280280 // `ProgramManager::NormalizeDispatchGroupSize()`. In the shader we should always use
281281 // `logical_workgroup_id` and `logical_global_id` instead of `workgroup_id` and `global_id`.
282282 shader.MainFunctionBody ()
283- << " let logical_workgroup_id_z = workgroup_idx / (uniforms.logical_dispatch_x * uniforms.logical_dispatch_y);\n "
284- << " let logical_workgroup_id_y = (workgroup_idx % (uniforms.logical_dispatch_x * uniforms.logical_dispatch_y)) / uniforms.logical_dispatch_x;\n "
285- << " let logical_workgroup_id_x = (workgroup_idx % (uniforms.logical_dispatch_x * uniforms.logical_dispatch_y)) % uniforms.logical_dispatch_x;\n "
283+ << " let logical_workgroups_xy = uniforms.logical_dispatch_x * uniforms.logical_dispatch_y;\n "
284+ << " let logical_workgroup_id_z = workgroup_idx / logical_workgroups_xy;\n "
285+ << " let logical_workgroup_id_y = (workgroup_idx % logical_workgroups_xy) / uniforms.logical_dispatch_x;\n "
286+ << " let logical_workgroup_id_x = (workgroup_idx % logical_workgroups_xy) % uniforms.logical_dispatch_x;\n "
286287 << " let logical_workgroup_id = vec3u(logical_workgroup_id_x, logical_workgroup_id_y, logical_workgroup_id_z);\n "
287288 << " let logical_global_id = logical_workgroup_id * workgroup_size + local_id;\n " ;
288289
@@ -516,9 +517,10 @@ Status MakeMatMulPackedSource(ShaderHelper& shader,
516517 // `ProgramManager::NormalizeDispatchGroupSize()`. In the shader we should always use
517518 // `logical_workgroup_id` and `logical_global_id` instead of `workgroup_id` and `global_id`.
518519 shader.MainFunctionBody ()
519- << " let logical_workgroup_id_z = workgroup_idx / (uniforms.logical_dispatch_x * uniforms.logical_dispatch_y);\n "
520- << " let logical_workgroup_id_y = (workgroup_idx % (uniforms.logical_dispatch_x * uniforms.logical_dispatch_y)) / uniforms.logical_dispatch_x;\n "
521- << " let logical_workgroup_id_x = (workgroup_idx % (uniforms.logical_dispatch_x * uniforms.logical_dispatch_y)) % uniforms.logical_dispatch_x;\n "
520+ << " let logical_workgroups_xy = uniforms.logical_dispatch_x * uniforms.logical_dispatch_y;\n "
521+ << " let logical_workgroup_id_z = workgroup_idx / logical_workgroups_xy;\n "
522+ << " let logical_workgroup_id_y = (workgroup_idx % logical_workgroups_xy) / uniforms.logical_dispatch_x;\n "
523+ << " let logical_workgroup_id_x = (workgroup_idx % logical_workgroups_xy) % uniforms.logical_dispatch_x;\n "
522524 << " let logical_workgroup_id = vec3u(logical_workgroup_id_x, logical_workgroup_id_y, logical_workgroup_id_z);\n "
523525 << " let logical_global_id = logical_workgroup_id * workgroup_size + local_id;\n " ;
524526
0 commit comments