Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 28 additions & 28 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2160,25 +2160,25 @@ def Vector_GatherOp :
];
}

def Vector_ScatterOp :
Vector_Op<"scatter", [
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
DeclareOpInterfaceMethods<AlignmentAttrOpInterface>
]>,
Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
Variadic<Index>:$offsets,
VectorOfNonZeroRankOf<[AnyInteger, Index]>:$indices,
VectorOfNonZeroRankOf<[I1]>:$mask,
AnyVectorOfNonZeroRank:$valueToStore,
OptionalAttr<IntValidAlignment<I64Attr>>: $alignment)> {
def Vector_ScatterOp
: Vector_Op<"scatter",
[DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
DeclareOpInterfaceMethods<AlignmentAttrOpInterface>]>,
Arguments<(ins Arg<TensorOrMemRef<[AnyType]>, "", [MemWrite]>:$base,
Variadic<Index>:$offsets,
VectorOfNonZeroRankOf<[AnyInteger, Index]>:$indices,
VectorOfNonZeroRankOf<[I1]>:$mask,
AnyVectorOfNonZeroRank:$valueToStore,
OptionalAttr<IntValidAlignment<I64Attr>>:$alignment)>,
Results<(outs Optional<AnyRankedTensor>:$result)> {

let summary = [{
scatters elements from a vector into memory as defined by an index vector
scatters elements from a vector into memory or ranked tensor as defined by an index vector
and a mask vector
}];

let description = [{
The scatter operation stores elements from a n-D vector into memory as
The scatter operation stores elements from a n-D vector into memory or ranked tensor as
defined by a base with indices and an additional n-D index vector, but
only if the corresponding bit in a n-D mask vector is set. Otherwise, no
action is taken for that element. Informally the semantics are:
Expand Down Expand Up @@ -2221,31 +2221,31 @@ def Vector_ScatterOp :
}];

let extraClassDeclaration = [{
MemRefType getMemRefType() { return getBase().getType(); }
ShapedType getBaseType() { return getBase().getType(); }
VectorType getIndexVectorType() { return getIndices().getType(); }
VectorType getMaskVectorType() { return getMask().getType(); }
VectorType getVectorType() { return getValueToStore().getType(); }
}];

let assemblyFormat =
"$base `[` $offsets `]` `[` $indices `]` `,` "
"$mask `,` $valueToStore attr-dict `:` type($base) `,` "
"type($indices) `,` type($mask) `,` type($valueToStore)";
let assemblyFormat = "$base `[` $offsets `]` `[` $indices `]` `,` "
"$mask `,` $valueToStore attr-dict `:` type($base) `,` "
"type($indices) `,` type($mask) `,` "
"type($valueToStore) (`->` type($result)^)?";
let hasCanonicalizer = 1;
let hasVerifier = 1;

let builders = [
OpBuilder<(ins "Value":$base,
"ValueRange":$indices,
"Value":$index_vec,
"Value":$mask,
"Value":$valueToStore,
CArg<"llvm::MaybeAlign", "llvm::MaybeAlign()">: $alignment), [{
return build($_builder, $_state, base, indices, index_vec, mask, valueToStore,
let builders =
[OpBuilder<(ins "Type":$resultType, "Value":$base, "ValueRange":$indices,
"Value":$index_vec, "Value":$mask, "Value":$valueToStore,
CArg<"llvm::MaybeAlign", "llvm::MaybeAlign()">:$alignment),
[{
return build($_builder, $_state, resultType, base, indices, index_vec, mask, valueToStore,
alignment.has_value() ? $_builder.getI64IntegerAttr(alignment->value()) :
nullptr);
}]>
];
}]>,
OpBuilder<(ins "Value":$base, "ValueRange":$indices, "Value":$index_vec,
"Value":$mask, "Value":$valueToStore,
CArg<"llvm::MaybeAlign", "llvm::MaybeAlign()">:$alignment)>];
}

def Vector_ExpandLoadOp :
Expand Down
3 changes: 2 additions & 1 deletion mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,8 @@ class VectorScatterOpConversion
matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = scatter->getLoc();
MemRefType memRefType = scatter.getMemRefType();
MemRefType memRefType = dyn_cast<MemRefType>(scatter.getBaseType());
assert(memRefType && "The base should be bufferized");

if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
return rewriter.notifyMatchFailure(scatter, "memref type not supported");
Expand Down
19 changes: 15 additions & 4 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6066,19 +6066,30 @@ LogicalResult ScatterOp::verify() {
VectorType indVType = getIndexVectorType();
VectorType maskVType = getMaskVectorType();
VectorType valueVType = getVectorType();
MemRefType memType = getMemRefType();
ShapedType baseType = getBaseType();

if (valueVType.getElementType() != memType.getElementType())
if (!llvm::isa<MemRefType, RankedTensorType>(baseType))
return emitOpError("requires base to be a memref or ranked tensor type");

if (valueVType.getElementType() != baseType.getElementType())
return emitOpError("base and valueToStore element type should match");
if (llvm::size(getOffsets()) != memType.getRank())
return emitOpError("requires ") << memType.getRank() << " indices";
if (llvm::size(getOffsets()) != baseType.getRank())
return emitOpError("requires ") << baseType.getRank() << " indices";
if (valueVType.getShape() != indVType.getShape())
return emitOpError("expected valueToStore dim to match indices dim");
if (valueVType.getShape() != maskVType.getShape())
return emitOpError("expected valueToStore dim to match mask dim");
return success();
}

void ScatterOp::build(OpBuilder &builder, OperationState &result, Value base,
ValueRange indices, Value index_vec, Value mask,
Value valueToStore, llvm::MaybeAlign alignment) {
Type resultType = llvm::dyn_cast<RankedTensorType>(base.getType());
build(builder, result, resultType, base, indices, index_vec, mask,
valueToStore, alignment);
}

namespace {
class ScatterFolder final : public OpRewritePattern<ScatterOp> {
public:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Value.h"

using namespace mlir;
using namespace mlir::bufferization;
Expand Down Expand Up @@ -126,6 +127,51 @@ struct TransferWriteOpInterface
}
};

/// Bufferization of vector.scatter. Replaced with a new vector.scatter that
/// operates on a memref.
struct ScatterOpInterface
: public BufferizableOpInterface::ExternalModel<ScatterOpInterface,
vector::ScatterOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
assert(isa<RankedTensorType>(opOperand.get().getType()) &&
"only tensor types expected");
return true;
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
assert(isa<RankedTensorType>(opOperand.get().getType()) &&
"only tensor types expected");
return true;
}
AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
auto scatterOp = cast<vector::ScatterOp>(op);
if (&opOperand != &scatterOp.getBaseMutable())
return {};
if (op->getNumResults() == 0)
return {};
return {{scatterOp.getResult(), BufferRelation::Equivalent}};
}

LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options,
BufferizationState &state) const {
ScatterOp scatterOp = cast<vector::ScatterOp>(op);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: auto

assert(isa<TensorType>(scatterOp.getBaseType()) &&
"only tensor types expected");
FailureOr<Value> buffer =
getBuffer(rewriter, scatterOp.getBase(), options, state);
if (failed(buffer))
return failure();
vector::ScatterOp::create(rewriter, scatterOp.getLoc(), *buffer,
scatterOp.getOffsets(), scatterOp.getIndices(),
scatterOp.getMask(), scatterOp.getValueToStore());
replaceOpWithBufferizedValues(rewriter, op, *buffer);
return success();
}
};

/// Bufferization of vector.gather. Replaced with a new vector.gather that
/// operates on a memref.
struct GatherOpInterface
Expand Down Expand Up @@ -335,5 +381,6 @@ void mlir::vector::registerBufferizableOpInterfaceExternalModels(
GatherOp::attachInterface<GatherOpInterface>(*ctx);
MaskOp::attachInterface<MaskOpInterface>(*ctx);
YieldOp::attachInterface<YieldOpInterface>(*ctx);
ScatterOp::attachInterface<ScatterOpInterface>(*ctx);
});
}
20 changes: 20 additions & 0 deletions mlir/test/Dialect/Vector/bufferize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,26 @@ func.func @transfer_write(%t: tensor<?x?xf32>, %o1: index,

// -----

// CHECK-LABEL: func @scatter(
// CHECK-SAME: %[[base:.*]]: tensor<16x16xf32>, %[[v:.*]]: vector<16xi32>,
// CHECK-SAME: %[[mask:.*]]: vector<16xi1>, %[[value:.*]]: vector<16xf32>) -> tensor<16x16xf32>
// CHECK: %[[buf:.*]] = bufferization.to_buffer %[[base]] : tensor<16x16xf32> to memref<16x16xf32>
// CHECK: %[[c0:.*]] = arith.constant 0 : index
// CHECK: %[[alloc:.*]] = memref.alloc() {alignment = 64 : i64} : memref<16x16xf32>
// CHECK: memref.copy %[[buf]], %[[alloc]] : memref<16x16xf32> to memref<16x16xf32>
// CHECK: vector.scatter %[[alloc]][%[[c0]], %[[c0]]] [%[[v]]], %[[mask]], %[[value]] : memref<16x16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
// CHECK: %[[tensor:.*]] = bufferization.to_tensor %[[alloc]] : memref<16x16xf32> to tensor<16x16xf32>
// CHECK: return %[[tensor]] : tensor<16x16xf32>
func.func @scatter(%base: tensor<16x16xf32>, %v: vector<16xi32>,
%mask: vector<16xi1>, %value: vector<16xf32>) -> tensor<16x16xf32> {
%c0 = arith.constant 0 : index
%0 = vector.scatter %base[%c0, %c0][%v], %mask, %value
: tensor<16x16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> -> tensor<16x16xf32>
return %0 : tensor<16x16xf32>
}

// -----

// CHECK-LABEL: func @gather(
// CHECK-SAME: %[[base:.*]]: tensor<?x?xf32>, %[[v:.*]]: vector<16xi32>,
// CHECK-SAME: %[[mask:.*]]: vector<16xi1>, %[[pass_thru:.*]]: vector<16xf32>)
Expand Down
4 changes: 2 additions & 2 deletions mlir/test/Dialect/Vector/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1491,9 +1491,9 @@ func.func @gather_non_power_of_two_alignment(%base: memref<16xf32>, %indices: ve
func.func @scatter_to_vector(%base: vector<16xf32>, %indices: vector<16xi32>,
%mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
%c0 = arith.constant 0 : index
// expected-error@+2 {{custom op 'vector.scatter' invalid kind of type specified}}
// expected-error@+1 {{'vector.scatter' op operand #0 must be Tensor or MemRef of any type values, but got 'vector<16xf32>'}}
vector.scatter %base[%c0][%indices], %mask, %pass_thru
: vector<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
: vector<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
}

// -----
Expand Down
Loading