Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/TypeRange.h"
#include "mlir/IR/Types.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#define DEBUG_TYPE "iree-decompose-map-scatter"
Expand Down Expand Up @@ -143,7 +146,7 @@ struct VectorizationResult {
/// The vectorized mask values for conditional stores.
Value maskVector;
/// The flattened 1D output buffer.
Value flatOutputBuffer;
Value flatOutput;
};

/// Vectorize the index computation and mask evaluation for a `map_scatter` op.
Expand All @@ -161,11 +164,22 @@ performIndexAndMaskVectorization(MapScatterOp mapScatterOp,
Location loc = mapScatterOp.getLoc();
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(mapScatterOp);
SmallVector<OpFoldResult> outputSizes =
memref::getMixedSizes(rewriter, loc, mapScatterOp.getOutput());
Comment thread
sakupan102 marked this conversation as resolved.
Value flatOutput;
SmallVector<Value> strides;
Value flatOutputBuffer = createFlatOutputBuffer(
rewriter, loc, mapScatterOp.getOutput(), outputSizes, strides);
SmallVector<OpFoldResult> outputSizes =
getDims(rewriter, loc, mapScatterOp.getOutput());
if (mapScatterOp.hasPureBufferSemantics()) {
flatOutput = createFlatOutputBuffer(rewriter, loc, mapScatterOp.getOutput(),
outputSizes, strides);
} else {
// For tensor outputs, create a flat output buffer as an empty tensor.
auto outputType = cast<TensorType>(mapScatterOp.getOutputType());
SmallVector<ReassociationIndices> reassociations;
reassociations.push_back(
llvm::to_vector(llvm::seq<int64_t>(outputType.getRank())));
flatOutput = tensor::CollapseShapeOp::create(
rewriter, loc, mapScatterOp.getOutput(), reassociations);
Comment on lines +180 to +181
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fine for now, since we don't have any concrete use cases yet, but adding collapse_shape ops on tensors can have more implications than it does for memrefs. It may be fine, but bufferization can often get tripped up by reshapes, which could end up creating extra large allocations. This is part of the challenge with vectorizing map_scatter on tensors, and we need to try it e2e to see if it works or not.

That said, I'm not sure what the best way to represent it will be, so we can reevaluate once we are able to test this in an e2e path.

}
auto inputType = cast<VectorType>(mapScatterOp.getInputType());
auto bodyBuilder = [&](OpBuilder &b, Location nestedLoc, ValueRange args) {
auto inlineBodyBuilder = [&](OpBuilder inlineBuilder, Location inlineLoc,
Expand Down Expand Up @@ -258,7 +272,7 @@ performIndexAndMaskVectorization(MapScatterOp mapScatterOp,
rewriter.eraseOp(indexWriteOp);
rewriter.eraseOp(maskWriteOp);
rewriter.eraseOp(genericOp);
return VectorizationResult{indexVector, maskVector, flatOutputBuffer};
return VectorizationResult{indexVector, maskVector, flatOutput};
}

/// Decompose the `map_scatter` into a sequence of `vector.extract` and
Expand All @@ -275,7 +289,7 @@ static LogicalResult decomposeToLoadStore(MapScatterOp mapScatterOp,
}
Value indexVector = vectorizationResult->indexVector;
Value maskVector = vectorizationResult->maskVector;
Value flatOutputBuffer = vectorizationResult->flatOutputBuffer;
Value flatOutputBuffer = vectorizationResult->flatOutput;

// Flatten all the index and mask vectors, since the scatter op lowering
// expects 1D vectors.
Expand Down Expand Up @@ -335,7 +349,7 @@ static LogicalResult decomposeToScatter(MapScatterOp mapScatterOp,
}
Value indexVector = vectorizationResult->indexVector;
Value maskVector = vectorizationResult->maskVector;
Value flatOutputBuffer = vectorizationResult->flatOutputBuffer;
Value flatOutput = vectorizationResult->flatOutput;

// Flatten all the vectors, since the scatter op lowering expects 1D vectors.
auto inputType = cast<VectorType>(mapScatterOp.getInputType());
Expand All @@ -356,15 +370,32 @@ static LogicalResult decomposeToScatter(MapScatterOp mapScatterOp,

SmallVector<Value> offsets = {
arith::ConstantIndexOp::create(rewriter, loc, 0)};
SmallVector<Value> operands = {flatOutputBuffer, offsets[0], indexVector,
SmallVector<Value> operands = {flatOutput, offsets[0], indexVector,
maskVector, inputVector};
rewriter.replaceOpWithNewOp<vector::ScatterOp>(
mapScatterOp, /*resultTypes=*/TypeRange{}, operands);

if (mapScatterOp.hasPureBufferSemantics()) {
rewriter.replaceOpWithNewOp<vector::ScatterOp>(
mapScatterOp, /*resultTypes=*/TypeRange{}, operands);
return success();
}

// For tensor outputs, expand the result back to the original shape.
auto scatterOp =
vector::ScatterOp::create(rewriter, loc, flatOutput.getType(), operands);
SmallVector<ReassociationIndices> reassociations;
reassociations.push_back(llvm::to_vector(
llvm::seq<int64_t>(mapScatterOp.getOutputType().getRank())));
SmallVector<OpFoldResult> outputSizes =
tensor::getMixedSizes(rewriter, loc, mapScatterOp.getOutput());
auto expandOp = tensor::ExpandShapeOp::create(
rewriter, loc, mapScatterOp.getOutputType(), scatterOp.getResult(),
reassociations, outputSizes);
rewriter.replaceOp(mapScatterOp, expandOp.getResult());
return success();
}

/// Decompose an `iree_linalg_ext.map_scatter` op with vector input and memref
/// output. This is the main dispatch function that analyzes the `map_scatter`
/// Decompose an `iree_linalg_ext.map_scatter` op with vector input.
/// This is the main dispatch function that analyzes the `map_scatter`
/// operation and chooses the most appropriate decomposition strategy.
///
/// Decomposition strategies (in order of preference):
Expand All @@ -388,7 +419,8 @@ static LogicalResult decomposeMapScatter(MapScatterOp mapScatterOp,
const bool isMaskForwardSlice = maskOp && slice.contains(maskOp);
const bool isUnitFunctionOfInnermostInputIdx =
isUnitFunctionOf(innermostOutputIdx, innermostInputIdx);
if (!isMaskForwardSlice && isUnitFunctionOfInnermostInputIdx) {
if (!isMaskForwardSlice && isUnitFunctionOfInnermostInputIdx &&
mapScatterOp.hasPureBufferSemantics()) {
Comment on lines -391 to +423
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It shouldn't make a difference whether it is memref or tensor here. Can you support this case too (and add tests for it)?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since vector.store only supports memref as an output, we first need to add tensor support for the output.
I think it is better to disable the tensor case for now.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I missed this response earlier. I think you can support the tensor case if you use vector.transfer_write instead of vector.store. The transfer_write should eventually lower into vector.store after vector lowering.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also support the decomposeToLoadStore case by using vector.transfer_write?

return decomposeToLoadStore(mapScatterOp, rewriter);
}
// In case of a sub-byte map_scatter that hasn't been decomposed into a
Expand Down Expand Up @@ -428,14 +460,10 @@ struct DecomposeMapScatterPass final
return;
}

// Decomposition is only supported for map_scatter ops that are both
// vectorized and bufferized. Bufferization is a requirement because
// vector.scatter only takes memref destinations.
// TODO(#21135): Allow tensor outputs when vector.scatter supports tensor
// destinations.
// Decomposition is only supported for map_scatter ops that are vectorized.
SmallVector<MapScatterOp> candidates;
funcOp->walk([&](MapScatterOp op) {
if (isa<VectorType>(op.getInputType()) && op.hasPureBufferSemantics()) {
if (isa<VectorType>(op.getInputType())) {
candidates.push_back(op);
}
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,52 @@ func.func @identity_map_scatter(

// -----
Comment thread
sakupan102 marked this conversation as resolved.

func.func @identity_map_scatter_tensor(
%input: vector<4x16xf32>, %output: tensor<4x16xf32>
) -> tensor<4x16xf32> {
%res = iree_linalg_ext.map_scatter %input into %output {
^bb0(%idx0: index, %idx1: index):
%mask = arith.constant true
iree_linalg_ext.yield %idx0, %idx1, %mask : index, index, i1
} : vector<4x16xf32> into tensor<4x16xf32> -> tensor<4x16xf32>
return %res : tensor<4x16xf32>
}

// CHECK-LABEL: func.func @identity_map_scatter_tensor(
// CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]+]]
// CHECK-SAME: %[[OUTPUT:[a-zA-Z0-9_]+]]
// CHECK-DAG: %[[CST:.+]] = arith.constant dense<16> : vector<4x16xindex>
// CHECK-DAG: %[[CST_TRUE:.+]] = arith.constant dense<true> : vector<64xi1>
// CHECK-DAG: %[[FLAT_OUTPUT:.+]] = tensor.collapse_shape %[[OUTPUT]] {{.*}} tensor<4x16xf32> into tensor<64xf32>
// CHECK: %[[SCATTER:.+]] = vector.scatter
// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[SCATTER]] {{.*}} : tensor<64xf32> into tensor<4x16xf32>
// CHECK: return %[[RESULT]] : tensor<4x16xf32>


// -----

func.func @identity_map_scatter_tensor(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Give this test a different name (like identity_map_scatter_tensor_dynamic).

%input: vector<4x16xf32>, %output: tensor<?x?xf32>
) -> tensor<?x?xf32> {
%res = iree_linalg_ext.map_scatter %input into %output {
^bb0(%idx0: index, %idx1: index):
%mask = arith.constant true
iree_linalg_ext.yield %idx0, %idx1, %mask : index, index, i1
} : vector<4x16xf32> into tensor<?x?xf32> -> tensor<?x?xf32>
return %res : tensor<?x?xf32>
}

// CHECK-LABEL: func.func @identity_map_scatter_tensor(
// CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]+]]
// CHECK-SAME: %[[OUTPUT:[a-zA-Z0-9_]+]]
// CHECK-DAG: %[[CST_TRUE:.+]] = arith.constant dense<true> : vector<64xi1>
// CHECK-DAG: %[[FLAT_OUTPUT:.+]] = tensor.collapse_shape %[[OUTPUT]] {{.*}} tensor<?x?xf32> into tensor<?xf32>
// CHECK: %[[SCATTER:.+]] = vector.scatter
// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[SCATTER]] {{.*}} : tensor<?xf32> into tensor<?x?xf32>
// CHECK: return %[[RESULT]] : tensor<?x?xf32>

// -----

// This test checks all index and mask computations for the `map_scatter` to `vector.scatter` path.
// Other tests shouldn't check this to avoid maintenance burden.

Expand Down
Loading