diff --git a/.github/scripts/generate-release-matrix.py b/.github/scripts/generate-release-matrix.py index b5926b2f2e..b92fb7d2e5 100644 --- a/.github/scripts/generate-release-matrix.py +++ b/.github/scripts/generate-release-matrix.py @@ -12,7 +12,7 @@ "wheel": ["3.10", "3.11", "3.12", "3.13"], "tarball": ["3.11"], } -sbsa_container_image: str = "quay.io/pypa/manylinux_2_34_aarch64" +sbsa_container_image: str = "quay.io/pypa/manylinux_2_39_aarch64" CXX11_TARBALL_CONTAINER_IMAGE = { "cu130": "pytorch/libtorch-cxx11-builder:cuda13.0-main", diff --git a/.github/workflows/build_linux.yml b/.github/workflows/build_linux.yml index 9168f11c6f..ceec73887d 100644 --- a/.github/workflows/build_linux.yml +++ b/.github/workflows/build_linux.yml @@ -413,5 +413,5 @@ jobs: PYPI_API_TOKEN: ${{ secrets.PYPI_API_TOKEN }} concurrency: - group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{inputs.is-release-wheel}}-${{inputs.is-release-tarball}}-${{inputs.use-rtx}}-${{inputs.architecture}}-${{inputs.is-jetpack}}-${{ inputs.repository }}-${{ github.event_name == 'workflow_dispatch' }} + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{inputs.is-release-wheel}}-${{inputs.is-release-tarball}}-${{inputs.use-rtx}}-${{inputs.architecture}}-${{inputs.is-jetpack}}-${{ inputs.repository }}-${{ github.event_name == 'workflow_dispatch' }}-${{ startsWith(github.ref, 'refs/tags/') && github.ref_name || 'no-tag' }} cancel-in-progress: true diff --git a/.github/workflows/build_windows.yml b/.github/workflows/build_windows.yml index cd9a926913..89db2ad1d9 100644 --- a/.github/workflows/build_windows.yml +++ b/.github/workflows/build_windows.yml @@ -438,5 +438,5 @@ jobs: architecture: ${{ inputs.architecture }} concurrency: - group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ inputs.repository }}-${{ inputs.is-release-wheel }}-${{ inputs.is-release-tarball }}-${{ github.event_name == 'workflow_dispatch' }} + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ inputs.repository }}-${{ inputs.is-release-wheel }}-${{ inputs.is-release-tarball }}-${{ github.event_name == 'workflow_dispatch' }}-${{ startsWith(github.ref, 'refs/tags/') && github.ref_name || 'no-tag' }} cancel-in-progress: true \ No newline at end of file diff --git a/.github/workflows/release-linux-aarch64.yml b/.github/workflows/release-linux-aarch64.yml index b1faa4d668..35defce330 100644 --- a/.github/workflows/release-linux-aarch64.yml +++ b/.github/workflows/release-linux-aarch64.yml @@ -1,6 +1,7 @@ name: Release aarch64 Linux wheels and tarball artifacts on: + pull_request: push: tags: # NOTE: Binary build pipelines should only get triggered on release candidate builds @@ -128,5 +129,5 @@ jobs: architecture: "aarch64" concurrency: - group: ${{ github.workflow }}-aarch64-${{ github.event.pull_request.number || github.ref_name }}-${{ inputs.repository }}-${{ github.event_name == 'workflow_dispatch' }}-${{ inputs.job-name }} + group: ${{ github.workflow }}-aarch64-release-${{ github.event.pull_request.number || github.ref_name }}-${{ inputs.repository }}-${{ github.event_name == 'workflow_dispatch' }}-${{ inputs.job-name }} cancel-in-progress: true \ No newline at end of file diff --git a/.github/workflows/release-linux-x86_64.yml b/.github/workflows/release-linux-x86_64.yml index 60a6abd9bf..d04445aa32 100644 --- a/.github/workflows/release-linux-x86_64.yml +++ b/.github/workflows/release-linux-x86_64.yml @@ -126,5 +126,5 @@ jobs: is-release-wheel: true concurrency: - group: ${{ github.workflow }}-x86_64-${{ github.event.pull_request.number || github.ref_name }}-${{ inputs.repository }}-${{ github.event_name == 'workflow_dispatch' }}-${{ inputs.job-name }} + group: ${{ github.workflow }}-x86_64-release-${{ github.event.pull_request.number || github.ref_name }}-${{ inputs.repository }}-${{ github.event_name == 'workflow_dispatch' }}-${{ inputs.job-name }} cancel-in-progress: true \ No newline at end of file diff --git a/core/runtime/TRTEngine.cpp b/core/runtime/TRTEngine.cpp index f316adc160..d122e00c9e 100644 --- a/core/runtime/TRTEngine.cpp +++ b/core/runtime/TRTEngine.cpp @@ -142,6 +142,9 @@ TRTEngine::TRTEngine( } TORCHTRT_CHECK((exec_ctx.get() != nullptr), "Unable to create TensorRT execution context"); + // Pre-allocate placeholder for empty tensors (TensorRT requires non-null addresses) + cudaMalloc(&empty_tensor_placeholder, 1); + runtime_states.old_cudagraphs = CUDAGRAPHS_MODE; runtime_states.old_pre_allocated_outputs = false; runtime_states.context_changed = false; @@ -264,6 +267,9 @@ TRTEngine::~TRTEngine() { trt_engine_profiler.reset(); exec_ctx.reset(); cuda_engine.reset(); + if (empty_tensor_placeholder) { + cudaFree(empty_tensor_placeholder); + } rt.reset(); } @@ -315,7 +321,7 @@ void TRTEngine::set_profile_format(std::string format) { } std::string TRTEngine::get_engine_layer_info() { - auto inspector = cuda_engine->createEngineInspector(); + auto inspector = make_trt(cuda_engine->createEngineInspector()); return inspector->getEngineInformation(nvinfer1::LayerInformationFormat::kJSON); } diff --git a/core/runtime/TRTEngine.h b/core/runtime/TRTEngine.h index 5a69fe9754..bf95740bae 100644 --- a/core/runtime/TRTEngine.h +++ b/core/runtime/TRTEngine.h @@ -187,6 +187,9 @@ struct TRTEngine : torch::CustomClassHolder { bool use_pre_allocated_outputs = false; std::vector pre_allocated_outputs; + // Single placeholder buffer for empty tensor inputs (allocated once, reused) + void* empty_tensor_placeholder = nullptr; + // Output Allocator-Related Functionality bool requires_output_allocator = false; // engine requires output allocator bool use_output_allocator_outputs = false; // users specify to use output allocator diff --git a/core/runtime/execute_engine.cpp b/core/runtime/execute_engine.cpp index 8338fde257..3d5975f049 100644 --- a/core/runtime/execute_engine.cpp +++ b/core/runtime/execute_engine.cpp @@ -149,18 +149,26 @@ void setup_input_tensors( TORCHTRT_CHECK( compiled_engine->exec_ctx->setInputShape(name.c_str(), dims), "Error while setting the input shape"); + at::Tensor final_input; if (cudagraphs_enabled) { // If using CUDAGraphs copy formatted input to the corresponding persistent input buffer compiled_engine->input_buffers[i].copy_(formatted_inputs.back(), true); - TORCHTRT_CHECK( - compiled_engine->exec_ctx->setTensorAddress(name.c_str(), compiled_engine->input_buffers[i].data_ptr()), - "Error while setting the input tensor address for inputs"); + final_input = compiled_engine->input_buffers[i]; } else { // Otherwise use the formatted buffer directly - TORCHTRT_CHECK( - compiled_engine->exec_ctx->setTensorAddress(name.c_str(), formatted_inputs.back().data_ptr()), - "Error while setting the input tensor address for inputs"); + final_input = formatted_inputs.back(); } + + // Get tensor address, using placeholder for empty tensors + // TensorRT requires non-null address even if numel() = 0 + // empty_tensor_placeholder is pre-allocated in TRTEngine constructor + void* input_addr = (final_input.numel() == 0 || final_input.data_ptr() == nullptr) + ? compiled_engine->empty_tensor_placeholder + : final_input.data_ptr(); + + TORCHTRT_CHECK( + compiled_engine->exec_ctx->setTensorAddress(name.c_str(), input_addr), + "Failed to bind tensor address for " << name); } } } diff --git a/cpp/bin/torchtrtc/main.cpp b/cpp/bin/torchtrtc/main.cpp index 874cb96ef3..b27611e1cf 100644 --- a/cpp/bin/torchtrtc/main.cpp +++ b/cpp/bin/torchtrtc/main.cpp @@ -356,6 +356,7 @@ int main(int argc, char** argv) { } if (enabled_precisions) { + compile_settings.enabled_precisions.clear(); for (const auto& precision : args::get(enabled_precisions)) { auto dtype = torchtrtc::parserutil::parse_dtype(precision); if (dtype == torchtrt::DataType::kFloat) { diff --git a/examples/dynamo/compile_with_dynamic_inputs.py b/examples/dynamo/compile_with_dynamic_inputs.py new file mode 100644 index 0000000000..b2f61c6caa --- /dev/null +++ b/examples/dynamo/compile_with_dynamic_inputs.py @@ -0,0 +1,58 @@ +import logging + +import torch +import torch.nn as nn +import torch_tensorrt + +logging.basicConfig(level=logging.DEBUG) + +torch.manual_seed(0) + + +class ExpandReshapeModel(nn.Module): + def __init__(self, embed_dim: int): + super().__init__() + self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim)) + self.embed_dim = embed_dim + self.qkv_proj = nn.Linear(self.embed_dim, self.embed_dim * 3) + + def forward(self, x: torch.Tensor): + batch_size = x.shape[0] + cls_token = self.cls_token.expand(batch_size, -1, -1) + x = torch.cat([cls_token, x], dim=1) + x = self.qkv_proj(x) + reshaped_qkv = x.reshape(batch_size, x.size(1), 3, 12, -1) + return reshaped_qkv + + +model = ExpandReshapeModel(embed_dim=768).cuda().eval() +x = torch.randn(4, 196, 768).cuda() + +# 1. JIT: torch.compile +x1 = x.clone() +torch._dynamo.mark_dynamic(x1, index=0, min=2, max=32) +trt_module = torch.compile(model, backend="tensorrt") +out1 = trt_module(x1) + +# 2. AOT: torch_tensorrt.compile +x2 = x.clone() +example_input = torch_tensorrt.Input( + min_shape=[1, 196, 768], + opt_shape=[4, 196, 768], + max_shape=[32, 196, 768], + dtype=torch.float32, +) +trt_module = torch_tensorrt.compile(model, ir="dynamo", inputs=example_input) +out2 = trt_module(x2) + +# 3. AOT: torch.export + Dynamo compile +x3 = x.clone() +bs = torch.export.Dim("bs", min=1, max=32) +dynamic_shapes = {"x": {0: bs}} +exp_program = torch.export.export(model, (x3,), dynamic_shapes=dynamic_shapes) +trt_module = torch_tensorrt.dynamo.compile(exp_program, (x3,)) +out3 = trt_module(x3) + +assert torch.allclose(out1, out2) +assert torch.allclose(out1, out3) +assert torch.allclose(out2, out3) diff --git a/examples/dynamo/torch_compile_resnet_example.py b/examples/dynamo/torch_compile_resnet_example.py index 506982c7f4..d82bd02fee 100644 --- a/examples/dynamo/torch_compile_resnet_example.py +++ b/examples/dynamo/torch_compile_resnet_example.py @@ -48,6 +48,7 @@ model, ir="torch_compile", inputs=inputs, + use_explicit_typing=False, enabled_precisions=enabled_precisions, workspace_size=workspace_size, min_block_size=min_block_size, @@ -86,6 +87,7 @@ model, ir="torch_compile", inputs=inputs_bs8, + use_explicit_typing=False, enabled_precisions=enabled_precisions, workspace_size=workspace_size, min_block_size=min_block_size, @@ -111,6 +113,7 @@ dtype=torch.half, ) ], + "use_explicit_typing": False, "enabled_precisions": enabled_precisions, "ir": "dynamo", } diff --git a/py/torch_tensorrt/_Input.py b/py/torch_tensorrt/_Input.py index 2f953094ca..b547afb278 100644 --- a/py/torch_tensorrt/_Input.py +++ b/py/torch_tensorrt/_Input.py @@ -1,11 +1,14 @@ from __future__ import annotations +import logging from enum import Enum from typing import Any, Dict, List, Optional, Sequence, Tuple import torch from torch_tensorrt._enums import dtype, memory_format +logger = logging.getLogger(__name__) + class Input(object): """ @@ -149,6 +152,16 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: } self.shape_mode = Input._ShapeMode.DYNAMIC + # Warn if min_shape has any 0 dimension (empty tensor) - TensorRT doesn't support this + # @apbose: Is this warning necessary? + if any(dim == 0 for dim in self.shape["min_shape"]): + logger.warning( + f"min_shape contains a 0 dimension: {self.shape['min_shape']}. " + "TensorRT does not support dynamic shapes with min dimension of 0 (empty tensors). " + "TensorRT will internally clamp min dimensions to 1, which may cause runtime errors " + "if you try to run inference with empty tensor inputs." + ) + else: raise ValueError( f"Unexpected number of positional arguments for class Input \n Found {len(args)} arguments, expected either zero or a single positional arguments" @@ -384,7 +397,7 @@ def example_tensor( dtype=self.dtype.to(torch.dtype, use_default=True) ) else: - RuntimeError( + raise RuntimeError( f"Input shape is dynamic but shapes are not provided as sequence (found: {self.shape})" ) else: @@ -412,4 +425,3 @@ def example_tensor( raise ValueError( "Requested an example tensor from a dynamic shaped input but did not specific which profile field to use." ) - raise diff --git a/py/torch_tensorrt/_utils.py b/py/torch_tensorrt/_utils.py index 7dd1198cd7..20521590ba 100644 --- a/py/torch_tensorrt/_utils.py +++ b/py/torch_tensorrt/_utils.py @@ -76,6 +76,12 @@ def is_tegra_platform() -> bool: return False +def is_orin() -> bool: + if torch.cuda.get_device_capability() in [(8, 7)]: + return True + return False + + def is_thor() -> bool: if torch.cuda.get_device_capability() in [(11, 0)]: return True diff --git a/py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py b/py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py index 35d5d7eeec..a84dc2155d 100644 --- a/py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py +++ b/py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py @@ -20,7 +20,6 @@ import tensorrt as trt import torch -import torch_tensorrt from torch import SymBool, SymFloat, SymInt from torch._ops import OpOverloadPacket from torch.fx.node import Argument, Node, Target, _get_qualified_name @@ -536,7 +535,7 @@ def __contains__(self, key: Target | Node) -> bool: def get_all_converters_with_target( self, key: Target, return_registry_info: bool = False ) -> Tuple[ - Union[List[Any], Dict[str, int], None] + List[Any], Optional[Dict[str, int]] ]: # TODO: Narrow to ConverterImplSignature this when we can remove FX converters """Get all converters across all registries for the target @@ -547,7 +546,7 @@ def get_all_converters_with_target( # Store count of number of registered converters per registry if return_registry_info: - registry_data = {name: 0 for name in self.registry_names} + registry_data = dict.fromkeys(self.registry_names, 0) for index, registry in enumerate(self.registries): if key in registry: @@ -622,22 +621,18 @@ def display_all_available_converters(self) -> str: return available_converters -# Initialize dynamo converter registry with the FX and Dynamo aten registries -# Note the Dynamo registry is listed first, for precedence -registries = [ - DYNAMO_ATEN_CONVERTERS, +# Initialize dynamo converter registry with Dynamo aten converters only +# FX converters are not loaded here - they are legacy and should only be used +# in the FX frontend, not as fallbacks in the dynamo frontend +registries: List[ + Dict[Target, Union[Callable[..., Any], Sequence[ConverterSupport]]] +] = [ + DYNAMO_ATEN_CONVERTERS, # type: ignore[list-item] ] registry_names = ["Dynamo ATen Converters Registry"] registry_calling_conventions = [ CallingConvention.CTX, ] -if torch_tensorrt.ENABLED_FEATURES.fx_frontend: - from torch_tensorrt.fx.converter_registry import CONVERTERS as FX_CONVERTERS - - registries.append(FX_CONVERTERS) - registry_names.append("FX Legacy ATen Converters Registry") - registry_calling_conventions.append(CallingConvention.LEGACY) - DYNAMO_CONVERTERS: ConverterRegistry = ConverterRegistry( registries, diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 02b6eb6377..07e9ea3100 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -219,6 +219,34 @@ def aten_ops_native_group_norm( ) +@dynamo_tensorrt_converter( + torch.ops.aten._fused_rms_norm.default, + supports_dynamic_shapes=True, +) +@enforce_tensor_types( + { + 0: (TRTTensor,), + } +) +def aten_ops_fused_rms_norm( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.normalization.fused_rms_norm( + ctx, + target, + SourceIR.ATEN, + name, + input=args[0], + normalized_shape=args[1], + weight=args_bounds_check(args, 2), + eps=args_bounds_check(args, 3), + ) + + def parse_cat_args( args: Tuple[Argument, ...], kwargs: Dict[str, Any] ) -> Tuple[List[Any], int]: @@ -251,8 +279,61 @@ def parse_cat_args( return input_tensors, dim +def cat_validator(node: Node, settings: Optional[CompilationSettings] = None) -> bool: + """ + Validator for torch.cat operation with empty tensor handling. + + PyTorch allows torch.tensor([]) (shape (0,)) to be concatenated with higher-dimensional + tensors, but TensorRT requires all inputs to have the same rank. This validator catches + this specific edge case. + + Example valid case: cat([(3, 4), (0, 4)], dim=0) - same rank, properly shaped empty tensor for TRT + Example invalid case: cat([(3, 4), (0,)], dim=0) - torch.tensor([]) with rank mismatch + """ + # Use parse_cat_args to properly extract inputs (handles both args and kwargs patterns) + inputs, _ = parse_cat_args(node.args, node.kwargs) + + if len(inputs) < 2: + return True + + # Collect metadata for all inputs + input_metas = [] + for inp in inputs: + if isinstance(inp, TRTTensor): + # TRTTensor has shape directly + input_metas.append(inp.shape) + else: + # For nodes, get metadata + meta = getattr(inp, "meta", {}).get("tensor_meta") + if meta is None: + # Can't validate without metadata, allow it + return True + shape = tuple(meta.shape) + input_metas.append(shape) + + # Check for the specific problematic case: + # 1D empty tensor (0,) being concatenated with higher-dimensional tensors + ranks = [len(shape) for shape in input_metas] + # If all ranks are the same, it's fine (PyTorch and TensorRT both handle this) + if len(set(ranks)) == 1: + return True + # If ranks differ, check if we have a 1D empty tensor (0,) in the mix + # This is the torch.tensor([]) case that PyTorch allows but TensorRT doesn't + for i, shape in enumerate(input_metas): + if shape == (0,) or (len(shape) == 1 and shape[0] == 0): + # Found a 1D empty tensor with rank mismatch + _LOGGER.debug( + f"Concatenation rejected by TRT, torch.tensor([]) or 1D empty tensor at position {i} " + f"PyTorch allows this but TensorRT requires all inputs to have the same rank. " + f"Use torch.empty((0, ...)) with explicit dimensions matching other inputs instead. Falling back to Pytorch" + ) + return False + return True + + @dynamo_tensorrt_converter( torch.ops.aten.cat.default, + capability_validator=cat_validator, supports_dynamic_shapes=True, ) def aten_ops_cat( @@ -413,6 +494,27 @@ def aten_ops_relu( ) +@dynamo_tensorrt_converter( + torch.ops.aten.hardtanh.default, supports_dynamic_shapes=True +) +def aten_ops_hardtanh( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.activation.hardtanh( + ctx, + target, + SourceIR.ATEN, + name, + args[0], + args_bounds_check(args, 1, -1.0), + args_bounds_check(args, 2, 1.0), + ) + + @dynamo_tensorrt_converter(torch.ops.aten.sigmoid.default, supports_dynamic_shapes=True) def aten_ops_sigmoid( ctx: ConversionContext, @@ -446,6 +548,24 @@ def aten_ops_symsize_int( return impl.shape.shape(ctx, target, SourceIR.ATEN, name, args[0], args[1]) +@enforce_tensor_types( + { + 0: (TRTTensor,), + } +) +@dynamo_tensorrt_converter( + torch.ops.aten.sym_numel.default, supports_dynamic_shapes=True +) +def aten_ops_sym_numel( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.shape.numel(ctx, target, SourceIR.ATEN, name, args[0]) + + def index_dtype_validator( node: Node, settings: Optional[CompilationSettings] = None ) -> bool: diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index 175979ccf9..816f8aec58 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -477,6 +477,42 @@ def create_constant( "Currently FP4 is only supported in TensorRT 10.8.0 and above" ) # Record the weight in ctx for refit and cpu memory reference + # TensorRT's add_constant doesn't support 0-element tensors, + # but TRT does support empty tensors at runtime. + # For empty constants, we create a larger constant and slice it to empty. + if torch_value.numel() == 0: + empty_shape = list(torch_value.shape) + + # Create a placeholder shape where each dim is max(1, target_dim) + # This ensures we can slice to the target empty shape + # e.g., target [0, 4] -> placeholder [1, 4] -> slice to [0, 4] + placeholder_shape = [max(1, d) for d in empty_shape] + placeholder_numel = 1 + for d in placeholder_shape: + placeholder_numel *= d + + # Create placeholder constant with the required number of elements + placeholder_value = torch.zeros(placeholder_numel, dtype=torch_value.dtype) + placeholder_weights = to_trt_weights( + ctx, placeholder_value, f"{name}_placeholder", "CONSTANT", "CONSTANT" + ) + placeholder_constant = ctx.net.add_constant( + tuple(placeholder_shape), placeholder_weights + ) + placeholder_constant.name = f"{name}_placeholder" + + # Slice to get the empty shape (at least one dimension is 0) + start = [0] * len(empty_shape) + stride = [1] * len(empty_shape) + slice_layer = ctx.net.add_slice( + placeholder_constant.get_output(0), + start=start, + shape=empty_shape, + stride=stride, + ) + slice_layer.name = f"{name}_empty_slice" + + return slice_layer.get_output(0) # Convert the torch.Tensor to a trt.Weights object trt_weights = to_trt_weights(ctx, torch_value, name, "CONSTANT", "CONSTANT") @@ -516,7 +552,8 @@ def get_trt_tensor( # If the input is 64-bit, cast it to 32-bit for TRT freezing if isinstance(input_val, torch.Tensor) and ctx.compilation_settings.truncate_double: if input_val.dtype == torch.float64: - input_val = input_val.to(torch.float32) + with unset_fake_temporarily(): + input_val = input_val.to(torch.float32) elif isinstance(input_val, np.ndarray) and ctx.compilation_settings.truncate_double: if input_val.dtype == np.float64: input_val = input_val.astype(np.float32) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/activation/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/activation/ops.py index af47a8e2c9..2f5accf4a6 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/activation/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/activation/ops.py @@ -350,3 +350,27 @@ def gelu( operation_type, input_val, ) + + +def hardtanh( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input_val: TRTTensor, + min_val: float = -1.0, + max_val: float = 1.0, +) -> TRTTensor: + # Ported from fx/converters/impl/activation.py + # dyn_range_fn removed as it's not used in dynamo's convert_activation base + operation_type = trt.ActivationType.CLIP + return convert_activation( + ctx, + target, + source_ir, + name, + operation_type, + input_val, + alpha=min_val, + beta=max_val, + ) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/cat.py b/py/torch_tensorrt/dynamo/conversion/impl/cat.py index 0553d766c1..8134f2401b 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/cat.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/cat.py @@ -11,6 +11,7 @@ from torch_tensorrt.dynamo.conversion.converter_utils import ( cast_trt_tensor, get_positive_dim, + get_trt_tensor, set_layer_name, ) @@ -36,25 +37,35 @@ def unify_and_concat_trt_tensors( cast_dtype: Optional target dtype for casting TRT tensors. force_trt_output: If True, return TRT tensor even if all inputs are static ints. (True for concat operations) """ - has_dynamic = any(not isinstance(x, int) for x in inputs) + # Normalize scalar tensors (0D) to Python values to avoid 0D vs 1D shape issues. + # + # eg case: + # torch.tensor(3) is a 0D tensor (shape=[]) + # get_trt_tensor creates a 0D TRT constant for it (shape=trt.Dims()) + # Python int 3 via get_trt_tensor creates a 1D TRT constant (shape=(1,)) + # because to_torch(3) returns torch.tensor([3]) with shape (1,) + # + # By normalizing torch.tensor(3) -> 3, we ensure: + # 1. Pure static case: all ints -> returns list directly, no TRT ops needed (eg:upsample) + # 2. Mixed case: Python ints become 1D constants, compatible with other 1D tensors + normalized_inputs = [] + for x in inputs: + if isinstance(x, (torch.Tensor, np.ndarray)) and x.ndim == 0: + normalized_inputs.append(x.item()) + else: + normalized_inputs.append(x) + has_dynamic = any(not isinstance(x, int) for x in normalized_inputs) trt_tensors = [] - for i, x in enumerate(inputs): + for i, x in enumerate(normalized_inputs): # convert to TRTTensor if isinstance(x, TRTTensor): t = x elif isinstance(x, int) and not has_dynamic and not force_trt_output: t = x # pure static path else: - const_arr = np.array([x], dtype=np.int32) - shape = (1,) - if not isinstance(x, int): - const_arr = np.array(x, dtype=np.int32) - shape = (x.numel(),) - - layer = ctx.net.add_constant(shape, const_arr) - set_layer_name(layer, target, f"{name}_dim{i}_const") - t = layer.get_output(0) + # Use get_trt_tensor which handles empty tensors properly via create_constant + t = get_trt_tensor(ctx, x, f"{name}_input_{i}") trt_tensors.append(t) if not has_dynamic and not force_trt_output: @@ -85,13 +96,6 @@ def unify_and_concat_trt_tensors( ).to(trt.DataType) final_dtype = promoted_type - # promote remaining ints to TRT consts before concat - for i, t in enumerate(trt_tensors): - if isinstance(t, int): - const = ctx.net.add_constant((1,), np.array([t], dtype=np.int32)) - set_layer_name(const, target, f"{name}_static_{i}_const") - trt_tensors[i] = const.get_output(0) - # final cast if final_dtype is not None: casted = [] diff --git a/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py index f12b16b150..cfd47af475 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py @@ -854,3 +854,89 @@ def cdist_forward( return_indices=False, ) return dist + + +def fused_rms_norm( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: trt.ITensor, + normalized_shape: List[int], + weight: Optional[Union[trt.ITensor, torch.Tensor, np.ndarray]], + eps: Optional[float], +) -> Tuple[trt.ITensor, torch.Tensor]: + """ + RMS Normalization: output = input / sqrt(mean(input^2) + eps) * weight + + Args: + ctx: ConversionContext containing the TensorRT network + target: Target of calling node + source_ir: SourceIR of calling converter + name: Name of the calling layer + input: Input tensor to normalize + normalized_shape: Shape over which to normalize (list of ints) + weight: Optional weight/scale parameter + eps: Epsilon for numerical stability (default: 1e-5) + + Returns: + Tuple of (normalized_output, rstd_placeholder) + Note: rstd (reciprocal standard deviation) is returned as None placeholder + """ + if eps is None: + eps = 1e-5 + + # Calculate dimensions to normalize over (similar to layer_norm) + # normalized_shape specifies the last N dimensions + dims = list(range(len(input.shape) - len(normalized_shape), len(input.shape))) + axes = get_axes_for_reduce_op(dims) + + # Square the input + input_squared = impl.elementwise.mul( + ctx, target, source_ir, f"{name}_input_squared", input, input + ) + + # Compute mean of squared values + mean_squared = impl.reduce.mean( + ctx, target, source_ir, f"{name}_mean_squared", input_squared, dim=dims, keepdim=True + ) + + # Add epsilon for numerical stability + eps_tensor = get_trt_tensor(ctx, eps, f"{name}_eps", input.dtype) + mean_squared_eps = impl.elementwise.add( + ctx, target, source_ir, f"{name}_mean_squared_eps", mean_squared, eps_tensor + ) + + # Compute RMS = sqrt(mean(input^2) + eps) + rms = impl.unary.sqrt(ctx, target, source_ir, f"{name}_rms", mean_squared_eps) + + # Normalize: input / rms + normalized = impl.elementwise.div( + ctx, target, source_ir, f"{name}_normalized", input, rms + ) + + # Apply weight (scale) if provided + if weight is not None: + weight_trt = get_trt_tensor(ctx, weight, f"{name}_weight") + + # Cast weight to match input dtype + weight_trt = cast_trt_tensor( + ctx, weight_trt, input.dtype, f"{name}_weight_cast", target, source_ir + ) + + # Expand weight to match input shape if needed + if tuple(input.shape) != tuple(weight_trt.shape): + weight_trt = impl.slice.expand( + ctx, target, source_ir, f"{name}_expand_weight", weight_trt, input.shape + ) + + # Multiply normalized output by weight + output = impl.elementwise.mul( + ctx, target, source_ir, f"{name}_output", normalized, weight_trt + ) + else: + output = normalized + + # Return (output, rstd_placeholder) + # PyTorch returns (output, rstd) but we return None for rstd as it's typically not used + return output, None diff --git a/py/torch_tensorrt/dynamo/conversion/impl/shape.py b/py/torch_tensorrt/dynamo/conversion/impl/shape.py index 27af02e5bb..effa22ec54 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/shape.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/shape.py @@ -11,6 +11,7 @@ from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion.converter_utils import ( cast_trt_tensor, + get_axes_for_reduce_op, get_positive_dim, get_trt_tensor, set_layer_name, @@ -57,6 +58,39 @@ def shape( return input_shape +def numel( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input_val: TRTTensor, +) -> TRTTensor: + """ + Returns the total number of elements in the input tensor. + This is equivalent to torch.numel() or tensor.numel(). + + Implementation: + 1. Get the shape of the input tensor via add_shape layer + 2. Reduce with PROD operation over all dimensions to get total element count + """ + # Get the shape tensor (1D tensor containing dimension sizes) + shape_layer = ctx.net.add_shape(input_val) + set_layer_name(shape_layer, target, name + "_shape", source_ir) + shape_tensor = shape_layer.get_output(0) + + # Reduce with PROD over axis 0 (the only axis of the 1D shape tensor) + # This computes the product of all dimensions = total number of elements + reduce_layer = ctx.net.add_reduce( + shape_tensor, + trt.ReduceOperation.PROD, + axes=get_axes_for_reduce_op(0), + keep_dims=True, + ) + set_layer_name(reduce_layer, target, name + "_reduce", source_ir) + + return reduce_layer.get_output(0) + + def get_shape_with_dynamic_shape( ctx: ConversionContext, target: Target, diff --git a/py/torch_tensorrt/dynamo/lowering/passes/remove_sym_nodes.py b/py/torch_tensorrt/dynamo/lowering/passes/remove_sym_nodes.py index 9f69572059..6fc4c27b6f 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/remove_sym_nodes.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/remove_sym_nodes.py @@ -15,6 +15,8 @@ def remove_sym_nodes( """Remove sym_int placeholders which get inserted due to torch.compile's dynamic=True behavior """ + gm = replace_symint_with_sym_size(gm) + # Extract SymInt placeholder Tensors placeholder_idx_sym_ints = [ (idx, node) @@ -36,3 +38,56 @@ def remove_sym_nodes( logger.debug(f"Removed SymInt placeholders:\n{gm.graph}") return gm + + +def replace_symint_with_sym_size( + gm: torch.fx.GraphModule, +) -> torch.fx.GraphModule: + """Replace SymInt placeholders with sym_size nodes""" + # Find all SymInt placeholders and their args + symint_node_arg_dict = {} + for node in gm.graph.nodes: + if ( + node.op == "placeholder" + and isinstance(node.type, type) + and issubclass(node.type, torch.SymInt) + ): + ga = node.meta.get("grapharg", None) + if ga is not None: + src = ga.source # TensorPropertySource + symint_node_arg_dict[node] = (src.base.local_name, src.idx) + + # Replace SymInt placeholders with sym_size nodes + for node in gm.graph.nodes: + if ( + node.op == "placeholder" + and isinstance(node.type, type) + and issubclass(node.type, torch.Tensor) + ): + ga = node.meta.get("grapharg", None) + if ga is not None: + src = ga.source + if hasattr(src, "local_name") and getattr(src, "is_input", False): + node_local_name = src.local_name + for symint_node, ( + symint_local_name, + idx, + ) in symint_node_arg_dict.items(): + if node_local_name == symint_local_name: + with gm.graph.inserting_after(node): + size_node = gm.graph.call_function( + torch.ops.aten.sym_size, args=(node, idx) + ) + symint_node.replace_all_uses_with(size_node) + logger.debug( + f"The SymInt node {symint_node} is replaced with the sym_size node {size_node}" + ) + # the symint_node is not used anymore, but it cannot be directly erased here + # because it will cause the number of positional arguments mismatch error. + # The node will be removed in the outside of the function + + gm.graph.lint() + gm.recompile() + logger.debug(f"Added sym_size nodes for SymInt placeholders:\n{gm.graph}") + + return gm diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index 00b5224740..12f1ce28c7 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -415,6 +415,16 @@ def setup_input_tensors( self.context.set_input_shape( input_name, tuple(contiguous_inputs[i].shape) ) + tensor_to_bind = contiguous_inputs[i] + if tensor_to_bind.numel() == 0: + # Use a single persistent placeholder for empty tensors (allocated once, reused) + if not hasattr(self, "_empty_tensor_placeholder"): + self._empty_tensor_placeholder = torch.empty( + 1, + dtype=tensor_to_bind.dtype, + device=torch.cuda.current_device(), + ) + tensor_to_bind = self._empty_tensor_placeholder if cudagraphs_enabled: self._input_buffers[i].copy_(contiguous_inputs[i]) self.context.set_tensor_address( @@ -422,7 +432,7 @@ def setup_input_tensors( ) else: self.context.set_tensor_address( - input_name, contiguous_inputs[i].data_ptr() + input_name, tensor_to_bind.data_ptr() ) def create_output_tensors(self) -> List[torch.Tensor]: diff --git a/tests/py/dynamo/conversion/test_hardtanh_aten.py b/tests/py/dynamo/conversion/test_hardtanh_aten.py index d71cd3d6dc..6d696391f9 100644 --- a/tests/py/dynamo/conversion/test_hardtanh_aten.py +++ b/tests/py/dynamo/conversion/test_hardtanh_aten.py @@ -9,10 +9,6 @@ from .harness import DispatchTestCase -@unittest.skipIf( - torch_tensorrt.ENABLED_FEATURES.tensorrt_rtx, - "hardtanh is implemented in fx, need to move to dynamo, skip for TensorRT-RTX for now", -) class TestHardTanHConverter(DispatchTestCase): def test_hardtanh(self): class TestModule(nn.Module): diff --git a/tests/py/dynamo/conversion/test_rms_norm_aten.py b/tests/py/dynamo/conversion/test_rms_norm_aten.py new file mode 100644 index 0000000000..868994829b --- /dev/null +++ b/tests/py/dynamo/conversion/test_rms_norm_aten.py @@ -0,0 +1,286 @@ +import torch +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input + +from .harness import DispatchTestCase + + +class TestFusedRMSNormConverter(DispatchTestCase): + """ + Tests for the aten._fused_rms_norm.default converter. + + RMS Normalization formula: output = input / sqrt(mean(input^2) + eps) * weight + + The operation signature is: _fused_rms_norm(input, normalized_shape, weight, eps) + Returns: (output, rstd) - where rstd is the reciprocal standard deviation + """ + + @parameterized.expand( + [ + # Test normalizing over last dimension + ("1d_last_dim", (2, 4, 8), [8]), + # Test normalizing over last 2 dimensions + ("2d_last_two_dims", (2, 4, 8), [4, 8]), + # Test normalizing over all dimensions + ("3d_all_dims", (2, 4, 8), [2, 4, 8]), + # Test with 4D tensor, last dimension + ("4d_last_dim", (2, 3, 4, 8), [8]), + # Test with 4D tensor, last 2 dimensions + ("4d_last_two_dims", (2, 3, 4, 8), [4, 8]), + # Test with 4D tensor, last 3 dimensions + ("4d_last_three_dims", (2, 3, 4, 8), [3, 4, 8]), + ] + ) + def test_rms_norm_with_weight(self, name, input_shape, normalized_shape): + """ + Test RMS norm with weight parameter across various tensor shapes. + This tests: + - Correct dimension calculation for normalization + - Weight broadcasting/expansion to match input shape + - Output correctness vs PyTorch reference + """ + + class RMSNorm(torch.nn.Module): + def forward(self, x, weight): + return torch.ops.aten._fused_rms_norm.default( + x, normalized_shape, weight, 1e-5 + )[0] # Return only the normalized output, not rstd + + inputs = [ + torch.randn(input_shape), + torch.randn(normalized_shape), + ] + self.run_test( + RMSNorm(), + inputs, + use_dynamo_tracer=True, + enable_passes=True, + ) + + @parameterized.expand( + [ + # Test without weight (None) + ("1d_no_weight", (2, 4, 8), [8]), + ("2d_no_weight", (2, 4, 8), [4, 8]), + ("4d_no_weight", (2, 3, 4, 8), [8]), + ] + ) + def test_rms_norm_without_weight(self, name, input_shape, normalized_shape): + """ + Test RMS norm without weight parameter (weight=None). + This ensures the converter handles optional weight correctly. + """ + + class RMSNorm(torch.nn.Module): + def forward(self, x): + return torch.ops.aten._fused_rms_norm.default( + x, normalized_shape, None, 1e-5 + )[0] + + inputs = [torch.randn(input_shape)] + self.run_test( + RMSNorm(), + inputs, + use_dynamo_tracer=True, + enable_passes=True, + ) + + @parameterized.expand( + [ + # Test different epsilon values + ("eps_1e5", (2, 4, 8), [8], 1e-5), + ("eps_1e6", (2, 4, 8), [8], 1e-6), + ("eps_1e4", (2, 4, 8), [8], 1e-4), + ] + ) + def test_rms_norm_different_eps(self, name, input_shape, normalized_shape, eps): + """ + Test RMS norm with different epsilon values. + Epsilon is critical for numerical stability, especially with small values. + """ + + class RMSNorm(torch.nn.Module): + def forward(self, x, weight): + return torch.ops.aten._fused_rms_norm.default( + x, normalized_shape, weight, eps + )[0] + + inputs = [ + torch.randn(input_shape), + torch.randn(normalized_shape), + ] + self.run_test( + RMSNorm(), + inputs, + use_dynamo_tracer=True, + enable_passes=True, + ) + + def test_rms_norm_with_dynamic_shape_batch(self): + """ + Test RMS norm with dynamic batch dimension. + This is common in inference scenarios where batch size varies. + """ + + class RMSNorm(torch.nn.Module): + def forward(self, x, weight): + return torch.ops.aten._fused_rms_norm.default( + x, [128], weight, 1e-6 + )[0] + + input_specs = [ + Input( + shape=(-1, 128), + dtype=torch.float32, + shape_ranges=[((1, 128), (4, 128), (8, 128))], + ), + Input( + shape=(128,), + dtype=torch.float32, + ), + ] + + self.run_test_with_dynamic_shape( + RMSNorm(), + input_specs, + use_dynamo_tracer=True, + enable_passes=True, + ) + + def test_rms_norm_with_dynamic_shape_sequence(self): + """ + Test RMS norm with dynamic sequence length. + This is critical for transformer models with variable sequence lengths. + """ + + class RMSNorm(torch.nn.Module): + def forward(self, x, weight): + return torch.ops.aten._fused_rms_norm.default( + x, [256], weight, 1e-5 + )[0] + + input_specs = [ + Input( + shape=(2, -1, 256), + dtype=torch.float32, + shape_ranges=[((2, 16, 256), (2, 64, 256), (2, 128, 256))], + ), + Input( + shape=(256,), + dtype=torch.float32, + ), + ] + + self.run_test_with_dynamic_shape( + RMSNorm(), + input_specs, + use_dynamo_tracer=True, + enable_passes=True, + ) + + def test_rms_norm_with_dynamic_shape_multi_dim(self): + """ + Test RMS norm with multiple dynamic dimensions. + Tests both batch and sequence length being dynamic simultaneously. + """ + + class RMSNorm(torch.nn.Module): + def forward(self, x, weight): + return torch.ops.aten._fused_rms_norm.default( + x, [64], weight, 1e-6 + )[0] + + input_specs = [ + Input( + shape=(-1, -1, 64), + dtype=torch.float32, + shape_ranges=[((1, 8, 64), (4, 16, 64), (8, 32, 64))], + ), + Input( + shape=(64,), + dtype=torch.float32, + ), + ] + + self.run_test_with_dynamic_shape( + RMSNorm(), + input_specs, + use_dynamo_tracer=True, + enable_passes=True, + ) + + def test_rms_norm_2d_input(self): + """ + Test RMS norm with 2D input (batch, features). + Common in MLP layers or simple feedforward networks. + """ + + class RMSNorm(torch.nn.Module): + def forward(self, x, weight): + return torch.ops.aten._fused_rms_norm.default( + x, [512], weight, 1e-5 + )[0] + + inputs = [ + torch.randn(32, 512), + torch.randn(512), + ] + self.run_test( + RMSNorm(), + inputs, + use_dynamo_tracer=True, + enable_passes=True, + ) + + def test_rms_norm_large_hidden_dim(self): + """ + Test RMS norm with larger hidden dimensions typical in modern LLMs. + Tests numerical stability and performance with realistic model sizes. + """ + + class RMSNorm(torch.nn.Module): + def forward(self, x, weight): + return torch.ops.aten._fused_rms_norm.default( + x, [4096], weight, 1e-6 + )[0] + + inputs = [ + torch.randn(2, 8, 4096), + torch.randn(4096), + ] + self.run_test( + RMSNorm(), + inputs, + use_dynamo_tracer=True, + enable_passes=True, + ) + + def test_rms_norm_flux_pattern(self): + """ + Test RMS norm with pattern similar to FLUX and modern diffusion models. + This tests the actual use case that motivated the converter implementation. + """ + + class RMSNorm(torch.nn.Module): + def forward(self, x, weight): + # FLUX-style: normalize over last dimension with small epsilon + normalized_shape = [x.shape[-1]] + return torch.ops.aten._fused_rms_norm.default( + x, normalized_shape, weight, 1e-6 + )[0] + + inputs = [ + torch.randn(1, 16, 3072), # Typical FLUX dimensions + torch.randn(3072), + ] + self.run_test( + RMSNorm(), + inputs, + use_dynamo_tracer=True, + enable_passes=True, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/conversion/test_sym_numel_aten.py b/tests/py/dynamo/conversion/test_sym_numel_aten.py new file mode 100644 index 0000000000..fdf43cd91c --- /dev/null +++ b/tests/py/dynamo/conversion/test_sym_numel_aten.py @@ -0,0 +1,58 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input + +from .harness import DispatchTestCase + + +class TestSymNumelConverter(DispatchTestCase): + @parameterized.expand( + [ + ("1d", (6,)), + ("2d", (3, 4)), + ("3d", (2, 3, 4)), + ("4d", (2, 3, 4, 5)), + ] + ) + def test_sym_numel(self, _, input_shape): + class NumelModel(nn.Module): + def forward(self, x): + return torch.ops.aten.sym_numel.default(x) + + inputs = [torch.randn(input_shape)] + self.run_test( + NumelModel(), + inputs, + ) + + @parameterized.expand( + [ + ("1d_dynamic", (2,), (4,), (8,)), + ("2d_dynamic_batch", (1, 4), (3, 4), (6, 4)), + ("2d_dynamic_all", (2, 2), (4, 4), (8, 8)), + ("3d_dynamic", (1, 2, 4), (2, 3, 4), (4, 4, 4)), + ] + ) + def test_sym_numel_dynamic_shape(self, _, min_shape, opt_shape, max_shape): + class NumelModel(nn.Module): + def forward(self, x): + return torch.ops.aten.sym_numel.default(x) + + input_specs = [ + Input( + dtype=torch.float32, + min_shape=min_shape, + opt_shape=opt_shape, + max_shape=max_shape, + ), + ] + self.run_test_with_dynamic_shape( + NumelModel(), + input_specs, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/llm/test_llm_models.py b/tests/py/dynamo/llm/test_llm_models.py index 3899d5dd93..f133d4016a 100644 --- a/tests/py/dynamo/llm/test_llm_models.py +++ b/tests/py/dynamo/llm/test_llm_models.py @@ -1,5 +1,7 @@ +import importlib import os import sys +import unittest import pytest import torch @@ -9,13 +11,17 @@ sys.path.append(os.path.join(os.path.dirname(__file__), "../../../../tools/llm")) import argparse -from run_llm import compile_torchtrt -from torchtrt_ext import register_sdpa - @pytest.mark.unit @pytest.mark.parametrize("precision", ["FP16", "BF16", "FP32"]) +@unittest.skipIf( + not importlib.util.find_spec("modelopt"), + "ModelOpt is required to run this test", +) def test_llm_decoder_layer(precision): + from run_llm import compile_torchtrt + from torchtrt_ext import register_sdpa + if torch_tensorrt.ENABLED_FEATURES.tensorrt_rtx and precision == "BF16": pytest.skip("TensorRT-RTX does not support bfloat16, skipping test") with torch.inference_mode(): diff --git a/tests/py/dynamo/lowering/test_aten_lowering_passes.py b/tests/py/dynamo/lowering/test_aten_lowering_passes.py index fb8b8633ac..ccfbf06268 100644 --- a/tests/py/dynamo/lowering/test_aten_lowering_passes.py +++ b/tests/py/dynamo/lowering/test_aten_lowering_passes.py @@ -248,5 +248,35 @@ def forward(self, x): torch._dynamo.reset() +class TestRemoveSymIntNodes(TestCase): + def test_remove_sym_nodes(self): + class ModelContainSymIntNodes(torch.nn.Module): + def __init__(self, embed_dim: int): + super().__init__() + self.cls_token = torch.nn.Parameter(torch.randn(1, 1, embed_dim)) + self.embed_dim = embed_dim + self.qkv_proj = torch.nn.Linear(self.embed_dim, self.embed_dim * 3) + + def forward(self, x: torch.Tensor): + batch_size = x.shape[0] + cls_token = self.cls_token.expand(batch_size, -1, -1) + x = torch.cat([cls_token, x], dim=1) + x = self.qkv_proj(x) + reshaped_qkv = x.reshape(batch_size, x.size(1), 3, 12, -1) + return reshaped_qkv + + model = ModelContainSymIntNodes(embed_dim=768).cuda().eval() + inputs = torch.randn(4, 196, 768).cuda() + torch._dynamo.mark_dynamic(inputs, index=0, min=2, max=32) + trt_module = torch.compile( + model, + backend="tensorrt", + options={"use_python_runtime": False, "min_block_size": 1}, + ) + out = trt_module(inputs) + # if the model can be successfully compiled, we regard the test as passed + self.assertTrue(True) + + if __name__ == "__main__": run_tests() diff --git a/tests/py/dynamo/partitioning/conftest.py b/tests/py/dynamo/partitioning/conftest.py new file mode 100644 index 0000000000..01c8e67a6c --- /dev/null +++ b/tests/py/dynamo/partitioning/conftest.py @@ -0,0 +1,53 @@ +import copy + +import pytest +from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( + DYNAMO_ATEN_CONVERTERS, + DYNAMO_CONVERTERS, +) +from torch_tensorrt.dynamo.partitioning._atomic_subgraphs import trace_atomic_graph + + +@pytest.fixture(autouse=True) +def reset_torch_tensorrt_state(): + """ + Ensure test isolation by restoring converter registry state and clearing caches. + This prevents earlier tests from mutating global state (e.g., disallowed targets) + which can cause different partitioning outcomes when running multiple tests. + """ + # Snapshot current global state + original_registry = {k: list(v) for k, v in DYNAMO_ATEN_CONVERTERS.items()} + original_disallowed = set(getattr(DYNAMO_CONVERTERS, "disallowed_targets", set())) + original_settings = getattr(DYNAMO_CONVERTERS, "compilation_settings", None) + + # Clear caches before running each test + try: + trace_atomic_graph.cache_clear() + except Exception: + pass + + try: + yield + finally: + # Restore converter registry + DYNAMO_ATEN_CONVERTERS.clear() + DYNAMO_ATEN_CONVERTERS.update( + {k: list(v) for k, v in original_registry.items()} + ) + + # Restore disallowed targets and compilation settings + try: + DYNAMO_CONVERTERS.set_disallowed_targets(original_disallowed) + except Exception: + pass + if original_settings is not None: + try: + DYNAMO_CONVERTERS.set_compilation_settings(original_settings) + except Exception: + pass + + # Clear caches again to avoid stale state carrying forward + try: + trace_atomic_graph.cache_clear() + except Exception: + pass diff --git a/tests/py/dynamo/runtime/test_004_weight_streaming.py b/tests/py/dynamo/runtime/test_004_weight_streaming.py index fe0ae649bc..5954a7d4d4 100644 --- a/tests/py/dynamo/runtime/test_004_weight_streaming.py +++ b/tests/py/dynamo/runtime/test_004_weight_streaming.py @@ -6,6 +6,7 @@ import torch_tensorrt as torchtrt from parameterized import parameterized from torch.testing._internal.common_utils import TestCase, run_tests +from torch_tensorrt._utils import is_orin from torch_tensorrt.dynamo.utils import prepare_inputs INPUT_SIZE = (64, 100) @@ -294,6 +295,9 @@ def test_weight_streaming_cudagraphs(self, _, use_python_runtime): @unittest.skipIf( torchtrt.ENABLED_FEATURES.tensorrt_rtx, "TensorRT-RTX has bug on cudagraphs" ) + @unittest.skipIf( + is_orin(), "There is a bug on Orin platform, skip for now until bug is fixed" + ) def test_runtime_state_change(self, _, use_python_runtime): class SampleModel(torch.nn.Module): def __init__(self): diff --git a/tests/py/dynamo/runtime/test_empty_input.py b/tests/py/dynamo/runtime/test_empty_input.py new file mode 100644 index 0000000000..793eafb82c --- /dev/null +++ b/tests/py/dynamo/runtime/test_empty_input.py @@ -0,0 +1,260 @@ +import pytest +import torch +import torch.nn as nn +import torch_tensorrt as torchtrt +from parameterized import parameterized +from torch.testing._internal.common_utils import TestCase, run_tests + +DECIMALS_OF_AGREEMENT = 5 # for output comparison + + +# We provide non null address to TRT +class ConcatEmptyModel(nn.Module): + def __init__(self, dim=0): + super().__init__() + self.dim = dim + + def forward(self, x, y): + return torch.cat([x, y], dim=self.dim) + + +# TRT will handle +class ConcatEmptyModelEmptyConstant(nn.Module): + def __init__(self, dim=0): + super().__init__() + self.dim = dim + + def forward(self, x): + y = torch.empty((0, 4), dtype=torch.float).cuda() + return torch.cat([x, y], dim=self.dim) + + +# makes use of validator +class ConcatEmptyModelEmptyConstantMisMatchDim(nn.Module): + def __init__(self, dim=0): + super().__init__() + self.dim = dim + + def forward(self, x): + y = torch.tensor([], device="cuda") + return torch.cat([x, y], dim=self.dim) + + +class TestConcatEmptyTensor(TestCase): + + @parameterized.expand( + [ + ( + "python_runtime_model_one_empty_0", + True, + ConcatEmptyModel, + "two_inputs", + (0,), + ), + ( + "cpp_runtime_model_one_empty_0", + False, + ConcatEmptyModel, + "two_inputs", + (0,), + ), + ( + "python_runtime_model_one_empty_0_4", + True, + ConcatEmptyModel, + "two_inputs", + (0, 4), + ), + ( + "cpp_runtime_model_one_empty_0_4", + False, + ConcatEmptyModel, + "two_inputs", + (0, 4), + ), + ( + "python_runtime_model_two_empty_0_4", + True, + ConcatEmptyModelEmptyConstant, + "one_input", + (0, 4), + ), + ( + "cpp_runtime_model_two_empty_0_4", + False, + ConcatEmptyModelEmptyConstant, + "one_input", + (0, 4), + ), + ( + "python_runtime_model_three_empty_0", + True, + ConcatEmptyModelEmptyConstantMisMatchDim, + "one_input", + (0,), + ), + ( + "cpp_runtime_model_three_empty_0", + False, + ConcatEmptyModelEmptyConstantMisMatchDim, + "one_input", + (0,), + ), + ] + ) + def test_concat_empty_with_nonempty( + self, _, use_python_runtime, model_class, input_type, empty_shape + ): + """ + Test concatenation of empty tensor with non-empty tensor + along a specific dimension using Torch-TensorRT compiled model. + """ + # Create model + model = model_class(dim=0).eval().cuda() + + # Inputs: prepare based on model requirements + empty_input = torch.empty(empty_shape, dtype=torch.float).cuda() + non_empty_input = torch.randn((3, 4), dtype=torch.float).cuda() + + if input_type == "two_inputs": + inputs = [empty_input, non_empty_input] + else: # one_input + inputs = [non_empty_input] + + # Compile with Torch-TensorRT + compiled_model = torchtrt.compile( + model, + "dynamo", + inputs, + min_block_size=1, + use_python_runtime=use_python_runtime, + ) + + # Run reference model + ref_out = model(*inputs) + # Run compiled model + trt_out = compiled_model(*inputs) + + # Assertions + self.assertEqual(ref_out.shape, trt_out.shape) + self.assertAlmostEqual( + float(torch.max(torch.abs(ref_out - trt_out))), + 0, + DECIMALS_OF_AGREEMENT, + msg="Concat with empty tensor output mismatch", + ) + + @parameterized.expand( + [ + ("python_runtime_empty_0", True, (0,)), + ("cpp_runtime_empty_0", False, (0,)), + ("python_runtime_empty_0_4", True, (0, 4)), + ("cpp_runtime_empty_0_4", False, (0, 4)), + ] + ) + def test_concat_nonempty_with_empty(self, _, use_python_runtime, empty_shape): + """ + Concatenate non-empty tensor with empty tensor (opposite order) + """ + model = ConcatEmptyModel(dim=0).eval().cuda() + + non_empty_input = torch.randn((3, 4), dtype=torch.float).cuda() + empty_input = torch.empty(empty_shape, dtype=torch.float).cuda() + inputs = [non_empty_input, empty_input] + + compiled_model = torchtrt.compile( + model, + "dynamo", + inputs, + min_block_size=1, + use_python_runtime=use_python_runtime, + ) + + ref_out = model(*inputs) + trt_out = compiled_model(*inputs) + + self.assertEqual(ref_out.shape, trt_out.shape) + self.assertAlmostEqual( + float(torch.max(torch.abs(ref_out - trt_out))), + 0, + DECIMALS_OF_AGREEMENT, + msg="Concat with empty tensor (opposite order) output mismatch", + ) + + +class TestEmptyTensorMemoryLeak(TestCase): + """ + Tests to verify that repeated inferences with empty tensors + do not cause memory leaks and produce correct results. + """ + + @parameterized.expand( + [ + ("cpp_runtime", False), + ("python_runtime", True), + ] + ) + def test_repeated_empty_tensor_no_leak_and_correct(self, _, use_python_runtime): + """ + Run many inferences with empty tensor input to verify: + 1. Memory doesn't grow (placeholder is reused, not reallocated) + 2. Outputs are correct (placeholder doesn't corrupt results) + """ + model = ConcatEmptyModel(dim=0).eval().cuda() + + empty_input = torch.empty((0, 4), dtype=torch.float).cuda() + non_empty_input = torch.randn((3, 4), dtype=torch.float).cuda() + inputs = [empty_input, non_empty_input] + + compiled_model = torchtrt.compile( + model, + "dynamo", + inputs, + min_block_size=1, + use_python_runtime=use_python_runtime, + ) + + # Record initial GPU memory + torch.cuda.synchronize() + torch.cuda.reset_peak_memory_stats() + initial_memory = torch.cuda.memory_allocated() + + # Run many inferences with empty tensor + num_iterations = 1000 + for i in range(num_iterations): + # Use different non_empty data each iteration to test correctness + non_empty_input = torch.randn((3, 4), dtype=torch.float).cuda() + inputs = [empty_input, non_empty_input] + + ref_out = model(*inputs) + trt_out = compiled_model(*inputs) + + # Verify correctness every 100 iterations (to keep test fast) + if i % 100 == 0: + self.assertEqual(ref_out.shape, trt_out.shape) + self.assertAlmostEqual( + float(torch.max(torch.abs(ref_out - trt_out))), + 0, + DECIMALS_OF_AGREEMENT, + msg=f"Output mismatch at iteration {i}", + ) + + torch.cuda.synchronize() + final_memory = torch.cuda.memory_allocated() + + # Memory growth should be minimal (not proportional to num_iterations) + memory_growth = final_memory - initial_memory + max_allowed_growth = 1024 * 1024 # 1 MB max threshold + + print(f"Memory growth: {memory_growth} bytes") + + self.assertLess( + memory_growth, + max_allowed_growth, + msg=f"Memory grew by {memory_growth} bytes after {num_iterations} iterations. " + f"Possible memory leak with empty tensor handling.", + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/ts/api/test_classes.py b/tests/py/ts/api/test_classes.py index 796f57e046..5e1d50ddd9 100644 --- a/tests/py/ts/api/test_classes.py +++ b/tests/py/ts/api/test_classes.py @@ -365,9 +365,14 @@ def test_get_layer_info(self): TestTorchTensorRTModule._get_trt_mod(via_ts=True), ): trt_json = json.loads(trt_mod.get_layer_info()) - [self.assertTrue(k in trt_json.keys()) for k in ["Layers", "Bindings"]] - self.assertTrue(len(trt_json["Layers"]) == num_layers) - self.assertTrue(len(trt_json["Bindings"]) == 2) + [ + self.assertTrue(k in trt_json.keys(), f"Key {k} is missing") + for k in ["Layers", "Bindings"] + ] + self.assertTrue( + len(trt_json["Layers"]) == num_layers + ), "Not enough layers found" + self.assertTrue(len(trt_json["Bindings"]) == 2, "Not enough bindings found") if __name__ == "__main__": diff --git a/uv.lock b/uv.lock index 723dc7a9f3..1d6dae25d7 100644 --- a/uv.lock +++ b/uv.lock @@ -446,7 +446,7 @@ wheels = [ [[package]] name = "cuda-toolkit" version = "13.0.0" -source = { registry = "https://pypi.nvidia.com/" } +source = { registry = "https://download.pytorch.org/whl/nightly/cu130" } resolution-markers = [ "(python_full_version >= '3.13' and platform_machine != 'AMD64' and sys_platform == 'win32') or (python_full_version >= '3.14' and platform_machine == 'AMD64' and sys_platform == 'win32')", "python_full_version == '3.13.*' and platform_machine == 'AMD64' and sys_platform == 'win32'", @@ -469,7 +469,7 @@ cudart = [ [[package]] name = "cuda-toolkit" version = "13.0.2" -source = { registry = "https://pypi.nvidia.com/" } +source = { registry = "https://download.pytorch.org/whl/nightly/cu130" } resolution-markers = [ "(python_full_version >= '3.13' and platform_machine != 'aarch64' and platform_machine != 'x86_64' and sys_platform == 'linux') or (python_full_version >= '3.13' and platform_machine == 'aarch64' and platform_python_implementation != 'CPython' and sys_platform == 'linux') or (python_full_version >= '3.15' and platform_machine == 'aarch64' and platform_python_implementation == 'CPython' and sys_platform == 'linux') or (python_full_version >= '3.14' and platform_machine == 'x86_64' and sys_platform == 'linux')", "python_full_version == '3.13.*' and platform_machine == 'x86_64' and sys_platform == 'linux'", @@ -1892,13 +1892,13 @@ name = "onnxruntime" version = "1.22.1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "coloredlogs", marker = "(python_full_version != '3.13.*' and sys_platform == 'linux') or (platform_machine != 'x86_64' and sys_platform == 'linux') or sys_platform == 'win32'" }, - { name = "flatbuffers", marker = "(python_full_version != '3.13.*' and sys_platform == 'linux') or (platform_machine != 'x86_64' and sys_platform == 'linux') or sys_platform == 'win32'" }, - { name = "numpy", version = "2.2.6", source = { registry = "https://download.pytorch.org/whl/nightly/cu130" }, marker = "(python_full_version < '3.11' and sys_platform == 'linux') or (python_full_version < '3.11' and sys_platform == 'win32')" }, - { name = "numpy", version = "2.3.5", source = { registry = "https://download.pytorch.org/whl/nightly/cu130" }, marker = "(python_full_version >= '3.11' and python_full_version < '3.13' and sys_platform == 'linux') or (python_full_version >= '3.11' and platform_machine != 'x86_64' and sys_platform == 'linux') or (python_full_version >= '3.14' and platform_machine == 'x86_64' and sys_platform == 'linux') or (python_full_version >= '3.11' and sys_platform == 'win32')" }, - { name = "packaging", marker = "(python_full_version != '3.13.*' and sys_platform == 'linux') or (platform_machine != 'x86_64' and sys_platform == 'linux') or sys_platform == 'win32'" }, - { name = "protobuf", marker = "(python_full_version != '3.13.*' and sys_platform == 'linux') or (platform_machine != 'x86_64' and sys_platform == 'linux') or sys_platform == 'win32'" }, - { name = "sympy", marker = "(python_full_version != '3.13.*' and sys_platform == 'linux') or (platform_machine != 'x86_64' and sys_platform == 'linux') or sys_platform == 'win32'" }, + { name = "coloredlogs", marker = "(python_full_version >= '3.14' and sys_platform == 'linux') or (python_full_version >= '3.14' and sys_platform == 'win32') or (platform_machine != 'x86_64' and sys_platform == 'linux') or (platform_machine != 'AMD64' and sys_platform == 'win32')" }, + { name = "flatbuffers", marker = "(python_full_version >= '3.14' and sys_platform == 'linux') or (python_full_version >= '3.14' and sys_platform == 'win32') or (platform_machine != 'x86_64' and sys_platform == 'linux') or (platform_machine != 'AMD64' and sys_platform == 'win32')" }, + { name = "numpy", version = "2.2.6", source = { registry = "https://download.pytorch.org/whl/nightly/cu130" }, marker = "(python_full_version < '3.11' and platform_machine != 'x86_64' and sys_platform == 'linux') or (python_full_version < '3.11' and platform_machine != 'AMD64' and sys_platform == 'win32')" }, + { name = "numpy", version = "2.3.5", source = { registry = "https://download.pytorch.org/whl/nightly/cu130" }, marker = "(python_full_version >= '3.11' and platform_machine != 'x86_64' and sys_platform == 'linux') or (python_full_version >= '3.14' and platform_machine == 'x86_64' and sys_platform == 'linux') or (python_full_version >= '3.11' and platform_machine != 'AMD64' and sys_platform == 'win32') or (python_full_version >= '3.14' and platform_machine == 'AMD64' and sys_platform == 'win32')" }, + { name = "packaging", marker = "(python_full_version >= '3.14' and sys_platform == 'linux') or (python_full_version >= '3.14' and sys_platform == 'win32') or (platform_machine != 'x86_64' and sys_platform == 'linux') or (platform_machine != 'AMD64' and sys_platform == 'win32')" }, + { name = "protobuf", marker = "(python_full_version >= '3.14' and sys_platform == 'linux') or (python_full_version >= '3.14' and sys_platform == 'win32') or (platform_machine != 'x86_64' and sys_platform == 'linux') or (platform_machine != 'AMD64' and sys_platform == 'win32')" }, + { name = "sympy", marker = "(python_full_version >= '3.14' and sys_platform == 'linux') or (python_full_version >= '3.14' and sys_platform == 'win32') or (platform_machine != 'x86_64' and sys_platform == 'linux') or (platform_machine != 'AMD64' and sys_platform == 'win32')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/b9/64/bc7221e92c994931024e22b22401b962c299e991558c3d57f7e34538b4b9/onnxruntime-1.22.1-cp310-cp310-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b89ddfdbbdaf7e3a59515dee657f6515601d55cb21a0f0f48c81aefc54ff1b73", size = 14472246, upload-time = "2025-07-10T19:15:19.403Z" }, @@ -3012,8 +3012,8 @@ name = "tensorrt-cu13-libs" version = "10.14.1.48.post1" source = { registry = "https://pypi.nvidia.com/" } dependencies = [ - { name = "cuda-toolkit", version = "13.0.0", source = { registry = "https://pypi.nvidia.com/" }, extra = ["cudart"], marker = "sys_platform == 'win32'" }, - { name = "cuda-toolkit", version = "13.0.2", source = { registry = "https://pypi.nvidia.com/" }, extra = ["cudart"], marker = "sys_platform == 'linux'" }, + { name = "cuda-toolkit", version = "13.0.0", source = { registry = "https://download.pytorch.org/whl/nightly/cu130" }, extra = ["cudart"], marker = "sys_platform == 'win32'" }, + { name = "cuda-toolkit", version = "13.0.2", source = { registry = "https://download.pytorch.org/whl/nightly/cu130" }, extra = ["cudart"], marker = "sys_platform == 'linux'" }, ] wheels = [ { url = "https://pypi.nvidia.com/tensorrt-cu13-libs/tensorrt_cu13_libs-10.14.1.48.post1-py2.py3-none-manylinux_2_28_x86_64.whl", hash = "sha256:f55d59e9f93ebe0967c4bc108fb4068e74cdbc50bef3e6c9936e92f21cf11352" }, @@ -3132,20 +3132,20 @@ dependencies = [ { name = "typing-extensions", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, ] wheels = [ - { url = "https://download.pytorch.org/whl/nightly/cu130/torch-2.11.0.dev20260108%2Bcu130-cp310-cp310-manylinux_2_28_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cu130/torch-2.11.0.dev20260108%2Bcu130-cp310-cp310-win_amd64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cu130/torch-2.11.0.dev20260108%2Bcu130-cp311-cp311-manylinux_2_28_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cu130/torch-2.11.0.dev20260108%2Bcu130-cp311-cp311-win_amd64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cu130/torch-2.11.0.dev20260108%2Bcu130-cp312-cp312-manylinux_2_28_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cu130/torch-2.11.0.dev20260108%2Bcu130-cp312-cp312-win_amd64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cu130/torch-2.11.0.dev20260108%2Bcu130-cp313-cp313-manylinux_2_28_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cu130/torch-2.11.0.dev20260108%2Bcu130-cp313-cp313-win_amd64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cu130/torch-2.11.0.dev20260108%2Bcu130-cp313-cp313t-manylinux_2_28_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cu130/torch-2.11.0.dev20260108%2Bcu130-cp313-cp313t-win_amd64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cu130/torch-2.11.0.dev20260108%2Bcu130-cp314-cp314-manylinux_2_28_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cu130/torch-2.11.0.dev20260108%2Bcu130-cp314-cp314-win_amd64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cu130/torch-2.11.0.dev20260108%2Bcu130-cp314-cp314t-manylinux_2_28_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cu130/torch-2.11.0.dev20260108%2Bcu130-cp314-cp314t-win_amd64.whl" }, + { url = "https://download-r2.pytorch.org/whl/nightly/cu130/torch-2.11.0.dev20260108%2Bcu130-cp310-cp310-manylinux_2_28_x86_64.whl" }, + { url = "https://download-r2.pytorch.org/whl/nightly/cu130/torch-2.11.0.dev20260108%2Bcu130-cp310-cp310-win_amd64.whl" }, + { url = "https://download-r2.pytorch.org/whl/nightly/cu130/torch-2.11.0.dev20260108%2Bcu130-cp311-cp311-manylinux_2_28_x86_64.whl" }, + { url = "https://download-r2.pytorch.org/whl/nightly/cu130/torch-2.11.0.dev20260108%2Bcu130-cp311-cp311-win_amd64.whl" }, + { url = "https://download-r2.pytorch.org/whl/nightly/cu130/torch-2.11.0.dev20260108%2Bcu130-cp312-cp312-manylinux_2_28_x86_64.whl" }, + { url = "https://download-r2.pytorch.org/whl/nightly/cu130/torch-2.11.0.dev20260108%2Bcu130-cp312-cp312-win_amd64.whl" }, + { url = "https://download-r2.pytorch.org/whl/nightly/cu130/torch-2.11.0.dev20260108%2Bcu130-cp313-cp313-manylinux_2_28_x86_64.whl" }, + { url = "https://download-r2.pytorch.org/whl/nightly/cu130/torch-2.11.0.dev20260108%2Bcu130-cp313-cp313-win_amd64.whl" }, + { url = "https://download-r2.pytorch.org/whl/nightly/cu130/torch-2.11.0.dev20260108%2Bcu130-cp313-cp313t-manylinux_2_28_x86_64.whl" }, + { url = "https://download-r2.pytorch.org/whl/nightly/cu130/torch-2.11.0.dev20260108%2Bcu130-cp313-cp313t-win_amd64.whl" }, + { url = "https://download-r2.pytorch.org/whl/nightly/cu130/torch-2.11.0.dev20260108%2Bcu130-cp314-cp314-manylinux_2_28_x86_64.whl" }, + { url = "https://download-r2.pytorch.org/whl/nightly/cu130/torch-2.11.0.dev20260108%2Bcu130-cp314-cp314-win_amd64.whl" }, + { url = "https://download-r2.pytorch.org/whl/nightly/cu130/torch-2.11.0.dev20260108%2Bcu130-cp314-cp314t-manylinux_2_28_x86_64.whl" }, + { url = "https://download-r2.pytorch.org/whl/nightly/cu130/torch-2.11.0.dev20260108%2Bcu130-cp314-cp314t-win_amd64.whl" }, ] [[package]] @@ -3302,27 +3302,27 @@ dependencies = [ { name = "torch", marker = "(python_full_version < '3.13' and platform_machine != 'aarch64' and platform_machine != 'x86_64' and sys_platform == 'linux') or (python_full_version < '3.13' and platform_machine == 'aarch64' and platform_python_implementation != 'CPython' and sys_platform == 'linux') or (python_full_version < '3.14' and platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'win32'" }, ] wheels = [ - { url = "https://download.pytorch.org/whl/nightly/cu128/torchvision-0.25.0.dev20260108%2Bcu128-cp310-cp310-manylinux_2_28_aarch64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cu128/torchvision-0.25.0.dev20260108%2Bcu128-cp310-cp310-manylinux_2_28_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cu128/torchvision-0.25.0.dev20260108%2Bcu128-cp310-cp310-win_amd64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cu128/torchvision-0.25.0.dev20260108%2Bcu128-cp311-cp311-manylinux_2_28_aarch64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cu128/torchvision-0.25.0.dev20260108%2Bcu128-cp311-cp311-manylinux_2_28_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cu128/torchvision-0.25.0.dev20260108%2Bcu128-cp311-cp311-win_amd64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cu128/torchvision-0.25.0.dev20260108%2Bcu128-cp312-cp312-manylinux_2_28_aarch64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cu128/torchvision-0.25.0.dev20260108%2Bcu128-cp312-cp312-manylinux_2_28_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cu128/torchvision-0.25.0.dev20260108%2Bcu128-cp312-cp312-win_amd64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cu128/torchvision-0.25.0.dev20260108%2Bcu128-cp313-cp313-manylinux_2_28_aarch64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cu128/torchvision-0.25.0.dev20260108%2Bcu128-cp313-cp313-manylinux_2_28_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cu128/torchvision-0.25.0.dev20260108%2Bcu128-cp313-cp313-win_amd64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cu128/torchvision-0.25.0.dev20260108%2Bcu128-cp313-cp313t-manylinux_2_28_aarch64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cu128/torchvision-0.25.0.dev20260108%2Bcu128-cp313-cp313t-manylinux_2_28_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cu128/torchvision-0.25.0.dev20260108%2Bcu128-cp313-cp313t-win_amd64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cu128/torchvision-0.25.0.dev20260108%2Bcu128-cp314-cp314-manylinux_2_28_aarch64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cu128/torchvision-0.25.0.dev20260108%2Bcu128-cp314-cp314-manylinux_2_28_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cu128/torchvision-0.25.0.dev20260108%2Bcu128-cp314-cp314-win_amd64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cu128/torchvision-0.25.0.dev20260108%2Bcu128-cp314-cp314t-manylinux_2_28_aarch64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cu128/torchvision-0.25.0.dev20260108%2Bcu128-cp314-cp314t-manylinux_2_28_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cu128/torchvision-0.25.0.dev20260108%2Bcu128-cp314-cp314t-win_amd64.whl" }, + { url = "https://download-r2.pytorch.org/whl/nightly/cu128/torchvision-0.25.0.dev20260108%2Bcu128-cp310-cp310-manylinux_2_28_aarch64.whl" }, + { url = "https://download-r2.pytorch.org/whl/nightly/cu128/torchvision-0.25.0.dev20260108%2Bcu128-cp310-cp310-manylinux_2_28_x86_64.whl" }, + { url = "https://download-r2.pytorch.org/whl/nightly/cu128/torchvision-0.25.0.dev20260108%2Bcu128-cp310-cp310-win_amd64.whl" }, + { url = "https://download-r2.pytorch.org/whl/nightly/cu128/torchvision-0.25.0.dev20260108%2Bcu128-cp311-cp311-manylinux_2_28_aarch64.whl" }, + { url = "https://download-r2.pytorch.org/whl/nightly/cu128/torchvision-0.25.0.dev20260108%2Bcu128-cp311-cp311-manylinux_2_28_x86_64.whl" }, + { url = "https://download-r2.pytorch.org/whl/nightly/cu128/torchvision-0.25.0.dev20260108%2Bcu128-cp311-cp311-win_amd64.whl" }, + { url = "https://download-r2.pytorch.org/whl/nightly/cu128/torchvision-0.25.0.dev20260108%2Bcu128-cp312-cp312-manylinux_2_28_aarch64.whl" }, + { url = "https://download-r2.pytorch.org/whl/nightly/cu128/torchvision-0.25.0.dev20260108%2Bcu128-cp312-cp312-manylinux_2_28_x86_64.whl" }, + { url = "https://download-r2.pytorch.org/whl/nightly/cu128/torchvision-0.25.0.dev20260108%2Bcu128-cp312-cp312-win_amd64.whl" }, + { url = "https://download-r2.pytorch.org/whl/nightly/cu128/torchvision-0.25.0.dev20260108%2Bcu128-cp313-cp313-manylinux_2_28_aarch64.whl" }, + { url = "https://download-r2.pytorch.org/whl/nightly/cu128/torchvision-0.25.0.dev20260108%2Bcu128-cp313-cp313-manylinux_2_28_x86_64.whl" }, + { url = "https://download-r2.pytorch.org/whl/nightly/cu128/torchvision-0.25.0.dev20260108%2Bcu128-cp313-cp313-win_amd64.whl" }, + { url = "https://download-r2.pytorch.org/whl/nightly/cu128/torchvision-0.25.0.dev20260108%2Bcu128-cp313-cp313t-manylinux_2_28_aarch64.whl" }, + { url = "https://download-r2.pytorch.org/whl/nightly/cu128/torchvision-0.25.0.dev20260108%2Bcu128-cp313-cp313t-manylinux_2_28_x86_64.whl" }, + { url = "https://download-r2.pytorch.org/whl/nightly/cu128/torchvision-0.25.0.dev20260108%2Bcu128-cp313-cp313t-win_amd64.whl" }, + { url = "https://download-r2.pytorch.org/whl/nightly/cu128/torchvision-0.25.0.dev20260108%2Bcu128-cp314-cp314-manylinux_2_28_aarch64.whl" }, + { url = "https://download-r2.pytorch.org/whl/nightly/cu128/torchvision-0.25.0.dev20260108%2Bcu128-cp314-cp314-manylinux_2_28_x86_64.whl" }, + { url = "https://download-r2.pytorch.org/whl/nightly/cu128/torchvision-0.25.0.dev20260108%2Bcu128-cp314-cp314-win_amd64.whl" }, + { url = "https://download-r2.pytorch.org/whl/nightly/cu128/torchvision-0.25.0.dev20260108%2Bcu128-cp314-cp314t-manylinux_2_28_aarch64.whl" }, + { url = "https://download-r2.pytorch.org/whl/nightly/cu128/torchvision-0.25.0.dev20260108%2Bcu128-cp314-cp314t-manylinux_2_28_x86_64.whl" }, + { url = "https://download-r2.pytorch.org/whl/nightly/cu128/torchvision-0.25.0.dev20260108%2Bcu128-cp314-cp314t-win_amd64.whl" }, ] [[package]] @@ -3363,20 +3363,20 @@ name = "triton" version = "3.6.0+git9844da95" source = { registry = "https://download.pytorch.org/whl/nightly/cu130" } wheels = [ - { url = "https://download.pytorch.org/whl/nightly/triton-3.6.0%2Bgit9844da95-cp310-cp310-linux_aarch64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/triton-3.6.0%2Bgit9844da95-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/triton-3.6.0%2Bgit9844da95-cp311-cp311-linux_aarch64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/triton-3.6.0%2Bgit9844da95-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/triton-3.6.0%2Bgit9844da95-cp312-cp312-linux_aarch64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/triton-3.6.0%2Bgit9844da95-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/triton-3.6.0%2Bgit9844da95-cp313-cp313-linux_aarch64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/triton-3.6.0%2Bgit9844da95-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/triton-3.6.0%2Bgit9844da95-cp313-cp313t-linux_aarch64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/triton-3.6.0%2Bgit9844da95-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/triton-3.6.0%2Bgit9844da95-cp314-cp314-linux_aarch64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/triton-3.6.0%2Bgit9844da95-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/triton-3.6.0%2Bgit9844da95-cp314-cp314t-linux_aarch64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/triton-3.6.0%2Bgit9844da95-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl" }, + { url = "https://download-r2.pytorch.org/whl/nightly/triton-3.6.0%2Bgit9844da95-cp310-cp310-linux_aarch64.whl" }, + { url = "https://download-r2.pytorch.org/whl/nightly/triton-3.6.0%2Bgit9844da95-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl" }, + { url = "https://download-r2.pytorch.org/whl/nightly/triton-3.6.0%2Bgit9844da95-cp311-cp311-linux_aarch64.whl" }, + { url = "https://download-r2.pytorch.org/whl/nightly/triton-3.6.0%2Bgit9844da95-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl" }, + { url = "https://download-r2.pytorch.org/whl/nightly/triton-3.6.0%2Bgit9844da95-cp312-cp312-linux_aarch64.whl" }, + { url = "https://download-r2.pytorch.org/whl/nightly/triton-3.6.0%2Bgit9844da95-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl" }, + { url = "https://download-r2.pytorch.org/whl/nightly/triton-3.6.0%2Bgit9844da95-cp313-cp313-linux_aarch64.whl" }, + { url = "https://download-r2.pytorch.org/whl/nightly/triton-3.6.0%2Bgit9844da95-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl" }, + { url = "https://download-r2.pytorch.org/whl/nightly/triton-3.6.0%2Bgit9844da95-cp313-cp313t-linux_aarch64.whl" }, + { url = "https://download-r2.pytorch.org/whl/nightly/triton-3.6.0%2Bgit9844da95-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl" }, + { url = "https://download-r2.pytorch.org/whl/nightly/triton-3.6.0%2Bgit9844da95-cp314-cp314-linux_aarch64.whl" }, + { url = "https://download-r2.pytorch.org/whl/nightly/triton-3.6.0%2Bgit9844da95-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl" }, + { url = "https://download-r2.pytorch.org/whl/nightly/triton-3.6.0%2Bgit9844da95-cp314-cp314t-linux_aarch64.whl" }, + { url = "https://download-r2.pytorch.org/whl/nightly/triton-3.6.0%2Bgit9844da95-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl" }, ] [[package]]