[Mlir-commits] [mlir] [mlir][vector] Allow lowering multi-dim scatters to LLVM (PR #132227)

Kunwar Grover llvmlistbot at llvm.org
Mon Mar 24 06:08:39 PDT 2025


https://github.com/Groverkss updated https://github.com/llvm/llvm-project/pull/132227

>From cfc0377ee1fb28a97b68c397f89481938a81c9ef Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Thu, 20 Mar 2025 14:51:13 +0000
Subject: [PATCH] [mlir][vector] Allow lowering multi-dim scatters to LLVM

---
 .../Vector/Transforms/LoweringPatterns.h      |  8 ++-
 .../VectorToLLVM/ConvertVectorToLLVM.cpp      |  5 +-
 .../VectorToLLVM/ConvertVectorToLLVMPass.cpp  |  2 +-
 .../TransformOps/VectorTransformOps.cpp       |  2 +-
 .../Dialect/Vector/Transforms/CMakeLists.txt  |  2 +-
 ...ather.cpp => LowerVectorGatherScatter.cpp} | 69 ++++++++++++++++---
 .../VectorToLLVM/vector-to-llvm.mlir          |  6 +-
 .../Dialect/Vector/TestVectorTransforms.cpp   |  2 +-
 8 files changed, 74 insertions(+), 22 deletions(-)
 rename mlir/lib/Dialect/Vector/Transforms/{LowerVectorGather.cpp => LowerVectorGatherScatter.cpp} (80%)

diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 14cff4ff893b5..528de2340f7b7 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -244,8 +244,12 @@ void populateVectorStepLoweringPatterns(RewritePatternSet &patterns,
 /// [UnrollGather]
 /// Unrolls 2 or more dimensional `vector.gather` ops by unrolling the
 /// outermost dimension.
-void populateVectorGatherLoweringPatterns(RewritePatternSet &patterns,
-                                          PatternBenefit benefit = 1);
+///
+/// [UnrollScatter]
+/// Unrolls 2 or more dimensional `vector.scatter` ops by unrolling the
+/// outermost dimension.
+void populateVectorGatherScatterLoweringPatterns(RewritePatternSet &patterns,
+                                                 PatternBenefit benefit = 1);
 
 /// Populate the pattern set with the following patterns:
 ///
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 213f7375b8d13..f77b7f2895b3a 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -286,10 +286,9 @@ class VectorGatherOpConversion
     // Resolve address.
     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
                                      adaptor.getIndices(), rewriter);
-    Value base = adaptor.getBase();
     Value ptrs =
         getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
-                       base, ptr, adaptor.getIndexVec(), vType);
+                       adaptor.getBase(), ptr, adaptor.getIndexVec(), vType);
 
     // Replace with the gather intrinsic.
     rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
@@ -308,7 +307,7 @@ class VectorScatterOpConversion
   LogicalResult
   matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    auto loc = scatter->getLoc();
+    Location loc = scatter->getLoc();
     MemRefType memRefType = scatter.getMemRefType();
 
     if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 7082b92c95d1d..dfa188bdfc5cc 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -81,7 +81,7 @@ void ConvertVectorToLLVMPass::runOnOperation() {
     populateVectorInsertExtractStridedSliceTransforms(patterns);
     populateVectorStepLoweringPatterns(patterns);
     populateVectorRankReducingFMAPattern(patterns);
-    populateVectorGatherLoweringPatterns(patterns);
+    populateVectorGatherScatterLoweringPatterns(patterns);
     (void)applyPatternsGreedily(getOperation(), std::move(patterns));
   }
 
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 20c577273d786..623b9aa83fff3 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -138,7 +138,7 @@ void transform::ApplyLowerOuterProductPatternsOp::populatePatterns(
 
 void transform::ApplyLowerGatherPatternsOp::populatePatterns(
     RewritePatternSet &patterns) {
-  vector::populateVectorGatherLoweringPatterns(patterns);
+  vector::populateVectorGatherScatterLoweringPatterns(patterns);
 }
 
 void transform::ApplyLowerScanPatternsOp::populatePatterns(
diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index 8ca5cb6c6dfab..8abaa6ac527eb 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -3,7 +3,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
   LowerVectorBitCast.cpp
   LowerVectorBroadcast.cpp
   LowerVectorContract.cpp
-  LowerVectorGather.cpp
+  LowerVectorGatherScatter.cpp
   LowerVectorInterleave.cpp
   LowerVectorMask.cpp
   LowerVectorMultiReduction.cpp
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGatherScatter.cpp
similarity index 80%
rename from mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
rename to mlir/lib/Dialect/Vector/Transforms/LowerVectorGatherScatter.cpp
index 3000204c8ce17..2c6f237457609 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGatherScatter.cpp
@@ -38,6 +38,7 @@ using namespace mlir;
 using namespace mlir::vector;
 
 namespace {
+
 /// Unrolls 2 or more dimensional `vector.gather` ops by unrolling the
 /// outermost dimension. For example:
 /// ```
@@ -81,19 +82,14 @@ struct UnrollGather : OpRewritePattern<vector::GatherOp> {
     VectorType subTy = VectorType::Builder(resultTy).dropDim(0);
 
     for (int64_t i = 0, e = resultTy.getShape().front(); i < e; ++i) {
-      int64_t thisIdx[1] = {i};
-
-      Value indexSubVec =
-          rewriter.create<vector::ExtractOp>(loc, indexVec, thisIdx);
-      Value maskSubVec =
-          rewriter.create<vector::ExtractOp>(loc, maskVec, thisIdx);
+      Value indexSubVec = rewriter.create<vector::ExtractOp>(loc, indexVec, i);
+      Value maskSubVec = rewriter.create<vector::ExtractOp>(loc, maskVec, i);
       Value passThruSubVec =
-          rewriter.create<vector::ExtractOp>(loc, passThruVec, thisIdx);
+          rewriter.create<vector::ExtractOp>(loc, passThruVec, i);
       Value subGather = rewriter.create<vector::GatherOp>(
           loc, subTy, op.getBase(), op.getIndices(), indexSubVec, maskSubVec,
           passThruSubVec);
-      result =
-          rewriter.create<vector::InsertOp>(loc, subGather, result, thisIdx);
+      result = rewriter.create<vector::InsertOp>(loc, subGather, result, i);
     }
 
     rewriter.replaceOp(op, result);
@@ -101,6 +97,57 @@ struct UnrollGather : OpRewritePattern<vector::GatherOp> {
   }
 };
 
+/// Unrolls 2 or more dimensional `vector.scatter` ops by unrolling the
+/// outermost dimension. For example:
+/// ```
+/// %g = vector.scatter %base[%c0][%v], %mask, %valueToStore : ...
+/// vector<2x3xf32>
+///
+/// ==>
+///
+/// %g0  = vector.extract %valueToStore[0] : vector<3xf32> from vector<2x3xf32>
+///        vector.scatter %base[%c0][%v0], %mask0, %g0
+/// %g1  = vector.extract %valueToStore[1] : vector<3xf32> from vector<2x3xf32>
+///        vector.scatter %base[%c0][%v0], %mask0, %g1
+/// ```
+///
+/// When applied exhaustively, this will produce a sequence of 1-d scatter ops.
+///
+/// Supports vector types with a fixed leading dimension.
+struct UnrollScatter : OpRewritePattern<vector::ScatterOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ScatterOp op,
+                                PatternRewriter &rewriter) const override {
+    VectorType vectorTy = op.getVectorType();
+    if (vectorTy.getRank() < 2)
+      return rewriter.notifyMatchFailure(op, "already 1-D");
+
+    // Unrolling doesn't take vscale into account. Pattern is disabled for
+    // vectors with leading scalable dim(s).
+    if (vectorTy.getScalableDims().front())
+      return rewriter.notifyMatchFailure(op, "cannot unroll scalable dim");
+
+    Location loc = op.getLoc();
+    Value indexVec = op.getIndexVec();
+    Value maskVec = op.getMask();
+    Value valueToStoreVec = op.getValueToStore();
+
+    for (int64_t i = 0, e = vectorTy.getShape().front(); i < e; ++i) {
+      Value indexSubVec = rewriter.create<vector::ExtractOp>(loc, indexVec, i);
+      Value maskSubVec = rewriter.create<vector::ExtractOp>(loc, maskVec, i);
+      Value valueToStoreSubVec =
+          rewriter.create<vector::ExtractOp>(loc, valueToStoreVec, i);
+      rewriter.create<vector::ScatterOp>(loc, op.getBase(), op.getIndices(),
+                                         indexSubVec, maskSubVec,
+                                         valueToStoreSubVec);
+    }
+
+    rewriter.eraseOp(op);
+    return success();
+  }
+};
+
 /// Rewrites a vector.gather of a strided MemRef as a gather of a non-strided
 /// MemRef with updated indices that model the strided access.
 ///
@@ -268,9 +315,9 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
 };
 } // namespace
 
-void mlir::vector::populateVectorGatherLoweringPatterns(
+void mlir::vector::populateVectorGatherScatterLoweringPatterns(
     RewritePatternSet &patterns, PatternBenefit benefit) {
-  patterns.add<UnrollGather>(patterns.getContext(), benefit);
+  patterns.add<UnrollGather, UnrollScatter>(patterns.getContext(), benefit);
 }
 
 void mlir::vector::populateVectorGatherToConditionalLoadPatterns(
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index ba1da84719106..24f01cb63097d 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1734,7 +1734,8 @@ func.func @scatter_with_mask(%arg0: memref<?xf32>, %arg1: vector<2x3xi32>, %arg2
 }
 
 // CHECK-LABEL: func @scatter_with_mask
-// CHECK: vector.scatter
+// CHECK: llvm.intr.masked.scatter %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : vector<3xf32>, vector<3xi1> into !llvm.vec<3 x ptr>
+// CHECK: llvm.intr.masked.scatter %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : vector<3xf32>, vector<3xi1> into !llvm.vec<3 x ptr>
 
 // -----
 
@@ -1749,7 +1750,8 @@ func.func @scatter_with_mask_scalable(%arg0: memref<?xf32>, %arg1: vector<2x[3]x
 }
 
 // CHECK-LABEL: func @scatter_with_mask_scalable
-// CHECK: vector.scatter
+// CHECK: llvm.intr.masked.scatter %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : vector<[3]xf32>, vector<[3]xi1> into !llvm.vec<? x 3 x ptr>
+// CHECK: llvm.intr.masked.scatter %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : vector<[3]xf32>, vector<[3]xi1> into !llvm.vec<? x 3 x ptr>
 
 // -----
 
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index a54ae816570a8..cec3d2c424fa4 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -781,7 +781,7 @@ struct TestVectorGatherLowering
 
   void runOnOperation() override {
     RewritePatternSet patterns(&getContext());
-    populateVectorGatherLoweringPatterns(patterns);
+    populateVectorGatherScatterLoweringPatterns(patterns);
     populateVectorGatherToConditionalLoadPatterns(patterns);
     (void)applyPatternsGreedily(getOperation(), std::move(patterns));
   }



More information about the Mlir-commits mailing list