Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/webgpu/allocator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ GpuBufferAllocator::GpuBufferAllocator(const BufferManager& buffer_manager, bool
OrtMemoryInfo(WEBGPU_BUFFER,
is_read_only_allocator ? OrtAllocatorType::OrtReadOnlyAllocator
: OrtAllocatorType::OrtDeviceAllocator,
OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NONE, 0),
WebGpuDevice,
OrtMemTypeDefault)),
buffer_manager_{buffer_manager},
mapped_at_creation_{is_read_only_allocator && buffer_manager.SupportsUMA()} {
Expand Down
5 changes: 5 additions & 0 deletions onnxruntime/core/providers/webgpu/allocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@ namespace webgpu {

class BufferManager;

inline constexpr OrtDevice WebGpuDevice{OrtDevice::GPU,
OrtDevice::MemType::DEFAULT,
OrtDevice::VendorIds::NONE,
0};

class GpuBufferAllocator : public IAllocator {
public:
GpuBufferAllocator(const BufferManager& buffer_manager, bool is_read_only_allocator);
Expand Down
23 changes: 13 additions & 10 deletions onnxruntime/core/providers/webgpu/compute_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,25 @@

namespace onnxruntime {
namespace webgpu {
ComputeContext::ComputeContext(OpKernelContext& kernel_context,
const OpKernel& op_kernel,
const WebGpuExecutionProvider& ep,
WebGpuContext& webgpu_context)

ComputeContextBase::ComputeContextBase(WebGpuContext& webgpu_context,
const WebGpuExecutionProvider& ep,
const OpKernel& op_kernel)
: webgpu_context_{webgpu_context},
kernel_context_{kernel_context},
op_kernel_{op_kernel},
ep_{ep} {
ep_{ep},
op_kernel_{op_kernel} {
}

const webgpu::BufferManager& ComputeContext::BufferManagerAccessor::Get(const ComputeContext& context) {
const webgpu::BufferManager& ComputeContextBase::BufferManagerAccessor::Get(const ComputeContextBase& context) {
return context.ep_.BufferManager();
}

const SplitKConfig& ComputeContext::GetSplitKConfig() {
return webgpu_context_.GetSplitKConfig();
ComputeContext::ComputeContext(WebGpuContext& webgpu_context,
const WebGpuExecutionProvider& ep,
const OpKernel& op_kernel,
OpKernelContext& kernel_context)
: ComputeContextBase(webgpu_context, ep, op_kernel),
kernel_context_{kernel_context} {
}

} // namespace webgpu
Expand Down
103 changes: 71 additions & 32 deletions onnxruntime/core/providers/webgpu/compute_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,13 @@ namespace webgpu {
class WebGpuContext;
class BufferManager;

class ComputeContext final {
//
// Class ComputeContextBase is designed to provide basic context information
// for running a compute shader program.
//
// An instance of ComputeContextBase does not depend on OpKernelContext, which needs an execution frame to be created.
//
class ComputeContextBase {
public:
// Nested accessor class to provide controlled access to BufferManager
class BufferManagerAccessor {
Expand All @@ -34,18 +40,31 @@ class ComputeContext final {
friend class WebGpuContext;

private:
static const webgpu::BufferManager& Get(const ComputeContext& context);
static const webgpu::BufferManager& Get(const ComputeContextBase& context);
};

ComputeContext(OpKernelContext& kernel_context,
const OpKernel& op_kernel,
const WebGpuExecutionProvider& ep,
WebGpuContext& webgpu_context);
ComputeContextBase(WebGpuContext& webgpu_context,
const WebGpuExecutionProvider& ep,
const OpKernel& op_kernel);

~ComputeContext() = default;
~ComputeContextBase() = default;

//
// Get the node name.
//
inline decltype(auto) NodeName() const {
return op_kernel_.Node().Name();
}

//
// Get the operator type.
//
inline decltype(auto) OpType() const {
return op_kernel_.Node().OpType();
}

//
// Get various information from the context.
// Get various information from the WebGPU context.
//

inline const wgpu::AdapterInfo& AdapterInfo() const {
Expand All @@ -57,27 +76,64 @@ class ComputeContext final {
inline bool HasFeature(wgpu::FeatureName feature) const {
return webgpu_context_.DeviceHasFeature(feature);
}
inline bool IsGraphCaptureEnabled() const {
return ep_.IsGraphCaptureEnabled();
}
#if !defined(__wasm__)
inline const wgpu::AdapterPropertiesSubgroupMatrixConfigs& SubgroupMatrixConfigs() const {
return webgpu_context_.SubgroupMatrixConfigs();
}
#endif

//
// Get the kernel context.
// Get Split-K configuration.
//
inline OpKernelContext& KernelContext() {
return kernel_context_;
inline const SplitKConfig& GetSplitKConfig() const {
return webgpu_context_.GetSplitKConfig();
}

//
// Get whether graph capture is enabled.
//
inline bool IsGraphCaptureEnabled() const {
return ep_.IsGraphCaptureEnabled();
}

//
// Get the logger.
//
inline const logging::Logger& Logger() const {
return kernel_context_.Logger();
return *ep_.GetLogger();
}

//
// Run a compute shader program.
//
inline Status RunProgram(const ProgramBase& program) {
return webgpu_context_.Run(*this, program);
}

protected:
WebGpuContext& webgpu_context_;
const WebGpuExecutionProvider& ep_;
const OpKernel& op_kernel_;
};

//
// Class ComputeContext provides all information a `ComputeContextBase` provides, and also
// access to `OpKernelContext` for input and output tensors.
//
class ComputeContext final : public ComputeContextBase {
public:
ComputeContext(WebGpuContext& webgpu_context,
const WebGpuExecutionProvider& ep,
const OpKernel& op_kernel,
OpKernelContext& kernel_context);

~ComputeContext() = default;

//
// Get the kernel context.
//
inline OpKernelContext& KernelContext() {
return kernel_context_;
}

//
Expand Down Expand Up @@ -145,25 +201,8 @@ class ComputeContext final {
return op_kernel_.Info().GetDataTransferManager().CopyTensor(src, dst);
}

//
// Run a compute shader program.
//
inline Status RunProgram(const ProgramBase& program) {
return webgpu_context_.Run(*this, program);
}

//
// Get Split-K configuration.
//
// `split_k_config_` won't be initialized until the first call to this method.
//
const SplitKConfig& GetSplitKConfig();

private:
WebGpuContext& webgpu_context_;
OpKernelContext& kernel_context_;
const OpKernel& op_kernel_;
const WebGpuExecutionProvider& ep_;
};

} // namespace webgpu
Expand Down
40 changes: 40 additions & 0 deletions onnxruntime/core/providers/webgpu/nn/conv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,46 @@
return context.RunProgram(conv2d_mm_program);
}

template <bool is_channels_last, bool is_fused>
Status Conv<is_channels_last, is_fused>::PrePackInternal(ComputeContextBase& /* context */,
const Tensor& tensor,
int input_idx,
AllocatorPtr /* alloc */,
/*out*/ bool& is_packed) {
is_packed = false;

if constexpr (is_channels_last) {
if (input_idx == 1 && tensor.Shape().NumDimensions() == 4) {
// only deal with 4D NHWC weights

// TODO: implement weight transpose for pre-pack here

Check warning on line 231 in onnxruntime/core/providers/webgpu/nn/conv.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2] Raw Output: onnxruntime/core/providers/webgpu/nn/conv.cc:231: Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]
// Conv::ComputeInternal() should be updated to reflect the change:
// - if the initializer is packed, `context.Input<Tensor>(1)` will be nullptr.
// - in this case, use `transposed_kernel_` instead.

// // Step.1 - calculate transposed weight shape
// TensorShape transposed_kernel_shape{tensor.Shape()[2],
// tensor.Shape()[3],
// tensor.Shape()[1],
// tensor.Shape()[0]};

// // Step.2 - create transposed weight tensor
// transposed_kernel_ = std::make_unique<Tensor>(tensor.DataType(), transposed_kernel_shape, alloc);

// // Step.3 - do transpose
// size_t perm[] = {2, 3, 1, 0};
// ORT_RETURN_IF_ERROR(Transpose::DoTranspose(context,
// perm,
// tensor,
// *transposed_kernel_));

// is_packed = true; // set this flag to true so that ORT will release the initializer tensor
}
}

return Status::OK();
}

// Explicit template instantiation for FusedConv
template class Conv<false, false>;
template class Conv<false, true>;
Expand Down
7 changes: 7 additions & 0 deletions onnxruntime/core/providers/webgpu/nn/conv.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,16 @@
}
Status ComputeInternal(ComputeContext& context) const override;

Status PrePackInternal(ComputeContextBase& context,
const Tensor& tensor,
int input_idx,
AllocatorPtr alloc,
/*out*/ bool& is_packed) override;

protected:
ConvAttributes conv_attrs_;
Activation activation_;
std::unique_ptr<Tensor> transposed_kernel_; // should only has value when `is_initializer` AND `is_4D` AND `is_NHWC`

Check warning on line 35 in onnxruntime/core/providers/webgpu/nn/conv.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <memory> for unique_ptr<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/webgpu/nn/conv.h:35: Add #include <memory> for unique_ptr<> [build/include_what_you_use] [4]
};

Status TransposeKernel(ComputeContext& context, const Tensor* kernel, const TensorShape& kernel_shape, Tensor* transposed_kernel, const InlinedVector<size_t>& perm);
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/webgpu/tensor/transpose.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ Status TransposeProgram::GenerateShaderCode(ShaderHelper& shader) const {
return Status::OK();
}

Status Transpose::DoTranspose(onnxruntime::webgpu::ComputeContext& context,
Status Transpose::DoTranspose(onnxruntime::webgpu::ComputeContextBase& context,
gsl::span<const size_t> permutations,
const Tensor& input, Tensor& output) {
const auto& input_shape = input.Shape();
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/webgpu/tensor/transpose.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class Transpose final : public WebGpuKernel, public TransposeBase {
Transpose(const OpKernelInfo& info) : WebGpuKernel{info}, TransposeBase{info} {
}
Status ComputeInternal(ComputeContext& context) const override;
static Status DoTranspose(onnxruntime::webgpu::ComputeContext& context, gsl::span<const size_t> permutations, const Tensor& input, Tensor& output);
static Status DoTranspose(onnxruntime::webgpu::ComputeContextBase& context, gsl::span<const size_t> permutations, const Tensor& input, Tensor& output);

constexpr static uint32_t TILE_SIZE = 16;
};
Expand Down
18 changes: 7 additions & 11 deletions onnxruntime/core/providers/webgpu/webgpu_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,9 @@ void WebGpuContext::Initialize(const WebGpuBufferCacheConfig& buffer_cache_confi
// create program manager
program_mgr_ = std::make_unique<ProgramManager>(*this);

// create split-k config
split_k_config_ = std::make_unique<SplitKConfig>(adapter_info_);

// set query type
#if !defined(__wasm__)
if (DeviceHasFeature(wgpu::FeatureName::ChromiumExperimentalTimestampQueryInsidePasses)) {
Expand Down Expand Up @@ -178,7 +181,7 @@ Status WebGpuContext::Wait(wgpu::Future f) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to wait for the operation:", uint32_t(status));
}

Status WebGpuContext::Run(ComputeContext& context, const ProgramBase& program) {
Status WebGpuContext::Run(ComputeContextBase& context, const ProgramBase& program) {
const auto& inputs = program.Inputs();
const auto& outputs = program.Outputs();

Expand Down Expand Up @@ -288,8 +291,8 @@ Status WebGpuContext::Run(ComputeContext& context, const ProgramBase& program) {
auto key = CalculateProgramCacheKey(program, inputs_segments, outputs_segments, is_1d_dispatch);

if (is_profiling_) {
PendingKernelInfo pending_kernel_info(context.KernelContext().GetNodeName(),
context.KernelContext().GetOpType(),
PendingKernelInfo pending_kernel_info(context.NodeName(),
context.OpType(),
program.Name(),
key,
inputs,
Expand Down Expand Up @@ -442,7 +445,7 @@ Status WebGpuContext::Run(ComputeContext& context, const ProgramBase& program) {
const size_t uniform_buffer_total_size = (current_offset + max_alignment_of_field - 1) / max_alignment_of_field * max_alignment_of_field;

WGPUBuffer uniform_buffer = nullptr;
const webgpu::BufferManager& buffer_mgr = ComputeContext::BufferManagerAccessor::Get(context);
const webgpu::BufferManager& buffer_mgr = ComputeContextBase::BufferManagerAccessor::Get(context);
if (uniform_buffer_total_size > 0) {
std::vector<uint8_t> uniform_data_buffer(uniform_buffer_total_size);

Expand Down Expand Up @@ -910,13 +913,6 @@ void WebGpuContext::ReleaseGraphResources(std::vector<webgpu::CapturedCommandInf
}
}

const SplitKConfig& WebGpuContext::GetSplitKConfig() {
if (!split_k_config_) {
split_k_config_ = SplitKConfig::GetSplitKConfig(adapter_info_);
}
return *split_k_config_;
}

std::unordered_map<int32_t, WebGpuContextFactory::WebGpuContextInfo> WebGpuContextFactory::contexts_;
std::mutex WebGpuContextFactory::mutex_;
std::once_flag WebGpuContextFactory::init_default_flag_;
Expand Down
Loading
Loading