Skip to content

Commit a2c4374

Browse files
authored
Add partial data propagation to enhance shape inference (#26269)
### Description Calling an operator's `TypeAndShapeInferenceFunction()` alone is sometimes insufficient for complete shape inference. For example, the `Shape` operator only infers the output’s rank (a 1-dimensional tensor) but not its actual dimension values. For instance, given an input of shape [1, 3, 64, 64], the Shape operator's `TypeAndShapeInferenceFunction()` produces an output shape tensor with 1-dimension as int64[4], representing the rank of the input tensor. Therefore, as you can imagine, the below graph's output shape can't be properly inferred (even though the input shape is known) because the shape data is lost at the `Shape `operator. <img width="563" height="488" alt="image" src="https://github.com/user-attachments/assets/bfa9fd8f-5291-4c6d-a679-3ce4a8c48669" /> To solve the issue, the `PartialDataPropagationFunction()`, defined in the ONNX operator schema, must also be executed to obtain the concrete output shape values, allowing accurate propagation of shape information throughout the graph. This PR adds the support of executing operator's `PartialDataPropagationFunction()` in ORT, and makes sure the shape values is properly propagated throughout the graph. ### Motivation and Context When using the Compile API to generate an EPContext model, all graph optimizations are disabled by default except for free dimension overrides. However, for certain models, such as a VAE decoder, the output shape may still fail to be properly inferred even when free dimension override values are provided beforehand. However, you won't hit this issue if enabling all the graph optimizations as some nodes, e.g. `Shape`, `Reshape `.. will be constant folded.
1 parent 8f8069d commit a2c4374

31 files changed

+1637
-4
lines changed

include/onnxruntime/core/graph/graph.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1753,6 +1753,15 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi
17531753
std::vector<const ONNX_NAMESPACE::TypeProto*>& output_types,
17541754
const Graph::ResolveOptions& options);
17551755

1756+
// If ONNX operator's PartialDataPropagationFunction() infers concrete shape values in the output
1757+
// save them to the output NodeArg as a TensorShapeProto or a scalar value so that downstream (consumer) nodes
1758+
// can use them later for their TypeAndShapeInferenceFunction() and PartialDataPropagationFunction().
1759+
common::Status SaveShapeValuesFromDataPropagation(const Node& node, NodeArg& output_def,
1760+
const ONNX_NAMESPACE::TypeProto& propagated_value_as_type_proto) const;
1761+
1762+
// Remove intermediate inferred shape values stored in all NodeArgs to reduce memory usage.
1763+
common::Status CleanUpShapeValuesFromDataPropagation();
1764+
17561765
// Apply type-inference and type-checking to all inputs and initializers:
17571766
common::Status TypeCheckInputsAndInitializers();
17581767

include/onnxruntime/core/graph/node_arg.h

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
#include "core/common/status.h"
1010
#include "core/common/logging/logging.h"
1111

12+
#include <optional>
13+
1214
namespace onnxruntime {
1315

1416
// Node argument definition, for both input and output,
@@ -107,6 +109,18 @@ class NodeArg {
107109
/** Gets this NodeArg as a NodeArgInfo, AKA ValueInfoProto. */
108110
const NodeArgInfo& ToProto() const noexcept { return node_arg_info_; }
109111

112+
/** Gets the inferred shape values as a TensorShapeProto. */
113+
const std::optional<ONNX_NAMESPACE::TensorShapeProto>& GetInferredShapeValues() const noexcept { return inferred_shape_values_; }
114+
115+
/** Gets mutable inferred shape values as a TensorShapeProto. */
116+
std::optional<ONNX_NAMESPACE::TensorShapeProto>& GetMutableInferredShapeValues() noexcept { return inferred_shape_values_; }
117+
118+
/** Gets the inferred shape scalar value */
119+
const std::optional<int64_t> GetInferredShapeScalarValue() const noexcept { return inferred_scalar_value_; }
120+
121+
/** Sets the inferred shape scalar value */
122+
void SetInferredShapeScalarValue(int64_t value) noexcept { inferred_scalar_value_ = value; }
123+
110124
/** Gets a flag indicating whether this NodeArg exists or not.
111125
Optional inputs are allowed in ONNX and an empty #Name represents a non-existent input argument. */
112126
bool Exists() const noexcept;
@@ -128,6 +142,24 @@ class NodeArg {
128142
// Node arg name, type and shape.
129143
NodeArgInfo node_arg_info_;
130144

145+
// This variable stores the actual tensor data of the shape as a TensorShapeProto after executing
146+
// the ONNX operator's PartialDataPropagationFunction(). It's used for shape inference purpose.
147+
//
148+
// Calling an operator's TypeAndShapeInferenceFunction() alone is sometimes insufficient
149+
// for complete shape inference. For example, the Shape operator's TypeAndShapeInferenceFunction()
150+
// only provides the output's rank which is 1 but not its actual shape values.
151+
//
152+
// The PartialDataPropagationFunction(), defined in the ONNX operator schema, must also
153+
// be executed to obtain the concrete shape values output, allowing accurate propagation
154+
// of shape information throughout the graph. If the concrete shape values output is not
155+
// computed, nothing is stored here that's why this is optional.
156+
std::optional<ONNX_NAMESPACE::TensorShapeProto> inferred_shape_values_;
157+
158+
// This variable stores the actual scalar value.
159+
// It is also used for shape inference and data propagation to ensure consistent shape and
160+
// value information throughout the graph.
161+
std::optional<int64_t> inferred_scalar_value_;
162+
131163
// Flag indicates whether <*this> node arg exists or not.
132164
bool exists_;
133165
};

include/onnxruntime/core/session/onnxruntime_cxx_api.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1441,6 +1441,12 @@ struct SessionOptionsImpl : ConstSessionOptionsImpl<T> {
14411441

14421442
///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_VitisAI
14431443
SessionOptionsImpl& AppendExecutionProvider_VitisAI(const std::unordered_map<std::string, std::string>& provider_options = {});
1444+
1445+
///< Wraps OrtApi::AddFreeDimensionOverride
1446+
SessionOptionsImpl& AddFreeDimensionOverride(const char* dim_denotation, int64_t dim_value);
1447+
1448+
///< Wraps OrtApi::AddFreeDimensionOverrideByName
1449+
SessionOptionsImpl& AddFreeDimensionOverrideByName(const char* dim_name, int64_t dim_value);
14441450
};
14451451
} // namespace detail
14461452

include/onnxruntime/core/session/onnxruntime_cxx_inline.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1503,6 +1503,18 @@ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::RegisterCustomOpsUsingFunct
15031503
return *this;
15041504
}
15051505

1506+
template <typename T>
1507+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AddFreeDimensionOverride(const char* dim_denotation, int64_t dim_value) {
1508+
ThrowOnError(GetApi().AddFreeDimensionOverrideByName(this->p_, dim_denotation, dim_value));
1509+
return *this;
1510+
}
1511+
1512+
template <typename T>
1513+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AddFreeDimensionOverrideByName(const char* dim_name, int64_t dim_value) {
1514+
ThrowOnError(GetApi().AddFreeDimensionOverrideByName(this->p_, dim_name, dim_value));
1515+
return *this;
1516+
}
1517+
15061518
/// Session
15071519
template <typename T>
15081520
inline size_t ConstSessionImpl<T>::GetInputCount() const {
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include "add_op_data_propagation.h"
5+
#include "core/common/common.h"
6+
#include "core/graph/node_arg.h"
7+
#include "core/graph/onnx_protobuf.h"
8+
#include "core/providers/common.h"
9+
10+
namespace onnxruntime {
11+
12+
Status AddOpDataPropagation::infer() {
13+
// Get "A" input
14+
const auto* input_0 = node_.InputDefs()[0];
15+
// Get "B" input
16+
const auto* input_1 = node_.InputDefs()[1];
17+
18+
// Return and do nothing if input doesn't exist
19+
if (!input_0 || !input_1 || !input_0->Exists() || !input_1->Exists()) {
20+
return Status::OK();
21+
}
22+
23+
if (input_0->GetInferredShapeScalarValue().has_value() && input_1->GetInferredShapeScalarValue().has_value()) {
24+
output_def_.SetInferredShapeScalarValue(
25+
input_0->GetInferredShapeScalarValue().value() +
26+
input_1->GetInferredShapeScalarValue().value());
27+
}
28+
29+
return Status::OK();
30+
}
31+
32+
} // namespace onnxruntime
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
6+
#include "custom_data_propagation.h"
7+
#include "core/graph/graph.h"
8+
9+
namespace onnxruntime {
10+
11+
/**
12+
* @brief Class to infer the output scalar for 'Add' operator given the input is a scalar related to shape.
13+
*
14+
*
15+
* For example:
16+
*
17+
* (input with the shape as float32[1, 3, 64, 64])
18+
* |
19+
* v
20+
* Shape (It saves [1, 3, 64, 64] in inferred_shape_values_ in output's node_arg
21+
* | during Graph::SaveShapeValuesFromDataPropagation())
22+
* |
23+
* | ______
24+
* | |
25+
* v v
26+
* Gather Gather (First 'Gather' saves 3 in inferred_scalar_value_ in output node_arg, and
27+
* | | second 'Gather' saves 64 in inferred_scalar_value_ in output node_arg
28+
* | | during GatherOpDataPropagation(), if the 'index' attributes
29+
* | | are 1 and 2 respectively)
30+
* \ /
31+
* \ /
32+
* | |
33+
* v v
34+
* Add (It gets 3 from inferred_scalar_value_ in input A's node_arg and 64 from inferred_scalar_value_
35+
* | in input B's node_arg, then performs add operation to get 67 and saves in inferred_scalar_value_
36+
* | in output's node_arg)
37+
* v
38+
* ...
39+
*/
40+
class AddOpDataPropagation : public CustomDataPropagationBase {
41+
public:
42+
AddOpDataPropagation(const Node& node,
43+
NodeArg& output_def,
44+
std::function<Status(const std::string&, TensorShapeVector&)> func,
45+
const ONNX_NAMESPACE::TypeProto& output_from_onnx_op_data_propagation,
46+
const logging::Logger& logger) noexcept
47+
: CustomDataPropagationBase(node, output_def, func, output_from_onnx_op_data_propagation, logger) {}
48+
49+
Status infer() override;
50+
};
51+
52+
} // namespace onnxruntime
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include "custom_data_propagation.h"
5+
#include "core/common/common.h"
6+
#include "core/graph/graph.h"
7+
#include "core/common/logging/logging.h"
8+
#include "size_op_data_propagation.h"
9+
#include "squeeze_op_data_propagation.h"
10+
#include "unsqueeze_op_data_propagation.h"
11+
#include "gather_op_data_propagation.h"
12+
#include "add_op_data_propagation.h"
13+
#include "sub_op_data_propagation.h"
14+
#include "mul_op_data_propagation.h"
15+
#include "div_op_data_propagation.h"
16+
#include <onnx/onnx-ml.pb.h>
17+
18+
namespace onnxruntime {
19+
20+
std::unique_ptr<CustomDataPropagationBase> CreateCustomDataPropagation(const Node& node,
21+
NodeArg& output_def,
22+
std::function<Status(const std::string&, TensorShapeVector&)> func,
23+
const ONNX_NAMESPACE::TypeProto& output_from_onnx_op_data_propagation,
24+
const logging::Logger& logger) {
25+
int dim_size = 0;
26+
if (output_from_onnx_op_data_propagation.has_tensor_type() &&
27+
output_from_onnx_op_data_propagation.tensor_type().has_shape()) {
28+
dim_size = output_from_onnx_op_data_propagation.tensor_type().shape().dim_size();
29+
}
30+
31+
if (node.OpType() == "Size") {
32+
return std::make_unique<SizeOpDataPropagation>(node, output_def, std::move(func), output_from_onnx_op_data_propagation, logger);
33+
} else if (node.OpType() == "Squeeze") {
34+
return std::make_unique<SqueezeOpDataPropagation>(node, output_def, std::move(func), output_from_onnx_op_data_propagation, logger);
35+
} else if (node.OpType() == "Unsqueeze") {
36+
return std::make_unique<UnsqueezeOpDataPropagation>(node, output_def, std::move(func), output_from_onnx_op_data_propagation, logger);
37+
} else if (dim_size == 0) {
38+
if (node.OpType() == "Gather") {
39+
return std::make_unique<GatherOpDataPropagation>(node, output_def, std::move(func), output_from_onnx_op_data_propagation, logger);
40+
} else if (node.OpType() == "Add") {
41+
return std::make_unique<AddOpDataPropagation>(node, output_def, std::move(func), output_from_onnx_op_data_propagation, logger);
42+
} else if (node.OpType() == "Sub") {
43+
return std::make_unique<SubOpDataPropagation>(node, output_def, std::move(func), output_from_onnx_op_data_propagation, logger);
44+
} else if (node.OpType() == "Mul") {
45+
return std::make_unique<MulOpDataPropagation>(node, output_def, std::move(func), output_from_onnx_op_data_propagation, logger);
46+
} else if (node.OpType() == "Div") {
47+
return std::make_unique<DivOpDataPropagation>(node, output_def, std::move(func), output_from_onnx_op_data_propagation, logger);
48+
}
49+
}
50+
return nullptr;
51+
}
52+
53+
} // namespace onnxruntime
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
6+
#include "core/common/common.h"
7+
#include "core/graph/graph.h"
8+
#include "core/common/logging/logging.h"
9+
#include <onnx/onnx-ml.pb.h>
10+
11+
namespace onnxruntime {
12+
13+
/**
14+
* @class CustomDataPropagation
15+
* Custom data propagation for the operator to help enhance shape inference.
16+
*
17+
* Calling infer() can infer the output values for the specific operator given the input is shape values
18+
* and saves the output values in output node_arg for other operators to use later.
19+
* The purpose of this class is to make shape values being correctly inferred and propogated through the graph.
20+
*/
21+
class CustomDataPropagationBase {
22+
public:
23+
ORT_DISALLOW_COPY(CustomDataPropagationBase);
24+
virtual ~CustomDataPropagationBase() = default;
25+
virtual Status infer() = 0;
26+
27+
protected:
28+
CustomDataPropagationBase(const Node& node,
29+
NodeArg& output_def,
30+
std::function<Status(const std::string&, TensorShapeVector&)> func,
31+
const ONNX_NAMESPACE::TypeProto& output_from_onnx_op_data_propagation,
32+
const logging::Logger& logger) noexcept
33+
: node_(node),
34+
output_def_(output_def),
35+
get_initialized_input_values_func_(std::move(func)),
36+
output_from_onnx_op_data_propagation_(output_from_onnx_op_data_propagation),
37+
logger_(logger) {}
38+
39+
const Node& node_;
40+
NodeArg& output_def_;
41+
std::function<Status(const std::string&, TensorShapeVector&)> get_initialized_input_values_func_;
42+
const ONNX_NAMESPACE::TypeProto& output_from_onnx_op_data_propagation_;
43+
const logging::Logger& logger_;
44+
};
45+
46+
/**
47+
* @brief Create custom data propagation for the operator.
48+
*
49+
* For certain operators (e.g., Size, Squeeze, Unsqueeze), ONNX's
50+
* PartialDataPropagationFunction() does not always produce complete or accurate
51+
* inferred shape values.
52+
*
53+
* In particular:
54+
* - Scalar inputs and outputs are not handled correctly.
55+
* - Some operators require additional logic that is not covered by the default function,
56+
e.g. PartialDataPropagationFunction.
57+
*
58+
* Therefore, for these cases, we perform custom data propagation to ensure
59+
* correct and complete inference.
60+
*
61+
* @param node The ORT's node
62+
* @param output_def The node's output NodeArg to save the inferred shape values if needed
63+
* @param func Helper function to get the input value if it's a initializer
64+
* @param output_from_onnx_op_data_propagation The result from executing ONNX operator's data propagation
65+
* @param logger The reference to a logger
66+
* @return std::unique_ptr<CustomDataPropagation> Returns a CustomDataPropagation object if available
67+
*/
68+
std::unique_ptr<CustomDataPropagationBase> CreateCustomDataPropagation(
69+
const Node& node,
70+
NodeArg& output_def,
71+
std::function<Status(const std::string&, TensorShapeVector&)> func,
72+
const ONNX_NAMESPACE::TypeProto& output_from_onnx_op_data_propagation,
73+
const logging::Logger& logger);
74+
75+
} // namespace onnxruntime
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include "div_op_data_propagation.h"
5+
#include "core/common/common.h"
6+
#include "core/graph/node_arg.h"
7+
#include "core/graph/onnx_protobuf.h"
8+
#include "core/providers/common.h"
9+
10+
namespace onnxruntime {
11+
12+
Status DivOpDataPropagation::infer() {
13+
// Get "A" input
14+
const auto* input_0 = node_.InputDefs()[0];
15+
// Get "B" input
16+
const auto* input_1 = node_.InputDefs()[1];
17+
18+
// Return and do nothing if input doesn't exist
19+
if (!input_0 || !input_1 || !input_0->Exists() || !input_1->Exists()) {
20+
return Status::OK();
21+
}
22+
23+
if (input_0->GetInferredShapeScalarValue().has_value() && input_1->GetInferredShapeScalarValue().has_value()) {
24+
output_def_.SetInferredShapeScalarValue(
25+
input_0->GetInferredShapeScalarValue().value() /
26+
input_1->GetInferredShapeScalarValue().value());
27+
}
28+
29+
return Status::OK();
30+
}
31+
32+
} // namespace onnxruntime

0 commit comments

Comments
 (0)