[Mlir-commits] [mlir] [MLIR] Extend vector.scatter to accept tensor as base (PR #165548)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Oct 29 05:04:07 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-vector

Author: Ryutaro Okada (sakupan102)

<details>
<summary>Changes</summary>

In addition to memref, accept ranked tensor as the base operand of vector.scatter, similar to vector.trasnfer_write.

It's worth to complete the functionality of map_scatter decomposition. Full discussion can be found here: https://github.com/iree-org/iree/issues/21135#event-20534280334

---
Full diff: https://github.com/llvm/llvm-project/pull/165548.diff


6 Files Affected:

- (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+27-25) 
- (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp (+2-1) 
- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+15-4) 
- (modified) mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp (+47) 
- (modified) mlir/test/Dialect/Vector/bufferize.mlir (+20) 
- (modified) mlir/test/Dialect/Vector/invalid.mlir (+2-2) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 6e15b1e7df606..db1b9e169608b 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2150,22 +2150,24 @@ def Vector_GatherOp :
   ];
 }
 
-def Vector_ScatterOp :
-  Vector_Op<"scatter", [DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>]>,
-    Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
-               Variadic<Index>:$offsets,
-               VectorOfNonZeroRankOf<[AnyInteger, Index]>:$indices,
-               VectorOfNonZeroRankOf<[I1]>:$mask,
-               AnyVectorOfNonZeroRank:$valueToStore,
-               OptionalAttr<IntValidAlignment<I64Attr>>: $alignment)> {
+def Vector_ScatterOp
+    : Vector_Op<"scatter", [DeclareOpInterfaceMethods<
+                               MemorySpaceCastConsumerOpInterface>]>,
+      Arguments<(ins Arg<TensorOrMemRef<[AnyType]>, "", [MemWrite]>:$base,
+          Variadic<Index>:$offsets,
+          VectorOfNonZeroRankOf<[AnyInteger, Index]>:$indices,
+          VectorOfNonZeroRankOf<[I1]>:$mask,
+          AnyVectorOfNonZeroRank:$valueToStore,
+          OptionalAttr<IntValidAlignment<I64Attr>>:$alignment)>,
+      Results<(outs Optional<AnyRankedTensor>:$result)> {
 
   let summary = [{
-    scatters elements from a vector into memory as defined by an index vector
+    scatters elements from a vector into memory or ranked tensor as defined by an index vector
     and a mask vector
   }];
 
   let description = [{
-    The scatter operation stores elements from a n-D vector into memory as
+    The scatter operation stores elements from a n-D vector into memory or ranked tensor as
     defined by a base with indices and an additional n-D index vector, but
     only if the corresponding bit in a n-D mask vector is set. Otherwise, no
     action is taken for that element. Informally the semantics are:
@@ -2208,31 +2210,31 @@ def Vector_ScatterOp :
   }];
 
   let extraClassDeclaration = [{
-    MemRefType getMemRefType() { return getBase().getType(); }
+    ShapedType getBaseType() { return getBase().getType(); }
     VectorType getIndexVectorType() { return getIndices().getType(); }
     VectorType getMaskVectorType() { return getMask().getType(); }
     VectorType getVectorType() { return getValueToStore().getType(); }
   }];
 
-  let assemblyFormat =
-      "$base `[` $offsets `]` `[` $indices `]` `,` "
-      "$mask `,` $valueToStore attr-dict `:` type($base) `,` "
-      "type($indices)  `,` type($mask) `,` type($valueToStore)";
+  let assemblyFormat = "$base `[` $offsets `]` `[` $indices `]` `,` "
+                       "$mask `,` $valueToStore attr-dict `:` type($base) `,` "
+                       "type($indices)  `,` type($mask) `,` "
+                       "type($valueToStore) (`->` type($result)^)?";
   let hasCanonicalizer = 1;
   let hasVerifier = 1;
 
-  let builders = [
-    OpBuilder<(ins "Value":$base,
-                   "ValueRange":$indices,
-                   "Value":$index_vec,
-                   "Value":$mask,
-                   "Value":$valueToStore,
-                   CArg<"llvm::MaybeAlign", "llvm::MaybeAlign()">: $alignment), [{
-      return build($_builder, $_state, base, indices, index_vec, mask, valueToStore,
+  let builders =
+      [OpBuilder<(ins "Type":$resultType, "Value":$base, "ValueRange":$indices,
+                     "Value":$index_vec, "Value":$mask, "Value":$valueToStore,
+                     CArg<"llvm::MaybeAlign", "llvm::MaybeAlign()">:$alignment),
+                 [{
+      return build($_builder, $_state, resultType, base, indices, index_vec, mask, valueToStore,
                    alignment.has_value() ? $_builder.getI64IntegerAttr(alignment->value()) :
                                     nullptr);
-    }]>
-  ];
+    }]>,
+       OpBuilder<(ins "Value":$base, "ValueRange":$indices, "Value":$index_vec,
+           "Value":$mask, "Value":$valueToStore,
+           CArg<"llvm::MaybeAlign", "llvm::MaybeAlign()">:$alignment)>];
 }
 
 def Vector_ExpandLoadOp :
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 41d8d532757ad..cb65d787ea854 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -345,7 +345,8 @@ class VectorScatterOpConversion
   matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     auto loc = scatter->getLoc();
-    MemRefType memRefType = scatter.getMemRefType();
+    MemRefType memRefType = dyn_cast<MemRefType>(scatter.getBaseType());
+    assert(memRefType && "The base should be bufferized");
 
     if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
       return rewriter.notifyMatchFailure(scatter, "memref type not supported");
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index ad8255a95cb4e..b4a8737107c8d 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6058,12 +6058,15 @@ LogicalResult ScatterOp::verify() {
   VectorType indVType = getIndexVectorType();
   VectorType maskVType = getMaskVectorType();
   VectorType valueVType = getVectorType();
-  MemRefType memType = getMemRefType();
+  ShapedType baseType = getBaseType();
 
-  if (valueVType.getElementType() != memType.getElementType())
+  if (!llvm::isa<MemRefType, RankedTensorType>(baseType))
+    return emitOpError("requires base to be a memref or ranked tensor type");
+
+  if (valueVType.getElementType() != baseType.getElementType())
     return emitOpError("base and valueToStore element type should match");
-  if (llvm::size(getOffsets()) != memType.getRank())
-    return emitOpError("requires ") << memType.getRank() << " indices";
+  if (llvm::size(getOffsets()) != baseType.getRank())
+    return emitOpError("requires ") << baseType.getRank() << " indices";
   if (valueVType.getShape() != indVType.getShape())
     return emitOpError("expected valueToStore dim to match indices dim");
   if (valueVType.getShape() != maskVType.getShape())
@@ -6071,6 +6074,14 @@ LogicalResult ScatterOp::verify() {
   return success();
 }
 
+void ScatterOp::build(OpBuilder &builder, OperationState &result, Value base,
+                      ValueRange indices, Value index_vec, Value mask,
+                      Value valueToStore, llvm::MaybeAlign alignment) {
+  Type resultType = llvm::dyn_cast<RankedTensorType>(base.getType());
+  build(builder, result, resultType, base, indices, index_vec, mask,
+        valueToStore, alignment);
+}
+
 namespace {
 class ScatterFolder final : public OpRewritePattern<ScatterOp> {
 public:
diff --git a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
index 546099ca975b7..eb11253ec647c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -14,6 +14,7 @@
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/Operation.h"
+#include "mlir/IR/Value.h"
 
 using namespace mlir;
 using namespace mlir::bufferization;
@@ -126,6 +127,51 @@ struct TransferWriteOpInterface
   }
 };
 
+/// Bufferization of vector.scatter. Replaced with a new vector.scatter that
+/// operates on a memref.
+struct ScatterOpInterface
+    : public BufferizableOpInterface::ExternalModel<ScatterOpInterface,
+                                                    vector::ScatterOp> {
+  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+                              const AnalysisState &state) const {
+    assert(isa<RankedTensorType>(opOperand.get().getType()) &&
+           "only tensor types expected");
+    return true;
+  }
+  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+                               const AnalysisState &state) const {
+    assert(isa<RankedTensorType>(opOperand.get().getType()) &&
+           "only tensor types expected");
+    return true;
+  }
+  AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
+                                      const AnalysisState &state) const {
+    auto scatterOp = cast<vector::ScatterOp>(op);
+    if (&opOperand != &scatterOp.getBaseMutable())
+      return {};
+    if (op->getNumResults() == 0)
+      return {};
+    return {{scatterOp.getResult(), BufferRelation::Equivalent}};
+  }
+
+  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+                          const BufferizationOptions &options,
+                          BufferizationState &state) const {
+    ScatterOp scatterOp = cast<vector::ScatterOp>(op);
+    assert(isa<TensorType>(scatterOp.getBaseType()) &&
+           "only tensor types expected");
+    FailureOr<Value> buffer =
+        getBuffer(rewriter, scatterOp.getBase(), options, state);
+    if (failed(buffer))
+      return failure();
+    vector::ScatterOp::create(rewriter, scatterOp.getLoc(), *buffer,
+                              scatterOp.getOffsets(), scatterOp.getIndices(),
+                              scatterOp.getMask(), scatterOp.getValueToStore());
+    replaceOpWithBufferizedValues(rewriter, op, *buffer);
+    return success();
+  }
+};
+
 /// Bufferization of vector.gather. Replaced with a new vector.gather that
 /// operates on a memref.
 struct GatherOpInterface
@@ -335,5 +381,6 @@ void mlir::vector::registerBufferizableOpInterfaceExternalModels(
     GatherOp::attachInterface<GatherOpInterface>(*ctx);
     MaskOp::attachInterface<MaskOpInterface>(*ctx);
     YieldOp::attachInterface<YieldOpInterface>(*ctx);
+    ScatterOp::attachInterface<ScatterOpInterface>(*ctx);
   });
 }
diff --git a/mlir/test/Dialect/Vector/bufferize.mlir b/mlir/test/Dialect/Vector/bufferize.mlir
index 887fb941cc651..70adefd0dc4ec 100644
--- a/mlir/test/Dialect/Vector/bufferize.mlir
+++ b/mlir/test/Dialect/Vector/bufferize.mlir
@@ -32,6 +32,26 @@ func.func @transfer_write(%t: tensor<?x?xf32>, %o1: index,
 
 // -----
 
+// CHECK-LABEL: func @scatter(
+//  CHECK-SAME:     %[[base:.*]]: tensor<16x16xf32>, %[[v:.*]]: vector<16xi32>,
+//  CHECK-SAME:     %[[mask:.*]]: vector<16xi1>, %[[value:.*]]: vector<16xf32>) -> tensor<16x16xf32>
+//       CHECK:   %[[buf:.*]] = bufferization.to_buffer %[[base]] : tensor<16x16xf32> to memref<16x16xf32>
+//       CHECK:   %[[c0:.*]] = arith.constant 0 : index
+//       CHECK:   %[[alloc:.*]] = memref.alloc() {alignment = 64 : i64} : memref<16x16xf32>
+//       CHECK:   memref.copy %[[buf]], %[[alloc]] : memref<16x16xf32> to memref<16x16xf32>
+//       CHECK:   vector.scatter %[[alloc]][%[[c0]], %[[c0]]] [%[[v]]], %[[mask]], %[[value]] : memref<16x16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
+//       CHECK:   %[[tensor:.*]] = bufferization.to_tensor %[[alloc]] : memref<16x16xf32> to tensor<16x16xf32>
+//       CHECK:   return %[[tensor]] : tensor<16x16xf32>
+func.func @scatter(%base: tensor<16x16xf32>, %v: vector<16xi32>, 
+                  %mask: vector<16xi1>, %value: vector<16xf32>) -> tensor<16x16xf32> {
+  %c0 = arith.constant 0 : index
+  %0 = vector.scatter %base[%c0, %c0][%v], %mask, %value
+      : tensor<16x16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> -> tensor<16x16xf32>
+  return %0 : tensor<16x16xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func @gather(
 //  CHECK-SAME:     %[[base:.*]]: tensor<?x?xf32>, %[[v:.*]]: vector<16xi32>,
 //  CHECK-SAME:     %[[mask:.*]]: vector<16xi1>, %[[pass_thru:.*]]: vector<16xf32>)
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 5f035e35a1b86..79b09e172145b 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1491,9 +1491,9 @@ func.func @gather_non_power_of_two_alignment(%base: memref<16xf32>, %indices: ve
 func.func @scatter_to_vector(%base: vector<16xf32>, %indices: vector<16xi32>,
                              %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
   %c0 = arith.constant 0 : index
-  // expected-error at +2 {{custom op 'vector.scatter' invalid kind of type specified}}
+  // expected-error at +1 {{'vector.scatter' op operand #0 must be Tensor or MemRef of any type values, but got 'vector<16xf32>'}}
   vector.scatter %base[%c0][%indices], %mask, %pass_thru
-    : vector<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+    : vector<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
 }
 
 // -----

``````````

</details>


https://github.com/llvm/llvm-project/pull/165548


More information about the Mlir-commits mailing list