Skip to content

Commit 8f8069d

Browse files
authored
[webgpu] Optimize Conv by im2col-matmul (#26603)
### Description This PR optimizes the `Conv` operation by implementing two new compute shaders: `oihw_to_ohwi` and `im2col-matmul`. `oihw_to_ohwi`: Improves performance over the default Transpose shader by utilizing workgroup memory to ensure continuous memory read/write patterns. `im2col-matmul`: - Employs a workgroup size of 64. - Dynamically selects tile sizes (32x64 or 16x64) based on the source/weight shape. - Each invocation handles a dedicated weight element. - Uses subgroupShuffle to efficiently access the source tile, leveraging k_vec4 vectorization for better memory throughput. Testing on Lunar Lake demonstrated **up to an 87%** performance improvement in Conv_2D operations. ### Motivation and Context See above.
1 parent 817a44f commit 8f8069d

File tree

5 files changed

+543
-7
lines changed

5 files changed

+543
-7
lines changed

onnxruntime/core/providers/webgpu/nn/conv.cc

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// Licensed under the MIT License.
33
#include "core/providers/webgpu/nn/conv.h"
44
#include "core/providers/webgpu/nn/conv2d_mm.h"
5+
#include "core/providers/webgpu/nn/im2col_matmul.h"
56
#include "core/providers/webgpu/shader_helper.h"
67
#include "core/providers/webgpu/webgpu_supported_types.h"
78
#include "core/providers/webgpu/tensor/transpose.h"
@@ -99,10 +100,34 @@ Status Conv<is_channels_last, is_fused>::ComputeInternal(ComputeContext& context
99100
modified_input_output_shapes.push_back(bias->Shape());
100101
}
101102
modified_input_output_shapes.push_back(TensorShape(output_shape_vector));
103+
104+
const auto input_height = input_shape[is_channels_last ? 1 : 2];
105+
const auto input_width = input_shape[is_channels_last ? 2 : 3];
106+
const auto input_channels = input_shape[is_channels_last ? 3 : 1];
107+
const auto kernel_height = kernel_shape[2];
108+
const auto kernel_width = kernel_shape[3];
109+
const auto output_height = output_shape_vector[is_channels_last ? 1 : 2];
110+
const auto output_width = output_shape_vector[is_channels_last ? 2 : 3];
111+
102112
uint32_t auto_pad_adjust = conv_attrs_.auto_pad == AutoPadType::SAME_LOWER ? 1 : 0;
103113
auto pad0 = conv_attrs_.auto_pad == AutoPadType::NOTSET ? pads[0] : (pads[0] + pads[2] + auto_pad_adjust) / 2;
104114
auto pad1 = conv_attrs_.auto_pad == AutoPadType::NOTSET ? pads[1] : (pads[1] + pads[3] + auto_pad_adjust) / 2;
105115
std::vector<uint32_t> updated_pads{pad0, pad1};
116+
117+
if (CanApplyIm2ColMatMulProgram(context,
118+
is_channels_last,
119+
activation_.activation_kind_,
120+
kernel_shape,
121+
conv_attrs_.auto_pad,
122+
onnxruntime::narrow<uint32_t>(conv_attrs_.group))) {
123+
return ApplyIm2ColMatMulProgram(context,
124+
is_channels_last,
125+
dilations,
126+
pads,
127+
strides,
128+
output);
129+
}
130+
106131
if (conv_attrs_.group > 1) {
107132
Tensor transposed_kernel;
108133
if (is_channels_last) {
@@ -128,13 +153,6 @@ Status Conv<is_channels_last, is_fused>::ComputeInternal(ComputeContext& context
128153
}
129154
return context.RunProgram(program);
130155
}
131-
const auto input_height = input_shape[is_channels_last ? 1 : 2];
132-
const auto input_width = input_shape[is_channels_last ? 2 : 3];
133-
const auto input_channels = input_shape[is_channels_last ? 3 : 1];
134-
const auto kernel_height = kernel_shape[2];
135-
const auto kernel_width = kernel_shape[3];
136-
const auto output_height = output_shape_vector[is_channels_last ? 1 : 2];
137-
const auto output_width = output_shape_vector[is_channels_last ? 2 : 3];
138156

139157
const auto same_size = is_channels_last && input_height == kernel_height && input_width == kernel_width && pads[0] == 0 && pads[1] == 0;
140158
if (same_size || (kernel_height == 1 && kernel_width == 1 && pads[0] == 0 && pads[1] == 0 && strides[0] == 1 && strides[1] == 1)) {
Lines changed: 233 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,233 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
#include <string>
4+
#include <utility>
5+
#include <vector>
6+
7+
#include "core/providers/webgpu/webgpu_utils.h"
8+
#include "core/providers/webgpu/nn/im2col_matmul.h"
9+
#include "core/providers/webgpu/nn/activation_util.h"
10+
11+
namespace onnxruntime {
12+
namespace webgpu {
13+
14+
namespace {
15+
16+
// TODO: move to common header.
17+
template <typename T>
18+
inline T ceil_div(T numerator, T denominator) {
19+
return (numerator + denominator - 1) / denominator;
20+
}
21+
22+
// Chooses the optimal tile size (M, N) for the im2col operation.
23+
// This tile size is performance-tuned and varies depending on the target device.
24+
std::pair<uint32_t, uint32_t> ChooseTileSize(uint32_t im2col_m, uint32_t im2col_n) {
25+
// Define a list of preferred (tile_m, tile_n) pairs in descending order of preference.
26+
const std::vector<std::pair<uint32_t, uint32_t>> kTileSizes = {
27+
std::make_pair(32, 64),
28+
std::make_pair(16, 64),
29+
};
30+
31+
for (const auto& tile_pair : kTileSizes) {
32+
const uint32_t tile_m = tile_pair.first;
33+
const uint32_t tile_n = tile_pair.second;
34+
35+
const uint32_t dispatch_m = ceil_div(im2col_m, tile_m);
36+
const uint32_t dispatch_n = ceil_div(im2col_n, tile_n);
37+
const uint32_t dispatch = dispatch_m * dispatch_n;
38+
39+
if (dispatch >= 128) {
40+
return tile_pair;
41+
}
42+
}
43+
44+
// If none of the tile sizes meet the dispatch >=128 requirement,
45+
return kTileSizes.back();
46+
}
47+
48+
// Add support for more devices.
49+
bool IsDeviceSupported(ComputeContext& context) {
50+
const wgpu::AdapterInfo& adapter_info = context.AdapterInfo();
51+
52+
if (adapter_info.vendor == std::string_view("intel")) {
53+
if (adapter_info.architecture == std::string_view("xe-2lpg")) {
54+
return true;
55+
}
56+
}
57+
58+
return false;
59+
}
60+
61+
} // namespace
62+
63+
Status OIHW2OHWIProgram::GenerateShaderCode(ShaderHelper& shader) const {
64+
const auto& src = shader.AddInput("src", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
65+
const auto& output = shader.AddOutput("output", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
66+
67+
return WGSL_TEMPLATE_APPLY(shader, "nn/oihw_to_ohwi.wgsl.template",
68+
WGSL_TEMPLATE_VARIABLE(output, output),
69+
WGSL_TEMPLATE_VARIABLE(src, src));
70+
}
71+
72+
Status Im2ColMatMulProgram::GenerateShaderCode(ShaderHelper& shader) const {
73+
const auto& src = shader.AddInput("src", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
74+
const auto& weight = shader.AddInput("weight", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
75+
if (has_bias_) {
76+
shader.AddInput("bias", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
77+
}
78+
const auto& output = shader.AddOutput("output", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
79+
80+
ORT_ENFORCE(tile_m_ == 16 || tile_m_ == 32, "tile_m must be 16 or 32.");
81+
ORT_ENFORCE(tile_n_ == 64, "tile_n must be 64.");
82+
83+
return WGSL_TEMPLATE_APPLY(shader, "nn/im2col_matmul.wgsl.template",
84+
WGSL_TEMPLATE_PARAMETER(has_bias, has_bias_),
85+
WGSL_TEMPLATE_PARAMETER(tile_m, tile_m_),
86+
WGSL_TEMPLATE_PARAMETER(tile_n, tile_n_),
87+
WGSL_TEMPLATE_PARAMETER(use_subgroup, use_subgroup_),
88+
WGSL_TEMPLATE_VARIABLE(output, output),
89+
WGSL_TEMPLATE_VARIABLE(src, src),
90+
WGSL_TEMPLATE_VARIABLE(weight, weight));
91+
}
92+
93+
Status ApplyIm2ColMatMulProgram(ComputeContext& context,
94+
bool is_channels_last,
95+
const std::vector<uint32_t>& dilations,
96+
const std::vector<uint32_t>& pads,
97+
const std::vector<uint32_t>& strides,
98+
Tensor* output) {
99+
const auto* src = context.Input<Tensor>(0);
100+
const auto* weight = context.Input<Tensor>(1);
101+
const bool has_bias = context.InputCount() > 2;
102+
const auto* bias = has_bias ? context.Input<Tensor>(2) : nullptr;
103+
104+
// Transpose OIHW Weight to OHWI
105+
// TODO: Move to `Transpose`
106+
// TODO: Use prepack
107+
TensorShape weight_shape = weight->Shape();
108+
const uint32_t channel_output = onnxruntime::narrow<uint32_t>(weight_shape[0]);
109+
const uint32_t channel_input = onnxruntime::narrow<uint32_t>(weight_shape[1]);
110+
const uint32_t kernel_height = onnxruntime::narrow<uint32_t>(weight_shape[2]);
111+
const uint32_t kernel_width = onnxruntime::narrow<uint32_t>(weight_shape[3]);
112+
113+
TensorShape ohwi_weight_shape{channel_output, kernel_height, kernel_width, channel_input};
114+
Tensor ohwi_weight = context.CreateGPUTensor(weight->DataType(), ohwi_weight_shape);
115+
OIHW2OHWIProgram transpose_program{};
116+
transpose_program.SetWorkgroupSize(64);
117+
118+
const uint32_t Ci_tiles = ceil_div(channel_input, 64u);
119+
transpose_program.SetDispatchGroupSize(channel_output, Ci_tiles);
120+
121+
transpose_program.AddInput({weight,
122+
ProgramTensorMetadataDependency::TypeAndRank});
123+
transpose_program.AddOutput({&ohwi_weight,
124+
ProgramTensorMetadataDependency::TypeAndRank});
125+
transpose_program.AddUniformVariables({{channel_output},
126+
{channel_input},
127+
{kernel_height},
128+
{kernel_width},
129+
{Ci_tiles},
130+
{ceil_div(kernel_height * kernel_height, 4u)}});
131+
ORT_RETURN_IF_ERROR(context.RunProgram(transpose_program));
132+
133+
// im2col-matmul
134+
const TensorShape src_shape = src->Shape();
135+
const TensorShape output_shape = output->Shape();
136+
137+
const uint32_t batch = onnxruntime::narrow<uint32_t>(src_shape[0]);
138+
const uint32_t src_height = onnxruntime::narrow<uint32_t>(src_shape[is_channels_last ? 1 : 2]);
139+
const uint32_t src_width = onnxruntime::narrow<uint32_t>(src_shape[is_channels_last ? 2 : 3]);
140+
const uint32_t output_height = onnxruntime::narrow<uint32_t>(output_shape[is_channels_last ? 1 : 2]);
141+
const uint32_t output_width = onnxruntime::narrow<uint32_t>(output_shape[is_channels_last ? 2 : 3]);
142+
143+
const uint32_t im2col_m = output_height * output_width;
144+
const uint32_t im2col_k = kernel_height * kernel_width * channel_input;
145+
const uint32_t im2col_n = channel_output;
146+
147+
const auto [tile_m, tile_n] = ChooseTileSize(im2col_m, im2col_n);
148+
const uint32_t workgroup_size = tile_n;
149+
150+
// Check the device's subgroup size before shader compilation to avoid potential performance penalties
151+
// associated with conditional checks in the shader runtime.
152+
//
153+
// Ensure the subgroup size must be greater than or equal to `tile_m` to safely enable `use_subgroup`.
154+
// If the status of this condition is uncertain, the feature must be disabled.
155+
const bool use_subgroup = false;
156+
Im2ColMatMulProgram im2col_mm_program{has_bias, tile_m, tile_n, use_subgroup};
157+
im2col_mm_program.SetWorkgroupSize(workgroup_size);
158+
159+
const uint32_t M_tiles = ceil_div(im2col_m, tile_m);
160+
const uint32_t N_tiles = ceil_div(im2col_n, tile_n);
161+
im2col_mm_program.SetDispatchGroupSize(M_tiles, N_tiles, batch);
162+
163+
im2col_mm_program.AddInput({src,
164+
ProgramTensorMetadataDependency::TypeAndRank,
165+
4});
166+
im2col_mm_program.AddInput({&ohwi_weight,
167+
ProgramTensorMetadataDependency::TypeAndRank,
168+
4});
169+
if (has_bias) {
170+
im2col_mm_program.AddInput({bias,
171+
ProgramTensorMetadataDependency::TypeAndRank});
172+
}
173+
im2col_mm_program.AddOutput({output,
174+
ProgramTensorMetadataDependency::TypeAndRank});
175+
im2col_mm_program.AddUniformVariables({{batch},
176+
{src_height},
177+
{src_width},
178+
{channel_input},
179+
{kernel_height},
180+
{kernel_width},
181+
{output_height},
182+
{output_width},
183+
{im2col_m},
184+
{im2col_k},
185+
{im2col_n},
186+
{M_tiles},
187+
{N_tiles},
188+
{ceil_div(ceil_div(im2col_k, 4u), 4u)},
189+
{dilations},
190+
{pads},
191+
{strides}});
192+
im2col_mm_program.CacheHint(has_bias, tile_m, tile_n, use_subgroup);
193+
194+
return context.RunProgram(im2col_mm_program);
195+
}
196+
197+
bool CanApplyIm2ColMatMulProgram(ComputeContext& context,
198+
const bool is_channels_last,
199+
const ActivationKind activation_kind,
200+
const TensorShape weight_shape,
201+
const AutoPadType auto_pad,
202+
const uint32_t group) {
203+
if (!IsDeviceSupported(context)) {
204+
return false;
205+
}
206+
207+
// TODO: Support !is_channels_last
208+
// TODO: Support fuse
209+
// TODO: Support auto pad
210+
// TODO: Support group conv
211+
if (!is_channels_last || activation_kind != ActivationKind::None || auto_pad != AutoPadType::NOTSET || group != 1) {
212+
return false;
213+
}
214+
215+
// TODO: Support conv1d
216+
// TODO: Support conv2d_1x1
217+
const uint32_t kernel_height = onnxruntime::narrow<uint32_t>(weight_shape[2]);
218+
const uint32_t kernel_width = onnxruntime::narrow<uint32_t>(weight_shape[3]);
219+
if (kernel_height == 1 || kernel_width == 1) {
220+
return false;
221+
}
222+
223+
// TODO: Support channel input vec1
224+
const uint32_t channel_input = onnxruntime::narrow<uint32_t>(weight_shape[1]);
225+
if (channel_input % 4 != 0) {
226+
return false;
227+
}
228+
229+
return true;
230+
}
231+
232+
} // namespace webgpu
233+
} // namespace onnxruntime
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
6+
#include <vector>
7+
8+
#include "core/framework/tensor_shape.h"
9+
#include "core/framework/tensor.h"
10+
#include "core/framework/op_kernel.h"
11+
#include "core/providers/cpu/nn/conv_attributes.h"
12+
#include "core/providers/webgpu/program.h"
13+
#include "core/providers/webgpu/webgpu_supported_types.h"
14+
#include "core/providers/webgpu/shader_helper.h"
15+
#include "core/providers/webgpu/webgpu_kernel.h"
16+
#include "core/providers/webgpu/nn/fuse_utils.h"
17+
18+
namespace onnxruntime {
19+
namespace webgpu {
20+
21+
// Transpose OIHW Weight to OHWI
22+
class OIHW2OHWIProgram final : public Program<OIHW2OHWIProgram> {
23+
public:
24+
OIHW2OHWIProgram() : Program("OIHW2OHWI") {}
25+
26+
Status GenerateShaderCode(ShaderHelper& shader) const override;
27+
28+
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES(
29+
{"O", ProgramUniformVariableDataType::Uint32},
30+
{"I", ProgramUniformVariableDataType::Uint32},
31+
{"H", ProgramUniformVariableDataType::Uint32},
32+
{"W", ProgramUniformVariableDataType::Uint32},
33+
{"Ci_tiles", ProgramUniformVariableDataType::Uint32},
34+
{"H_W_tiles", ProgramUniformVariableDataType::Uint32});
35+
};
36+
37+
class Im2ColMatMulProgram final : public Program<Im2ColMatMulProgram> {
38+
public:
39+
Im2ColMatMulProgram(bool has_bias,
40+
uint32_t tile_m,
41+
uint32_t tile_n,
42+
bool use_subgroup) : Program("Im2ColMatMul"),
43+
has_bias_(has_bias),
44+
tile_m_(tile_m),
45+
tile_n_(tile_n),
46+
use_subgroup_(use_subgroup) {}
47+
48+
Status GenerateShaderCode(ShaderHelper& shader) const override;
49+
50+
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES(
51+
{"batch", ProgramUniformVariableDataType::Uint32},
52+
{"src_h", ProgramUniformVariableDataType::Uint32},
53+
{"src_w", ProgramUniformVariableDataType::Uint32},
54+
{"channel_i", ProgramUniformVariableDataType::Uint32},
55+
{"kernel_h", ProgramUniformVariableDataType::Uint32},
56+
{"kernel_w", ProgramUniformVariableDataType::Uint32},
57+
{"output_h", ProgramUniformVariableDataType::Uint32},
58+
{"output_w", ProgramUniformVariableDataType::Uint32},
59+
{"im2col_m", ProgramUniformVariableDataType::Uint32},
60+
{"im2col_k", ProgramUniformVariableDataType::Uint32},
61+
{"im2col_n", ProgramUniformVariableDataType::Uint32},
62+
{"M_tiles", ProgramUniformVariableDataType::Uint32},
63+
{"N_tiles", ProgramUniformVariableDataType::Uint32},
64+
{"K_tiles", ProgramUniformVariableDataType::Uint32},
65+
{"dilations", ProgramUniformVariableDataType::Uint32},
66+
{"pads", ProgramUniformVariableDataType::Uint32},
67+
{"strides", ProgramUniformVariableDataType::Uint32});
68+
69+
private:
70+
bool has_bias_;
71+
72+
uint32_t tile_m_;
73+
uint32_t tile_n_;
74+
bool use_subgroup_;
75+
};
76+
77+
bool CanApplyIm2ColMatMulProgram(ComputeContext& context,
78+
const bool is_channels_last,
79+
const ActivationKind activation_kind,
80+
const TensorShape kernel_shape,
81+
const AutoPadType auto_pad,
82+
const uint32_t group);
83+
84+
Status ApplyIm2ColMatMulProgram(ComputeContext& context,
85+
const bool is_channels_last,
86+
const std::vector<uint32_t>& dilations,
87+
const std::vector<uint32_t>& pads,
88+
const std::vector<uint32_t>& strides,
89+
Tensor* output);
90+
91+
} // namespace webgpu
92+
} // namespace onnxruntime

0 commit comments

Comments
 (0)