-
Notifications
You must be signed in to change notification settings - Fork 895
[LinalgExt] Enable tensor map_scatter lowering #22931
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,6 +16,9 @@ | |
| #include "mlir/Dialect/MemRef/IR/MemRef.h" | ||
| #include "mlir/Dialect/Tensor/IR/Tensor.h" | ||
| #include "mlir/Dialect/Vector/IR/VectorOps.h" | ||
| #include "mlir/IR/OpDefinition.h" | ||
| #include "mlir/IR/TypeRange.h" | ||
| #include "mlir/IR/Types.h" | ||
| #include "mlir/Transforms/GreedyPatternRewriteDriver.h" | ||
|
|
||
| #define DEBUG_TYPE "iree-decompose-map-scatter" | ||
|
|
@@ -143,7 +146,7 @@ struct VectorizationResult { | |
| /// The vectorized mask values for conditional stores. | ||
| Value maskVector; | ||
| /// The flattened 1D output buffer. | ||
| Value flatOutputBuffer; | ||
| Value flatOutput; | ||
| }; | ||
|
|
||
| /// Vectorize the index computation and mask evaluation for a `map_scatter` op. | ||
|
|
@@ -161,11 +164,22 @@ performIndexAndMaskVectorization(MapScatterOp mapScatterOp, | |
| Location loc = mapScatterOp.getLoc(); | ||
| OpBuilder::InsertionGuard g(rewriter); | ||
| rewriter.setInsertionPoint(mapScatterOp); | ||
| SmallVector<OpFoldResult> outputSizes = | ||
| memref::getMixedSizes(rewriter, loc, mapScatterOp.getOutput()); | ||
| Value flatOutput; | ||
| SmallVector<Value> strides; | ||
| Value flatOutputBuffer = createFlatOutputBuffer( | ||
| rewriter, loc, mapScatterOp.getOutput(), outputSizes, strides); | ||
| SmallVector<OpFoldResult> outputSizes = | ||
| getDims(rewriter, loc, mapScatterOp.getOutput()); | ||
| if (mapScatterOp.hasPureBufferSemantics()) { | ||
| flatOutput = createFlatOutputBuffer(rewriter, loc, mapScatterOp.getOutput(), | ||
| outputSizes, strides); | ||
| } else { | ||
| // For tensor outputs, create a flat output buffer as an empty tensor. | ||
| auto outputType = cast<TensorType>(mapScatterOp.getOutputType()); | ||
| SmallVector<ReassociationIndices> reassociations; | ||
| reassociations.push_back( | ||
| llvm::to_vector(llvm::seq<int64_t>(outputType.getRank()))); | ||
| flatOutput = tensor::CollapseShapeOp::create( | ||
| rewriter, loc, mapScatterOp.getOutput(), reassociations); | ||
|
Comment on lines
+180
to
+181
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is fine for now, since we don't have any concrete use cases yet, but adding collapse_shape ops on tensors can have more implications than it does for memrefs. It may be fine, but bufferization can often get tripped up by reshapes, which could end up creating extra large allocations. This is part of the challenge with vectorizing map_scatter on tensors, and we need to try it e2e to see if it works or not. That said, I'm not sure what the best way to represent it will be, so we can reevaluate once we are able to test this in an e2e path. |
||
| } | ||
| auto inputType = cast<VectorType>(mapScatterOp.getInputType()); | ||
| auto bodyBuilder = [&](OpBuilder &b, Location nestedLoc, ValueRange args) { | ||
| auto inlineBodyBuilder = [&](OpBuilder inlineBuilder, Location inlineLoc, | ||
|
|
@@ -258,7 +272,7 @@ performIndexAndMaskVectorization(MapScatterOp mapScatterOp, | |
| rewriter.eraseOp(indexWriteOp); | ||
| rewriter.eraseOp(maskWriteOp); | ||
| rewriter.eraseOp(genericOp); | ||
| return VectorizationResult{indexVector, maskVector, flatOutputBuffer}; | ||
| return VectorizationResult{indexVector, maskVector, flatOutput}; | ||
| } | ||
|
|
||
| /// Decompose the `map_scatter` into a sequence of `vector.extract` and | ||
|
|
@@ -275,7 +289,7 @@ static LogicalResult decomposeToLoadStore(MapScatterOp mapScatterOp, | |
| } | ||
| Value indexVector = vectorizationResult->indexVector; | ||
| Value maskVector = vectorizationResult->maskVector; | ||
| Value flatOutputBuffer = vectorizationResult->flatOutputBuffer; | ||
| Value flatOutputBuffer = vectorizationResult->flatOutput; | ||
|
|
||
| // Flatten all the index and mask vectors, since the scatter op lowering | ||
| // expects 1D vectors. | ||
|
|
@@ -335,7 +349,7 @@ static LogicalResult decomposeToScatter(MapScatterOp mapScatterOp, | |
| } | ||
| Value indexVector = vectorizationResult->indexVector; | ||
| Value maskVector = vectorizationResult->maskVector; | ||
| Value flatOutputBuffer = vectorizationResult->flatOutputBuffer; | ||
| Value flatOutput = vectorizationResult->flatOutput; | ||
|
|
||
| // Flatten all the vectors, since the scatter op lowering expects 1D vectors. | ||
| auto inputType = cast<VectorType>(mapScatterOp.getInputType()); | ||
|
|
@@ -356,15 +370,32 @@ static LogicalResult decomposeToScatter(MapScatterOp mapScatterOp, | |
|
|
||
| SmallVector<Value> offsets = { | ||
| arith::ConstantIndexOp::create(rewriter, loc, 0)}; | ||
| SmallVector<Value> operands = {flatOutputBuffer, offsets[0], indexVector, | ||
| SmallVector<Value> operands = {flatOutput, offsets[0], indexVector, | ||
| maskVector, inputVector}; | ||
| rewriter.replaceOpWithNewOp<vector::ScatterOp>( | ||
| mapScatterOp, /*resultTypes=*/TypeRange{}, operands); | ||
|
|
||
| if (mapScatterOp.hasPureBufferSemantics()) { | ||
| rewriter.replaceOpWithNewOp<vector::ScatterOp>( | ||
| mapScatterOp, /*resultTypes=*/TypeRange{}, operands); | ||
| return success(); | ||
| } | ||
|
|
||
| // For tensor outputs, expand the result back to the original shape. | ||
| auto scatterOp = | ||
| vector::ScatterOp::create(rewriter, loc, flatOutput.getType(), operands); | ||
| SmallVector<ReassociationIndices> reassociations; | ||
| reassociations.push_back(llvm::to_vector( | ||
| llvm::seq<int64_t>(mapScatterOp.getOutputType().getRank()))); | ||
| SmallVector<OpFoldResult> outputSizes = | ||
| tensor::getMixedSizes(rewriter, loc, mapScatterOp.getOutput()); | ||
| auto expandOp = tensor::ExpandShapeOp::create( | ||
| rewriter, loc, mapScatterOp.getOutputType(), scatterOp.getResult(), | ||
| reassociations, outputSizes); | ||
| rewriter.replaceOp(mapScatterOp, expandOp.getResult()); | ||
| return success(); | ||
| } | ||
|
|
||
| /// Decompose an `iree_linalg_ext.map_scatter` op with vector input and memref | ||
| /// output. This is the main dispatch function that analyzes the `map_scatter` | ||
| /// Decompose an `iree_linalg_ext.map_scatter` op with vector input. | ||
| /// This is the main dispatch function that analyzes the `map_scatter` | ||
| /// operation and chooses the most appropriate decomposition strategy. | ||
| /// | ||
| /// Decomposition strategies (in order of preference): | ||
|
|
@@ -388,7 +419,8 @@ static LogicalResult decomposeMapScatter(MapScatterOp mapScatterOp, | |
| const bool isMaskForwardSlice = maskOp && slice.contains(maskOp); | ||
| const bool isUnitFunctionOfInnermostInputIdx = | ||
| isUnitFunctionOf(innermostOutputIdx, innermostInputIdx); | ||
| if (!isMaskForwardSlice && isUnitFunctionOfInnermostInputIdx) { | ||
| if (!isMaskForwardSlice && isUnitFunctionOfInnermostInputIdx && | ||
| mapScatterOp.hasPureBufferSemantics()) { | ||
|
Comment on lines
-391
to
+423
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It shouldn't make a difference whether it is memref or tensor here. Can you support this case too (and add tests for it)?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry, I missed this response earlier. I think you can support the tensor case if you use
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you also support the decomposeToLoadStore case by using vector.transfer_write? |
||
| return decomposeToLoadStore(mapScatterOp, rewriter); | ||
| } | ||
| // In case of a sub-byte map_scatter that hasn't been decomposed into a | ||
|
|
@@ -428,14 +460,10 @@ struct DecomposeMapScatterPass final | |
| return; | ||
| } | ||
|
|
||
| // Decomposition is only supported for map_scatter ops that are both | ||
| // vectorized and bufferized. Bufferization is a requirement because | ||
| // vector.scatter only takes memref destinations. | ||
| // TODO(#21135): Allow tensor outputs when vector.scatter supports tensor | ||
| // destinations. | ||
| // Decomposition is only supported for map_scatter ops that are vectorized. | ||
| SmallVector<MapScatterOp> candidates; | ||
| funcOp->walk([&](MapScatterOp op) { | ||
| if (isa<VectorType>(op.getInputType()) && op.hasPureBufferSemantics()) { | ||
| if (isa<VectorType>(op.getInputType())) { | ||
| candidates.push_back(op); | ||
| } | ||
| }); | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -33,6 +33,52 @@ func.func @identity_map_scatter( | |
|
|
||
| // ----- | ||
|
sakupan102 marked this conversation as resolved.
|
||
|
|
||
| func.func @identity_map_scatter_tensor( | ||
| %input: vector<4x16xf32>, %output: tensor<4x16xf32> | ||
| ) -> tensor<4x16xf32> { | ||
| %res = iree_linalg_ext.map_scatter %input into %output { | ||
| ^bb0(%idx0: index, %idx1: index): | ||
| %mask = arith.constant true | ||
| iree_linalg_ext.yield %idx0, %idx1, %mask : index, index, i1 | ||
| } : vector<4x16xf32> into tensor<4x16xf32> -> tensor<4x16xf32> | ||
| return %res : tensor<4x16xf32> | ||
| } | ||
|
|
||
| // CHECK-LABEL: func.func @identity_map_scatter_tensor( | ||
| // CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]+]] | ||
| // CHECK-SAME: %[[OUTPUT:[a-zA-Z0-9_]+]] | ||
| // CHECK-DAG: %[[CST:.+]] = arith.constant dense<16> : vector<4x16xindex> | ||
| // CHECK-DAG: %[[CST_TRUE:.+]] = arith.constant dense<true> : vector<64xi1> | ||
| // CHECK-DAG: %[[FLAT_OUTPUT:.+]] = tensor.collapse_shape %[[OUTPUT]] {{.*}} tensor<4x16xf32> into tensor<64xf32> | ||
| // CHECK: %[[SCATTER:.+]] = vector.scatter | ||
| // CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[SCATTER]] {{.*}} : tensor<64xf32> into tensor<4x16xf32> | ||
| // CHECK: return %[[RESULT]] : tensor<4x16xf32> | ||
|
|
||
|
|
||
| // ----- | ||
|
|
||
| func.func @identity_map_scatter_tensor( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Give this test a different name (like identity_map_scatter_tensor_dynamic). |
||
| %input: vector<4x16xf32>, %output: tensor<?x?xf32> | ||
| ) -> tensor<?x?xf32> { | ||
| %res = iree_linalg_ext.map_scatter %input into %output { | ||
| ^bb0(%idx0: index, %idx1: index): | ||
| %mask = arith.constant true | ||
| iree_linalg_ext.yield %idx0, %idx1, %mask : index, index, i1 | ||
| } : vector<4x16xf32> into tensor<?x?xf32> -> tensor<?x?xf32> | ||
| return %res : tensor<?x?xf32> | ||
| } | ||
|
|
||
| // CHECK-LABEL: func.func @identity_map_scatter_tensor( | ||
| // CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]+]] | ||
| // CHECK-SAME: %[[OUTPUT:[a-zA-Z0-9_]+]] | ||
| // CHECK-DAG: %[[CST_TRUE:.+]] = arith.constant dense<true> : vector<64xi1> | ||
| // CHECK-DAG: %[[FLAT_OUTPUT:.+]] = tensor.collapse_shape %[[OUTPUT]] {{.*}} tensor<?x?xf32> into tensor<?xf32> | ||
| // CHECK: %[[SCATTER:.+]] = vector.scatter | ||
| // CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[SCATTER]] {{.*}} : tensor<?xf32> into tensor<?x?xf32> | ||
| // CHECK: return %[[RESULT]] : tensor<?x?xf32> | ||
|
|
||
| // ----- | ||
|
|
||
| // This test checks all index and mask computations for the `map_scatter` to `vector.scatter` path. | ||
| // Other tests shouldn't check this to avoid maintenance burden. | ||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.