[Mlir-commits] [mlir] 24a8e18 - [mlir][vector] Allow multi dim vectors in vector.scatter (#132217)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Mar 24 05:52:49 PDT 2025


Author: Kunwar Grover
Date: 2025-03-24T12:52:46Z
New Revision: 24a8e18f5a1dfddf7c9a0704a1ccb96a235d3767

URL: https://github.com/llvm/llvm-project/commit/24a8e18f5a1dfddf7c9a0704a1ccb96a235d3767
DIFF: https://github.com/llvm/llvm-project/commit/24a8e18f5a1dfddf7c9a0704a1ccb96a235d3767.diff

LOG: [mlir][vector] Allow multi dim vectors in vector.scatter (#132217)

This patch matches the definition of vector.scatter as a counter part of
vector.gather.

All of the changes done in this patch make vector.scatter match
vector.gather 's multi dimensional definition.

Unrolling for vector.scatter will be implemented in subsequent patches.

Discourse Discussion:
https://discourse.llvm.org/t/rfc-improving-gather-codegen-for-vector-dialect/85011/13

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
    mlir/test/Dialect/Vector/invalid.mlir
    mlir/test/Dialect/Vector/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index fbbf817ecff98..5fab2ee1194e8 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2034,9 +2034,9 @@ def Vector_ScatterOp :
   Vector_Op<"scatter">,
     Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
                Variadic<Index>:$indices,
-               VectorOfRankAndType<[1], [AnyInteger, Index]>:$index_vec,
-               VectorOfRankAndType<[1], [I1]>:$mask,
-               VectorOfRank<[1]>:$valueToStore)> {
+               VectorOfNonZeroRankOf<[AnyInteger, Index]>:$index_vec,
+               VectorOfNonZeroRankOf<[I1]>:$mask,
+               AnyVectorOfNonZeroRank:$valueToStore)> {
 
   let summary = [{
     scatters elements from a vector into memory as defined by an index vector
@@ -2044,9 +2044,9 @@ def Vector_ScatterOp :
   }];
 
   let description = [{
-    The scatter operation stores elements from a 1-D vector into memory as
-    defined by a base with indices and an additional 1-D index vector, but
-    only if the corresponding bit in a 1-D mask vector is set. Otherwise, no
+    The scatter operation stores elements from a n-D vector into memory as
+    defined by a base with indices and an additional n-D index vector, but
+    only if the corresponding bit in a n-D mask vector is set. Otherwise, no
     action is taken for that element. Informally the semantics are:
     ```
     if (mask[0]) base[index[0]] = value[0]

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 357152eba8003..213f7375b8d13 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -263,22 +263,25 @@ class VectorGatherOpConversion
   LogicalResult
   matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+    Location loc = gather->getLoc();
     MemRefType memRefType = dyn_cast<MemRefType>(gather.getBaseType());
     assert(memRefType && "The base should be bufferized");
 
     if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
-      return failure();
+      return rewriter.notifyMatchFailure(gather, "memref type not supported");
 
     VectorType vType = gather.getVectorType();
-    if (vType.getRank() > 1)
-      return failure();
-
-    Location loc = gather->getLoc();
+    if (vType.getRank() > 1) {
+      return rewriter.notifyMatchFailure(
+          gather, "only 1-D vectors can be lowered to LLVM");
+    }
 
     // Resolve alignment.
     unsigned align;
-    if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
-      return failure();
+    if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) {
+      return rewriter.notifyMatchFailure(gather,
+                                         "could not resolve memref alignment");
+    }
 
     // Resolve address.
     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
@@ -309,15 +312,22 @@ class VectorScatterOpConversion
     MemRefType memRefType = scatter.getMemRefType();
 
     if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
-      return failure();
+      return rewriter.notifyMatchFailure(scatter, "memref type not supported");
+
+    VectorType vType = scatter.getVectorType();
+    if (vType.getRank() > 1) {
+      return rewriter.notifyMatchFailure(
+          scatter, "only 1-D vectors can be lowered to LLVM");
+    }
 
     // Resolve alignment.
     unsigned align;
-    if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
-      return failure();
+    if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) {
+      return rewriter.notifyMatchFailure(scatter,
+                                         "could not resolve memref alignment");
+    }
 
     // Resolve address.
-    VectorType vType = scatter.getVectorType();
     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
                                      adaptor.getIndices(), rewriter);
     Value ptrs =

diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index d4c1da30d498d..d006a1498f350 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5340,9 +5340,9 @@ LogicalResult ScatterOp::verify() {
     return emitOpError("base and valueToStore element type should match");
   if (llvm::size(getIndices()) != memType.getRank())
     return emitOpError("requires ") << memType.getRank() << " indices";
-  if (valueVType.getDimSize(0) != indVType.getDimSize(0))
+  if (valueVType.getShape() != indVType.getShape())
     return emitOpError("expected valueToStore dim to match indices dim");
-  if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
+  if (valueVType.getShape() != maskVType.getShape())
     return emitOpError("expected valueToStore dim to match mask dim");
   return success();
 }

diff  --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 5404fdda033ee..ba1da84719106 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1719,6 +1719,40 @@ func.func @gather_with_zero_mask_scalable(%arg0: memref<?xf32>, %arg1: vector<2x
 
 // -----
 
+//===----------------------------------------------------------------------===//
+// vector.scatter
+//===----------------------------------------------------------------------===//
+
+// Multi-Dimensional scatters are not supported yet. Check that we do not lower
+// them.
+
+func.func @scatter_with_mask(%arg0: memref<?xf32>, %arg1: vector<2x3xi32>, %arg2: vector<2x3xf32>) {
+  %0 = arith.constant 0: index
+  %1 = vector.constant_mask [2, 2] : vector<2x3xi1>
+  vector.scatter %arg0[%0][%arg1], %1, %arg2 : memref<?xf32>, vector<2x3xi32>, vector<2x3xi1>, vector<2x3xf32>
+  return
+}
+
+// CHECK-LABEL: func @scatter_with_mask
+// CHECK: vector.scatter
+
+// -----
+
+func.func @scatter_with_mask_scalable(%arg0: memref<?xf32>, %arg1: vector<2x[3]xi32>, %arg2: vector<2x[3]xf32>) {
+  %0 = arith.constant 0: index
+  // vector.constant_mask only supports 'none set' or 'all set' scalable
+  // dimensions, hence [2, 3] rather than [2, 2] as in the example for fixed
+  // width vectors above.
+  %1 = vector.constant_mask [2, 3] : vector<2x[3]xi1>
+  vector.scatter %arg0[%0][%arg1], %1, %arg2 : memref<?xf32>, vector<2x[3]xi32>, vector<2x[3]xi1>, vector<2x[3]xf32>
+  return
+}
+
+// CHECK-LABEL: func @scatter_with_mask_scalable
+// CHECK: vector.scatter
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // vector.interleave
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 57e348c7d5991..1b89e8eb5069b 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1484,7 +1484,7 @@ func.func @scatter_memref_mismatch(%base: memref<?x?xf64>, %indices: vector<16xi
 func.func @scatter_rank_mismatch(%base: memref<?xf32>, %indices: vector<16xi32>,
                             %mask: vector<16xi1>, %value: vector<2x16xf32>) {
   %c0 = arith.constant 0 : index
-  // expected-error at +1 {{'vector.scatter' op operand #4 must be  of ranks 1, but got 'vector<2x16xf32>'}}
+  // expected-error at +1 {{'vector.scatter' op expected valueToStore dim to match indices dim}}
   vector.scatter %base[%c0][%indices], %mask, %value
     : memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<2x16xf32>
 }

diff  --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 67484e06f456d..8ae1e9f9d0c64 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -882,6 +882,16 @@ func.func @gather_and_scatter2d(%base: memref<?x?xf32>, %v: vector<16xi32>, %mas
   return
 }
 
+// CHECK-LABEL: @gather_and_scatter_multi_dims
+func.func @gather_and_scatter_multi_dims(%base: memref<?xf32>, %v: vector<2x16xi32>, %mask: vector<2x16xi1>, %pass_thru: vector<2x16xf32>) -> vector<2x16xf32> {
+  %c0 = arith.constant 0 : index
+  // CHECK: %[[X:.*]] = vector.gather %{{.*}}[%{{.*}}] [%{{.*}}], %{{.*}}, %{{.*}} : memref<?xf32>, vector<2x16xi32>, vector<2x16xi1>, vector<2x16xf32> into vector<2x16xf32>
+  %0 = vector.gather %base[%c0][%v], %mask, %pass_thru : memref<?xf32>, vector<2x16xi32>, vector<2x16xi1>, vector<2x16xf32> into vector<2x16xf32>
+  // CHECK: vector.scatter %{{.*}}[%{{.*}}] [%{{.*}}], %{{.*}}, %[[X]] : memref<?xf32>, vector<2x16xi32>, vector<2x16xi1>, vector<2x16xf32>
+  vector.scatter %base[%c0][%v], %mask, %0 : memref<?xf32>, vector<2x16xi32>, vector<2x16xi1>, vector<2x16xf32>
+  return %0 : vector<2x16xf32>
+}
+
 // CHECK-LABEL: @gather_on_tensor
 func.func @gather_on_tensor(%base: tensor<?xf32>, %v: vector<16xi32>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
   %c0 = arith.constant 0 : index
@@ -890,14 +900,6 @@ func.func @gather_on_tensor(%base: tensor<?xf32>, %v: vector<16xi32>, %mask: vec
   return %0 : vector<16xf32>
 }
 
-// CHECK-LABEL: @gather_multi_dims
-func.func @gather_multi_dims(%base: tensor<?xf32>, %v: vector<2x16xi32>, %mask: vector<2x16xi1>, %pass_thru: vector<2x16xf32>) -> vector<2x16xf32> {
-  %c0 = arith.constant 0 : index
-  // CHECK: vector.gather %{{.*}}[%{{.*}}] [%{{.*}}], %{{.*}}, %{{.*}} : tensor<?xf32>, vector<2x16xi32>, vector<2x16xi1>, vector<2x16xf32> into vector<2x16xf32>
-  %0 = vector.gather %base[%c0][%v], %mask, %pass_thru : tensor<?xf32>, vector<2x16xi32>, vector<2x16xi1>, vector<2x16xf32> into vector<2x16xf32>
-  return %0 : vector<2x16xf32>
-}
-
 // CHECK-LABEL: @expand_and_compress
 func.func @expand_and_compress(%base: memref<?xf32>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
   %c0 = arith.constant 0 : index


        


More information about the Mlir-commits mailing list