Skip to content

Commit 29ff80f

Browse files
committed
address revewer's comments
1 parent b959403 commit 29ff80f

13 files changed

+88
-40
lines changed

include/onnxruntime/core/graph/graph.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1760,10 +1760,10 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi
17601760
std::vector<const ONNX_NAMESPACE::TypeProto*>& output_types,
17611761
const Graph::ResolveOptions& options);
17621762

1763-
// If the shape values are inferred after executing ONNX operator's PartialDataPropagationFunction(),
1763+
// If ONNX operator's PartialDataPropagationFunction() infers concrete shape values in the output
17641764
// save them to the output NodeArg as a TensorShapeProto or a scalar value so that downstream (consumer) nodes
17651765
// can use them later for their TypeAndShapeInferenceFunction() and PartialDataPropagationFunction().
1766-
common::Status SaveShapeValuesFromDataPropagation(Node& node, NodeArg& output_def,
1766+
common::Status SaveShapeValuesFromDataPropagation(const Node& node, NodeArg& output_def,
17671767
const ONNX_NAMESPACE::TypeProto& propagated_value_as_type_proto) const;
17681768

17691769
// Remove intermediate inferred shape values stored in all NodeArgs to reduce memory usage.

include/onnxruntime/core/graph/node_arg.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,8 @@ class NodeArg {
151151
//
152152
// The PartialDataPropagationFunction(), defined in the ONNX operator schema, must also
153153
// be executed to obtain the concrete output shape values, allowing accurate propagation
154-
// of shape information throughout the graph.
154+
// of shape information throughout the graph. If the concrete output shape value is not
155+
// computed, then no shape value is saved here that's why this is optional.
155156
std::optional<ONNX_NAMESPACE::TensorShapeProto> inferred_shape_values_;
156157

157158
// This variable stores the actual scalar value.

onnxruntime/core/graph/data_propagation/add_op_data_propagation.cc

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,15 @@ Status AddOpDataPropagation::infer() {
1515
// Get "B" input
1616
const auto* input_1 = node_.InputDefs()[1];
1717

18+
// Return and do nothing if input doesn't exist
19+
if (!input_0 || !input_1) {
20+
return Status::OK();
21+
}
22+
1823
if (input_0->GetInferredShapeScalarValue().has_value() && input_1->GetInferredShapeScalarValue().has_value()) {
19-
output_def_.SetInferredShapeScalarValue(input_0->GetInferredShapeScalarValue().value() + input_1->GetInferredShapeScalarValue().value());
24+
output_def_.SetInferredShapeScalarValue(
25+
input_0->GetInferredShapeScalarValue().value() +
26+
input_1->GetInferredShapeScalarValue().value());
2027
}
2128

2229
return Status::OK();

onnxruntime/core/graph/data_propagation/custom_data_propagation.h

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,11 @@ class CustomDataPropagationBase {
6565
* @param logger The reference to a logger
6666
* @return std::unique_ptr<CustomDataPropagation> Returns a CustomDataPropagation object if available
6767
*/
68-
std::unique_ptr<CustomDataPropagationBase> CreateCustomDataPropagation(const Node& node,
69-
NodeArg& output_def,
70-
std::function<Status(const std::string&, TensorShapeVector&)> func,
71-
const ONNX_NAMESPACE::TypeProto& output_from_onnx_op_data_propagation,
72-
const logging::Logger& logger);
68+
std::unique_ptr<CustomDataPropagationBase> CreateCustomDataPropagation(
69+
const Node& node,
70+
NodeArg& output_def,
71+
std::function<Status(const std::string&, TensorShapeVector&)> func,
72+
const ONNX_NAMESPACE::TypeProto& output_from_onnx_op_data_propagation,
73+
const logging::Logger& logger);
7374

7475
} // namespace onnxruntime

onnxruntime/core/graph/data_propagation/div_op_data_propagation.cc

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,15 @@ Status DivOpDataPropagation::infer() {
1515
// Get "B" input
1616
const auto* input_1 = node_.InputDefs()[1];
1717

18+
// Return and do nothing if input doesn't exist
19+
if (!input_0 || !input_1) {
20+
return Status::OK();
21+
}
22+
1823
if (input_0->GetInferredShapeScalarValue().has_value() && input_1->GetInferredShapeScalarValue().has_value()) {
19-
output_def_.SetInferredShapeScalarValue(input_0->GetInferredShapeScalarValue().value() / input_1->GetInferredShapeScalarValue().value());
24+
output_def_.SetInferredShapeScalarValue(
25+
input_0->GetInferredShapeScalarValue().value() /
26+
input_1->GetInferredShapeScalarValue().value());
2027
}
2128

2229
return Status::OK();

onnxruntime/core/graph/data_propagation/gather_op_data_propagation.cc

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,36 +10,47 @@
1010
namespace onnxruntime {
1111

1212
Status GatherOpDataPropagation::infer() {
13-
int dim_size = 0;
1413
if (output_from_onnx_op_data_propagation_.has_tensor_type() &&
1514
output_from_onnx_op_data_propagation_.tensor_type().has_shape()) {
16-
dim_size = output_from_onnx_op_data_propagation_.tensor_type().shape().dim_size();
15+
int dim_size = output_from_onnx_op_data_propagation_.tensor_type().shape().dim_size();
16+
// Check there is no result from Gather's PartialDataPropagationFunction(),
17+
// so that it can run custom data propagation below.
18+
// Otherwise, this infer() function won't be called as the result from Gather's PartialDataPropagationFunction()
19+
// will be used in Graph::SaveShapeValuesFromDataPropagation().
20+
if (dim_size != 0) {
21+
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
22+
"ORT shouldn't run Gather's custom data propagation here as Gather's"
23+
"PartialDataPropagationFunction() already infers shape values in the output.");
24+
}
1725
}
18-
// Check there is no result from Gather's PartialDataPropagationFunction(),
19-
// so that it can run custom data propagation below.
20-
// Otherwise, this infer() function won't be called as the result from Gather's PartialDataPropagationFunction()
21-
// will be used in Graph::SaveShapeValuesFromDataPropagation().
22-
ORT_ENFORCE(dim_size == 0);
2326

2427
// Following code extracts an element from a 1D array if all conditions are met.
2528
// e.g.
2629
// shape data is [1, 3, 64, 64] -> gets 64 if the index is 2.
2730
// shape data is [1, 3, 64, 64] -> gets 3 if the index is 1.
2831

2932
// Get "data" input
30-
// Note: The "data" input should be a one dimension array in this case.
33+
// Note: The "data" input should be an one dimension array in this case.
3134
const auto* input_0 = node_.InputDefs()[0];
3235

3336
// Get "indices" input
3437
// Note: The "indices" input could be one of the three cases:
3538
// 1. A tensor with rank > 0 and all tensor values are known.
36-
// 2. A tensor with rank > 0 but not all tensor values are know.
39+
// 2. A tensor with rank > 0 but not all tensor values are known.
3740
// 3. A scalar.
3841
//
3942
// If it's case #1, ONNX operator's PartialDataPropagationFunction()
4043
// should have inferred the output shape value.
44+
// If it's case #2, neither ONNX operator's PartialDataPropagationFunction()
45+
// nor Gather's custom data propagation can handle it.
4146
// This Gather's custom data propagation handles case #3.
4247
const auto* input_1 = node_.InputDefs()[1];
48+
49+
// Return and do nothing if input doesn't exist
50+
if (!input_0 || !input_1) {
51+
return Status::OK();
52+
}
53+
4354
TensorShapeVector indices;
4455
ORT_RETURN_IF_ERROR(get_initialized_input_values_func_(input_1->Name(), indices));
4556

@@ -51,7 +62,10 @@ Status GatherOpDataPropagation::infer() {
5162
if (indices.size() == 1) {
5263
ORT_TRY {
5364
// Note: Index value is expected to be within bounds [-s, s-1] along axis of size s
54-
auto& dim = tensor_shape_proto.dim(static_cast<int32_t>(HandleNegativeAxis(indices[0], tensor_shape_proto.dim_size())));
65+
auto index = static_cast<int32_t>(
66+
HandleNegativeAxis(indices[0], tensor_shape_proto.dim_size()));
67+
68+
auto& dim = tensor_shape_proto.dim(index);
5569
if (dim.has_dim_value()) {
5670
output_def_.SetInferredShapeScalarValue(dim.dim_value());
5771
}

onnxruntime/core/graph/data_propagation/mul_op_data_propagation.cc

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,15 @@ Status MulOpDataPropagation::infer() {
1515
// Get "B" input
1616
const auto* input_1 = node_.InputDefs()[1];
1717

18+
// Return and do nothing if input doesn't exist
19+
if (!input_0 || !input_1) {
20+
return Status::OK();
21+
}
22+
1823
if (input_0->GetInferredShapeScalarValue().has_value() && input_1->GetInferredShapeScalarValue().has_value()) {
19-
output_def_.SetInferredShapeScalarValue(input_0->GetInferredShapeScalarValue().value() * input_1->GetInferredShapeScalarValue().value());
24+
output_def_.SetInferredShapeScalarValue(
25+
input_0->GetInferredShapeScalarValue().value() *
26+
input_1->GetInferredShapeScalarValue().value());
2027
}
2128

2229
return Status::OK();

onnxruntime/core/graph/data_propagation/size_op_data_propagation.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,11 @@ Status SizeOpDataPropagation::infer() {
1212
// Size operator generates a scalar output
1313
const auto* input_0 = node_.InputDefs()[0];
1414

15+
// Return and do nothing if input doesn't exist
16+
if (!input_0) {
17+
return Status::OK();
18+
}
19+
1520
if (input_0->GetInferredShapeValues().has_value()) {
1621
const auto& tensor_shape_proto = input_0->GetInferredShapeValues().value();
1722

onnxruntime/core/graph/data_propagation/size_op_data_propagation.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@ namespace onnxruntime {
1010

1111
/**
1212
* @brief Class to infer the output scalar for 'Size' operator given the input is shape values.
13-
*
13+
*
14+
* 'Size' operator takes a tensor as input and outputs a int64 scalar that equals to the total
15+
* number of elements of the input tensor.
1416
*/
1517
class SizeOpDataPropagation : public CustomDataPropagationBase {
1618
public:

onnxruntime/core/graph/data_propagation/squeeze_op_data_propagation.cc

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,18 @@
66
#include "core/graph/node_arg.h"
77
#include "core/graph/onnx_protobuf.h"
88
#include "core/providers/common.h"
9-
10-
#if defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
11-
#include <unordered_set>
12-
#endif
9+
#include "core/common/inlined_containers.h"
1310

1411
namespace onnxruntime {
1512

1613
Status SqueezeOpDataPropagation::infer() {
1714
const auto* input_0 = node_.InputDefs()[0];
1815

16+
// Return and do nothing if input doesn't exist
17+
if (!input_0) {
18+
return Status::OK();
19+
}
20+
1921
if (input_0->GetInferredShapeValues().has_value()) {
2022
const auto& tensor_shape_proto = input_0->GetInferredShapeValues().value();
2123

@@ -34,11 +36,7 @@ Status SqueezeOpDataPropagation::infer() {
3436
} else if (tensor_shape_proto.dim_size() > 1) {
3537
// Get axes value
3638
TensorShapeVector axes;
37-
#if !defined(ORT_MINIMAL_BUILD) && !defined(ORT_EXTENDED_MINIMAL_BUILD)
3839
InlinedHashSet<int64_t> axes_set;
39-
#else
40-
std::unordered_set<int64_t> axes_set;
41-
#endif
4240

4341
// Note: Starting from opset 13, "axes" is provided as a second input to the Squeeze operator.
4442
// In opset 11 and earlier, "axes" is defined as a node attribute instead.

0 commit comments

Comments
 (0)