diff --git a/py/torch_tensorrt/dynamo/conversion/impl/select.py b/py/torch_tensorrt/dynamo/conversion/impl/select.py index 7b8a77dd56..7b4bb2ef19 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/select.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/select.py @@ -2,10 +2,9 @@ from typing import List, Optional, Sequence, Union import numpy as np -import tensorrt as trt import torch -from tensorrt import ITensor as TRTTensor from torch.fx.node import Target +from torch_tensorrt._enums import dtype from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion import impl from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext @@ -22,6 +21,9 @@ from torch_tensorrt.dynamo.conversion.impl.shape import shape as get_shape from torch_tensorrt.dynamo.utils import DYNAMIC_DIM +import tensorrt as trt +from tensorrt import ITensor + _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -30,11 +32,11 @@ def select( target: Target, source_ir: Optional[SourceIR], name: str, - input: TRTTensor, + input: ITensor, dim: int, index: int, -) -> TRTTensor: - if not isinstance(input, TRTTensor): +) -> ITensor: + if not isinstance(input, ITensor): raise RuntimeError( f"slice_tensor received input {input} that is not part " "of the TensorRT region!" @@ -52,13 +54,13 @@ def select( def is_boolean_tensor( - tensor: Union[TRTTensor, np.ndarray, torch.Tensor, torch.fx.Node], + tensor: Union[ITensor, np.ndarray, torch.Tensor, torch.fx.Node], ) -> bool: if isinstance(tensor, torch.Tensor): return bool(tensor.dtype == torch.bool) elif isinstance(tensor, np.ndarray): return bool(tensor.dtype == np.bool_) - elif isinstance(tensor, TRTTensor): + elif isinstance(tensor, ITensor): return bool(tensor.dtype == trt.DataType.BOOL) # when index is a node else: @@ -74,9 +76,9 @@ def expand_boolean_indices( target: Target, source_ir: Optional[SourceIR], name: str, - input: TRTTensor, - indices: Sequence[Union[TRTTensor, np.ndarray, torch.Tensor]], -) -> Sequence[Union[TRTTensor, np.ndarray, torch.Tensor]]: + input: ITensor, + indices: Sequence[Union[ITensor, np.ndarray, torch.Tensor]], +) -> Sequence[Union[ITensor, np.ndarray, torch.Tensor]]: new_indices = [] for i, ind in enumerate(indices): if ind is not None and is_boolean_tensor(ind): @@ -89,24 +91,53 @@ def expand_boolean_indices( set_layer_name( nonzero_layer, target, name + f"_bool_nonzero_{i}", source_ir ) - nonzero_indices = nonzero_layer.get_output(0) - - # nonzero returns shape [N, dims], we need to extract dim i - if len(indices) == 1: - # x[mask] — 1D mask + # TRT add_non_zero returns shape (ndim, N): row d holds the d-th + # coordinate of every nonzero element. This is the transpose of + # PyTorch's nonzero() which returns (N, ndim). + # Ref: https://docs.nvidia.com/deeplearning/tensorrt/latest/reference/python-api/infer/Graph/Layers.html#tensorrt.INetworkDefinition.add_non_zero + nonzero_indices = nonzero_layer.get_output(0) # (mask_ndim, N) + + mask_ndim = len(ind.shape) if hasattr(ind, "shape") else 1 + + if len(indices) == 1 and mask_ndim > 1: + # x[bool_nd] = v — single N-D boolean mask. + # Extract row d (axis=0) from (mask_ndim, N) → (N,) per dim. + for d in range(mask_ndim): + gather_layer = ctx.net.add_gather( + nonzero_indices, + get_trt_tensor(ctx, d, name + f"_bool_nz_dim_{i}_{d}"), + axis=0, + ) + set_layer_name( + gather_layer, + target, + name + f"_bool_nonzero_row_{i}_{d}", + source_ir, + ) + row = gather_layer.get_output(0) # (N,) + sq = ctx.net.add_shuffle(row) + sq.reshape_dims = (-1,) + set_layer_name( + sq, target, name + f"_bool_row_sq_{i}_{d}", source_ir + ) + new_indices.append(sq.get_output(0)) + continue # already appended all per-dim indices; skip append below + elif len(indices) == 1: + # x[bool_1d] = v — 1D mask: nonzero → (1, N), flatten to (N,). to_squeeze = nonzero_indices else: - # Advanced multi-axis mask: extract index i from shape [N, D] - gather_axis = 1 # dim index + # Multi-index bool (1-D bool at position i): extract row i from + # (1, N) — i.e. gather row 0 along axis=0. gather_layer = ctx.net.add_gather( nonzero_indices, - get_trt_tensor(ctx, i, name + f"_dim_index_{i}"), - gather_axis, + get_trt_tensor(ctx, 0, name + f"_dim_index_{i}"), + axis=0, ) set_layer_name( gather_layer, target, name + f"_bool_nonzero_extract_{i}", source_ir ) to_squeeze = gather_layer.get_output(0) + squeeze_layer = ctx.net.add_shuffle(to_squeeze) squeeze_layer.reshape_dims = (-1,) set_layer_name( @@ -127,9 +158,9 @@ def index( target: Target, source_ir: Optional[SourceIR], name: str, - input: TRTTensor, - indices: Sequence[Union[TRTTensor, np.ndarray, torch.Tensor]], -) -> TRTTensor: + input: ITensor, + indices: Sequence[Union[ITensor, np.ndarray, torch.Tensor]], +) -> ITensor: adv_indx_indices = [] tensor_indices = [] # is_numpy is a flag to specify if all the indices are numpy or torchTensor. @@ -152,20 +183,24 @@ def index( adv_indx_indices.append(i) # torch.nn.parameter.Parameter=> numpy array # numpy array is kept as numpy - # other cases are kept as TRTTensor + # other cases are kept as ITensor if is_numpy: ind = to_numpy(ind) else: ind = get_trt_tensor(ctx, ind, name + f"_parameter_to_fp32_tensor_{i}") if last_index is not None: - assert broadcastable( - ind, last_index - ), "The indices should be broadcastable!" + assert broadcastable(ind, last_index), ( + f"Index tensors must be broadcastable with each other, but index {i} " + f"has shape {tuple(ind.shape)} which is not broadcastable with the " + f"previous index shape {tuple(last_index.shape)}. " + "All advanced (integer/boolean) indices must follow NumPy style broadcasting rules. " + "See https://numpy.org/doc/stable/user/basics.broadcasting.html" + ) last_index = ind tensor_indices.append(ind) if not tensor_indices: - cast_layer = ctx.net.add_cast(input, trt.int32) + cast_layer = ctx.net.add_cast(input, dtype.i32.to(trt.DataType)) set_layer_name(cast_layer, target, name + "_index_casted", source_ir) return cast_layer.get_output(0) elif len(tensor_indices) == 1: @@ -469,10 +504,10 @@ def index_select( target: Target, source_ir: Optional[SourceIR], name: str, - input: TRTTensor, + input: ITensor, dim: int, - index: TRTTensor, -) -> TRTTensor: + index: ITensor, +) -> ITensor: # The axis parameter specifies the dimension along which to index. dim = get_positive_dim(dim, len(input.shape)) gather_layer = ctx.net.add_gather(input, index, axis=dim) @@ -487,11 +522,11 @@ def scatter( target: Target, source_ir: Optional[SourceIR], name: str, - input: TRTTensor, + input: ITensor, dim: int, - index: Union[TRTTensor, np.ndarray, torch.Tensor], - src: Union[TRTTensor, int, float], -) -> TRTTensor: + index: Union[ITensor, np.ndarray, torch.Tensor], + src: Union[ITensor, int, float], +) -> ITensor: input_shape = input.shape index_shape = index.shape index_shape_list = list(index_shape) @@ -524,7 +559,7 @@ def scatter( ctx, src_tensor, input.dtype, name + "_cast_value_tensor" ) # scatter.src - elif not (isinstance(src, TRTTensor)): + elif not (isinstance(src, ITensor)): src_tensor = get_trt_tensor(ctx, src, name + "_src_tensor") scatter_layer = ctx.net.add_scatter( @@ -541,10 +576,10 @@ def gather( target: Target, source_ir: Optional[SourceIR], name: str, - input: TRTTensor, + input: ITensor, dim: int, - index: Union[TRTTensor, np.ndarray, torch.Tensor], -) -> TRTTensor: + index: Union[ITensor, np.ndarray, torch.Tensor], +) -> ITensor: input_shape = input.shape dim = get_positive_dim(dim, len(input_shape)) index = cast_trt_tensor(ctx, index, trt.int32, name + "_cast_index_tensor") @@ -555,26 +590,222 @@ def gather( return out +def _index_put_scatter_add( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input_tensor: ITensor, + indices_cat: ITensor, + flattened_values: ITensor, +) -> ITensor: + """Scatter-add via indicator-matrix matmul. + + Correctly accumulates into duplicate index positions, unlike + ``ScatterMode.ND`` which overwrites on collision. + + Algorithm + --------- + Given ``P`` scatter positions and ``M`` total elements in ``input_tensor``: + + 1. Linearise the ND indices (shape ``(P, rank)``) to flat indices (shape + ``(P,)``) using the row-major strides of ``input_tensor``. + 2. Build a boolean indicator matrix of shape ``(M, P)`` where entry + ``[i, j]`` is True iff ``flat_idx[j] == i``. + 3. Compute ``delta = indicator @ values`` (shape ``(M,)``). Because the + matmul sums over ``j``, duplicate positions are accumulated exactly. + 4. Return ``flatten(input) + delta`` reshaped to ``input_tensor.shape``. + + Memory cost: O(M * P). Suitable for inference use-cases where both M and P + are small (e.g. KV-cache updates, small buffer writes). + """ + rank = len(input_tensor.shape) + + # ------------------------------------------------------------------ + # Step 1: per-dimension sizes (int when static, ITensor when dynamic) + # ------------------------------------------------------------------ + dims: List[Union[int, ITensor]] = [] + for i in range(rank): + s = input_tensor.shape[i] + if s == DYNAMIC_DIM: + dims.append( + get_shape(ctx, target, source_ir, f"{name}_dim_{i}", input_tensor, i) + ) + else: + dims.append(s) + + # ------------------------------------------------------------------ + # Step 2: row-major strides (stride[k] = product(dims[k+1:])) + # ------------------------------------------------------------------ + strides: List[Union[int, ITensor]] = [1] * rank + for k in range(rank - 2, -1, -1): + d = dims[k + 1] + prev = strides[k + 1] + if isinstance(d, ITensor) or isinstance(prev, ITensor): + strides[k] = impl.elementwise.mul( + ctx, target, source_ir, f"{name}_stride_{k}", prev, d + ) + else: + strides[k] = prev * d + + # ------------------------------------------------------------------ + # Step 3: M = total number of elements + # ------------------------------------------------------------------ + M: Union[int, ITensor] = dims[0] + for i in range(1, rank): + d = dims[i] + if isinstance(M, ITensor) or isinstance(d, ITensor): + M = impl.elementwise.mul(ctx, target, source_ir, f"{name}_M_{i}", M, d) + else: + M = M * d + + # ------------------------------------------------------------------ + # Step 4: linearise indices_cat (P, rank) -> flat_idx (P,) + # flat_idx = sum_k( indices_cat[:, k] * strides[k] ) + # ------------------------------------------------------------------ + flat_idx: Optional[ITensor] = None + for k in range(rank): + # Extract column k from indices_cat via a gather on axis=1 + k_t = get_trt_tensor(ctx, k, f"{name}_col_idx_{k}", min_rank=0) + gather_l = ctx.net.add_gather(indices_cat, k_t, axis=1) + set_layer_name(gather_l, target, f"{name}_gather_col_{k}", source_ir) + col_k = gather_l.get_output(0) # shape (P,) + + # Normalize negative indices: idx < 0 → idx + dim_size + dim_k = dims[k] + if isinstance(dim_k, ITensor): + dim_k_t = dim_k + else: + dim_k_t = get_trt_tensor(ctx, dim_k, f"{name}_dim_val_{k}", min_rank=0) + zero_k = get_trt_tensor(ctx, 0, f"{name}_zero_{k}", min_rank=0) + is_neg = impl.elementwise.lt( + ctx, target, source_ir, f"{name}_is_neg_{k}", col_k, zero_k + ) + col_k_shifted = impl.elementwise.add( + ctx, target, source_ir, f"{name}_col_shifted_{k}", col_k, dim_k_t + ) + sel_l = ctx.net.add_select(is_neg, col_k_shifted, col_k) + set_layer_name(sel_l, target, f"{name}_sel_neg_{k}", source_ir) + col_k = sel_l.get_output(0) + + stride_k = strides[k] + if isinstance(stride_k, ITensor): + contrib = impl.elementwise.mul( + ctx, target, source_ir, f"{name}_contrib_{k}", col_k, stride_k + ) + elif stride_k == 1: + contrib = col_k + else: + stride_t = get_trt_tensor( + ctx, + stride_k, + f"{name}_stride_val_{k}", + min_rank=0, + ) + contrib = impl.elementwise.mul( + ctx, target, source_ir, f"{name}_contrib_{k}", col_k, stride_t + ) + + flat_idx = ( + contrib + if flat_idx is None + else impl.elementwise.add( + ctx, target, source_ir, f"{name}_flat_idx_{k}", flat_idx, contrib + ) + ) + + # ------------------------------------------------------------------ + # Step 5: indicator matrix (M, P) = (arange(M)[:,None] == flat_idx[None,:]) + # ------------------------------------------------------------------ + arange_M = impl.arange.arange( + ctx, target, source_ir, f"{name}_arange_M", 0, M, 1 + ) # (M,) int32 + arange_M = cast_trt_tensor(ctx, arange_M, dtype.i32, f"{name}_arange_int32") + + # Reshape for broadcast: (M, 1) and (1, P) + arange_col = impl.unsqueeze.unsqueeze( + ctx, target, source_ir, f"{name}_arange_col", arange_M, 1 + ) # (M, 1) + flat_idx_row = impl.unsqueeze.unsqueeze( + ctx, target, source_ir, f"{name}_flat_idx_row", flat_idx, 0 + ) # (1, P) + + indicator = impl.elementwise.eq( + ctx, target, source_ir, f"{name}_indicator", arange_col, flat_idx_row + ) # (M, P) bool + + # ------------------------------------------------------------------ + # Step 6: delta = indicator @ values (M,) + # ------------------------------------------------------------------ + # TRT matmul requires a floating-point dtype. Use the input's own dtype + # when it is already a float type so precision is preserved natively + # (matrix_multiply handles the fp16 fp32-acc path internally). + # Fall back to float32 only for non-floating-point inputs (e.g. int32). + _float_dtypes = (trt.float32, trt.float16, trt.bfloat16) + compute_dtype = ( + input_tensor.dtype if input_tensor.dtype in _float_dtypes else trt.float32 + ) + + indicator_f = cast_trt_tensor( + ctx, indicator, compute_dtype, f"{name}_indicator_cast" + ) # (M, P) + values_f = cast_trt_tensor( + ctx, flattened_values, compute_dtype, f"{name}_values_cast" + ) # (P,) + + # matrix_multiply treats the 1-D `values_f` as a column vector (VECTOR + # mode) and returns shape (M,). + delta = impl.matmul.matrix_multiply( + ctx, target, source_ir, f"{name}_delta", indicator_f, values_f + ) # (M,) + + # Cast delta back to the original input dtype (no-op for float inputs) + delta = cast_trt_tensor(ctx, delta, input_tensor.dtype, f"{name}_delta_cast") + + # ------------------------------------------------------------------ + # Step 7: flatten input, add delta, reshape back + # ------------------------------------------------------------------ + src_flat = impl.shuffle.reshape( + ctx, target, source_ir, f"{name}_src_flat", input_tensor, (-1,) + ) + result_flat = impl.elementwise.add( + ctx, target, source_ir, f"{name}_result_flat", src_flat, delta + ) + + # Rebuild the output shape (may contain dynamic dims) + out_shape = tuple( + ( + get_shape(ctx, target, source_ir, f"{name}_oshape_{i}", input_tensor, i) + if input_tensor.shape[i] == DYNAMIC_DIM + else input_tensor.shape[i] + ) + for i in range(rank) + ) + return impl.shuffle.reshape( + ctx, target, source_ir, f"{name}_result", result_flat, out_shape + ) + + def index_put_converter( ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, - input_tensor: TRTTensor, - input_indices: Sequence[Union[TRTTensor, torch.Tensor, np.ndarray, int, None]], - values: TRTTensor, + input_tensor: ITensor, + input_indices: Sequence[Union[ITensor, torch.Tensor, np.ndarray, int, None]], + values: ITensor, accumulate: bool = False, -) -> TRTTensor: +) -> ITensor: # Convert 'input_indices' to TRT tensors (or keep None as is) input_indices = expand_boolean_indices( ctx, target, source_ir, name, input_tensor, input_indices ) - indices: List[Optional[Union[TRTTensor, None]]] = [] + indices: List[Optional[Union[ITensor, None]]] = [] for i, idx in enumerate(input_indices): if idx is None: indices.append(None) else: - if not isinstance(idx, TRTTensor): + if not isinstance(idx, ITensor): idx = get_trt_tensor(ctx, idx, f"{name}_index_{i}", min_rank=1) if len(idx.shape) == 0 or not idx.shape: # Reshape a scalar to (1,) idx = impl.shuffle.reshape( @@ -582,6 +813,58 @@ def index_put_converter( ) indices.append(idx) + # Normalize multi-dimensional index tensors. + # PyTorch allows mesh-style indices like [arange(3)[:,None], arange(2)[None,:]] + # which broadcast together before scattering. The rest of the converter + # pipeline assumes every non-None index is 1-D, so we broadcast all + # non-None indices to their common shape and flatten each to (N,) here. + _non_none_idx = [(pos, idx) for pos, idx in enumerate(indices) if idx is not None] + if _non_none_idx and any(len(idx.shape) > 1 for _, idx in _non_none_idx): + _max_ndim = max(len(idx.shape) for _, idx in _non_none_idx) + # Compute the static broadcast shape (dynamic mesh indices unsupported). + _bcast: List[int] = [1] * _max_ndim + for _, idx in _non_none_idx: + _padded = (1,) * (_max_ndim - len(idx.shape)) + tuple( + int(s) for s in idx.shape + ) + for j, (a, b) in enumerate(zip(_bcast, _padded)): + if a == 1: + _bcast[j] = b + elif b != 1 and b != a: + raise ValueError( + f"index_put: cannot broadcast index shapes {[idx.shape for _, idx in _non_none_idx]}" + ) + # Expand each non-None index to _bcast then flatten to 1-D. + for pos, idx in _non_none_idx: + if len(idx.shape) < _max_ndim: + for _d in range(_max_ndim - len(idx.shape)): + idx = impl.unsqueeze.unsqueeze( + ctx, target, source_ir, f"{name}_idx_ndpad_{pos}_{_d}", idx, 0 + ) + idx = impl.slice.expand( + ctx, target, source_ir, f"{name}_idx_ndbcast_{pos}", idx, tuple(_bcast) + ) + idx = impl.shuffle.reshape( + ctx, target, source_ir, f"{name}_idx_ndflat_{pos}", idx, (-1,) + ) + indices[pos] = idx + + # Also pre-broadcast values to the mesh shape and flatten so they match N. + # e.g. values (2,) + mesh (3,2) → expand to (3,2) → flatten to (6,). + if not isinstance(values, ITensor): + values = get_trt_tensor(ctx, values, f"{name}_values_nd", min_rank=0) + if len(values.shape) < _max_ndim: + for _d in range(_max_ndim - len(values.shape)): + values = impl.unsqueeze.unsqueeze( + ctx, target, source_ir, f"{name}_val_ndpad_{_d}", values, 0 + ) + values = impl.slice.expand( + ctx, target, source_ir, f"{name}_val_ndbcast", values, tuple(_bcast) + ) + values = impl.shuffle.reshape( + ctx, target, source_ir, f"{name}_val_ndflat", values, (-1,) + ) + rank = len(input_tensor.shape) # Pad the 'indices' list with None for remaining dimensions indices = list(indices) + [None] * (rank - len(indices)) @@ -595,7 +878,7 @@ def index_put_converter( index_shapes = ( [] ) # [tensor.shape[0] for tensor in indices if tensor is not None] - for idx_tensor in indices: + for _ni, idx_tensor in enumerate(indices): if idx_tensor is not None: if idx_tensor.shape[0] != DYNAMIC_DIM: index_shapes.append(idx_tensor.shape[0]) @@ -605,19 +888,53 @@ def index_put_converter( ctx, target, source_ir, - name + "idx_shape_dim_0", + name + f"idx_shape_dim_0_{_ni}", idx_tensor, 0, ) ) - N = max(index_shapes) if index_shapes else 1 + # When any index has a dynamic size, use the first dynamic value + # (all valid indices are guaranteed to have the same N after broadcasting). + # Python's max() cannot compare ITensors, so we avoid it when dynamic. + if any(isinstance(s, ITensor) for s in index_shapes): + N = next(s for s in index_shapes if isinstance(s, ITensor)) + else: + N = max(index_shapes) if index_shapes else 1 else: N = 1 - # Compute shapes and volume for the free dimensions + # Compute shapes and volume for the free dimensions. + # F_shapes: static ints (-1 for dynamic dims), used where static ints are required. + # F_shape_values: per-free-dim size as int (static) or ITensor (dynamic). + # F_volume: product of F_shape_values, int if all static else ITensor. F_shapes = [input_tensor.shape[i] for i in F] - assert -1 not in F_shapes, "Dynamic shape in free dimensions is not supported" - F_volume = trt.volume(F_shapes) if F_shapes else 1 + F_shape_values: List[Union[int, ITensor]] = [] + for _fi, _fdim in enumerate(F): + _s = input_tensor.shape[_fdim] + if _s == DYNAMIC_DIM: + F_shape_values.append( + get_shape( + ctx, + target, + source_ir, + f"{name}_fshape_{_fdim}", + input_tensor, + _fdim, + ) + ) + else: + F_shape_values.append(_s) + _has_dynamic_f = any(isinstance(_s, ITensor) for _s in F_shape_values) + # Can't figure out a better way to calculate the volume at runtime + if _has_dynamic_f: + _fvol: Union[int, ITensor] = 1 + for _i, _s in enumerate(F_shape_values): + _fvol = impl.elementwise.mul( + ctx, target, source_ir, f"{name}_fvol_{_i}", _fvol, _s + ) + F_volume: Union[int, ITensor] = _fvol + else: + F_volume = trt.volume(F_shapes) if F_shapes else 1 # Process indexed dimensions (I) I_tensors = [] @@ -640,8 +957,8 @@ def index_put_converter( # Create a meshgrid for free dimensions (F) if len(F) > 0: arange_tensors = [] - for dim in F: - dim_size = input_tensor.shape[dim] + for _fi2, dim in enumerate(F): + dim_size = F_shape_values[_fi2] # int or ITensor arange_tensor = impl.arange.arange( ctx, target, source_ir, f"{name}_arange_{dim}", 0, dim_size, 1 ) @@ -653,8 +970,8 @@ def index_put_converter( else: meshgrid_tensors = [] for i, arange in enumerate(arange_tensors): - reshape_shape = [1] * len(F) - reshape_shape[i] = F_shapes[i] + reshape_shape: List[Union[int, ITensor]] = [1] * len(F) + reshape_shape[i] = F_shape_values[i] arange_reshaped = impl.shuffle.reshape( ctx, target, @@ -669,7 +986,7 @@ def index_put_converter( source_ir, f"{name}_expand_arange_F_{F[i]}", arange_reshaped, - tuple(F_shapes), + tuple(F_shape_values), ) meshgrid_tensors.append(expanded_arange) @@ -685,7 +1002,7 @@ def index_put_converter( source_ir, f"{name}_reshape_mesh_{i}", t, - (*F_shapes, 1), + (*F_shape_values, 1), ) for i, t in enumerate(meshgrid_tensors) ], @@ -715,7 +1032,6 @@ def index_put_converter( # Combine all indexed dimensions (I) if K > 0: - I_combined = [ impl.shuffle.reshape( ctx, target, source_ir, f"{name}_reshape_I_{i}", t, (N, F_volume, 1) @@ -736,19 +1052,27 @@ def index_put_converter( ii_list.append(idx_tensor) i_idx += 1 else: - start = [0, 0, f_idx] - shape = [-1, F_volume, 1] if isinstance(N, TRTTensor) else [N, F_volume, 1] - stride = [1, 1, 1] - mesh_tensor = impl.slice.slice( + # Extract the f_idx-th column along the last dim (static len(F)) of + # meshgrid_expanded (shape: N×F_volume×len(F)). Using gather+unsqueeze + # avoids passing F_volume (potentially a ITensor) as a slice shape. + f_idx_t = get_trt_tensor( + ctx, + np.array(f_idx, dtype=np.int32), + f"{name}_f_idx_t_{unique_suffix}", + ) + gather_l = ctx.net.add_gather(meshgrid_expanded, f_idx_t, axis=2) + set_layer_name( + gather_l, target, f"{name}_gather_mesh_{unique_suffix}", source_ir + ) + mesh_tensor = gather_l.get_output(0) # (N, F_volume) + mesh_tensor = impl.unsqueeze.unsqueeze( ctx, target, source_ir, - f"{name}_slice_F_dim_{unique_suffix}", - meshgrid_expanded, - start, - shape, - stride, - ) + f"{name}_unsq_mesh_{unique_suffix}", + mesh_tensor, + 2, + ) # (N, F_volume, 1) ii_list.append(mesh_tensor) f_idx += 1 @@ -767,13 +1091,14 @@ def index_put_converter( (-1, rank), ) - if not isinstance(values, TRTTensor): + if not isinstance(values, ITensor): values = get_trt_tensor(ctx, values, f"{name}_values", min_rank=0) - # Define the expected shape based on (N,) + F_shapes - expected_shape = ( - (-1,) + tuple(F_shapes) if isinstance(N, TRTTensor) else (N,) + tuple(F_shapes) - ) + # Define the expected shape based on (N,) + F_shape_values. + # N may be a ITensor when it comes from a dynamic source (e.g. nonzero). + # Pass it directly — impl.slice.expand handles ITensor shape elements and + # will emit a stride-0 broadcast when expanding a size-1 dim to a dynamic N. + expected_shape = (N,) + tuple(F_shape_values) # Broadcast 'values' to match the expected shape if len(values.shape) == 0 or values.shape == (1,): # Scalar case @@ -790,12 +1115,33 @@ def index_put_converter( ) else: # Non-scalar case values_shape = list(values.shape) - if ( + if K == 1 and len(values.shape) == rank: + # For a single indexed dimension where values has the same rank as + # the input, permute values from input layout + # (dim0, ..., I[0], ..., dimN-1) → (I[0], F[0], ..., F[k-1]). + # This gives expected_shape = (N, *F_shape_values) directly and + # correctly handles non-contiguous free dims and dynamic batch dims. + # When values has fewer dims than rank it is being broadcast, so + # fall through to the discontinuous path which handles that via padding. + perm_order = I + F + values_permuted = impl.permutation.permute( + ctx, target, source_ir, f"{name}_permute_values", values, perm_order + ) + # Expand any size-1 dims to match expected_shape (handles broadcasting). + values_expanded = impl.slice.expand( + ctx, + target, + source_ir, + f"{name}_expand_values", + values_permuted, + expected_shape, + ) + elif ( K > 0 and N in values_shape and (len(F) > 1 and max(F) - min(F) + 1 == len(F)) ): - # Continuous case + # Continuous case (K > 1, F dims contiguous) n_idx = values_shape.index(N) permute_order = [n_idx] + [ i for i in range(len(values_shape)) if i != n_idx @@ -841,7 +1187,7 @@ def index_put_converter( tuple(broadcast_shape), ) else: - # Discontinuous case + # Discontinuous case (K > 1 or K == 0) values_shape_padded = [1] * ( len(expected_shape) - len(values.shape) ) + list(values.shape) @@ -876,40 +1222,12 @@ def index_put_converter( ) indices_cat = cast_trt_tensor(ctx, indices_cat, trt.int32, f"{name}_idx_int32") if accumulate: - zero_tensor = impl.full.full( - ctx, - target, - source_ir, - f"{name}_zero_tensor", - [ - get_shape( - ctx, - target, - source_ir, - name + f"input_tensor_shape_dim_{i}", - input_tensor, - i, - ) - for i in range(len(input_tensor.shape)) - ], - 0.0, - dtype=input_tensor.dtype, + # Use indicator-matrix matmul for correct scatter-add semantics. + # ScatterMode.ND overwrites on duplicate indices; _index_put_scatter_add + # accumulates them via (M, P) @ (P,) matmul. + return _index_put_scatter_add( + ctx, target, source_ir, name, input_tensor, indices_cat, flattened_values ) - # Perform Scatter ND operation - scatter_layer = ctx.net.add_scatter( - zero_tensor, - indices_cat, - flattened_values, - trt.ScatterMode.ND, - ) - set_layer_name(scatter_layer, target, f"{name}_scatter", source_ir) - - scatter_out = scatter_layer.get_output(0) - result = impl.elementwise.add( - ctx, target, source_ir, f"{name}_add", scatter_out, input_tensor - ) - return result - else: scatter_layer = ctx.net.add_scatter( input_tensor, diff --git a/tests/py/dynamo/conversion/test_index_put_aten.py b/tests/py/dynamo/conversion/test_index_put_aten.py index a7b58d0d3c..1a8a4d1d18 100644 --- a/tests/py/dynamo/conversion/test_index_put_aten.py +++ b/tests/py/dynamo/conversion/test_index_put_aten.py @@ -5,6 +5,11 @@ from .harness import DispatchTestCase +# NOTE: accumulate=True with *duplicate* indices is NOT supported in TRT. +# TensorRT's ScatterMode.ND overwrites on collision — there is no scatter_add +# reduction mode. The current implementation (scatter into zeros + elementwise +# add) only gives correct results when every scattered index is unique. + class TestIndexPutConverter(DispatchTestCase): @parameterized.expand( @@ -201,6 +206,13 @@ class TestIndexPutConverter(DispatchTestCase): # indices_tensor=(None, None, torch.tensor([0, 1, 2], dtype=torch.int64)), # value_tensor=torch.randn([2, 3, 3], dtype=torch.float32), # ), + param( + test_name="trailing_none_after_tensor", + # K=1: indexed dim first, trailing free dims as None + source_tensor=torch.zeros([4, 3, 2], dtype=torch.float32), + indices_tensor=(torch.tensor([1, 3], dtype=torch.int64), None, None), + value_tensor=torch.ones([2, 3, 2], dtype=torch.float32), + ), param( test_name="discontinuous_test", source_tensor=torch.zeros([2, 4, 4], dtype=torch.float32), @@ -233,27 +245,165 @@ class TestIndexPutConverter(DispatchTestCase): ), value_tensor=torch.tensor([2, 3, 3, 4], dtype=torch.float32), ), - # param( - # test_name="2d_indices_accumulate_True", - # source_tensor=torch.zeros([5, 5], dtype=torch.int32), - # indices_tensor=(torch.tensor([0, 0], dtype=torch.int32), torch.tensor([1, 1], dtype=torch.int32)), - # value_tensor=torch.tensor([1, 2], dtype=torch.int32), - # accumulate=True, - # ), - # param( - # test_name="3d_indices_accumulate_True", - # source_tensor=torch.zeros([3, 3, 3], dtype=torch.int32), - # indices_tensor=(torch.tensor([0, 0], dtype=torch.int32), torch.tensor([1, 1], dtype=torch.int32), torch.tensor([2, 2], dtype=torch.int32)), - # value_tensor=torch.tensor([1, 2], dtype=torch.int32), - # accumulate=True, - # ), - # param( - # test_name="4d_indices_accumulate_True", - # source_tensor=torch.zeros([2, 2, 2, 2], dtype=torch.int32), - # indices_tensor=(torch.tensor([0, 0], dtype=torch.int32), torch.tensor([1, 1], dtype=torch.int32), torch.tensor([0, 0], dtype=torch.int32), torch.tensor([1, 1], dtype=torch.int32)), - # value_tensor=torch.tensor([1, 2], dtype=torch.int32), - # accumulate=True, - # ), + # --- dtype coverage (mirrors PyTorch's test_index_put_src_datatype) --- + param( + test_name="bfloat16_1d_single", + source_tensor=torch.zeros([5], dtype=torch.bfloat16), + indices_tensor=(torch.tensor([0, 3], dtype=torch.int32),), + value_tensor=torch.tensor([1.5, 3.5], dtype=torch.bfloat16), + ), + param( + test_name="bfloat16_2d_multiple", + source_tensor=torch.zeros([5, 5], dtype=torch.bfloat16), + indices_tensor=( + torch.tensor([0, 2], dtype=torch.int32), + torch.tensor([2, 0], dtype=torch.int32), + ), + value_tensor=torch.tensor([1.5, 3.5], dtype=torch.bfloat16), + ), + param( + test_name="float16_1d_single", + source_tensor=torch.zeros([5], dtype=torch.float16), + indices_tensor=(torch.tensor([1, 4], dtype=torch.int32),), + value_tensor=torch.tensor([2.0, 4.0], dtype=torch.float16), + ), + param( + test_name="float16_2d_multiple", + source_tensor=torch.zeros([4, 4], dtype=torch.float16), + indices_tensor=( + torch.tensor([0, 3], dtype=torch.int32), + torch.tensor([1, 2], dtype=torch.int32), + ), + value_tensor=torch.tensor([1.0, 2.0], dtype=torch.float16), + ), + # --- index dtype: int64 (mirrors PyTorch's test_index_ind_dtype) --- + param( + test_name="int64_indices_2d", + source_tensor=torch.zeros([4, 4], dtype=torch.float32), + indices_tensor=( + torch.tensor([0, 1, 2, 3], dtype=torch.int64), + torch.tensor([0, 1, 2, 3], dtype=torch.int64), + ), + value_tensor=torch.tensor([1.0, 2.0, 3.0, 4.0], dtype=torch.float32), + ), + # --- accumulate=True, unique indices (safe for TRT scatter) --- + # Mirrors PyTorch's test_index_put_accumulate_expanded_values (no duplicates). + param( + test_name="accumulate_true_unique_indices_1d", + source_tensor=torch.ones([6], dtype=torch.float32), + indices_tensor=(torch.tensor([0, 2, 4], dtype=torch.int64),), + value_tensor=torch.tensor([10.0, 20.0, 30.0], dtype=torch.float32), + accumulate=True, + ), + param( + test_name="accumulate_true_unique_indices_2d", + source_tensor=torch.ones([4, 3], dtype=torch.float32), + indices_tensor=( + torch.tensor([0, 2], dtype=torch.int64), + torch.tensor([1, 2], dtype=torch.int64), + ), + value_tensor=torch.tensor([5.0, 7.0], dtype=torch.float32), + accumulate=True, + ), + # Broadcast: single value written to multiple unique positions. + param( + test_name="accumulate_true_broadcast_scalar_value", + source_tensor=torch.zeros([5, 2], dtype=torch.float32), + indices_tensor=( + torch.tensor([0, 1, 3], dtype=torch.int64), + torch.tensor([0, 1, 0], dtype=torch.int64), + ), + value_tensor=torch.tensor([1.0], dtype=torch.float32), + accumulate=True, + ), + # --- accumulate=True with duplicate indices (uses _index_put_scatter_add) --- + # These exercise the matmul-based scatter_add path which handles + # duplicate positions correctly (mirrors test_index_put_accumulate_duplicate_indices). + param( + test_name="1d_duplicate_indices_accumulate", + source_tensor=torch.zeros([6], dtype=torch.float32), + indices_tensor=(torch.tensor([0, 0, 2, 2, 2], dtype=torch.int64),), + value_tensor=torch.tensor( + [1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.float32 + ), + accumulate=True, + ), + param( + test_name="2d_indices_accumulate_True", + source_tensor=torch.zeros([5, 5], dtype=torch.float32), + indices_tensor=( + torch.tensor([0, 0], dtype=torch.int32), + torch.tensor([1, 1], dtype=torch.int32), + ), + value_tensor=torch.tensor([1.0, 2.0], dtype=torch.float32), + accumulate=True, + ), + param( + test_name="3d_indices_accumulate_True", + source_tensor=torch.zeros([3, 3, 3], dtype=torch.float32), + indices_tensor=( + torch.tensor([0, 0], dtype=torch.int32), + torch.tensor([1, 1], dtype=torch.int32), + torch.tensor([2, 2], dtype=torch.int32), + ), + value_tensor=torch.tensor([1.0, 2.0], dtype=torch.float32), + accumulate=True, + ), + param( + test_name="4d_indices_accumulate_True", + source_tensor=torch.zeros([2, 2, 2, 2], dtype=torch.float32), + indices_tensor=( + torch.tensor([0, 0], dtype=torch.int32), + torch.tensor([1, 1], dtype=torch.int32), + torch.tensor([0, 0], dtype=torch.int32), + torch.tensor([1, 1], dtype=torch.int32), + ), + value_tensor=torch.tensor([1.0, 2.0], dtype=torch.float32), + accumulate=True, + ), + # Negative indices with accumulate (mirrors test_index_put_accumulate_large_tensor). + param( + test_name="accumulate_negative_indices", + source_tensor=torch.zeros([6], dtype=torch.float32), + indices_tensor=(torch.tensor([-1, -1, -3], dtype=torch.int64),), + value_tensor=torch.tensor([5.0, 7.0, 3.0], dtype=torch.float32), + accumulate=True, + ), + # bfloat16 + duplicate indices: computation stays in bfloat16 (no forced fp32 cast). + param( + test_name="accumulate_bfloat16_duplicate", + source_tensor=torch.zeros([4, 4], dtype=torch.bfloat16), + indices_tensor=( + torch.tensor([0, 0, 2], dtype=torch.int64), + torch.tensor([1, 1, 3], dtype=torch.int64), + ), + value_tensor=torch.tensor([1.0, 2.0, 4.0], dtype=torch.bfloat16), + accumulate=True, + ), + # float16 + duplicate indices. + param( + test_name="accumulate_float16_duplicate", + source_tensor=torch.zeros([4, 4], dtype=torch.float16), + indices_tensor=( + torch.tensor([1, 1, 3], dtype=torch.int64), + torch.tensor([0, 0, 2], dtype=torch.int64), + ), + value_tensor=torch.tensor([2.0, 3.0, 5.0], dtype=torch.float16), + accumulate=True, + ), + # Partial broadcast: one index covers a single position on dim-1 while + # dim-0 has multiple positions — mirrors test_index_put_accumulate_expanded_values + # (t[tensor([0,1,2,3]), tensor([1])] += 1.0). + param( + test_name="accumulate_partial_dim1_broadcast", + source_tensor=torch.zeros([5, 2], dtype=torch.float32), + indices_tensor=( + torch.tensor([0, 1, 2, 3], dtype=torch.int64), + torch.tensor([1], dtype=torch.int64), + ), + value_tensor=torch.tensor([1.0], dtype=torch.float32), + accumulate=True, + ), ] ) def test_index_put( @@ -359,6 +509,358 @@ def forward(self, source_tensor, indices_tensor, value_tensor): torch.allclose(result, torch_output, atol=1e-4, rtol=1e-4) + def test_index_put_dynamic_index_length(self): + """index_put where the index tensor itself has a dynamic length (N dynamic). + + Pattern: src[idx] = values — no free dims, K=rank=1, index length dynamic. + """ + + class IndexPutDynN(torch.nn.Module): + def forward(self, src, values, idx): + return torch.ops.aten.index_put.default(src, [idx], values) + + src = torch.zeros(16, dtype=torch.float32, device="cuda") + n_dim = torch.export.Dim("n", min=1, max=16) + + model = IndexPutDynN().eval().cuda() + # concrete inputs for reference + idx = torch.tensor([0, 2, 4], dtype=torch.int32, device="cuda") + values = torch.ones(3, dtype=torch.float32, device="cuda") + torch_output = model(src.clone(), values, idx) + + ep = torch.export.export( + model, + args=(src, values, idx), + dynamic_shapes={"src": {}, "values": {0: n_dim}, "idx": {0: n_dim}}, + ) + trt_mod = torchtrt.dynamo.compile( + ep, + arg_inputs=[ + torchtrt.Input(shape=(16,), dtype=torch.float32), + torchtrt.Input( + min_shape=(1,), opt_shape=(3,), max_shape=(16,), dtype=torch.float32 + ), + torchtrt.Input( + min_shape=(1,), opt_shape=(3,), max_shape=(16,), dtype=torch.int32 + ), + ], + min_block_size=1, + ) + result = trt_mod(src.clone(), values, idx) + assert torch.allclose( + result, torch_output, atol=1e-4, rtol=1e-4 + ), f"Dynamic index length mismatch: max diff = {(result - torch_output).abs().max()}" + + def test_kv_cache_dynamic_batch(self): + """index_put with a dynamic free dimension (batch) — issue #4139. + + Pattern: cache[..., idx, :] = values where dim-1 (batch) is dynamic + and dim-2 (cache/time) is the indexed static dimension. + """ + + class KVCacheModel(torch.nn.Module): + def forward(self, cache, values, idx): + cache[..., idx, :] = values + return cache + + N = 4 + max_ctx = 256 + L = 1 + H = 512 + + cache = torch.zeros(2, N, max_ctx, H, dtype=torch.float16, device="cuda") + values = torch.randn(2, N, L, H, dtype=torch.float16, device="cuda") + idx = torch.tensor([3], dtype=torch.long, device="cuda") + + model = KVCacheModel().eval().cuda() + torch_output = model(cache.clone(), values, idx) + + batch_dim = torch.export.Dim("batch", min=1, max=64) + ep = torch.export.export( + model, + args=(cache, values, idx), + dynamic_shapes={ + "cache": {1: batch_dim}, + "values": {1: batch_dim}, + "idx": {}, + }, + ) + + trt_mod = torchtrt.dynamo.compile( + ep, + arg_inputs=[ + torchtrt.Input( + min_shape=(2, 1, max_ctx, H), + opt_shape=(2, N, max_ctx, H), + max_shape=(2, 64, max_ctx, H), + dtype=torch.float16, + ), + torchtrt.Input( + min_shape=(2, 1, L, H), + opt_shape=(2, N, L, H), + max_shape=(2, 64, L, H), + dtype=torch.float16, + ), + torchtrt.Input( + min_shape=(L,), + opt_shape=(L,), + max_shape=(L,), + dtype=torch.long, + ), + ], + use_explicit_typing=True, + min_block_size=1, + ) + + result = trt_mod(cache.clone(), values, idx) + assert torch.allclose( + result, torch_output, atol=1e-3, rtol=1e-3 + ), f"KV-cache index_put mismatch: max diff = {(result - torch_output).abs().max()}" + + def test_accumulate_random_walk_duplicate_indices(self): + """accumulate=True on 1-D input where indices are generated by a random walk + (many duplicates interleaved). Mirrors PyTorch's + test_index_put_accumulate_duplicate_indices, scaled to a TRT-friendly size. + """ + import random + + torch.manual_seed(42) + random.seed(42) + + class AccumDup(torch.nn.Module): + def forward(self, src, values, idx): + return torch.ops.aten.index_put.default(src, [idx], values, True) + + for trial in range(5): + n = random.randint(8, 32) + delta = torch.empty(n, dtype=torch.float32).uniform_(-1, 1) + idx = delta.cumsum(0).long() + src_size = int(idx.abs().max().item()) + 1 + src = torch.randn(src_size, dtype=torch.float32, device="cuda") + values = torch.randn(n, dtype=torch.float32, device="cuda") + idx_cuda = idx.cuda() + + model = AccumDup().eval().cuda() + torch_output = model(src.clone(), values, idx_cuda) + + ep = torch.export.export(model, args=(src, values, idx_cuda)) + trt_mod = torchtrt.dynamo.compile( + ep, + arg_inputs=[ + torchtrt.Input(shape=(src_size,), dtype=torch.float32), + torchtrt.Input(shape=(n,), dtype=torch.float32), + torchtrt.Input(shape=(n,), dtype=torch.int64), + ], + min_block_size=1, + ) + result = trt_mod(src.clone(), values, idx_cuda) + assert torch.allclose(result, torch_output, atol=1e-4, rtol=1e-4), ( + f"Trial {trial}: random-walk accumulate mismatch, " + f"max diff = {(result - torch_output).abs().max()}" + ) + + def test_accumulate_expanded_values_3d(self): + """accumulate=True with 3-D source and broadcast mesh indices — mirrors the + second half of PyTorch's test_index_put_accumulate_expanded_values. + + Pattern: t[tensor([0]), arange(3)[:,None], arange(2)[None,:]] += values + """ + + class AccumMesh(torch.nn.Module): + def forward(self, src, values, i0, i1, i2): + return torch.ops.aten.index_put.default(src, [i0, i1, i2], values, True) + + src = torch.zeros(4, 3, 2, dtype=torch.float32, device="cuda") + i0 = torch.tensor([0], dtype=torch.int64, device="cuda") + i1 = torch.arange(3, dtype=torch.int64, device="cuda").view(3, 1) + i2 = torch.arange(2, dtype=torch.int64, device="cuda").view(1, 2) + values = torch.tensor([-1.0, -2.0], dtype=torch.float32, device="cuda") + + model = AccumMesh().eval().cuda() + torch_output = model(src.clone(), values, i0, i1, i2) + + ep = torch.export.export(model, args=(src, values, i0, i1, i2)) + trt_mod = torchtrt.dynamo.compile( + ep, + arg_inputs=[ + torchtrt.Input(shape=(4, 3, 2), dtype=torch.float32), + torchtrt.Input(shape=(2,), dtype=torch.float32), + torchtrt.Input(shape=(1,), dtype=torch.int64), + torchtrt.Input(shape=(3, 1), dtype=torch.int64), + torchtrt.Input(shape=(1, 2), dtype=torch.int64), + ], + min_block_size=1, + ) + result = trt_mod(src.clone(), values, i0, i1, i2) + assert torch.allclose( + result, torch_output, atol=1e-4, rtol=1e-4 + ), f"3D expand accumulate mismatch: max diff = {(result - torch_output).abs().max()}" + + def test_empty_index_no_op(self): + """index_put with an empty index tensor is a no-op — output equals input. + + Mirrors PyTorch's test_empty_index: x[empty_idx] = values leaves x unchanged. + Uses static shapes (torch.export rejects Dim(min=0) with a concrete 0-size tensor). + """ + + class EmptyIndexPut(torch.nn.Module): + def forward(self, src, values, idx): + return torch.ops.aten.index_put.default(src, [idx], values) + + src = torch.arange(8, dtype=torch.float32, device="cuda") + idx = torch.tensor([], dtype=torch.int64, device="cuda") + values = torch.tensor([], dtype=torch.float32, device="cuda") + + model = EmptyIndexPut().eval().cuda() + torch_output = model(src.clone(), values, idx) + + ep = torch.export.export(model, args=(src, values, idx)) + trt_mod = torchtrt.dynamo.compile( + ep, + arg_inputs=[ + torchtrt.Input(shape=(8,), dtype=torch.float32), + torchtrt.Input(shape=(0,), dtype=torch.float32), + torchtrt.Input(shape=(0,), dtype=torch.int64), + ], + min_block_size=1, + ) + result = trt_mod(src.clone(), values, idx) + assert torch.allclose( + result, torch_output, atol=1e-4, rtol=1e-4 + ), f"Empty-index no-op mismatch: {result} vs {torch_output}" + + def test_index_ind_dtype_int_vs_long(self): + """int32 and int64 index tensors must produce identical results. + + Mirrors PyTorch's test_index_ind_dtype. + """ + + class IndexPutIntIdx(torch.nn.Module): + def forward(self, src, values, idx): + return torch.ops.aten.index_put.default(src, [idx], values) + + src = torch.zeros(4, 4, dtype=torch.float32, device="cuda") + values = torch.tensor([1.0, 2.0, 3.0, 4.0], device="cuda") + + idx_long = torch.arange(4, dtype=torch.int64, device="cuda") + idx_int = idx_long.to(torch.int32) + + model = IndexPutIntIdx().eval().cuda() + ref_long = model(src.clone(), values, idx_long) + ref_int = model(src.clone(), values, idx_int) + assert torch.allclose(ref_long, ref_int), "CPU int32 vs int64 mismatch" + + ep_long = torch.export.export(model, args=(src, values, idx_long)) + ep_int = torch.export.export(model, args=(src, values, idx_int)) + + trt_long = torchtrt.dynamo.compile( + ep_long, + arg_inputs=[ + torchtrt.Input(shape=(4, 4), dtype=torch.float32), + torchtrt.Input(shape=(4,), dtype=torch.float32), + torchtrt.Input(shape=(4,), dtype=torch.int64), + ], + min_block_size=1, + ) + trt_int = torchtrt.dynamo.compile( + ep_int, + arg_inputs=[ + torchtrt.Input(shape=(4, 4), dtype=torch.float32), + torchtrt.Input(shape=(4,), dtype=torch.float32), + torchtrt.Input(shape=(4,), dtype=torch.int32), + ], + min_block_size=1, + ) + + out_long = trt_long(src.clone(), values, idx_long) + out_int = trt_int(src.clone(), values, idx_int) + + assert torch.allclose( + out_long, ref_long, atol=1e-4, rtol=1e-4 + ), "TRT int64 mismatch" + assert torch.allclose( + out_int, ref_int, atol=1e-4, rtol=1e-4 + ), "TRT int32 mismatch" + assert torch.allclose( + out_long, out_int, atol=1e-4, rtol=1e-4 + ), "TRT int32 vs int64 inconsistency" + + def test_accumulate_non_contiguous_source(self): + """accumulate=True on a non-contiguous (sliced) source tensor. + + Mirrors PyTorch's test_index_put_accumulate_non_contiguous. + Uses unique indices so TRT scatter is correct. + """ + + class AccumNonContig(torch.nn.Module): + def forward(self, src, values, idx): + # src is already a non-contiguous slice passed in + return torch.ops.aten.index_put.default(src, [idx], values, True) + + base = torch.zeros(5, 2, 2, dtype=torch.float32, device="cuda") + # take a non-contiguous slice: shape (5, 2), stride (4, 1) + src_slice = base[:, 0, :] + assert not src_slice.is_contiguous() + + idx = torch.tensor([0, 2], dtype=torch.int64, device="cuda") + values = torch.ones(2, 2, dtype=torch.float32, device="cuda") + + model = AccumNonContig().eval().cuda() + torch_output = model(src_slice.clone(), values, idx) + + ep = torch.export.export( + model, + args=(src_slice.contiguous(), values, idx), + ) + trt_mod = torchtrt.dynamo.compile( + ep, + arg_inputs=[ + torchtrt.Input(shape=(5, 2), dtype=torch.float32), + torchtrt.Input(shape=(2, 2), dtype=torch.float32), + torchtrt.Input(shape=(2,), dtype=torch.int64), + ], + min_block_size=1, + ) + result = trt_mod(src_slice.contiguous(), values, idx) + assert torch.allclose( + result, torch_output, atol=1e-4, rtol=1e-4 + ), f"Non-contiguous accumulate mismatch: max diff = {(result - torch_output).abs().max()}" + + def test_accumulate_expanded_values_broadcast(self): + """accumulate=True with value broadcasting — 0D scalar and 1D values + broadcast across unique indexed positions. + + Mirrors PyTorch's test_index_put_accumulate_expanded_values (unique indices only). + """ + + class AccumBroadcast(torch.nn.Module): + def forward(self, src, values, idx0, idx1): + return torch.ops.aten.index_put.default(src, [idx0, idx1], values, True) + + src = torch.zeros(5, 2, dtype=torch.float32, device="cuda") + idx0 = torch.tensor([0, 1, 2, 3], dtype=torch.int64, device="cuda") + idx1 = torch.tensor([0, 1, 0, 1], dtype=torch.int64, device="cuda") + values_1d = torch.tensor([1.0], dtype=torch.float32, device="cuda") + + model = AccumBroadcast().eval().cuda() + torch_output = model(src.clone(), values_1d, idx0, idx1) + + ep = torch.export.export(model, args=(src, values_1d, idx0, idx1)) + trt_mod = torchtrt.dynamo.compile( + ep, + arg_inputs=[ + torchtrt.Input(shape=(5, 2), dtype=torch.float32), + torchtrt.Input(shape=(1,), dtype=torch.float32), + torchtrt.Input(shape=(4,), dtype=torch.int64), + torchtrt.Input(shape=(4,), dtype=torch.int64), + ], + min_block_size=1, + ) + result = trt_mod(src.clone(), values_1d, idx0, idx1) + assert torch.allclose( + result, torch_output, atol=1e-4, rtol=1e-4 + ), f"Accumulate broadcast mismatch: max diff = {(result - torch_output).abs().max()}" + if __name__ == "__main__": run_tests()