Skip to content

Commit c58116d

Browse files
committed
Add EP-specific weight layout transformation framework
This infrastructure enables execution providers to optimize operator weights with custom memory layouts (such as blocked formats) during session initialization, dramatically improving inference performance through better cache utilization and memory access patterns. Current Implementation: - HWIO Transpose (WebGPU EP): Transposes Conv weights from OIHW to HWIO layout as the first application of the framework - ABcd16a4b Blocking (Proof-of-Concept): OneDNN-style blocked format with 16×4 tiles demonstrates the framework's primary purpose The framework is generic and extensible, allowing any EP to implement custom weight transformations optimized for their target hardware.
1 parent c30905d commit c58116d

File tree

13 files changed

+736
-15
lines changed

13 files changed

+736
-15
lines changed

include/onnxruntime/core/framework/execution_provider.h

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,39 @@ class IExecutionProvider {
378378
return std::nullopt;
379379
}
380380

381+
/**
382+
Query the preferred format descriptor for an initializer without performing the transformation.
383+
This is a lightweight query called during session initialization to determine what format
384+
transformations are needed.
385+
386+
@param node The node that consumes the initializer
387+
@param input_index The input index of the initializer in the node
388+
@param[out] format_descriptor A string that uniquely identifies the preferred format.
389+
Empty string means no transformation is needed.
390+
Examples: "ABcd16a4b", "hwio".
391+
@return Status::OK() if query succeeded (format_descriptor will be set).
392+
Failed status indicates no transformation is needed.
393+
*/
394+
virtual Status GetPreferredInitializerFormat(const Node& /*node*/, int /*input_index*/,
395+
std::string& /*format_descriptor*/) const {
396+
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "No format transformation needed");
397+
}
398+
399+
/**
400+
Transform an initializer to the specified format.
401+
This performs the actual data transformation. It is only called once per unique format
402+
even if multiple nodes need the same format.
403+
404+
@param original_tensor The original initializer tensor
405+
@param format_descriptor The target format (from GetPreferredInitializerFormat)
406+
@param[out] transformed_tensor The EP should allocate and fill this with the transformed data.
407+
@return Status::OK() if transformation succeeded.
408+
*/
409+
virtual Status TransformInitializerFormat(const Tensor& /*original_tensor*/, const std::string& /*format_descriptor*/,
410+
std::unique_ptr<Tensor>& /*transformed_tensor*/) const {
411+
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Format transformation not supported");
412+
}
413+
381414
virtual void RegisterStreamHandlers(IStreamCommandHandleRegistry& /*stream_handle_registry*/, AllocatorMap&) const {}
382415

383416
/** Does the EP support concurrent calls to InferenceSession::Run to execute the model.

include/onnxruntime/core/framework/tensor.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,19 @@ class Tensor final {
282282
byte_offset_ = byte_offset;
283283
}
284284

285+
/**
286+
* Get the memory format descriptor for this tensor.
287+
* Returns empty string if the tensor is in standard format.
288+
*/
289+
inline const std::string& GetFormatDescriptor() const { return format_descriptor_; }
290+
291+
/**
292+
* Set the memory format descriptor for this tensor.
293+
* Used for EP-specific memory layouts (e.g., "ABcd16a4b" for blocked format).
294+
* The format string encodes all necessary information including block sizes.
295+
*/
296+
inline void SetFormatDescriptor(const std::string& format) { format_descriptor_ = format; }
297+
285298
/// <summary>
286299
/// The number of Tensor "storage" elements. A single storage element may contain multiple sub-elements for
287300
/// sub-byte data types (e.g., int4/float4).
@@ -349,6 +362,9 @@ class Tensor final {
349362
const PrimitiveDataTypeBase* dtype_;
350363
OrtMemoryInfo alloc_info_;
351364
ptrdiff_t byte_offset_;
365+
366+
// Memory format descriptor for EP-specific layouts (e.g., "ABcd16a4b")
367+
std::string format_descriptor_;
352368
};
353369
#ifdef __GNUC__
354370
#pragma GCC diagnostic pop

onnxruntime/core/framework/session_state.cc

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "core/framework/ort_value_pattern_planner.h"
1616
#include "core/framework/prepacked_weights_container.h"
1717
#include "core/framework/session_state_utils.h"
18+
#include "core/framework/tensorprotoutils.h"
1819
#include "core/framework/utils.h"
1920
#include "core/providers/cpu/controlflow/utils.h"
2021
#include "core/session/onnxruntime_session_options_config_keys.h"
@@ -1332,6 +1333,9 @@ Status SessionState::FinalizeSessionState(const std::basic_string<PATH_CHAR_TYPE
13321333
ORT_RETURN_IF_ERROR(CreateSubgraphSessionState());
13331334

13341335
ORT_RETURN_IF_ERROR(VerifyEachNodeIsAssignedToAnEp(graph_, logger_, execution_providers_));
1336+
1337+
ORT_RETURN_IF_ERROR(TransformInitializersToPreferredFormat());
1338+
13351339
ORT_RETURN_IF_ERROR(PopulateKernelCreateInfo(kernel_registry_manager, saving_ort_format));
13361340

13371341
InlinedHashMap<std::string, size_t> constant_initializers_use_count;
@@ -1501,6 +1505,14 @@ Status SessionState::FinalizeSessionStateImpl(const std::basic_string<PATH_CHAR_
15011505
CreateGraphInfo(save_prepacked_initializers);
15021506
}
15031507

1508+
// Index all initializers including those that may have become unreferenced after transformation.
1509+
// This runs after CreateGraphInfo() to ensure consistent ordering - CreateGraphInfo indexes based on
1510+
// graph structure, then we add any remaining initializers (e.g., original weights before transformation).
1511+
for (const auto& [init_name, tensor_proto] : graph_.GetAllInitializedTensors()) {
1512+
ORT_UNUSED_PARAMETER(tensor_proto);
1513+
ort_value_name_idx_map_.Add(init_name);
1514+
}
1515+
15041516
#if defined(ORT_EXTENDED_MINIMAL_BUILD)
15051517
// Remove any unused initializers.
15061518
// Not needed in a full build because unused initializers should have been removed earlier by Graph::Resolve().
@@ -1793,4 +1805,152 @@ void SessionState::RecycleDeviceStreamCollection(std::unique_ptr<DeviceStreamCol
17931805
}
17941806
#endif
17951807

1808+
Status SessionState::TransformInitializersToPreferredFormat() {
1809+
// Build a map from initializer name to all nodes that consume it
1810+
std::unordered_map<std::string, std::vector<std::pair<NodeIndex, int>>> initializer_to_consumers;
1811+
1812+
const auto& initialized_tensors_map = graph_.GetAllInitializedTensors();
1813+
std::unordered_set<std::string> initializer_names;
1814+
for (const auto& [name, tensor_proto] : initialized_tensors_map) {
1815+
ORT_UNUSED_PARAMETER(tensor_proto);
1816+
initializer_names.insert(name);
1817+
}
1818+
1819+
// Scan nodes to find which initializers they use
1820+
for (const auto& node : graph_.Nodes()) {
1821+
int input_index = 0;
1822+
for (const auto* input_def : node.InputDefs()) {
1823+
if (input_def && input_def->Exists()) {
1824+
const auto& input_name = input_def->Name();
1825+
if (initializer_names.count(input_name) > 0) {
1826+
initializer_to_consumers[input_name].emplace_back(node.Index(), input_index);
1827+
}
1828+
}
1829+
++input_index;
1830+
}
1831+
}
1832+
1833+
auto cpu_allocator = GetAllocator(OrtDevice());
1834+
1835+
for (const auto& [init_name, consumers] : initializer_to_consumers) {
1836+
if (consumers.empty()) {
1837+
continue;
1838+
}
1839+
1840+
const ONNX_NAMESPACE::TensorProto* tensor_proto = graph_.GetInitializer(init_name, true);
1841+
if (!tensor_proto) {
1842+
continue;
1843+
}
1844+
1845+
// Skip if this initializer was already transformed (when loading a saved ORT format model)
1846+
// Transformed initializers have format metadata in string_data
1847+
bool already_transformed = false;
1848+
for (const auto& attr_str : tensor_proto->string_data()) {
1849+
if (attr_str.find("onnxruntime_format:") == 0) {
1850+
already_transformed = true;
1851+
break;
1852+
}
1853+
}
1854+
if (already_transformed) {
1855+
continue;
1856+
}
1857+
1858+
// Phase 1: Query all consumers to discover what formats are needed
1859+
// Multiple nodes may request the same format, so we deduplicate by format
1860+
std::unordered_map<std::string, std::vector<std::pair<NodeIndex, int>>> format_to_consumers;
1861+
1862+
for (const auto& [node_idx, input_idx] : consumers) {
1863+
const Node* node = graph_.GetNode(node_idx);
1864+
if (!node) {
1865+
continue;
1866+
}
1867+
1868+
const auto& ep_type = node->GetExecutionProviderType();
1869+
if (ep_type.empty()) {
1870+
continue;
1871+
}
1872+
1873+
const auto* ep = execution_providers_.Get(ep_type);
1874+
if (!ep) {
1875+
continue;
1876+
}
1877+
1878+
// Ask EP if it wants this initializer in a different format
1879+
std::string format_descriptor;
1880+
Status query_status = ep->GetPreferredInitializerFormat(*node, input_idx, format_descriptor);
1881+
1882+
if (!query_status.IsOK() || format_descriptor.empty()) {
1883+
continue;
1884+
}
1885+
1886+
format_to_consumers[format_descriptor].emplace_back(node_idx, input_idx);
1887+
}
1888+
1889+
if (format_to_consumers.empty()) {
1890+
continue;
1891+
}
1892+
1893+
// Load the original initializer to CPU for transformation
1894+
TensorShape tensor_shape = utils::GetTensorShapeFromTensorProto(*tensor_proto);
1895+
const auto* tensor_type = DataTypeImpl::TensorTypeFromONNXEnum(tensor_proto->data_type())->GetElementType();
1896+
1897+
Tensor original_tensor(tensor_type, tensor_shape, cpu_allocator);
1898+
ORT_RETURN_IF_ERROR(
1899+
utils::TensorProtoToTensor(Env::Default(), std::filesystem::path(), *tensor_proto, original_tensor));
1900+
1901+
// Phase 2: Transform once per unique format requested
1902+
for (const auto& [format_descriptor, nodes_needing_format] : format_to_consumers) {
1903+
const Node* first_node = graph_.GetNode(nodes_needing_format[0].first);
1904+
if (!first_node) {
1905+
continue;
1906+
}
1907+
1908+
const auto& ep_type = first_node->GetExecutionProviderType();
1909+
const auto* ep = execution_providers_.Get(ep_type);
1910+
if (!ep) {
1911+
continue;
1912+
}
1913+
1914+
// Perform the actual transformation
1915+
std::unique_ptr<Tensor> transformed_tensor;
1916+
Status transform_status = ep->TransformInitializerFormat(original_tensor, format_descriptor, transformed_tensor);
1917+
1918+
if (!transform_status.IsOK() || !transformed_tensor) {
1919+
LOGS(logger_, WARNING) << "Failed to transform initializer '" << init_name << "' to format '"
1920+
<< format_descriptor << "': " << transform_status.ErrorMessage();
1921+
continue;
1922+
}
1923+
1924+
// Set format metadata on the transformed tensor
1925+
transformed_tensor->SetFormatDescriptor(format_descriptor);
1926+
1927+
// Add the transformed initializer with a new name
1928+
std::string transformed_name = init_name + "_fmt_" + format_descriptor;
1929+
1930+
ONNX_NAMESPACE::TensorProto transformed_proto = utils::TensorToTensorProto(*transformed_tensor, transformed_name);
1931+
1932+
// Add format metadata as TensorProto attribute
1933+
auto* format_attr = transformed_proto.add_string_data();
1934+
*format_attr = "onnxruntime_format:" + format_descriptor;
1935+
1936+
graph_.AddInitializedTensor(transformed_proto);
1937+
1938+
// Update all nodes that need this format to use the transformed version
1939+
for (const auto& [node_idx, input_idx] : nodes_needing_format) {
1940+
Node* node = graph_.GetNode(node_idx);
1941+
if (!node) {
1942+
continue;
1943+
}
1944+
1945+
const auto* original_node_arg = node->InputDefs()[input_idx];
1946+
auto* transformed_node_arg = &graph_.GetOrCreateNodeArg(transformed_name, original_node_arg->TypeAsProto());
1947+
1948+
node->MutableInputDefs()[input_idx] = transformed_node_arg;
1949+
}
1950+
}
1951+
}
1952+
1953+
return Status::OK();
1954+
}
1955+
17961956
} // namespace onnxruntime

onnxruntime/core/framework/session_state.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,12 @@ class SessionState {
431431
const InlinedHashMap<OrtValueName, OrtDevice>& outer_scope_node_arg_to_location_map = {},
432432
bool graph_info_already_created = false);
433433

434+
/**
435+
* Transform initializer tensors to EP-preferred memory formats.
436+
* This is called during session initialization before kernel creation.
437+
*/
438+
Status TransformInitializersToPreferredFormat();
439+
434440
#ifdef ENABLE_TRAINING
435441
Status GeneratePatternGroupCache(
436442
gsl::span<const OrtValue> inputs,

onnxruntime/core/framework/session_state_utils.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,11 @@ common::Status CopyTensorFromCPUToDevice(
227227
}
228228
return copy_status;
229229
} else {
230+
// Preserve format descriptor when copying from CPU to device
231+
const std::string& format = deserialized_tensor.GetFormatDescriptor();
232+
if (!format.empty()) {
233+
tensor.SetFormatDescriptor(format);
234+
}
230235
Tensor::InitOrtValue(std::move(tensor), ort_value);
231236
return common::Status::OK();
232237
}

onnxruntime/core/framework/tensor.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,8 @@ Tensor::Tensor(Tensor&& other) noexcept
195195
#endif
196196
dtype_(other.dtype_),
197197
alloc_info_(other.alloc_info_),
198-
byte_offset_(other.byte_offset_) {
198+
byte_offset_(other.byte_offset_),
199+
format_descriptor_(std::move(other.format_descriptor_)) {
199200
other.p_data_ = nullptr;
200201
other.buffer_deleter_ = nullptr;
201202
other.dtype_ = DataTypeImpl::GetType<float>()->AsPrimitiveDataType();
@@ -221,6 +222,7 @@ Tensor& Tensor::operator=(Tensor&& other) noexcept {
221222
dtype_ = other.dtype_;
222223
alloc_info_ = other.alloc_info_;
223224
byte_offset_ = other.byte_offset_;
225+
format_descriptor_ = std::move(other.format_descriptor_);
224226

225227
other.p_data_ = nullptr;
226228
other.buffer_deleter_ = nullptr;

onnxruntime/core/framework/tensorprotoutils.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1424,6 +1424,15 @@ Status TensorProtoToTensor(const Env& env, const std::filesystem::path& model_pa
14241424
}
14251425
}
14261426

1427+
// Read format metadata from TensorProto string_data
1428+
for (const auto& attr_str : tensor_proto.string_data()) {
1429+
if (attr_str.find("onnxruntime_format:") == 0) {
1430+
std::string format = attr_str.substr(19); // Skip "onnxruntime_format:"
1431+
tensor.SetFormatDescriptor(format);
1432+
break; // Only one format descriptor expected
1433+
}
1434+
}
1435+
14271436
return Status::OK();
14281437
}
14291438

0 commit comments

Comments
 (0)