Skip to content

Commit 60c50a4

Browse files
committed
Add weight layout transformation cache for Conv operator
Implement lazy weight layout transformation for WebGPU Conv kernel to avoid redundant GPU transposes on every inference. Key changes: - Add WeightLayoutTransformCache to cache transformed weights by name and format - Implement TransformWeightLayout() helper using existing TransposeKernel for OIHW->HWIO transformation - Cache stored in WebGpuExecutionProvider, shared across all kernels
1 parent d55ade0 commit 60c50a4

File tree

9 files changed

+260
-17
lines changed

9 files changed

+260
-17
lines changed

onnxruntime/core/providers/webgpu/compute_context.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,13 @@ class ComputeContext final {
139139
return webgpu_context_.Run(*this, program);
140140
}
141141

142+
//
143+
// Get the execution provider.
144+
//
145+
inline const WebGpuExecutionProvider& GetExecutionProvider() const {
146+
return ep_;
147+
}
148+
142149
private:
143150
WebGpuContext& webgpu_context_;
144151
OpKernelContext& kernel_context_;

onnxruntime/core/providers/webgpu/nn/conv.cc

Lines changed: 47 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
#include "core/providers/webgpu/nn/grouped_conv.h"
99
#include "core/providers/webgpu/webgpu_utils.h"
1010
#include "core/providers/webgpu/math/matmul.h"
11+
#include "core/providers/webgpu/weight_layout_transform.h"
12+
#include "core/providers/webgpu/webgpu_execution_provider.h"
1113

1214
namespace onnxruntime {
1315
namespace webgpu {
@@ -25,6 +27,37 @@ Status TransposeKernel(ComputeContext& context, const Tensor* kernel, const Tens
2527
return Transpose::DoTranspose(context, perm, reshaped_kernel, *transposed_kernel);
2628
}
2729

30+
template <bool is_channels_last, bool is_fused>
31+
Status Conv<is_channels_last, is_fused>::GetTransformedWeight(ComputeContext& context,
32+
const Tensor* original_weight,
33+
const Tensor*& transformed_weight) const {
34+
// Return cached weight if already transformed
35+
if (transformed_weight_) {
36+
transformed_weight = transformed_weight_;
37+
return Status::OK();
38+
}
39+
40+
// If transformation was attempted but failed, return error
41+
if (weight_transform_attempted_) {
42+
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Weight transformation previously failed");
43+
}
44+
45+
weight_transform_attempted_ = true;
46+
47+
// Use the weight input name extracted during construction
48+
ORT_ENFORCE(!weight_name_.empty(), "Weight input name must be available for transformation caching");
49+
50+
// Get cache from execution provider
51+
auto& cache = context.GetExecutionProvider().GetWeightLayoutTransformCache();
52+
53+
// Transform weight to HWIO layout
54+
ORT_RETURN_IF_ERROR(TransformWeightLayout(context, original_weight, weight_name_,
55+
"hwio", cache, transformed_weight_));
56+
57+
transformed_weight = transformed_weight_;
58+
return Status::OK();
59+
}
60+
2861
template <bool is_channels_last, bool is_fused>
2962
Status Conv<is_channels_last, is_fused>::ComputeInternal(ComputeContext& context) const {
3063
bool has_bias = context.InputCount() > 2;
@@ -104,11 +137,11 @@ Status Conv<is_channels_last, is_fused>::ComputeInternal(ComputeContext& context
104137
auto pad1 = conv_attrs_.auto_pad == AutoPadType::NOTSET ? pads[1] : (pads[1] + pads[3] + auto_pad_adjust) / 2;
105138
std::vector<uint32_t> updated_pads{pad0, pad1};
106139
if (conv_attrs_.group > 1) {
107-
Tensor transposed_kernel;
108140
if (is_channels_last) {
109-
ORT_RETURN_IF_ERROR(TransposeKernel(context, kernel, kernel_shape, &transposed_kernel, perm));
110-
inputs[1] = &transposed_kernel;
111-
modified_input_output_shapes[1] = transposed_kernel.Shape();
141+
const Tensor* kernel_to_use = nullptr;
142+
ORT_RETURN_IF_ERROR(GetTransformedWeight(context, kernel, kernel_to_use));
143+
inputs[1] = kernel_to_use;
144+
modified_input_output_shapes[1] = kernel_to_use->Shape();
112145
}
113146
auto output_channels_per_group = output_channels / conv_attrs_.group;
114147
auto components = static_cast<int>(is_channels_last && output_channels_per_group >= 4 ? GetMaxComponents(output_channels) : 1);
@@ -138,17 +171,16 @@ Status Conv<is_channels_last, is_fused>::ComputeInternal(ComputeContext& context
138171

139172
const auto same_size = is_channels_last && input_height == kernel_height && input_width == kernel_width && pads[0] == 0 && pads[1] == 0;
140173
if (same_size || (kernel_height == 1 && kernel_width == 1 && pads[0] == 0 && pads[1] == 0 && strides[0] == 1 && strides[1] == 1)) {
141-
Tensor transposed_kernel;
142174
TensorShape input_reshape;
143175
TensorShape kernel_reshape;
144176
TensorShape matmul_output_shape;
145177
std::vector<const Tensor*> matmul_inputs;
146178
std::vector<TensorShape> matmul_input_reshapes;
147179
if (is_channels_last) {
148-
// Transpose weights
149-
150-
ORT_RETURN_IF_ERROR(TransposeKernel(context, kernel, kernel_shape, &transposed_kernel, perm));
151-
inputs[1] = &transposed_kernel;
180+
// Transform weights to HWIO layout (cached on first inference)
181+
const Tensor* kernel_to_use = nullptr;
182+
ORT_RETURN_IF_ERROR(GetTransformedWeight(context, kernel, kernel_to_use));
183+
inputs[1] = kernel_to_use;
152184
if (same_size) {
153185
const auto shared_dim = input_height * input_width * input_channels;
154186
input_reshape = TensorShape({1, batch, shared_dim});
@@ -160,7 +192,7 @@ Status Conv<is_channels_last, is_fused>::ComputeInternal(ComputeContext& context
160192
matmul_output_shape = TensorShape({batch, output_height * output_width, output_channels});
161193
}
162194
matmul_inputs.push_back(input);
163-
matmul_inputs.push_back(&transposed_kernel);
195+
matmul_inputs.push_back(kernel_to_use);
164196
matmul_input_reshapes.push_back(input_reshape);
165197
matmul_input_reshapes.push_back(kernel_reshape);
166198
} else {
@@ -204,15 +236,14 @@ Status Conv<is_channels_last, is_fused>::ComputeInternal(ComputeContext& context
204236
return context.RunProgram(program);
205237
}
206238
}
207-
// Transpose weights
208-
Tensor transposed_kernel;
209-
ORT_RETURN_IF_ERROR(TransposeKernel(context, kernel, kernel_shape, &transposed_kernel, perm));
239+
// Transpose weights - use cached transformation
240+
const Tensor* kernel_to_use = nullptr;
241+
ORT_RETURN_IF_ERROR(GetTransformedWeight(context, kernel, kernel_to_use));
242+
inputs[1] = kernel_to_use;
243+
modified_input_output_shapes[1] = kernel_to_use->Shape();
210244
auto dim_a_outer = static_cast<uint32_t>(is_channels_last ? output_height * output_width : output_channels);
211245
auto dim_b_outer = static_cast<uint32_t>(is_channels_last ? output_channels : output_height * output_width);
212246
auto dim_inner = static_cast<uint32_t>(kernel_height * kernel_width * input_channels);
213-
inputs[1] = &transposed_kernel;
214-
TensorShape transposed_kernel_shape = transposed_kernel.Shape();
215-
modified_input_output_shapes[1] = transposed_kernel.Shape();
216247
Conv2dMMProgram conv2d_mm_program = CreateConv2dMMProgram(activation_, inputs, pads, strides, dilations, output, dim_a_outer, dim_b_outer, dim_inner, is_channels_last, modified_input_output_shapes);
217248
return context.RunProgram(conv2d_mm_program);
218249
}

onnxruntime/core/providers/webgpu/nn/conv.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,26 @@ class Conv : public WebGpuKernel {
2020
if (is_fused) {
2121
ORT_ENFORCE(GetFusedActivationAttr(info, activation_).IsOK());
2222
}
23+
// Extract weight input name (input index 1) for caching
24+
const auto& input_defs = info.node().InputDefs();
25+
if (input_defs.size() > 1 && input_defs[1]->Exists()) {
26+
weight_name_ = input_defs[1]->Name();
27+
}
2328
}
2429
Status ComputeInternal(ComputeContext& context) const override;
2530

2631
protected:
2732
ConvAttributes conv_attrs_;
2833
Activation activation_;
34+
std::string weight_name_; // Name of weight input for cache key
35+
36+
// Cached transformed weight pointer (set on first inference)
37+
mutable const Tensor* transformed_weight_ = nullptr;
38+
mutable bool weight_transform_attempted_ = false;
39+
40+
// Get or create transformed weight (lazy transformation on first inference)
41+
Status GetTransformedWeight(ComputeContext& context, const Tensor* original_weight,
42+
const Tensor*& transformed_weight) const;
2943
};
3044

3145
Status TransposeKernel(ComputeContext& context, const Tensor* kernel, const TensorShape& kernel_shape, Tensor* transposed_kernel, const InlinedVector<size_t>& perm);

onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -800,7 +800,8 @@ WebGpuExecutionProvider::WebGpuExecutionProvider(int context_id,
800800
context_{context},
801801
preferred_data_layout_{config.data_layout},
802802
force_cpu_node_names_{std::move(config.force_cpu_node_names)},
803-
enable_graph_capture_{config.enable_graph_capture} {
803+
enable_graph_capture_{config.enable_graph_capture},
804+
weight_layout_transform_cache_{std::make_unique<WeightLayoutTransformCache>()} {
804805
// If graph capture is enabled, create a dedicated buffer manager for graph mode
805806
if (enable_graph_capture_) {
806807
// Create buffer manager for graph capture mode with appropriate cache modes
@@ -948,6 +949,12 @@ std::optional<bool> WebGpuExecutionProvider::ShouldConvertDataLayoutForOp(std::s
948949
}
949950

950951
WebGpuExecutionProvider::~WebGpuExecutionProvider() {
952+
// Clear weight transform cache before releasing WebGPU resources
953+
// This ensures cached GPU tensors are freed while BufferManager is still valid
954+
if (weight_layout_transform_cache_) {
955+
weight_layout_transform_cache_->Clear();
956+
}
957+
951958
// Release all resources associated with the captured graph
952959
if (!captured_commands_.empty()) {
953960
context_.ReleaseGraphResources(captured_commands_);

onnxruntime/core/providers/webgpu/webgpu_execution_provider.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "core/graph/constants.h"
1010
#include "core/providers/providers.h"
1111
#include "core/providers/webgpu/buffer_manager.h"
12+
#include "core/providers/webgpu/weight_layout_transform_cache.h"
1213

1314
struct pthreadpool;
1415
namespace onnxruntime {
@@ -85,6 +86,11 @@ class WebGpuExecutionProvider : public IExecutionProvider {
8586
Status ReplayGraph(int graph_annotation_id) override;
8687
webgpu::BufferManager& BufferManager() const;
8788

89+
// Get weight layout transform cache
90+
webgpu::WeightLayoutTransformCache& GetWeightLayoutTransformCache() const {
91+
return *weight_layout_transform_cache_;
92+
}
93+
8894
private:
8995
bool IsGraphCaptureAllowed() const;
9096
void IncrementRegularRunCountBeforeGraphCapture();
@@ -105,6 +111,9 @@ class WebGpuExecutionProvider : public IExecutionProvider {
105111

106112
// Store captured commands directly in the EP instead of in WebGpuContext
107113
std::vector<webgpu::CapturedCommandInfo> captured_commands_;
114+
115+
// Cache for transformed weights (e.g., OIHW -> HWIO)
116+
std::unique_ptr<webgpu::WeightLayoutTransformCache> weight_layout_transform_cache_;
108117
};
109118

110119
} // namespace onnxruntime
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include "core/providers/webgpu/weight_layout_transform.h"
5+
#include "core/providers/webgpu/compute_context.h"
6+
#include "core/providers/webgpu/weight_layout_transform_cache.h"
7+
#include "core/providers/webgpu/nn/conv.h" // For TransposeKernel
8+
9+
namespace onnxruntime {
10+
namespace webgpu {
11+
12+
Status TransformWeightLayout(
13+
ComputeContext& context,
14+
const Tensor* weight,
15+
const std::string& weight_name,
16+
const std::string& format_descriptor,
17+
WeightLayoutTransformCache& cache,
18+
/*out*/ const Tensor*& transformed_weight) {
19+
// Check cache first
20+
const auto* cached = cache.GetTransformedWeight(weight_name, format_descriptor);
21+
if (cached != nullptr) {
22+
transformed_weight = cached;
23+
return Status::OK();
24+
}
25+
26+
// Not in cache, need to transform
27+
28+
const auto& original_shape = weight->Shape();
29+
auto num_dims = original_shape.NumDimensions();
30+
31+
// Dispatch transformation based on format
32+
Tensor output_tensor;
33+
if (format_descriptor == "hwio") {
34+
// For 3D tensors, extend to 4D before transposing
35+
TensorShape input_shape_for_transpose = original_shape;
36+
if (num_dims == 3) {
37+
// Extend OIW [O, I, W] to OIHW [O, I, 1, W]
38+
TensorShapeVector extended_shape = original_shape.AsShapeVector();
39+
extended_shape.insert(extended_shape.begin() + 2, 1); // Insert H=1 at position 2
40+
input_shape_for_transpose = TensorShape(extended_shape);
41+
}
42+
43+
// Use existing TransposeKernel: OIHW [O,I,H,W] -> HWIO [H,W,I,O]
44+
// Permutation: [2, 3, 1, 0] means output[i] = input[perm[i]]
45+
// TransposeKernel creates the output tensor internally
46+
const InlinedVector<size_t> perm = {2, 3, 1, 0};
47+
ORT_RETURN_IF_ERROR(TransposeKernel(context, weight, input_shape_for_transpose,
48+
&output_tensor, perm));
49+
} else {
50+
// Add more format implementations here
51+
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
52+
"Format not yet implemented: ", format_descriptor);
53+
}
54+
55+
// Add to cache
56+
cache.AddTransformedWeight(weight_name, format_descriptor, std::move(output_tensor));
57+
58+
// Return cached tensor
59+
const auto* cached_result = cache.GetTransformedWeight(weight_name, format_descriptor);
60+
ORT_ENFORCE(cached_result != nullptr, "Failed to cache transformed weight");
61+
transformed_weight = cached_result;
62+
63+
return Status::OK();
64+
}
65+
66+
} // namespace webgpu
67+
} // namespace onnxruntime
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
6+
#include "core/common/common.h"
7+
#include "core/framework/tensor.h"
8+
#include <string>
9+
10+
namespace onnxruntime {
11+
namespace webgpu {
12+
13+
class ComputeContext;
14+
class WeightLayoutTransformCache;
15+
16+
// Transform weight tensor to specified format
17+
// Returns the transformed tensor (either from cache or newly created)
18+
Status TransformWeightLayout(
19+
ComputeContext& context,
20+
const Tensor* weight,
21+
const std::string& weight_name,
22+
const std::string& format_descriptor,
23+
WeightLayoutTransformCache& cache,
24+
/*out*/ const Tensor*& transformed_weight);
25+
26+
} // namespace webgpu
27+
} // namespace onnxruntime
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include "core/providers/webgpu/weight_layout_transform_cache.h"
5+
6+
namespace onnxruntime {
7+
namespace webgpu {
8+
9+
const Tensor* WeightLayoutTransformCache::GetTransformedWeight(
10+
const std::string& weight_name,
11+
const std::string& format_descriptor) const {
12+
std::lock_guard<std::mutex> lock(mutex_);
13+
std::string cache_key = MakeCacheKey(weight_name, format_descriptor);
14+
auto it = cache_.find(cache_key);
15+
if (it != cache_.end()) {
16+
return &it->second;
17+
}
18+
return nullptr;
19+
}
20+
21+
void WeightLayoutTransformCache::AddTransformedWeight(
22+
const std::string& weight_name,
23+
const std::string& format_descriptor,
24+
Tensor&& tensor) {
25+
std::lock_guard<std::mutex> lock(mutex_);
26+
std::string cache_key = MakeCacheKey(weight_name, format_descriptor);
27+
cache_[cache_key] = std::move(tensor);
28+
}
29+
30+
void WeightLayoutTransformCache::Clear() {
31+
std::lock_guard<std::mutex> lock(mutex_);
32+
cache_.clear();
33+
}
34+
35+
} // namespace webgpu
36+
} // namespace onnxruntime
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
6+
#include <unordered_map>
7+
#include <mutex>
8+
#include <string>
9+
#include "core/framework/tensor.h"
10+
#include "core/common/common.h"
11+
12+
namespace onnxruntime {
13+
namespace webgpu {
14+
15+
// Cache manager for transformed weights
16+
// Owned by WebGpuExecutionProvider
17+
class WeightLayoutTransformCache {
18+
public:
19+
WeightLayoutTransformCache() = default;
20+
~WeightLayoutTransformCache() = default;
21+
22+
// Get transformed weight from cache (nullptr if not found)
23+
const Tensor* GetTransformedWeight(const std::string& weight_name,
24+
const std::string& format_descriptor) const;
25+
26+
// Add transformed weight to cache
27+
void AddTransformedWeight(const std::string& weight_name,
28+
const std::string& format_descriptor,
29+
Tensor&& tensor);
30+
31+
// Clear cache (must be called before BufferManager is destroyed)
32+
void Clear();
33+
34+
private:
35+
std::string MakeCacheKey(const std::string& weight_name,
36+
const std::string& format) const {
37+
return weight_name + ":" + format;
38+
}
39+
40+
mutable std::mutex mutex_;
41+
std::unordered_map<std::string, Tensor> cache_;
42+
};
43+
44+
} // namespace webgpu
45+
} // namespace onnxruntime

0 commit comments

Comments
 (0)