Skip to content

Commit 5aa4f75

Browse files
Add KernelBuilder aliasing API functions. Add a relu kernel to test it.
1 parent a43d7d6 commit 5aa4f75

File tree

13 files changed

+310
-11
lines changed

13 files changed

+310
-11
lines changed

cmake/onnxruntime_unittests.cmake

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2105,6 +2105,8 @@ if (onnxruntime_BUILD_SHARED_LIB AND
21052105
"${TEST_SRC_DIR}/autoep/library/example_kernel_plugin_ep/kernels/data_types.cc"
21062106
"${TEST_SRC_DIR}/autoep/library/example_kernel_plugin_ep/kernels/squeeze.h"
21072107
"${TEST_SRC_DIR}/autoep/library/example_kernel_plugin_ep/kernels/squeeze.cc"
2108+
"${TEST_SRC_DIR}/autoep/library/example_kernel_plugin_ep/kernels/relu.h"
2109+
"${TEST_SRC_DIR}/autoep/library/example_kernel_plugin_ep/kernels/relu.cc"
21082110
"${TEST_SRC_DIR}/autoep/library/example_kernel_plugin_ep/kernels/mul.h"
21092111
"${TEST_SRC_DIR}/autoep/library/example_kernel_plugin_ep/kernels/mul.cc")
21102112
onnxruntime_add_shared_library_module(example_kernel_plugin_ep ${onnxruntime_autoep_test_example_kernel_plugin_ep_src})

include/onnxruntime/core/session/onnxruntime_cxx_api.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3331,6 +3331,8 @@ struct KernelDef : detail::ConstKernelDefImpl<OrtKernelDef> {
33313331
* Used by plugin EPs to build a kernel definition.
33323332
*/
33333333
struct KernelDefBuilder : detail::Base<OrtKernelDefBuilder> {
3334+
using InOutAliasPair = std::pair<int, int>;
3335+
33343336
KernelDefBuilder(); ///< Wraps OrtEpApi::CreateKernelDefBuilder
33353337
explicit KernelDefBuilder(std::nullptr_t) {} ///< Create an empty object, must be assigned a valid one to be used
33363338
explicit KernelDefBuilder(OrtKernelDefBuilder* ort_kernel_def_builder);
@@ -3343,6 +3345,10 @@ struct KernelDefBuilder : detail::Base<OrtKernelDefBuilder> {
33433345
KernelDefBuilder& SetOutputMemType(size_t output_index, OrtMemType mem_type);
33443346
KernelDefBuilder& AddTypeConstraint(const char* arg_name, const OrtMLDataType* data_type);
33453347
KernelDefBuilder& AddTypeConstraint(const char* arg_name, const std::vector<const OrtMLDataType*>& data_types);
3348+
KernelDefBuilder& AddInputOutputAlias(int input_index, int output_index);
3349+
KernelDefBuilder& AddInputOutputAliases(const std::vector<InOutAliasPair>& aliases);
3350+
KernelDefBuilder& AddInputOutputMutableAlias(int input_index, int output_index);
3351+
KernelDefBuilder& AddInputOutputMutableAliases(const std::vector<InOutAliasPair>& aliases);
33463352

33473353
KernelDef Build();
33483354
};

include/onnxruntime/core/session/onnxruntime_cxx_inline.h

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3644,6 +3644,48 @@ inline KernelDefBuilder& KernelDefBuilder::AddTypeConstraint(const char* arg_nam
36443644
return *this;
36453645
}
36463646

3647+
inline KernelDefBuilder& KernelDefBuilder::AddInputOutputAlias(int input_index, int output_index) {
3648+
ThrowOnError(GetEpApi().KernelDefBuilder_AddInputOutputAliases(p_, &input_index, &output_index, 1));
3649+
return *this;
3650+
}
3651+
3652+
inline KernelDefBuilder& KernelDefBuilder::AddInputOutputAliases(const std::vector<InOutAliasPair>& aliases) {
3653+
std::vector<int> input_indices;
3654+
std::vector<int> output_indices;
3655+
3656+
input_indices.reserve(aliases.size());
3657+
output_indices.reserve(aliases.size());
3658+
for (const std::pair<int, int>& alias : aliases) {
3659+
input_indices.push_back(alias.first);
3660+
output_indices.push_back(alias.second);
3661+
}
3662+
3663+
ThrowOnError(GetEpApi().KernelDefBuilder_AddInputOutputAliases(p_, input_indices.data(), output_indices.data(),
3664+
input_indices.size()));
3665+
return *this;
3666+
}
3667+
3668+
inline KernelDefBuilder& KernelDefBuilder::AddInputOutputMutableAlias(int input_index, int output_index) {
3669+
ThrowOnError(GetEpApi().KernelDefBuilder_AddInputOutputMutableAliases(p_, &input_index, &output_index, 1));
3670+
return *this;
3671+
}
3672+
3673+
inline KernelDefBuilder& KernelDefBuilder::AddInputOutputMutableAliases(const std::vector<InOutAliasPair>& aliases) {
3674+
std::vector<int> input_indices;
3675+
std::vector<int> output_indices;
3676+
3677+
input_indices.reserve(aliases.size());
3678+
output_indices.reserve(aliases.size());
3679+
for (const std::pair<int, int>& alias : aliases) {
3680+
input_indices.push_back(alias.first);
3681+
output_indices.push_back(alias.second);
3682+
}
3683+
3684+
ThrowOnError(GetEpApi().KernelDefBuilder_AddInputOutputMutableAliases(p_, input_indices.data(), output_indices.data(),
3685+
input_indices.size()));
3686+
return *this;
3687+
}
3688+
36473689
inline KernelDef KernelDefBuilder::Build() {
36483690
OrtKernelDef* kernel_def = nullptr;
36493691
ThrowOnError(GetEpApi().KernelDefBuilder_Build(p_, &kernel_def));

include/onnxruntime/core/session/onnxruntime_ep_c_api.h

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -683,7 +683,7 @@ struct OrtEpApi {
683683
ORT_API2_STATUS(KernelDefBuilder_SetOutputMemType, _In_ OrtKernelDefBuilder* kernel_def_builder,
684684
_In_ size_t output_index, _In_ OrtMemType mem_type);
685685

686-
/** \brief Sets type constraints for a kernel argument represented as a string (e.g., "T").
686+
/** \brief Adds type constraints for a kernel argument represented as a string (e.g., "T").
687687
*
688688
* \param[in] kernel_def_builder The OrtKernelDefBuilder instance.
689689
* \param[in] arg_name A null-terminated string representing the argument to constrain (e.g., "T").
@@ -699,6 +699,45 @@ struct OrtEpApi {
699699
_In_ const char* arg_name, _In_reads_(num_types) const OrtMLDataType* const* types,
700700
_In_ size_t num_types);
701701

702+
/** \brief Adds aliases for the given input and output pairs.
703+
*
704+
* \note Used for operators like Identity and Reshape to allow ORT to reuse the input buffer for the output
705+
* without modification.
706+
*
707+
* \param[in] kernel_def_builder The OrtKernelDefBuilder instance.
708+
* \param[in] input_indices Array of input indices. Array must contain `num_io_indices` elements.
709+
* \param[in] output_indices Array of output indices. Each output index is aliased with a corresponding
710+
* input index in `input_indices`. Array must contain `num_io_indices` elements.
711+
* \param[in] num_io_indices The number of input/output index pairs to alias.
712+
*
713+
* \snippet{doc} snippets.dox OrtStatus Return Value
714+
*
715+
* \since Version 1.24.
716+
*/
717+
ORT_API2_STATUS(KernelDefBuilder_AddInputOutputAliases, _In_ OrtKernelDefBuilder* kernel_def_builder,
718+
_In_reads_(num_io_indices) int const* input_indices,
719+
_In_reads_(num_io_indices) int const* output_indices,
720+
_In_ size_t num_io_indices);
721+
722+
/** \brief Adds mutable aliases for the given input and output pairs.
723+
*
724+
* \note Allows ORT to reuse and *modify* an input buffer (in-place) for the output buffer.
725+
*
726+
* \param[in] kernel_def_builder The OrtKernelDefBuilder instance.
727+
* \param[in] input_indices Array of input indices. Array must contain `num_io_indices` elements.
728+
* \param[in] output_indices Array of output indices. Each output index is aliased with a corresponding
729+
* input index in `input_indices`. Array must contain `num_io_indices` elements.
730+
* \param[in] num_io_indices The number of input/output index pairs to alias.
731+
*
732+
* \snippet{doc} snippets.dox OrtStatus Return Value
733+
*
734+
* \since Version 1.24.
735+
*/
736+
ORT_API2_STATUS(KernelDefBuilder_AddInputOutputMutableAliases, _In_ OrtKernelDefBuilder* kernel_def_builder,
737+
_In_reads_(num_io_indices) int const* input_indices,
738+
_In_reads_(num_io_indices) int const* output_indices,
739+
_In_ size_t num_io_indices);
740+
702741
/** \brief Creates a OrtKernelDef instance from the given kernel definition builder.
703742
*
704743
* \param[in] kernel_def_builder The OrtKernelDefBuilder instance.

onnxruntime/core/session/plugin_ep/ep_api.cc

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,76 @@ ORT_API_STATUS_IMPL(KernelDefBuilder_AddTypeConstraint, _In_ OrtKernelDefBuilder
364364
API_IMPL_END
365365
}
366366

367+
ORT_API_STATUS_IMPL(KernelDefBuilder_AddInputOutputAliases, _In_ OrtKernelDefBuilder* kernel_def_builder,
368+
_In_reads_(num_io_indices) int const* input_indices,
369+
_In_reads_(num_io_indices) int const* output_indices,
370+
_In_ size_t num_io_indices) {
371+
API_IMPL_BEGIN
372+
if (num_io_indices == 0) {
373+
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Must specify at least one input/output alias");
374+
}
375+
376+
if (input_indices == nullptr) {
377+
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Must specify a valid array of input indices to alias");
378+
}
379+
380+
if (output_indices == nullptr) {
381+
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Must specify a valid array of output indices to alias");
382+
}
383+
384+
if (num_io_indices == 1) {
385+
kernel_def_builder->Alias(input_indices[0], output_indices[0]);
386+
} else {
387+
std::vector<std::pair<int, int>> pairs;
388+
pairs.reserve(num_io_indices);
389+
390+
for (size_t i = 0; i < num_io_indices; ++i) {
391+
pairs.push_back({input_indices[i], output_indices[i]});
392+
}
393+
394+
kernel_def_builder->Alias(pairs);
395+
}
396+
397+
return nullptr;
398+
API_IMPL_END
399+
}
400+
401+
ORT_API_STATUS_IMPL(KernelDefBuilder_AddInputOutputMutableAliases, _In_ OrtKernelDefBuilder* kernel_def_builder,
402+
_In_reads_(num_io_indices) int const* input_indices,
403+
_In_reads_(num_io_indices) int const* output_indices,
404+
_In_ size_t num_io_indices) {
405+
API_IMPL_BEGIN
406+
if (num_io_indices == 0) {
407+
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Must specify at least one input/output alias (mutable)");
408+
}
409+
410+
if (input_indices == nullptr) {
411+
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT,
412+
"Must specify a valid array of input indices to alias (mutable)");
413+
}
414+
415+
if (output_indices == nullptr) {
416+
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT,
417+
"Must specify a valid array of output indices to alias (mutable)");
418+
}
419+
420+
if (num_io_indices == 1) {
421+
kernel_def_builder->MayInplace(input_indices[0], output_indices[0]);
422+
} else {
423+
std::vector<std::pair<int, int>> pairs;
424+
pairs.reserve(num_io_indices);
425+
426+
for (size_t i = 0; i < num_io_indices; ++i) {
427+
pairs.push_back({input_indices[i], output_indices[i]});
428+
}
429+
430+
kernel_def_builder->MayInplace(pairs);
431+
}
432+
433+
return nullptr;
434+
API_IMPL_END
435+
}
436+
367437
ORT_API_STATUS_IMPL(KernelDefBuilder_Build, _In_ OrtKernelDefBuilder* kernel_def_builder,
368438
_Outptr_ OrtKernelDef** kernel_def_out) {
369439
API_IMPL_BEGIN
@@ -550,6 +620,8 @@ static constexpr OrtEpApi ort_ep_api = {
550620
&OrtExecutionProviderApi::KernelDefBuilder_SetInputMemType,
551621
&OrtExecutionProviderApi::KernelDefBuilder_SetOutputMemType,
552622
&OrtExecutionProviderApi::KernelDefBuilder_AddTypeConstraint,
623+
&OrtExecutionProviderApi::KernelDefBuilder_AddInputOutputAliases,
624+
&OrtExecutionProviderApi::KernelDefBuilder_AddInputOutputMutableAliases,
553625
&OrtExecutionProviderApi::KernelDefBuilder_Build,
554626
&OrtExecutionProviderApi::ReleaseKernelDef,
555627
&OrtExecutionProviderApi::KernelDef_GetOperatorType,

onnxruntime/core/session/plugin_ep/ep_api.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,14 @@ ORT_API_STATUS_IMPL(KernelDefBuilder_SetOutputMemType, _In_ OrtKernelDefBuilder*
7373
ORT_API_STATUS_IMPL(KernelDefBuilder_AddTypeConstraint, _In_ OrtKernelDefBuilder* kernel_def_builder,
7474
_In_ const char* arg_name, _In_reads_(num_types) const OrtMLDataType* const* types,
7575
_In_ size_t num_types);
76+
ORT_API_STATUS_IMPL(KernelDefBuilder_AddInputOutputAliases, _In_ OrtKernelDefBuilder* kernel_def_builder,
77+
_In_reads_(num_io_indices) int const* input_indices,
78+
_In_reads_(num_io_indices) int const* output_indices,
79+
_In_ size_t num_io_indices);
80+
ORT_API_STATUS_IMPL(KernelDefBuilder_AddInputOutputMutableAliases, _In_ OrtKernelDefBuilder* kernel_def_builder,
81+
_In_reads_(num_io_indices) int const* input_indices,
82+
_In_reads_(num_io_indices) int const* output_indices,
83+
_In_ size_t num_io_indices);
7684
ORT_API_STATUS_IMPL(KernelDefBuilder_Build, _In_ OrtKernelDefBuilder* kernel_def_builder,
7785
_Outptr_ OrtKernelDef** kernel_def_out);
7886

onnxruntime/test/autoep/library/example_kernel_plugin_ep/ep_kernel_registration.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,13 @@
88

99
// Include kernels:
1010
#include "kernels/mul.h"
11+
#include "kernels/relu.h"
1112
#include "kernels/squeeze.h"
1213

1314
// Table of BuildKernelCreateInfo functions for each operator
1415
static const BuildKernelCreateInfoFn build_kernel_create_info_funcs[] = {
1516
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kOnnxDomain, 7, 24, Mul)>,
17+
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kOnnxDomain, 14, 24, Relu)>,
1618
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kOnnxDomain, 13, 24, Squeeze)>,
1719
};
1820

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include "relu.h"
5+
6+
#include <gsl/span>
7+
#include <algorithm>
8+
#include <cassert>
9+
10+
#include "utils.h"
11+
12+
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
13+
Relu,
14+
kOnnxDomain,
15+
14, 24,
16+
(Ort::KernelDefBuilder()
17+
.AddTypeConstraint("T", MLDataTypes::GetTensorType(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT))
18+
.AddInputOutputMutableAlias(0, 0)),
19+
Relu)
20+
21+
Relu::Relu(const OrtKernelInfo* info, void* state, PrivateTag)
22+
: info_(info),
23+
state_(state) {
24+
ort_version_supported = ORT_API_VERSION;
25+
Compute = ComputeImpl;
26+
Release = ReleaseImpl;
27+
}
28+
29+
/*static*/
30+
OrtStatus* Relu::Create(const OrtKernelInfo* info, void* state, /*out*/ std::unique_ptr<Relu>& kernel) {
31+
Ort::ConstKernelInfo kernel_info(info);
32+
33+
try {
34+
kernel = std::make_unique<Relu>(info, state, PrivateTag{});
35+
} catch (const Ort::Exception& ex) {
36+
Ort::Status status(ex);
37+
return status.release();
38+
} catch (const std::exception& ex) {
39+
Ort::Status status(ex.what(), ORT_EP_FAIL);
40+
return status.release();
41+
}
42+
43+
return nullptr;
44+
}
45+
46+
/*static*/
47+
OrtStatus* ORT_API_CALL Relu::ComputeImpl(OrtKernelImpl* this_ptr, OrtKernelContext* kernel_ctx) noexcept {
48+
Relu* relu = static_cast<Relu*>(this_ptr);
49+
return relu->DoCompute(kernel_ctx);
50+
}
51+
52+
/*static*/
53+
void ORT_API_CALL Relu::ReleaseImpl(OrtKernelImpl* this_ptr) noexcept {
54+
delete static_cast<Relu*>(this_ptr);
55+
}
56+
57+
OrtStatus* Relu::DoCompute(OrtKernelContext* kernel_ctx) noexcept {
58+
// const OrtEpApi& ep_api = Ort::GetEpApi();
59+
Ort::KernelContext kernel_context(kernel_ctx);
60+
(void)this->state_; // NOTE: Unused in this example.
61+
(void)this->info_; // NOTE: Unused in this example.
62+
63+
try {
64+
gsl::span<const float> input0;
65+
std::vector<int64_t> shape0;
66+
RETURN_IF_ERROR(GetKernelInputDataAndShape<float>(kernel_context, 0, input0, shape0));
67+
68+
Ort::UnownedValue output = kernel_context.GetOutput(0, shape0);
69+
float* output_data = output.GetTensorMutableData<float>();
70+
71+
for (size_t i = 0; i < input0.size(); ++i) {
72+
output_data[i] = std::max(0.0f, input0[i]);
73+
}
74+
} catch (const Ort::Exception& ex) {
75+
Ort::Status status(ex);
76+
return status.release();
77+
} catch (const std::exception& ex) {
78+
Ort::Status status(ex.what(), ORT_EP_FAIL);
79+
return status.release();
80+
}
81+
82+
return nullptr;
83+
}
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
6+
#include "utils.h"
7+
#include "../../plugin_ep_utils.h"
8+
9+
// Forward declarations of kernel classes used as template args for BuildKernelCreateInfo
10+
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kOnnxDomain, 14, 24, Relu);
11+
12+
struct Relu : public OrtKernelImpl {
13+
private:
14+
struct PrivateTag {};
15+
16+
public:
17+
static OrtStatus* Create(const OrtKernelInfo* info, void* state, /*out*/ std::unique_ptr<Relu>& kernel);
18+
19+
Relu(const OrtKernelInfo* info, void* state, PrivateTag);
20+
21+
static OrtStatus* ORT_API_CALL ComputeImpl(OrtKernelImpl* this_ptr, OrtKernelContext* kernel_ctx) noexcept;
22+
static void ORT_API_CALL ReleaseImpl(OrtKernelImpl* this_ptr) noexcept;
23+
24+
OrtStatus* DoCompute(OrtKernelContext* kernel_ctx) noexcept;
25+
26+
private:
27+
const OrtKernelInfo* info_;
28+
void* state_{nullptr}; // Custom state passed from OrtEp
29+
};

onnxruntime/test/autoep/library/example_kernel_plugin_ep/kernels/squeeze.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
1414
13, 24,
1515
(Ort::KernelDefBuilder()
1616
.AddTypeConstraint("T", MLDataTypes::GetTensorType(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT))
17-
.AddTypeConstraint("axes", MLDataTypes::GetTensorType(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64))),
17+
.AddTypeConstraint("axes", MLDataTypes::GetTensorType(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64))
18+
.AddInputOutputAlias(0, 0)),
1819
Squeeze)
1920

2021
Squeeze::Squeeze(const OrtKernelInfo* info, void* state, PrivateTag)

0 commit comments

Comments
 (0)