Skip to content

Add dynamic shape support to index_put#4143

Open
narendasan wants to merge 2 commits intomainfrom
narendasan/push-trxznozvxnsq
Open

Add dynamic shape support to index_put#4143
narendasan wants to merge 2 commits intomainfrom
narendasan/push-trxznozvxnsq

Conversation

@narendasan
Copy link
Collaborator

@narendasan narendasan commented Mar 20, 2026

Description

Index put appears in KV cache implementations from huggingface, we get a broadcast error because there was no validator catching this. This PR adds support for dynamic shape in the op and properly guards failure modes

Fixes #4139
Fixes #4142
Fixes #3647
Fixes #2939
Fixes #3798
Fixes #3806

Issue Description Fix
#4139 index_put fails with dynamic non-indexed dimensions ("Dynamic shape in free dimensions not supported") Rewrote converter to propagate dynamic dims through get_shape
#4142 index_copy_ / StaticCache fails with shape broadcast error ("Cannot broadcast (1,8,1,128) to (1,1,8,128)") Fixed axis alignment in ND scatter index construction
#3806 index_add_ fails with dynamic shape Dynamic M/P handled in scatter-add path
#3798 Non-consecutive indices + dynamic shape not supported Converter now handles arbitrary non-contiguous index combinations with dynamic shapes
#3777 x[bool_3d_mask] = 0.0 crashes with ValueError: __len__() should return >= 0 Fixed expand_boolean_indices to correctly split TRT add_non_zero output (ndim, N) into per-dim (N,) tensors
#2939 accumulate=True with duplicate indices produces wrong results (scatter overwrites instead of accumulating)

Accumulate is supported by a MxP implementation of native tensorrt ops which we expect to have ~10-15% overhead in reasonably sized problems but can be costly in very large tasks.

Type of change

Please delete options that are not relevant and/or add your own.

  • Bug fix (non-breaking change which fixes an issue)

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified

@meta-cla meta-cla bot added the cla signed label Mar 20, 2026
@narendasan narendasan force-pushed the narendasan/push-trxznozvxnsq branch from fef7530 to 17d88b1 Compare March 20, 2026 20:13
@github-actions github-actions bot added component: tests Issues re: Tests component: conversion Issues re: Conversion stage component: core Issues re: The core compiler component: converters Issues re: Specific op converters component: api [Python] Issues re: Python API component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels Mar 20, 2026
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/select.py	2026-03-20 20:13:55.127307+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/select.py	2026-03-20 20:14:14.887149+00:00
@@ -623,11 +623,16 @@
    for _fi, _fdim in enumerate(F):
        _s = input_tensor.shape[_fdim]
        if _s == DYNAMIC_DIM:
            F_shape_values.append(
                get_shape(
-                    ctx, target, source_ir, f"{name}_fshape_{_fdim}", input_tensor, _fdim
+                    ctx,
+                    target,
+                    source_ir,
+                    f"{name}_fshape_{_fdim}",
+                    input_tensor,
+                    _fdim,
                )
            )
        else:
            F_shape_values.append(_s)
    _has_dynamic_f = any(isinstance(_s, TRTTensor) for _s in F_shape_values)
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_index_put_aten.py	2026-03-20 20:13:55.159208+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_index_put_aten.py	2026-03-20 20:14:16.694415+00:00
@@ -357,11 +357,10 @@
        )
        result = trt_engine(source_tensor, indices_tensor, value_tensor)

        torch.allclose(result, torch_output, atol=1e-4, rtol=1e-4)

-
    def test_kv_cache_dynamic_batch(self):
        """index_put with a dynamic free dimension (batch) — issue #4139.

        Pattern: cache[..., idx, :] = values  where dim-1 (batch) is dynamic
        and dim-2 (cache/time) is the indexed static dimension.
@@ -420,12 +419,12 @@
            use_explicit_typing=True,
            min_block_size=1,
        )

        result = trt_mod(cache.clone(), values, idx)
-        assert torch.allclose(result, torch_output, atol=1e-3, rtol=1e-3), (
-            f"KV-cache index_put mismatch: max diff = {(result - torch_output).abs().max()}"
-        )
+        assert torch.allclose(
+            result, torch_output, atol=1e-3, rtol=1e-3
+        ), f"KV-cache index_put mismatch: max diff = {(result - torch_output).abs().max()}"


if __name__ == "__main__":
    run_tests()

@narendasan narendasan force-pushed the narendasan/push-trxznozvxnsq branch from 17d88b1 to 858bf17 Compare March 20, 2026 22:24
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_index_put_aten.py	2026-03-20 22:25:06.608217+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_index_put_aten.py	2026-03-20 22:25:28.809312+00:00
@@ -392,12 +392,16 @@
        )
        trt_mod = torchtrt.dynamo.compile(
            ep,
            arg_inputs=[
                torchtrt.Input(shape=(16,), dtype=torch.float32),
-                torchtrt.Input(min_shape=(1,), opt_shape=(3,), max_shape=(16,), dtype=torch.float32),
-                torchtrt.Input(min_shape=(1,), opt_shape=(3,), max_shape=(16,), dtype=torch.int32),
+                torchtrt.Input(
+                    min_shape=(1,), opt_shape=(3,), max_shape=(16,), dtype=torch.float32
+                ),
+                torchtrt.Input(
+                    min_shape=(1,), opt_shape=(3,), max_shape=(16,), dtype=torch.int32
+                ),
            ],
            min_block_size=1,
        )
        result = trt_mod(src.clone(), values, idx)
        assert torch.allclose(

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/select.py	2026-03-23 18:18:24.127073+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/select.py	2026-03-23 18:18:42.849840+00:00
@@ -772,13 +772,15 @@
        ctx, target, source_ir, f"{name}_result_flat", src_flat, delta
    )

    # Rebuild the output shape (may contain dynamic dims)
    out_shape = tuple(
-        get_shape(ctx, target, source_ir, f"{name}_oshape_{i}", input_tensor, i)
-        if input_tensor.shape[i] == DYNAMIC_DIM
-        else input_tensor.shape[i]
+        (
+            get_shape(ctx, target, source_ir, f"{name}_oshape_{i}", input_tensor, i)
+            if input_tensor.shape[i] == DYNAMIC_DIM
+            else input_tensor.shape[i]
+        )
        for i in range(rank)
    )
    return impl.shuffle.reshape(
        ctx, target, source_ir, f"{name}_result", result_flat, out_shape
    )
@@ -871,11 +873,13 @@
    F = [i for i in range(rank) if indices[i] is None]  # Free dimensions
    I = [i for i in range(rank) if indices[i] is not None]  # Indexed dimensions
    K = len(I)
    # Determine the maximum size 'N' among the index tensors
    if K > 0:
-        index_shapes = []  # [tensor.shape[0] for tensor in indices if tensor is not None]
+        index_shapes = (
+            []
+        )  # [tensor.shape[0] for tensor in indices if tensor is not None]
        for _ni, idx_tensor in enumerate(indices):
            if idx_tensor is not None:
                if idx_tensor.shape[0] != DYNAMIC_DIM:
                    index_shapes.append(idx_tensor.shape[0])
                else:
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_index_put_aten.py	2026-03-23 18:18:24.152947+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_index_put_aten.py	2026-03-23 18:18:45.151011+00:00
@@ -321,31 +321,45 @@
            # duplicate positions correctly (mirrors test_index_put_accumulate_duplicate_indices).
            param(
                test_name="1d_duplicate_indices_accumulate",
                source_tensor=torch.zeros([6], dtype=torch.float32),
                indices_tensor=(torch.tensor([0, 0, 2, 2, 2], dtype=torch.int64),),
-                value_tensor=torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.float32),
+                value_tensor=torch.tensor(
+                    [1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.float32
+                ),
                accumulate=True,
            ),
            param(
                test_name="2d_indices_accumulate_True",
                source_tensor=torch.zeros([5, 5], dtype=torch.float32),
-                indices_tensor=(torch.tensor([0, 0], dtype=torch.int32), torch.tensor([1, 1], dtype=torch.int32)),
+                indices_tensor=(
+                    torch.tensor([0, 0], dtype=torch.int32),
+                    torch.tensor([1, 1], dtype=torch.int32),
+                ),
                value_tensor=torch.tensor([1.0, 2.0], dtype=torch.float32),
                accumulate=True,
            ),
            param(
                test_name="3d_indices_accumulate_True",
                source_tensor=torch.zeros([3, 3, 3], dtype=torch.float32),
-                indices_tensor=(torch.tensor([0, 0], dtype=torch.int32), torch.tensor([1, 1], dtype=torch.int32), torch.tensor([2, 2], dtype=torch.int32)),
+                indices_tensor=(
+                    torch.tensor([0, 0], dtype=torch.int32),
+                    torch.tensor([1, 1], dtype=torch.int32),
+                    torch.tensor([2, 2], dtype=torch.int32),
+                ),
                value_tensor=torch.tensor([1.0, 2.0], dtype=torch.float32),
                accumulate=True,
            ),
            param(
                test_name="4d_indices_accumulate_True",
                source_tensor=torch.zeros([2, 2, 2, 2], dtype=torch.float32),
-                indices_tensor=(torch.tensor([0, 0], dtype=torch.int32), torch.tensor([1, 1], dtype=torch.int32), torch.tensor([0, 0], dtype=torch.int32), torch.tensor([1, 1], dtype=torch.int32)),
+                indices_tensor=(
+                    torch.tensor([0, 0], dtype=torch.int32),
+                    torch.tensor([1, 1], dtype=torch.int32),
+                    torch.tensor([0, 0], dtype=torch.int32),
+                    torch.tensor([1, 1], dtype=torch.int32),
+                ),
                value_tensor=torch.tensor([1.0, 2.0], dtype=torch.float32),
                accumulate=True,
            ),
            # Negative indices with accumulate (mirrors test_index_put_accumulate_large_tensor).
            param(
@@ -357,29 +371,38 @@
            ),
            # bfloat16 + duplicate indices: computation stays in bfloat16 (no forced fp32 cast).
            param(
                test_name="accumulate_bfloat16_duplicate",
                source_tensor=torch.zeros([4, 4], dtype=torch.bfloat16),
-                indices_tensor=(torch.tensor([0, 0, 2], dtype=torch.int64), torch.tensor([1, 1, 3], dtype=torch.int64)),
+                indices_tensor=(
+                    torch.tensor([0, 0, 2], dtype=torch.int64),
+                    torch.tensor([1, 1, 3], dtype=torch.int64),
+                ),
                value_tensor=torch.tensor([1.0, 2.0, 4.0], dtype=torch.bfloat16),
                accumulate=True,
            ),
            # float16 + duplicate indices.
            param(
                test_name="accumulate_float16_duplicate",
                source_tensor=torch.zeros([4, 4], dtype=torch.float16),
-                indices_tensor=(torch.tensor([1, 1, 3], dtype=torch.int64), torch.tensor([0, 0, 2], dtype=torch.int64)),
+                indices_tensor=(
+                    torch.tensor([1, 1, 3], dtype=torch.int64),
+                    torch.tensor([0, 0, 2], dtype=torch.int64),
+                ),
                value_tensor=torch.tensor([2.0, 3.0, 5.0], dtype=torch.float16),
                accumulate=True,
            ),
            # Partial broadcast: one index covers a single position on dim-1 while
            # dim-0 has multiple positions — mirrors test_index_put_accumulate_expanded_values
            # (t[tensor([0,1,2,3]), tensor([1])] += 1.0).
            param(
                test_name="accumulate_partial_dim1_broadcast",
                source_tensor=torch.zeros([5, 2], dtype=torch.float32),
-                indices_tensor=(torch.tensor([0, 1, 2, 3], dtype=torch.int64), torch.tensor([1], dtype=torch.int64)),
+                indices_tensor=(
+                    torch.tensor([0, 1, 2, 3], dtype=torch.int64),
+                    torch.tensor([1], dtype=torch.int64),
+                ),
                value_tensor=torch.tensor([1.0], dtype=torch.float32),
                accumulate=True,
            ),
        ]
    )
@@ -512,12 +535,16 @@
        )
        trt_mod = torchtrt.dynamo.compile(
            ep,
            arg_inputs=[
                torchtrt.Input(shape=(16,), dtype=torch.float32),
-                torchtrt.Input(min_shape=(1,), opt_shape=(3,), max_shape=(16,), dtype=torch.float32),
-                torchtrt.Input(min_shape=(1,), opt_shape=(3,), max_shape=(16,), dtype=torch.int32),
+                torchtrt.Input(
+                    min_shape=(1,), opt_shape=(3,), max_shape=(16,), dtype=torch.float32
+                ),
+                torchtrt.Input(
+                    min_shape=(1,), opt_shape=(3,), max_shape=(16,), dtype=torch.int32
+                ),
            ],
            min_block_size=1,
        )
        result = trt_mod(src.clone(), values, idx)
        assert torch.allclose(
@@ -587,11 +614,10 @@

        result = trt_mod(cache.clone(), values, idx)
        assert torch.allclose(
            result, torch_output, atol=1e-3, rtol=1e-3
        ), f"KV-cache index_put mismatch: max diff = {(result - torch_output).abs().max()}"
-

    def test_accumulate_random_walk_duplicate_indices(self):
        """accumulate=True on 1-D input where indices are generated by a random walk
        (many duplicates interleaved).  Mirrors PyTorch's
        test_index_put_accumulate_duplicate_indices, scaled to a TRT-friendly size.
@@ -664,13 +690,13 @@
                torchtrt.Input(shape=(1, 2), dtype=torch.int64),
            ],
            min_block_size=1,
        )
        result = trt_mod(src.clone(), values, i0, i1, i2)
-        assert torch.allclose(result, torch_output, atol=1e-4, rtol=1e-4), (
-            f"3D expand accumulate mismatch: max diff = {(result - torch_output).abs().max()}"
-        )
+        assert torch.allclose(
+            result, torch_output, atol=1e-4, rtol=1e-4
+        ), f"3D expand accumulate mismatch: max diff = {(result - torch_output).abs().max()}"

    def test_empty_index_no_op(self):
        """index_put with an empty index tensor is a no-op — output equals input.

        Mirrors PyTorch's test_empty_index: x[empty_idx] = values leaves x unchanged.
@@ -697,13 +723,13 @@
                torchtrt.Input(shape=(0,), dtype=torch.int64),
            ],
            min_block_size=1,
        )
        result = trt_mod(src.clone(), values, idx)
-        assert torch.allclose(result, torch_output, atol=1e-4, rtol=1e-4), (
-            f"Empty-index no-op mismatch: {result} vs {torch_output}"
-        )
+        assert torch.allclose(
+            result, torch_output, atol=1e-4, rtol=1e-4
+        ), f"Empty-index no-op mismatch: {result} vs {torch_output}"

    def test_index_ind_dtype_int_vs_long(self):
        """int32 and int64 index tensors must produce identical results.

        Mirrors PyTorch's test_index_ind_dtype.
@@ -725,27 +751,41 @@
        assert torch.allclose(ref_long, ref_int), "CPU int32 vs int64 mismatch"

        ep_long = torch.export.export(model, args=(src, values, idx_long))
        ep_int = torch.export.export(model, args=(src, values, idx_int))

-        trt_long = torchtrt.dynamo.compile(ep_long, arg_inputs=[
-            torchtrt.Input(shape=(4, 4), dtype=torch.float32),
-            torchtrt.Input(shape=(4,), dtype=torch.float32),
-            torchtrt.Input(shape=(4,), dtype=torch.int64),
-        ], min_block_size=1)
-        trt_int = torchtrt.dynamo.compile(ep_int, arg_inputs=[
-            torchtrt.Input(shape=(4, 4), dtype=torch.float32),
-            torchtrt.Input(shape=(4,), dtype=torch.float32),
-            torchtrt.Input(shape=(4,), dtype=torch.int32),
-        ], min_block_size=1)
+        trt_long = torchtrt.dynamo.compile(
+            ep_long,
+            arg_inputs=[
+                torchtrt.Input(shape=(4, 4), dtype=torch.float32),
+                torchtrt.Input(shape=(4,), dtype=torch.float32),
+                torchtrt.Input(shape=(4,), dtype=torch.int64),
+            ],
+            min_block_size=1,
+        )
+        trt_int = torchtrt.dynamo.compile(
+            ep_int,
+            arg_inputs=[
+                torchtrt.Input(shape=(4, 4), dtype=torch.float32),
+                torchtrt.Input(shape=(4,), dtype=torch.float32),
+                torchtrt.Input(shape=(4,), dtype=torch.int32),
+            ],
+            min_block_size=1,
+        )

        out_long = trt_long(src.clone(), values, idx_long)
        out_int = trt_int(src.clone(), values, idx_int)

-        assert torch.allclose(out_long, ref_long, atol=1e-4, rtol=1e-4), "TRT int64 mismatch"
-        assert torch.allclose(out_int, ref_int, atol=1e-4, rtol=1e-4), "TRT int32 mismatch"
-        assert torch.allclose(out_long, out_int, atol=1e-4, rtol=1e-4), "TRT int32 vs int64 inconsistency"
+        assert torch.allclose(
+            out_long, ref_long, atol=1e-4, rtol=1e-4
+        ), "TRT int64 mismatch"
+        assert torch.allclose(
+            out_int, ref_int, atol=1e-4, rtol=1e-4
+        ), "TRT int32 mismatch"
+        assert torch.allclose(
+            out_long, out_int, atol=1e-4, rtol=1e-4
+        ), "TRT int32 vs int64 inconsistency"

    def test_accumulate_non_contiguous_source(self):
        """accumulate=True on a non-contiguous (sliced) source tensor.

        Mirrors PyTorch's test_index_put_accumulate_non_contiguous.
@@ -780,26 +820,24 @@
                torchtrt.Input(shape=(2,), dtype=torch.int64),
            ],
            min_block_size=1,
        )
        result = trt_mod(src_slice.contiguous(), values, idx)
-        assert torch.allclose(result, torch_output, atol=1e-4, rtol=1e-4), (
-            f"Non-contiguous accumulate mismatch: max diff = {(result - torch_output).abs().max()}"
-        )
+        assert torch.allclose(
+            result, torch_output, atol=1e-4, rtol=1e-4
+        ), f"Non-contiguous accumulate mismatch: max diff = {(result - torch_output).abs().max()}"

    def test_accumulate_expanded_values_broadcast(self):
        """accumulate=True with value broadcasting — 0D scalar and 1D values
        broadcast across unique indexed positions.

        Mirrors PyTorch's test_index_put_accumulate_expanded_values (unique indices only).
        """

        class AccumBroadcast(torch.nn.Module):
            def forward(self, src, values, idx0, idx1):
-                return torch.ops.aten.index_put.default(
-                    src, [idx0, idx1], values, True
-                )
+                return torch.ops.aten.index_put.default(src, [idx0, idx1], values, True)

        src = torch.zeros(5, 2, dtype=torch.float32, device="cuda")
        idx0 = torch.tensor([0, 1, 2, 3], dtype=torch.int64, device="cuda")
        idx1 = torch.tensor([0, 1, 0, 1], dtype=torch.int64, device="cuda")
        values_1d = torch.tensor([1.0], dtype=torch.float32, device="cuda")
@@ -817,12 +855,12 @@
                torchtrt.Input(shape=(4,), dtype=torch.int64),
            ],
            min_block_size=1,
        )
        result = trt_mod(src.clone(), values_1d, idx0, idx1)
-        assert torch.allclose(result, torch_output, atol=1e-4, rtol=1e-4), (
-            f"Accumulate broadcast mismatch: max diff = {(result - torch_output).abs().max()}"
-        )
+        assert torch.allclose(
+            result, torch_output, atol=1e-4, rtol=1e-4
+        ), f"Accumulate broadcast mismatch: max diff = {(result - torch_output).abs().max()}"


if __name__ == "__main__":
    run_tests()

@narendasan narendasan force-pushed the narendasan/push-trxznozvxnsq branch from 8eb87a5 to 3402e30 Compare March 23, 2026 18:19
@narendasan narendasan requested a review from zewenli98 March 23, 2026 18:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: core Issues re: The core compiler component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: tests Issues re: Tests

Projects

None yet

1 participant