diff --git a/.gitignore b/.gitignore index f08d97d448..6a4c6eda74 100644 --- a/.gitignore +++ b/.gitignore @@ -81,3 +81,4 @@ coverage.xml *.log *.pt2 examples/torchtrt_aoti_example/torchtrt_aoti_example +CLAUDE.md \ No newline at end of file diff --git a/core/runtime/executorch/TensorRTBackend.cpp b/core/runtime/executorch/TensorRTBackend.cpp new file mode 100644 index 0000000000..93f97dc8ce --- /dev/null +++ b/core/runtime/executorch/TensorRTBackend.cpp @@ -0,0 +1,313 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "core/runtime/executorch/TensorRTBackend.h" + +#include +#include +#include +#include +#include + +#include +#include + +#include +#include + +#include "core/runtime/TRTEngine.h" +#include "core/util/prelude.h" + +namespace torch_tensorrt { +namespace executorch_backend { + +using ::executorch::runtime::ArrayRef; +using ::executorch::runtime::BackendExecutionContext; +using ::executorch::runtime::BackendInitContext; +using ::executorch::runtime::CompileSpec; +using ::executorch::runtime::DelegateHandle; +using ::executorch::runtime::Error; +using ::executorch::runtime::EValue; +using ::executorch::runtime::FreeableBuffer; +using ::executorch::runtime::MemoryAllocator; +using ::executorch::runtime::Result; +using ::executorch::runtime::Span; + +namespace { + +// --------------------------------------------------------------------------- +// Blob deserialization +// +// Wire format written by +// py/torch_tensorrt/executorch/serialization.py::serialize_engine_info() +// +// [uint32_t count (LE)] +// for each of `count` entries: +// [uint32_t len (LE)] [uint8_t data[len]] +// +// The resulting vector is passed directly to +// core::runtime::TRTEngine(std::vector serialized_info) +// which expects the 11-element list defined by SerializedInfoIndex in +// core/runtime/runtime.h +// --------------------------------------------------------------------------- +std::vector deserialize_engine_info(const void* data, size_t size) { + const uint8_t* ptr = static_cast(data); + const uint8_t* const end = ptr + size; + + if (ptr + sizeof(uint32_t) > end) { + return {}; + } + + uint32_t count = 0; + std::memcpy(&count, ptr, sizeof(uint32_t)); + ptr += sizeof(uint32_t); + + std::vector result; + result.reserve(count); + + for (uint32_t i = 0; i < count; ++i) { + if (ptr + sizeof(uint32_t) > end) { + return {}; + } + uint32_t len = 0; + std::memcpy(&len, ptr, sizeof(uint32_t)); + ptr += sizeof(uint32_t); + + if (ptr + len > end) { + return {}; + } + result.emplace_back(reinterpret_cast(ptr), len); + ptr += len; + } + + return result; +} + +// --------------------------------------------------------------------------- +// Build a nvinfer1::Dims from an ExecuTorch tensor's shape +// --------------------------------------------------------------------------- +nvinfer1::Dims to_trt_dims(const exec_aten::Tensor& t) { + nvinfer1::Dims dims{}; + dims.nbDims = t.dim(); + for (int d = 0; d < t.dim(); ++d) { + dims.d[d] = static_cast(t.size(d)); + } + return dims; +} + +} // namespace + +// --------------------------------------------------------------------------- +// is_available +// --------------------------------------------------------------------------- +bool TensorRTBackend::is_available() const { + return true; +} + +// --------------------------------------------------------------------------- +// init +// +// Deserializes the processed blob into a TRTEngine and returns it as the +// opaque DelegateHandle. The engine is placement-new'd into memory +// provided by the ExecuTorch MemoryAllocator so that ExecuTorch owns the +// lifetime; destroy() calls the destructor explicitly. +// --------------------------------------------------------------------------- +Result TensorRTBackend::init(BackendInitContext& context, FreeableBuffer* processed) const { + if (processed == nullptr || processed->data() == nullptr) { + ET_LOG(Error, "TensorRTBackend::init: null processed buffer"); + return Error::InvalidArgument; + } + + auto serialized_info = deserialize_engine_info(processed->data(), processed->size()); + + if (serialized_info.empty()) { + ET_LOG(Error, "TensorRTBackend::init: failed to deserialize engine blob"); + return Error::InvalidArgument; + } + + // Validate the vector length before handing to TRTEngine + // (verify_serialization_fmt throws on mismatch) + core::runtime::TRTEngine::verify_serialization_fmt(serialized_info); + + MemoryAllocator* allocator = context.get_runtime_allocator(); + if (allocator == nullptr) { + ET_LOG(Error, "TensorRTBackend::init: null runtime allocator"); + return Error::InvalidState; + } + + // Allocate raw storage for TRTEngine from ExecuTorch's arena + core::runtime::TRTEngine* engine = allocator->allocateInstance(); + if (engine == nullptr) { + ET_LOG(Error, "TensorRTBackend::init: allocateInstance failed"); + return Error::MemoryAllocationFailed; + } + + // Construct in-place; TRTEngine(std::vector) deserializes the + // engine bytes, builds the IRuntime/ICudaEngine/IExecutionContext, and + // populates in_binding_names / out_binding_names / num_io. + new (engine) core::runtime::TRTEngine(std::move(serialized_info)); + + // Release the blob; we no longer need it + processed->Free(); + + ET_LOG( + Info, + "TensorRTBackend::init: engine '%s' ready (%zu inputs, %zu outputs)", + engine->name.c_str(), + engine->num_io.first, + engine->num_io.second); + + return static_cast(engine); +} + +// --------------------------------------------------------------------------- +// execute +// +// Binds the ExecuTorch input/output tensor data pointers directly to the +// TRT IExecutionContext and calls enqueueV3(). ExecuTorch pre-allocates +// all output tensors before calling execute(), so we only need to register +// their addresses; no separate output allocation is required. +// +// Args layout (mirroring the Python exporter): +// args[0 .. num_inputs-1] – input EValues +// args[num_inputs .. num_inputs+num_outputs-1] – output EValues +// --------------------------------------------------------------------------- +Error TensorRTBackend::execute(BackendExecutionContext& context, DelegateHandle* handle, Span args) const { + (void)context; + + if (handle == nullptr) { + ET_LOG(Error, "TensorRTBackend::execute: null delegate handle"); + return Error::InvalidArgument; + } + + auto* engine = static_cast(handle); + + const size_t num_inputs = engine->num_io.first; + const size_t num_outputs = engine->num_io.second; + + if (args.size() < num_inputs + num_outputs) { + ET_LOG( + Error, "TensorRTBackend::execute: expected at least %zu args, got %zu", num_inputs + num_outputs, args.size()); + return Error::InvalidArgument; + } + + // IExecutionContext::enqueueV3 is not thread-safe; use the engine mutex + std::unique_lock lock(engine->mu); + + nvinfer1::IExecutionContext* ctx = engine->exec_ctx.get(); + + // ------------------------------------------------------------------ + // 1. Bind input shapes and addresses + // ------------------------------------------------------------------ + for (size_t i = 0; i < num_inputs; ++i) { + EValue* arg = args[i]; + if (arg == nullptr || !arg->isTensor()) { + ET_LOG(Error, "TensorRTBackend::execute: input %zu is not a tensor", i); + return Error::InvalidArgument; + } + + exec_aten::Tensor et_in = arg->toTensor(); + const std::string& name = engine->in_binding_names[i]; + nvinfer1::Dims dims = to_trt_dims(et_in); + + if (!ctx->setInputShape(name.c_str(), dims)) { + ET_LOG(Error, "TensorRTBackend::execute: setInputShape failed for '%s'", name.c_str()); + return Error::InvalidState; + } + + void* ptr = et_in.mutable_data_ptr(); + // TRT requires a non-null address even for 0-element tensors + static char placeholder[16] = {}; + if (ptr == nullptr || et_in.numel() == 0) { + ptr = placeholder; + } + + if (!ctx->setTensorAddress(name.c_str(), ptr)) { + ET_LOG(Error, "TensorRTBackend::execute: setTensorAddress failed for input '%s'", name.c_str()); + return Error::InvalidState; + } + } + + // ------------------------------------------------------------------ + // 2. Infer output shapes (requires all input shapes to be set first) + // ------------------------------------------------------------------ + { + const int32_t io_size = engine->cuda_engine->getNbIOTensors(); + std::vector unresolved(static_cast(io_size), nullptr); + const int32_t n_unresolved = ctx->inferShapes(io_size, unresolved.data()); + if (n_unresolved != 0) { + ET_LOG(Error, "TensorRTBackend::execute: inferShapes could not resolve %d tensor(s)", n_unresolved); + return Error::InvalidState; + } + } + + // ------------------------------------------------------------------ + // 3. Bind output addresses (ExecuTorch pre-allocates the buffers) + // ------------------------------------------------------------------ + for (size_t o = 0; o < num_outputs; ++o) { + EValue* arg = args[num_inputs + o]; + if (arg == nullptr || !arg->isTensor()) { + ET_LOG(Error, "TensorRTBackend::execute: output %zu is not a tensor", o); + return Error::InvalidArgument; + } + + exec_aten::Tensor et_out = arg->toTensor(); + const std::string& name = engine->out_binding_names[o]; + void* ptr = et_out.mutable_data_ptr(); + + if (!ctx->setTensorAddress(name.c_str(), ptr)) { + ET_LOG(Error, "TensorRTBackend::execute: setTensorAddress failed for output '%s'", name.c_str()); + return Error::InvalidState; + } + } + + // ------------------------------------------------------------------ + // 4. Enqueue inference on the current CUDA stream + // ------------------------------------------------------------------ + cudaStream_t stream = c10::cuda::getCurrentCUDAStream(static_cast(engine->device_info.id)); + + if (!ctx->enqueueV3(stream)) { + ET_LOG(Error, "TensorRTBackend::execute: enqueueV3 failed"); + return Error::InvalidState; + } + + // Synchronize so that outputs are visible to downstream ExecuTorch ops + cudaStreamSynchronize(stream); + + return Error::Ok; +} + +// --------------------------------------------------------------------------- +// destroy +// +// Explicitly destructs the TRTEngine. The underlying memory was allocated +// by ExecuTorch's MemoryAllocator and will be reclaimed by the arena. +// --------------------------------------------------------------------------- +void TensorRTBackend::destroy(DelegateHandle* handle) const { + if (handle != nullptr) { + static_cast(handle)->~TRTEngine(); + } +} + +} // namespace executorch_backend +} // namespace torch_tensorrt + +// --------------------------------------------------------------------------- +// Static registration – links the name "TensorRTBackend" used in the .pte +// file to this implementation at program startup. +// --------------------------------------------------------------------------- +namespace { + +torch_tensorrt::executorch_backend::TensorRTBackend& get_backend() { + static torch_tensorrt::executorch_backend::TensorRTBackend backend; + return backend; +} + +const ::executorch::runtime::Backend kBackendId{"TensorRTBackend", &get_backend()}; +const auto kRegistered = ::executorch::runtime::register_backend(kBackendId); + +} // namespace diff --git a/core/runtime/executorch/TensorRTBackend.h b/core/runtime/executorch/TensorRTBackend.h new file mode 100644 index 0000000000..3855942251 --- /dev/null +++ b/core/runtime/executorch/TensorRTBackend.h @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * ExecuTorch backend delegate that runs TensorRT engines serialized by + * torch_tensorrt. The processed blob must be in the vector-of-strings wire + * format produced by + * py/torch_tensorrt/executorch/serialization.py::serialize_engine_info() + * which maps 1-to-1 to the std::vector accepted by + * core/runtime/TRTEngine::TRTEngine(std::vector). + */ +#pragma once + +#include + +namespace torch_tensorrt { +namespace executorch_backend { + +class TensorRTBackend final : public ::executorch::runtime::BackendInterface { + public: + bool is_available() const override; + + ::executorch::runtime::Result<::executorch::runtime::DelegateHandle*> init( + ::executorch::runtime::BackendInitContext& context, + ::executorch::runtime::FreeableBuffer* processed, + ::executorch::runtime::ArrayRef<::executorch::runtime::CompileSpec> compile_specs) const override; + + ::executorch::runtime::Error execute( + ::executorch::runtime::BackendExecutionContext& context, + ::executorch::runtime::DelegateHandle* handle, + ::executorch::runtime::Span<::executorch::runtime::EValue*> args) const override; + + void destroy(::executorch::runtime::DelegateHandle* handle) const override; +}; + +} // namespace executorch_backend +} // namespace torch_tensorrt diff --git a/examples/torchtrt_executorch_example/export_static_shape.py b/examples/torchtrt_executorch_example/export_static_shape.py new file mode 100644 index 0000000000..f6f9e9006d --- /dev/null +++ b/examples/torchtrt_executorch_example/export_static_shape.py @@ -0,0 +1,67 @@ +""" +.. _executorch_export: + +Saving a Torch-TensorRT Model in ExecuTorch Format (.pte) +========================================================= + +This example demonstrates how to compile a model with Torch-TensorRT and save it +as an ExecuTorch ``.pte`` file, which can be loaded by the ExecuTorch runtime +(e.g., on embedded or mobile devices with a TensorRT-capable backend). + +Prerequisites +------------- +Install ExecuTorch before running this example:: + + pip install executorch + +See https://pytorch.org/executorch/stable/getting-started-setup.html for details. +""" + +# %% +# Imports and Model Definition +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +import torch +import torch_tensorrt + + +class MyModel(torch.nn.Module): + def forward(self, x): + return x + 1 + + +# %% +# Compile with Torch-TensorRT +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# Export the model, compile it with TensorRT, then save as .pte + +with torch.no_grad(): + model = MyModel().eval().cuda() + example_input = (torch.randn((2, 3, 4, 4)).cuda(),) + + exported_program = torch.export.export(model, example_input) + compile_settings = { + "arg_inputs": [ + torch_tensorrt.Input(shape=(2, 3, 4, 4), dtype=torch.float32), + ], + "min_block_size": 1, + } + trt_gm = torch_tensorrt.dynamo.compile(exported_program, **compile_settings) + + # %% + # Save as ExecuTorch .pte format (loadable by the ExecuTorch runtime) + # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + # The TensorRT engine is serialized inside the .pte using the same blob format + # as the Torch-TensorRT runtime (vector of strings), so one engine format for + # both ExecuTorch and non-ExecuTorch deployment. + # Use retrace=False so the legacy exporter is used; the engine is then available + # when ExecuTorch's partitioner runs the graph. + torch_tensorrt.save( + trt_gm, + "model.pte", + output_format="executorch", + arg_inputs=example_input, + retrace=False, + ) + + print("Saved model.pte successfully.") diff --git a/py/torch_tensorrt/_compile.py b/py/torch_tensorrt/_compile.py index c4dbb1c148..d0246710e2 100644 --- a/py/torch_tensorrt/_compile.py +++ b/py/torch_tensorrt/_compile.py @@ -653,7 +653,7 @@ def save( inputs (Union[torch.Tensor, torch_tensorrt.Input]): Torch input tensors or Input specifications arg_inputs (Tuple[Union[torch.Tensor, torch_tensorrt.Input], ...]): Same as inputs. Alias for better understanding with kwarg_inputs. kwarg_inputs (dict[str, Union[torch.Tensor, torch_tensorrt.Input]]): Optional, kwarg inputs to the module forward function. - output_format (str): Format to save the model. Options include exported_program | torchscript | aot_inductor. + output_format (str): Format to save the model. Options include exported_program | torchscript | aot_inductor | executorch. retrace (bool): When the module type is a fx.GraphModule, this option re-exports the graph using torch.export.export(strict=False) to save it. For TRT-compiled modules with dynamic shapes, both retrace=True and retrace=False are supported: @@ -726,7 +726,7 @@ def save( if isinstance(module, CudaGraphsTorchTensorRTModule): module = module.compiled_module module_type = _parse_module_type(module) - accepted_formats = {"exported_program", "torchscript", "aot_inductor"} + accepted_formats = {"exported_program", "torchscript", "aot_inductor", "executorch"} if arg_inputs is not None and not all( isinstance(input, (torch.Tensor, Input)) for input in arg_inputs ): @@ -847,12 +847,16 @@ def _extract_tensor(obj: Any) -> Any: if output_format not in accepted_formats: raise ValueError( - f"Provided output_format {output_format} is not supported. Supported options are exported_program | torchscript" + f"Provided output_format {output_format} is not supported. Supported options are exported_program | torchscript | aot_inductor | executorch" ) if output_format == "aot_inductor" and platform.system() != "Linux": raise ValueError( f"The AOT Inductor format is only supported on Linux, {platform.system()} is not a supported platform for this format" ) + if output_format == "executorch" and platform.system() != "Linux": + raise ValueError( + f"The executorch format is only supported on Linux, {platform.system()} is not a supported platform for this format" + ) if not file_path: raise ValueError("File path cannot be empty. Please provide a valid file path") @@ -906,6 +910,8 @@ def _extract_tensor(obj: Any) -> Any: inductor_configs=inductor_configs, package_path=file_path, ) + elif output_format == "executorch": + _save_as_executorch(module, file_path) else: raise RuntimeError( "Attempted to serialize an exported program with an unsupported format. Exported programs support exported_program and aot_inductor" @@ -963,6 +969,8 @@ def _extract_tensor(obj: Any) -> Any: inductor_configs=inductor_configs, package_path=file_path, ) + elif output_format == "executorch": + _save_as_executorch(exp_program, file_path) else: raise RuntimeError( "Attempted to serialize an exported program with an unsupported format. Exported programs support exported_program and aot_inductor" @@ -1014,7 +1022,6 @@ def _extract_tensor(obj: Any) -> Any: "Provided model is a torch.fx.GraphModule without existing shape metadata and retrace is True, however no inputs specs were provided. " "Please provide valid torch.Tensors or torch_tensorrt.Input objects as inputs to retrace and save the model" ) - exp_program = torch.export.export( module, args=tuple(arg_tensors), @@ -1042,12 +1049,51 @@ def _extract_tensor(obj: Any) -> Any: inductor_configs=inductor_configs, package_path=file_path, ) + elif output_format == "executorch": + _save_as_executorch(exp_program, file_path) else: raise RuntimeError( "Attempted to serialize an exported program with an unsupported format. Exported programs support exported_program and aot_inductor" ) +def _save_as_executorch(exp_program: Any, file_path: str, **kwargs: Any) -> None: + """Save an ExportedProgram (with TensorRT execute_engine nodes) as an ExecuTorch .pte file. + + Partitions the graph by torch.ops.tensorrt.no_op_placeholder_for_execute_engine + (execute_engine is pre-converted to avoid schema type errors in edge passes), + serializes each engine to the same blob format as the TRT runtime (vector of + strings), and embeds it in the .pte. Requires the ``executorch`` package and + torch_tensorrt_runtime. See https://pytorch.org/executorch/stable/getting-started-setup.html + """ + if not ENABLED_FEATURES.torch_tensorrt_runtime: + raise RuntimeError( + "output_format='executorch' requires the Torch-TensorRT runtime " + "(torch_tensorrt_runtime). Reinstall torch_tensorrt with the runtime extension." + ) + try: + from executorch.exir import EdgeCompileConfig, to_edge_transform_and_lower + except ImportError: + raise ImportError( + "ExecuTorch is not installed. Please install it to use output_format='executorch'. " + "See https://pytorch.org/executorch/stable/getting-started-setup.html" + ) + import torch_tensorrt.dynamo.runtime.meta_ops.register_meta_ops # noqa: F401 + from torch_tensorrt.executorch import TensorRTPartitioner + + extra_partitioners = kwargs.get("partitioners", []) + partitioners = [TensorRTPartitioner()] + extra_partitioners + + edge_program = to_edge_transform_and_lower( + exp_program, + partitioner=partitioners, + compile_config=EdgeCompileConfig(_check_ir_validity=False), + ) + executorch_program = edge_program.to_executorch() + with open(file_path, "wb") as f: + executorch_program.write_to_file(f) + + def function_overload_with_kwargs( fn: Callable[..., Any], *args: Any, **kwargs: Any ) -> Any: diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index bc3cdc5721..79cd025925 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -9,6 +9,12 @@ import torch from torch.export import ExportedProgram + +# TODO: remove this in future, this is just for test executorch which uses torch 2.11 which has a bug in the leaf spec compat +# the bug has been fixed in the torch 2.12 in the upstream. +from torch_tensorrt.dynamo._leaf_spec_compat import _apply_leaf_spec_patch + +_apply_leaf_spec_patch() from torch.fx.node import Target from torch_tensorrt._Device import Device from torch_tensorrt._enums import EngineCapability, dtype diff --git a/py/torch_tensorrt/dynamo/_leaf_spec_compat.py b/py/torch_tensorrt/dynamo/_leaf_spec_compat.py new file mode 100644 index 0000000000..7bbd3a97ad --- /dev/null +++ b/py/torch_tensorrt/dynamo/_leaf_spec_compat.py @@ -0,0 +1,60 @@ +""" +Compatibility shim for a PyTorch 2.11 bug where ``LeafSpec`` (frozen dataclass +with ``slots=True``) inherits the ``type`` slot from ``TreeSpec`` but never +initialises it, leaving the slot empty. This causes + + AttributeError: 'LeafSpec' object has no attribute 'type' + +inside ``ExportedProgram.run_decompositions()`` when a model returns a single +tensor (i.e. the output pytree spec is a leaf rather than a list/tuple). + +The fix is applied once at import time and is a no-op on versions that already +set the attribute correctly. + +Upstream fix: https://github.com/pytorch/pytorch/issues/ +""" + +from __future__ import annotations + +import logging + +logger = logging.getLogger(__name__) + + +def _apply_leaf_spec_patch() -> None: + """Patch ``LeafSpec`` so its inherited ``type`` slot is always set to ``None``. + + Safe to call multiple times; the patch is idempotent. + """ + try: + from torch.utils._pytree import _LEAF_SPEC, LeafSpec + except ImportError: + return # too old / too new, nothing to do + + # Check whether the bug is present on the singleton instance + try: + _ = _LEAF_SPEC.type # noqa: F841 + return # attribute accessible — no patch needed + except AttributeError: + pass + + logger.debug( + "torch_tensorrt: applying LeafSpec.type compatibility patch " + "(PyTorch bug: frozen-dataclass slot not initialised in subclass)" + ) + + # Fix the pre-existing singleton that all pytree leaf specs share + object.__setattr__(_LEAF_SPEC, "type", None) + object.__setattr__(_LEAF_SPEC, "_context", None) + object.__setattr__(_LEAF_SPEC, "_children", []) + + # Patch __post_init__ so any new LeafSpec() instances are also fixed + _orig_post_init = LeafSpec.__post_init__ + + def _post_init_with_type(self: LeafSpec) -> None: + _orig_post_init(self) + object.__setattr__(self, "type", None) + object.__setattr__(self, "_context", None) + object.__setattr__(self, "_children", []) + + LeafSpec.__post_init__ = _post_init_with_type diff --git a/py/torch_tensorrt/executorch/__init__.py b/py/torch_tensorrt/executorch/__init__.py new file mode 100644 index 0000000000..81aa088610 --- /dev/null +++ b/py/torch_tensorrt/executorch/__init__.py @@ -0,0 +1,9 @@ +# ExecuTorch backend for Torch-TensorRT: save/load .pte with TensorRT delegate. + +from torch_tensorrt.executorch.backend import TensorRTBackend +from torch_tensorrt.executorch.partitioner import TensorRTPartitioner + +__all__ = [ + "TensorRTBackend", + "TensorRTPartitioner", +] diff --git a/py/torch_tensorrt/executorch/backend.py b/py/torch_tensorrt/executorch/backend.py new file mode 100644 index 0000000000..5ed8a01893 --- /dev/null +++ b/py/torch_tensorrt/executorch/backend.py @@ -0,0 +1,91 @@ +# ExecuTorch TensorRT backend: serialize engine to same blob format as TRT runtime. + +import base64 +from typing import Any, List, final + +import torch +from executorch.exir.backend.backend_details import ( + BackendDetails, + CompileSpec, + PreprocessResult, +) +from torch.export.exported_program import ExportedProgram +from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import ENGINE_IDX +from torch_tensorrt.executorch.serialization import serialize_engine_info + + +def _get_engine_info_from_edge_program(edge_program: ExportedProgram) -> List[Any]: + """Extract engine info (list of strings/bytes) from the partition's execute_engine node. + + The partition contains a single execute_engine node whose second argument is + either a get_attr node (engine on the graph module) or a placeholder node + (engine lifted into edge_program.constants by torch.export). Either way, + the engine object's __getstate__() returns the SERIALIZATION_LEN-item list + used by the TRT runtime blob format. + """ + gm = edge_program.graph_module + execute_engine_op = torch.ops.tensorrt.execute_engine.default + + for node in gm.graph.nodes: + if node.op != "call_function" or node.target is not execute_engine_op: + continue + + engine_node = node.args[1] + if engine_node.op == "get_attr": + engine_obj = getattr(gm, engine_node.target, None) + if engine_obj is None: + raise RuntimeError( + f"execute_engine node '{node.name}': get_attr target " + f"'{engine_node.target}' not found on graph module" + ) + elif engine_node.op == "placeholder": + constants = getattr(edge_program, "constants", {}) + engine_obj = constants.get(engine_node.name) or constants.get( + engine_node.target + ) + if engine_obj is None: + raise RuntimeError( + f"execute_engine node '{node.name}': placeholder engine " + f"'{engine_node.name}' not found in edge_program.constants" + ) + else: + raise RuntimeError( + f"execute_engine node '{node.name}': unexpected engine arg op " + f"'{engine_node.op}'" + ) + + return list(engine_obj.__getstate__()) + + raise RuntimeError( + "TensorRT ExecuTorch backend: no execute_engine node found in partition." + ) + + +@final +class TensorRTBackend(BackendDetails): # type: ignore[misc] + """Backend that serializes TensorRT engine to the same blob format as the TRT runtime. + + The partition contains a single execute_engine node; we extract the engine + and metadata and encode them as a vector of strings (same layout as + core/runtime/runtime.h SerializedInfoIndex) so the same blob works for + both ExecuTorch and non-ExecuTorch TRT runtime. + """ + + @staticmethod + def preprocess( + edge_program: ExportedProgram, + compile_specs: List[CompileSpec], + ) -> PreprocessResult: + engine_info = _get_engine_info_from_edge_program(edge_program) + engine_info = list(engine_info) + serialized_engine = engine_info[ENGINE_IDX] + if isinstance(serialized_engine, str): + engine_info[ENGINE_IDX] = base64.b64decode( + serialized_engine.encode("utf-8") + ) + elif not isinstance(serialized_engine, (bytes, bytearray)): + engine_info[ENGINE_IDX] = bytes(serialized_engine) + if len(engine_info) > 7 and isinstance(engine_info[7], bytes): + engine_info[7] = engine_info[7].decode("utf-8", errors="replace") + blob = serialize_engine_info(engine_info) + return PreprocessResult(processed_bytes=blob) diff --git a/py/torch_tensorrt/executorch/operator_support.py b/py/torch_tensorrt/executorch/operator_support.py new file mode 100644 index 0000000000..32763665c2 --- /dev/null +++ b/py/torch_tensorrt/executorch/operator_support.py @@ -0,0 +1,26 @@ +# Operator support for ExecuTorch TensorRT partitioner: only execute_engine is supported. + +from typing import Dict + +import torch +from torch.fx.passes.operator_support import OperatorSupportBase + + +class TensorRTOperatorSupport(OperatorSupportBase): # type: ignore[misc] + """Supports only torch.ops.tensorrt.execute_engine for partitioning. + + Used so that TRT-compiled graphs (which already contain execute_engine nodes) + are partitioned per engine; each partition is then lowered to TensorRTBackend + which serializes the engine to the same blob format as the TRT runtime. + """ + + def __init__(self) -> None: + super().__init__() + self._execute_engine_op = torch.ops.tensorrt.execute_engine.default + + def is_node_supported( + self, submodules: Dict[str, torch.nn.Module], node: torch.fx.Node + ) -> bool: + if node.op != "call_function": + return False + return node.target is self._execute_engine_op diff --git a/py/torch_tensorrt/executorch/partitioner.py b/py/torch_tensorrt/executorch/partitioner.py new file mode 100644 index 0000000000..9fcab9f709 --- /dev/null +++ b/py/torch_tensorrt/executorch/partitioner.py @@ -0,0 +1,63 @@ +# ExecuTorch partitioner: partition by execute_engine nodes. + +from typing import Callable, Dict, List, Optional, Tuple + +import torch +from executorch.exir.backend.compile_spec_schema import CompileSpec +from executorch.exir.backend.partitioner import ( + DelegationSpec, + Partitioner, + PartitionResult, +) +from executorch.exir.backend.utils import tag_constant_data +from torch.export import ExportedProgram +from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner +from torch_tensorrt.executorch.backend import TensorRTBackend +from torch_tensorrt.executorch.operator_support import TensorRTOperatorSupport + + +class TensorRTPartitioner(Partitioner): # type: ignore[misc] + """Partitions the graph for TensorRT delegation. + + Only nodes that are torch.ops.tensorrt.execute_engine are supported; + each such node becomes its own partition so the backend can serialize + the engine to the same format as the TRT runtime. + """ + + def __init__( + self, + compile_specs: Optional[List[CompileSpec]] = None, + ) -> None: + super().__init__() + self.compile_specs = compile_specs or [] + self.delegation_spec = DelegationSpec( + backend_id=TensorRTBackend.__name__, + compile_specs=self.compile_specs, + ) + + def partition(self, exported_program: ExportedProgram) -> PartitionResult: + capability_partitioner = CapabilityBasedPartitioner( + exported_program.graph_module, + TensorRTOperatorSupport(), + allows_single_node_partition=True, + ) + partition_list = capability_partitioner.propose_partitions() + + partition_tags: Dict[str, DelegationSpec] = {} + for partition in partition_list: + tag = f"tensorrt_{partition.id}" + for node in partition.nodes: + node.meta["delegation_tag"] = tag + partition_tags[tag] = self.delegation_spec + + tag_constant_data(exported_program) + + return PartitionResult( + tagged_exported_program=exported_program, + partition_tags=partition_tags, + ) + + def ops_to_not_decompose( + self, ep: ExportedProgram + ) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]: + return ([], None) diff --git a/py/torch_tensorrt/executorch/serialization.py b/py/torch_tensorrt/executorch/serialization.py new file mode 100644 index 0000000000..742269973d --- /dev/null +++ b/py/torch_tensorrt/executorch/serialization.py @@ -0,0 +1,32 @@ +# Serialization for ExecuTorch TensorRT blob: same format as TRT runtime (vector of strings). +# Uses the same list format as TorchTensorRTModule._pack_engine_info, then encodes to bytes. +# Only valid when ENABLED_FEATURES.torch_tensorrt_runtime is True. + +import struct +from typing import List, Union + +from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import SERIALIZATION_LEN + + +def serialize_engine_info(engine_info: List[Union[str, bytes]]) -> bytes: + """Encode engine info list (same format as TorchTensorRTModule._pack_engine_info) to bytes. + + Takes the list produced by _pack_engine_info (or equivalent) and writes it in the + TRT runtime vector format: 4-byte count (SERIALIZATION_LEN), then for each + entry 4-byte length (LE) + raw bytes. C++ can deserialize to std::vector + and pass to TRTEngine(std::vector serialized_info). + """ + if len(engine_info) < SERIALIZATION_LEN: + engine_info = list(engine_info) + [""] * (SERIALIZATION_LEN - len(engine_info)) + parts: List[bytes] = [] + for i in range(SERIALIZATION_LEN): + raw = engine_info[i] + if isinstance(raw, str): + raw = raw.encode("utf-8") + elif raw is None: + raw = b"" + else: + raw = bytes(raw) + parts.append(struct.pack(" Path: