[Mlir-commits] [mlir] [mlir][vector] Allow multi dim vectors in vector.scatter (PR #132217)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Mar 20 07:19:25 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Kunwar Grover (Groverkss)

<details>
<summary>Changes</summary>

This patch matches the definition of vector.scatter as a counter part of vector.gather.

All of the changes done in this patch make vector.scatter match vector.gather 's multi dimensional definition.

Unrolling for vector.scatter will be implemented in subsequent patches.

---

Patch is 21.52 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/132217.diff


11 Files Affected:

- (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+6-6) 
- (modified) mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h (+6-2) 
- (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp (+18-32) 
- (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp (+1) 
- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+2-2) 
- (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp (+12-6) 
- (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir (-46) 
- (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir (+37-3) 
- (modified) mlir/test/Dialect/Vector/invalid.mlir (+1-1) 
- (modified) mlir/test/Dialect/Vector/ops.mlir (+10-8) 
- (modified) mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp (+1) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index fbbf817ecff98..5fab2ee1194e8 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2034,9 +2034,9 @@ def Vector_ScatterOp :
   Vector_Op<"scatter">,
     Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
                Variadic<Index>:$indices,
-               VectorOfRankAndType<[1], [AnyInteger, Index]>:$index_vec,
-               VectorOfRankAndType<[1], [I1]>:$mask,
-               VectorOfRank<[1]>:$valueToStore)> {
+               VectorOfNonZeroRankOf<[AnyInteger, Index]>:$index_vec,
+               VectorOfNonZeroRankOf<[I1]>:$mask,
+               AnyVectorOfNonZeroRank:$valueToStore)> {
 
   let summary = [{
     scatters elements from a vector into memory as defined by an index vector
@@ -2044,9 +2044,9 @@ def Vector_ScatterOp :
   }];
 
   let description = [{
-    The scatter operation stores elements from a 1-D vector into memory as
-    defined by a base with indices and an additional 1-D index vector, but
-    only if the corresponding bit in a 1-D mask vector is set. Otherwise, no
+    The scatter operation stores elements from a n-D vector into memory as
+    defined by a base with indices and an additional n-D index vector, but
+    only if the corresponding bit in a n-D mask vector is set. Otherwise, no
     action is taken for that element. Informally the semantics are:
     ```
     if (mask[0]) base[index[0]] = value[0]
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 601a65333d026..77d8b82b2bad0 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -244,13 +244,17 @@ void populateVectorStepLoweringPatterns(RewritePatternSet &patterns,
 /// [FlattenGather]
 /// Flattens 2 or more dimensional `vector.gather` ops by unrolling the
 /// outermost dimension.
+void populateVectorGatherLoweringPatterns(RewritePatternSet &patterns,
+                                          PatternBenefit benefit = 1);
+
+/// Populate the pattern set with the following patterns:
 ///
 /// [Gather1DToConditionalLoads]
 /// Turns 1-d `vector.gather` into a scalarized sequence of `vector.loads` or
 /// `tensor.extract`s. To avoid out-of-bounds memory accesses, these
 /// loads/extracts are made conditional using `scf.if` ops.
-void populateVectorGatherLoweringPatterns(RewritePatternSet &patterns,
-                                          PatternBenefit benefit = 1);
+void populateVectorGatherToConditionalLoadPatterns(RewritePatternSet &patterns,
+                                                   PatternBenefit benefit = 1);
 
 /// Populates instances of `MaskOpRewritePattern` to lower masked operations
 /// with `vector.mask`. Patterns should rewrite the `vector.mask` operation and
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 94efec61a466c..4127f5b065bc8 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -269,6 +269,10 @@ class VectorGatherOpConversion
     if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
       return failure();
 
+    VectorType vType = gather.getVectorType();
+    if (vType.getRank() > 1)
+      return failure();
+
     auto loc = gather->getLoc();
 
     // Resolve alignment.
@@ -276,42 +280,21 @@ class VectorGatherOpConversion
     if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
       return failure();
 
+    // Resolve address.
     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
                                      adaptor.getIndices(), rewriter);
     Value base = adaptor.getBase();
 
-    auto llvmNDVectorTy = adaptor.getIndexVec().getType();
     // Handle the simple case of 1-D vector.
-    if (!isa<LLVM::LLVMArrayType>(llvmNDVectorTy)) {
-      auto vType = gather.getVectorType();
-      // Resolve address.
-      Value ptrs =
-          getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
-                         base, ptr, adaptor.getIndexVec(), vType);
-      // Replace with the gather intrinsic.
-      rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
-          gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(),
-          adaptor.getPassThru(), rewriter.getI32IntegerAttr(align));
-      return success();
-    }
-
-    const LLVMTypeConverter &typeConverter = *this->getTypeConverter();
-    auto callback = [align, memRefType, base, ptr, loc, &rewriter,
-                     &typeConverter](Type llvm1DVectorTy,
-                                     ValueRange vectorOperands) {
-      // Resolve address.
-      Value ptrs = getIndexedPtrs(
-          rewriter, loc, typeConverter, memRefType, base, ptr,
-          /*index=*/vectorOperands[0], cast<VectorType>(llvm1DVectorTy));
-      // Create the gather intrinsic.
-      return rewriter.create<LLVM::masked_gather>(
-          loc, llvm1DVectorTy, ptrs, /*mask=*/vectorOperands[1],
-          /*passThru=*/vectorOperands[2], rewriter.getI32IntegerAttr(align));
-    };
-    SmallVector<Value> vectorOperands = {
-        adaptor.getIndexVec(), adaptor.getMask(), adaptor.getPassThru()};
-    return LLVM::detail::handleMultidimensionalVectors(
-        gather, vectorOperands, *getTypeConverter(), callback, rewriter);
+    // Resolve address.
+    Value ptrs =
+        getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
+                       base, ptr, adaptor.getIndexVec(), vType);
+    // Replace with the gather intrinsic.
+    rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
+        gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(),
+        adaptor.getPassThru(), rewriter.getI32IntegerAttr(align));
+    return success();
   }
 };
 
@@ -330,13 +313,16 @@ class VectorScatterOpConversion
     if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
       return failure();
 
+    VectorType vType = scatter.getVectorType();
+    if (vType.getRank() > 1)
+      return failure();
+
     // Resolve alignment.
     unsigned align;
     if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
       return failure();
 
     // Resolve address.
-    VectorType vType = scatter.getVectorType();
     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
                                      adaptor.getIndices(), rewriter);
     Value ptrs =
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index eb1555df5d574..7082b92c95d1d 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -81,6 +81,7 @@ void ConvertVectorToLLVMPass::runOnOperation() {
     populateVectorInsertExtractStridedSliceTransforms(patterns);
     populateVectorStepLoweringPatterns(patterns);
     populateVectorRankReducingFMAPattern(patterns);
+    populateVectorGatherLoweringPatterns(patterns);
     (void)applyPatternsGreedily(getOperation(), std::move(patterns));
   }
 
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 8e0e723cf4ed3..59da2ebe4aae0 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5340,9 +5340,9 @@ LogicalResult ScatterOp::verify() {
     return emitOpError("base and valueToStore element type should match");
   if (llvm::size(getIndices()) != memType.getRank())
     return emitOpError("requires ") << memType.getRank() << " indices";
-  if (valueVType.getDimSize(0) != indVType.getDimSize(0))
+  if (valueVType.getShape() != indVType.getShape())
     return emitOpError("expected valueToStore dim to match indices dim");
-  if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
+  if (valueVType.getShape() != maskVType.getShape())
     return emitOpError("expected valueToStore dim to match mask dim");
   return success();
 }
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
index 3b38505becd18..eff8ee0e9de7a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
@@ -38,7 +38,7 @@ using namespace mlir;
 using namespace mlir::vector;
 
 namespace {
-/// Flattens 2 or more dimensional `vector.gather` ops by unrolling the
+/// Unrolls 2 or more dimensional `vector.gather` ops by unrolling the
 /// outermost dimension. For example:
 /// ```
 /// %g = vector.gather %base[%c0][%v], %mask, %pass_thru :
@@ -56,14 +56,14 @@ namespace {
 /// When applied exhaustively, this will produce a sequence of 1-d gather ops.
 ///
 /// Supports vector types with a fixed leading dimension.
-struct FlattenGather : OpRewritePattern<vector::GatherOp> {
+struct UnrollGather : OpRewritePattern<vector::GatherOp> {
   using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(vector::GatherOp op,
                                 PatternRewriter &rewriter) const override {
     VectorType resultTy = op.getType();
     if (resultTy.getRank() < 2)
-      return rewriter.notifyMatchFailure(op, "already flat");
+      return rewriter.notifyMatchFailure(op, "already 1-D");
 
     // Unrolling doesn't take vscale into account. Pattern is disabled for
     // vectors with leading scalable dim(s).
@@ -107,7 +107,8 @@ struct FlattenGather : OpRewritePattern<vector::GatherOp> {
 /// ```mlir
 ///   %subview = memref.subview %M (...)
 ///     : memref<100x3xf32> to memref<100xf32, strided<[3]>>
-///   %gather = vector.gather %subview[%idxs] (...) : memref<100xf32, strided<[3]>>
+///   %gather = vector.gather %subview[%idxs] (...) : memref<100xf32,
+///   strided<[3]>>
 /// ```
 /// ==>
 /// ```mlir
@@ -269,6 +270,11 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
 
 void mlir::vector::populateVectorGatherLoweringPatterns(
     RewritePatternSet &patterns, PatternBenefit benefit) {
-  patterns.add<FlattenGather, RemoveStrideFromGatherSource,
-               Gather1DToConditionalLoads>(patterns.getContext(), benefit);
+  patterns.add<UnrollGather>(patterns.getContext(), benefit);
+}
+
+void mlir::vector::populateVectorGatherToConditionalLoadPatterns(
+    RewritePatternSet &patterns, PatternBenefit benefit) {
+  patterns.add<RemoveStrideFromGatherSource, Gather1DToConditionalLoads>(
+      patterns.getContext(), benefit);
 }
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
index c3f06dd4d5dd1..44b4a25a051f1 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
@@ -2074,52 +2074,6 @@ func.func @gather_index_scalable(%arg0: memref<?xindex>, %arg1: vector<[3]xindex
 
 // -----
 
-func.func @gather_2d_from_1d(%arg0: memref<?xf32>, %arg1: vector<2x3xi32>, %arg2: vector<2x3xi1>, %arg3: vector<2x3xf32>) -> vector<2x3xf32> {
-  %0 = arith.constant 0: index
-  %1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 : memref<?xf32>, vector<2x3xi32>, vector<2x3xi1>, vector<2x3xf32> into vector<2x3xf32>
-  return %1 : vector<2x3xf32>
-}
-
-// CHECK-LABEL: func @gather_2d_from_1d
-// CHECK: %[[B:.*]] = llvm.getelementptr %{{.*}} : (!llvm.ptr, i64) -> !llvm.ptr, f32
-// CHECK: %[[I0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<3xi32>>
-// CHECK: %[[M0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<3xi1>>
-// CHECK: %[[S0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<3xf32>>
-// CHECK: %[[P0:.*]] = llvm.getelementptr %[[B]][%[[I0]]] : (!llvm.ptr, vector<3xi32>) -> !llvm.vec<3 x ptr>, f32
-// CHECK: %[[G0:.*]] = llvm.intr.masked.gather %[[P0]], %[[M0]], %[[S0]] {alignment = 4 : i32} : (!llvm.vec<3 x ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
-// CHECK: %{{.*}} = llvm.insertvalue %[[G0]], %{{.*}}[0] : !llvm.array<2 x vector<3xf32>>
-// CHECK: %[[I1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<3xi32>>
-// CHECK: %[[M1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<3xi1>>
-// CHECK: %[[S1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<3xf32>>
-// CHECK: %[[P1:.*]] = llvm.getelementptr %[[B]][%[[I1]]] : (!llvm.ptr, vector<3xi32>) -> !llvm.vec<3 x ptr>, f32
-// CHECK: %[[G1:.*]] = llvm.intr.masked.gather %[[P1]], %[[M1]], %[[S1]] {alignment = 4 : i32} : (!llvm.vec<3 x ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
-// CHECK: %{{.*}} = llvm.insertvalue %[[G1]], %{{.*}}[1] : !llvm.array<2 x vector<3xf32>>
-
-// -----
-
-func.func @gather_2d_from_1d_scalable(%arg0: memref<?xf32>, %arg1: vector<2x[3]xi32>, %arg2: vector<2x[3]xi1>, %arg3: vector<2x[3]xf32>) -> vector<2x[3]xf32> {
-  %0 = arith.constant 0: index
-  %1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 : memref<?xf32>, vector<2x[3]xi32>, vector<2x[3]xi1>, vector<2x[3]xf32> into vector<2x[3]xf32>
-  return %1 : vector<2x[3]xf32>
-}
-
-// CHECK-LABEL: func @gather_2d_from_1d_scalable
-// CHECK: %[[B:.*]] = llvm.getelementptr %{{.*}} : (!llvm.ptr, i64) -> !llvm.ptr, f32
-// CHECK: %[[I0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<[3]xi32>>
-// CHECK: %[[M0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<[3]xi1>>
-// CHECK: %[[S0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<[3]xf32>>
-// CHECK: %[[P0:.*]] = llvm.getelementptr %[[B]][%[[I0]]] : (!llvm.ptr, vector<[3]xi32>) -> !llvm.vec<? x 3 x ptr>, f32
-// CHECK: %[[G0:.*]] = llvm.intr.masked.gather %[[P0]], %[[M0]], %[[S0]] {alignment = 4 : i32} : (!llvm.vec<? x 3 x ptr>, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32>
-// CHECK: %{{.*}} = llvm.insertvalue %[[G0]], %{{.*}}[0] : !llvm.array<2 x vector<[3]xf32>>
-// CHECK: %[[I1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<[3]xi32>>
-// CHECK: %[[M1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<[3]xi1>>
-// CHECK: %[[S1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<[3]xf32>>
-// CHECK: %[[P1:.*]] = llvm.getelementptr %[[B]][%[[I1]]] : (!llvm.ptr, vector<[3]xi32>) -> !llvm.vec<? x 3 x ptr>, f32
-// CHECK: %[[G1:.*]] = llvm.intr.masked.gather %[[P1]], %[[M1]], %[[S1]] {alignment = 4 : i32} : (!llvm.vec<? x 3 x ptr>, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32>
-// CHECK: %{{.*}} = llvm.insertvalue %[[G1]], %{{.*}}[1] : !llvm.array<2 x vector<[3]xf32>>
-
-// -----
-
 
 func.func @gather_1d_from_2d(%arg0: memref<4x4xf32>, %arg1: vector<4xi32>, %arg2: vector<4xi1>, %arg3: vector<4xf32>) -> vector<4xf32> {
   %0 = arith.constant 3 : index
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 1ab28b9df2d19..f5c722e29420c 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1654,7 +1654,7 @@ func.func @flat_transpose(%arg0: vector<16xf32>) -> vector<16xf32> {
 
 //===----------------------------------------------------------------------===//
 // vector.gather
-//
+// 
 // NOTE: vector.constant_mask won't lower with
 //  * --convert-to-llvm="filter-dialects=vector",
 // hence testing here.
@@ -1663,7 +1663,7 @@ func.func @flat_transpose(%arg0: vector<16xf32>) -> vector<16xf32> {
 
 func.func @gather_with_mask(%arg0: memref<?xf32>, %arg1: vector<2x3xi32>, %arg2: vector<2x3xf32>) -> vector<2x3xf32> {
   %0 = arith.constant 0: index
-  %1 = vector.constant_mask [1, 2] : vector<2x3xi1>
+  %1 = vector.constant_mask [2, 2] : vector<2x3xi1>
   %2 = vector.gather %arg0[%0][%arg1], %1, %arg2 : memref<?xf32>, vector<2x3xi32>, vector<2x3xi1>, vector<2x3xf32> into vector<2x3xf32>
   return %2 : vector<2x3xf32>
 }
@@ -1679,7 +1679,7 @@ func.func @gather_with_mask_scalable(%arg0: memref<?xf32>, %arg1: vector<2x[3]xi
   // vector.constant_mask only supports 'none set' or 'all set' scalable
   // dimensions, hence [1, 3] rather than [1, 2] as in the example for fixed
   // width vectors above.
-  %1 = vector.constant_mask [1, 3] : vector<2x[3]xi1>
+  %1 = vector.constant_mask [2, 3] : vector<2x[3]xi1>
   %2 = vector.gather %arg0[%0][%arg1], %1, %arg2 : memref<?xf32>, vector<2x[3]xi32>, vector<2x[3]xi1>, vector<2x[3]xf32> into vector<2x[3]xf32>
   return %2 : vector<2x[3]xf32>
 }
@@ -1719,6 +1719,40 @@ func.func @gather_with_zero_mask_scalable(%arg0: memref<?xf32>, %arg1: vector<2x
 
 // -----
 
+//===----------------------------------------------------------------------===//
+// vector.scatter
+//===----------------------------------------------------------------------===//
+
+// Multi-Dimensional scatters are not supported yet. Check that we do not lower
+// them.
+
+func.func @scatter_with_mask(%arg0: memref<?xf32>, %arg1: vector<2x3xi32>, %arg2: vector<2x3xf32>) {
+  %0 = arith.constant 0: index
+  %1 = vector.constant_mask [2, 2] : vector<2x3xi1>
+  vector.scatter %arg0[%0][%arg1], %1, %arg2 : memref<?xf32>, vector<2x3xi32>, vector<2x3xi1>, vector<2x3xf32>
+  return
+}
+
+// CHECK-LABEL: func @scatter_with_mask
+// CHECK: vector.scatter
+
+// -----
+
+func.func @scatter_with_mask_scalable(%arg0: memref<?xf32>, %arg1: vector<2x[3]xi32>, %arg2: vector<2x[3]xf32>) {
+  %0 = arith.constant 0: index
+  // vector.constant_mask only supports 'none set' or 'all set' scalable
+  // dimensions, hence [1, 3] rather than [1, 2] as in the example for fixed
+  // width vectors above.
+  %1 = vector.constant_mask [2, 3] : vector<2x[3]xi1>
+  vector.scatter %arg0[%0][%arg1], %1, %arg2 : memref<?xf32>, vector<2x[3]xi32>, vector<2x[3]xi1>, vector<2x[3]xf32>
+  return
+}
+
+// CHECK-LABEL: func @scatter_with_mask_scalable
+// CHECK: vector.scatter
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // vector.interleave
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 57e348c7d5991..1b89e8eb5069b 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1484,7 +1484,7 @@ func.func @scatter_memref_mismatch(%base: memref<?x?xf64>, %indices: vector<16xi
 func.func @scatter_rank_mismatch(%base: memref<?xf32>, %indices: vector<16xi32>,
                             %mask: vector<16xi1>, %value: vector<2x16xf32>) {
   %c0 = arith.constant 0 : index
-  // expected-error at +1 {{'vector.scatter' op operand #4 must be  of ranks 1, but got 'vector<2x16xf32>'}}
+  // expected-error at +1 {{'vector.scatter' op expected valueToStore dim to match indices dim}}
   vector.scatter %base[%c0][%indices], %mask, %value
     : memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<2x16xf32>
 }
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 67484e06f456d..279fd3e522775 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -882,6 +882,16 @@ func.func @gather_and_scatter2d(%base: memref<?x?xf32>, %v: vector<16xi32>, %mas
   return
 }
 
+// CHECK-LABEL: @gather_multi_dims
+func.func @gather_multi_dims(%base: memref<?xf32>, %v: vector<2x16xi32>, %mask: vector<2x16xi1>, %pass_thru: vector<2x16xf32>) -> vector<2x16xf32> {
+  %c0 = arith.constant 0 : index
+  // CHECK: %[[X:.*]] = vector.gather %{{.*}}[%{{.*}}] [%{{.*}}], %{{.*}}, %{{.*}} : memref<?xf32>, vector<2x16xi32>, vector<2x16xi1>, vector<2x16xf32> into vector<2x16xf32>
+  %0 = vector.gather %base[%c0][%v], %mask, %pass_thru : memref<?xf32>, vector<2x16xi32>, vector<2x16xi1>, vector<2x16xf32> into vector<2x16xf32>
+  // CHECK: vector.scatter %{{.*}}[%{{.*}}] [%{{.*}}], %{{.*}}, %[[X]] : memref<?xf32>, vector<2x16xi32>, vector<2x16xi1>, vector<2x16xf32>
+  vector.scatter %base[%c0][%v], %mask, %0 : memref<?xf32>, vector<2x16xi32>, vector<2x16xi1>, vector<2x16xf32>
+  return %0 : vector<2x16xf32>
+}
+
 // CHECK-LABEL: @gather_on_tensor
 func.func @gather_on_tensor(%base: tensor<?xf32>, %v: vector<16xi32>, %mask: vector<16xi1>, %pass_thru: vector<16xf3...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/132217


More information about the Mlir-commits mailing list