@@ -143,6 +143,41 @@ def strides_from_shape(shape: UserShape) -> UserStrides:
143143 return tuple (reversed (layout [:- 1 ]))
144144
145145
146+ def normalize_slice (s : slice , dim_size : int ) -> Tuple [int , int , int ]:
147+ """
148+ Normalize a slice object to (start, stop, step) with proper bounds.
149+
150+ Args:
151+ s: slice object
152+ dim_size: size of the dimension being sliced
153+
154+ Returns:
155+ (start, stop, step) tuple with normalized values
156+ """
157+ step = s .step if s .step is not None else 1
158+ if step == 0 :
159+ raise IndexingError ("slice step cannot be zero" )
160+
161+ if step < 0 :
162+ start = s .start if s .start is not None else dim_size - 1
163+ stop = s .stop if s .stop is not None else - dim_size - 1
164+ else :
165+ start = s .start if s .start is not None else 0
166+ stop = s .stop if s .stop is not None else dim_size
167+
168+ if start < 0 :
169+ start = max (0 , dim_size + start )
170+ else :
171+ start = min (start , dim_size )
172+
173+ if stop < 0 :
174+ stop = max (- 1 if step < 0 else 0 , dim_size + stop )
175+ else :
176+ stop = min (stop , dim_size )
177+
178+ return start , stop , step
179+
180+
146181class TensorData :
147182 _storage : Storage
148183 _strides : Strides
@@ -175,7 +210,8 @@ def __init__(
175210 self .dims = len (strides )
176211 self .size = int (prod (shape ))
177212 self .shape = shape
178- assert len (self ._storage ) == self .size
213+ # Note: Storage can be larger than size for non-contiguous views
214+ # assert len(self._storage) == self.size
179215
180216 def to_cuda_ (self ) -> None : # pragma: no cover
181217 if not numba .cuda .is_cuda_array (self ._storage ):
@@ -260,6 +296,55 @@ def permute(self, *order: int) -> TensorData:
260296 new_strides = tuple (self .strides [i ] for i in order )
261297 return TensorData (self ._storage , new_shape , new_strides )
262298
299+ def slice (self , key : Union [int , slice , Sequence [Union [int , slice ]]]) -> TensorData :
300+ """
301+ Create a sliced view of the tensor.
302+
303+ Args:
304+ key: int, slice, or tuple of ints/slices for indexing
305+
306+ Returns:
307+ New TensorData representing the sliced view
308+ """
309+ if isinstance (key , (int , slice )):
310+ key = (key ,)
311+
312+ if len (key ) > len (self .shape ):
313+ raise IndexingError (f"Too many indices { len (key )} for tensor of dimension { len (self .shape )} " )
314+
315+ key = tuple (key ) + (slice (None ),) * (len (self .shape ) - len (key ))
316+
317+ new_shape = []
318+ new_strides = []
319+ offset = 0
320+
321+ for dim , (k , dim_size , stride ) in enumerate (zip (key , self .shape , self .strides )):
322+ if isinstance (k , int ):
323+ idx = k
324+ if idx < 0 :
325+ idx = dim_size + idx
326+ if idx < 0 or idx >= dim_size :
327+ raise IndexingError (f"Index { k } out of range for dimension { dim } with size { dim_size } " )
328+ offset += idx * stride
329+ elif isinstance (k , slice ):
330+ start , stop , step = normalize_slice (k , dim_size )
331+ if step > 0 :
332+ size = max (0 , (stop - start + step - 1 ) // step )
333+ else :
334+ size = max (0 , (stop - start + step + 1 ) // step )
335+
336+ new_shape .append (size )
337+ new_strides .append (stride * step )
338+ offset += start * stride
339+ else :
340+ raise IndexingError (f"Unsupported index type: { type (k )} " )
341+
342+ if len (new_shape ) == 0 :
343+ scalar_val = self ._storage [offset ]
344+ return TensorData ([scalar_val ], (1 ,), (1 ,))
345+
346+ return _make_tensor_data_view (self ._storage , tuple (new_shape ), tuple (new_strides ), offset )
347+
263348 def to_string (self ) -> str :
264349 s = ""
265350 for index in self .indices ():
@@ -283,3 +368,30 @@ def to_string(self) -> str:
283368 else :
284369 s += " "
285370 return s
371+
372+
373+ def _make_tensor_data_view (
374+ storage : Storage , shape : UserShape , strides : UserStrides , offset : int
375+ ) -> TensorData :
376+ """
377+ Create a TensorData view with an offset into the storage.
378+
379+ Args:
380+ storage: The underlying storage array
381+ shape: Shape of the view
382+ strides: Strides for the view
383+ offset: Offset into the storage where the view starts
384+
385+ Returns:
386+ TensorData representing the view
387+ """
388+ if len (shape ) == 0 or prod (shape ) == 0 :
389+ # Empty tensor
390+ return TensorData ([], shape , strides )
391+
392+ if offset > 0 :
393+ view_storage = storage [offset :]
394+ else :
395+ view_storage = storage
396+
397+ return TensorData (view_storage , shape , strides )
0 commit comments