@@ -117,6 +117,20 @@ void HandleMatMulWithSplitK(
117117 }
118118}
119119
120+ // Compute `logical_workgroup_id` and `logical_global_id` because the dispatch workgroup size in
121+ // `ProgramBase.SetDispatchGroupSize()` may be normalized in
122+ // `ProgramManager::NormalizeDispatchGroupSize()`. In the shader we should always use
123+ // `logical_workgroup_id` and `logical_global_id` instead of `workgroup_id` and `global_id`.
124+ void InitializeLogicalWorkgroupIDAndGlobalID (ShaderHelper& shader) {
125+ shader.MainFunctionBody ()
126+ << " let logical_workgroup_id_z = workgroup_idx / (uniforms.logical_dispatch_x * uniforms.logical_dispatch_y);\n "
127+ << " let logical_workgroup_id_y = (workgroup_idx % (uniforms.logical_dispatch_x * uniforms.logical_dispatch_y)) / uniforms.logical_dispatch_x;\n "
128+ << " let logical_workgroup_id_x = (workgroup_idx % (uniforms.logical_dispatch_x * uniforms.logical_dispatch_y)) % uniforms.logical_dispatch_x;\n "
129+ << " let logical_workgroup_id = vec3u(logical_workgroup_id_x, logical_workgroup_id_y, logical_workgroup_id_z);\n "
130+ << " const workgroupSize = vec3u(workgroup_size_x, workgroup_size_y, workgroup_size_z);\n "
131+ << " let logical_global_id = logical_workgroup_id * workgroupSize + local_id;\n " ;
132+ }
133+
120134} // namespace
121135
122136void MatMulReadFnSource (ShaderHelper& shader,
@@ -272,20 +286,9 @@ Status MakeMatMulPackedVec4Source(ShaderHelper& shader,
272286 << " const rowPerThread = " << elements_per_thread_y << " ;\n "
273287 << " const colPerThread = " << elements_per_thread_x << " ;\n "
274288 << " const innerElementSize = " << inner_elements_size << " ;\n "
275- << " const tileInner = " << tile_inner << " ;\n "
276- << " const workgroup_size = vec3u(workgroup_size_x, workgroup_size_y, workgroup_size_z);\n " ;
289+ << " const tileInner = " << tile_inner << " ;\n " ;
277290
278- // Compute `logical_workgroup_id` and `logical_global_id` because the dispatch workgroup size in
279- // `ProgramBase.SetDispatchGroupSize()` may be normalized in
280- // `ProgramManager::NormalizeDispatchGroupSize()`. In the shader we should always use
281- // `logical_workgroup_id` and `logical_global_id` instead of `workgroup_id` and `global_id`.
282- shader.MainFunctionBody ()
283- << " let logical_workgroups_xy = uniforms.logical_dispatch_x * uniforms.logical_dispatch_y;\n "
284- << " let logical_workgroup_id_z = workgroup_idx / logical_workgroups_xy;\n "
285- << " let logical_workgroup_id_y = (workgroup_idx % logical_workgroups_xy) / uniforms.logical_dispatch_x;\n "
286- << " let logical_workgroup_id_x = (workgroup_idx % logical_workgroups_xy) % uniforms.logical_dispatch_x;\n "
287- << " let logical_workgroup_id = vec3u(logical_workgroup_id_x, logical_workgroup_id_y, logical_workgroup_id_z);\n "
288- << " let logical_global_id = logical_workgroup_id * workgroup_size + local_id;\n " ;
291+ InitializeLogicalWorkgroupIDAndGlobalID (shader);
289292
290293 shader.MainFunctionBody ()
291294 << " let localRow = i32(local_id.y);\n "
@@ -509,20 +512,9 @@ Status MakeMatMulPackedSource(ShaderHelper& shader,
509512 << " var<workgroup> mm_Bsub: array<array<" << data_type << " , " << tile_b_outer << " >, " << tile_inner << " >;\n "
510513 << " const rowPerThread = " << elements_per_thread_y << " ;\n "
511514 << " const colPerThread = " << elements_per_thread_x << " ;\n "
512- << " const tileInner = " << tile_inner << " ;\n "
513- << " const workgroup_size = vec3u(workgroup_size_x, workgroup_size_y, workgroup_size_z);\n " ;
515+ << " const tileInner = " << tile_inner << " ;\n " ;
514516
515- // Compute `logical_workgroup_id` and `logical_global_id` because the dispatch workgroup size in
516- // `ProgramBase.SetDispatchGroupSize()` may be normalized in
517- // `ProgramManager::NormalizeDispatchGroupSize()`. In the shader we should always use
518- // `logical_workgroup_id` and `logical_global_id` instead of `workgroup_id` and `global_id`.
519- shader.MainFunctionBody ()
520- << " let logical_workgroups_xy = uniforms.logical_dispatch_x * uniforms.logical_dispatch_y;\n "
521- << " let logical_workgroup_id_z = workgroup_idx / logical_workgroups_xy;\n "
522- << " let logical_workgroup_id_y = (workgroup_idx % logical_workgroups_xy) / uniforms.logical_dispatch_x;\n "
523- << " let logical_workgroup_id_x = (workgroup_idx % logical_workgroups_xy) % uniforms.logical_dispatch_x;\n "
524- << " let logical_workgroup_id = vec3u(logical_workgroup_id_x, logical_workgroup_id_y, logical_workgroup_id_z);\n "
525- << " let logical_global_id = logical_workgroup_id * workgroup_size + local_id;\n " ;
517+ InitializeLogicalWorkgroupIDAndGlobalID (shader);
526518
527519 shader.MainFunctionBody () << " let batch = i32(logical_global_id.z);\n "
528520 << (nullptr != batch_dims ? " let batchIndices = " + batch_dims->OffsetToIndices (" u32(batch)" ) + " ;\n " : " " )
0 commit comments