Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 25 additions & 28 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2160,25 +2160,25 @@ def Vector_GatherOp :
];
}

def Vector_ScatterOp :
Vector_Op<"scatter", [
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
DeclareOpInterfaceMethods<AlignmentAttrOpInterface>
]>,
Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
Variadic<Index>:$offsets,
VectorOfNonZeroRankOf<[AnyInteger, Index]>:$indices,
VectorOfNonZeroRankOf<[I1]>:$mask,
AnyVectorOfNonZeroRank:$valueToStore,
OptionalAttr<IntValidAlignment<I64Attr>>: $alignment)> {
def Vector_ScatterOp
: Vector_Op<"scatter",
[DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
DeclareOpInterfaceMethods<AlignmentAttrOpInterface>]>,
Arguments<(ins Arg<TensorOrMemRef<[AnyType]>, "", [MemWrite]>:$base,
Variadic<Index>:$offsets,
VectorOfNonZeroRankOf<[AnyInteger, Index]>:$indices,
VectorOfNonZeroRankOf<[I1]>:$mask,
AnyVectorOfNonZeroRank:$valueToStore,
OptionalAttr<IntValidAlignment<I64Attr>>:$alignment)>,
Results<(outs Optional<AnyRankedTensor>:$result)> {

let summary = [{
scatters elements from a vector into memory as defined by an index vector
scatters elements from a vector into memory or ranked tensor as defined by an index vector
and a mask vector
}];

let description = [{
The scatter operation stores elements from a n-D vector into memory as
The scatter operation stores elements from a n-D vector into memory or ranked tensor as
defined by a base with indices and an additional n-D index vector, but
only if the corresponding bit in a n-D mask vector is set. Otherwise, no
action is taken for that element. Informally the semantics are:
Expand Down Expand Up @@ -2221,31 +2221,28 @@ def Vector_ScatterOp :
}];

let extraClassDeclaration = [{
MemRefType getMemRefType() { return getBase().getType(); }
ShapedType getBaseType() { return getBase().getType(); }
VectorType getIndexVectorType() { return getIndices().getType(); }
VectorType getMaskVectorType() { return getMask().getType(); }
VectorType getVectorType() { return getValueToStore().getType(); }
}];

let assemblyFormat =
"$base `[` $offsets `]` `[` $indices `]` `,` "
"$mask `,` $valueToStore attr-dict `:` type($base) `,` "
"type($indices) `,` type($mask) `,` type($valueToStore)";
let assemblyFormat = "$base `[` $offsets `]` `[` $indices `]` `,` "
"$mask `,` $valueToStore attr-dict `:` type($base) `,` "
"type($indices) `,` type($mask) `,` "
"type($valueToStore) (`->` type($result)^)?";
let hasCanonicalizer = 1;
let hasVerifier = 1;

let builders = [
OpBuilder<(ins "Value":$base,
"ValueRange":$indices,
"Value":$index_vec,
"Value":$mask,
"Value":$valueToStore,
CArg<"llvm::MaybeAlign", "llvm::MaybeAlign()">: $alignment), [{
return build($_builder, $_state, base, indices, index_vec, mask, valueToStore,
let builders = [OpBuilder<
(ins "Type":$resultType, "Value":$base, "ValueRange":$indices,
"Value":$index_vec, "Value":$mask, "Value":$valueToStore,
CArg<"llvm::MaybeAlign", "llvm::MaybeAlign()">:$alignment),
[{
return build($_builder, $_state, resultType, base, indices, index_vec, mask, valueToStore,
alignment.has_value() ? $_builder.getI64IntegerAttr(alignment->value()) :
nullptr);
}]>
];
}]>];
}

def Vector_ExpandLoadOp :
Expand Down
3 changes: 2 additions & 1 deletion mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,8 @@ class VectorScatterOpConversion
matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = scatter->getLoc();
MemRefType memRefType = scatter.getMemRefType();
auto memRefType = dyn_cast<MemRefType>(scatter.getBaseType());
assert(memRefType && "The base should be bufferized");

if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
return rewriter.notifyMatchFailure(scatter, "memref type not supported");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,8 @@ static void genVectorStore(PatternRewriter &rewriter, Location loc, Value mem,
SmallVector<Value> scalarArgs(idxs);
Value indexVec = idxs.back();
scalarArgs.back() = constantIndex(rewriter, loc, 0);
vector::ScatterOp::create(rewriter, loc, mem, scalarArgs, indexVec, vmask,
rhs);
vector::ScatterOp::create(rewriter, loc, /*resultType=*/nullptr, mem,
scalarArgs, indexVec, vmask, rhs);
return;
}
vector::MaskedStoreOp::create(rewriter, loc, mem, idxs, vmask, rhs);
Expand Down
12 changes: 7 additions & 5 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6066,19 +6066,21 @@ LogicalResult ScatterOp::verify() {
VectorType indVType = getIndexVectorType();
VectorType maskVType = getMaskVectorType();
VectorType valueVType = getVectorType();
MemRefType memType = getMemRefType();
ShapedType baseType = getBaseType();

if (valueVType.getElementType() != memType.getElementType())
if (!llvm::isa<MemRefType, RankedTensorType>(baseType))
return emitOpError("requires base to be a memref or ranked tensor type");

if (valueVType.getElementType() != baseType.getElementType())
return emitOpError("base and valueToStore element type should match");
if (llvm::size(getOffsets()) != memType.getRank())
return emitOpError("requires ") << memType.getRank() << " indices";
if (llvm::size(getOffsets()) != baseType.getRank())
return emitOpError("requires ") << baseType.getRank() << " indices";
if (valueVType.getShape() != indVType.getShape())
return emitOpError("expected valueToStore dim to match indices dim");
if (valueVType.getShape() != maskVType.getShape())
return emitOpError("expected valueToStore dim to match mask dim");
return success();
}

namespace {
class ScatterFolder final : public OpRewritePattern<ScatterOp> {
public:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Value.h"

using namespace mlir;
using namespace mlir::bufferization;
Expand Down Expand Up @@ -126,6 +127,54 @@ struct TransferWriteOpInterface
}
};

/// Bufferization of vector.scatter. Replaced with a new vector.scatter that
/// operates on a memref.
struct ScatterOpInterface
: public BufferizableOpInterface::ExternalModel<ScatterOpInterface,
vector::ScatterOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
assert(isa<RankedTensorType>(opOperand.get().getType()) &&
"only tensor types expected");
return true;
}

bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
assert(isa<RankedTensorType>(opOperand.get().getType()) &&
"only tensor types expected");
return true;
}

AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
assert(isa<RankedTensorType>(opOperand.get().getType()) &&
"only tensor types expected");
auto scatterOp = cast<vector::ScatterOp>(op);
if (&opOperand != &scatterOp.getBaseMutable())
return {};
return {{scatterOp.getResult(), BufferRelation::Equivalent}};
}

LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options,
BufferizationState &state) const {
auto scatterOp = cast<vector::ScatterOp>(op);
assert(isa<TensorType>(scatterOp.getBaseType()) &&
"only tensor types expected");
FailureOr<Value> buffer =
getBuffer(rewriter, scatterOp.getBase(), options, state);
if (failed(buffer))
return failure();
vector::ScatterOp::create(rewriter, scatterOp.getLoc(),
/*resultType=*/nullptr, *buffer,
scatterOp.getOffsets(), scatterOp.getIndices(),
scatterOp.getMask(), scatterOp.getValueToStore());
replaceOpWithBufferizedValues(rewriter, op, *buffer);
return success();
}
};

/// Bufferization of vector.gather. Replaced with a new vector.gather that
/// operates on a memref.
struct GatherOpInterface
Expand Down Expand Up @@ -335,5 +384,6 @@ void mlir::vector::registerBufferizableOpInterfaceExternalModels(
GatherOp::attachInterface<GatherOpInterface>(*ctx);
MaskOp::attachInterface<MaskOpInterface>(*ctx);
YieldOp::attachInterface<YieldOpInterface>(*ctx);
ScatterOp::attachInterface<ScatterOpInterface>(*ctx);
});
}
20 changes: 20 additions & 0 deletions mlir/test/Dialect/Vector/bufferize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,26 @@ func.func @transfer_write(%t: tensor<?x?xf32>, %o1: index,

// -----

// CHECK-LABEL: func @scatter(
// CHECK-SAME: %[[base:.*]]: tensor<16x16xf32>, %[[v:.*]]: vector<16xi32>,
// CHECK-SAME: %[[mask:.*]]: vector<16xi1>, %[[value:.*]]: vector<16xf32>) -> tensor<16x16xf32>
// CHECK: %[[buf:.*]] = bufferization.to_buffer %[[base]] : tensor<16x16xf32> to memref<16x16xf32>
// CHECK: %[[c0:.*]] = arith.constant 0 : index
// CHECK: %[[alloc:.*]] = memref.alloc() {alignment = 64 : i64} : memref<16x16xf32>
// CHECK: memref.copy %[[buf]], %[[alloc]] : memref<16x16xf32> to memref<16x16xf32>
// CHECK: vector.scatter %[[alloc]][%[[c0]], %[[c0]]] [%[[v]]], %[[mask]], %[[value]] : memref<16x16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
// CHECK: %[[tensor:.*]] = bufferization.to_tensor %[[alloc]] : memref<16x16xf32> to tensor<16x16xf32>
// CHECK: return %[[tensor]] : tensor<16x16xf32>
func.func @scatter(%base: tensor<16x16xf32>, %v: vector<16xi32>,
%mask: vector<16xi1>, %value: vector<16xf32>) -> tensor<16x16xf32> {
%c0 = arith.constant 0 : index
%0 = vector.scatter %base[%c0, %c0][%v], %mask, %value
: tensor<16x16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> -> tensor<16x16xf32>
return %0 : tensor<16x16xf32>
}

// -----

// CHECK-LABEL: func @gather(
// CHECK-SAME: %[[base:.*]]: tensor<?x?xf32>, %[[v:.*]]: vector<16xi32>,
// CHECK-SAME: %[[mask:.*]]: vector<16xi1>, %[[pass_thru:.*]]: vector<16xf32>)
Expand Down
4 changes: 2 additions & 2 deletions mlir/test/Dialect/Vector/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1491,9 +1491,9 @@ func.func @gather_non_power_of_two_alignment(%base: memref<16xf32>, %indices: ve
func.func @scatter_to_vector(%base: vector<16xf32>, %indices: vector<16xi32>,
%mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
%c0 = arith.constant 0 : index
// expected-error@+2 {{custom op 'vector.scatter' invalid kind of type specified}}
// expected-error@+1 {{'vector.scatter' op operand #0 must be Tensor or MemRef of any type values, but got 'vector<16xf32>'}}
vector.scatter %base[%c0][%indices], %mask, %pass_thru
: vector<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
: vector<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
}

// -----
Expand Down
14 changes: 14 additions & 0 deletions mlir/test/Dialect/Vector/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1160,3 +1160,17 @@ func.func @step() {
%1 = vector.step : vector<[4]xindex>
return
}

// CHECK-LABEL: func @scatter_tensor(
// CHECK-SAME: %[[BASE:.*]]: tensor<16x16xf32>, %[[V:.*]]: vector<16xi32>,
// CHECK-SAME: %[[MASK:.*]]: vector<16xi1>, %[[VALUE:.*]]: vector<16xf32>) -> tensor<16x16xf32>
func.func @scatter_tensor(%base: tensor<16x16xf32>, %v: vector<16xi32>,
%mask: vector<16xi1>, %value: vector<16xf32>) -> tensor<16x16xf32> {
// CHECK: %[[C0:.*]] = arith.constant 0 : index
%c0 = arith.constant 0 : index
// CHECK: %[[RESULT:.*]] = vector.scatter %[[BASE]][%[[C0]], %[[C0]]] [%[[V]]], %[[MASK]], %[[VALUE]]
%0 = vector.scatter %base[%c0, %c0] [%v], %mask, %value
: tensor<16x16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> -> tensor<16x16xf32>
// CHECK: return %[[RESULT]] : tensor<16x16xf32>
return %0 : tensor<16x16xf32>
}
Loading