Open
Conversation
fef7530 to
17d88b1
Compare
There was a problem hiding this comment.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/select.py 2026-03-20 20:13:55.127307+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/select.py 2026-03-20 20:14:14.887149+00:00
@@ -623,11 +623,16 @@
for _fi, _fdim in enumerate(F):
_s = input_tensor.shape[_fdim]
if _s == DYNAMIC_DIM:
F_shape_values.append(
get_shape(
- ctx, target, source_ir, f"{name}_fshape_{_fdim}", input_tensor, _fdim
+ ctx,
+ target,
+ source_ir,
+ f"{name}_fshape_{_fdim}",
+ input_tensor,
+ _fdim,
)
)
else:
F_shape_values.append(_s)
_has_dynamic_f = any(isinstance(_s, TRTTensor) for _s in F_shape_values)
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_index_put_aten.py 2026-03-20 20:13:55.159208+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_index_put_aten.py 2026-03-20 20:14:16.694415+00:00
@@ -357,11 +357,10 @@
)
result = trt_engine(source_tensor, indices_tensor, value_tensor)
torch.allclose(result, torch_output, atol=1e-4, rtol=1e-4)
-
def test_kv_cache_dynamic_batch(self):
"""index_put with a dynamic free dimension (batch) — issue #4139.
Pattern: cache[..., idx, :] = values where dim-1 (batch) is dynamic
and dim-2 (cache/time) is the indexed static dimension.
@@ -420,12 +419,12 @@
use_explicit_typing=True,
min_block_size=1,
)
result = trt_mod(cache.clone(), values, idx)
- assert torch.allclose(result, torch_output, atol=1e-3, rtol=1e-3), (
- f"KV-cache index_put mismatch: max diff = {(result - torch_output).abs().max()}"
- )
+ assert torch.allclose(
+ result, torch_output, atol=1e-3, rtol=1e-3
+ ), f"KV-cache index_put mismatch: max diff = {(result - torch_output).abs().max()}"
if __name__ == "__main__":
run_tests()17d88b1 to
858bf17
Compare
There was a problem hiding this comment.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_index_put_aten.py 2026-03-20 22:25:06.608217+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_index_put_aten.py 2026-03-20 22:25:28.809312+00:00
@@ -392,12 +392,16 @@
)
trt_mod = torchtrt.dynamo.compile(
ep,
arg_inputs=[
torchtrt.Input(shape=(16,), dtype=torch.float32),
- torchtrt.Input(min_shape=(1,), opt_shape=(3,), max_shape=(16,), dtype=torch.float32),
- torchtrt.Input(min_shape=(1,), opt_shape=(3,), max_shape=(16,), dtype=torch.int32),
+ torchtrt.Input(
+ min_shape=(1,), opt_shape=(3,), max_shape=(16,), dtype=torch.float32
+ ),
+ torchtrt.Input(
+ min_shape=(1,), opt_shape=(3,), max_shape=(16,), dtype=torch.int32
+ ),
],
min_block_size=1,
)
result = trt_mod(src.clone(), values, idx)
assert torch.allclose(There was a problem hiding this comment.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/select.py 2026-03-23 18:18:24.127073+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/select.py 2026-03-23 18:18:42.849840+00:00
@@ -772,13 +772,15 @@
ctx, target, source_ir, f"{name}_result_flat", src_flat, delta
)
# Rebuild the output shape (may contain dynamic dims)
out_shape = tuple(
- get_shape(ctx, target, source_ir, f"{name}_oshape_{i}", input_tensor, i)
- if input_tensor.shape[i] == DYNAMIC_DIM
- else input_tensor.shape[i]
+ (
+ get_shape(ctx, target, source_ir, f"{name}_oshape_{i}", input_tensor, i)
+ if input_tensor.shape[i] == DYNAMIC_DIM
+ else input_tensor.shape[i]
+ )
for i in range(rank)
)
return impl.shuffle.reshape(
ctx, target, source_ir, f"{name}_result", result_flat, out_shape
)
@@ -871,11 +873,13 @@
F = [i for i in range(rank) if indices[i] is None] # Free dimensions
I = [i for i in range(rank) if indices[i] is not None] # Indexed dimensions
K = len(I)
# Determine the maximum size 'N' among the index tensors
if K > 0:
- index_shapes = [] # [tensor.shape[0] for tensor in indices if tensor is not None]
+ index_shapes = (
+ []
+ ) # [tensor.shape[0] for tensor in indices if tensor is not None]
for _ni, idx_tensor in enumerate(indices):
if idx_tensor is not None:
if idx_tensor.shape[0] != DYNAMIC_DIM:
index_shapes.append(idx_tensor.shape[0])
else:
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_index_put_aten.py 2026-03-23 18:18:24.152947+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_index_put_aten.py 2026-03-23 18:18:45.151011+00:00
@@ -321,31 +321,45 @@
# duplicate positions correctly (mirrors test_index_put_accumulate_duplicate_indices).
param(
test_name="1d_duplicate_indices_accumulate",
source_tensor=torch.zeros([6], dtype=torch.float32),
indices_tensor=(torch.tensor([0, 0, 2, 2, 2], dtype=torch.int64),),
- value_tensor=torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.float32),
+ value_tensor=torch.tensor(
+ [1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.float32
+ ),
accumulate=True,
),
param(
test_name="2d_indices_accumulate_True",
source_tensor=torch.zeros([5, 5], dtype=torch.float32),
- indices_tensor=(torch.tensor([0, 0], dtype=torch.int32), torch.tensor([1, 1], dtype=torch.int32)),
+ indices_tensor=(
+ torch.tensor([0, 0], dtype=torch.int32),
+ torch.tensor([1, 1], dtype=torch.int32),
+ ),
value_tensor=torch.tensor([1.0, 2.0], dtype=torch.float32),
accumulate=True,
),
param(
test_name="3d_indices_accumulate_True",
source_tensor=torch.zeros([3, 3, 3], dtype=torch.float32),
- indices_tensor=(torch.tensor([0, 0], dtype=torch.int32), torch.tensor([1, 1], dtype=torch.int32), torch.tensor([2, 2], dtype=torch.int32)),
+ indices_tensor=(
+ torch.tensor([0, 0], dtype=torch.int32),
+ torch.tensor([1, 1], dtype=torch.int32),
+ torch.tensor([2, 2], dtype=torch.int32),
+ ),
value_tensor=torch.tensor([1.0, 2.0], dtype=torch.float32),
accumulate=True,
),
param(
test_name="4d_indices_accumulate_True",
source_tensor=torch.zeros([2, 2, 2, 2], dtype=torch.float32),
- indices_tensor=(torch.tensor([0, 0], dtype=torch.int32), torch.tensor([1, 1], dtype=torch.int32), torch.tensor([0, 0], dtype=torch.int32), torch.tensor([1, 1], dtype=torch.int32)),
+ indices_tensor=(
+ torch.tensor([0, 0], dtype=torch.int32),
+ torch.tensor([1, 1], dtype=torch.int32),
+ torch.tensor([0, 0], dtype=torch.int32),
+ torch.tensor([1, 1], dtype=torch.int32),
+ ),
value_tensor=torch.tensor([1.0, 2.0], dtype=torch.float32),
accumulate=True,
),
# Negative indices with accumulate (mirrors test_index_put_accumulate_large_tensor).
param(
@@ -357,29 +371,38 @@
),
# bfloat16 + duplicate indices: computation stays in bfloat16 (no forced fp32 cast).
param(
test_name="accumulate_bfloat16_duplicate",
source_tensor=torch.zeros([4, 4], dtype=torch.bfloat16),
- indices_tensor=(torch.tensor([0, 0, 2], dtype=torch.int64), torch.tensor([1, 1, 3], dtype=torch.int64)),
+ indices_tensor=(
+ torch.tensor([0, 0, 2], dtype=torch.int64),
+ torch.tensor([1, 1, 3], dtype=torch.int64),
+ ),
value_tensor=torch.tensor([1.0, 2.0, 4.0], dtype=torch.bfloat16),
accumulate=True,
),
# float16 + duplicate indices.
param(
test_name="accumulate_float16_duplicate",
source_tensor=torch.zeros([4, 4], dtype=torch.float16),
- indices_tensor=(torch.tensor([1, 1, 3], dtype=torch.int64), torch.tensor([0, 0, 2], dtype=torch.int64)),
+ indices_tensor=(
+ torch.tensor([1, 1, 3], dtype=torch.int64),
+ torch.tensor([0, 0, 2], dtype=torch.int64),
+ ),
value_tensor=torch.tensor([2.0, 3.0, 5.0], dtype=torch.float16),
accumulate=True,
),
# Partial broadcast: one index covers a single position on dim-1 while
# dim-0 has multiple positions — mirrors test_index_put_accumulate_expanded_values
# (t[tensor([0,1,2,3]), tensor([1])] += 1.0).
param(
test_name="accumulate_partial_dim1_broadcast",
source_tensor=torch.zeros([5, 2], dtype=torch.float32),
- indices_tensor=(torch.tensor([0, 1, 2, 3], dtype=torch.int64), torch.tensor([1], dtype=torch.int64)),
+ indices_tensor=(
+ torch.tensor([0, 1, 2, 3], dtype=torch.int64),
+ torch.tensor([1], dtype=torch.int64),
+ ),
value_tensor=torch.tensor([1.0], dtype=torch.float32),
accumulate=True,
),
]
)
@@ -512,12 +535,16 @@
)
trt_mod = torchtrt.dynamo.compile(
ep,
arg_inputs=[
torchtrt.Input(shape=(16,), dtype=torch.float32),
- torchtrt.Input(min_shape=(1,), opt_shape=(3,), max_shape=(16,), dtype=torch.float32),
- torchtrt.Input(min_shape=(1,), opt_shape=(3,), max_shape=(16,), dtype=torch.int32),
+ torchtrt.Input(
+ min_shape=(1,), opt_shape=(3,), max_shape=(16,), dtype=torch.float32
+ ),
+ torchtrt.Input(
+ min_shape=(1,), opt_shape=(3,), max_shape=(16,), dtype=torch.int32
+ ),
],
min_block_size=1,
)
result = trt_mod(src.clone(), values, idx)
assert torch.allclose(
@@ -587,11 +614,10 @@
result = trt_mod(cache.clone(), values, idx)
assert torch.allclose(
result, torch_output, atol=1e-3, rtol=1e-3
), f"KV-cache index_put mismatch: max diff = {(result - torch_output).abs().max()}"
-
def test_accumulate_random_walk_duplicate_indices(self):
"""accumulate=True on 1-D input where indices are generated by a random walk
(many duplicates interleaved). Mirrors PyTorch's
test_index_put_accumulate_duplicate_indices, scaled to a TRT-friendly size.
@@ -664,13 +690,13 @@
torchtrt.Input(shape=(1, 2), dtype=torch.int64),
],
min_block_size=1,
)
result = trt_mod(src.clone(), values, i0, i1, i2)
- assert torch.allclose(result, torch_output, atol=1e-4, rtol=1e-4), (
- f"3D expand accumulate mismatch: max diff = {(result - torch_output).abs().max()}"
- )
+ assert torch.allclose(
+ result, torch_output, atol=1e-4, rtol=1e-4
+ ), f"3D expand accumulate mismatch: max diff = {(result - torch_output).abs().max()}"
def test_empty_index_no_op(self):
"""index_put with an empty index tensor is a no-op — output equals input.
Mirrors PyTorch's test_empty_index: x[empty_idx] = values leaves x unchanged.
@@ -697,13 +723,13 @@
torchtrt.Input(shape=(0,), dtype=torch.int64),
],
min_block_size=1,
)
result = trt_mod(src.clone(), values, idx)
- assert torch.allclose(result, torch_output, atol=1e-4, rtol=1e-4), (
- f"Empty-index no-op mismatch: {result} vs {torch_output}"
- )
+ assert torch.allclose(
+ result, torch_output, atol=1e-4, rtol=1e-4
+ ), f"Empty-index no-op mismatch: {result} vs {torch_output}"
def test_index_ind_dtype_int_vs_long(self):
"""int32 and int64 index tensors must produce identical results.
Mirrors PyTorch's test_index_ind_dtype.
@@ -725,27 +751,41 @@
assert torch.allclose(ref_long, ref_int), "CPU int32 vs int64 mismatch"
ep_long = torch.export.export(model, args=(src, values, idx_long))
ep_int = torch.export.export(model, args=(src, values, idx_int))
- trt_long = torchtrt.dynamo.compile(ep_long, arg_inputs=[
- torchtrt.Input(shape=(4, 4), dtype=torch.float32),
- torchtrt.Input(shape=(4,), dtype=torch.float32),
- torchtrt.Input(shape=(4,), dtype=torch.int64),
- ], min_block_size=1)
- trt_int = torchtrt.dynamo.compile(ep_int, arg_inputs=[
- torchtrt.Input(shape=(4, 4), dtype=torch.float32),
- torchtrt.Input(shape=(4,), dtype=torch.float32),
- torchtrt.Input(shape=(4,), dtype=torch.int32),
- ], min_block_size=1)
+ trt_long = torchtrt.dynamo.compile(
+ ep_long,
+ arg_inputs=[
+ torchtrt.Input(shape=(4, 4), dtype=torch.float32),
+ torchtrt.Input(shape=(4,), dtype=torch.float32),
+ torchtrt.Input(shape=(4,), dtype=torch.int64),
+ ],
+ min_block_size=1,
+ )
+ trt_int = torchtrt.dynamo.compile(
+ ep_int,
+ arg_inputs=[
+ torchtrt.Input(shape=(4, 4), dtype=torch.float32),
+ torchtrt.Input(shape=(4,), dtype=torch.float32),
+ torchtrt.Input(shape=(4,), dtype=torch.int32),
+ ],
+ min_block_size=1,
+ )
out_long = trt_long(src.clone(), values, idx_long)
out_int = trt_int(src.clone(), values, idx_int)
- assert torch.allclose(out_long, ref_long, atol=1e-4, rtol=1e-4), "TRT int64 mismatch"
- assert torch.allclose(out_int, ref_int, atol=1e-4, rtol=1e-4), "TRT int32 mismatch"
- assert torch.allclose(out_long, out_int, atol=1e-4, rtol=1e-4), "TRT int32 vs int64 inconsistency"
+ assert torch.allclose(
+ out_long, ref_long, atol=1e-4, rtol=1e-4
+ ), "TRT int64 mismatch"
+ assert torch.allclose(
+ out_int, ref_int, atol=1e-4, rtol=1e-4
+ ), "TRT int32 mismatch"
+ assert torch.allclose(
+ out_long, out_int, atol=1e-4, rtol=1e-4
+ ), "TRT int32 vs int64 inconsistency"
def test_accumulate_non_contiguous_source(self):
"""accumulate=True on a non-contiguous (sliced) source tensor.
Mirrors PyTorch's test_index_put_accumulate_non_contiguous.
@@ -780,26 +820,24 @@
torchtrt.Input(shape=(2,), dtype=torch.int64),
],
min_block_size=1,
)
result = trt_mod(src_slice.contiguous(), values, idx)
- assert torch.allclose(result, torch_output, atol=1e-4, rtol=1e-4), (
- f"Non-contiguous accumulate mismatch: max diff = {(result - torch_output).abs().max()}"
- )
+ assert torch.allclose(
+ result, torch_output, atol=1e-4, rtol=1e-4
+ ), f"Non-contiguous accumulate mismatch: max diff = {(result - torch_output).abs().max()}"
def test_accumulate_expanded_values_broadcast(self):
"""accumulate=True with value broadcasting — 0D scalar and 1D values
broadcast across unique indexed positions.
Mirrors PyTorch's test_index_put_accumulate_expanded_values (unique indices only).
"""
class AccumBroadcast(torch.nn.Module):
def forward(self, src, values, idx0, idx1):
- return torch.ops.aten.index_put.default(
- src, [idx0, idx1], values, True
- )
+ return torch.ops.aten.index_put.default(src, [idx0, idx1], values, True)
src = torch.zeros(5, 2, dtype=torch.float32, device="cuda")
idx0 = torch.tensor([0, 1, 2, 3], dtype=torch.int64, device="cuda")
idx1 = torch.tensor([0, 1, 0, 1], dtype=torch.int64, device="cuda")
values_1d = torch.tensor([1.0], dtype=torch.float32, device="cuda")
@@ -817,12 +855,12 @@
torchtrt.Input(shape=(4,), dtype=torch.int64),
],
min_block_size=1,
)
result = trt_mod(src.clone(), values_1d, idx0, idx1)
- assert torch.allclose(result, torch_output, atol=1e-4, rtol=1e-4), (
- f"Accumulate broadcast mismatch: max diff = {(result - torch_output).abs().max()}"
- )
+ assert torch.allclose(
+ result, torch_output, atol=1e-4, rtol=1e-4
+ ), f"Accumulate broadcast mismatch: max diff = {(result - torch_output).abs().max()}"
if __name__ == "__main__":
run_tests()8eb87a5 to
3402e30
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
Index put appears in KV cache implementations from huggingface, we get a broadcast error because there was no validator catching this. This PR adds support for dynamic shape in the op and properly guards failure modes
Fixes #4139
Fixes #4142
Fixes #3647
Fixes #2939
Fixes #3798
Fixes #3806
index_putfails with dynamic non-indexed dimensions ("Dynamic shape in free dimensions not supported")get_shapeindex_copy_/StaticCachefails with shape broadcast error ("Cannot broadcast (1,8,1,128) to (1,1,8,128)")index_add_fails with dynamic shapex[bool_3d_mask] = 0.0crashes withValueError: __len__() should return >= 0expand_boolean_indicesto correctly split TRTadd_non_zerooutput(ndim, N)into per-dim(N,)tensorsaccumulate=Truewith duplicate indices produces wrong results (scatter overwrites instead of accumulating)Accumulate is supported by a MxP implementation of native tensorrt ops which we expect to have ~10-15% overhead in reasonably sized problems but can be costly in very large tasks.
Type of change
Please delete options that are not relevant and/or add your own.
Checklist: