Skip to content
Open
64 changes: 43 additions & 21 deletions onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2035,9 +2035,25 @@ std::unique_ptr<IndexedSubGraph> TensorrtExecutionProvider::GetSubGraph(SubGraph
}

// Find inputs and outputs of the subgraph

std::unique_ptr<IndexedSubGraph> sub_graph = onnxruntime::IndexedSubGraph::Create();
std::unordered_map<const NodeArg*, int> original_inputs, fused_inputs, fused_outputs, fused_outputs_to_add, graph_outputs_to_add;
std::unordered_map<const NodeArg*, int> original_inputs;

// These maps store the inputs and outpus of the subgraph.
// Please note that the inputs and outputs of the maps will be dynamically updated during node iteration
// to determine the final inputs and outputs of the subgraph.
std::unordered_map<const NodeArg*, int> fused_inputs, fused_outputs;

// This map stores the node's output that will be consumed by another node outside of this subgraph.
// So the node's output should be put into the subgraph's output list.
std::unordered_map<const NodeArg*, int> fused_outputs_to_add;

// This map stores the node's output that is original graph's output.
// So the node's output should be put into the subgraph's output list.
std::unordered_map<const NodeArg*, int> graph_outputs_to_add;

std::unordered_set<const NodeArg*> erased;

int input_order = 0;
int output_order = 0;

Expand All @@ -2056,7 +2072,7 @@ std::unique_ptr<IndexedSubGraph> TensorrtExecutionProvider::GetSubGraph(SubGraph
erased.insert(input);
} else if (erased.find(input) == erased.end()) {
// Only when input is neither in output list nor erased list, add the input to input list
fused_inputs[input] = input_order++;
fused_inputs.insert({input, input_order++});
}
}

Expand All @@ -2071,7 +2087,7 @@ std::unique_ptr<IndexedSubGraph> TensorrtExecutionProvider::GetSubGraph(SubGraph
erased.insert(input);
} else if (erased.find(input) == erased.end()) {
// Only when input is neither in output list nor erased list, add the input to input list
fused_inputs[input] = input_order++;
fused_inputs.insert({input, input_order++});
}
}

Expand Down Expand Up @@ -2099,32 +2115,38 @@ std::unique_ptr<IndexedSubGraph> TensorrtExecutionProvider::GetSubGraph(SubGraph
erased.insert(output);
} else if (erased.find(output) == erased.end()) {
if (graph_output_names.find(output->Name()) != graph_output_names.end()) {
graph_outputs_to_add[output] = output_order;
// This output is the graph's output.
// So the output should be put into the subgraph's output list.
graph_outputs_to_add.insert({output, output_order});
}
fused_outputs[output] = output_order++;
fused_outputs.insert({output, output_order++});
}
} else {
fused_outputs_to_add[output] = output_order++;
// This output will be consumed by another node outside of this subgraph.
// So the output should be put into the subgraph's output list.
fused_outputs_to_add.insert({output, output_order++});
}
}
} else {
for (const auto& output : node->OutputDefs()) {
const auto& it = fused_inputs.find(output);
if (it != fused_inputs.end()) {
fused_inputs.erase(it);
erased.insert(output);
}
// Only when output is neither in input list nor erased list, and the output is consumed by another node, add the output to output list
else if (erased.find(output) == erased.end()) {
if (graph_output_names.find(output->Name()) != graph_output_names.end()) {
graph_outputs_to_add[output] = output_order;
}
}

if (graph.GetGraph().GetConsumerNodes(output->Name()).size() > 0) {
fused_outputs[output] = output_order++;
}
for (const auto& output : node->OutputDefs()) {
const auto& it = fused_inputs.find(output);
if (it != fused_inputs.end()) {
fused_inputs.erase(it);
erased.insert(output);
} else if (erased.find(output) == erased.end()) {
if (graph.GetGraph().GetConsumerNodes(output->Name()).size() > 0) {
// Only when output is neither in input list nor erased list,
// and the output is consumed by another node, add the output to output list
fused_outputs.insert({output, output_order++});
}
}

if (graph_output_names.find(output->Name()) != graph_output_names.end()) {
// This output is the graph's output.
// So the output should be put into the subgraph's output list.
graph_outputs_to_add.insert({output, output_order++});
}
}
}

Expand Down
56 changes: 55 additions & 1 deletion onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "onnxruntime_cxx_api.h"
#include "tensorrt_test_utils.h"
#include "core/graph/onnx_protobuf.h"
#include "core/session/inference_session.h"
#include "test/providers/provider_test_utils.h"
Expand Down Expand Up @@ -1358,5 +1360,57 @@ TEST(TensorrtExecutionProviderTest, RemoveCycleTest) {
ASSERT_STATUS_OK(session_object.Run(run_options, feeds, output_names, &fetches));
VerifyOutputs(fetches, expected_dims_mul_m, expected_values_mul_m);
}

TEST(TensorrtExecutionProviderTest, TestSessionOutputs) {
/*
* Model #1:
*
* "input" ---> TopK ---
* |---> "scores"
* |--- Less ---> "Less_output_0"
* |--- Div ---> "Div_output_0"
* |--- Mod ---> "labels"
*/
{
Ort::Env env{ORT_LOGGING_LEVEL_WARNING, "test"};
OrtTensorRTProviderOptionsV2 provider_options;
Ort::SessionOptions session_options;
session_options.AppendExecutionProvider_TensorRT_V2(provider_options);

auto model_path = ORT_TSTR("model_with_topk_and_multiple_graph_outputs.onnx");
Ort::Status status(CreateModelWithTopKWhichContainsGraphOutput(model_path));
ASSERT_TRUE(status.IsOK());

Ort::Session session(env, model_path, session_options);

size_t output_count = session.GetOutputCount();
ASSERT_TRUE(output_count == 4);
}

/*
* Model #2:
*
* "X" ---> Dropout ---> MatMul ---> "Y"
* ^ |
* | |
* "W" ------ ----> Can't be graph's output
*
*/
{
Ort::Env env{ORT_LOGGING_LEVEL_WARNING, "test"};
OrtTensorRTProviderOptionsV2 provider_options;
Ort::SessionOptions session_options;
session_options.AppendExecutionProvider_TensorRT_V2(provider_options);

auto model_path = ORT_TSTR("model_with_node_output_not_used.onnx");
Ort::Status status(CreateModelWithNodeOutputNotUsed(model_path));
ASSERT_TRUE(status.IsOK());

Ort::Session session(env, model_path, session_options);

size_t output_count = session.GetOutputCount();
ASSERT_TRUE(output_count == 1);
}
}
} // namespace test
} // namespace onnxruntime
Loading
Loading