[Mlir-commits] [mlir] [mlir][vector] Refactor multi-reduction patterns (NFC) (PR #183048)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Feb 24 04:37:40 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Andrzej WarzyƄski (banach-space)

<details>
<summary>Changes</summary>

Refactor the following patterns to inherit from `MaskableOpRewritePattern`:
  * `TwoDimMultiReductionToReduction`
  * `TwoDimMultiReductionToElementWise`

This improves code reuse, enables small simplifications, and unifies the
structure of the patterns. Add high-level comments to clarify the overall
lowering strategy.

Prepares for future refactoring (e.g. #<!-- -->182301) and helps maintain a
uniform implementation.


---
Full diff: https://github.com/llvm/llvm-project/pull/183048.diff


1 Files Affected:

- (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp (+88-68) 


``````````diff
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 2d6a49bad27bc..a8c57fa4811e4 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -300,14 +300,40 @@ class ReduceMultiDimReductionRank
   const bool useInnerDimsForReduction;
 };
 
-/// Unrolls vector.multi_reduction with outermost reductions
-/// and combines results
+/// Lowers 2D vector.multi_reduction to a squence of Arith Ops
+///
+/// The reduction dimension must be the outer-most dimension.
+///
+/// BEFORE:
+///
+///  %1 = vector.multi_reduction <mul>, %src, %acc [0] : vector<4x2xf32> to
+///  vector<2xf32>
+///
+/// AFTER:
+///
+///   // Prod 1.
+///   %vec_0 = vector.extract %src[0] : vector<2xf32> from vector<4x2xf32>
+///   %mul_0 = arith.mulf %vec_0, %acc : vector<2xf32>
+///
+///   // Prod 2.
+///   %vec_1 = vector.extract %src[1] : vector<2xf32> from vector<4x2xf32>
+///   %mul_2 = arith.mulf %vec_1, %mul_0 : vector<2xf32>
+///
+///   // Prod 3.
+///   %vec_3 = vector.extract %src[2] : vector<2xf32> from vector<4x2xf32>
+///   %mul_3 = arith.mulf %vec_3, %mul_2 : vector<2xf32>
+///
+///   // Prod 4.
+///   %vec_4 = vector.extract %src[3] : vector<2xf32> from vector<4x2xf32>
+///   %res = arith.mulf %vec_4, %mul_3 : vector<2xf32>
 struct TwoDimMultiReductionToElementWise
-    : public OpRewritePattern<vector::MultiDimReductionOp> {
-  using Base::Base;
+    : public vector::MaskableOpRewritePattern<vector::MultiDimReductionOp> {
+  using MaskableOpRewritePattern::MaskableOpRewritePattern;
 
-  LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
-                                PatternRewriter &rewriter) const override {
+  FailureOr<Value>
+  matchAndRewriteMaskableOp(vector::MultiDimReductionOp multiReductionOp,
+                            vector::MaskingOpInterface maskingOp,
+                            PatternRewriter &rewriter) const override {
     auto srcRank = multiReductionOp.getSourceVectorType().getRank();
     // Rank-2 ["parallel", "reduce"] or bail.
     if (srcRank != 2)
@@ -316,100 +342,94 @@ struct TwoDimMultiReductionToElementWise
     if (multiReductionOp.isReducedDim(1) || !multiReductionOp.isReducedDim(0))
       return failure();
 
+    Value mask = maskingOp ? maskingOp.getMask() : Value();
+
     auto loc = multiReductionOp.getLoc();
+    Value source = multiReductionOp.getSource();
     ArrayRef<int64_t> srcShape =
         multiReductionOp.getSourceVectorType().getShape();
-
-    Type elementType = getElementTypeOrSelf(multiReductionOp.getDestType());
-    if (!elementType.isIntOrIndexOrFloat())
-      return failure();
-
-    OpBuilder::InsertionGuard guard(rewriter);
-    auto maskableOp =
-        cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
-    Operation *rootOp;
-    Value mask = nullptr;
-    if (maskableOp.isMasked()) {
-      rewriter.setInsertionPoint(maskableOp.getMaskingOp());
-      rootOp = maskableOp.getMaskingOp();
-      mask = maskableOp.getMaskingOp().getMask();
-    } else {
-      rootOp = multiReductionOp;
-    }
+    int outerDim = srcShape[0];
 
     Value result = multiReductionOp.getAcc();
-    for (int64_t i = 0; i < srcShape[0]; i++) {
-      auto operand = vector::ExtractOp::create(rewriter, loc,
-                                               multiReductionOp.getSource(), i);
-      Value extractMask = nullptr;
-      if (mask) {
-        extractMask = vector::ExtractOp::create(rewriter, loc, mask, i);
-      }
-      result =
-          makeArithReduction(rewriter, loc, multiReductionOp.getKind(), operand,
-                             result, /*fastmath=*/nullptr, extractMask);
+    for (int64_t i = 0; i < outerDim; i++) {
+      auto v = vector::ExtractOp::create(rewriter, loc, source, i);
+      Value m;
+      m = mask ? Value(vector::ExtractOp::create(rewriter, loc, mask, i))
+               : nullptr;
+      result = makeArithReduction(rewriter, loc, multiReductionOp.getKind(), v,
+                                  result, /*fastmath=*/nullptr, m);
     }
 
-    rewriter.replaceOp(rootOp, result);
-    return success();
+    return result;
   }
 };
 
-/// Converts 2d vector.multi_reduction with inner most reduction dimension into
-/// a sequence of vector.reduction ops.
+/// Lowers 2D vector.multi_reduction to a squence of vector.reduction Ops
+///
+/// The reduction dimension must be the inner-most dimension.
+///
+/// BEFORE:
+///  vector.multi_reduction <mul>, %src, %acc [1] : vector<2x4xf32> to
+///  vector<2xf32>
+///
+/// AFTER:
+///   // 1st reduction
+///   %v_0 = vector.extract %src[0] : vector<4xf32> from vector<2x4xf32>
+///   %a_0 = vector.extract %acc[0] : f32 from vector<2xf32>
+///   %red_1 = vector.reduction <mul>, %v_0, %a_1 : vector<4xf32> into f32
+///   %res_tmp = vector.insert %red_1, %res [0] : f32 into vector<2xf32>
+///
+///   // 2nd reduction
+///   %v_1 = vector.extract %src[1] : vector<4xf32> from vector<2x4xf32>
+///   %a_1 = vector.extract %acc[1] : f32 from vector<2xf32>
+///   %red_2 = vector.reduction <mul>, %v_1, %a_1 : vector<4xf32> into f32
+///   %res_final = vector.insert %red_2, %red_2 [1] : f32 into vector<2xf32>
 struct TwoDimMultiReductionToReduction
-    : public OpRewritePattern<vector::MultiDimReductionOp> {
-  using Base::Base;
+    : public vector::MaskableOpRewritePattern<vector::MultiDimReductionOp> {
+  using MaskableOpRewritePattern::MaskableOpRewritePattern;
 
-  LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
-                                PatternRewriter &rewriter) const override {
+  FailureOr<Value>
+  matchAndRewriteMaskableOp(vector::MultiDimReductionOp multiReductionOp,
+                            vector::MaskingOpInterface maskingOp,
+                            PatternRewriter &rewriter) const override {
     auto srcRank = multiReductionOp.getSourceVectorType().getRank();
+    // Rank-2 ["reduce", "parallel"] or bail.
     if (srcRank != 2)
       return failure();
 
     if (multiReductionOp.isReducedDim(0) || !multiReductionOp.isReducedDim(1))
       return failure();
 
-    // Vector mask setup.
-    OpBuilder::InsertionGuard guard(rewriter);
-    auto maskableOp =
-        cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
-    Operation *rootOp;
-    if (maskableOp.isMasked()) {
-      rewriter.setInsertionPoint(maskableOp.getMaskingOp());
-      rootOp = maskableOp.getMaskingOp();
-    } else {
-      rootOp = multiReductionOp;
-    }
+    Value mask = maskingOp ? maskingOp.getMask() : nullptr;
 
     auto loc = multiReductionOp.getLoc();
+    Value source = multiReductionOp.getSource();
+    Value acc = multiReductionOp.getAcc();
+    int outerDim = multiReductionOp.getSourceVectorType().getShape()[0];
+
     Value result = arith::ConstantOp::create(
         rewriter, loc, multiReductionOp.getDestType(),
         rewriter.getZeroAttr(multiReductionOp.getDestType()));
-    int outerDim = multiReductionOp.getSourceVectorType().getShape()[0];
 
-    for (int i = 0; i < outerDim; ++i) {
-      auto v = vector::ExtractOp::create(
-          rewriter, loc, multiReductionOp.getSource(), ArrayRef<int64_t>{i});
-      auto acc = vector::ExtractOp::create(
-          rewriter, loc, multiReductionOp.getAcc(), ArrayRef<int64_t>{i});
+    SmallVector<Value> vectors(outerDim);
+    Value m, v, a;
+    for (int64_t i = 0; i < outerDim; ++i) {
+      v = vector::ExtractOp::create(rewriter, loc, source, i);
+      a = vector::ExtractOp::create(rewriter, loc, acc, i);
+
       Operation *reductionOp = vector::ReductionOp::create(
-          rewriter, loc, multiReductionOp.getKind(), v, acc);
-
-      // If masked, slice the mask and mask the new reduction operation.
-      if (maskableOp.isMasked()) {
-        Value mask = vector::ExtractOp::create(
-            rewriter, loc, maskableOp.getMaskingOp().getMask(),
-            ArrayRef<int64_t>{i});
-        reductionOp = mlir::vector::maskOperation(rewriter, reductionOp, mask);
+          rewriter, loc, multiReductionOp.getKind(), v, a);
+
+      if (mask) {
+        m = vector::ExtractOp::create(rewriter, loc, mask, i);
+        reductionOp = mlir::vector::maskOperation(rewriter, reductionOp, m);
       }
 
       result = vector::InsertOp::create(rewriter, loc,
                                         reductionOp->getResult(0), result, i);
     }
 
-    rewriter.replaceOp(rootOp, result);
-    return success();
+    return result;
   }
 };
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/183048


More information about the Mlir-commits mailing list