[Mlir-commits] [mlir] [mlir][vector] Refactor multi-reduction patterns (NFC) (PR #183048)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Tue Feb 24 05:35:11 PST 2026
https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/183048
>From 74aa682c7b241ce0f38916b9ea412cff05713a9f Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Tue, 24 Feb 2026 12:30:33 +0000
Subject: [PATCH 1/2] [mlir][vector] Refactor multi-reduction patterns (NFC)
Refactor the following patterns to inherit from `MaskableOpRewritePattern`:
* `TwoDimMultiReductionToReduction`
* `TwoDimMultiReductionToElementWise`
This improves code reuse, enables small simplifications, and unifies the
structure of the patterns. Add high-level comments to clarify the overall
lowering strategy.
Prepares for future refactoring (e.g. #182301) and helps maintain a
uniform implementation.
---
.../Transforms/LowerVectorMultiReduction.cpp | 156 ++++++++++--------
1 file changed, 88 insertions(+), 68 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 2d6a49bad27bc..a8c57fa4811e4 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -300,14 +300,40 @@ class ReduceMultiDimReductionRank
const bool useInnerDimsForReduction;
};
-/// Unrolls vector.multi_reduction with outermost reductions
-/// and combines results
+/// Lowers 2D vector.multi_reduction to a squence of Arith Ops
+///
+/// The reduction dimension must be the outer-most dimension.
+///
+/// BEFORE:
+///
+/// %1 = vector.multi_reduction <mul>, %src, %acc [0] : vector<4x2xf32> to
+/// vector<2xf32>
+///
+/// AFTER:
+///
+/// // Prod 1.
+/// %vec_0 = vector.extract %src[0] : vector<2xf32> from vector<4x2xf32>
+/// %mul_0 = arith.mulf %vec_0, %acc : vector<2xf32>
+///
+/// // Prod 2.
+/// %vec_1 = vector.extract %src[1] : vector<2xf32> from vector<4x2xf32>
+/// %mul_2 = arith.mulf %vec_1, %mul_0 : vector<2xf32>
+///
+/// // Prod 3.
+/// %vec_3 = vector.extract %src[2] : vector<2xf32> from vector<4x2xf32>
+/// %mul_3 = arith.mulf %vec_3, %mul_2 : vector<2xf32>
+///
+/// // Prod 4.
+/// %vec_4 = vector.extract %src[3] : vector<2xf32> from vector<4x2xf32>
+/// %res = arith.mulf %vec_4, %mul_3 : vector<2xf32>
struct TwoDimMultiReductionToElementWise
- : public OpRewritePattern<vector::MultiDimReductionOp> {
- using Base::Base;
+ : public vector::MaskableOpRewritePattern<vector::MultiDimReductionOp> {
+ using MaskableOpRewritePattern::MaskableOpRewritePattern;
- LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
- PatternRewriter &rewriter) const override {
+ FailureOr<Value>
+ matchAndRewriteMaskableOp(vector::MultiDimReductionOp multiReductionOp,
+ vector::MaskingOpInterface maskingOp,
+ PatternRewriter &rewriter) const override {
auto srcRank = multiReductionOp.getSourceVectorType().getRank();
// Rank-2 ["parallel", "reduce"] or bail.
if (srcRank != 2)
@@ -316,100 +342,94 @@ struct TwoDimMultiReductionToElementWise
if (multiReductionOp.isReducedDim(1) || !multiReductionOp.isReducedDim(0))
return failure();
+ Value mask = maskingOp ? maskingOp.getMask() : Value();
+
auto loc = multiReductionOp.getLoc();
+ Value source = multiReductionOp.getSource();
ArrayRef<int64_t> srcShape =
multiReductionOp.getSourceVectorType().getShape();
-
- Type elementType = getElementTypeOrSelf(multiReductionOp.getDestType());
- if (!elementType.isIntOrIndexOrFloat())
- return failure();
-
- OpBuilder::InsertionGuard guard(rewriter);
- auto maskableOp =
- cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
- Operation *rootOp;
- Value mask = nullptr;
- if (maskableOp.isMasked()) {
- rewriter.setInsertionPoint(maskableOp.getMaskingOp());
- rootOp = maskableOp.getMaskingOp();
- mask = maskableOp.getMaskingOp().getMask();
- } else {
- rootOp = multiReductionOp;
- }
+ int outerDim = srcShape[0];
Value result = multiReductionOp.getAcc();
- for (int64_t i = 0; i < srcShape[0]; i++) {
- auto operand = vector::ExtractOp::create(rewriter, loc,
- multiReductionOp.getSource(), i);
- Value extractMask = nullptr;
- if (mask) {
- extractMask = vector::ExtractOp::create(rewriter, loc, mask, i);
- }
- result =
- makeArithReduction(rewriter, loc, multiReductionOp.getKind(), operand,
- result, /*fastmath=*/nullptr, extractMask);
+ for (int64_t i = 0; i < outerDim; i++) {
+ auto v = vector::ExtractOp::create(rewriter, loc, source, i);
+ Value m;
+ m = mask ? Value(vector::ExtractOp::create(rewriter, loc, mask, i))
+ : nullptr;
+ result = makeArithReduction(rewriter, loc, multiReductionOp.getKind(), v,
+ result, /*fastmath=*/nullptr, m);
}
- rewriter.replaceOp(rootOp, result);
- return success();
+ return result;
}
};
-/// Converts 2d vector.multi_reduction with inner most reduction dimension into
-/// a sequence of vector.reduction ops.
+/// Lowers 2D vector.multi_reduction to a squence of vector.reduction Ops
+///
+/// The reduction dimension must be the inner-most dimension.
+///
+/// BEFORE:
+/// vector.multi_reduction <mul>, %src, %acc [1] : vector<2x4xf32> to
+/// vector<2xf32>
+///
+/// AFTER:
+/// // 1st reduction
+/// %v_0 = vector.extract %src[0] : vector<4xf32> from vector<2x4xf32>
+/// %a_0 = vector.extract %acc[0] : f32 from vector<2xf32>
+/// %red_1 = vector.reduction <mul>, %v_0, %a_1 : vector<4xf32> into f32
+/// %res_tmp = vector.insert %red_1, %res [0] : f32 into vector<2xf32>
+///
+/// // 2nd reduction
+/// %v_1 = vector.extract %src[1] : vector<4xf32> from vector<2x4xf32>
+/// %a_1 = vector.extract %acc[1] : f32 from vector<2xf32>
+/// %red_2 = vector.reduction <mul>, %v_1, %a_1 : vector<4xf32> into f32
+/// %res_final = vector.insert %red_2, %red_2 [1] : f32 into vector<2xf32>
struct TwoDimMultiReductionToReduction
- : public OpRewritePattern<vector::MultiDimReductionOp> {
- using Base::Base;
+ : public vector::MaskableOpRewritePattern<vector::MultiDimReductionOp> {
+ using MaskableOpRewritePattern::MaskableOpRewritePattern;
- LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
- PatternRewriter &rewriter) const override {
+ FailureOr<Value>
+ matchAndRewriteMaskableOp(vector::MultiDimReductionOp multiReductionOp,
+ vector::MaskingOpInterface maskingOp,
+ PatternRewriter &rewriter) const override {
auto srcRank = multiReductionOp.getSourceVectorType().getRank();
+ // Rank-2 ["reduce", "parallel"] or bail.
if (srcRank != 2)
return failure();
if (multiReductionOp.isReducedDim(0) || !multiReductionOp.isReducedDim(1))
return failure();
- // Vector mask setup.
- OpBuilder::InsertionGuard guard(rewriter);
- auto maskableOp =
- cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
- Operation *rootOp;
- if (maskableOp.isMasked()) {
- rewriter.setInsertionPoint(maskableOp.getMaskingOp());
- rootOp = maskableOp.getMaskingOp();
- } else {
- rootOp = multiReductionOp;
- }
+ Value mask = maskingOp ? maskingOp.getMask() : nullptr;
auto loc = multiReductionOp.getLoc();
+ Value source = multiReductionOp.getSource();
+ Value acc = multiReductionOp.getAcc();
+ int outerDim = multiReductionOp.getSourceVectorType().getShape()[0];
+
Value result = arith::ConstantOp::create(
rewriter, loc, multiReductionOp.getDestType(),
rewriter.getZeroAttr(multiReductionOp.getDestType()));
- int outerDim = multiReductionOp.getSourceVectorType().getShape()[0];
- for (int i = 0; i < outerDim; ++i) {
- auto v = vector::ExtractOp::create(
- rewriter, loc, multiReductionOp.getSource(), ArrayRef<int64_t>{i});
- auto acc = vector::ExtractOp::create(
- rewriter, loc, multiReductionOp.getAcc(), ArrayRef<int64_t>{i});
+ SmallVector<Value> vectors(outerDim);
+ Value m, v, a;
+ for (int64_t i = 0; i < outerDim; ++i) {
+ v = vector::ExtractOp::create(rewriter, loc, source, i);
+ a = vector::ExtractOp::create(rewriter, loc, acc, i);
+
Operation *reductionOp = vector::ReductionOp::create(
- rewriter, loc, multiReductionOp.getKind(), v, acc);
-
- // If masked, slice the mask and mask the new reduction operation.
- if (maskableOp.isMasked()) {
- Value mask = vector::ExtractOp::create(
- rewriter, loc, maskableOp.getMaskingOp().getMask(),
- ArrayRef<int64_t>{i});
- reductionOp = mlir::vector::maskOperation(rewriter, reductionOp, mask);
+ rewriter, loc, multiReductionOp.getKind(), v, a);
+
+ if (mask) {
+ m = vector::ExtractOp::create(rewriter, loc, mask, i);
+ reductionOp = mlir::vector::maskOperation(rewriter, reductionOp, m);
}
result = vector::InsertOp::create(rewriter, loc,
reductionOp->getResult(0), result, i);
}
- rewriter.replaceOp(rootOp, result);
- return success();
+ return result;
}
};
>From 60e87afb8ce493e1e55c9e3390a93e12ed715814 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Tue, 24 Feb 2026 13:34:49 +0000
Subject: [PATCH 2/2] Address PR comments
---
.../Transforms/LowerVectorMultiReduction.cpp | 14 ++++++--------
1 file changed, 6 insertions(+), 8 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index a8c57fa4811e4..e04ed0309aeac 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -353,9 +353,8 @@ struct TwoDimMultiReductionToElementWise
Value result = multiReductionOp.getAcc();
for (int64_t i = 0; i < outerDim; i++) {
auto v = vector::ExtractOp::create(rewriter, loc, source, i);
- Value m;
- m = mask ? Value(vector::ExtractOp::create(rewriter, loc, mask, i))
- : nullptr;
+ Value m = mask ? Value(vector::ExtractOp::create(rewriter, loc, mask, i))
+ : nullptr;
result = makeArithReduction(rewriter, loc, multiReductionOp.getKind(), v,
result, /*fastmath=*/nullptr, m);
}
@@ -383,7 +382,7 @@ struct TwoDimMultiReductionToElementWise
/// %v_1 = vector.extract %src[1] : vector<4xf32> from vector<2x4xf32>
/// %a_1 = vector.extract %acc[1] : f32 from vector<2xf32>
/// %red_2 = vector.reduction <mul>, %v_1, %a_1 : vector<4xf32> into f32
-/// %res_final = vector.insert %red_2, %red_2 [1] : f32 into vector<2xf32>
+/// %res_final = vector.insert %red_2, %res_tmp [1] : f32 into vector<2xf32>
struct TwoDimMultiReductionToReduction
: public vector::MaskableOpRewritePattern<vector::MultiDimReductionOp> {
using MaskableOpRewritePattern::MaskableOpRewritePattern;
@@ -412,16 +411,15 @@ struct TwoDimMultiReductionToReduction
rewriter.getZeroAttr(multiReductionOp.getDestType()));
SmallVector<Value> vectors(outerDim);
- Value m, v, a;
for (int64_t i = 0; i < outerDim; ++i) {
- v = vector::ExtractOp::create(rewriter, loc, source, i);
- a = vector::ExtractOp::create(rewriter, loc, acc, i);
+ Value v = vector::ExtractOp::create(rewriter, loc, source, i);
+ Value a = vector::ExtractOp::create(rewriter, loc, acc, i);
Operation *reductionOp = vector::ReductionOp::create(
rewriter, loc, multiReductionOp.getKind(), v, a);
if (mask) {
- m = vector::ExtractOp::create(rewriter, loc, mask, i);
+ Value m = vector::ExtractOp::create(rewriter, loc, mask, i);
reductionOp = mlir::vector::maskOperation(rewriter, reductionOp, m);
}
More information about the Mlir-commits
mailing list