Skip to content

Commit adb2569

Browse files
committed
Fix EinSum failure. Support more MatMul broadcasting via AreStridesCollapsible.
1 parent 6d8eef2 commit adb2569

File tree

2 files changed

+40
-32
lines changed

2 files changed

+40
-32
lines changed

onnxruntime/core/providers/dml/DmlExecutionProvider/src/TensorDesc.cpp

Lines changed: 38 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -303,11 +303,15 @@ void TensorDesc::EnsureMaximumDimensionCount(uint32_t maximumDimensionCount, Ten
303303
{
304304
if (m_bufferTensorDesc.DimensionCount > maximumDimensionCount)
305305
{
306-
SetDimensionCount(maximumDimensionCount, alignment);
306+
SetDimensionCount(maximumDimensionCount, alignment, /*foldEndDimensions*/ true);
307307
}
308308
}
309309

310-
void TensorDesc::SetDimensionCount(uint32_t newDimensionCount, TensorAxis alignment)
310+
// Set a new dimension count, adding or removing dimensions as needed.
311+
// If the new rank is larger, any new dimensions are filled with size 1 and stride 0.
312+
// If the new rank is smaller and foldEndDimensions is true, then any removed dimensions are folded together.
313+
// Otherwise those dimensions (leading or trailing, depending on alignment) are simply truncated.
314+
void TensorDesc::SetDimensionCount(uint32_t newDimensionCount, TensorAxis alignment, bool foldEndDimensions)
311315
{
312316
ML_CHECK_VALID_ARGUMENT(newDimensionCount <= MaximumDimensionCount);
313317
ML_CHECK_VALID_ARGUMENT(alignment == TensorAxis::RightAligned || alignment == TensorAxis::LeftAligned);
@@ -322,17 +326,29 @@ void TensorDesc::SetDimensionCount(uint32_t newDimensionCount, TensorAxis alignm
322326
int32_t fillOffset = oldDimensionCount;
323327
int32_t fillCount = std::max(0, difference);
324328

325-
// If shrinking the rank, fold dimensions into the first/last dimension.
329+
// If shrinking the rank and asked to fold dimensions, then collapse them into the first/last dimension.
326330
// e.g. Folding 4D dimensions [2,3,4,5] to 3D right-aligned yield [6,4,5]
327331
// e.g. 6D dimensions [2,3,4,5,6,7] to 3D left-aligned yield [1,2,840]
328-
if (difference < 0 && newDimensionCount > 0)
332+
//
333+
// Otherwise dimensions are simply truncated (which may be desired if they were modified before calling).
334+
if (foldEndDimensions && difference < 0 && newDimensionCount > 0)
329335
{
330336
uint32_t dimensionCountRemoved = -difference;
331337
uint32_t dimensionCountFolded = dimensionCountRemoved + 1; // If 2 dimensions are removed, then 3 dimensions are folded into one.
332338
uint32_t targetDimensionIndex;
333339
uint32_t firstFoldedDimensionIndex;
334340

335341
// Determine the range to fold and which dimension to fold them into.
342+
// e.g. Right-aligned: was 4D [2, 3, 4, 5]
343+
// now 2D [12, 5]
344+
// fold <----->
345+
// target *
346+
//
347+
// Left-aligned: was 4D [2, 3, 4, 5]
348+
// now 2D [2, 60]
349+
// fold <----->
350+
// target *
351+
//
336352
if (alignment == TensorAxis::RightAligned)
337353
{
338354
targetDimensionIndex = dimensionCountRemoved; // Fold extra dimensions into the first dimension of the new size.
@@ -349,7 +365,7 @@ void TensorDesc::SetDimensionCount(uint32_t newDimensionCount, TensorAxis alignm
349365
// Ensure no stride broadcasting is lost during the fold, which would silently give incorrect results.
350366
ML_CHECK_VALID_ARGUMENT(
351367
m_bufferTensorDesc.Strides == nullptr ||
352-
!HasBroadcastedDimensions(
368+
AreStridesCollapsible(
353369
{ sizeFoldBegin, sizeFoldEnd },
354370
{ &m_strides[firstFoldedDimensionIndex], dimensionCountFolded }
355371
)
@@ -433,29 +449,26 @@ void TensorDesc::EnsureStridesExist() noexcept
433449
m_bufferTensorDesc.Strides = m_strides;
434450
}
435451

436-
bool TensorDesc::HasBroadcastedDimensions(
437-
gsl::span<const uint32_t> dimensions,
438-
gsl::span<const uint32_t> strides
439-
) noexcept
452+
bool TensorDesc::AreStridesCollapsible(gsl::span<const uint32_t> sizes, gsl::span<const uint32_t> strides) const noexcept
440453
{
441-
assert(dimensions.size() == strides.size());
442-
for (uint32_t i = 0; i < dimensions.size(); ++i)
454+
if (strides.empty())
455+
{
456+
return true;
457+
}
458+
459+
assert(sizes.size() == strides.size());
460+
461+
uint32_t expectedStride = strides.back();
462+
for (size_t i = strides.size(); i-- > 0; )
443463
{
444-
// Note logical dimensions of size 1 (even when stride is 0) are not considered broadcasted.
445-
if (strides[i] == 0 && dimensions[i] != 1)
464+
if (sizes[i] != 1)
446465
{
447-
return true;
466+
if (strides[i] != expectedStride)
467+
{
468+
return false;
469+
}
470+
expectedStride *= sizes[i];
448471
}
449472
}
450-
return false;
451-
}
452-
453-
bool TensorDesc::HasBroadcastedDimensions() const noexcept
454-
{
455-
return IsValid()
456-
&& m_bufferTensorDesc.Strides != nullptr
457-
&& HasBroadcastedDimensions(
458-
{ m_sizes, m_bufferTensorDesc.DimensionCount },
459-
{ m_strides, m_bufferTensorDesc.DimensionCount }
460-
);
473+
return true;
461474
}

onnxruntime/core/providers/dml/DmlExecutionProvider/src/TensorDesc.h

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,20 +40,15 @@ namespace Dml
4040

4141
inline bool IsValid() const noexcept { return m_tensorType != DML_TENSOR_TYPE_INVALID; }
4242
inline uint32_t GetDimensionCount() const { return m_bufferTensorDesc.DimensionCount; }
43-
void SetDimensionCount(uint32_t newDimensionCount, TensorAxis alignment);
43+
void SetDimensionCount(uint32_t newDimensionCount, TensorAxis alignment, bool foldEndDimensions = false);
4444
void EnsureMinimumDimensionCount(uint32_t newDimensionCount, TensorAxis alignment);
4545
void EnsureMaximumDimensionCount(uint32_t maximumDimensionCount, TensorAxis alignment);
4646

4747
gsl::span<const uint32_t> GetSizes() const noexcept { return { m_sizes, m_sizes + m_bufferTensorDesc.DimensionCount }; }
4848
gsl::span<const uint32_t> GetStrides() const noexcept;
4949
void SetStrides(gsl::span<const uint32_t> strides);
5050
void EnsureStridesExist() noexcept;
51-
bool HasBroadcastedDimensions() const noexcept;
52-
static bool HasBroadcastedDimensions(
53-
gsl::span<const uint32_t> dimensions,
54-
gsl::span<const uint32_t> strides
55-
) noexcept;
56-
51+
bool AreStridesCollapsible(gsl::span<const uint32_t> sizes, gsl::span<const uint32_t> strides) const noexcept;
5752
void SetDimensionsAndStrides(gsl::span<const uint32_t> sizes, gsl::span<const uint32_t> strides);
5853

5954
// Rearranges existing m_sizes and m_strides by gathering axes from dimensionMapping.

0 commit comments

Comments
 (0)