Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 46 additions & 31 deletions onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1407,9 +1407,30 @@ std::unique_ptr<IndexedSubGraph> NvExecutionProvider::GetSubGraph(SubGraph_t gra
}

// Find inputs and outputs of the subgraph

std::unique_ptr<IndexedSubGraph> sub_graph = onnxruntime::IndexedSubGraph::Create();
std::unordered_map<const NodeArg*, int> original_inputs, fused_inputs, fused_outputs, fused_outputs_to_add, graph_outputs_to_add;
std::unordered_map<const NodeArg*, int> original_inputs;

// These maps store the inputs and outputs of the subgraph.
// Please note that the inputs and outputs of the maps will be dynamically updated during node iteration
// to determine the final inputs and outputs of the subgraph.
std::unordered_map<const NodeArg*, int> fused_inputs, fused_outputs;

// This map stores the node's output that will be consumed by another node outside of this subgraph.
// So the node's output should be put into the subgraph's output list.
std::unordered_map<const NodeArg*, int> fused_outputs_to_add;

// This map stores the node's output that is original graph's output.
// So the node's output should be put into the subgraph's output list.
std::unordered_map<const NodeArg*, int> graph_outputs_to_add;

std::unordered_set<const NodeArg*> erased;

// This is the relative ordering that ensures node's input or output being added to the 'fused_inputs',
// 'fused_outputs', 'fused_outputs_to_add' and 'graph_outputs_to_add' maps is associated with a relative order index.
// Items added earlier receive a smaller order index than items added later.
// When constructing the final sub_graph's input or output lists, entries with smaller
// order indices will appear before those with larger indices.
int input_order = 0;
int output_order = 0;

Expand All @@ -1428,7 +1449,7 @@ std::unique_ptr<IndexedSubGraph> NvExecutionProvider::GetSubGraph(SubGraph_t gra
erased.insert(input);
} else if (erased.find(input) == erased.end()) {
// Only when input is neither in output list nor erased list, add the input to input list
fused_inputs[input] = input_order++;
fused_inputs.insert({input, input_order++});
}
}

Expand All @@ -1443,7 +1464,7 @@ std::unique_ptr<IndexedSubGraph> NvExecutionProvider::GetSubGraph(SubGraph_t gra
erased.insert(input);
} else if (erased.find(input) == erased.end()) {
// Only when input is neither in output list nor erased list, add the input to input list
fused_inputs[input] = input_order++;
fused_inputs.insert({input, input_order++});
}
}

Expand All @@ -1464,39 +1485,33 @@ std::unique_ptr<IndexedSubGraph> NvExecutionProvider::GetSubGraph(SubGraph_t gra
} else {
output = (it->GetNode()).ImplicitInputDefs()[it->GetDstArgIndex() - static_cast<int>(it->GetNode().InputDefs().size())];
}
if (node_set.find(node_idx) != node_set.end()) {
const auto& iter = fused_inputs.find(output);
if (iter != fused_inputs.end()) {
fused_inputs.erase(iter);
erased.insert(output);
} else if (erased.find(output) == erased.end()) {
if (graph_output_names.find(output->Name()) != graph_output_names.end()) {
graph_outputs_to_add[output] = output_order;
}
fused_outputs[output] = output_order++;
}
} else {
fused_outputs_to_add[output] = output_order++;

if (node_set.find(node_idx) == node_set.end()) {
// This output will be consumed by another node outside of this subgraph.
// So the output should be put into the subgraph's output list.
fused_outputs_to_add.insert({output, output_order++});
}
}
} else {
for (const auto& output : node->OutputDefs()) {
const auto& it = fused_inputs.find(output);
if (it != fused_inputs.end()) {
fused_inputs.erase(it);
erased.insert(output);
}
// Only when output is neither in input list nor erased list, and the output is consumed by another node, add the output to output list
else if (erased.find(output) == erased.end()) {
if (graph_output_names.find(output->Name()) != graph_output_names.end()) {
graph_outputs_to_add[output] = output_order;
}
}

if (graph.GetGraph().GetConsumerNodes(output->Name()).size() > 0) {
fused_outputs[output] = output_order++;
}
for (const auto& output : node->OutputDefs()) {
const auto& it = fused_inputs.find(output);
if (it != fused_inputs.end()) {
fused_inputs.erase(it);
erased.insert(output);
} else if (erased.find(output) == erased.end()) {
if (graph.GetGraph().GetConsumerNodes(output->Name()).size() > 0) {
// Only when output is neither in input list nor erased list,
// and the output is consumed by another node, add the output to output list
fused_outputs.insert({output, output_order++});
}
}

if (graph_output_names.find(output->Name()) != graph_output_names.end()) {
// This output is the graph's output.
// So the output should be put into the subgraph's output list.
graph_outputs_to_add.insert({output, output_order++});
}
}
}

Expand Down
77 changes: 46 additions & 31 deletions onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2035,9 +2035,30 @@ std::unique_ptr<IndexedSubGraph> TensorrtExecutionProvider::GetSubGraph(SubGraph
}

// Find inputs and outputs of the subgraph

std::unique_ptr<IndexedSubGraph> sub_graph = onnxruntime::IndexedSubGraph::Create();
std::unordered_map<const NodeArg*, int> original_inputs, fused_inputs, fused_outputs, fused_outputs_to_add, graph_outputs_to_add;
std::unordered_map<const NodeArg*, int> original_inputs;

// These maps store the inputs and outputs of the subgraph.
// Please note that the inputs and outputs of the maps will be dynamically updated during node iteration
// to determine the final inputs and outputs of the subgraph.
std::unordered_map<const NodeArg*, int> fused_inputs, fused_outputs;

// This map stores the node's output that will be consumed by another node outside of this subgraph.
// So the node's output should be put into the subgraph's output list.
std::unordered_map<const NodeArg*, int> fused_outputs_to_add;

// This map stores the node's output that is original graph's output.
// So the node's output should be put into the subgraph's output list.
std::unordered_map<const NodeArg*, int> graph_outputs_to_add;

std::unordered_set<const NodeArg*> erased;

// This is the relative ordering that ensures node's input or output being added to the 'fused_inputs',
// 'fused_outputs', 'fused_outputs_to_add' and 'graph_outputs_to_add' maps is associated with a relative order index.
// Items added earlier receive a smaller order index than items added later.
// When constructing the final sub_graph's input or output lists, entries with smaller
// order indices will appear before those with larger indices.
int input_order = 0;
int output_order = 0;

Expand All @@ -2056,7 +2077,7 @@ std::unique_ptr<IndexedSubGraph> TensorrtExecutionProvider::GetSubGraph(SubGraph
erased.insert(input);
} else if (erased.find(input) == erased.end()) {
// Only when input is neither in output list nor erased list, add the input to input list
fused_inputs[input] = input_order++;
fused_inputs.insert({input, input_order++});
}
}

Expand All @@ -2071,7 +2092,7 @@ std::unique_ptr<IndexedSubGraph> TensorrtExecutionProvider::GetSubGraph(SubGraph
erased.insert(input);
} else if (erased.find(input) == erased.end()) {
// Only when input is neither in output list nor erased list, add the input to input list
fused_inputs[input] = input_order++;
fused_inputs.insert({input, input_order++});
}
}

Expand All @@ -2092,39 +2113,33 @@ std::unique_ptr<IndexedSubGraph> TensorrtExecutionProvider::GetSubGraph(SubGraph
} else {
output = (it->GetNode()).ImplicitInputDefs()[it->GetDstArgIndex() - static_cast<int>(it->GetNode().InputDefs().size())];
}
if (node_set.find(node_idx) != node_set.end()) {
const auto& iter = fused_inputs.find(output);
if (iter != fused_inputs.end()) {
fused_inputs.erase(iter);
erased.insert(output);
} else if (erased.find(output) == erased.end()) {
if (graph_output_names.find(output->Name()) != graph_output_names.end()) {
graph_outputs_to_add[output] = output_order;
}
fused_outputs[output] = output_order++;
}
} else {
fused_outputs_to_add[output] = output_order++;

if (node_set.find(node_idx) == node_set.end()) {
// This output will be consumed by another node outside of this subgraph.
// So the output should be put into the subgraph's output list.
fused_outputs_to_add.insert({output, output_order++});
}
}
} else {
for (const auto& output : node->OutputDefs()) {
const auto& it = fused_inputs.find(output);
if (it != fused_inputs.end()) {
fused_inputs.erase(it);
erased.insert(output);
}
// Only when output is neither in input list nor erased list, and the output is consumed by another node, add the output to output list
else if (erased.find(output) == erased.end()) {
if (graph_output_names.find(output->Name()) != graph_output_names.end()) {
graph_outputs_to_add[output] = output_order;
}
}

if (graph.GetGraph().GetConsumerNodes(output->Name()).size() > 0) {
fused_outputs[output] = output_order++;
}
for (const auto& output : node->OutputDefs()) {
const auto& it = fused_inputs.find(output);
if (it != fused_inputs.end()) {
fused_inputs.erase(it);
erased.insert(output);
} else if (erased.find(output) == erased.end()) {
if (graph.GetGraph().GetConsumerNodes(output->Name()).size() > 0) {
// Only when output is neither in input list nor erased list,
// and the output is consumed by another node, add the output to output list
fused_outputs.insert({output, output_order++});
}
}

if (graph_output_names.find(output->Name()) != graph_output_names.end()) {
// This output is the graph's output.
// So the output should be put into the subgraph's output list.
graph_outputs_to_add.insert({output, output_order++});
}
}
}

Expand Down
42 changes: 42 additions & 0 deletions onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,48 @@ TEST_P(TypeTests, IOTypes) {
}
}

TEST(NvExecutionProviderTest, TestSessionOutputs) {
/*
* Model #1:
*
* "input" ---> TopK ---
* |---> "scores"
* |--- Less ---> "Less_output_0"
* |--- Div ---> "Div_output_0"
* |--- Mod ---> "labels"
*/
{
Ort::SessionOptions session_options;
session_options.AppendExecutionProvider(kNvTensorRTRTXExecutionProvider, {});

auto model_path = ORT_TSTR("testdata/topk_and_multiple_graph_outputs.onnx");
Ort::Session session(*ort_env, model_path, session_options);

size_t output_count = session.GetOutputCount();
ASSERT_TRUE(output_count == 4);
}

/*
* Model #2:
*
* "X" ---> Dropout ---> MatMul ---> "Y"
* ^ |
* | |
* "W" ------ ----> Can't be graph's output
*
*/
{
Ort::SessionOptions session_options;
session_options.AppendExecutionProvider(kNvTensorRTRTXExecutionProvider, {});

auto model_path = ORT_TSTR("testdata/node_output_not_used.onnx");
Ort::Session session(*ort_env, model_path, session_options);

size_t output_count = session.GetOutputCount();
ASSERT_TRUE(output_count == 1);
}
}

INSTANTIATE_TEST_SUITE_P(NvExecutionProviderTest, TypeTests,
::testing::Values(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,
ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16,
Expand Down
49 changes: 48 additions & 1 deletion onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "onnxruntime_cxx_api.h"
#include "core/graph/onnx_protobuf.h"
#include "core/session/inference_session.h"
#include "test/providers/provider_test_utils.h"
Expand All @@ -18,6 +19,8 @@ using namespace std;
using namespace ONNX_NAMESPACE;
using namespace ::onnxruntime::logging;

extern std::unique_ptr<Ort::Env> ort_env;

namespace onnxruntime {

namespace test {
Expand Down Expand Up @@ -1360,5 +1363,49 @@ TEST(TensorrtExecutionProviderTest, RemoveCycleTest) {
ASSERT_STATUS_OK(session_object.Run(run_options, feeds, output_names, &fetches));
VerifyOutputs(fetches, expected_dims_mul_m, expected_values_mul_m);
}

TEST(TensorrtExecutionProviderTest, TestSessionOutputs) {
/*
* Model #1:
*
* "input" ---> TopK ---
* |---> "scores"
* |--- Less ---> "Less_output_0"
* |--- Div ---> "Div_output_0"
* |--- Mod ---> "labels"
*/
{
OrtTensorRTProviderOptionsV2 provider_options;
Ort::SessionOptions session_options;
session_options.AppendExecutionProvider_TensorRT_V2(provider_options);

auto model_path = ORT_TSTR("testdata/topk_and_multiple_graph_outputs.onnx");
Ort::Session session(*ort_env, model_path, session_options);

size_t output_count = session.GetOutputCount();
ASSERT_TRUE(output_count == 4);
}

/*
* Model #2:
*
* "X" ---> Dropout ---> MatMul ---> "Y"
* ^ |
* | |
* "W" ------ ----> Can't be graph's output
*
*/
{
OrtTensorRTProviderOptionsV2 provider_options;
Ort::SessionOptions session_options;
session_options.AppendExecutionProvider_TensorRT_V2(provider_options);

auto model_path = ORT_TSTR("testdata/node_output_not_used.onnx");
Ort::Session session(*ort_env, model_path, session_options);

size_t output_count = session.GetOutputCount();
ASSERT_TRUE(output_count == 1);
}
}
} // namespace test
} // namespace onnxruntime
Binary file added onnxruntime/test/testdata/node_output_not_used.onnx
Binary file not shown.
43 changes: 43 additions & 0 deletions onnxruntime/test/testdata/node_output_not_used.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import onnx

Check notice

Code scanning / CodeQL

Module is imported with 'import' and 'import from' Note test

Module 'onnx' is imported with both 'import' and 'import from'.
Module 'onnxruntime.test.onnx' is imported with both 'import' and 'import from'.

Copilot Autofix

AI 2 days ago

To resolve the "Module is imported with both 'import' and 'import from'" issue, remove the from onnx import TensorProto, helper statement and reference TensorProto and helper via the onnx module (that is, use onnx.TensorProto and onnx.helper). Update all usages of helper and TensorProto in the code accordingly. No additional dependencies or code structure changes are required. Only lines in onnxruntime/test/testdata/node_output_not_used.py handling imports and references to helper and TensorProto need to be changed.

Suggested changeset 1
onnxruntime/test/testdata/node_output_not_used.py

Autofix patch

Autofix patch
Run the following command in your local git repository to apply this patch
cat << 'EOF' | git apply
diff --git a/onnxruntime/test/testdata/node_output_not_used.py b/onnxruntime/test/testdata/node_output_not_used.py
--- a/onnxruntime/test/testdata/node_output_not_used.py
+++ b/onnxruntime/test/testdata/node_output_not_used.py
@@ -1,15 +1,14 @@
 import onnx
-from onnx import TensorProto, helper
 
 
 def create_model_with_node_output_not_used(model_path):
     # Create graph
-    X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [3, 2])
-    W = helper.make_tensor_value_info("W", TensorProto.FLOAT, [2, 3])
-    Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 3])
+    X = onnx.helper.make_tensor_value_info("X", onnx.TensorProto.FLOAT, [3, 2])
+    W = onnx.helper.make_tensor_value_info("W", onnx.TensorProto.FLOAT, [2, 3])
+    Y = onnx.helper.make_tensor_value_info("Y", onnx.TensorProto.FLOAT, [2, 3])
 
     # Dropout node (two outputs)
-    dropout_node = helper.make_node(
+    dropout_node = onnx.helper.make_node(
         "Dropout",
         inputs=["X"],
         outputs=["dropout_out", "dropout_mask"],
@@ -17,21 +10,21 @@
     )
 
     # MatMul node
-    matmul_node = helper.make_node(
+    matmul_node = onnx.helper.make_node(
         "MatMul",
         inputs=["dropout_out", "W"],
         outputs=["Y"],
         name="MatMulNode",
     )
 
-    graph = helper.make_graph(
+    graph = onnx.helper.make_graph(
         nodes=[dropout_node, matmul_node],
         name="DropoutMatMulGraph",
         inputs=[X, W],
         outputs=[Y],
     )
 
-    model = helper.make_model(graph, opset_imports=[helper.make_operatorsetid("", 13)])
+    model = onnx.helper.make_model(graph, opset_imports=[onnx.helper.make_operatorsetid("", 13)])
 
     onnx.checker.check_model(model)
     onnx.save(model, model_path)
EOF
@@ -1,15 +1,14 @@
import onnx
from onnx import TensorProto, helper


def create_model_with_node_output_not_used(model_path):
# Create graph
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [3, 2])
W = helper.make_tensor_value_info("W", TensorProto.FLOAT, [2, 3])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 3])
X = onnx.helper.make_tensor_value_info("X", onnx.TensorProto.FLOAT, [3, 2])
W = onnx.helper.make_tensor_value_info("W", onnx.TensorProto.FLOAT, [2, 3])
Y = onnx.helper.make_tensor_value_info("Y", onnx.TensorProto.FLOAT, [2, 3])

# Dropout node (two outputs)
dropout_node = helper.make_node(
dropout_node = onnx.helper.make_node(
"Dropout",
inputs=["X"],
outputs=["dropout_out", "dropout_mask"],
@@ -17,21 +10,21 @@
)

# MatMul node
matmul_node = helper.make_node(
matmul_node = onnx.helper.make_node(
"MatMul",
inputs=["dropout_out", "W"],
outputs=["Y"],
name="MatMulNode",
)

graph = helper.make_graph(
graph = onnx.helper.make_graph(
nodes=[dropout_node, matmul_node],
name="DropoutMatMulGraph",
inputs=[X, W],
outputs=[Y],
)

model = helper.make_model(graph, opset_imports=[helper.make_operatorsetid("", 13)])
model = onnx.helper.make_model(graph, opset_imports=[onnx.helper.make_operatorsetid("", 13)])

onnx.checker.check_model(model)
onnx.save(model, model_path)
Copilot is powered by AI and may make mistakes. Always verify output.
from onnx import TensorProto, helper


def create_model_with_node_output_not_used(model_path):
# Create graph
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [3, 2])
W = helper.make_tensor_value_info("W", TensorProto.FLOAT, [2, 3])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 3])

# Dropout node (two outputs)
dropout_node = helper.make_node(
"Dropout",
inputs=["X"],
outputs=["dropout_out", "dropout_mask"],
name="DropoutNode",
)

# MatMul node
matmul_node = helper.make_node(
"MatMul",
inputs=["dropout_out", "W"],
outputs=["Y"],
name="MatMulNode",
)

graph = helper.make_graph(
nodes=[dropout_node, matmul_node],
name="DropoutMatMulGraph",
inputs=[X, W],
outputs=[Y],
)

model = helper.make_model(graph, opset_imports=[helper.make_operatorsetid("", 13)])

onnx.checker.check_model(model)
onnx.save(model, model_path)

print(f"Model saved to: {model_path}")


if __name__ == "__main__":
create_model_with_node_output_not_used("node_output_not_used.onnx")
Binary file not shown.
Loading
Loading