|
15 | 15 | #include "core/framework/ort_value_pattern_planner.h" |
16 | 16 | #include "core/framework/prepacked_weights_container.h" |
17 | 17 | #include "core/framework/session_state_utils.h" |
| 18 | +#include "core/framework/tensorprotoutils.h" |
18 | 19 | #include "core/framework/utils.h" |
19 | 20 | #include "core/providers/cpu/controlflow/utils.h" |
20 | 21 | #include "core/session/onnxruntime_session_options_config_keys.h" |
@@ -1332,6 +1333,9 @@ Status SessionState::FinalizeSessionState(const std::basic_string<PATH_CHAR_TYPE |
1332 | 1333 | ORT_RETURN_IF_ERROR(CreateSubgraphSessionState()); |
1333 | 1334 |
|
1334 | 1335 | ORT_RETURN_IF_ERROR(VerifyEachNodeIsAssignedToAnEp(graph_, logger_, execution_providers_)); |
| 1336 | + |
| 1337 | + ORT_RETURN_IF_ERROR(TransformInitializersToPreferredFormat()); |
| 1338 | + |
1335 | 1339 | ORT_RETURN_IF_ERROR(PopulateKernelCreateInfo(kernel_registry_manager, saving_ort_format)); |
1336 | 1340 |
|
1337 | 1341 | InlinedHashMap<std::string, size_t> constant_initializers_use_count; |
@@ -1501,6 +1505,14 @@ Status SessionState::FinalizeSessionStateImpl(const std::basic_string<PATH_CHAR_ |
1501 | 1505 | CreateGraphInfo(save_prepacked_initializers); |
1502 | 1506 | } |
1503 | 1507 |
|
| 1508 | + // Index all initializers including those that may have become unreferenced after transformation. |
| 1509 | + // This runs after CreateGraphInfo() to ensure consistent ordering - CreateGraphInfo indexes based on |
| 1510 | + // graph structure, then we add any remaining initializers (e.g., original weights before transformation). |
| 1511 | + for (const auto& [init_name, tensor_proto] : graph_.GetAllInitializedTensors()) { |
| 1512 | + ORT_UNUSED_PARAMETER(tensor_proto); |
| 1513 | + ort_value_name_idx_map_.Add(init_name); |
| 1514 | + } |
| 1515 | + |
1504 | 1516 | #if defined(ORT_EXTENDED_MINIMAL_BUILD) |
1505 | 1517 | // Remove any unused initializers. |
1506 | 1518 | // Not needed in a full build because unused initializers should have been removed earlier by Graph::Resolve(). |
@@ -1793,4 +1805,152 @@ void SessionState::RecycleDeviceStreamCollection(std::unique_ptr<DeviceStreamCol |
1793 | 1805 | } |
1794 | 1806 | #endif |
1795 | 1807 |
|
| 1808 | +Status SessionState::TransformInitializersToPreferredFormat() { |
| 1809 | + // Build a map from initializer name to all nodes that consume it |
| 1810 | + std::unordered_map<std::string, std::vector<std::pair<NodeIndex, int>>> initializer_to_consumers; |
| 1811 | + |
| 1812 | + const auto& initialized_tensors_map = graph_.GetAllInitializedTensors(); |
| 1813 | + std::unordered_set<std::string> initializer_names; |
| 1814 | + for (const auto& [name, tensor_proto] : initialized_tensors_map) { |
| 1815 | + ORT_UNUSED_PARAMETER(tensor_proto); |
| 1816 | + initializer_names.insert(name); |
| 1817 | + } |
| 1818 | + |
| 1819 | + // Scan nodes to find which initializers they use |
| 1820 | + for (const auto& node : graph_.Nodes()) { |
| 1821 | + int input_index = 0; |
| 1822 | + for (const auto* input_def : node.InputDefs()) { |
| 1823 | + if (input_def && input_def->Exists()) { |
| 1824 | + const auto& input_name = input_def->Name(); |
| 1825 | + if (initializer_names.count(input_name) > 0) { |
| 1826 | + initializer_to_consumers[input_name].emplace_back(node.Index(), input_index); |
| 1827 | + } |
| 1828 | + } |
| 1829 | + ++input_index; |
| 1830 | + } |
| 1831 | + } |
| 1832 | + |
| 1833 | + auto cpu_allocator = GetAllocator(OrtDevice()); |
| 1834 | + |
| 1835 | + for (const auto& [init_name, consumers] : initializer_to_consumers) { |
| 1836 | + if (consumers.empty()) { |
| 1837 | + continue; |
| 1838 | + } |
| 1839 | + |
| 1840 | + const ONNX_NAMESPACE::TensorProto* tensor_proto = graph_.GetInitializer(init_name, true); |
| 1841 | + if (!tensor_proto) { |
| 1842 | + continue; |
| 1843 | + } |
| 1844 | + |
| 1845 | + // Skip if this initializer was already transformed (when loading a saved ORT format model) |
| 1846 | + // Transformed initializers have format metadata in string_data |
| 1847 | + bool already_transformed = false; |
| 1848 | + for (const auto& attr_str : tensor_proto->string_data()) { |
| 1849 | + if (attr_str.find("onnxruntime_format:") == 0) { |
| 1850 | + already_transformed = true; |
| 1851 | + break; |
| 1852 | + } |
| 1853 | + } |
| 1854 | + if (already_transformed) { |
| 1855 | + continue; |
| 1856 | + } |
| 1857 | + |
| 1858 | + // Phase 1: Query all consumers to discover what formats are needed |
| 1859 | + // Multiple nodes may request the same format, so we deduplicate by format |
| 1860 | + std::unordered_map<std::string, std::vector<std::pair<NodeIndex, int>>> format_to_consumers; |
| 1861 | + |
| 1862 | + for (const auto& [node_idx, input_idx] : consumers) { |
| 1863 | + const Node* node = graph_.GetNode(node_idx); |
| 1864 | + if (!node) { |
| 1865 | + continue; |
| 1866 | + } |
| 1867 | + |
| 1868 | + const auto& ep_type = node->GetExecutionProviderType(); |
| 1869 | + if (ep_type.empty()) { |
| 1870 | + continue; |
| 1871 | + } |
| 1872 | + |
| 1873 | + const auto* ep = execution_providers_.Get(ep_type); |
| 1874 | + if (!ep) { |
| 1875 | + continue; |
| 1876 | + } |
| 1877 | + |
| 1878 | + // Ask EP if it wants this initializer in a different format |
| 1879 | + std::string format_descriptor; |
| 1880 | + Status query_status = ep->GetPreferredInitializerFormat(*node, input_idx, format_descriptor); |
| 1881 | + |
| 1882 | + if (!query_status.IsOK() || format_descriptor.empty()) { |
| 1883 | + continue; |
| 1884 | + } |
| 1885 | + |
| 1886 | + format_to_consumers[format_descriptor].emplace_back(node_idx, input_idx); |
| 1887 | + } |
| 1888 | + |
| 1889 | + if (format_to_consumers.empty()) { |
| 1890 | + continue; |
| 1891 | + } |
| 1892 | + |
| 1893 | + // Load the original initializer to CPU for transformation |
| 1894 | + TensorShape tensor_shape = utils::GetTensorShapeFromTensorProto(*tensor_proto); |
| 1895 | + const auto* tensor_type = DataTypeImpl::TensorTypeFromONNXEnum(tensor_proto->data_type())->GetElementType(); |
| 1896 | + |
| 1897 | + Tensor original_tensor(tensor_type, tensor_shape, cpu_allocator); |
| 1898 | + ORT_RETURN_IF_ERROR( |
| 1899 | + utils::TensorProtoToTensor(Env::Default(), std::filesystem::path(), *tensor_proto, original_tensor)); |
| 1900 | + |
| 1901 | + // Phase 2: Transform once per unique format requested |
| 1902 | + for (const auto& [format_descriptor, nodes_needing_format] : format_to_consumers) { |
| 1903 | + const Node* first_node = graph_.GetNode(nodes_needing_format[0].first); |
| 1904 | + if (!first_node) { |
| 1905 | + continue; |
| 1906 | + } |
| 1907 | + |
| 1908 | + const auto& ep_type = first_node->GetExecutionProviderType(); |
| 1909 | + const auto* ep = execution_providers_.Get(ep_type); |
| 1910 | + if (!ep) { |
| 1911 | + continue; |
| 1912 | + } |
| 1913 | + |
| 1914 | + // Perform the actual transformation |
| 1915 | + std::unique_ptr<Tensor> transformed_tensor; |
| 1916 | + Status transform_status = ep->TransformInitializerFormat(original_tensor, format_descriptor, transformed_tensor); |
| 1917 | + |
| 1918 | + if (!transform_status.IsOK() || !transformed_tensor) { |
| 1919 | + LOGS(logger_, WARNING) << "Failed to transform initializer '" << init_name << "' to format '" |
| 1920 | + << format_descriptor << "': " << transform_status.ErrorMessage(); |
| 1921 | + continue; |
| 1922 | + } |
| 1923 | + |
| 1924 | + // Set format metadata on the transformed tensor |
| 1925 | + transformed_tensor->SetFormatDescriptor(format_descriptor); |
| 1926 | + |
| 1927 | + // Add the transformed initializer with a new name |
| 1928 | + std::string transformed_name = init_name + "_fmt_" + format_descriptor; |
| 1929 | + |
| 1930 | + ONNX_NAMESPACE::TensorProto transformed_proto = utils::TensorToTensorProto(*transformed_tensor, transformed_name); |
| 1931 | + |
| 1932 | + // Add format metadata as TensorProto attribute |
| 1933 | + auto* format_attr = transformed_proto.add_string_data(); |
| 1934 | + *format_attr = "onnxruntime_format:" + format_descriptor; |
| 1935 | + |
| 1936 | + graph_.AddInitializedTensor(transformed_proto); |
| 1937 | + |
| 1938 | + // Update all nodes that need this format to use the transformed version |
| 1939 | + for (const auto& [node_idx, input_idx] : nodes_needing_format) { |
| 1940 | + Node* node = graph_.GetNode(node_idx); |
| 1941 | + if (!node) { |
| 1942 | + continue; |
| 1943 | + } |
| 1944 | + |
| 1945 | + const auto* original_node_arg = node->InputDefs()[input_idx]; |
| 1946 | + auto* transformed_node_arg = &graph_.GetOrCreateNodeArg(transformed_name, original_node_arg->TypeAsProto()); |
| 1947 | + |
| 1948 | + node->MutableInputDefs()[input_idx] = transformed_node_arg; |
| 1949 | + } |
| 1950 | + } |
| 1951 | + } |
| 1952 | + |
| 1953 | + return Status::OK(); |
| 1954 | +} |
| 1955 | + |
1796 | 1956 | } // namespace onnxruntime |
0 commit comments