Skip to content

Commit d08da8c

Browse files
committed
address reviewer's comments
1 parent 29ff80f commit d08da8c

File tree

1 file changed

+14
-9
lines changed

1 file changed

+14
-9
lines changed

onnxruntime/core/graph/graph.cc

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2902,12 +2902,10 @@ Status Graph::SaveShapeValuesFromDataPropagation(const Node& node,
29022902
const TensorProto* initializer = this->GetConstantInitializer(input_name, true);
29032903

29042904
if (initializer) {
2905-
// Get shape from TensorProto and calculate element counts.
2905+
// Get shape from TensorProto as well as element counts.
29062906
// If shape has dimension size equals zero, it means it's a scalar and has only one element.
2907-
size_t element_cnt = 1;
2908-
for (auto& dim : initializer->dims()) {
2909-
element_cnt *= static_cast<size_t>(dim);
2910-
}
2907+
auto tensor_shape = utils::GetTensorShapeFromTensorProto(*initializer);
2908+
size_t element_cnt = tensor_shape.Size();
29112909

29122910
// Check if this is in-memory external data (data stored in OrtValue)
29132911
if (utils::HasExternalDataInMemory(*initializer)) {
@@ -2916,9 +2914,16 @@ Status Graph::SaveShapeValuesFromDataPropagation(const Node& node,
29162914
if (this->GetOrtValueInitializer(input_name, ort_value, true)) {
29172915
const Tensor& tensor = ort_value.Get<Tensor>();
29182916
if (initializer->data_type() == TensorProto_DataType_INT32) {
2917+
auto data_span = tensor.DataAsSpan<int32_t>();
2918+
ORT_ENFORCE(data_span.size() == element_cnt,
2919+
"The element counts from Tensor should be the same"
2920+
"from using utils::GetTensorShapeFromTensorProto()");
2921+
2922+
size_t index = 0;
29192923
input_values.resize(element_cnt);
2920-
for (size_t i = 0; i < element_cnt; ++i) {
2921-
input_values[i] = static_cast<int64_t>(tensor.Data<int32_t>()[i]);
2924+
for (const auto& v : data_span) {
2925+
input_values[index] = static_cast<int64_t>(v);
2926+
++index;
29222927
}
29232928
} else if (initializer->data_type() == TensorProto_DataType_INT64) {
29242929
const int64_t* src = tensor.Data<int64_t>();
@@ -2933,7 +2938,7 @@ Status Graph::SaveShapeValuesFromDataPropagation(const Node& node,
29332938
// Unpack tensor from raw data, external data (not in memory) or the type specific data field
29342939
else {
29352940
if (initializer->data_type() == TensorProto_DataType_INT32) {
2936-
std::vector<int32_t> tmp_values;
2941+
InlinedVector<int32_t> tmp_values;
29372942
tmp_values.resize(element_cnt);
29382943
ORT_RETURN_IF_ERROR(utils::UnpackTensor<int32_t>(*initializer,
29392944
this->ModelPath(),
@@ -2977,7 +2982,7 @@ Status Graph::SaveShapeValuesFromDataPropagation(const Node& node,
29772982
}
29782983

29792984
// If no custom data propagation is defined for the operator,
2980-
// fall back to using ONNX's PartialDataPropagationFunction(), if available.
2985+
// fall back to using the result of ONNX's PartialDataPropagationFunction(), if available.
29812986

29822987
int dim_size = 0;
29832988
if (onnx_inferred_type_after_data_propagation.has_tensor_type() &&

0 commit comments

Comments
 (0)