diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index 43172ff2082df..f91d2b6404c9b 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -2160,25 +2160,25 @@ def Vector_GatherOp : ]; } -def Vector_ScatterOp : - Vector_Op<"scatter", [ - DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods - ]>, - Arguments<(ins Arg:$base, - Variadic:$offsets, - VectorOfNonZeroRankOf<[AnyInteger, Index]>:$indices, - VectorOfNonZeroRankOf<[I1]>:$mask, - AnyVectorOfNonZeroRank:$valueToStore, - OptionalAttr>: $alignment)> { +def Vector_ScatterOp + : Vector_Op<"scatter", + [DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]>, + Arguments<(ins Arg, "", [MemWrite]>:$base, + Variadic:$offsets, + VectorOfNonZeroRankOf<[AnyInteger, Index]>:$indices, + VectorOfNonZeroRankOf<[I1]>:$mask, + AnyVectorOfNonZeroRank:$valueToStore, + OptionalAttr>:$alignment)>, + Results<(outs Optional:$result)> { let summary = [{ - scatters elements from a vector into memory as defined by an index vector + scatters elements from a vector into memory or ranked tensor as defined by an index vector and a mask vector }]; let description = [{ - The scatter operation stores elements from a n-D vector into memory as + The scatter operation stores elements from a n-D vector into memory or ranked tensor as defined by a base with indices and an additional n-D index vector, but only if the corresponding bit in a n-D mask vector is set. Otherwise, no action is taken for that element. Informally the semantics are: @@ -2221,31 +2221,28 @@ def Vector_ScatterOp : }]; let extraClassDeclaration = [{ - MemRefType getMemRefType() { return getBase().getType(); } + ShapedType getBaseType() { return getBase().getType(); } VectorType getIndexVectorType() { return getIndices().getType(); } VectorType getMaskVectorType() { return getMask().getType(); } VectorType getVectorType() { return getValueToStore().getType(); } }]; - let assemblyFormat = - "$base `[` $offsets `]` `[` $indices `]` `,` " - "$mask `,` $valueToStore attr-dict `:` type($base) `,` " - "type($indices) `,` type($mask) `,` type($valueToStore)"; + let assemblyFormat = "$base `[` $offsets `]` `[` $indices `]` `,` " + "$mask `,` $valueToStore attr-dict `:` type($base) `,` " + "type($indices) `,` type($mask) `,` " + "type($valueToStore) (`->` type($result)^)?"; let hasCanonicalizer = 1; let hasVerifier = 1; - let builders = [ - OpBuilder<(ins "Value":$base, - "ValueRange":$indices, - "Value":$index_vec, - "Value":$mask, - "Value":$valueToStore, - CArg<"llvm::MaybeAlign", "llvm::MaybeAlign()">: $alignment), [{ - return build($_builder, $_state, base, indices, index_vec, mask, valueToStore, + let builders = [OpBuilder< + (ins "Type":$resultType, "Value":$base, "ValueRange":$indices, + "Value":$index_vec, "Value":$mask, "Value":$valueToStore, + CArg<"llvm::MaybeAlign", "llvm::MaybeAlign()">:$alignment), + [{ + return build($_builder, $_state, resultType, base, indices, index_vec, mask, valueToStore, alignment.has_value() ? $_builder.getI64IntegerAttr(alignment->value()) : nullptr); - }]> - ]; + }]>]; } def Vector_ExpandLoadOp : diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 69a317ecd101f..40a166f20c085 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -345,7 +345,8 @@ class VectorScatterOpConversion matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = scatter->getLoc(); - MemRefType memRefType = scatter.getMemRefType(); + auto memRefType = dyn_cast(scatter.getBaseType()); + assert(memRefType && "The base should be bufferized"); if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter()))) return rewriter.notifyMatchFailure(scatter, "memref type not supported"); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp index febec6d2d2ce4..23436a68535fc 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp @@ -132,8 +132,8 @@ static void genVectorStore(PatternRewriter &rewriter, Location loc, Value mem, SmallVector scalarArgs(idxs); Value indexVec = idxs.back(); scalarArgs.back() = constantIndex(rewriter, loc, 0); - vector::ScatterOp::create(rewriter, loc, mem, scalarArgs, indexVec, vmask, - rhs); + vector::ScatterOp::create(rewriter, loc, /*resultType=*/nullptr, mem, + scalarArgs, indexVec, vmask, rhs); return; } vector::MaskedStoreOp::create(rewriter, loc, mem, idxs, vmask, rhs); diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index daef0ba02100a..a97d0cd7f755b 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -6066,19 +6066,21 @@ LogicalResult ScatterOp::verify() { VectorType indVType = getIndexVectorType(); VectorType maskVType = getMaskVectorType(); VectorType valueVType = getVectorType(); - MemRefType memType = getMemRefType(); + ShapedType baseType = getBaseType(); - if (valueVType.getElementType() != memType.getElementType()) + if (!llvm::isa(baseType)) + return emitOpError("requires base to be a memref or ranked tensor type"); + + if (valueVType.getElementType() != baseType.getElementType()) return emitOpError("base and valueToStore element type should match"); - if (llvm::size(getOffsets()) != memType.getRank()) - return emitOpError("requires ") << memType.getRank() << " indices"; + if (llvm::size(getOffsets()) != baseType.getRank()) + return emitOpError("requires ") << baseType.getRank() << " indices"; if (valueVType.getShape() != indVType.getShape()) return emitOpError("expected valueToStore dim to match indices dim"); if (valueVType.getShape() != maskVType.getShape()) return emitOpError("expected valueToStore dim to match mask dim"); return success(); } - namespace { class ScatterFolder final : public OpRewritePattern { public: diff --git a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp index 546099ca975b7..352f477a8746e 100644 --- a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/Operation.h" +#include "mlir/IR/Value.h" using namespace mlir; using namespace mlir::bufferization; @@ -126,6 +127,54 @@ struct TransferWriteOpInterface } }; +/// Bufferization of vector.scatter. Replaced with a new vector.scatter that +/// operates on a memref. +struct ScatterOpInterface + : public BufferizableOpInterface::ExternalModel { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + assert(isa(opOperand.get().getType()) && + "only tensor types expected"); + return true; + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + assert(isa(opOperand.get().getType()) && + "only tensor types expected"); + return true; + } + + AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + assert(isa(opOperand.get().getType()) && + "only tensor types expected"); + auto scatterOp = cast(op); + if (&opOperand != &scatterOp.getBaseMutable()) + return {}; + return {{scatterOp.getResult(), BufferRelation::Equivalent}}; + } + + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + const BufferizationOptions &options, + BufferizationState &state) const { + auto scatterOp = cast(op); + assert(isa(scatterOp.getBaseType()) && + "only tensor types expected"); + FailureOr buffer = + getBuffer(rewriter, scatterOp.getBase(), options, state); + if (failed(buffer)) + return failure(); + vector::ScatterOp::create(rewriter, scatterOp.getLoc(), + /*resultType=*/nullptr, *buffer, + scatterOp.getOffsets(), scatterOp.getIndices(), + scatterOp.getMask(), scatterOp.getValueToStore()); + replaceOpWithBufferizedValues(rewriter, op, *buffer); + return success(); + } +}; + /// Bufferization of vector.gather. Replaced with a new vector.gather that /// operates on a memref. struct GatherOpInterface @@ -335,5 +384,6 @@ void mlir::vector::registerBufferizableOpInterfaceExternalModels( GatherOp::attachInterface(*ctx); MaskOp::attachInterface(*ctx); YieldOp::attachInterface(*ctx); + ScatterOp::attachInterface(*ctx); }); } diff --git a/mlir/test/Dialect/Vector/bufferize.mlir b/mlir/test/Dialect/Vector/bufferize.mlir index 887fb941cc651..70adefd0dc4ec 100644 --- a/mlir/test/Dialect/Vector/bufferize.mlir +++ b/mlir/test/Dialect/Vector/bufferize.mlir @@ -32,6 +32,26 @@ func.func @transfer_write(%t: tensor, %o1: index, // ----- +// CHECK-LABEL: func @scatter( +// CHECK-SAME: %[[base:.*]]: tensor<16x16xf32>, %[[v:.*]]: vector<16xi32>, +// CHECK-SAME: %[[mask:.*]]: vector<16xi1>, %[[value:.*]]: vector<16xf32>) -> tensor<16x16xf32> +// CHECK: %[[buf:.*]] = bufferization.to_buffer %[[base]] : tensor<16x16xf32> to memref<16x16xf32> +// CHECK: %[[c0:.*]] = arith.constant 0 : index +// CHECK: %[[alloc:.*]] = memref.alloc() {alignment = 64 : i64} : memref<16x16xf32> +// CHECK: memref.copy %[[buf]], %[[alloc]] : memref<16x16xf32> to memref<16x16xf32> +// CHECK: vector.scatter %[[alloc]][%[[c0]], %[[c0]]] [%[[v]]], %[[mask]], %[[value]] : memref<16x16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> +// CHECK: %[[tensor:.*]] = bufferization.to_tensor %[[alloc]] : memref<16x16xf32> to tensor<16x16xf32> +// CHECK: return %[[tensor]] : tensor<16x16xf32> +func.func @scatter(%base: tensor<16x16xf32>, %v: vector<16xi32>, + %mask: vector<16xi1>, %value: vector<16xf32>) -> tensor<16x16xf32> { + %c0 = arith.constant 0 : index + %0 = vector.scatter %base[%c0, %c0][%v], %mask, %value + : tensor<16x16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> -> tensor<16x16xf32> + return %0 : tensor<16x16xf32> +} + +// ----- + // CHECK-LABEL: func @gather( // CHECK-SAME: %[[base:.*]]: tensor, %[[v:.*]]: vector<16xi32>, // CHECK-SAME: %[[mask:.*]]: vector<16xi1>, %[[pass_thru:.*]]: vector<16xf32>) diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir index 5f035e35a1b86..79b09e172145b 100644 --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -1491,9 +1491,9 @@ func.func @gather_non_power_of_two_alignment(%base: memref<16xf32>, %indices: ve func.func @scatter_to_vector(%base: vector<16xf32>, %indices: vector<16xi32>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) { %c0 = arith.constant 0 : index - // expected-error@+2 {{custom op 'vector.scatter' invalid kind of type specified}} + // expected-error@+1 {{'vector.scatter' op operand #0 must be Tensor or MemRef of any type values, but got 'vector<16xf32>'}} vector.scatter %base[%c0][%indices], %mask, %pass_thru - : vector<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32> + : vector<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> } // ----- diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir index da9a1a8180a05..de620221944de 100644 --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -1160,3 +1160,17 @@ func.func @step() { %1 = vector.step : vector<[4]xindex> return } + +// CHECK-LABEL: func @scatter_tensor( +// CHECK-SAME: %[[BASE:.*]]: tensor<16x16xf32>, %[[V:.*]]: vector<16xi32>, +// CHECK-SAME: %[[MASK:.*]]: vector<16xi1>, %[[VALUE:.*]]: vector<16xf32>) -> tensor<16x16xf32> +func.func @scatter_tensor(%base: tensor<16x16xf32>, %v: vector<16xi32>, + %mask: vector<16xi1>, %value: vector<16xf32>) -> tensor<16x16xf32> { + // CHECK: %[[C0:.*]] = arith.constant 0 : index + %c0 = arith.constant 0 : index + // CHECK: %[[RESULT:.*]] = vector.scatter %[[BASE]][%[[C0]], %[[C0]]] [%[[V]]], %[[MASK]], %[[VALUE]] + %0 = vector.scatter %base[%c0, %c0] [%v], %mask, %value + : tensor<16x16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> -> tensor<16x16xf32> + // CHECK: return %[[RESULT]] : tensor<16x16xf32> + return %0 : tensor<16x16xf32> +}