[Mlir-commits] [mlir] [mlir][vector] Allow multi dim vectors in vector.scatter (PR #132217)

Kunwar Grover llvmlistbot at llvm.org
Mon Mar 24 05:30:39 PDT 2025


https://github.com/Groverkss updated https://github.com/llvm/llvm-project/pull/132217

>From 840e123e14449ef1b5ddd34a3fc65407ae1d291a Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Thu, 20 Mar 2025 14:04:28 +0000
Subject: [PATCH 1/2] [mlir][vector] Allow multi dim vectors in vector.scatter

---
 .../mlir/Dialect/Vector/IR/VectorOps.td       | 12 +++----
 .../VectorToLLVM/ConvertVectorToLLVM.cpp      |  5 ++-
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      |  4 +--
 .../VectorToLLVM/vector-to-llvm.mlir          | 36 ++++++++++++++++++-
 mlir/test/Dialect/Vector/invalid.mlir         |  2 +-
 mlir/test/Dialect/Vector/ops.mlir             | 18 +++++-----
 6 files changed, 58 insertions(+), 19 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index fbbf817ecff98..5fab2ee1194e8 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2034,9 +2034,9 @@ def Vector_ScatterOp :
   Vector_Op<"scatter">,
     Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
                Variadic<Index>:$indices,
-               VectorOfRankAndType<[1], [AnyInteger, Index]>:$index_vec,
-               VectorOfRankAndType<[1], [I1]>:$mask,
-               VectorOfRank<[1]>:$valueToStore)> {
+               VectorOfNonZeroRankOf<[AnyInteger, Index]>:$index_vec,
+               VectorOfNonZeroRankOf<[I1]>:$mask,
+               AnyVectorOfNonZeroRank:$valueToStore)> {
 
   let summary = [{
     scatters elements from a vector into memory as defined by an index vector
@@ -2044,9 +2044,9 @@ def Vector_ScatterOp :
   }];
 
   let description = [{
-    The scatter operation stores elements from a 1-D vector into memory as
-    defined by a base with indices and an additional 1-D index vector, but
-    only if the corresponding bit in a 1-D mask vector is set. Otherwise, no
+    The scatter operation stores elements from a n-D vector into memory as
+    defined by a base with indices and an additional n-D index vector, but
+    only if the corresponding bit in a n-D mask vector is set. Otherwise, no
     action is taken for that element. Informally the semantics are:
     ```
     if (mask[0]) base[index[0]] = value[0]
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 357152eba8003..15dede34e54f4 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -311,13 +311,16 @@ class VectorScatterOpConversion
     if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
       return failure();
 
+    VectorType vType = scatter.getVectorType();
+    if (vType.getRank() > 1)
+      return failure();
+
     // Resolve alignment.
     unsigned align;
     if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
       return failure();
 
     // Resolve address.
-    VectorType vType = scatter.getVectorType();
     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
                                      adaptor.getIndices(), rewriter);
     Value ptrs =
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index d4c1da30d498d..d006a1498f350 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5340,9 +5340,9 @@ LogicalResult ScatterOp::verify() {
     return emitOpError("base and valueToStore element type should match");
   if (llvm::size(getIndices()) != memType.getRank())
     return emitOpError("requires ") << memType.getRank() << " indices";
-  if (valueVType.getDimSize(0) != indVType.getDimSize(0))
+  if (valueVType.getShape() != indVType.getShape())
     return emitOpError("expected valueToStore dim to match indices dim");
-  if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
+  if (valueVType.getShape() != maskVType.getShape())
     return emitOpError("expected valueToStore dim to match mask dim");
   return success();
 }
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 5404fdda033ee..01c718796018d 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1654,7 +1654,7 @@ func.func @flat_transpose(%arg0: vector<16xf32>) -> vector<16xf32> {
 
 //===----------------------------------------------------------------------===//
 // vector.gather
-//
+// 
 // NOTE: vector.constant_mask won't lower with
 //  * --convert-to-llvm="filter-dialects=vector",
 // hence testing here.
@@ -1719,6 +1719,40 @@ func.func @gather_with_zero_mask_scalable(%arg0: memref<?xf32>, %arg1: vector<2x
 
 // -----
 
+//===----------------------------------------------------------------------===//
+// vector.scatter
+//===----------------------------------------------------------------------===//
+
+// Multi-Dimensional scatters are not supported yet. Check that we do not lower
+// them.
+
+func.func @scatter_with_mask(%arg0: memref<?xf32>, %arg1: vector<2x3xi32>, %arg2: vector<2x3xf32>) {
+  %0 = arith.constant 0: index
+  %1 = vector.constant_mask [2, 2] : vector<2x3xi1>
+  vector.scatter %arg0[%0][%arg1], %1, %arg2 : memref<?xf32>, vector<2x3xi32>, vector<2x3xi1>, vector<2x3xf32>
+  return
+}
+
+// CHECK-LABEL: func @scatter_with_mask
+// CHECK: vector.scatter
+
+// -----
+
+func.func @scatter_with_mask_scalable(%arg0: memref<?xf32>, %arg1: vector<2x[3]xi32>, %arg2: vector<2x[3]xf32>) {
+  %0 = arith.constant 0: index
+  // vector.constant_mask only supports 'none set' or 'all set' scalable
+  // dimensions, hence [1, 3] rather than [1, 2] as in the example for fixed
+  // width vectors above.
+  %1 = vector.constant_mask [2, 3] : vector<2x[3]xi1>
+  vector.scatter %arg0[%0][%arg1], %1, %arg2 : memref<?xf32>, vector<2x[3]xi32>, vector<2x[3]xi1>, vector<2x[3]xf32>
+  return
+}
+
+// CHECK-LABEL: func @scatter_with_mask_scalable
+// CHECK: vector.scatter
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // vector.interleave
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 57e348c7d5991..1b89e8eb5069b 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1484,7 +1484,7 @@ func.func @scatter_memref_mismatch(%base: memref<?x?xf64>, %indices: vector<16xi
 func.func @scatter_rank_mismatch(%base: memref<?xf32>, %indices: vector<16xi32>,
                             %mask: vector<16xi1>, %value: vector<2x16xf32>) {
   %c0 = arith.constant 0 : index
-  // expected-error at +1 {{'vector.scatter' op operand #4 must be  of ranks 1, but got 'vector<2x16xf32>'}}
+  // expected-error at +1 {{'vector.scatter' op expected valueToStore dim to match indices dim}}
   vector.scatter %base[%c0][%indices], %mask, %value
     : memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<2x16xf32>
 }
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 67484e06f456d..279fd3e522775 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -882,6 +882,16 @@ func.func @gather_and_scatter2d(%base: memref<?x?xf32>, %v: vector<16xi32>, %mas
   return
 }
 
+// CHECK-LABEL: @gather_multi_dims
+func.func @gather_multi_dims(%base: memref<?xf32>, %v: vector<2x16xi32>, %mask: vector<2x16xi1>, %pass_thru: vector<2x16xf32>) -> vector<2x16xf32> {
+  %c0 = arith.constant 0 : index
+  // CHECK: %[[X:.*]] = vector.gather %{{.*}}[%{{.*}}] [%{{.*}}], %{{.*}}, %{{.*}} : memref<?xf32>, vector<2x16xi32>, vector<2x16xi1>, vector<2x16xf32> into vector<2x16xf32>
+  %0 = vector.gather %base[%c0][%v], %mask, %pass_thru : memref<?xf32>, vector<2x16xi32>, vector<2x16xi1>, vector<2x16xf32> into vector<2x16xf32>
+  // CHECK: vector.scatter %{{.*}}[%{{.*}}] [%{{.*}}], %{{.*}}, %[[X]] : memref<?xf32>, vector<2x16xi32>, vector<2x16xi1>, vector<2x16xf32>
+  vector.scatter %base[%c0][%v], %mask, %0 : memref<?xf32>, vector<2x16xi32>, vector<2x16xi1>, vector<2x16xf32>
+  return %0 : vector<2x16xf32>
+}
+
 // CHECK-LABEL: @gather_on_tensor
 func.func @gather_on_tensor(%base: tensor<?xf32>, %v: vector<16xi32>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
   %c0 = arith.constant 0 : index
@@ -890,14 +900,6 @@ func.func @gather_on_tensor(%base: tensor<?xf32>, %v: vector<16xi32>, %mask: vec
   return %0 : vector<16xf32>
 }
 
-// CHECK-LABEL: @gather_multi_dims
-func.func @gather_multi_dims(%base: tensor<?xf32>, %v: vector<2x16xi32>, %mask: vector<2x16xi1>, %pass_thru: vector<2x16xf32>) -> vector<2x16xf32> {
-  %c0 = arith.constant 0 : index
-  // CHECK: vector.gather %{{.*}}[%{{.*}}] [%{{.*}}], %{{.*}}, %{{.*}} : tensor<?xf32>, vector<2x16xi32>, vector<2x16xi1>, vector<2x16xf32> into vector<2x16xf32>
-  %0 = vector.gather %base[%c0][%v], %mask, %pass_thru : tensor<?xf32>, vector<2x16xi32>, vector<2x16xi1>, vector<2x16xf32> into vector<2x16xf32>
-  return %0 : vector<2x16xf32>
-}
-
 // CHECK-LABEL: @expand_and_compress
 func.func @expand_and_compress(%base: memref<?xf32>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
   %c0 = arith.constant 0 : index

>From 81d517ad86a968a176e0eaac82d9ebccb8d8235c Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Mon, 24 Mar 2025 12:19:48 +0000
Subject: [PATCH 2/2] Address comments

---
 .../VectorToLLVM/ConvertVectorToLLVM.cpp      | 31 ++++++++++++-------
 .../VectorToLLVM/vector-to-llvm.mlir          |  4 +--
 mlir/test/Dialect/Vector/ops.mlir             |  4 +--
 3 files changed, 23 insertions(+), 16 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 15dede34e54f4..213f7375b8d13 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -263,22 +263,25 @@ class VectorGatherOpConversion
   LogicalResult
   matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+    Location loc = gather->getLoc();
     MemRefType memRefType = dyn_cast<MemRefType>(gather.getBaseType());
     assert(memRefType && "The base should be bufferized");
 
     if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
-      return failure();
+      return rewriter.notifyMatchFailure(gather, "memref type not supported");
 
     VectorType vType = gather.getVectorType();
-    if (vType.getRank() > 1)
-      return failure();
-
-    Location loc = gather->getLoc();
+    if (vType.getRank() > 1) {
+      return rewriter.notifyMatchFailure(
+          gather, "only 1-D vectors can be lowered to LLVM");
+    }
 
     // Resolve alignment.
     unsigned align;
-    if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
-      return failure();
+    if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) {
+      return rewriter.notifyMatchFailure(gather,
+                                         "could not resolve memref alignment");
+    }
 
     // Resolve address.
     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
@@ -309,16 +312,20 @@ class VectorScatterOpConversion
     MemRefType memRefType = scatter.getMemRefType();
 
     if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
-      return failure();
+      return rewriter.notifyMatchFailure(scatter, "memref type not supported");
 
     VectorType vType = scatter.getVectorType();
-    if (vType.getRank() > 1)
-      return failure();
+    if (vType.getRank() > 1) {
+      return rewriter.notifyMatchFailure(
+          scatter, "only 1-D vectors can be lowered to LLVM");
+    }
 
     // Resolve alignment.
     unsigned align;
-    if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
-      return failure();
+    if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) {
+      return rewriter.notifyMatchFailure(scatter,
+                                         "could not resolve memref alignment");
+    }
 
     // Resolve address.
     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 01c718796018d..ba1da84719106 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1654,7 +1654,7 @@ func.func @flat_transpose(%arg0: vector<16xf32>) -> vector<16xf32> {
 
 //===----------------------------------------------------------------------===//
 // vector.gather
-// 
+//
 // NOTE: vector.constant_mask won't lower with
 //  * --convert-to-llvm="filter-dialects=vector",
 // hence testing here.
@@ -1741,7 +1741,7 @@ func.func @scatter_with_mask(%arg0: memref<?xf32>, %arg1: vector<2x3xi32>, %arg2
 func.func @scatter_with_mask_scalable(%arg0: memref<?xf32>, %arg1: vector<2x[3]xi32>, %arg2: vector<2x[3]xf32>) {
   %0 = arith.constant 0: index
   // vector.constant_mask only supports 'none set' or 'all set' scalable
-  // dimensions, hence [1, 3] rather than [1, 2] as in the example for fixed
+  // dimensions, hence [2, 3] rather than [2, 2] as in the example for fixed
   // width vectors above.
   %1 = vector.constant_mask [2, 3] : vector<2x[3]xi1>
   vector.scatter %arg0[%0][%arg1], %1, %arg2 : memref<?xf32>, vector<2x[3]xi32>, vector<2x[3]xi1>, vector<2x[3]xf32>
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 279fd3e522775..8ae1e9f9d0c64 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -882,8 +882,8 @@ func.func @gather_and_scatter2d(%base: memref<?x?xf32>, %v: vector<16xi32>, %mas
   return
 }
 
-// CHECK-LABEL: @gather_multi_dims
-func.func @gather_multi_dims(%base: memref<?xf32>, %v: vector<2x16xi32>, %mask: vector<2x16xi1>, %pass_thru: vector<2x16xf32>) -> vector<2x16xf32> {
+// CHECK-LABEL: @gather_and_scatter_multi_dims
+func.func @gather_and_scatter_multi_dims(%base: memref<?xf32>, %v: vector<2x16xi32>, %mask: vector<2x16xi1>, %pass_thru: vector<2x16xf32>) -> vector<2x16xf32> {
   %c0 = arith.constant 0 : index
   // CHECK: %[[X:.*]] = vector.gather %{{.*}}[%{{.*}}] [%{{.*}}], %{{.*}}, %{{.*}} : memref<?xf32>, vector<2x16xi32>, vector<2x16xi1>, vector<2x16xf32> into vector<2x16xf32>
   %0 = vector.gather %base[%c0][%v], %mask, %pass_thru : memref<?xf32>, vector<2x16xi32>, vector<2x16xi1>, vector<2x16xf32> into vector<2x16xf32>



More information about the Mlir-commits mailing list