-
Notifications
You must be signed in to change notification settings - Fork 3.6k
[webgpu] Optimize Conv by im2col-matmul #26603
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
b1e5290
2f7487e
4efeff4
e6d48e4
6a4bede
07073e1
749c6e5
58a38c3
17988a4
8eb8ecc
aa7fa8e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,233 @@ | ||
| // Copyright (c) Microsoft Corporation. All rights reserved. | ||
| // Licensed under the MIT License. | ||
| #include <string> | ||
| #include <utility> | ||
| #include <vector> | ||
|
|
||
| #include "core/providers/webgpu/webgpu_utils.h" | ||
| #include "core/providers/webgpu/nn/im2col_matmul.h" | ||
| #include "core/providers/webgpu/nn/activation_util.h" | ||
|
|
||
| namespace onnxruntime { | ||
| namespace webgpu { | ||
|
|
||
| namespace { | ||
|
|
||
| // TODO: move to common header. | ||
| template <typename T> | ||
| inline T ceil_div(T numerator, T denominator) { | ||
| return (numerator + denominator - 1) / denominator; | ||
| } | ||
|
|
||
| // Chooses the optimal tile size (M, N) for the im2col operation. | ||
| // This tile size is performance-tuned and varies depending on the target device. | ||
| std::pair<uint32_t, uint32_t> ChooseTileSize(uint32_t im2col_m, uint32_t im2col_n) { | ||
| // Define a list of preferred (tile_m, tile_n) pairs in descending order of preference. | ||
| const std::vector<std::pair<uint32_t, uint32_t>> kTileSizes = { | ||
| std::make_pair(32, 64), | ||
| std::make_pair(16, 64), | ||
| }; | ||
|
|
||
| for (const auto& tile_pair : kTileSizes) { | ||
| const uint32_t tile_m = tile_pair.first; | ||
| const uint32_t tile_n = tile_pair.second; | ||
|
|
||
| const uint32_t dispatch_m = ceil_div(im2col_m, tile_m); | ||
| const uint32_t dispatch_n = ceil_div(im2col_n, tile_n); | ||
| const uint32_t dispatch = dispatch_m * dispatch_n; | ||
|
|
||
| if (dispatch >= 128) { | ||
| return tile_pair; | ||
| } | ||
| } | ||
|
|
||
| // If none of the tile sizes meet the dispatch >=128 requirement, | ||
| return kTileSizes.back(); | ||
| } | ||
|
|
||
| // Add support for more devices. | ||
| bool IsDeviceSupported(ComputeContext& context) { | ||
| const wgpu::AdapterInfo& adapter_info = context.AdapterInfo(); | ||
|
|
||
| if (adapter_info.vendor == std::string_view("intel")) { | ||
| if (adapter_info.architecture == std::string_view("xe-2lpg")) { | ||
| return true; | ||
| } | ||
| } | ||
|
|
||
| return false; | ||
| } | ||
|
|
||
| } // namespace | ||
|
|
||
| Status OIHW2OHWIProgram::GenerateShaderCode(ShaderHelper& shader) const { | ||
| const auto& src = shader.AddInput("src", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); | ||
| const auto& output = shader.AddOutput("output", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); | ||
|
|
||
| return WGSL_TEMPLATE_APPLY(shader, "nn/oihw_to_ohwi.wgsl.template", | ||
| WGSL_TEMPLATE_VARIABLE(output, output), | ||
| WGSL_TEMPLATE_VARIABLE(src, src)); | ||
| } | ||
|
|
||
| Status Im2ColMatMulProgram::GenerateShaderCode(ShaderHelper& shader) const { | ||
| const auto& src = shader.AddInput("src", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); | ||
| const auto& weight = shader.AddInput("weight", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); | ||
| if (has_bias_) { | ||
| shader.AddInput("bias", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); | ||
| } | ||
| const auto& output = shader.AddOutput("output", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); | ||
|
|
||
| ORT_ENFORCE(tile_m_ == 16 || tile_m_ == 32, "tile_m must be 16 or 32."); | ||
| ORT_ENFORCE(tile_n_ == 64, "tile_n must be 64."); | ||
|
|
||
| return WGSL_TEMPLATE_APPLY(shader, "nn/im2col_matmul.wgsl.template", | ||
| WGSL_TEMPLATE_PARAMETER(has_bias, has_bias_), | ||
| WGSL_TEMPLATE_PARAMETER(tile_m, tile_m_), | ||
| WGSL_TEMPLATE_PARAMETER(tile_n, tile_n_), | ||
| WGSL_TEMPLATE_PARAMETER(use_subgroup, use_subgroup_), | ||
| WGSL_TEMPLATE_VARIABLE(output, output), | ||
| WGSL_TEMPLATE_VARIABLE(src, src), | ||
| WGSL_TEMPLATE_VARIABLE(weight, weight)); | ||
| } | ||
|
|
||
| Status ApplyIm2ColMatMulProgram(ComputeContext& context, | ||
| bool is_channels_last, | ||
| const std::vector<uint32_t>& dilations, | ||
| const std::vector<uint32_t>& pads, | ||
| const std::vector<uint32_t>& strides, | ||
| Tensor* output) { | ||
| const auto* src = context.Input<Tensor>(0); | ||
| const auto* weight = context.Input<Tensor>(1); | ||
| const bool has_bias = context.InputCount() > 2; | ||
| const auto* bias = has_bias ? context.Input<Tensor>(2) : nullptr; | ||
|
|
||
| // Transpose OIHW Weight to OHWI | ||
| // TODO: Move to `Transpose` | ||
| // TODO: Use prepack | ||
| TensorShape weight_shape = weight->Shape(); | ||
| const uint32_t channel_output = onnxruntime::narrow<uint32_t>(weight_shape[0]); | ||
| const uint32_t channel_input = onnxruntime::narrow<uint32_t>(weight_shape[1]); | ||
| const uint32_t kernel_height = onnxruntime::narrow<uint32_t>(weight_shape[2]); | ||
| const uint32_t kernel_width = onnxruntime::narrow<uint32_t>(weight_shape[3]); | ||
|
|
||
| TensorShape ohwi_weight_shape{channel_output, kernel_height, kernel_width, channel_input}; | ||
| Tensor ohwi_weight = context.CreateGPUTensor(weight->DataType(), ohwi_weight_shape); | ||
| OIHW2OHWIProgram transpose_program{}; | ||
| transpose_program.SetWorkgroupSize(64); | ||
|
|
||
| const uint32_t Ci_tiles = ceil_div(channel_input, 64u); | ||
| transpose_program.SetDispatchGroupSize(channel_output, Ci_tiles); | ||
|
|
||
| transpose_program.AddInput({weight, | ||
| ProgramTensorMetadataDependency::TypeAndRank}); | ||
| transpose_program.AddOutput({&ohwi_weight, | ||
| ProgramTensorMetadataDependency::TypeAndRank}); | ||
| transpose_program.AddUniformVariables({{channel_output}, | ||
| {channel_input}, | ||
| {kernel_height}, | ||
| {kernel_width}, | ||
| {Ci_tiles}, | ||
| {ceil_div(kernel_height * kernel_height, 4u)}}); | ||
| ORT_RETURN_IF_ERROR(context.RunProgram(transpose_program)); | ||
|
|
||
| // im2col-matmul | ||
| const TensorShape src_shape = src->Shape(); | ||
| const TensorShape output_shape = output->Shape(); | ||
|
|
||
| const uint32_t batch = onnxruntime::narrow<uint32_t>(src_shape[0]); | ||
| const uint32_t src_height = onnxruntime::narrow<uint32_t>(src_shape[is_channels_last ? 1 : 2]); | ||
| const uint32_t src_width = onnxruntime::narrow<uint32_t>(src_shape[is_channels_last ? 2 : 3]); | ||
| const uint32_t output_height = onnxruntime::narrow<uint32_t>(output_shape[is_channels_last ? 1 : 2]); | ||
| const uint32_t output_width = onnxruntime::narrow<uint32_t>(output_shape[is_channels_last ? 2 : 3]); | ||
|
|
||
| const uint32_t im2col_m = output_height * output_width; | ||
| const uint32_t im2col_k = kernel_height * kernel_width * channel_input; | ||
| const uint32_t im2col_n = channel_output; | ||
|
|
||
| const auto [tile_m, tile_n] = ChooseTileSize(im2col_m, im2col_n); | ||
| const uint32_t workgroup_size = tile_n; | ||
|
|
||
| // Check the device's subgroup size before shader compilation to avoid potential performance penalties | ||
| // associated with conditional checks in the shader runtime. | ||
| // | ||
| // Ensure the subgroup size must be greater than or equal to `tile_m` to safely enable `use_subgroup`. | ||
| // If the status of this condition is uncertain, the feature must be disabled. | ||
| const bool use_subgroup = false; | ||
| Im2ColMatMulProgram im2col_mm_program{has_bias, tile_m, tile_n, use_subgroup}; | ||
| im2col_mm_program.SetWorkgroupSize(workgroup_size); | ||
|
|
||
| const uint32_t M_tiles = ceil_div(im2col_m, tile_m); | ||
| const uint32_t N_tiles = ceil_div(im2col_n, tile_n); | ||
| im2col_mm_program.SetDispatchGroupSize(M_tiles, N_tiles, batch); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How about enhancing the current TransposeProgram with shared path instead of adding a new one?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Understood.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. FYI, prepack is supported with this PR #26602. You can move the weights transpose to prepack to get better performance. |
||
|
|
||
| im2col_mm_program.AddInput({src, | ||
| ProgramTensorMetadataDependency::TypeAndRank, | ||
| 4}); | ||
| im2col_mm_program.AddInput({&ohwi_weight, | ||
| ProgramTensorMetadataDependency::TypeAndRank, | ||
| 4}); | ||
| if (has_bias) { | ||
| im2col_mm_program.AddInput({bias, | ||
| ProgramTensorMetadataDependency::TypeAndRank}); | ||
| } | ||
| im2col_mm_program.AddOutput({output, | ||
| ProgramTensorMetadataDependency::TypeAndRank}); | ||
| im2col_mm_program.AddUniformVariables({{batch}, | ||
| {src_height}, | ||
| {src_width}, | ||
| {channel_input}, | ||
| {kernel_height}, | ||
| {kernel_width}, | ||
| {output_height}, | ||
| {output_width}, | ||
| {im2col_m}, | ||
| {im2col_k}, | ||
| {im2col_n}, | ||
| {M_tiles}, | ||
| {N_tiles}, | ||
| {ceil_div(ceil_div(im2col_k, 4u), 4u)}, | ||
| {dilations}, | ||
| {pads}, | ||
| {strides}}); | ||
| im2col_mm_program.CacheHint(has_bias, tile_m, tile_n, use_subgroup); | ||
|
|
||
| return context.RunProgram(im2col_mm_program); | ||
| } | ||
|
|
||
| bool CanApplyIm2ColMatMulProgram(ComputeContext& context, | ||
| const bool is_channels_last, | ||
| const ActivationKind activation_kind, | ||
| const TensorShape weight_shape, | ||
| const AutoPadType auto_pad, | ||
| const uint32_t group) { | ||
| if (!IsDeviceSupported(context)) { | ||
| return false; | ||
| } | ||
|
|
||
| // TODO: Support !is_channels_last | ||
| // TODO: Support fuse | ||
| // TODO: Support auto pad | ||
| // TODO: Support group conv | ||
| if (!is_channels_last || activation_kind != ActivationKind::None || auto_pad != AutoPadType::NOTSET || group != 1) { | ||
| return false; | ||
| } | ||
|
|
||
| // TODO: Support conv1d | ||
| // TODO: Support conv2d_1x1 | ||
| const uint32_t kernel_height = onnxruntime::narrow<uint32_t>(weight_shape[2]); | ||
| const uint32_t kernel_width = onnxruntime::narrow<uint32_t>(weight_shape[3]); | ||
| if (kernel_height == 1 || kernel_width == 1) { | ||
| return false; | ||
| } | ||
|
|
||
| // TODO: Support channel input vec1 | ||
| const uint32_t channel_input = onnxruntime::narrow<uint32_t>(weight_shape[1]); | ||
| if (channel_input % 4 != 0) { | ||
| return false; | ||
| } | ||
|
|
||
| return true; | ||
| } | ||
|
|
||
| } // namespace webgpu | ||
| } // namespace onnxruntime | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,92 @@ | ||
| // Copyright (c) Microsoft Corporation. All rights reserved. | ||
| // Licensed under the MIT License. | ||
|
|
||
| #pragma once | ||
|
|
||
| #include <vector> | ||
|
|
||
| #include "core/framework/tensor_shape.h" | ||
| #include "core/framework/tensor.h" | ||
| #include "core/framework/op_kernel.h" | ||
| #include "core/providers/cpu/nn/conv_attributes.h" | ||
| #include "core/providers/webgpu/program.h" | ||
| #include "core/providers/webgpu/webgpu_supported_types.h" | ||
| #include "core/providers/webgpu/shader_helper.h" | ||
| #include "core/providers/webgpu/webgpu_kernel.h" | ||
| #include "core/providers/webgpu/nn/fuse_utils.h" | ||
|
|
||
| namespace onnxruntime { | ||
| namespace webgpu { | ||
|
|
||
| // Transpose OIHW Weight to OHWI | ||
| class OIHW2OHWIProgram final : public Program<OIHW2OHWIProgram> { | ||
| public: | ||
| OIHW2OHWIProgram() : Program("OIHW2OHWI") {} | ||
|
|
||
| Status GenerateShaderCode(ShaderHelper& shader) const override; | ||
|
|
||
| WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( | ||
| {"O", ProgramUniformVariableDataType::Uint32}, | ||
| {"I", ProgramUniformVariableDataType::Uint32}, | ||
| {"H", ProgramUniformVariableDataType::Uint32}, | ||
| {"W", ProgramUniformVariableDataType::Uint32}, | ||
| {"Ci_tiles", ProgramUniformVariableDataType::Uint32}, | ||
| {"H_W_tiles", ProgramUniformVariableDataType::Uint32}); | ||
| }; | ||
|
|
||
| class Im2ColMatMulProgram final : public Program<Im2ColMatMulProgram> { | ||
| public: | ||
| Im2ColMatMulProgram(bool has_bias, | ||
| uint32_t tile_m, | ||
| uint32_t tile_n, | ||
| bool use_subgroup) : Program("Im2ColMatMul"), | ||
| has_bias_(has_bias), | ||
| tile_m_(tile_m), | ||
| tile_n_(tile_n), | ||
| use_subgroup_(use_subgroup) {} | ||
|
|
||
| Status GenerateShaderCode(ShaderHelper& shader) const override; | ||
|
|
||
| WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( | ||
| {"batch", ProgramUniformVariableDataType::Uint32}, | ||
| {"src_h", ProgramUniformVariableDataType::Uint32}, | ||
| {"src_w", ProgramUniformVariableDataType::Uint32}, | ||
| {"channel_i", ProgramUniformVariableDataType::Uint32}, | ||
| {"kernel_h", ProgramUniformVariableDataType::Uint32}, | ||
| {"kernel_w", ProgramUniformVariableDataType::Uint32}, | ||
| {"output_h", ProgramUniformVariableDataType::Uint32}, | ||
| {"output_w", ProgramUniformVariableDataType::Uint32}, | ||
| {"im2col_m", ProgramUniformVariableDataType::Uint32}, | ||
| {"im2col_k", ProgramUniformVariableDataType::Uint32}, | ||
| {"im2col_n", ProgramUniformVariableDataType::Uint32}, | ||
| {"M_tiles", ProgramUniformVariableDataType::Uint32}, | ||
| {"N_tiles", ProgramUniformVariableDataType::Uint32}, | ||
| {"K_tiles", ProgramUniformVariableDataType::Uint32}, | ||
| {"dilations", ProgramUniformVariableDataType::Uint32}, | ||
| {"pads", ProgramUniformVariableDataType::Uint32}, | ||
| {"strides", ProgramUniformVariableDataType::Uint32}); | ||
|
|
||
| private: | ||
| bool has_bias_; | ||
|
|
||
| uint32_t tile_m_; | ||
| uint32_t tile_n_; | ||
| bool use_subgroup_; | ||
| }; | ||
|
|
||
| bool CanApplyIm2ColMatMulProgram(ComputeContext& context, | ||
| const bool is_channels_last, | ||
| const ActivationKind activation_kind, | ||
| const TensorShape kernel_shape, | ||
| const AutoPadType auto_pad, | ||
| const uint32_t group); | ||
|
|
||
| Status ApplyIm2ColMatMulProgram(ComputeContext& context, | ||
| const bool is_channels_last, | ||
| const std::vector<uint32_t>& dilations, | ||
| const std::vector<uint32_t>& pads, | ||
| const std::vector<uint32_t>& strides, | ||
| Tensor* output); | ||
|
|
||
| } // namespace webgpu | ||
| } // namespace onnxruntime |
Uh oh!
There was an error while loading. Please reload this page.