Skip to content

Commit 5834bfe

Browse files
authored
Add API to access config entries from KernelInfo (#26589)
## Description This PR adds a new API function `KernelInfo_GetConfigEntries` that allows custom operators to access all configuration entries from the `OrtKernelInfo` object during kernel construction. ## Motivation and Context Custom operators may need to access session configuration options to adjust their behavior. Previously, there was no way to retrieve all config entries from `KernelInfo`. This PR provides a convenient method to get all configuration key-value pairs that were set on the `OrtSessionOptions`. ## Changes ### API Additions - **C API**: Added `KernelInfo_GetConfigEntries` function to `OrtApi` (Version 1.24) - Takes an `OrtKernelInfo*` as input - Returns all config entries as `OrtKeyValuePairs*` - Properly documented with usage examples - **C++ API**: Added `GetConfigEntries()` method to `KernelInfoImpl` template class - Returns `KeyValuePairs` object - Follows existing C++ wrapper patterns ### Implementation - Implemented in `onnxruntime/core/session/custom_ops.cc` - Iterates through `config_options_map` from `OpKernelInfo` - Creates and populates `OrtKeyValuePairs` with all configuration entries ### Testing - Updated `shape_inference_test.cc` with test case - Verifies config entries can be retrieved in custom kernel constructor - Tests both existing and non-existing config keys ## Files Changed - `include/onnxruntime/core/session/onnxruntime_c_api.h` - API declaration - `include/onnxruntime/core/session/onnxruntime_cxx_api.h` - C++ wrapper declaration - `include/onnxruntime/core/session/onnxruntime_cxx_inline.h` - C++ wrapper implementation - `onnxruntime/core/session/custom_ops.cc` - Core implementation - `onnxruntime/core/session/onnxruntime_c_api.cc` - API registration - `onnxruntime/core/session/ort_apis.h` - API header declaration - `onnxruntime/test/framework/shape_inference_test.cc` - Test coverage ## API Version This change is part of ORT API Version 1.24. ## Breaking Changes None. This is a backward-compatible addition to the API.
1 parent 4870d45 commit 5834bfe

File tree

7 files changed

+59
-1
lines changed

7 files changed

+59
-1
lines changed

include/onnxruntime/core/session/onnxruntime_c_api.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6591,6 +6591,23 @@ struct OrtApi {
65916591
* \since Version 1.24
65926592
*/
65936593
ORT_API_T(bool, TensorTypeAndShape_HasShape, _In_ const OrtTensorTypeAndShapeInfo* info);
6594+
6595+
/** \brief Get all config entries from ::OrtKernelInfo.
6596+
*
6597+
* Gets all configuration entries from the ::OrtKernelInfo object as key-value pairs.
6598+
* Config entries are set on the ::OrtSessionOptions and are accessible in custom operator kernels.
6599+
*
6600+
* Used in the CreateKernel callback of an OrtCustomOp to access all session configuration entries
6601+
* during kernel construction.
6602+
*
6603+
* \param[in] info An instance of ::OrtKernelInfo.
6604+
* \param[out] out A pointer to a newly created OrtKeyValuePairs instance containing all config entries.
6605+
* Note: the user should call OrtApi::ReleaseKeyValuePairs.
6606+
*
6607+
* \snippet{doc} snippets.dox OrtStatus Return Value
6608+
* \since Version 1.24
6609+
*/
6610+
ORT_API2_STATUS(KernelInfo_GetConfigEntries, _In_ const OrtKernelInfo* info, _Outptr_ OrtKeyValuePairs** out);
65946611
};
65956612

65966613
/*

include/onnxruntime/core/session/onnxruntime_cxx_api.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2768,6 +2768,8 @@ struct KernelInfoImpl : Base<T> {
27682768

27692769
std::string GetNodeName() const;
27702770
Logger GetLogger() const;
2771+
2772+
KeyValuePairs GetConfigEntries() const;
27712773
};
27722774

27732775
} // namespace detail

include/onnxruntime/core/session/onnxruntime_cxx_inline.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2822,6 +2822,13 @@ inline Logger KernelInfoImpl<T>::GetLogger() const {
28222822
return Logger{out};
28232823
}
28242824

2825+
template <typename T>
2826+
inline KeyValuePairs KernelInfoImpl<T>::GetConfigEntries() const {
2827+
OrtKeyValuePairs* out = nullptr;
2828+
Ort::ThrowOnError(GetApi().KernelInfo_GetConfigEntries(this->p_, &out));
2829+
return KeyValuePairs{out};
2830+
}
2831+
28252832
inline void attr_utils::GetAttr(const OrtKernelInfo* p, const char* name, float& out) {
28262833
Ort::ThrowOnError(GetApi().KernelInfoGetAttribute_float(p, name, &out));
28272834
}

onnxruntime/core/session/custom_ops.cc

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -755,6 +755,21 @@ ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetAllocator, _In_ const OrtKernelInfo* i
755755
});
756756
}
757757

758+
ORT_API_STATUS_IMPL(OrtApis::KernelInfo_GetConfigEntries, _In_ const OrtKernelInfo* info, _Outptr_ OrtKeyValuePairs** out) {
759+
return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr {
760+
const auto* op_info = reinterpret_cast<const onnxruntime::OpKernelInfo*>(info);
761+
const auto& config_options_map = op_info->GetConfigOptions().GetConfigOptionsMap();
762+
763+
auto kvps = std::make_unique<OrtKeyValuePairs>();
764+
for (const auto& kv : config_options_map) {
765+
kvps->Add(kv.first.c_str(), kv.second.c_str());
766+
}
767+
768+
*out = kvps.release();
769+
return nullptr;
770+
});
771+
}
772+
758773
ORT_API_STATUS_IMPL(OrtApis::KernelContext_GetScratchBuffer, _In_ const OrtKernelContext* context, _In_ const OrtMemoryInfo* mem_info, _In_ size_t count_or_bytes, _Outptr_ void** out) {
759774
if (count_or_bytes == 0) {
760775
*out = nullptr;

onnxruntime/core/session/onnxruntime_c_api.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4237,6 +4237,7 @@ static constexpr OrtApi ort_api_1_to_24 = {
42374237
// End of Version 23 - DO NOT MODIFY ABOVE (see above text for more information)
42384238

42394239
&OrtApis::TensorTypeAndShape_HasShape,
4240+
&OrtApis::KernelInfo_GetConfigEntries,
42404241
};
42414242

42424243
// OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase.

onnxruntime/core/session/ort_apis.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -751,4 +751,7 @@ ORT_API_STATUS_IMPL(CopyTensors, _In_ const OrtEnv* env,
751751
_In_reads_(num_tensors) OrtValue* const* dst_tensors,
752752
_In_opt_ OrtSyncStream* stream,
753753
_In_ size_t num_tensors);
754+
755+
ORT_API_STATUS_IMPL(KernelInfo_GetConfigEntries, _In_ const OrtKernelInfo* info, _Outptr_ OrtKeyValuePairs** out);
756+
754757
} // namespace OrtApis

onnxruntime/test/framework/shape_inference_test.cc

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,18 @@ TEST_F(ShapeInferenceTest, BasicTest) {
7878

7979
namespace {
8080
struct MyCustomKernelWithOptionalInput {
81-
MyCustomKernelWithOptionalInput(const OrtKernelInfo* /*info*/) {
81+
MyCustomKernelWithOptionalInput(const OrtKernelInfo* info) {
82+
Ort::ConstKernelInfo k_info(info);
83+
84+
Ort::KeyValuePairs kvp = k_info.GetConfigEntries();
85+
86+
EXPECT_NE(nullptr, kvp.GetValue("session.inter_op.allow_spinning"));
87+
EXPECT_STREQ("0", kvp.GetValue("session.inter_op.allow_spinning"));
88+
89+
EXPECT_NE(nullptr, kvp.GetValue("session.intra_op.allow_spinning"));
90+
EXPECT_STREQ("0", kvp.GetValue("session.intra_op.allow_spinning"));
91+
92+
EXPECT_EQ(nullptr, kvp.GetValue("__not__exist__"));
8293
}
8394

8495
OrtStatusPtr ComputeV2(OrtKernelContext* /* context */) const {
@@ -143,6 +154,8 @@ TEST(ShapeInferenceCustomOpTest, custom_op_optional_input_inference_test) {
143154
SessionOptions sess_opts;
144155
sess_opts.inter_op_param.thread_pool_size = 1;
145156
sess_opts.intra_op_param.thread_pool_size = 1;
157+
ASSERT_STATUS_OK(sess_opts.config_options.AddConfigEntry("session.inter_op.allow_spinning", "0"));
158+
ASSERT_STATUS_OK(sess_opts.config_options.AddConfigEntry("session.intra_op.allow_spinning", "0"));
146159

147160
InferenceSessionWrapper session{sess_opts, env, OPTIONAL_INPUT_CUSTOM_OP_MODEL_URI_2};
148161
ASSERT_STATUS_OK(session.AddCustomOpDomains(AsSpan(op_domains)));

0 commit comments

Comments
 (0)