Skip to content

Commit 2b659e4

Browse files
authored
[TRT/TRT RTX EP] Fix bug for missing outputs in the returning ComputeCapability/IndexedSubGraph (#26444)
### Description For TRT EP's `GetCapability()`, in some case, the `GetSubGraph()` won't add graph's output to the `ComputeCapability/IndexedSubGraph` returning to ORT. The issue if from following code: ````c++ ... if (node->GetOutputEdgesCount() > node->OutputDefs().size()) { ... // execute here } else { ... if (graph_output_names.find(output->Name()) != graph_output_names.end()) { graph_outputs_to_add[output] = output_order; // missing this } } ```` Update TRT RTX EP as well. ### Motivation and Context #25373
1 parent c156e93 commit 2b659e4

File tree

8 files changed

+303
-63
lines changed

8 files changed

+303
-63
lines changed

onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc

Lines changed: 46 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1407,9 +1407,30 @@ std::unique_ptr<IndexedSubGraph> NvExecutionProvider::GetSubGraph(SubGraph_t gra
14071407
}
14081408

14091409
// Find inputs and outputs of the subgraph
1410+
14101411
std::unique_ptr<IndexedSubGraph> sub_graph = onnxruntime::IndexedSubGraph::Create();
1411-
std::unordered_map<const NodeArg*, int> original_inputs, fused_inputs, fused_outputs, fused_outputs_to_add, graph_outputs_to_add;
1412+
std::unordered_map<const NodeArg*, int> original_inputs;
1413+
1414+
// These maps store the inputs and outputs of the subgraph.
1415+
// Please note that the inputs and outputs of the maps will be dynamically updated during node iteration
1416+
// to determine the final inputs and outputs of the subgraph.
1417+
std::unordered_map<const NodeArg*, int> fused_inputs, fused_outputs;
1418+
1419+
// This map stores the node's output that will be consumed by another node outside of this subgraph.
1420+
// So the node's output should be put into the subgraph's output list.
1421+
std::unordered_map<const NodeArg*, int> fused_outputs_to_add;
1422+
1423+
// This map stores the node's output that is original graph's output.
1424+
// So the node's output should be put into the subgraph's output list.
1425+
std::unordered_map<const NodeArg*, int> graph_outputs_to_add;
1426+
14121427
std::unordered_set<const NodeArg*> erased;
1428+
1429+
// This is the relative ordering that ensures node's input or output being added to the 'fused_inputs',
1430+
// 'fused_outputs', 'fused_outputs_to_add' and 'graph_outputs_to_add' maps is associated with a relative order index.
1431+
// Items added earlier receive a smaller order index than items added later.
1432+
// When constructing the final sub_graph's input or output lists, entries with smaller
1433+
// order indices will appear before those with larger indices.
14131434
int input_order = 0;
14141435
int output_order = 0;
14151436

@@ -1428,7 +1449,7 @@ std::unique_ptr<IndexedSubGraph> NvExecutionProvider::GetSubGraph(SubGraph_t gra
14281449
erased.insert(input);
14291450
} else if (erased.find(input) == erased.end()) {
14301451
// Only when input is neither in output list nor erased list, add the input to input list
1431-
fused_inputs[input] = input_order++;
1452+
fused_inputs.insert({input, input_order++});
14321453
}
14331454
}
14341455

@@ -1443,7 +1464,7 @@ std::unique_ptr<IndexedSubGraph> NvExecutionProvider::GetSubGraph(SubGraph_t gra
14431464
erased.insert(input);
14441465
} else if (erased.find(input) == erased.end()) {
14451466
// Only when input is neither in output list nor erased list, add the input to input list
1446-
fused_inputs[input] = input_order++;
1467+
fused_inputs.insert({input, input_order++});
14471468
}
14481469
}
14491470

@@ -1464,39 +1485,33 @@ std::unique_ptr<IndexedSubGraph> NvExecutionProvider::GetSubGraph(SubGraph_t gra
14641485
} else {
14651486
output = (it->GetNode()).ImplicitInputDefs()[it->GetDstArgIndex() - static_cast<int>(it->GetNode().InputDefs().size())];
14661487
}
1467-
if (node_set.find(node_idx) != node_set.end()) {
1468-
const auto& iter = fused_inputs.find(output);
1469-
if (iter != fused_inputs.end()) {
1470-
fused_inputs.erase(iter);
1471-
erased.insert(output);
1472-
} else if (erased.find(output) == erased.end()) {
1473-
if (graph_output_names.find(output->Name()) != graph_output_names.end()) {
1474-
graph_outputs_to_add[output] = output_order;
1475-
}
1476-
fused_outputs[output] = output_order++;
1477-
}
1478-
} else {
1479-
fused_outputs_to_add[output] = output_order++;
1488+
1489+
if (node_set.find(node_idx) == node_set.end()) {
1490+
// This output will be consumed by another node outside of this subgraph.
1491+
// So the output should be put into the subgraph's output list.
1492+
fused_outputs_to_add.insert({output, output_order++});
14801493
}
14811494
}
1482-
} else {
1483-
for (const auto& output : node->OutputDefs()) {
1484-
const auto& it = fused_inputs.find(output);
1485-
if (it != fused_inputs.end()) {
1486-
fused_inputs.erase(it);
1487-
erased.insert(output);
1488-
}
1489-
// Only when output is neither in input list nor erased list, and the output is consumed by another node, add the output to output list
1490-
else if (erased.find(output) == erased.end()) {
1491-
if (graph_output_names.find(output->Name()) != graph_output_names.end()) {
1492-
graph_outputs_to_add[output] = output_order;
1493-
}
1495+
}
14941496

1495-
if (graph.GetGraph().GetConsumerNodes(output->Name()).size() > 0) {
1496-
fused_outputs[output] = output_order++;
1497-
}
1497+
for (const auto& output : node->OutputDefs()) {
1498+
const auto& it = fused_inputs.find(output);
1499+
if (it != fused_inputs.end()) {
1500+
fused_inputs.erase(it);
1501+
erased.insert(output);
1502+
} else if (erased.find(output) == erased.end()) {
1503+
if (graph.GetGraph().GetConsumerNodes(output->Name()).size() > 0) {
1504+
// Only when output is neither in input list nor erased list,
1505+
// and the output is consumed by another node, add the output to output list
1506+
fused_outputs.insert({output, output_order++});
14981507
}
14991508
}
1509+
1510+
if (graph_output_names.find(output->Name()) != graph_output_names.end()) {
1511+
// This output is the graph's output.
1512+
// So the output should be put into the subgraph's output list.
1513+
graph_outputs_to_add.insert({output, output_order++});
1514+
}
15001515
}
15011516
}
15021517

onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc

Lines changed: 46 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2035,9 +2035,30 @@ std::unique_ptr<IndexedSubGraph> TensorrtExecutionProvider::GetSubGraph(SubGraph
20352035
}
20362036

20372037
// Find inputs and outputs of the subgraph
2038+
20382039
std::unique_ptr<IndexedSubGraph> sub_graph = onnxruntime::IndexedSubGraph::Create();
2039-
std::unordered_map<const NodeArg*, int> original_inputs, fused_inputs, fused_outputs, fused_outputs_to_add, graph_outputs_to_add;
2040+
std::unordered_map<const NodeArg*, int> original_inputs;
2041+
2042+
// These maps store the inputs and outputs of the subgraph.
2043+
// Please note that the inputs and outputs of the maps will be dynamically updated during node iteration
2044+
// to determine the final inputs and outputs of the subgraph.
2045+
std::unordered_map<const NodeArg*, int> fused_inputs, fused_outputs;
2046+
2047+
// This map stores the node's output that will be consumed by another node outside of this subgraph.
2048+
// So the node's output should be put into the subgraph's output list.
2049+
std::unordered_map<const NodeArg*, int> fused_outputs_to_add;
2050+
2051+
// This map stores the node's output that is original graph's output.
2052+
// So the node's output should be put into the subgraph's output list.
2053+
std::unordered_map<const NodeArg*, int> graph_outputs_to_add;
2054+
20402055
std::unordered_set<const NodeArg*> erased;
2056+
2057+
// This is the relative ordering that ensures node's input or output being added to the 'fused_inputs',
2058+
// 'fused_outputs', 'fused_outputs_to_add' and 'graph_outputs_to_add' maps is associated with a relative order index.
2059+
// Items added earlier receive a smaller order index than items added later.
2060+
// When constructing the final sub_graph's input or output lists, entries with smaller
2061+
// order indices will appear before those with larger indices.
20412062
int input_order = 0;
20422063
int output_order = 0;
20432064

@@ -2056,7 +2077,7 @@ std::unique_ptr<IndexedSubGraph> TensorrtExecutionProvider::GetSubGraph(SubGraph
20562077
erased.insert(input);
20572078
} else if (erased.find(input) == erased.end()) {
20582079
// Only when input is neither in output list nor erased list, add the input to input list
2059-
fused_inputs[input] = input_order++;
2080+
fused_inputs.insert({input, input_order++});
20602081
}
20612082
}
20622083

@@ -2071,7 +2092,7 @@ std::unique_ptr<IndexedSubGraph> TensorrtExecutionProvider::GetSubGraph(SubGraph
20712092
erased.insert(input);
20722093
} else if (erased.find(input) == erased.end()) {
20732094
// Only when input is neither in output list nor erased list, add the input to input list
2074-
fused_inputs[input] = input_order++;
2095+
fused_inputs.insert({input, input_order++});
20752096
}
20762097
}
20772098

@@ -2092,39 +2113,33 @@ std::unique_ptr<IndexedSubGraph> TensorrtExecutionProvider::GetSubGraph(SubGraph
20922113
} else {
20932114
output = (it->GetNode()).ImplicitInputDefs()[it->GetDstArgIndex() - static_cast<int>(it->GetNode().InputDefs().size())];
20942115
}
2095-
if (node_set.find(node_idx) != node_set.end()) {
2096-
const auto& iter = fused_inputs.find(output);
2097-
if (iter != fused_inputs.end()) {
2098-
fused_inputs.erase(iter);
2099-
erased.insert(output);
2100-
} else if (erased.find(output) == erased.end()) {
2101-
if (graph_output_names.find(output->Name()) != graph_output_names.end()) {
2102-
graph_outputs_to_add[output] = output_order;
2103-
}
2104-
fused_outputs[output] = output_order++;
2105-
}
2106-
} else {
2107-
fused_outputs_to_add[output] = output_order++;
2116+
2117+
if (node_set.find(node_idx) == node_set.end()) {
2118+
// This output will be consumed by another node outside of this subgraph.
2119+
// So the output should be put into the subgraph's output list.
2120+
fused_outputs_to_add.insert({output, output_order++});
21082121
}
21092122
}
2110-
} else {
2111-
for (const auto& output : node->OutputDefs()) {
2112-
const auto& it = fused_inputs.find(output);
2113-
if (it != fused_inputs.end()) {
2114-
fused_inputs.erase(it);
2115-
erased.insert(output);
2116-
}
2117-
// Only when output is neither in input list nor erased list, and the output is consumed by another node, add the output to output list
2118-
else if (erased.find(output) == erased.end()) {
2119-
if (graph_output_names.find(output->Name()) != graph_output_names.end()) {
2120-
graph_outputs_to_add[output] = output_order;
2121-
}
2123+
}
21222124

2123-
if (graph.GetGraph().GetConsumerNodes(output->Name()).size() > 0) {
2124-
fused_outputs[output] = output_order++;
2125-
}
2125+
for (const auto& output : node->OutputDefs()) {
2126+
const auto& it = fused_inputs.find(output);
2127+
if (it != fused_inputs.end()) {
2128+
fused_inputs.erase(it);
2129+
erased.insert(output);
2130+
} else if (erased.find(output) == erased.end()) {
2131+
if (graph.GetGraph().GetConsumerNodes(output->Name()).size() > 0) {
2132+
// Only when output is neither in input list nor erased list,
2133+
// and the output is consumed by another node, add the output to output list
2134+
fused_outputs.insert({output, output_order++});
21262135
}
21272136
}
2137+
2138+
if (graph_output_names.find(output->Name()) != graph_output_names.end()) {
2139+
// This output is the graph's output.
2140+
// So the output should be put into the subgraph's output list.
2141+
graph_outputs_to_add.insert({output, output_order++});
2142+
}
21282143
}
21292144
}
21302145

onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,48 @@ TEST_P(TypeTests, IOTypes) {
203203
}
204204
}
205205

206+
TEST(NvExecutionProviderTest, TestSessionOutputs) {
207+
/*
208+
* Model #1:
209+
*
210+
* "input" ---> TopK ---
211+
* |---> "scores"
212+
* |--- Less ---> "Less_output_0"
213+
* |--- Div ---> "Div_output_0"
214+
* |--- Mod ---> "labels"
215+
*/
216+
{
217+
Ort::SessionOptions session_options;
218+
session_options.AppendExecutionProvider(kNvTensorRTRTXExecutionProvider, {});
219+
220+
auto model_path = ORT_TSTR("testdata/topk_and_multiple_graph_outputs.onnx");
221+
Ort::Session session(*ort_env, model_path, session_options);
222+
223+
size_t output_count = session.GetOutputCount();
224+
ASSERT_TRUE(output_count == 4);
225+
}
226+
227+
/*
228+
* Model #2:
229+
*
230+
* "X" ---> Dropout ---> MatMul ---> "Y"
231+
* ^ |
232+
* | |
233+
* "W" ------ ----> Can't be graph's output
234+
*
235+
*/
236+
{
237+
Ort::SessionOptions session_options;
238+
session_options.AppendExecutionProvider(kNvTensorRTRTXExecutionProvider, {});
239+
240+
auto model_path = ORT_TSTR("testdata/node_output_not_used.onnx");
241+
Ort::Session session(*ort_env, model_path, session_options);
242+
243+
size_t output_count = session.GetOutputCount();
244+
ASSERT_TRUE(output_count == 1);
245+
}
246+
}
247+
206248
INSTANTIATE_TEST_SUITE_P(NvExecutionProviderTest, TypeTests,
207249
::testing::Values(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,
208250
ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16,

onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
// Copyright (c) Microsoft Corporation. All rights reserved.
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
22
// Licensed under the MIT License.
3+
#include "onnxruntime_cxx_api.h"
34
#include "core/graph/onnx_protobuf.h"
45
#include "core/session/inference_session.h"
56
#include "test/providers/provider_test_utils.h"
@@ -18,6 +19,8 @@ using namespace std;
1819
using namespace ONNX_NAMESPACE;
1920
using namespace ::onnxruntime::logging;
2021

22+
extern std::unique_ptr<Ort::Env> ort_env;
23+
2124
namespace onnxruntime {
2225

2326
namespace test {
@@ -1360,5 +1363,49 @@ TEST(TensorrtExecutionProviderTest, RemoveCycleTest) {
13601363
ASSERT_STATUS_OK(session_object.Run(run_options, feeds, output_names, &fetches));
13611364
VerifyOutputs(fetches, expected_dims_mul_m, expected_values_mul_m);
13621365
}
1366+
1367+
TEST(TensorrtExecutionProviderTest, TestSessionOutputs) {
1368+
/*
1369+
* Model #1:
1370+
*
1371+
* "input" ---> TopK ---
1372+
* |---> "scores"
1373+
* |--- Less ---> "Less_output_0"
1374+
* |--- Div ---> "Div_output_0"
1375+
* |--- Mod ---> "labels"
1376+
*/
1377+
{
1378+
OrtTensorRTProviderOptionsV2 provider_options;
1379+
Ort::SessionOptions session_options;
1380+
session_options.AppendExecutionProvider_TensorRT_V2(provider_options);
1381+
1382+
auto model_path = ORT_TSTR("testdata/topk_and_multiple_graph_outputs.onnx");
1383+
Ort::Session session(*ort_env, model_path, session_options);
1384+
1385+
size_t output_count = session.GetOutputCount();
1386+
ASSERT_TRUE(output_count == 4);
1387+
}
1388+
1389+
/*
1390+
* Model #2:
1391+
*
1392+
* "X" ---> Dropout ---> MatMul ---> "Y"
1393+
* ^ |
1394+
* | |
1395+
* "W" ------ ----> Can't be graph's output
1396+
*
1397+
*/
1398+
{
1399+
OrtTensorRTProviderOptionsV2 provider_options;
1400+
Ort::SessionOptions session_options;
1401+
session_options.AppendExecutionProvider_TensorRT_V2(provider_options);
1402+
1403+
auto model_path = ORT_TSTR("testdata/node_output_not_used.onnx");
1404+
Ort::Session session(*ort_env, model_path, session_options);
1405+
1406+
size_t output_count = session.GetOutputCount();
1407+
ASSERT_TRUE(output_count == 1);
1408+
}
1409+
}
13631410
} // namespace test
13641411
} // namespace onnxruntime
189 Bytes
Binary file not shown.
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import onnx
2+
from onnx import TensorProto, helper
3+
4+
5+
def create_model_with_node_output_not_used(model_path):
6+
# Create graph
7+
x = helper.make_tensor_value_info("X", TensorProto.FLOAT, [3, 2])
8+
w = helper.make_tensor_value_info("W", TensorProto.FLOAT, [2, 3])
9+
y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 3])
10+
11+
# Dropout node (two outputs)
12+
dropout_node = helper.make_node(
13+
"Dropout",
14+
inputs=["X"],
15+
outputs=["dropout_out", "dropout_mask"],
16+
name="DropoutNode",
17+
)
18+
19+
# MatMul node
20+
matmul_node = helper.make_node(
21+
"MatMul",
22+
inputs=["dropout_out", "W"],
23+
outputs=["Y"],
24+
name="MatMulNode",
25+
)
26+
27+
graph = helper.make_graph(
28+
nodes=[dropout_node, matmul_node],
29+
name="DropoutMatMulGraph",
30+
inputs=[x, w],
31+
outputs=[y],
32+
)
33+
34+
model = helper.make_model(graph, opset_imports=[helper.make_operatorsetid("", 13)])
35+
36+
onnx.checker.check_model(model)
37+
onnx.save(model, model_path)
38+
39+
print(f"Model saved to: {model_path}")
40+
41+
42+
if __name__ == "__main__":
43+
create_model_with_node_output_not_used("node_output_not_used.onnx")
393 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)