diff --git a/onnxruntime/core/providers/webgpu/allocator.cc b/onnxruntime/core/providers/webgpu/allocator.cc index b3eb4b5061423..3e1b87821fe2f 100644 --- a/onnxruntime/core/providers/webgpu/allocator.cc +++ b/onnxruntime/core/providers/webgpu/allocator.cc @@ -13,7 +13,7 @@ GpuBufferAllocator::GpuBufferAllocator(const BufferManager& buffer_manager, bool OrtMemoryInfo(WEBGPU_BUFFER, is_read_only_allocator ? OrtAllocatorType::OrtReadOnlyAllocator : OrtAllocatorType::OrtDeviceAllocator, - OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NONE, 0), + WebGpuDevice, OrtMemTypeDefault)), buffer_manager_{buffer_manager}, mapped_at_creation_{is_read_only_allocator && buffer_manager.SupportsUMA()} { diff --git a/onnxruntime/core/providers/webgpu/allocator.h b/onnxruntime/core/providers/webgpu/allocator.h index 7c38b4557e078..74b3d669fcf3b 100644 --- a/onnxruntime/core/providers/webgpu/allocator.h +++ b/onnxruntime/core/providers/webgpu/allocator.h @@ -11,6 +11,11 @@ namespace webgpu { class BufferManager; +inline constexpr OrtDevice WebGpuDevice{OrtDevice::GPU, + OrtDevice::MemType::DEFAULT, + OrtDevice::VendorIds::NONE, + 0}; + class GpuBufferAllocator : public IAllocator { public: GpuBufferAllocator(const BufferManager& buffer_manager, bool is_read_only_allocator); diff --git a/onnxruntime/core/providers/webgpu/compute_context.cc b/onnxruntime/core/providers/webgpu/compute_context.cc index ebe71c6ccfacd..d1a2011c8e191 100644 --- a/onnxruntime/core/providers/webgpu/compute_context.cc +++ b/onnxruntime/core/providers/webgpu/compute_context.cc @@ -6,22 +6,25 @@ namespace onnxruntime { namespace webgpu { -ComputeContext::ComputeContext(OpKernelContext& kernel_context, - const OpKernel& op_kernel, - const WebGpuExecutionProvider& ep, - WebGpuContext& webgpu_context) + +ComputeContextBase::ComputeContextBase(WebGpuContext& webgpu_context, + const WebGpuExecutionProvider& ep, + const OpKernel& op_kernel) : webgpu_context_{webgpu_context}, - kernel_context_{kernel_context}, - op_kernel_{op_kernel}, - ep_{ep} { + ep_{ep}, + op_kernel_{op_kernel} { } -const webgpu::BufferManager& ComputeContext::BufferManagerAccessor::Get(const ComputeContext& context) { +const webgpu::BufferManager& ComputeContextBase::BufferManagerAccessor::Get(const ComputeContextBase& context) { return context.ep_.BufferManager(); } -const SplitKConfig& ComputeContext::GetSplitKConfig() { - return webgpu_context_.GetSplitKConfig(); +ComputeContext::ComputeContext(WebGpuContext& webgpu_context, + const WebGpuExecutionProvider& ep, + const OpKernel& op_kernel, + OpKernelContext& kernel_context) + : ComputeContextBase(webgpu_context, ep, op_kernel), + kernel_context_{kernel_context} { } } // namespace webgpu diff --git a/onnxruntime/core/providers/webgpu/compute_context.h b/onnxruntime/core/providers/webgpu/compute_context.h index ed16f2f0a1345..fdf89854469d6 100644 --- a/onnxruntime/core/providers/webgpu/compute_context.h +++ b/onnxruntime/core/providers/webgpu/compute_context.h @@ -24,7 +24,13 @@ namespace webgpu { class WebGpuContext; class BufferManager; -class ComputeContext final { +// +// Class ComputeContextBase is designed to provide basic context information +// for running a compute shader program. +// +// An instance of ComputeContextBase does not depend on OpKernelContext, which needs an execution frame to be created. +// +class ComputeContextBase { public: // Nested accessor class to provide controlled access to BufferManager class BufferManagerAccessor { @@ -34,18 +40,31 @@ class ComputeContext final { friend class WebGpuContext; private: - static const webgpu::BufferManager& Get(const ComputeContext& context); + static const webgpu::BufferManager& Get(const ComputeContextBase& context); }; - ComputeContext(OpKernelContext& kernel_context, - const OpKernel& op_kernel, - const WebGpuExecutionProvider& ep, - WebGpuContext& webgpu_context); + ComputeContextBase(WebGpuContext& webgpu_context, + const WebGpuExecutionProvider& ep, + const OpKernel& op_kernel); - ~ComputeContext() = default; + ~ComputeContextBase() = default; + + // + // Get the node name. + // + inline decltype(auto) NodeName() const { + return op_kernel_.Node().Name(); + } + + // + // Get the operator type. + // + inline decltype(auto) OpType() const { + return op_kernel_.Node().OpType(); + } // - // Get various information from the context. + // Get various information from the WebGPU context. // inline const wgpu::AdapterInfo& AdapterInfo() const { @@ -57,9 +76,6 @@ class ComputeContext final { inline bool HasFeature(wgpu::FeatureName feature) const { return webgpu_context_.DeviceHasFeature(feature); } - inline bool IsGraphCaptureEnabled() const { - return ep_.IsGraphCaptureEnabled(); - } #if !defined(__wasm__) inline const wgpu::AdapterPropertiesSubgroupMatrixConfigs& SubgroupMatrixConfigs() const { return webgpu_context_.SubgroupMatrixConfigs(); @@ -67,17 +83,57 @@ class ComputeContext final { #endif // - // Get the kernel context. + // Get Split-K configuration. // - inline OpKernelContext& KernelContext() { - return kernel_context_; + inline const SplitKConfig& GetSplitKConfig() const { + return webgpu_context_.GetSplitKConfig(); + } + + // + // Get whether graph capture is enabled. + // + inline bool IsGraphCaptureEnabled() const { + return ep_.IsGraphCaptureEnabled(); } // // Get the logger. // inline const logging::Logger& Logger() const { - return kernel_context_.Logger(); + return *ep_.GetLogger(); + } + + // + // Run a compute shader program. + // + inline Status RunProgram(const ProgramBase& program) { + return webgpu_context_.Run(*this, program); + } + + protected: + WebGpuContext& webgpu_context_; + const WebGpuExecutionProvider& ep_; + const OpKernel& op_kernel_; +}; + +// +// Class ComputeContext provides all information a `ComputeContextBase` provides, and also +// access to `OpKernelContext` for input and output tensors. +// +class ComputeContext final : public ComputeContextBase { + public: + ComputeContext(WebGpuContext& webgpu_context, + const WebGpuExecutionProvider& ep, + const OpKernel& op_kernel, + OpKernelContext& kernel_context); + + ~ComputeContext() = default; + + // + // Get the kernel context. + // + inline OpKernelContext& KernelContext() { + return kernel_context_; } // @@ -145,25 +201,8 @@ class ComputeContext final { return op_kernel_.Info().GetDataTransferManager().CopyTensor(src, dst); } - // - // Run a compute shader program. - // - inline Status RunProgram(const ProgramBase& program) { - return webgpu_context_.Run(*this, program); - } - - // - // Get Split-K configuration. - // - // `split_k_config_` won't be initialized until the first call to this method. - // - const SplitKConfig& GetSplitKConfig(); - private: - WebGpuContext& webgpu_context_; OpKernelContext& kernel_context_; - const OpKernel& op_kernel_; - const WebGpuExecutionProvider& ep_; }; } // namespace webgpu diff --git a/onnxruntime/core/providers/webgpu/nn/conv.cc b/onnxruntime/core/providers/webgpu/nn/conv.cc index 77fa46cb87518..4fff736fd2f32 100644 --- a/onnxruntime/core/providers/webgpu/nn/conv.cc +++ b/onnxruntime/core/providers/webgpu/nn/conv.cc @@ -216,6 +216,46 @@ Status Conv::ComputeInternal(ComputeContext& context return context.RunProgram(conv2d_mm_program); } +template +Status Conv::PrePackInternal(ComputeContextBase& /* context */, + const Tensor& tensor, + int input_idx, + AllocatorPtr /* alloc */, + /*out*/ bool& is_packed) { + is_packed = false; + + if constexpr (is_channels_last) { + if (input_idx == 1 && tensor.Shape().NumDimensions() == 4) { + // only deal with 4D NHWC weights + + // TODO: implement weight transpose for pre-pack here + // Conv::ComputeInternal() should be updated to reflect the change: + // - if the initializer is packed, `context.Input(1)` will be nullptr. + // - in this case, use `transposed_kernel_` instead. + + // // Step.1 - calculate transposed weight shape + // TensorShape transposed_kernel_shape{tensor.Shape()[2], + // tensor.Shape()[3], + // tensor.Shape()[1], + // tensor.Shape()[0]}; + + // // Step.2 - create transposed weight tensor + // transposed_kernel_ = std::make_unique(tensor.DataType(), transposed_kernel_shape, alloc); + + // // Step.3 - do transpose + // size_t perm[] = {2, 3, 1, 0}; + // ORT_RETURN_IF_ERROR(Transpose::DoTranspose(context, + // perm, + // tensor, + // *transposed_kernel_)); + + // is_packed = true; // set this flag to true so that ORT will release the initializer tensor + } + } + + return Status::OK(); +} + // Explicit template instantiation for FusedConv template class Conv; template class Conv; diff --git a/onnxruntime/core/providers/webgpu/nn/conv.h b/onnxruntime/core/providers/webgpu/nn/conv.h index cafaa272c0613..7c2afb444e754 100644 --- a/onnxruntime/core/providers/webgpu/nn/conv.h +++ b/onnxruntime/core/providers/webgpu/nn/conv.h @@ -23,9 +23,16 @@ class Conv : public WebGpuKernel { } Status ComputeInternal(ComputeContext& context) const override; + Status PrePackInternal(ComputeContextBase& context, + const Tensor& tensor, + int input_idx, + AllocatorPtr alloc, + /*out*/ bool& is_packed) override; + protected: ConvAttributes conv_attrs_; Activation activation_; + std::unique_ptr transposed_kernel_; // should only has value when `is_initializer` AND `is_4D` AND `is_NHWC` }; Status TransposeKernel(ComputeContext& context, const Tensor* kernel, const TensorShape& kernel_shape, Tensor* transposed_kernel, const InlinedVector& perm); diff --git a/onnxruntime/core/providers/webgpu/tensor/transpose.cc b/onnxruntime/core/providers/webgpu/tensor/transpose.cc index cec321d0da80e..5415d4a5ead5b 100644 --- a/onnxruntime/core/providers/webgpu/tensor/transpose.cc +++ b/onnxruntime/core/providers/webgpu/tensor/transpose.cc @@ -108,7 +108,7 @@ Status TransposeProgram::GenerateShaderCode(ShaderHelper& shader) const { return Status::OK(); } -Status Transpose::DoTranspose(onnxruntime::webgpu::ComputeContext& context, +Status Transpose::DoTranspose(onnxruntime::webgpu::ComputeContextBase& context, gsl::span permutations, const Tensor& input, Tensor& output) { const auto& input_shape = input.Shape(); diff --git a/onnxruntime/core/providers/webgpu/tensor/transpose.h b/onnxruntime/core/providers/webgpu/tensor/transpose.h index b62a419fa12bc..5e9ccc6750cd6 100644 --- a/onnxruntime/core/providers/webgpu/tensor/transpose.h +++ b/onnxruntime/core/providers/webgpu/tensor/transpose.h @@ -16,7 +16,7 @@ class Transpose final : public WebGpuKernel, public TransposeBase { Transpose(const OpKernelInfo& info) : WebGpuKernel{info}, TransposeBase{info} { } Status ComputeInternal(ComputeContext& context) const override; - static Status DoTranspose(onnxruntime::webgpu::ComputeContext& context, gsl::span permutations, const Tensor& input, Tensor& output); + static Status DoTranspose(onnxruntime::webgpu::ComputeContextBase& context, gsl::span permutations, const Tensor& input, Tensor& output); constexpr static uint32_t TILE_SIZE = 16; }; diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc index 28decb076951e..b8d5adc421124 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc @@ -147,6 +147,9 @@ void WebGpuContext::Initialize(const WebGpuBufferCacheConfig& buffer_cache_confi // create program manager program_mgr_ = std::make_unique(*this); + // create split-k config + split_k_config_ = std::make_unique(adapter_info_); + // set query type #if !defined(__wasm__) if (DeviceHasFeature(wgpu::FeatureName::ChromiumExperimentalTimestampQueryInsidePasses)) { @@ -178,7 +181,7 @@ Status WebGpuContext::Wait(wgpu::Future f) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to wait for the operation:", uint32_t(status)); } -Status WebGpuContext::Run(ComputeContext& context, const ProgramBase& program) { +Status WebGpuContext::Run(ComputeContextBase& context, const ProgramBase& program) { const auto& inputs = program.Inputs(); const auto& outputs = program.Outputs(); @@ -288,8 +291,8 @@ Status WebGpuContext::Run(ComputeContext& context, const ProgramBase& program) { auto key = CalculateProgramCacheKey(program, inputs_segments, outputs_segments, is_1d_dispatch); if (is_profiling_) { - PendingKernelInfo pending_kernel_info(context.KernelContext().GetNodeName(), - context.KernelContext().GetOpType(), + PendingKernelInfo pending_kernel_info(context.NodeName(), + context.OpType(), program.Name(), key, inputs, @@ -442,7 +445,7 @@ Status WebGpuContext::Run(ComputeContext& context, const ProgramBase& program) { const size_t uniform_buffer_total_size = (current_offset + max_alignment_of_field - 1) / max_alignment_of_field * max_alignment_of_field; WGPUBuffer uniform_buffer = nullptr; - const webgpu::BufferManager& buffer_mgr = ComputeContext::BufferManagerAccessor::Get(context); + const webgpu::BufferManager& buffer_mgr = ComputeContextBase::BufferManagerAccessor::Get(context); if (uniform_buffer_total_size > 0) { std::vector uniform_data_buffer(uniform_buffer_total_size); @@ -910,13 +913,6 @@ void WebGpuContext::ReleaseGraphResources(std::vector WebGpuContextFactory::contexts_; std::mutex WebGpuContextFactory::mutex_; std::once_flag WebGpuContextFactory::init_default_flag_; diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.h b/onnxruntime/core/providers/webgpu/webgpu_context.h index bd7dae75f2e2d..84dfb47ef4687 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.h +++ b/onnxruntime/core/providers/webgpu/webgpu_context.h @@ -5,7 +5,6 @@ #include #include -#include #include "core/providers/webgpu/webgpu_external_header.h" @@ -23,7 +22,7 @@ class Tensor; namespace webgpu { class WebGpuContext; -class ComputeContext; +class ComputeContextBase; class ProgramBase; // Definition for CapturedCommandInfo in the webgpu namespace @@ -152,6 +151,13 @@ class WebGpuContext final { return validation_mode_; } + // + // Get Split-K configuration. + // + const SplitKConfig& GetSplitKConfig() const { + return *split_k_config_; + } + void StartProfiling(); void CollectProfilingData(profiling::Events& events); void EndProfiling(TimePoint, profiling::Events& events, profiling::Events& cached_events); @@ -170,16 +176,9 @@ class WebGpuContext final { // Status PopErrorScope(); - Status Run(ComputeContext& context, const ProgramBase& program); + Status Run(ComputeContextBase& context, const ProgramBase& program); void OnRunEnd(); - // - // Get Split-K configuration. - // - // `split_k_config_` won't be initialized until the first call to this method. - // - const SplitKConfig& GetSplitKConfig(); - private: enum class TimestampQueryType { None = 0, @@ -277,7 +276,7 @@ class WebGpuContext final { uint32_t num_pending_dispatches_ = 0; const uint32_t max_num_pending_dispatches_ = 16; - std::optional split_k_config_; + std::unique_ptr split_k_config_; // profiling TimestampQueryType query_type_; diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index e0b84fef51f1f..882c084bf6157 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -794,8 +794,7 @@ using namespace webgpu; WebGpuExecutionProvider::WebGpuExecutionProvider(int context_id, WebGpuContext& context, WebGpuExecutionProviderConfig&& config) - : IExecutionProvider{kWebGpuExecutionProvider, - OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NONE, 0)}, + : IExecutionProvider{kWebGpuExecutionProvider, WebGpuDevice}, context_id_{context_id}, context_{context}, preferred_data_layout_{config.data_layout}, diff --git a/onnxruntime/core/providers/webgpu/webgpu_kernel.cc b/onnxruntime/core/providers/webgpu/webgpu_kernel.cc index 8d6ae6caeaf83..ea38e9415e1fe 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_kernel.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_kernel.cc @@ -11,25 +11,58 @@ namespace webgpu { WebGpuKernel::WebGpuKernel(const OpKernelInfo& info) : OpKernel(info), - ep_(*static_cast(info.GetExecutionProvider())) { + ep_(*static_cast(info.GetExecutionProvider())), + webgpu_context_(WebGpuContextFactory::GetContext(ep_.GetDeviceId())) { } Status WebGpuKernel::Compute(OpKernelContext* p_op_kernel_context) const { - WebGpuContext& webgpu_context = WebGpuContextFactory::GetContext(ep_.GetDeviceId()); - ComputeContext context{*p_op_kernel_context, *this, ep_, webgpu_context}; + ComputeContext context{webgpu_context_, + ep_, + *this, + *p_op_kernel_context}; - if (webgpu_context.ValidationMode() >= ValidationMode::Full) { - webgpu_context.PushErrorScope(); + if (webgpu_context_.ValidationMode() >= ValidationMode::Full) { + webgpu_context_.PushErrorScope(); } Status s = ComputeInternal(context); - if (webgpu_context.ValidationMode() >= ValidationMode::Full) { - ORT_RETURN_IF_ERROR(webgpu_context.PopErrorScope()); + if (webgpu_context_.ValidationMode() >= ValidationMode::Full) { + ORT_RETURN_IF_ERROR(webgpu_context_.PopErrorScope()); } return s; } +Status WebGpuKernel::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + /*out*/ bool& is_packed, /*out*/ PrePackedWeights* /* prepacked_weights */) { + ComputeContextBase context{webgpu_context_, ep_, *this}; + + if (webgpu_context_.ValidationMode() >= ValidationMode::Full) { + webgpu_context_.PushErrorScope(); + } + + // Currently, ORT does not allow using prepacked weights in non-CPU EPs. + // So we do not pass prepacked_weights to PrePackInternal. + // Kernel implementation that supports prepacking should manage its own storage. + + Status s = PrePackInternal(context, tensor, input_idx, alloc, is_packed); + + if (webgpu_context_.ValidationMode() >= ValidationMode::Full) { + ORT_RETURN_IF_ERROR(webgpu_context_.PopErrorScope()); + } + + return s; +} + +Status WebGpuKernel::PrePackInternal(ComputeContextBase& /*context*/, + const Tensor& /*tensor*/, + int /*input_idx*/, + AllocatorPtr /*alloc*/, + /*out*/ bool& is_packed) { + is_packed = false; + return Status::OK(); +} + } // namespace webgpu } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_kernel.h b/onnxruntime/core/providers/webgpu/webgpu_kernel.h index 3c750e305421c..78cd2dfd480d5 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_kernel.h +++ b/onnxruntime/core/providers/webgpu/webgpu_kernel.h @@ -23,8 +23,21 @@ class WebGpuKernel : public OpKernel { virtual Status ComputeInternal(ComputeContext& context) const = 0; + Status PrePack(const Tensor& tensor, + int input_idx, + AllocatorPtr alloc, + /*out*/ bool& is_packed, + /*out*/ PrePackedWeights* prepacked_weights) override; + + virtual Status PrePackInternal(ComputeContextBase& context, + const Tensor& tensor, + int input_idx, + AllocatorPtr alloc, + /*out*/ bool& is_packed); + private: const WebGpuExecutionProvider& ep_; + WebGpuContext& webgpu_context_; }; } // namespace webgpu diff --git a/onnxruntime/core/providers/webgpu/webgpu_utils.cc b/onnxruntime/core/providers/webgpu/webgpu_utils.cc index 568d29a96cb88..5fd24b2bff037 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_utils.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_utils.cc @@ -21,27 +21,24 @@ TensorShape ReduceShapeByComponents(const TensorShape& shape, int64_t components return TensorShape(shape_vector); } -SplitKConfig SplitKConfig::GetSplitKConfig(const wgpu::AdapterInfo& adapter_info) { - SplitKConfig config = {}; - +SplitKConfig::SplitKConfig(const wgpu::AdapterInfo& adapter_info) { if (adapter_info.vendor == std::string_view{"intel"}) { if (adapter_info.architecture == std::string_view{"xe-2lpg"} || adapter_info.architecture == std::string_view{"xe-2hpg"} || adapter_info.architecture == std::string_view{"xe-lpg"} || adapter_info.architecture == std::string_view{"gen-12hp"}) { - config.enable_split_k_ = true; + enable_split_k_ = true; // Below thresholds are only verified on the above Intel GPUs without any regressions. The // proper value of `max_dim_a_outer_multiplies_dim_b_outer_divides_dim_inner_` may be // reduced when we support a larger `dim_inner` because larger `dim_inner` will bring more // atomic calls for each output value. - config.split_dim_inner_ = 256; - config.min_dim_inner_with_split_k_ = config.split_dim_inner_ * 2; - config.max_dim_inner_with_split_k_ = config.split_dim_inner_ * 9; - config.max_dim_a_outer_multiplies_dim_b_outer_divides_dim_inner_ = 35.0f; + split_dim_inner_ = 256; + min_dim_inner_with_split_k_ = split_dim_inner_ * 2; + max_dim_inner_with_split_k_ = split_dim_inner_ * 9; + max_dim_a_outer_multiplies_dim_b_outer_divides_dim_inner_ = 35.0f; } } - return config; } bool SplitKConfig::UseSplitK( diff --git a/onnxruntime/core/providers/webgpu/webgpu_utils.h b/onnxruntime/core/providers/webgpu/webgpu_utils.h index d45b9bf4dd119..7d5ab5fea8006 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_utils.h +++ b/onnxruntime/core/providers/webgpu/webgpu_utils.h @@ -91,9 +91,12 @@ inline Tensor CreateTensorView(const Tensor& tensor, MLDataType new_data_type, c return {new_data_type, new_shape, const_cast(tensor.DataRaw()), tensor.Location()}; } +/** + * Configuration for Split-K optimization (Conv|MatMul). + */ class SplitKConfig { public: - static SplitKConfig GetSplitKConfig(const wgpu::AdapterInfo& adapter_info); + explicit SplitKConfig(const wgpu::AdapterInfo& adapter_info); bool UseSplitK( bool is_vec4, ActivationKind activation_kind, uint64_t batch_size,