@@ -24,7 +24,13 @@ namespace webgpu {
2424class WebGpuContext ;
2525class BufferManager ;
2626
27- class ComputeContext final {
27+ //
28+ // Class ComputeContextBase is designed to provide basic context information
29+ // for running a compute shader program.
30+ //
31+ // An instance of ComputeContextBase does not depend on OpKernelContext, which needs an execution frame to be created.
32+ //
33+ class ComputeContextBase {
2834 public:
2935 // Nested accessor class to provide controlled access to BufferManager
3036 class BufferManagerAccessor {
@@ -34,18 +40,31 @@ class ComputeContext final {
3440 friend class WebGpuContext ;
3541
3642 private:
37- static const webgpu::BufferManager& Get (const ComputeContext & context);
43+ static const webgpu::BufferManager& Get (const ComputeContextBase & context);
3844 };
3945
40- ComputeContext (OpKernelContext& kernel_context,
41- const OpKernel& op_kernel,
42- const WebGpuExecutionProvider& ep,
43- WebGpuContext& webgpu_context);
46+ ComputeContextBase (WebGpuContext& webgpu_context,
47+ const WebGpuExecutionProvider& ep,
48+ const OpKernel& op_kernel);
4449
45- ~ComputeContext () = default ;
50+ ~ComputeContextBase () = default ;
51+
52+ //
53+ // Get the node name.
54+ //
55+ inline decltype (auto ) NodeName() const {
56+ return op_kernel_.Node ().Name ();
57+ }
58+
59+ //
60+ // Get the operator type.
61+ //
62+ inline decltype (auto ) OpType() const {
63+ return op_kernel_.Node ().OpType ();
64+ }
4665
4766 //
48- // Get various information from the context.
67+ // Get various information from the WebGPU context.
4968 //
5069
5170 inline const wgpu::AdapterInfo& AdapterInfo () const {
@@ -57,27 +76,63 @@ class ComputeContext final {
5776 inline bool HasFeature (wgpu::FeatureName feature) const {
5877 return webgpu_context_.DeviceHasFeature (feature);
5978 }
60- inline bool IsGraphCaptureEnabled () const {
61- return ep_.IsGraphCaptureEnabled ();
62- }
6379#if !defined(__wasm__)
6480 inline const wgpu::AdapterPropertiesSubgroupMatrixConfigs& SubgroupMatrixConfigs () const {
6581 return webgpu_context_.SubgroupMatrixConfigs ();
6682 }
6783#endif
6884
6985 //
70- // Get the kernel context .
86+ // Get Split-K configuration .
7187 //
72- inline OpKernelContext& KernelContext () {
73- return kernel_context_;
88+ inline const SplitKConfig& GetSplitKConfig () const {
89+ return webgpu_context_.GetSplitKConfig ();
90+ }
91+
92+ //
93+ // Get whether graph capture is enabled.
94+ //
95+ inline bool IsGraphCaptureEnabled () const {
96+ return ep_.IsGraphCaptureEnabled ();
7497 }
7598
7699 //
77100 // Get the logger.
78101 //
79102 inline const logging::Logger& Logger () const {
80- return kernel_context_.Logger ();
103+ return *ep_.GetLogger ();
104+ }
105+
106+ //
107+ // Run a compute shader program.
108+ //
109+ inline Status RunProgram (const ProgramBase& program) {
110+ return webgpu_context_.Run (*this , program);
111+ }
112+
113+ protected:
114+ WebGpuContext& webgpu_context_;
115+ const WebGpuExecutionProvider& ep_;
116+ const OpKernel& op_kernel_;
117+ };
118+
119+ //
120+ // Class ComputeContext provides all information a `ComputeContextBase` provides, and also
121+ // access to `OpKernelContext` for input and output tensors.
122+ class ComputeContext final : public ComputeContextBase {
123+ public:
124+ ComputeContext (WebGpuContext& webgpu_context,
125+ const WebGpuExecutionProvider& ep,
126+ const OpKernel& op_kernel,
127+ OpKernelContext& kernel_context);
128+
129+ ~ComputeContext () = default ;
130+
131+ //
132+ // Get the kernel context.
133+ //
134+ inline OpKernelContext& KernelContext () {
135+ return kernel_context_;
81136 }
82137
83138 //
@@ -145,25 +200,8 @@ class ComputeContext final {
145200 return op_kernel_.Info ().GetDataTransferManager ().CopyTensor (src, dst);
146201 }
147202
148- //
149- // Run a compute shader program.
150- //
151- inline Status RunProgram (const ProgramBase& program) {
152- return webgpu_context_.Run (*this , program);
153- }
154-
155- //
156- // Get Split-K configuration.
157- //
158- // `split_k_config_` won't be initialized until the first call to this method.
159- //
160- const SplitKConfig& GetSplitKConfig ();
161-
162203 private:
163- WebGpuContext& webgpu_context_;
164204 OpKernelContext& kernel_context_;
165- const OpKernel& op_kernel_;
166- const WebGpuExecutionProvider& ep_;
167205};
168206
169207} // namespace webgpu
0 commit comments