@@ -303,11 +303,15 @@ void TensorDesc::EnsureMaximumDimensionCount(uint32_t maximumDimensionCount, Ten
303303{
304304 if (m_bufferTensorDesc.DimensionCount > maximumDimensionCount)
305305 {
306- SetDimensionCount (maximumDimensionCount, alignment);
306+ SetDimensionCount (maximumDimensionCount, alignment, /* foldEndDimensions */ true );
307307 }
308308}
309309
310- void TensorDesc::SetDimensionCount (uint32_t newDimensionCount, TensorAxis alignment)
310+ // Set a new dimension count, adding or removing dimensions as needed.
311+ // If the new rank is larger, any new dimensions are filled with size 1 and stride 0.
312+ // If the new rank is smaller and foldEndDimensions is true, then any removed dimensions are folded together.
313+ // Otherwise those dimensions (leading or trailing, depending on alignment) are simply truncated.
314+ void TensorDesc::SetDimensionCount (uint32_t newDimensionCount, TensorAxis alignment, bool foldEndDimensions)
311315{
312316 ML_CHECK_VALID_ARGUMENT (newDimensionCount <= MaximumDimensionCount);
313317 ML_CHECK_VALID_ARGUMENT (alignment == TensorAxis::RightAligned || alignment == TensorAxis::LeftAligned);
@@ -322,17 +326,29 @@ void TensorDesc::SetDimensionCount(uint32_t newDimensionCount, TensorAxis alignm
322326 int32_t fillOffset = oldDimensionCount;
323327 int32_t fillCount = std::max (0 , difference);
324328
325- // If shrinking the rank, fold dimensions into the first/last dimension.
329+ // If shrinking the rank and asked to fold dimensions, then collapse them into the first/last dimension.
326330 // e.g. Folding 4D dimensions [2,3,4,5] to 3D right-aligned yield [6,4,5]
327331 // e.g. 6D dimensions [2,3,4,5,6,7] to 3D left-aligned yield [1,2,840]
328- if (difference < 0 && newDimensionCount > 0 )
332+ //
333+ // Otherwise dimensions are simply truncated (which may be desired if they were modified before calling).
334+ if (foldEndDimensions && difference < 0 && newDimensionCount > 0 )
329335 {
330336 uint32_t dimensionCountRemoved = -difference;
331337 uint32_t dimensionCountFolded = dimensionCountRemoved + 1 ; // If 2 dimensions are removed, then 3 dimensions are folded into one.
332338 uint32_t targetDimensionIndex;
333339 uint32_t firstFoldedDimensionIndex;
334340
335341 // Determine the range to fold and which dimension to fold them into.
342+ // e.g. Right-aligned: was 4D [2, 3, 4, 5]
343+ // now 2D [12, 5]
344+ // fold <----->
345+ // target *
346+ //
347+ // Left-aligned: was 4D [2, 3, 4, 5]
348+ // now 2D [2, 60]
349+ // fold <----->
350+ // target *
351+ //
336352 if (alignment == TensorAxis::RightAligned)
337353 {
338354 targetDimensionIndex = dimensionCountRemoved; // Fold extra dimensions into the first dimension of the new size.
@@ -349,7 +365,7 @@ void TensorDesc::SetDimensionCount(uint32_t newDimensionCount, TensorAxis alignm
349365 // Ensure no stride broadcasting is lost during the fold, which would silently give incorrect results.
350366 ML_CHECK_VALID_ARGUMENT (
351367 m_bufferTensorDesc.Strides == nullptr ||
352- ! HasBroadcastedDimensions (
368+ AreStridesCollapsible (
353369 { sizeFoldBegin, sizeFoldEnd },
354370 { &m_strides[firstFoldedDimensionIndex], dimensionCountFolded }
355371 )
@@ -433,29 +449,26 @@ void TensorDesc::EnsureStridesExist() noexcept
433449 m_bufferTensorDesc.Strides = m_strides;
434450}
435451
436- bool TensorDesc::HasBroadcastedDimensions (
437- gsl::span<const uint32_t > dimensions,
438- gsl::span<const uint32_t > strides
439- ) noexcept
452+ bool TensorDesc::AreStridesCollapsible (gsl::span<const uint32_t > sizes, gsl::span<const uint32_t > strides) const noexcept
440453{
441- assert (dimensions.size () == strides.size ());
442- for (uint32_t i = 0 ; i < dimensions.size (); ++i)
454+ if (strides.empty ())
455+ {
456+ return true ;
457+ }
458+
459+ assert (sizes.size () == strides.size ());
460+
461+ uint32_t expectedStride = strides.back ();
462+ for (size_t i = strides.size (); i-- > 0 ; )
443463 {
444- // Note logical dimensions of size 1 (even when stride is 0) are not considered broadcasted.
445- if (strides[i] == 0 && dimensions[i] != 1 )
464+ if (sizes[i] != 1 )
446465 {
447- return true ;
466+ if (strides[i] != expectedStride)
467+ {
468+ return false ;
469+ }
470+ expectedStride *= sizes[i];
448471 }
449472 }
450- return false ;
451- }
452-
453- bool TensorDesc::HasBroadcastedDimensions () const noexcept
454- {
455- return IsValid ()
456- && m_bufferTensorDesc.Strides != nullptr
457- && HasBroadcastedDimensions (
458- { m_sizes, m_bufferTensorDesc.DimensionCount },
459- { m_strides, m_bufferTensorDesc.DimensionCount }
460- );
473+ return true ;
461474}
0 commit comments