From 2aa6271198f28f7dc44608c2abf243292c437303 Mon Sep 17 00:00:00 2001 From: Victor Stone Date: Mon, 2 Feb 2026 12:52:51 -0800 Subject: [PATCH 1/2] Remove exception; trt-rtx supports bf16. --- py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index d4735baa12..e1eaed1a35 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -202,8 +202,7 @@ def _all_precisions_supported(enabled_precisions: Set[dtype]) -> bool: def validate_compile_settings(self) -> None: if ENABLED_FEATURES.tensorrt_rtx: - if dtype.bfloat16 in self.compilation_settings.enabled_precisions: - raise RuntimeError("TensorRT-RTX does not support bfloat16!") + # The below checks are not relevant for TensorRT-RTX return if ( From 2cc8e308a43aa33ff1df1889d43cc50c592e02de Mon Sep 17 00:00:00 2001 From: Victor Stone Date: Mon, 2 Feb 2026 16:30:56 -0800 Subject: [PATCH 2/2] Reenable tests for bf16 --- .../dynamo/conversion/test_binary_ops_aten.py | 4 ---- tests/py/dynamo/conversion/test_casts.py | 4 ---- tests/py/dynamo/llm/test_llm_models.py | 2 -- tests/py/dynamo/models/test_dtype_support.py | 4 ---- tests/py/dynamo/models/test_dyn_models.py | 3 --- tests/py/dynamo/models/test_models.py | 21 ------------------- tests/py/dynamo/models/test_models_export.py | 6 ------ 7 files changed, 44 deletions(-) diff --git a/tests/py/dynamo/conversion/test_binary_ops_aten.py b/tests/py/dynamo/conversion/test_binary_ops_aten.py index 16b82b9858..d7c7a554c0 100644 --- a/tests/py/dynamo/conversion/test_binary_ops_aten.py +++ b/tests/py/dynamo/conversion/test_binary_ops_aten.py @@ -237,10 +237,6 @@ def forward(self, x, y): if op[0].__name__ not in ["pow.Tensor_Tensor", "fmod.Tensor"] ] ) - @unittest.skipIf( - torch_tensorrt.ENABLED_FEATURES.tensorrt_rtx, - "bf16 is not supported for tensorrt_rtx", - ) def test_elementwise_ops_bf16(self, _, orig_op): class TestModule(nn.Module): def __init__(self, orig_op): diff --git a/tests/py/dynamo/conversion/test_casts.py b/tests/py/dynamo/conversion/test_casts.py index 62920c9610..cb79001f4f 100644 --- a/tests/py/dynamo/conversion/test_casts.py +++ b/tests/py/dynamo/conversion/test_casts.py @@ -67,10 +67,6 @@ def forward(self, x): precision=torch.float, ) - @unittest.skipIf( - torch_tensorrt.ENABLED_FEATURES.tensorrt_rtx, - "bf16 is not supported for tensorrt_rtx", - ) def test_to_copy_bfloat16(self): class ToCopyBFloat16(nn.Module): def forward(self, x): diff --git a/tests/py/dynamo/llm/test_llm_models.py b/tests/py/dynamo/llm/test_llm_models.py index 73811572f9..054c2536d7 100644 --- a/tests/py/dynamo/llm/test_llm_models.py +++ b/tests/py/dynamo/llm/test_llm_models.py @@ -16,8 +16,6 @@ @pytest.mark.unit @pytest.mark.parametrize("precision", ["FP16", "BF16", "FP32"]) def test_llm_decoder_layer(precision): - if torch_tensorrt.ENABLED_FEATURES.tensorrt_rtx and precision == "BF16": - pytest.skip("TensorRT-RTX does not support bfloat16, skipping test") with torch.inference_mode(): args = argparse.Namespace() args.debug = False diff --git a/tests/py/dynamo/models/test_dtype_support.py b/tests/py/dynamo/models/test_dtype_support.py index 6c02db6b68..42507968f7 100644 --- a/tests/py/dynamo/models/test_dtype_support.py +++ b/tests/py/dynamo/models/test_dtype_support.py @@ -200,10 +200,6 @@ def forward(self, x): ), "Platform does not have BF16 support", ) -@unittest.skipIf( - torch_tensorrt.ENABLED_FEATURES.tensorrt_rtx, - "bf16 is not supported for tensorrt_rtx", -) class TestBF16Support(TestCase): @unittest.skipIf( not torch_tensorrt.ENABLED_FEATURES.torch_tensorrt_runtime, diff --git a/tests/py/dynamo/models/test_dyn_models.py b/tests/py/dynamo/models/test_dyn_models.py index 28d72433b7..b54e67c2fb 100644 --- a/tests/py/dynamo/models/test_dyn_models.py +++ b/tests/py/dynamo/models/test_dyn_models.py @@ -189,9 +189,6 @@ def test_resnet_dynamic(ir, dtype): """ Tests the Resnet18 model (which is fully convertible) with dynamic shapes """ - if torchtrt.ENABLED_FEATURES.tensorrt_rtx and dtype == torch.bfloat16: - pytest.skip("TensorRT-RTX does not support bfloat16") - import torchvision.models as models model = models.resnet18(pretrained=True).eval().to("cuda").to(dtype) diff --git a/tests/py/dynamo/models/test_models.py b/tests/py/dynamo/models/test_models.py index b95968809d..d7f9da4302 100644 --- a/tests/py/dynamo/models/test_models.py +++ b/tests/py/dynamo/models/test_models.py @@ -195,9 +195,6 @@ def test_resnet18_torch_exec_ops(ir): "torchvision is not installed", ) def test_mobilenet_v2(ir, dtype): - if torchtrt.ENABLED_FEATURES.tensorrt_rtx and dtype == torch.bfloat16: - pytest.skip("TensorRT-RTX does not support bfloat16") - model = models.mobilenet_v2(pretrained=True).eval().to("cuda").to(dtype) input = torch.randn((1, 3, 224, 224)).to("cuda").to(dtype) @@ -237,9 +234,6 @@ def test_mobilenet_v2(ir, dtype): "timm or torchvision not installed", ) def test_efficientnet_b0(ir, dtype): - if torchtrt.ENABLED_FEATURES.tensorrt_rtx and dtype == torch.bfloat16: - pytest.skip("TensorRT-RTX does not support bfloat16") - model = ( timm.create_model("efficientnet_b0", pretrained=True) .eval() @@ -284,9 +278,6 @@ def test_efficientnet_b0(ir, dtype): "transformers is required to run this test", ) def test_bert_base_uncased(ir, dtype): - if torchtrt.ENABLED_FEATURES.tensorrt_rtx and dtype == torch.bfloat16: - pytest.skip("TensorRT-RTX does not support bfloat16") - from transformers import BertModel model = BertModel.from_pretrained("bert-base-uncased").cuda().eval().to(dtype) @@ -425,10 +416,6 @@ def test_resnet18_half(ir): @pytest.mark.unit -@unittest.skipIf( - torchtrt.ENABLED_FEATURES.tensorrt_rtx, - "tensorrt_rtx does not support bfloat16", -) def test_cosmos_true_div(ir): class CosmosLearnablePositionalEmbed(torch.nn.Module): def __init__( @@ -527,10 +514,6 @@ def forward( @pytest.mark.unit -@unittest.skipIf( - torchtrt.ENABLED_FEATURES.tensorrt_rtx, - "bf16 is not supported for tensorrt_rtx", -) @pytest.mark.critical def test_bf16_model(ir): class MyModule(torch.nn.Module): @@ -576,10 +559,6 @@ def forward(self, x): @pytest.mark.unit -@unittest.skipIf( - torchtrt.ENABLED_FEATURES.tensorrt_rtx, - "bf16 is not supported for tensorrt_rtx", -) @pytest.mark.critical def test_bf16_fallback_model(ir): class MyModule(torch.nn.Module): diff --git a/tests/py/dynamo/models/test_models_export.py b/tests/py/dynamo/models/test_models_export.py index ce4a61a876..baf42fc7d5 100644 --- a/tests/py/dynamo/models/test_models_export.py +++ b/tests/py/dynamo/models/test_models_export.py @@ -408,9 +408,6 @@ def test_base_int8(ir, dtype): import modelopt.torch.quantization as mtq from modelopt.torch.quantization.utils import export_torch_mode - if torchtrt.ENABLED_FEATURES.tensorrt_rtx and dtype == torch.bfloat16: - pytest.skip("TensorRT-RTX does not support bfloat16") - class SimpleNetwork(torch.nn.Module): def __init__(self): super(SimpleNetwork, self).__init__() @@ -469,9 +466,6 @@ def test_base_int8_dynamic_shape(ir, dtype): import modelopt.torch.quantization as mtq from modelopt.torch.quantization.utils import export_torch_mode - if torchtrt.ENABLED_FEATURES.tensorrt_rtx and dtype == torch.bfloat16: - pytest.skip("TensorRT-RTX does not support bfloat16") - class SimpleNetwork(torch.nn.Module): def __init__(self): super(SimpleNetwork, self).__init__()