@@ -272,22 +272,34 @@ Status MakeMatMulPackedVec4Source(ShaderHelper& shader,
272272 << " const rowPerThread = " << elements_per_thread_y << " ;\n "
273273 << " const colPerThread = " << elements_per_thread_x << " ;\n "
274274 << " const innerElementSize = " << inner_elements_size << " ;\n "
275- << " const tileInner = " << tile_inner << " ;\n " ;
275+ << " const tileInner = " << tile_inner << " ;\n "
276+ << " const workgroupSize = vec3u(workgroup_size_x, workgroup_size_y, workgroup_size_z);\n " ;
277+
278+ // Compute `logical_workgroup_id` and `logical_global_id` because the dispatch workgroup size in
279+ // `ProgramBase.SetDispatchGroupSize()` may be normalized in
280+ // `ProgramManager::NormalizeDispatchGroupSize()`. In the shader we should always use
281+ // `logical_workgroup_id` and `logical_global_id` instead of `workgroup_id` and `global_id`.
282+ shader.MainFunctionBody ()
283+ << " let logical_workgroup_id_z = workgroup_idx / (uniforms.logical_dispatch_x * uniforms.logical_dispatch_y);\n "
284+ << " let logical_workgroup_id_y = (workgroup_idx % (uniforms.logical_dispatch_x * uniforms.logical_dispatch_y)) / uniforms.logical_dispatch_x;\n "
285+ << " let logical_workgroup_id_x = (workgroup_idx % (uniforms.logical_dispatch_x * uniforms.logical_dispatch_y)) % uniforms.logical_dispatch_x;\n "
286+ << " let logical_workgroup_id = vec3u(logical_workgroup_id_x, logical_workgroup_id_y, logical_workgroup_id_z);\n "
287+ << " let logical_global_id = logical_workgroup_id * workgroupSize + local_id;\n " ;
276288
277289 shader.MainFunctionBody ()
278290 << " let localRow = i32(local_id.y);\n "
279291 << " let tileRow = localRow * rowPerThread;\n "
280292 << " let tileCol = i32(local_id.x);\n "
281- << " let globalRow = i32(global_id .y) * rowPerThread;\n "
282- << " let globalCol = i32(global_id .x);\n "
283- << " let globalRowStart = i32(workgroup_id .y) * " << tile_a_outer << " ;\n "
284- << " let globalColStart = i32(workgroup_id .x) * " << tile_b_outer << " ;\n "
293+ << " let globalRow = i32(logical_global_id .y) * rowPerThread;\n "
294+ << " let globalCol = i32(logical_global_id .x);\n "
295+ << " let globalRowStart = i32(logical_workgroup_id .y) * " << tile_a_outer << " ;\n "
296+ << " let globalColStart = i32(logical_workgroup_id .x) * " << tile_b_outer << " ;\n "
285297 << " var acc: array<vec4<" << data_type << " >, rowPerThread>;\n " ;
286298
287299 if (split_k) {
288300 // With Split-K, the original "workgroup" (with dispatch_z == 1 in API side) is split into
289301 // multiple ones, and in the current workgroup we only compute `kSplitK` elements starting from
290- // `kSplitK * i32(global_id .z)`.
302+ // `kSplitK * i32(logical_global_id .z)`.
291303 //
292304 // For example: considering computing Y = (X * W + B) in one workgroup.
293305 // Let kSplitK = 2, B = [d1, d2]
@@ -305,23 +317,23 @@ Status MakeMatMulPackedVec4Source(ShaderHelper& shader,
305317 // Workgroup1: compute (A1 * A2) Workgroup2: compute (B1 * B2)
306318 // Workgroup3: compute (C1 * C2)
307319 // In each workgroup:
308- // - `num_tiles` is computed with `kSplitK`, and `kStart` is computed with `global_id .z`
320+ // - `num_tiles` is computed with `kSplitK`, and `kStart` is computed with `logical_global_id .z`
309321 // - When the computation in each workgroup is completed, add the result to Y with several
310322 // atomic built-in functions in `HandleMatMulWithSplitK()`.
311323 shader.MainFunctionBody ()
312324 << " const kSplitK = " << split_dim_inner << " ;\n "
313325 << " let num_tiles = (kSplitK - 1) / tileInner + 1;\n "
314- << " var kStart = kSplitK * i32(global_id .z);\n "
326+ << " var kStart = kSplitK * i32(logical_global_id .z);\n "
315327
316- // When Split-K is used, `batch` should always be 0 and `global_id .z` is used to indicate
328+ // When Split-K is used, `batch` should always be 0 and `logical_global_id .z` is used to indicate
317329 // the index of split-k instead of batch.
318330 << " let batch = 0;\n "
319331 << " let batchIndices = 0u;\n " ;
320332 } else {
321333 shader.MainFunctionBody ()
322334 << " let num_tiles = (uniforms.dim_inner - 1) / tileInner + 1;\n "
323335 << " var kStart = 0;\n "
324- << " let batch = i32(global_id .z);\n "
336+ << " let batch = i32(logical_global_id .z);\n "
325337 << (nullptr != batch_dims ? " let batchIndices = " + batch_dims->OffsetToIndices (" u32(batch)" ) + " ;\n " : " " );
326338 }
327339
@@ -496,9 +508,21 @@ Status MakeMatMulPackedSource(ShaderHelper& shader,
496508 << " var<workgroup> mm_Bsub: array<array<" << data_type << " , " << tile_b_outer << " >, " << tile_inner << " >;\n "
497509 << " const rowPerThread = " << elements_per_thread_y << " ;\n "
498510 << " const colPerThread = " << elements_per_thread_x << " ;\n "
499- << " const tileInner = " << tile_inner << " ;\n " ;
511+ << " const tileInner = " << tile_inner << " ;\n "
512+ << " const workgroupSize = vec3u(workgroup_size_x, workgroup_size_y, workgroup_size_z);\n " ;
513+
514+ // Compute `logical_workgroup_id` and `logical_global_id` because the dispatch workgroup size in
515+ // `ProgramBase.SetDispatchGroupSize()` may be normalized in
516+ // `ProgramManager::NormalizeDispatchGroupSize()`. In the shader we should always use
517+ // `logical_workgroup_id` and `logical_global_id` instead of `workgroup_id` and `global_id`.
518+ shader.MainFunctionBody ()
519+ << " let logical_workgroup_id_z = workgroup_idx / (uniforms.logical_dispatch_x * uniforms.logical_dispatch_y);\n "
520+ << " let logical_workgroup_id_y = (workgroup_idx % (uniforms.logical_dispatch_x * uniforms.logical_dispatch_y)) / uniforms.logical_dispatch_x;\n "
521+ << " let logical_workgroup_id_x = (workgroup_idx % (uniforms.logical_dispatch_x * uniforms.logical_dispatch_y)) % uniforms.logical_dispatch_x;\n "
522+ << " let logical_workgroup_id = vec3u(logical_workgroup_id_x, logical_workgroup_id_y, logical_workgroup_id_z);\n "
523+ << " let logical_global_id = logical_workgroup_id * workgroupSize + local_id;\n " ;
500524
501- shader.MainFunctionBody () << " let batch = i32(global_id .z);\n "
525+ shader.MainFunctionBody () << " let batch = i32(logical_global_id .z);\n "
502526 << (nullptr != batch_dims ? " let batchIndices = " + batch_dims->OffsetToIndices (" u32(batch)" ) + " ;\n " : " " )
503527 << " let num_tiles = (uniforms.dim_inner - 1) / tileInner + 1;\n "
504528 << " var kStart = 0;\n "
@@ -507,10 +531,10 @@ Status MakeMatMulPackedSource(ShaderHelper& shader,
507531 shader.MainFunctionBody ()
508532 << " let tileRow = i32(local_id.y) * rowPerThread;\n "
509533 << " let tileCol = i32(local_id.x) * colPerThread;\n "
510- << " let globalRow = i32(global_id .y) * rowPerThread;\n "
511- << " let globalCol = i32(global_id .x) * colPerThread;\n "
512- << " let globalRowStart = i32(workgroup_id .y) * " << tile_a_outer << " ;\n "
513- << " let globalColStart = i32(workgroup_id .x) * " << tile_b_outer << " ;\n "
534+ << " let globalRow = i32(logical_global_id .y) * rowPerThread;\n "
535+ << " let globalCol = i32(logical_global_id .x) * colPerThread;\n "
536+ << " let globalRowStart = i32(logical_workgroup_id .y) * " << tile_a_outer << " ;\n "
537+ << " let globalColStart = i32(logical_workgroup_id .x) * " << tile_b_outer << " ;\n "
514538 << " let tileRowA = i32(local_id.y) * " << row_per_thread_a << " ;\n "
515539 << " let tileColA = i32(local_id.x) * " << col_per_thread_a << " ;\n "
516540 << " let tileRowB = i32(local_id.y) * " << row_per_thread_b << " ;\n " ;
0 commit comments