Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions onnxruntime/core/providers/webgpu/math/gemm_packed.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,18 +93,21 @@ Status ApplyGemmPacked(const Tensor* a,
}

const uint32_t TILE_SIZE = 32;
const uint32_t num_tile_n = (N + TILE_SIZE - 1) / TILE_SIZE;
const uint32_t num_tile_m = (M + TILE_SIZE - 1) / TILE_SIZE;
const uint32_t dispatch_x = (N + TILE_SIZE - 1) / TILE_SIZE;
const uint32_t dispatch_y = (M + TILE_SIZE - 1) / TILE_SIZE;

program.CacheHint(alpha, transA, transB, c_is_scalar)
.AddOutputs({{y, ProgramTensorMetadataDependency::TypeAndRank, output_components}})
.SetDispatchGroupSize(num_tile_n, num_tile_m, 1)
.SetDispatchGroupSize(dispatch_x, dispatch_y, 1u)
.SetWorkgroupSize(GemmProgram::MATMUL_PACKED_WORKGROUP_SIZE_X, GemmProgram::MATMUL_PACKED_WORKGROUP_SIZE_Y, GemmProgram::MATMUL_PACKED_WORKGROUP_SIZE_Z)
.AddUniformVariables({{alpha},
{beta},
{M}, /* dim_a_outer */
{N}, /* dim_b_outer */
{K}} /*dim_inner */
{M}, /* dim_a_outer */
{N}, /* dim_b_outer */
{K}, /*dim_inner */
{dispatch_x}, /* logical_dispatch_x */
{dispatch_y}, /* logical_dispatch_y */
{1u}} /* logical_dispatch_z */
);

return context.RunProgram(program);
Expand Down
5 changes: 4 additions & 1 deletion onnxruntime/core/providers/webgpu/math/gemm_packed.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@ class GemmProgram final : public Program<GemmProgram> {
{"beta", ProgramUniformVariableDataType::Float32},
{"dim_a_outer", ProgramUniformVariableDataType::Uint32},
{"dim_b_outer", ProgramUniformVariableDataType::Uint32},
{"dim_inner", ProgramUniformVariableDataType::Uint32});
{"dim_inner", ProgramUniformVariableDataType::Uint32},
{"logical_dispatch_x", ProgramUniformVariableDataType::Uint32},
{"logical_dispatch_y", ProgramUniformVariableDataType::Uint32},
{"logical_dispatch_z", ProgramUniformVariableDataType::Uint32});

constexpr static uint32_t MATMUL_PACKED_WORKGROUP_SIZE_X = 8;
constexpr static uint32_t MATMUL_PACKED_WORKGROUP_SIZE_Y = 8;
Expand Down
46 changes: 32 additions & 14 deletions onnxruntime/core/providers/webgpu/math/gemm_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,20 @@ void HandleMatMulWithSplitK(
}
}

// Compute `logical_workgroup_id` and `logical_global_id` because the dispatch workgroup size in
// `ProgramBase.SetDispatchGroupSize()` may be normalized in
// `ProgramManager::NormalizeDispatchGroupSize()`. In the shader we should always use
// `logical_workgroup_id` and `logical_global_id` instead of `workgroup_id` and `global_id`.
void InitializeLogicalWorkgroupIDAndGlobalID(ShaderHelper& shader) {
shader.MainFunctionBody()
<< " let logical_workgroup_id_z = workgroup_idx / (uniforms.logical_dispatch_x * uniforms.logical_dispatch_y);\n"
<< " let logical_workgroup_id_y = (workgroup_idx % (uniforms.logical_dispatch_x * uniforms.logical_dispatch_y)) / uniforms.logical_dispatch_x;\n"
<< " let logical_workgroup_id_x = (workgroup_idx % (uniforms.logical_dispatch_x * uniforms.logical_dispatch_y)) % uniforms.logical_dispatch_x;\n"
<< " let logical_workgroup_id = vec3u(logical_workgroup_id_x, logical_workgroup_id_y, logical_workgroup_id_z);\n"
<< " const workgroupSize = vec3u(workgroup_size_x, workgroup_size_y, workgroup_size_z);\n"
<< " let logical_global_id = logical_workgroup_id * workgroupSize + local_id;\n";
}

} // namespace

void MatMulReadFnSource(ShaderHelper& shader,
Expand Down Expand Up @@ -274,20 +288,22 @@ Status MakeMatMulPackedVec4Source(ShaderHelper& shader,
<< "const innerElementSize = " << inner_elements_size << ";\n"
<< "const tileInner = " << tile_inner << ";\n";

InitializeLogicalWorkgroupIDAndGlobalID(shader);

shader.MainFunctionBody()
<< " let localRow = i32(local_id.y);\n"
<< " let tileRow = localRow * rowPerThread;\n"
<< " let tileCol = i32(local_id.x);\n"
<< " let globalRow = i32(global_id.y) * rowPerThread;\n"
<< " let globalCol = i32(global_id.x);\n"
<< " let globalRowStart = i32(workgroup_id.y) * " << tile_a_outer << ";\n"
<< " let globalColStart = i32(workgroup_id.x) * " << tile_b_outer << ";\n"
<< " let globalRow = i32(logical_global_id.y) * rowPerThread;\n"
<< " let globalCol = i32(logical_global_id.x);\n"
<< " let globalRowStart = i32(logical_workgroup_id.y) * " << tile_a_outer << ";\n"
<< " let globalColStart = i32(logical_workgroup_id.x) * " << tile_b_outer << ";\n"
<< " var acc: array<vec4<" << data_type << ">, rowPerThread>;\n";

if (split_k) {
// With Split-K, the original "workgroup" (with dispatch_z == 1 in API side) is split into
// multiple ones, and in the current workgroup we only compute `kSplitK` elements starting from
// `kSplitK * i32(global_id.z)`.
// `kSplitK * i32(logical_global_id.z)`.
//
// For example: considering computing Y = (X * W + B) in one workgroup.
// Let kSplitK = 2, B = [d1, d2]
Expand All @@ -305,23 +321,23 @@ Status MakeMatMulPackedVec4Source(ShaderHelper& shader,
// Workgroup1: compute (A1 * A2) Workgroup2: compute (B1 * B2)
// Workgroup3: compute (C1 * C2)
// In each workgroup:
// - `num_tiles` is computed with `kSplitK`, and `kStart` is computed with `global_id.z`
// - `num_tiles` is computed with `kSplitK`, and `kStart` is computed with `logical_global_id.z`
// - When the computation in each workgroup is completed, add the result to Y with several
// atomic built-in functions in `HandleMatMulWithSplitK()`.
shader.MainFunctionBody()
<< "const kSplitK = " << split_dim_inner << ";\n"
<< " let num_tiles = (kSplitK - 1) / tileInner + 1;\n"
<< " var kStart = kSplitK * i32(global_id.z);\n"
<< " var kStart = kSplitK * i32(logical_global_id.z);\n"

// When Split-K is used, `batch` should always be 0 and `global_id.z` is used to indicate
// When Split-K is used, `batch` should always be 0 and `logical_global_id.z` is used to indicate
// the index of split-k instead of batch.
<< " let batch = 0;\n"
<< " let batchIndices = 0u;\n";
} else {
shader.MainFunctionBody()
<< " let num_tiles = (uniforms.dim_inner - 1) / tileInner + 1;\n"
<< " var kStart = 0;\n"
<< " let batch = i32(global_id.z);\n"
<< " let batch = i32(logical_global_id.z);\n"
<< (nullptr != batch_dims ? " let batchIndices = " + batch_dims->OffsetToIndices("u32(batch)") + ";\n" : "");
}

Expand Down Expand Up @@ -498,7 +514,9 @@ Status MakeMatMulPackedSource(ShaderHelper& shader,
<< "const colPerThread = " << elements_per_thread_x << ";\n"
<< "const tileInner = " << tile_inner << ";\n";

shader.MainFunctionBody() << " let batch = i32(global_id.z);\n"
InitializeLogicalWorkgroupIDAndGlobalID(shader);

shader.MainFunctionBody() << " let batch = i32(logical_global_id.z);\n"
<< (nullptr != batch_dims ? " let batchIndices = " + batch_dims->OffsetToIndices("u32(batch)") + ";\n" : "")
<< " let num_tiles = (uniforms.dim_inner - 1) / tileInner + 1;\n"
<< " var kStart = 0;\n"
Expand All @@ -507,10 +525,10 @@ Status MakeMatMulPackedSource(ShaderHelper& shader,
shader.MainFunctionBody()
<< "let tileRow = i32(local_id.y) * rowPerThread;\n"
<< "let tileCol = i32(local_id.x) * colPerThread;\n"
<< "let globalRow = i32(global_id.y) * rowPerThread;\n"
<< "let globalCol = i32(global_id.x) * colPerThread;\n"
<< "let globalRowStart = i32(workgroup_id.y) * " << tile_a_outer << ";\n"
<< "let globalColStart = i32(workgroup_id.x) * " << tile_b_outer << ";\n"
<< "let globalRow = i32(logical_global_id.y) * rowPerThread;\n"
<< "let globalCol = i32(logical_global_id.x) * colPerThread;\n"
<< "let globalRowStart = i32(logical_workgroup_id.y) * " << tile_a_outer << ";\n"
<< "let globalColStart = i32(logical_workgroup_id.x) * " << tile_b_outer << ";\n"
<< "let tileRowA = i32(local_id.y) * " << row_per_thread_a << ";\n"
<< "let tileColA = i32(local_id.x) * " << col_per_thread_a << ";\n"
<< "let tileRowB = i32(local_id.y) * " << row_per_thread_b << ";\n";
Expand Down
4 changes: 1 addition & 3 deletions onnxruntime/core/providers/webgpu/math/matmul.cc
Original file line number Diff line number Diff line change
Expand Up @@ -256,8 +256,6 @@ Status ComputeMatMul(ComputeContext* context,

// With Split-K, `dim_inner` will be split into multiple parts and `dispatch_z` will be the
// number of splits along `dim_inner`.
// TODO: avoid using `global_id.xxx` or `workgroup_id.xxx` in `MatMulProgram` when we normalize
// the dispatch size with `ProgramManager::NormalizeDispatchGroupSize()` for `MatMulProgram`.
split_dim_inner = split_k_config.GetSplitDimInner();
dispatch_z = (dim_inner + split_dim_inner - 1) / split_dim_inner;

Expand All @@ -271,7 +269,7 @@ Status ComputeMatMul(ComputeContext* context,
.CacheHint(activation.ToString(), absl::StrJoin(elements_per_thread, "-"), std::to_string(is_vec4), components, is_channels_last, split_dim_inner)
.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, a_shape_temp, components},
{b, ProgramTensorMetadataDependency::TypeAndRank, b_shape_temp, components}})
.AddUniformVariables({{dim_a_outer}, {dim_b_outer}, {dim_inner}})
.AddUniformVariables({{dim_a_outer}, {dim_b_outer}, {dim_inner}, {dispatch_x}, {dispatch_y}, {dispatch_z}})
.AddIndices(outer_dims)
.SetDispatchGroupSize(dispatch_x, dispatch_y, dispatch_z)
.SetWorkgroupSize(MatMul::MATMUL_PACKED_WORKGROUP_SIZE_X, MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Y, MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Z)
Expand Down
5 changes: 4 additions & 1 deletion onnxruntime/core/providers/webgpu/math/matmul_packed.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@ class MatMulProgram final : public Program<MatMulProgram> {
Status GenerateShaderCode(ShaderHelper& sh) const override;
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"dim_a_outer", ProgramUniformVariableDataType::Uint32},
{"dim_b_outer", ProgramUniformVariableDataType::Uint32},
{"dim_inner", ProgramUniformVariableDataType::Uint32});
{"dim_inner", ProgramUniformVariableDataType::Uint32},
{"logical_dispatch_x", ProgramUniformVariableDataType::Uint32},
{"logical_dispatch_y", ProgramUniformVariableDataType::Uint32},
{"logical_dispatch_z", ProgramUniformVariableDataType::Uint32});

bool NeedSplitK() const;

Expand Down
5 changes: 4 additions & 1 deletion onnxruntime/core/providers/webgpu/nn/conv2d_mm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,10 @@ Conv2dMMProgram CreateConv2dMMProgram(const Activation& activation, const std::v
{static_cast<uint32_t>(dim_inner)},
{pads},
{strides},
{dilations}});
{dilations},
{dispatch[0]},
{dispatch[1]},
{dispatch[2]}});

return program;
}
Expand Down
5 changes: 4 additions & 1 deletion onnxruntime/core/providers/webgpu/nn/conv2d_mm.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,10 @@ class Conv2dMMProgram final : public Program<Conv2dMMProgram> {
{"dim_inner", ProgramUniformVariableDataType::Uint32},
{"pads", ProgramUniformVariableDataType::Uint32},
{"strides", ProgramUniformVariableDataType::Uint32},
{"dilations", ProgramUniformVariableDataType::Uint32});
{"dilations", ProgramUniformVariableDataType::Uint32},
{"logical_dispatch_x", ProgramUniformVariableDataType::Uint32},
{"logical_dispatch_y", ProgramUniformVariableDataType::Uint32},
{"logical_dispatch_z", ProgramUniformVariableDataType::Uint32});

private:
const Activation& activation_;
Expand Down
Loading