22// Licensed under the MIT License.
33
44#include < charconv>
5+ #include < mutex>
56
67#include " core/framework/error_code_helper.h"
78#include " core/providers/webgpu/buffer_manager.h"
1617using namespace onnxruntime ::webgpu::options;
1718
1819namespace onnxruntime {
19- // Helper to get default context config, buffer cache config, backend type, and enable_pix_capture
20+ // Helper struct that holds configuration parameters for creating a WebGPU context with default settings.
21+ // This is used during lazy initialization of the data transfer to create a context if one doesn't exist.
2022struct WebGpuContextParams {
21- webgpu::WebGpuContextConfig context_config;
22- webgpu::WebGpuBufferCacheConfig buffer_cache_config;
23- int backend_type;
24- bool enable_pix_capture;
23+ webgpu::WebGpuContextConfig context_config; // WebGPU context configuration
24+ webgpu::WebGpuBufferCacheConfig buffer_cache_config; // Buffer cache settings
25+ int backend_type; // Dawn backend type (D3D12, Vulkan, etc.)
26+ bool enable_pix_capture; // Enable PIX GPU capture for debugging
2527};
2628
2729static WebGpuContextParams GetDefaultWebGpuContextParams () {
@@ -336,11 +338,12 @@ struct WebGpuDataTransferImpl : OrtDataTransferImpl {
336338 : ort_api{ort_api_in},
337339 ep_api{*ort_api_in.GetEpApi ()},
338340 data_transfer_{nullptr },
339- context_id_{-1 } {
341+ context_id_{-1 },
342+ init_mutex_{} {
340343 ort_version_supported = ORT_API_VERSION;
341- CanCopy = CanCopyImpl;
342- CopyTensors = CopyTensorsImpl;
343- Release = ReleaseImpl;
344+ CanCopy = CanCopyImpl; // OrtDataTransferImpl::CanCopy callback
345+ CopyTensors = CopyTensorsImpl; // OrtDataTransferImpl::CopyTensors callback
346+ Release = ReleaseImpl; // OrtDataTransferImpl::Release callback
344347 }
345348
346349 static bool CanCopyImpl (const OrtDataTransferImpl* this_ptr,
@@ -414,24 +417,40 @@ struct WebGpuDataTransferImpl : OrtDataTransferImpl {
414417 }
415418 }
416419
417- // Initialize data_transfer if not already done or if context_id changed
418- if (impl.data_transfer_ == nullptr || impl.context_id_ != context_id) {
419- impl.context_id_ = context_id;
420+ // If no GPU tensor found, return an error as this indicates an invalid state
421+ // CanCopy should have rejected CPU-only tensor copies
422+ if (!found_gpu_tensor) {
423+ return OrtApis::CreateStatus (ORT_INVALID_ARGUMENT,
424+ " No GPU tensor found in CopyTensors call - all tensors are CPU-only. "
425+ " This indicates an invalid call as CanCopy should have rejected this." );
426+ }
420427
421- // Check if context exists, create a default one if it doesn't
422- webgpu::WebGpuContext* context_ptr = nullptr ;
423- if (webgpu::WebGpuContextFactory::HasContext (context_id)) {
424- context_ptr = &webgpu::WebGpuContextFactory::GetContext (context_id);
425- } else {
426- WebGpuContextParams params = GetDefaultWebGpuContextParams ();
427- params.context_config .context_id = context_id;
428- context_ptr = &webgpu::WebGpuContextFactory::CreateContext (params.context_config );
429- context_ptr->Initialize (params.buffer_cache_config , params.backend_type , params.enable_pix_capture );
428+ // Initialize data_transfer if not already done or if context_id changed
429+ // Use mutex to ensure thread-safe lazy initialization
430+ {
431+ std::lock_guard<std::mutex> lock (impl.init_mutex_ );
432+
433+ if (impl.data_transfer_ == nullptr || impl.context_id_ != context_id) {
434+ impl.context_id_ = context_id;
435+
436+ // Check if context exists, create a default one if it doesn't
437+ webgpu::WebGpuContext* context_ptr = nullptr ;
438+ if (webgpu::WebGpuContextFactory::HasContext (context_id)) {
439+ context_ptr = &webgpu::WebGpuContextFactory::GetContext (context_id);
440+ } else {
441+ WebGpuContextParams params = GetDefaultWebGpuContextParams ();
442+ params.context_config .context_id = context_id;
443+ context_ptr = &webgpu::WebGpuContextFactory::CreateContext (params.context_config );
444+ context_ptr->Initialize (params.buffer_cache_config , params.backend_type , params.enable_pix_capture );
445+ }
446+
447+ // Create the DataTransfer instance
448+ // Note: The DataTransfer holds a const reference to BufferManager. The BufferManager's lifecycle
449+ // is managed by the WebGpuContext, which is stored in a static WebGpuContextFactory and persists
450+ // for the lifetime of the application, ensuring the reference remains valid.
451+ impl.data_transfer_ = std::make_unique<webgpu::DataTransfer>(context_ptr->BufferManager ());
430452 }
431-
432- // Create the DataTransfer instance
433- impl.data_transfer_ = std::make_unique<webgpu::DataTransfer>(context_ptr->BufferManager ());
434- }
453+ } // Release lock
435454
436455 // Now perform the actual tensor copy
437456 for (size_t idx = 0 ; idx < num_tensors; ++idx) {
@@ -453,10 +472,17 @@ struct WebGpuDataTransferImpl : OrtDataTransferImpl {
453472 const OrtEpApi& ep_api;
454473 std::unique_ptr<webgpu::DataTransfer> data_transfer_; // Lazy-initialized
455474 int context_id_; // Track which context we're using
475+ std::mutex init_mutex_; // Protects lazy initialization
456476};
457477
458478OrtDataTransferImpl* OrtWebGpuCreateDataTransfer () {
459- return new WebGpuDataTransferImpl (*OrtApis::GetApi (ORT_API_VERSION));
479+ // Validate API version is supported
480+ const OrtApi* api = OrtApis::GetApi (ORT_API_VERSION);
481+ if (!api) {
482+ // API version not supported - return nullptr to indicate failure
483+ return nullptr ;
484+ }
485+ return new WebGpuDataTransferImpl (*api);
460486}
461487
462488} // namespace onnxruntime
0 commit comments