Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
5dce33a
Graphics ORT Interop related changes
praneshgo Oct 15, 2025
b65c59e
Adding a few fixes based on CodeRabbit's review comments
praneshgo Oct 15, 2025
36f6130
Few bug fixes in the fall back path and a API linking issue fix
praneshgo Oct 16, 2025
db281b5
Adding additional compile parameter and macro for DX interop
praneshgo Oct 16, 2025
36086b5
Added documentation for APIs, modified the error handling in EpSignal…
praneshgo Oct 16, 2025
9830d42
Adding Vulkan compilation support, adding more members to GraphicsInt…
praneshgo Oct 24, 2025
d032260
Adding CIG support for DX
praneshgo Oct 30, 2025
0432087
Not looping for EPs in InteropWait and InteropSignal calls
praneshgo Nov 4, 2025
49549bf
FenceInteropParams moved out of GraphicsInteropParams; SemaphoreEpMap…
praneshgo Nov 5, 2025
8931f16
A few code fixes, highlighted by CodeRabbit
praneshgo Nov 5, 2025
cb90d98
Fixing windows_nvtensorrtrtx_build and build_linux_nv_only pipelines
praneshgo Nov 5, 2025
0861626
Adding a small fix to further get build-linux-nv-only pipeline to pass
praneshgo Nov 6, 2025
21697fa
A few parameter level changes that do not change the functionality much
praneshgo Nov 7, 2025
a42b35e
Moving newly added APIs to the end to avoid ABI breakage
praneshgo Nov 7, 2025
c3e7886
Not exposing DX/Vulkan datatypes and headers in public ORT headers
praneshgo Nov 10, 2025
ad61f8c
Merge remote-tracking branch 'origin/main' into pgonegandla/graphics_…
praneshgo Nov 10, 2025
8d7a5c8
Couple of tiny fixes
praneshgo Nov 10, 2025
613ffbf
Adding sample test that demonstrates DX-ORT interop feature with NV T…
praneshgo Nov 10, 2025
9d1f961
Update nv_basic_ort_interop_test.cc
praneshgo Nov 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions cmake/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,31 @@ option(onnxruntime_USE_AZURE "Build with azure inferencing support" OFF)
option(onnxruntime_USE_LOCK_FREE_QUEUE "Build with lock-free task queue for threadpool." OFF)
option(onnxruntime_FORCE_GENERIC_ALGORITHMS "Disable optimized arch-specific algorithms. Use only for testing and debugging generic algorithms." OFF)

# DX for interop feature option
option(onnxruntime_USE_DX_FOR_INTEROP "Build with the DX for Interop feature." OFF)

if (onnxruntime_USE_DX_FOR_INTEROP)
add_compile_definitions(DX_FOR_INTEROP=1)
else()
add_compile_definitions(DX_FOR_INTEROP=0)
endif()

# Vulkan for interop feature option
find_package(Vulkan QUIET)
option(onnxruntime_USE_VULKAN_FOR_INTEROP "Build with the Vulkan for Interop feature." OFF)

if (onnxruntime_USE_VULKAN_FOR_INTEROP AND Vulkan_FOUND)
if (WIN32)
add_compile_definitions(VK_USE_PLATFORM_WIN32_KHR=1)
endif()
add_compile_definitions(VULKAN_FOR_INTEROP=1)
else()
add_compile_definitions(VULKAN_FOR_INTEROP=0)
if (NOT Vulkan_FOUND)
message(STATUS "Vulkan not found. Vulkan interop disabled.")
endif()
endif()

option(onnxruntime_USE_TENSORRT_INTERFACE "Build ONNXRuntime shared lib which is compatible with TensorRT EP interface" OFF)
option(onnxruntime_USE_NV_INTERFACE "Build ONNXRuntime shared lib which is compatible with NV EP interface" OFF)
option(onnxruntime_USE_CUDA_INTERFACE "Build ONNXRuntime shared lib which is compatible with Cuda EP interface" OFF)
Expand Down Expand Up @@ -1198,6 +1223,9 @@ function(onnxruntime_configure_target target_name)
set_target_properties(${target_name} PROPERTIES VS_USER_PROPS ${PROJECT_SOURCE_DIR}/EnableVisualStudioCodeAnalysis.props)
endif()
target_include_directories(${target_name} PRIVATE ${CMAKE_CURRENT_BINARY_DIR} ${ONNXRUNTIME_ROOT})
if(Vulkan_FOUND)
target_include_directories(${target_name} PRIVATE ${Vulkan_INCLUDE_DIRS})
endif()
if (onnxruntime_ENABLE_TRAINING_OPS)
target_include_directories(${target_name} PRIVATE ${ORTTRAINING_ROOT})
endif()
Expand Down
4 changes: 2 additions & 2 deletions cmake/onnxruntime_providers_nv.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,9 @@ endif ()
target_link_libraries(onnxruntime_providers_nv_tensorrt_rtx PRIVATE Eigen3::Eigen onnx flatbuffers::flatbuffers Boost::mp11 safeint_interface Eigen3::Eigen)
add_dependencies(onnxruntime_providers_nv_tensorrt_rtx onnxruntime_providers_shared ${onnxruntime_EXTERNAL_DEPENDENCIES})
if (onnxruntime_USE_TENSORRT_BUILTIN_PARSER)
target_link_libraries(onnxruntime_providers_nv_tensorrt_rtx PRIVATE ${trt_link_libs} ${ONNXRUNTIME_PROVIDERS_SHARED} ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11 safeint_interface ${ABSEIL_LIBS} PUBLIC CUDA::cudart)
target_link_libraries(onnxruntime_providers_nv_tensorrt_rtx PRIVATE ${trt_link_libs} ${ONNXRUNTIME_PROVIDERS_SHARED} ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11 safeint_interface ${ABSEIL_LIBS} PUBLIC CUDA::cudart CUDA::cuda_driver)
else()
target_link_libraries(onnxruntime_providers_nv_tensorrt_rtx PRIVATE ${onnxparser_link_libs} ${trt_link_libs} ${ONNXRUNTIME_PROVIDERS_SHARED} ${PROTOBUF_LIB} flatbuffers::flatbuffers ${ABSEIL_LIBS} PUBLIC CUDA::cudart)
target_link_libraries(onnxruntime_providers_nv_tensorrt_rtx PRIVATE ${onnxparser_link_libs} ${trt_link_libs} ${ONNXRUNTIME_PROVIDERS_SHARED} ${PROTOBUF_LIB} flatbuffers::flatbuffers ${ABSEIL_LIBS} PUBLIC CUDA::cudart CUDA::cuda_driver)
endif()
target_include_directories(onnxruntime_providers_nv_tensorrt_rtx PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${TENSORRT_RTX_INCLUDE_DIR} ${onnx_tensorrt_SOURCE_DIR}
PUBLIC ${CUDAToolkit_INCLUDE_DIRS})
Expand Down
54 changes: 50 additions & 4 deletions csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,10 @@ public struct OrtApi
public IntPtr Graph_GetModelMetadata;
public IntPtr GetModelCompatibilityForEpDevices;
public IntPtr CreateExternalInitializerInfo;

public IntPtr GetOrtFenceForGraphicsInterop;
public IntPtr InteropEpWait;
public IntPtr InteropEpSignal;
}

internal static class NativeMethods
Expand Down Expand Up @@ -482,7 +486,7 @@ static NativeMethods()
DOrtGetApi OrtGetApi = (DOrtGetApi)Marshal.GetDelegateForFunctionPointer(OrtGetApiBase().GetApi, typeof(DOrtGetApi));
#endif

const uint ORT_API_VERSION = 14;
const uint ORT_API_VERSION = 15;
#if NETSTANDARD2_0
IntPtr ortApiPtr = OrtGetApi(ORT_API_VERSION);
api_ = (OrtApi)Marshal.PtrToStructure(ortApiPtr, typeof(OrtApi));
Expand Down Expand Up @@ -847,7 +851,7 @@ static NativeMethods()
api_.CreateSyncStreamForEpDevice,
typeof(DOrtCreateSyncStreamForEpDevice));

OrtSyncStream_GetHandle =
OrtSyncStream_GetHandle =
(DOrtSyncStream_GetHandle)Marshal.GetDelegateForFunctionPointer(
api_.SyncStream_GetHandle,
typeof(DOrtSyncStream_GetHandle));
Expand All @@ -861,6 +865,21 @@ static NativeMethods()
(DOrtCopyTensors)Marshal.GetDelegateForFunctionPointer(
api_.CopyTensors,
typeof(DOrtCopyTensors));

OrtGetOrtFenceForGraphicsInterop =
(DOrtGetOrtFenceForGraphicsInterop)Marshal.GetDelegateForFunctionPointer(
api_.GetOrtFenceForGraphicsInterop,
typeof(DOrtGetOrtFenceForGraphicsInterop));

OrtInteropEpWait =
(DOrtInteropEpWait)Marshal.GetDelegateForFunctionPointer(
api_.InteropEpWait,
typeof(DOrtInteropEpWait));

OrtInteropEpSignal =
(DOrtInteropEpSignal)Marshal.GetDelegateForFunctionPointer(
api_.InteropEpSignal,
typeof(DOrtInteropEpSignal));
}

internal class NativeLib
Expand Down Expand Up @@ -2644,7 +2663,7 @@ public delegate void DOrtAddKeyValuePair(IntPtr /* OrtKeyValuePairs* */ kvps,
byte[] /* const char* */ value);

/// <summary>
/// Get the value for the provided key.
/// Get the value for the provided key.
/// </summary>
/// <returns>Value. Returns IntPtr.Zero if key was not found.</returns>
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
Expand Down Expand Up @@ -2743,6 +2762,30 @@ public delegate void DOrtRemoveKeyValuePair(IntPtr /* OrtKeyValuePairs* */ kvps,
out IntPtr /* OrtSyncStream** */ stream
);

[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /* OrtStatus* */ DOrtGetOrtFenceForGraphicsInterop(
IntPtr /* OrtSession* */ session,
IntPtr /* struct GraphicsInteropParams* */ graphicsInteropParams,
IntPtr /* struct FenceInteropParams* */ fenceInteropParams,
out IntPtr /* void** */ ortFence
);

[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /* OrtStatus* */ DOrtInteropEpWait(
IntPtr /* OrtSession* */ session,
IntPtr /* void* */ ortFence,
IntPtr /* OrtSyncStream* */ stream,
uint /* uint64_t */ fenceValue
);

[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /* OrtStatus* */ DOrtInteropEpSignal(
IntPtr /* OrtSession* */ session,
IntPtr /* void* */ ortFence,
IntPtr /* OrtSyncStream* */ stream,
uint /* uint64_t */ fenceValue
);

[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /* void* */ DOrtSyncStream_GetHandle(
IntPtr /* OrtSyncStream* */ stream
Expand All @@ -2760,14 +2803,17 @@ out IntPtr /* OrtSyncStream** */ stream
public static DOrtEpDevice_Device OrtEpDevice_Device;
public static DOrtEpDevice_MemoryInfo OrtEpDevice_MemoryInfo;
public static DOrtCreateSyncStreamForEpDevice OrtCreateSyncStreamForEpDevice;
public static DOrtGetOrtFenceForGraphicsInterop OrtGetOrtFenceForGraphicsInterop;
public static DOrtInteropEpWait OrtInteropEpWait;
public static DOrtInteropEpSignal OrtInteropEpSignal;
public static DOrtSyncStream_GetHandle OrtSyncStream_GetHandle;
public static DOrtReleaseSyncStream OrtReleaseSyncStream;

//
// Auto Selection EP registration and selection customization

/// <summary>
/// Register an execution provider library.
/// Register an execution provider library.
/// The library must implement CreateEpFactories and ReleaseEpFactory.
/// </summary>
/// <param name="env">Environment to add the EP library to.</param>
Expand Down
88 changes: 88 additions & 0 deletions include/onnxruntime/core/framework/execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,14 @@ class GraphOptimizerRegistry;
#include "core/framework/tuning_context.h"
#include "core/session/onnxruntime_c_api.h"

#if DX_FOR_INTEROP && _WIN32
#include <d3d12.h>
#endif

#if VULKAN_FOR_INTEROP
#include <vulkan/vulkan.h>
#endif

struct OrtEpDevice;
struct OrtRunOptions;

Expand Down Expand Up @@ -92,6 +100,86 @@ class IExecutionProvider {
public:
virtual ~IExecutionProvider() = default;

virtual Status GetExtSemaphore(const struct GraphicsInteropParams* graphicsInteropParams, struct FenceInteropParams* fenceInteropParams, void** extSemFence) {
auto interop_params_sptr = std::make_shared<FenceInteropParams>(*fenceInteropParams);
*extSemFence = new std::shared_ptr<FenceInteropParams>(interop_params_sptr);
ORT_UNUSED_PARAMETER(graphicsInteropParams);
return Status::OK();
}

virtual Status SetupInteropEpWait(void* extSemFence, OrtSyncStream* stream, uint64_t fenceValue) {
ORT_UNUSED_PARAMETER(stream);
auto* sptr_ptr = static_cast<std::shared_ptr<FenceInteropParams>*>(extSemFence);
std::shared_ptr<FenceInteropParams> interopWaitParamsSptr = *sptr_ptr;
delete sptr_ptr;

auto* interopWaitParams = interopWaitParamsSptr.get();

ExternalSyncPrimitive extSyncPrimitive = interopWaitParams->extSyncPrimitive;
// to-do: The fallback logic needs more refinement to deal with multi threaded scenarios.
if (extSyncPrimitive == ExternalSyncPrimitive_D3D12Fence) {
#if DX_FOR_INTEROP && _WIN32
HANDLE hEvent = CreateEvent(nullptr, FALSE, FALSE, nullptr);
reinterpret_cast<ID3D12Fence*>(interopWaitParams->FencePtr.pFence)->SetEventOnCompletion(fenceValue, hEvent);
WaitForSingleObject(hEvent, INFINITE);
CloseHandle(hEvent);
return Status::OK();
#endif
}
else if(extSyncPrimitive == ExternalSyncPrimitive_VulkanSemaphore)
{
#if VULKAN_FOR_INTEROP
PFN_vkWaitForFences pfnVkWaitForFences = reinterpret_cast<PFN_vkWaitForFences>(
reinterpret_cast<PFN_vkGetDeviceProcAddr>(interopWaitParams->VulkanDeviceParams.pVkGetDeviceProcAddr)(
reinterpret_cast<VkDevice>(interopWaitParams->VulkanDeviceParams.pVkDevice), "vkWaitForFences"));

if (!pfnVkWaitForFences) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to get function pointer for vkWaitForFences");
}
VkResult result = pfnVkWaitForFences(reinterpret_cast<VkDevice>(interopWaitParams->VulkanDeviceParams.pVkDevice), 1, reinterpret_cast<const VkFence*>(&interopWaitParams->FencePtr.pVkFence), VK_TRUE, UINT64_MAX);

if (result != VK_SUCCESS) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "vkWaitForFences failed with Vulkan error code: " + std::to_string(result));
}

PFN_vkResetFences pfnVkResetFences = reinterpret_cast<PFN_vkResetFences>(
reinterpret_cast<PFN_vkGetDeviceProcAddr>(interopWaitParams->VulkanDeviceParams.pVkGetDeviceProcAddr)(
reinterpret_cast<VkDevice>(interopWaitParams->VulkanDeviceParams.pVkDevice), "vkResetFences"));
if (pfnVkResetFences) {
pfnVkResetFences(reinterpret_cast<VkDevice>(interopWaitParams->VulkanDeviceParams.pVkDevice), 1, reinterpret_cast<const VkFence*>(&interopWaitParams->FencePtr.pVkFence));
}

return Status::OK();
#endif
}
ORT_UNUSED_PARAMETER(fenceValue);
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported External Sync primitive");
}
virtual Status SetupInteropEpSignal(const OrtEpApi* ortEpApi, void* extSemFence, OrtSyncStream* stream, uint64_t fenceValue) {
ORT_UNUSED_PARAMETER(extSemFence);
ORT_UNUSED_PARAMETER(fenceValue);

const OrtSyncStreamImpl* streamImpl;
OrtSyncNotificationImpl* streamNotification;
streamImpl = ortEpApi->SyncStream_GetImpl(static_cast<OrtSyncStream*>(stream));

OrtStatus* status = nullptr;
status = streamImpl->CreateNotification(const_cast<OrtSyncStreamImpl*>(streamImpl), &streamNotification);
if(status != nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to create notification");
}

status = streamNotification->Activate(streamNotification);
if(status != nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to activate notification");
}
status = streamNotification->WaitOnHost(streamNotification);
if(status != nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to wait on host");
}
return Status::OK();
}

/**
* Returns a data transfer object that implements methods to copy to and
* from this device.
Expand Down
Loading
Loading