Skip to content

Commit 0a89439

Browse files
committed
refactor - 1
1 parent 4dbb05f commit 0a89439

File tree

15 files changed

+232
-85
lines changed

15 files changed

+232
-85
lines changed

onnxruntime/core/providers/webgpu/allocator.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ GpuBufferAllocator::GpuBufferAllocator(const BufferManager& buffer_manager, bool
1313
OrtMemoryInfo(WEBGPU_BUFFER,
1414
is_read_only_allocator ? OrtAllocatorType::OrtReadOnlyAllocator
1515
: OrtAllocatorType::OrtDeviceAllocator,
16-
OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NONE, 0),
16+
WebGpuDevice,
1717
OrtMemTypeDefault)),
1818
buffer_manager_{buffer_manager},
1919
mapped_at_creation_{is_read_only_allocator && buffer_manager.SupportsUMA()} {

onnxruntime/core/providers/webgpu/allocator.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,11 @@ namespace webgpu {
1111

1212
class BufferManager;
1313

14+
inline constexpr OrtDevice WebGpuDevice{OrtDevice::GPU,
15+
OrtDevice::MemType::DEFAULT,
16+
OrtDevice::VendorIds::NONE,
17+
0};
18+
1419
class GpuBufferAllocator : public IAllocator {
1520
public:
1621
GpuBufferAllocator(const BufferManager& buffer_manager, bool is_read_only_allocator);

onnxruntime/core/providers/webgpu/compute_context.cc

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,23 +6,30 @@
66

77
namespace onnxruntime {
88
namespace webgpu {
9-
ComputeContext::ComputeContext(OpKernelContext& kernel_context,
10-
const OpKernel& op_kernel,
11-
const WebGpuExecutionProvider& ep,
12-
WebGpuContext& webgpu_context)
9+
10+
ComputeContextBase::ComputeContextBase(WebGpuContext& webgpu_context,
11+
const WebGpuExecutionProvider& ep,
12+
const OpKernel& op_kernel)
1313
: webgpu_context_{webgpu_context},
14-
kernel_context_{kernel_context},
15-
op_kernel_{op_kernel},
16-
ep_{ep} {
14+
ep_{ep},
15+
op_kernel_{op_kernel} {
1716
}
1817

19-
const webgpu::BufferManager& ComputeContext::BufferManagerAccessor::Get(const ComputeContext& context) {
18+
const webgpu::BufferManager& ComputeContextBase::BufferManagerAccessor::Get(const ComputeContextBase& context) {
2019
return context.ep_.BufferManager();
2120
}
2221

23-
const SplitKConfig& ComputeContext::GetSplitKConfig() {
22+
const SplitKConfig& ComputeContextBase::GetSplitKConfig() {
2423
return webgpu_context_.GetSplitKConfig();
2524
}
2625

26+
ComputeContext::ComputeContext(WebGpuContext& webgpu_context,
27+
const WebGpuExecutionProvider& ep,
28+
const OpKernel& op_kernel,
29+
OpKernelContext& kernel_context)
30+
: ComputeContextBase(webgpu_context, ep, op_kernel),
31+
kernel_context_{kernel_context} {
32+
}
33+
2734
} // namespace webgpu
2835
} // namespace onnxruntime

onnxruntime/core/providers/webgpu/compute_context.h

Lines changed: 70 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,13 @@ namespace webgpu {
2424
class WebGpuContext;
2525
class BufferManager;
2626

27-
class ComputeContext final {
27+
//
28+
// Class ComputeContextBase is designed to provide basic context information
29+
// for running a compute shader program.
30+
//
31+
// An instance of ComputeContextBase does not depend on OpKernelContext, which needs an execution frame to be created.
32+
//
33+
class ComputeContextBase {
2834
public:
2935
// Nested accessor class to provide controlled access to BufferManager
3036
class BufferManagerAccessor {
@@ -34,18 +40,31 @@ class ComputeContext final {
3440
friend class WebGpuContext;
3541

3642
private:
37-
static const webgpu::BufferManager& Get(const ComputeContext& context);
43+
static const webgpu::BufferManager& Get(const ComputeContextBase& context);
3844
};
3945

40-
ComputeContext(OpKernelContext& kernel_context,
41-
const OpKernel& op_kernel,
42-
const WebGpuExecutionProvider& ep,
43-
WebGpuContext& webgpu_context);
46+
ComputeContextBase(WebGpuContext& webgpu_context,
47+
const WebGpuExecutionProvider& ep,
48+
const OpKernel& op_kernel);
4449

45-
~ComputeContext() = default;
50+
~ComputeContextBase() = default;
51+
52+
//
53+
// Get the node name.
54+
//
55+
inline decltype(auto) NodeName() const {
56+
return op_kernel_.Node().Name();
57+
}
58+
59+
//
60+
// Get the operator type.
61+
//
62+
inline decltype(auto) OpType() const {
63+
return op_kernel_.Node().OpType();
64+
}
4665

4766
//
48-
// Get various information from the context.
67+
// Get various information from the WebGPU context.
4968
//
5069

5170
inline const wgpu::AdapterInfo& AdapterInfo() const {
@@ -57,27 +76,63 @@ class ComputeContext final {
5776
inline bool HasFeature(wgpu::FeatureName feature) const {
5877
return webgpu_context_.DeviceHasFeature(feature);
5978
}
60-
inline bool IsGraphCaptureEnabled() const {
61-
return ep_.IsGraphCaptureEnabled();
62-
}
6379
#if !defined(__wasm__)
6480
inline const wgpu::AdapterPropertiesSubgroupMatrixConfigs& SubgroupMatrixConfigs() const {
6581
return webgpu_context_.SubgroupMatrixConfigs();
6682
}
6783
#endif
6884

6985
//
70-
// Get the kernel context.
86+
// Get Split-K configuration.
7187
//
72-
inline OpKernelContext& KernelContext() {
73-
return kernel_context_;
88+
inline const SplitKConfig& GetSplitKConfig() const {
89+
return webgpu_context_.GetSplitKConfig();
90+
}
91+
92+
//
93+
// Get whether graph capture is enabled.
94+
//
95+
inline bool IsGraphCaptureEnabled() const {
96+
return ep_.IsGraphCaptureEnabled();
7497
}
7598

7699
//
77100
// Get the logger.
78101
//
79102
inline const logging::Logger& Logger() const {
80-
return kernel_context_.Logger();
103+
return *ep_.GetLogger();
104+
}
105+
106+
//
107+
// Run a compute shader program.
108+
//
109+
inline Status RunProgram(const ProgramBase& program) {
110+
return webgpu_context_.Run(*this, program);
111+
}
112+
113+
protected:
114+
WebGpuContext& webgpu_context_;
115+
const WebGpuExecutionProvider& ep_;
116+
const OpKernel& op_kernel_;
117+
};
118+
119+
//
120+
// Class ComputeContext provides all information a `ComputeContextBase` provides, and also
121+
// access to `OpKernelContext` for input and output tensors.
122+
class ComputeContext final : public ComputeContextBase {
123+
public:
124+
ComputeContext(WebGpuContext& webgpu_context,
125+
const WebGpuExecutionProvider& ep,
126+
const OpKernel& op_kernel,
127+
OpKernelContext& kernel_context);
128+
129+
~ComputeContext() = default;
130+
131+
//
132+
// Get the kernel context.
133+
//
134+
inline OpKernelContext& KernelContext() {
135+
return kernel_context_;
81136
}
82137

83138
//
@@ -145,25 +200,8 @@ class ComputeContext final {
145200
return op_kernel_.Info().GetDataTransferManager().CopyTensor(src, dst);
146201
}
147202

148-
//
149-
// Run a compute shader program.
150-
//
151-
inline Status RunProgram(const ProgramBase& program) {
152-
return webgpu_context_.Run(*this, program);
153-
}
154-
155-
//
156-
// Get Split-K configuration.
157-
//
158-
// `split_k_config_` won't be initialized until the first call to this method.
159-
//
160-
const SplitKConfig& GetSplitKConfig();
161-
162203
private:
163-
WebGpuContext& webgpu_context_;
164204
OpKernelContext& kernel_context_;
165-
const OpKernel& op_kernel_;
166-
const WebGpuExecutionProvider& ep_;
167205
};
168206

169207
} // namespace webgpu

onnxruntime/core/providers/webgpu/nn/conv.cc

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,51 @@ Status Conv<is_channels_last, is_fused>::ComputeInternal(ComputeContext& context
216216
return context.RunProgram(conv2d_mm_program);
217217
}
218218

219+
template <bool is_channels_last, bool is_fused>
220+
Status Conv<is_channels_last, is_fused>::PrePackInternal(ComputeContextBase& context,
221+
const Tensor& tensor,
222+
int input_idx,
223+
AllocatorPtr alloc,
224+
/*out*/ bool& is_packed,
225+
/*out*/ PrePackedWeights* /*prepacked_weights*/) {
226+
is_packed = false;
227+
228+
if constexpr (is_channels_last) {
229+
if (input_idx == 1 && tensor.Shape().NumDimensions() == 4) {
230+
// only deal with 4D NHWC weights
231+
232+
// Tensors and allocator should both be on GPU
233+
ORT_ENFORCE(tensor.Location().device.Type() == OrtDevice::GPU &&
234+
tensor.Location().mem_type == OrtMemType::OrtMemTypeDefault &&
235+
tensor.Location().name == WEBGPU_BUFFER,
236+
"Tensor must be a WebGPU buffer.");
237+
ORT_ENFORCE(alloc->Info().device.Type() == OrtDevice::GPU &&
238+
alloc->Info().name == WEBGPU_BUFFER,
239+
"Allocator must be for WebGPU.");
240+
241+
// Step.1 - calculate transposed weight shape
242+
TensorShape transposed_kernel_shape{tensor.Shape()[2],
243+
tensor.Shape()[3],
244+
tensor.Shape()[1],
245+
tensor.Shape()[0]};
246+
247+
// Step.2 - create transposed weight tensor
248+
transposed_kernel_ = std::make_unique<Tensor>(tensor.DataType(), transposed_kernel_shape, alloc);
249+
250+
// Step.3 - do transpose
251+
size_t perm[] = {2, 3, 1, 0};
252+
ORT_RETURN_IF_ERROR(Transpose::DoTranspose(context,
253+
perm,
254+
tensor,
255+
*transposed_kernel_));
256+
257+
// is_packed = true; // set this flag to true so that ORT will release the initializer tensor
258+
}
259+
}
260+
261+
return Status::OK();
262+
}
263+
219264
// Explicit template instantiation for FusedConv
220265
template class Conv<false, false>;
221266
template class Conv<false, true>;

onnxruntime/core/providers/webgpu/nn/conv.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,13 @@ class Conv : public WebGpuKernel {
2323
}
2424
Status ComputeInternal(ComputeContext& context) const override;
2525

26+
Status PrePackInternal(ComputeContextBase& context, const Tensor& tensor, int input_idx, AllocatorPtr alloc,
27+
/*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) override;
28+
2629
protected:
2730
ConvAttributes conv_attrs_;
2831
Activation activation_;
32+
std::unique_ptr<Tensor> transposed_kernel_; // should only has value when `is_initializer` AND `is_4D` AND `is_NHWC`
2933
};
3034

3135
Status TransposeKernel(ComputeContext& context, const Tensor* kernel, const TensorShape& kernel_shape, Tensor* transposed_kernel, const InlinedVector<size_t>& perm);

onnxruntime/core/providers/webgpu/tensor/transpose.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ Status TransposeProgram::GenerateShaderCode(ShaderHelper& shader) const {
108108
return Status::OK();
109109
}
110110

111-
Status Transpose::DoTranspose(onnxruntime::webgpu::ComputeContext& context,
111+
Status Transpose::DoTranspose(onnxruntime::webgpu::ComputeContextBase& context,
112112
gsl::span<const size_t> permutations,
113113
const Tensor& input, Tensor& output) {
114114
const auto& input_shape = input.Shape();

onnxruntime/core/providers/webgpu/tensor/transpose.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ class Transpose final : public WebGpuKernel, public TransposeBase {
1616
Transpose(const OpKernelInfo& info) : WebGpuKernel{info}, TransposeBase{info} {
1717
}
1818
Status ComputeInternal(ComputeContext& context) const override;
19-
static Status DoTranspose(onnxruntime::webgpu::ComputeContext& context, gsl::span<const size_t> permutations, const Tensor& input, Tensor& output);
19+
static Status DoTranspose(onnxruntime::webgpu::ComputeContextBase& context, gsl::span<const size_t> permutations, const Tensor& input, Tensor& output);
2020

2121
constexpr static uint32_t TILE_SIZE = 16;
2222
};

onnxruntime/core/providers/webgpu/webgpu_context.cc

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,9 @@ void WebGpuContext::Initialize(const WebGpuBufferCacheConfig& buffer_cache_confi
147147
// create program manager
148148
program_mgr_ = std::make_unique<ProgramManager>(*this);
149149

150+
// create split-k config
151+
split_k_config_ = std::make_unique<SplitKConfig>(adapter_info_);
152+
150153
// set query type
151154
#if !defined(__wasm__)
152155
if (DeviceHasFeature(wgpu::FeatureName::ChromiumExperimentalTimestampQueryInsidePasses)) {
@@ -178,7 +181,7 @@ Status WebGpuContext::Wait(wgpu::Future f) {
178181
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to wait for the operation:", uint32_t(status));
179182
}
180183

181-
Status WebGpuContext::Run(ComputeContext& context, const ProgramBase& program) {
184+
Status WebGpuContext::Run(ComputeContextBase& context, const ProgramBase& program) {
182185
const auto& inputs = program.Inputs();
183186
const auto& outputs = program.Outputs();
184187

@@ -288,8 +291,8 @@ Status WebGpuContext::Run(ComputeContext& context, const ProgramBase& program) {
288291
auto key = CalculateProgramCacheKey(program, inputs_segments, outputs_segments, is_1d_dispatch);
289292

290293
if (is_profiling_) {
291-
PendingKernelInfo pending_kernel_info(context.KernelContext().GetNodeName(),
292-
context.KernelContext().GetOpType(),
294+
PendingKernelInfo pending_kernel_info(context.NodeName(),
295+
context.OpType(),
293296
program.Name(),
294297
key,
295298
inputs,
@@ -442,7 +445,7 @@ Status WebGpuContext::Run(ComputeContext& context, const ProgramBase& program) {
442445
const size_t uniform_buffer_total_size = (current_offset + max_alignment_of_field - 1) / max_alignment_of_field * max_alignment_of_field;
443446

444447
WGPUBuffer uniform_buffer = nullptr;
445-
const webgpu::BufferManager& buffer_mgr = ComputeContext::BufferManagerAccessor::Get(context);
448+
const webgpu::BufferManager& buffer_mgr = ComputeContextBase::BufferManagerAccessor::Get(context);
446449
if (uniform_buffer_total_size > 0) {
447450
std::vector<uint8_t> uniform_data_buffer(uniform_buffer_total_size);
448451

@@ -910,13 +913,6 @@ void WebGpuContext::ReleaseGraphResources(std::vector<webgpu::CapturedCommandInf
910913
}
911914
}
912915

913-
const SplitKConfig& WebGpuContext::GetSplitKConfig() {
914-
if (!split_k_config_) {
915-
split_k_config_ = SplitKConfig::GetSplitKConfig(adapter_info_);
916-
}
917-
return *split_k_config_;
918-
}
919-
920916
std::unordered_map<int32_t, WebGpuContextFactory::WebGpuContextInfo> WebGpuContextFactory::contexts_;
921917
std::mutex WebGpuContextFactory::mutex_;
922918
std::once_flag WebGpuContextFactory::init_default_flag_;

0 commit comments

Comments
 (0)