@@ -2902,12 +2902,10 @@ Status Graph::SaveShapeValuesFromDataPropagation(const Node& node,
29022902 const TensorProto* initializer = this ->GetConstantInitializer (input_name, true );
29032903
29042904 if (initializer) {
2905- // Get shape from TensorProto and calculate element counts.
2905+ // Get shape from TensorProto as well as element counts.
29062906 // If shape has dimension size equals zero, it means it's a scalar and has only one element.
2907- size_t element_cnt = 1 ;
2908- for (auto & dim : initializer->dims ()) {
2909- element_cnt *= static_cast <size_t >(dim);
2910- }
2907+ auto tensor_shape = utils::GetTensorShapeFromTensorProto (*initializer);
2908+ size_t element_cnt = tensor_shape.Size ();
29112909
29122910 // Check if this is in-memory external data (data stored in OrtValue)
29132911 if (utils::HasExternalDataInMemory (*initializer)) {
@@ -2916,9 +2914,16 @@ Status Graph::SaveShapeValuesFromDataPropagation(const Node& node,
29162914 if (this ->GetOrtValueInitializer (input_name, ort_value, true )) {
29172915 const Tensor& tensor = ort_value.Get <Tensor>();
29182916 if (initializer->data_type () == TensorProto_DataType_INT32) {
2917+ auto data_span = tensor.DataAsSpan <int32_t >();
2918+ ORT_ENFORCE (data_span.size () == element_cnt,
2919+ " The element counts from Tensor should be the same"
2920+ " from using utils::GetTensorShapeFromTensorProto()" );
2921+
2922+ size_t index = 0 ;
29192923 input_values.resize (element_cnt);
2920- for (size_t i = 0 ; i < element_cnt; ++i) {
2921- input_values[i] = static_cast <int64_t >(tensor.Data <int32_t >()[i]);
2924+ for (const auto & v : data_span) {
2925+ input_values[index] = static_cast <int64_t >(v);
2926+ ++index;
29222927 }
29232928 } else if (initializer->data_type () == TensorProto_DataType_INT64) {
29242929 const int64_t * src = tensor.Data <int64_t >();
@@ -2933,7 +2938,7 @@ Status Graph::SaveShapeValuesFromDataPropagation(const Node& node,
29332938 // Unpack tensor from raw data, external data (not in memory) or the type specific data field
29342939 else {
29352940 if (initializer->data_type () == TensorProto_DataType_INT32) {
2936- std::vector <int32_t > tmp_values;
2941+ InlinedVector <int32_t > tmp_values;
29372942 tmp_values.resize (element_cnt);
29382943 ORT_RETURN_IF_ERROR (utils::UnpackTensor<int32_t >(*initializer,
29392944 this ->ModelPath (),
@@ -2977,7 +2982,7 @@ Status Graph::SaveShapeValuesFromDataPropagation(const Node& node,
29772982 }
29782983
29792984 // If no custom data propagation is defined for the operator,
2980- // fall back to using ONNX's PartialDataPropagationFunction(), if available.
2985+ // fall back to using the result of ONNX's PartialDataPropagationFunction(), if available.
29812986
29822987 int dim_size = 0 ;
29832988 if (onnx_inferred_type_after_data_propagation.has_tensor_type () &&
0 commit comments