[Mlir-commits] [mlir] [MLIR] Extend vector.scatter to accept tensor as base (PR #165548)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Oct 29 05:04:07 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-vector
Author: Ryutaro Okada (sakupan102)
<details>
<summary>Changes</summary>
In addition to memref, accept ranked tensor as the base operand of vector.scatter, similar to vector.trasnfer_write.
It's worth to complete the functionality of map_scatter decomposition. Full discussion can be found here: https://github.com/iree-org/iree/issues/21135#event-20534280334
---
Full diff: https://github.com/llvm/llvm-project/pull/165548.diff
6 Files Affected:
- (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+27-25)
- (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp (+2-1)
- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+15-4)
- (modified) mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp (+47)
- (modified) mlir/test/Dialect/Vector/bufferize.mlir (+20)
- (modified) mlir/test/Dialect/Vector/invalid.mlir (+2-2)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 6e15b1e7df606..db1b9e169608b 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2150,22 +2150,24 @@ def Vector_GatherOp :
];
}
-def Vector_ScatterOp :
- Vector_Op<"scatter", [DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>]>,
- Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
- Variadic<Index>:$offsets,
- VectorOfNonZeroRankOf<[AnyInteger, Index]>:$indices,
- VectorOfNonZeroRankOf<[I1]>:$mask,
- AnyVectorOfNonZeroRank:$valueToStore,
- OptionalAttr<IntValidAlignment<I64Attr>>: $alignment)> {
+def Vector_ScatterOp
+ : Vector_Op<"scatter", [DeclareOpInterfaceMethods<
+ MemorySpaceCastConsumerOpInterface>]>,
+ Arguments<(ins Arg<TensorOrMemRef<[AnyType]>, "", [MemWrite]>:$base,
+ Variadic<Index>:$offsets,
+ VectorOfNonZeroRankOf<[AnyInteger, Index]>:$indices,
+ VectorOfNonZeroRankOf<[I1]>:$mask,
+ AnyVectorOfNonZeroRank:$valueToStore,
+ OptionalAttr<IntValidAlignment<I64Attr>>:$alignment)>,
+ Results<(outs Optional<AnyRankedTensor>:$result)> {
let summary = [{
- scatters elements from a vector into memory as defined by an index vector
+ scatters elements from a vector into memory or ranked tensor as defined by an index vector
and a mask vector
}];
let description = [{
- The scatter operation stores elements from a n-D vector into memory as
+ The scatter operation stores elements from a n-D vector into memory or ranked tensor as
defined by a base with indices and an additional n-D index vector, but
only if the corresponding bit in a n-D mask vector is set. Otherwise, no
action is taken for that element. Informally the semantics are:
@@ -2208,31 +2210,31 @@ def Vector_ScatterOp :
}];
let extraClassDeclaration = [{
- MemRefType getMemRefType() { return getBase().getType(); }
+ ShapedType getBaseType() { return getBase().getType(); }
VectorType getIndexVectorType() { return getIndices().getType(); }
VectorType getMaskVectorType() { return getMask().getType(); }
VectorType getVectorType() { return getValueToStore().getType(); }
}];
- let assemblyFormat =
- "$base `[` $offsets `]` `[` $indices `]` `,` "
- "$mask `,` $valueToStore attr-dict `:` type($base) `,` "
- "type($indices) `,` type($mask) `,` type($valueToStore)";
+ let assemblyFormat = "$base `[` $offsets `]` `[` $indices `]` `,` "
+ "$mask `,` $valueToStore attr-dict `:` type($base) `,` "
+ "type($indices) `,` type($mask) `,` "
+ "type($valueToStore) (`->` type($result)^)?";
let hasCanonicalizer = 1;
let hasVerifier = 1;
- let builders = [
- OpBuilder<(ins "Value":$base,
- "ValueRange":$indices,
- "Value":$index_vec,
- "Value":$mask,
- "Value":$valueToStore,
- CArg<"llvm::MaybeAlign", "llvm::MaybeAlign()">: $alignment), [{
- return build($_builder, $_state, base, indices, index_vec, mask, valueToStore,
+ let builders =
+ [OpBuilder<(ins "Type":$resultType, "Value":$base, "ValueRange":$indices,
+ "Value":$index_vec, "Value":$mask, "Value":$valueToStore,
+ CArg<"llvm::MaybeAlign", "llvm::MaybeAlign()">:$alignment),
+ [{
+ return build($_builder, $_state, resultType, base, indices, index_vec, mask, valueToStore,
alignment.has_value() ? $_builder.getI64IntegerAttr(alignment->value()) :
nullptr);
- }]>
- ];
+ }]>,
+ OpBuilder<(ins "Value":$base, "ValueRange":$indices, "Value":$index_vec,
+ "Value":$mask, "Value":$valueToStore,
+ CArg<"llvm::MaybeAlign", "llvm::MaybeAlign()">:$alignment)>];
}
def Vector_ExpandLoadOp :
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 41d8d532757ad..cb65d787ea854 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -345,7 +345,8 @@ class VectorScatterOpConversion
matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = scatter->getLoc();
- MemRefType memRefType = scatter.getMemRefType();
+ MemRefType memRefType = dyn_cast<MemRefType>(scatter.getBaseType());
+ assert(memRefType && "The base should be bufferized");
if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
return rewriter.notifyMatchFailure(scatter, "memref type not supported");
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index ad8255a95cb4e..b4a8737107c8d 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6058,12 +6058,15 @@ LogicalResult ScatterOp::verify() {
VectorType indVType = getIndexVectorType();
VectorType maskVType = getMaskVectorType();
VectorType valueVType = getVectorType();
- MemRefType memType = getMemRefType();
+ ShapedType baseType = getBaseType();
- if (valueVType.getElementType() != memType.getElementType())
+ if (!llvm::isa<MemRefType, RankedTensorType>(baseType))
+ return emitOpError("requires base to be a memref or ranked tensor type");
+
+ if (valueVType.getElementType() != baseType.getElementType())
return emitOpError("base and valueToStore element type should match");
- if (llvm::size(getOffsets()) != memType.getRank())
- return emitOpError("requires ") << memType.getRank() << " indices";
+ if (llvm::size(getOffsets()) != baseType.getRank())
+ return emitOpError("requires ") << baseType.getRank() << " indices";
if (valueVType.getShape() != indVType.getShape())
return emitOpError("expected valueToStore dim to match indices dim");
if (valueVType.getShape() != maskVType.getShape())
@@ -6071,6 +6074,14 @@ LogicalResult ScatterOp::verify() {
return success();
}
+void ScatterOp::build(OpBuilder &builder, OperationState &result, Value base,
+ ValueRange indices, Value index_vec, Value mask,
+ Value valueToStore, llvm::MaybeAlign alignment) {
+ Type resultType = llvm::dyn_cast<RankedTensorType>(base.getType());
+ build(builder, result, resultType, base, indices, index_vec, mask,
+ valueToStore, alignment);
+}
+
namespace {
class ScatterFolder final : public OpRewritePattern<ScatterOp> {
public:
diff --git a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
index 546099ca975b7..eb11253ec647c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -14,6 +14,7 @@
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Operation.h"
+#include "mlir/IR/Value.h"
using namespace mlir;
using namespace mlir::bufferization;
@@ -126,6 +127,51 @@ struct TransferWriteOpInterface
}
};
+/// Bufferization of vector.scatter. Replaced with a new vector.scatter that
+/// operates on a memref.
+struct ScatterOpInterface
+ : public BufferizableOpInterface::ExternalModel<ScatterOpInterface,
+ vector::ScatterOp> {
+ bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+ const AnalysisState &state) const {
+ assert(isa<RankedTensorType>(opOperand.get().getType()) &&
+ "only tensor types expected");
+ return true;
+ }
+ bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+ const AnalysisState &state) const {
+ assert(isa<RankedTensorType>(opOperand.get().getType()) &&
+ "only tensor types expected");
+ return true;
+ }
+ AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
+ const AnalysisState &state) const {
+ auto scatterOp = cast<vector::ScatterOp>(op);
+ if (&opOperand != &scatterOp.getBaseMutable())
+ return {};
+ if (op->getNumResults() == 0)
+ return {};
+ return {{scatterOp.getResult(), BufferRelation::Equivalent}};
+ }
+
+ LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+ const BufferizationOptions &options,
+ BufferizationState &state) const {
+ ScatterOp scatterOp = cast<vector::ScatterOp>(op);
+ assert(isa<TensorType>(scatterOp.getBaseType()) &&
+ "only tensor types expected");
+ FailureOr<Value> buffer =
+ getBuffer(rewriter, scatterOp.getBase(), options, state);
+ if (failed(buffer))
+ return failure();
+ vector::ScatterOp::create(rewriter, scatterOp.getLoc(), *buffer,
+ scatterOp.getOffsets(), scatterOp.getIndices(),
+ scatterOp.getMask(), scatterOp.getValueToStore());
+ replaceOpWithBufferizedValues(rewriter, op, *buffer);
+ return success();
+ }
+};
+
/// Bufferization of vector.gather. Replaced with a new vector.gather that
/// operates on a memref.
struct GatherOpInterface
@@ -335,5 +381,6 @@ void mlir::vector::registerBufferizableOpInterfaceExternalModels(
GatherOp::attachInterface<GatherOpInterface>(*ctx);
MaskOp::attachInterface<MaskOpInterface>(*ctx);
YieldOp::attachInterface<YieldOpInterface>(*ctx);
+ ScatterOp::attachInterface<ScatterOpInterface>(*ctx);
});
}
diff --git a/mlir/test/Dialect/Vector/bufferize.mlir b/mlir/test/Dialect/Vector/bufferize.mlir
index 887fb941cc651..70adefd0dc4ec 100644
--- a/mlir/test/Dialect/Vector/bufferize.mlir
+++ b/mlir/test/Dialect/Vector/bufferize.mlir
@@ -32,6 +32,26 @@ func.func @transfer_write(%t: tensor<?x?xf32>, %o1: index,
// -----
+// CHECK-LABEL: func @scatter(
+// CHECK-SAME: %[[base:.*]]: tensor<16x16xf32>, %[[v:.*]]: vector<16xi32>,
+// CHECK-SAME: %[[mask:.*]]: vector<16xi1>, %[[value:.*]]: vector<16xf32>) -> tensor<16x16xf32>
+// CHECK: %[[buf:.*]] = bufferization.to_buffer %[[base]] : tensor<16x16xf32> to memref<16x16xf32>
+// CHECK: %[[c0:.*]] = arith.constant 0 : index
+// CHECK: %[[alloc:.*]] = memref.alloc() {alignment = 64 : i64} : memref<16x16xf32>
+// CHECK: memref.copy %[[buf]], %[[alloc]] : memref<16x16xf32> to memref<16x16xf32>
+// CHECK: vector.scatter %[[alloc]][%[[c0]], %[[c0]]] [%[[v]]], %[[mask]], %[[value]] : memref<16x16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
+// CHECK: %[[tensor:.*]] = bufferization.to_tensor %[[alloc]] : memref<16x16xf32> to tensor<16x16xf32>
+// CHECK: return %[[tensor]] : tensor<16x16xf32>
+func.func @scatter(%base: tensor<16x16xf32>, %v: vector<16xi32>,
+ %mask: vector<16xi1>, %value: vector<16xf32>) -> tensor<16x16xf32> {
+ %c0 = arith.constant 0 : index
+ %0 = vector.scatter %base[%c0, %c0][%v], %mask, %value
+ : tensor<16x16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> -> tensor<16x16xf32>
+ return %0 : tensor<16x16xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @gather(
// CHECK-SAME: %[[base:.*]]: tensor<?x?xf32>, %[[v:.*]]: vector<16xi32>,
// CHECK-SAME: %[[mask:.*]]: vector<16xi1>, %[[pass_thru:.*]]: vector<16xf32>)
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 5f035e35a1b86..79b09e172145b 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1491,9 +1491,9 @@ func.func @gather_non_power_of_two_alignment(%base: memref<16xf32>, %indices: ve
func.func @scatter_to_vector(%base: vector<16xf32>, %indices: vector<16xi32>,
%mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
%c0 = arith.constant 0 : index
- // expected-error at +2 {{custom op 'vector.scatter' invalid kind of type specified}}
+ // expected-error at +1 {{'vector.scatter' op operand #0 must be Tensor or MemRef of any type values, but got 'vector<16xf32>'}}
vector.scatter %base[%c0][%indices], %mask, %pass_thru
- : vector<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+ : vector<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
}
// -----
``````````
</details>
https://github.com/llvm/llvm-project/pull/165548
More information about the Mlir-commits
mailing list