Skip to content

Commit 60ef5d8

Browse files
committed
address comments
1 parent 5ad41d4 commit 60ef5d8

File tree

1 file changed

+52
-26
lines changed

1 file changed

+52
-26
lines changed

onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc

Lines changed: 52 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// Licensed under the MIT License.
33

44
#include <charconv>
5+
#include <mutex>
56

67
#include "core/framework/error_code_helper.h"
78
#include "core/providers/webgpu/buffer_manager.h"
@@ -16,12 +17,13 @@
1617
using namespace onnxruntime::webgpu::options;
1718

1819
namespace onnxruntime {
19-
// Helper to get default context config, buffer cache config, backend type, and enable_pix_capture
20+
// Helper struct that holds configuration parameters for creating a WebGPU context with default settings.
21+
// This is used during lazy initialization of the data transfer to create a context if one doesn't exist.
2022
struct WebGpuContextParams {
21-
webgpu::WebGpuContextConfig context_config;
22-
webgpu::WebGpuBufferCacheConfig buffer_cache_config;
23-
int backend_type;
24-
bool enable_pix_capture;
23+
webgpu::WebGpuContextConfig context_config; // WebGPU context configuration
24+
webgpu::WebGpuBufferCacheConfig buffer_cache_config; // Buffer cache settings
25+
int backend_type; // Dawn backend type (D3D12, Vulkan, etc.)
26+
bool enable_pix_capture; // Enable PIX GPU capture for debugging
2527
};
2628

2729
static WebGpuContextParams GetDefaultWebGpuContextParams() {
@@ -336,11 +338,12 @@ struct WebGpuDataTransferImpl : OrtDataTransferImpl {
336338
: ort_api{ort_api_in},
337339
ep_api{*ort_api_in.GetEpApi()},
338340
data_transfer_{nullptr},
339-
context_id_{-1} {
341+
context_id_{-1},
342+
init_mutex_{} {
340343
ort_version_supported = ORT_API_VERSION;
341-
CanCopy = CanCopyImpl;
342-
CopyTensors = CopyTensorsImpl;
343-
Release = ReleaseImpl;
344+
CanCopy = CanCopyImpl; // OrtDataTransferImpl::CanCopy callback
345+
CopyTensors = CopyTensorsImpl; // OrtDataTransferImpl::CopyTensors callback
346+
Release = ReleaseImpl; // OrtDataTransferImpl::Release callback
344347
}
345348

346349
static bool CanCopyImpl(const OrtDataTransferImpl* this_ptr,
@@ -414,24 +417,40 @@ struct WebGpuDataTransferImpl : OrtDataTransferImpl {
414417
}
415418
}
416419

417-
// Initialize data_transfer if not already done or if context_id changed
418-
if (impl.data_transfer_ == nullptr || impl.context_id_ != context_id) {
419-
impl.context_id_ = context_id;
420+
// If no GPU tensor found, return an error as this indicates an invalid state
421+
// CanCopy should have rejected CPU-only tensor copies
422+
if (!found_gpu_tensor) {
423+
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT,
424+
"No GPU tensor found in CopyTensors call - all tensors are CPU-only. "
425+
"This indicates an invalid call as CanCopy should have rejected this.");
426+
}
420427

421-
// Check if context exists, create a default one if it doesn't
422-
webgpu::WebGpuContext* context_ptr = nullptr;
423-
if (webgpu::WebGpuContextFactory::HasContext(context_id)) {
424-
context_ptr = &webgpu::WebGpuContextFactory::GetContext(context_id);
425-
} else {
426-
WebGpuContextParams params = GetDefaultWebGpuContextParams();
427-
params.context_config.context_id = context_id;
428-
context_ptr = &webgpu::WebGpuContextFactory::CreateContext(params.context_config);
429-
context_ptr->Initialize(params.buffer_cache_config, params.backend_type, params.enable_pix_capture);
428+
// Initialize data_transfer if not already done or if context_id changed
429+
// Use mutex to ensure thread-safe lazy initialization
430+
{
431+
std::lock_guard<std::mutex> lock(impl.init_mutex_);
432+
433+
if (impl.data_transfer_ == nullptr || impl.context_id_ != context_id) {
434+
impl.context_id_ = context_id;
435+
436+
// Check if context exists, create a default one if it doesn't
437+
webgpu::WebGpuContext* context_ptr = nullptr;
438+
if (webgpu::WebGpuContextFactory::HasContext(context_id)) {
439+
context_ptr = &webgpu::WebGpuContextFactory::GetContext(context_id);
440+
} else {
441+
WebGpuContextParams params = GetDefaultWebGpuContextParams();
442+
params.context_config.context_id = context_id;
443+
context_ptr = &webgpu::WebGpuContextFactory::CreateContext(params.context_config);
444+
context_ptr->Initialize(params.buffer_cache_config, params.backend_type, params.enable_pix_capture);
445+
}
446+
447+
// Create the DataTransfer instance
448+
// Note: The DataTransfer holds a const reference to BufferManager. The BufferManager's lifecycle
449+
// is managed by the WebGpuContext, which is stored in a static WebGpuContextFactory and persists
450+
// for the lifetime of the application, ensuring the reference remains valid.
451+
impl.data_transfer_ = std::make_unique<webgpu::DataTransfer>(context_ptr->BufferManager());
430452
}
431-
432-
// Create the DataTransfer instance
433-
impl.data_transfer_ = std::make_unique<webgpu::DataTransfer>(context_ptr->BufferManager());
434-
}
453+
} // Release lock
435454

436455
// Now perform the actual tensor copy
437456
for (size_t idx = 0; idx < num_tensors; ++idx) {
@@ -453,10 +472,17 @@ struct WebGpuDataTransferImpl : OrtDataTransferImpl {
453472
const OrtEpApi& ep_api;
454473
std::unique_ptr<webgpu::DataTransfer> data_transfer_; // Lazy-initialized
455474
int context_id_; // Track which context we're using
475+
std::mutex init_mutex_; // Protects lazy initialization
456476
};
457477

458478
OrtDataTransferImpl* OrtWebGpuCreateDataTransfer() {
459-
return new WebGpuDataTransferImpl(*OrtApis::GetApi(ORT_API_VERSION));
479+
// Validate API version is supported
480+
const OrtApi* api = OrtApis::GetApi(ORT_API_VERSION);
481+
if (!api) {
482+
// API version not supported - return nullptr to indicate failure
483+
return nullptr;
484+
}
485+
return new WebGpuDataTransferImpl(*api);
460486
}
461487

462488
} // namespace onnxruntime

0 commit comments

Comments
 (0)