-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[MLIR] Extend vector.scatter to accept tensor as base #165548
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Thank you for submitting a Pull Request (PR) to the LLVM Project! This PR will be automatically labeled and the relevant teams will be notified. If you wish to, you can add reviewers by using the "Reviewers" section on this page. If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers. If you have further questions, they may be answered by the LLVM GitHub User Guide. You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums. |
|
@llvm/pr-subscribers-mlir-sparse @llvm/pr-subscribers-mlir-vector Author: Ryutaro Okada (sakupan102) ChangesIn addition to memref, accept ranked tensor as the base operand of vector.scatter, similar to vector.trasnfer_write. It's worth to complete the functionality of map_scatter decomposition. Full discussion can be found here: iree-org/iree#21135 (comment) Full diff: https://github.com/llvm/llvm-project/pull/165548.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 6e15b1e7df606..db1b9e169608b 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2150,22 +2150,24 @@ def Vector_GatherOp :
];
}
-def Vector_ScatterOp :
- Vector_Op<"scatter", [DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>]>,
- Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
- Variadic<Index>:$offsets,
- VectorOfNonZeroRankOf<[AnyInteger, Index]>:$indices,
- VectorOfNonZeroRankOf<[I1]>:$mask,
- AnyVectorOfNonZeroRank:$valueToStore,
- OptionalAttr<IntValidAlignment<I64Attr>>: $alignment)> {
+def Vector_ScatterOp
+ : Vector_Op<"scatter", [DeclareOpInterfaceMethods<
+ MemorySpaceCastConsumerOpInterface>]>,
+ Arguments<(ins Arg<TensorOrMemRef<[AnyType]>, "", [MemWrite]>:$base,
+ Variadic<Index>:$offsets,
+ VectorOfNonZeroRankOf<[AnyInteger, Index]>:$indices,
+ VectorOfNonZeroRankOf<[I1]>:$mask,
+ AnyVectorOfNonZeroRank:$valueToStore,
+ OptionalAttr<IntValidAlignment<I64Attr>>:$alignment)>,
+ Results<(outs Optional<AnyRankedTensor>:$result)> {
let summary = [{
- scatters elements from a vector into memory as defined by an index vector
+ scatters elements from a vector into memory or ranked tensor as defined by an index vector
and a mask vector
}];
let description = [{
- The scatter operation stores elements from a n-D vector into memory as
+ The scatter operation stores elements from a n-D vector into memory or ranked tensor as
defined by a base with indices and an additional n-D index vector, but
only if the corresponding bit in a n-D mask vector is set. Otherwise, no
action is taken for that element. Informally the semantics are:
@@ -2208,31 +2210,31 @@ def Vector_ScatterOp :
}];
let extraClassDeclaration = [{
- MemRefType getMemRefType() { return getBase().getType(); }
+ ShapedType getBaseType() { return getBase().getType(); }
VectorType getIndexVectorType() { return getIndices().getType(); }
VectorType getMaskVectorType() { return getMask().getType(); }
VectorType getVectorType() { return getValueToStore().getType(); }
}];
- let assemblyFormat =
- "$base `[` $offsets `]` `[` $indices `]` `,` "
- "$mask `,` $valueToStore attr-dict `:` type($base) `,` "
- "type($indices) `,` type($mask) `,` type($valueToStore)";
+ let assemblyFormat = "$base `[` $offsets `]` `[` $indices `]` `,` "
+ "$mask `,` $valueToStore attr-dict `:` type($base) `,` "
+ "type($indices) `,` type($mask) `,` "
+ "type($valueToStore) (`->` type($result)^)?";
let hasCanonicalizer = 1;
let hasVerifier = 1;
- let builders = [
- OpBuilder<(ins "Value":$base,
- "ValueRange":$indices,
- "Value":$index_vec,
- "Value":$mask,
- "Value":$valueToStore,
- CArg<"llvm::MaybeAlign", "llvm::MaybeAlign()">: $alignment), [{
- return build($_builder, $_state, base, indices, index_vec, mask, valueToStore,
+ let builders =
+ [OpBuilder<(ins "Type":$resultType, "Value":$base, "ValueRange":$indices,
+ "Value":$index_vec, "Value":$mask, "Value":$valueToStore,
+ CArg<"llvm::MaybeAlign", "llvm::MaybeAlign()">:$alignment),
+ [{
+ return build($_builder, $_state, resultType, base, indices, index_vec, mask, valueToStore,
alignment.has_value() ? $_builder.getI64IntegerAttr(alignment->value()) :
nullptr);
- }]>
- ];
+ }]>,
+ OpBuilder<(ins "Value":$base, "ValueRange":$indices, "Value":$index_vec,
+ "Value":$mask, "Value":$valueToStore,
+ CArg<"llvm::MaybeAlign", "llvm::MaybeAlign()">:$alignment)>];
}
def Vector_ExpandLoadOp :
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 41d8d532757ad..cb65d787ea854 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -345,7 +345,8 @@ class VectorScatterOpConversion
matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = scatter->getLoc();
- MemRefType memRefType = scatter.getMemRefType();
+ MemRefType memRefType = dyn_cast<MemRefType>(scatter.getBaseType());
+ assert(memRefType && "The base should be bufferized");
if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
return rewriter.notifyMatchFailure(scatter, "memref type not supported");
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index ad8255a95cb4e..b4a8737107c8d 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6058,12 +6058,15 @@ LogicalResult ScatterOp::verify() {
VectorType indVType = getIndexVectorType();
VectorType maskVType = getMaskVectorType();
VectorType valueVType = getVectorType();
- MemRefType memType = getMemRefType();
+ ShapedType baseType = getBaseType();
- if (valueVType.getElementType() != memType.getElementType())
+ if (!llvm::isa<MemRefType, RankedTensorType>(baseType))
+ return emitOpError("requires base to be a memref or ranked tensor type");
+
+ if (valueVType.getElementType() != baseType.getElementType())
return emitOpError("base and valueToStore element type should match");
- if (llvm::size(getOffsets()) != memType.getRank())
- return emitOpError("requires ") << memType.getRank() << " indices";
+ if (llvm::size(getOffsets()) != baseType.getRank())
+ return emitOpError("requires ") << baseType.getRank() << " indices";
if (valueVType.getShape() != indVType.getShape())
return emitOpError("expected valueToStore dim to match indices dim");
if (valueVType.getShape() != maskVType.getShape())
@@ -6071,6 +6074,14 @@ LogicalResult ScatterOp::verify() {
return success();
}
+void ScatterOp::build(OpBuilder &builder, OperationState &result, Value base,
+ ValueRange indices, Value index_vec, Value mask,
+ Value valueToStore, llvm::MaybeAlign alignment) {
+ Type resultType = llvm::dyn_cast<RankedTensorType>(base.getType());
+ build(builder, result, resultType, base, indices, index_vec, mask,
+ valueToStore, alignment);
+}
+
namespace {
class ScatterFolder final : public OpRewritePattern<ScatterOp> {
public:
diff --git a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
index 546099ca975b7..eb11253ec647c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -14,6 +14,7 @@
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Operation.h"
+#include "mlir/IR/Value.h"
using namespace mlir;
using namespace mlir::bufferization;
@@ -126,6 +127,51 @@ struct TransferWriteOpInterface
}
};
+/// Bufferization of vector.scatter. Replaced with a new vector.scatter that
+/// operates on a memref.
+struct ScatterOpInterface
+ : public BufferizableOpInterface::ExternalModel<ScatterOpInterface,
+ vector::ScatterOp> {
+ bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+ const AnalysisState &state) const {
+ assert(isa<RankedTensorType>(opOperand.get().getType()) &&
+ "only tensor types expected");
+ return true;
+ }
+ bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+ const AnalysisState &state) const {
+ assert(isa<RankedTensorType>(opOperand.get().getType()) &&
+ "only tensor types expected");
+ return true;
+ }
+ AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
+ const AnalysisState &state) const {
+ auto scatterOp = cast<vector::ScatterOp>(op);
+ if (&opOperand != &scatterOp.getBaseMutable())
+ return {};
+ if (op->getNumResults() == 0)
+ return {};
+ return {{scatterOp.getResult(), BufferRelation::Equivalent}};
+ }
+
+ LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+ const BufferizationOptions &options,
+ BufferizationState &state) const {
+ ScatterOp scatterOp = cast<vector::ScatterOp>(op);
+ assert(isa<TensorType>(scatterOp.getBaseType()) &&
+ "only tensor types expected");
+ FailureOr<Value> buffer =
+ getBuffer(rewriter, scatterOp.getBase(), options, state);
+ if (failed(buffer))
+ return failure();
+ vector::ScatterOp::create(rewriter, scatterOp.getLoc(), *buffer,
+ scatterOp.getOffsets(), scatterOp.getIndices(),
+ scatterOp.getMask(), scatterOp.getValueToStore());
+ replaceOpWithBufferizedValues(rewriter, op, *buffer);
+ return success();
+ }
+};
+
/// Bufferization of vector.gather. Replaced with a new vector.gather that
/// operates on a memref.
struct GatherOpInterface
@@ -335,5 +381,6 @@ void mlir::vector::registerBufferizableOpInterfaceExternalModels(
GatherOp::attachInterface<GatherOpInterface>(*ctx);
MaskOp::attachInterface<MaskOpInterface>(*ctx);
YieldOp::attachInterface<YieldOpInterface>(*ctx);
+ ScatterOp::attachInterface<ScatterOpInterface>(*ctx);
});
}
diff --git a/mlir/test/Dialect/Vector/bufferize.mlir b/mlir/test/Dialect/Vector/bufferize.mlir
index 887fb941cc651..70adefd0dc4ec 100644
--- a/mlir/test/Dialect/Vector/bufferize.mlir
+++ b/mlir/test/Dialect/Vector/bufferize.mlir
@@ -32,6 +32,26 @@ func.func @transfer_write(%t: tensor<?x?xf32>, %o1: index,
// -----
+// CHECK-LABEL: func @scatter(
+// CHECK-SAME: %[[base:.*]]: tensor<16x16xf32>, %[[v:.*]]: vector<16xi32>,
+// CHECK-SAME: %[[mask:.*]]: vector<16xi1>, %[[value:.*]]: vector<16xf32>) -> tensor<16x16xf32>
+// CHECK: %[[buf:.*]] = bufferization.to_buffer %[[base]] : tensor<16x16xf32> to memref<16x16xf32>
+// CHECK: %[[c0:.*]] = arith.constant 0 : index
+// CHECK: %[[alloc:.*]] = memref.alloc() {alignment = 64 : i64} : memref<16x16xf32>
+// CHECK: memref.copy %[[buf]], %[[alloc]] : memref<16x16xf32> to memref<16x16xf32>
+// CHECK: vector.scatter %[[alloc]][%[[c0]], %[[c0]]] [%[[v]]], %[[mask]], %[[value]] : memref<16x16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
+// CHECK: %[[tensor:.*]] = bufferization.to_tensor %[[alloc]] : memref<16x16xf32> to tensor<16x16xf32>
+// CHECK: return %[[tensor]] : tensor<16x16xf32>
+func.func @scatter(%base: tensor<16x16xf32>, %v: vector<16xi32>,
+ %mask: vector<16xi1>, %value: vector<16xf32>) -> tensor<16x16xf32> {
+ %c0 = arith.constant 0 : index
+ %0 = vector.scatter %base[%c0, %c0][%v], %mask, %value
+ : tensor<16x16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> -> tensor<16x16xf32>
+ return %0 : tensor<16x16xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @gather(
// CHECK-SAME: %[[base:.*]]: tensor<?x?xf32>, %[[v:.*]]: vector<16xi32>,
// CHECK-SAME: %[[mask:.*]]: vector<16xi1>, %[[pass_thru:.*]]: vector<16xf32>)
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 5f035e35a1b86..79b09e172145b 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1491,9 +1491,9 @@ func.func @gather_non_power_of_two_alignment(%base: memref<16xf32>, %indices: ve
func.func @scatter_to_vector(%base: vector<16xf32>, %indices: vector<16xi32>,
%mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
%c0 = arith.constant 0 : index
- // expected-error@+2 {{custom op 'vector.scatter' invalid kind of type specified}}
+ // expected-error@+1 {{'vector.scatter' op operand #0 must be Tensor or MemRef of any type values, but got 'vector<16xf32>'}}
vector.scatter %base[%c0][%indices], %mask, %pass_thru
- : vector<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+ : vector<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
}
// -----
|
|
@llvm/pr-subscribers-mlir Author: Ryutaro Okada (sakupan102) ChangesIn addition to memref, accept ranked tensor as the base operand of vector.scatter, similar to vector.trasnfer_write. It's worth to complete the functionality of map_scatter decomposition. Full discussion can be found here: iree-org/iree#21135 (comment) Full diff: https://github.com/llvm/llvm-project/pull/165548.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 6e15b1e7df606..db1b9e169608b 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2150,22 +2150,24 @@ def Vector_GatherOp :
];
}
-def Vector_ScatterOp :
- Vector_Op<"scatter", [DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>]>,
- Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
- Variadic<Index>:$offsets,
- VectorOfNonZeroRankOf<[AnyInteger, Index]>:$indices,
- VectorOfNonZeroRankOf<[I1]>:$mask,
- AnyVectorOfNonZeroRank:$valueToStore,
- OptionalAttr<IntValidAlignment<I64Attr>>: $alignment)> {
+def Vector_ScatterOp
+ : Vector_Op<"scatter", [DeclareOpInterfaceMethods<
+ MemorySpaceCastConsumerOpInterface>]>,
+ Arguments<(ins Arg<TensorOrMemRef<[AnyType]>, "", [MemWrite]>:$base,
+ Variadic<Index>:$offsets,
+ VectorOfNonZeroRankOf<[AnyInteger, Index]>:$indices,
+ VectorOfNonZeroRankOf<[I1]>:$mask,
+ AnyVectorOfNonZeroRank:$valueToStore,
+ OptionalAttr<IntValidAlignment<I64Attr>>:$alignment)>,
+ Results<(outs Optional<AnyRankedTensor>:$result)> {
let summary = [{
- scatters elements from a vector into memory as defined by an index vector
+ scatters elements from a vector into memory or ranked tensor as defined by an index vector
and a mask vector
}];
let description = [{
- The scatter operation stores elements from a n-D vector into memory as
+ The scatter operation stores elements from a n-D vector into memory or ranked tensor as
defined by a base with indices and an additional n-D index vector, but
only if the corresponding bit in a n-D mask vector is set. Otherwise, no
action is taken for that element. Informally the semantics are:
@@ -2208,31 +2210,31 @@ def Vector_ScatterOp :
}];
let extraClassDeclaration = [{
- MemRefType getMemRefType() { return getBase().getType(); }
+ ShapedType getBaseType() { return getBase().getType(); }
VectorType getIndexVectorType() { return getIndices().getType(); }
VectorType getMaskVectorType() { return getMask().getType(); }
VectorType getVectorType() { return getValueToStore().getType(); }
}];
- let assemblyFormat =
- "$base `[` $offsets `]` `[` $indices `]` `,` "
- "$mask `,` $valueToStore attr-dict `:` type($base) `,` "
- "type($indices) `,` type($mask) `,` type($valueToStore)";
+ let assemblyFormat = "$base `[` $offsets `]` `[` $indices `]` `,` "
+ "$mask `,` $valueToStore attr-dict `:` type($base) `,` "
+ "type($indices) `,` type($mask) `,` "
+ "type($valueToStore) (`->` type($result)^)?";
let hasCanonicalizer = 1;
let hasVerifier = 1;
- let builders = [
- OpBuilder<(ins "Value":$base,
- "ValueRange":$indices,
- "Value":$index_vec,
- "Value":$mask,
- "Value":$valueToStore,
- CArg<"llvm::MaybeAlign", "llvm::MaybeAlign()">: $alignment), [{
- return build($_builder, $_state, base, indices, index_vec, mask, valueToStore,
+ let builders =
+ [OpBuilder<(ins "Type":$resultType, "Value":$base, "ValueRange":$indices,
+ "Value":$index_vec, "Value":$mask, "Value":$valueToStore,
+ CArg<"llvm::MaybeAlign", "llvm::MaybeAlign()">:$alignment),
+ [{
+ return build($_builder, $_state, resultType, base, indices, index_vec, mask, valueToStore,
alignment.has_value() ? $_builder.getI64IntegerAttr(alignment->value()) :
nullptr);
- }]>
- ];
+ }]>,
+ OpBuilder<(ins "Value":$base, "ValueRange":$indices, "Value":$index_vec,
+ "Value":$mask, "Value":$valueToStore,
+ CArg<"llvm::MaybeAlign", "llvm::MaybeAlign()">:$alignment)>];
}
def Vector_ExpandLoadOp :
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 41d8d532757ad..cb65d787ea854 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -345,7 +345,8 @@ class VectorScatterOpConversion
matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = scatter->getLoc();
- MemRefType memRefType = scatter.getMemRefType();
+ MemRefType memRefType = dyn_cast<MemRefType>(scatter.getBaseType());
+ assert(memRefType && "The base should be bufferized");
if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
return rewriter.notifyMatchFailure(scatter, "memref type not supported");
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index ad8255a95cb4e..b4a8737107c8d 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6058,12 +6058,15 @@ LogicalResult ScatterOp::verify() {
VectorType indVType = getIndexVectorType();
VectorType maskVType = getMaskVectorType();
VectorType valueVType = getVectorType();
- MemRefType memType = getMemRefType();
+ ShapedType baseType = getBaseType();
- if (valueVType.getElementType() != memType.getElementType())
+ if (!llvm::isa<MemRefType, RankedTensorType>(baseType))
+ return emitOpError("requires base to be a memref or ranked tensor type");
+
+ if (valueVType.getElementType() != baseType.getElementType())
return emitOpError("base and valueToStore element type should match");
- if (llvm::size(getOffsets()) != memType.getRank())
- return emitOpError("requires ") << memType.getRank() << " indices";
+ if (llvm::size(getOffsets()) != baseType.getRank())
+ return emitOpError("requires ") << baseType.getRank() << " indices";
if (valueVType.getShape() != indVType.getShape())
return emitOpError("expected valueToStore dim to match indices dim");
if (valueVType.getShape() != maskVType.getShape())
@@ -6071,6 +6074,14 @@ LogicalResult ScatterOp::verify() {
return success();
}
+void ScatterOp::build(OpBuilder &builder, OperationState &result, Value base,
+ ValueRange indices, Value index_vec, Value mask,
+ Value valueToStore, llvm::MaybeAlign alignment) {
+ Type resultType = llvm::dyn_cast<RankedTensorType>(base.getType());
+ build(builder, result, resultType, base, indices, index_vec, mask,
+ valueToStore, alignment);
+}
+
namespace {
class ScatterFolder final : public OpRewritePattern<ScatterOp> {
public:
diff --git a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
index 546099ca975b7..eb11253ec647c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -14,6 +14,7 @@
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Operation.h"
+#include "mlir/IR/Value.h"
using namespace mlir;
using namespace mlir::bufferization;
@@ -126,6 +127,51 @@ struct TransferWriteOpInterface
}
};
+/// Bufferization of vector.scatter. Replaced with a new vector.scatter that
+/// operates on a memref.
+struct ScatterOpInterface
+ : public BufferizableOpInterface::ExternalModel<ScatterOpInterface,
+ vector::ScatterOp> {
+ bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+ const AnalysisState &state) const {
+ assert(isa<RankedTensorType>(opOperand.get().getType()) &&
+ "only tensor types expected");
+ return true;
+ }
+ bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+ const AnalysisState &state) const {
+ assert(isa<RankedTensorType>(opOperand.get().getType()) &&
+ "only tensor types expected");
+ return true;
+ }
+ AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
+ const AnalysisState &state) const {
+ auto scatterOp = cast<vector::ScatterOp>(op);
+ if (&opOperand != &scatterOp.getBaseMutable())
+ return {};
+ if (op->getNumResults() == 0)
+ return {};
+ return {{scatterOp.getResult(), BufferRelation::Equivalent}};
+ }
+
+ LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+ const BufferizationOptions &options,
+ BufferizationState &state) const {
+ ScatterOp scatterOp = cast<vector::ScatterOp>(op);
+ assert(isa<TensorType>(scatterOp.getBaseType()) &&
+ "only tensor types expected");
+ FailureOr<Value> buffer =
+ getBuffer(rewriter, scatterOp.getBase(), options, state);
+ if (failed(buffer))
+ return failure();
+ vector::ScatterOp::create(rewriter, scatterOp.getLoc(), *buffer,
+ scatterOp.getOffsets(), scatterOp.getIndices(),
+ scatterOp.getMask(), scatterOp.getValueToStore());
+ replaceOpWithBufferizedValues(rewriter, op, *buffer);
+ return success();
+ }
+};
+
/// Bufferization of vector.gather. Replaced with a new vector.gather that
/// operates on a memref.
struct GatherOpInterface
@@ -335,5 +381,6 @@ void mlir::vector::registerBufferizableOpInterfaceExternalModels(
GatherOp::attachInterface<GatherOpInterface>(*ctx);
MaskOp::attachInterface<MaskOpInterface>(*ctx);
YieldOp::attachInterface<YieldOpInterface>(*ctx);
+ ScatterOp::attachInterface<ScatterOpInterface>(*ctx);
});
}
diff --git a/mlir/test/Dialect/Vector/bufferize.mlir b/mlir/test/Dialect/Vector/bufferize.mlir
index 887fb941cc651..70adefd0dc4ec 100644
--- a/mlir/test/Dialect/Vector/bufferize.mlir
+++ b/mlir/test/Dialect/Vector/bufferize.mlir
@@ -32,6 +32,26 @@ func.func @transfer_write(%t: tensor<?x?xf32>, %o1: index,
// -----
+// CHECK-LABEL: func @scatter(
+// CHECK-SAME: %[[base:.*]]: tensor<16x16xf32>, %[[v:.*]]: vector<16xi32>,
+// CHECK-SAME: %[[mask:.*]]: vector<16xi1>, %[[value:.*]]: vector<16xf32>) -> tensor<16x16xf32>
+// CHECK: %[[buf:.*]] = bufferization.to_buffer %[[base]] : tensor<16x16xf32> to memref<16x16xf32>
+// CHECK: %[[c0:.*]] = arith.constant 0 : index
+// CHECK: %[[alloc:.*]] = memref.alloc() {alignment = 64 : i64} : memref<16x16xf32>
+// CHECK: memref.copy %[[buf]], %[[alloc]] : memref<16x16xf32> to memref<16x16xf32>
+// CHECK: vector.scatter %[[alloc]][%[[c0]], %[[c0]]] [%[[v]]], %[[mask]], %[[value]] : memref<16x16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
+// CHECK: %[[tensor:.*]] = bufferization.to_tensor %[[alloc]] : memref<16x16xf32> to tensor<16x16xf32>
+// CHECK: return %[[tensor]] : tensor<16x16xf32>
+func.func @scatter(%base: tensor<16x16xf32>, %v: vector<16xi32>,
+ %mask: vector<16xi1>, %value: vector<16xf32>) -> tensor<16x16xf32> {
+ %c0 = arith.constant 0 : index
+ %0 = vector.scatter %base[%c0, %c0][%v], %mask, %value
+ : tensor<16x16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> -> tensor<16x16xf32>
+ return %0 : tensor<16x16xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @gather(
// CHECK-SAME: %[[base:.*]]: tensor<?x?xf32>, %[[v:.*]]: vector<16xi32>,
// CHECK-SAME: %[[mask:.*]]: vector<16xi1>, %[[pass_thru:.*]]: vector<16xf32>)
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 5f035e35a1b86..79b09e172145b 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1491,9 +1491,9 @@ func.func @gather_non_power_of_two_alignment(%base: memref<16xf32>, %indices: ve
func.func @scatter_to_vector(%base: vector<16xf32>, %indices: vector<16xi32>,
%mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
%c0 = arith.constant 0 : index
- // expected-error@+2 {{custom op 'vector.scatter' invalid kind of type specified}}
+ // expected-error@+1 {{'vector.scatter' op operand #0 must be Tensor or MemRef of any type values, but got 'vector<16xf32>'}}
vector.scatter %base[%c0][%indices], %mask, %pass_thru
- : vector<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+ : vector<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
}
// -----
|
|
@hanhanW |
hanhanW
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey folks, I'm the one who proposed the idea and the author goes ahead and implements it. It'd be weird if I approve the change and land the change. Does the direction look okay to you?
IMO, it looks okay because vector.gather op also has the tensor semantics.
(On the other hand, I'm quite busy on IREE integrate things and I can help review next week. Many thanks!)
|
@Groverkss can you review? |
Aligning |
In addition to memref, accept ranked tensor as the base operand of vector.scatter, similar to vector.trasnfer_write. It's worth to complete the functionality of map_scatter decomposition. Full discussion can be found here: iree-org/iree#21135 (comment) Signed-off-by: Ryutaro Okada <[email protected]>
f4c4451 to
482aeb4
Compare
keshavvinayak01
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
hanhanW
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This does not only add memref semantics to the op, but also implement the bufferization. Please update the PR description to reflect all the changes.
@matthias-springer can you help review the bufferization logic?
| LogicalResult bufferize(Operation *op, RewriterBase &rewriter, | ||
| const BufferizationOptions &options, | ||
| BufferizationState &state) const { | ||
| ScatterOp scatterOp = cast<vector::ScatterOp>(op); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: auto
mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
Outdated
Show resolved
Hide resolved
matthias-springer
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bufferization part looks good.
Signed-off-by: Ryutaro Okada <[email protected]>
Signed-off-by: Ryutaro Okada <[email protected]>
hanhanW
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, please wait for an approval from other vector dialect maintainers.
dcaballe
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the change! LGTM in general. Could you add some tensor-based tests to Vector/ops.mlir?
Ideally, we might have wanted two separate PRs for this, one modifying and testing the new op semantics and another one introducing bufferization support. However, the change is not huge so it looks OK to me.
Signed-off-by: Ryutaro Okada <[email protected]>
|
@dcaballe |
Signed-off-by: Ryutaro Okada <[email protected]>
607313e to
217661d
Compare
|
@sakupan102 Congratulations on having your first Pull Request (PR) merged into the LLVM Project! Your changes will be combined with recent changes from other authors, then tested by our build bots. If there is a problem with a build, you may receive a report in an email or a comment on this PR. Please check whether problems have been caused by your change specifically, as the builds can include changes from many authors. It is not uncommon for your change to be included in a build that fails due to someone else's changes, or infrastructure issues. How to do this, and the rest of the post-merge process, is covered in detail here. If your change does cause a problem, it may be reverted, or you can revert it yourself. This is a normal part of LLVM development. You can fix your changes and open a new PR to merge them again. If you don't get any reports, no action is required from you. Your changes are working as expected, well done! |
This PR makes the following improvements to
vector.scatterand its lowering pipeline:memref, accept a rankedtensoras the base operand ofvector.scatter, similar tovector.transfer_write.vector.scatter, so that tensor-based scatter ops can be fully lowered to memref-based forms.It's worth to complete the functionality of map_scatter decomposition. Full discussion can be found here: iree-org/iree#21135