[Mlir-commits] [mlir] [mlir][vector] Separate multi_reduce lowering into invariants, flattening, and unrolling. (PR #178974)
Erick Ochoa Lopez
llvmlistbot at llvm.org
Fri Jan 30 15:57:50 PST 2026
https://github.com/amd-eochoalo updated https://github.com/llvm/llvm-project/pull/178974
>From c7e0fa9bb52ffecb21cd9bec4b51a16b02b73000 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Thu, 29 Jan 2026 10:29:59 -0500
Subject: [PATCH 01/14] [mlir][vector] Finer pattern selection for
VectorMultiReduction (NFC).
* Adds populateVectorMultiReductionInnerOuterDimPatterns which populates
InnerOuterDimReductionConversion and OneDimMultiReductionToTwoDim.
These patterns set invariants for the patterns to follow.
* Adds populateVectorMultiReductionFlatteningPatterns which populates
ReduceMultiDimReductionRank, TwoDimMultiReductionToElementWise, and
TwoDimMultiReductionToReduction.
---
.../Vector/Transforms/LoweringPatterns.h | 40 +++++++++++++++++++
.../TransformOps/VectorTransformOps.cpp | 2 +
.../Transforms/LowerVectorMultiReduction.cpp | 36 +++++++++++++++--
3 files changed, 74 insertions(+), 4 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 7bd96c8a6d1a1..6d39183d3e66d 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -60,6 +60,46 @@ void populateVectorContractLoweringPatterns(
void populateVectorOuterProductLoweringPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);
+/// Collect a set of patterns to set invariants for vector.multi_reduction's
+/// conversion. The patterns comprise:
+///
+/// [InnerOuterDimReductionConversion]
+/// Rewrites vector.multi_reduction such that all reduction dimensions are
+/// either innermost or outermost, by adding the proper vector.transpose
+/// operations.
+///
+/// [OneDimMultiReductionToTwoDim]
+/// For cases that reduce to 1-D vector<k> reduction (and are thus missing
+/// either a parallel or a reduction), we lift them back up to 2-D with a simple
+/// vector.shape_cast to vector<1xk> so that the other patterns can kick in,
+/// thus fully exiting out of the vector.multi_reduction abstraction.
+void populateVectorMultiReductionInnerOuterDimPatterns(
+ RewritePatternSet &patterns, VectorMultiReductionLowering options,
+ PatternBenefit benefit = 1);
+
+/// Collect a set of patterns to convert vector.multi_reduction op into
+/// a sequence of vector.reduction ops. The patterns comprise:
+///
+/// [ReduceMultiDimReductionRank]
+/// Once in innermost or outermost reduction
+/// form, rewrites n-D vector.multi_reduction into 2-D vector.multi_reduction,
+/// by introducing vector.shape_cast ops to collapse + multi-reduce + expand
+/// back.
+///
+/// [TwoDimMultiReductionToElementWise]
+/// Once in 2-D vector.multi_reduction form, with an **outermost** reduction
+/// dimension, unroll the outer dimension to obtain a sequence of 1-D vector
+/// ops. This also has an opportunity for tree-reduction (in the future).
+///
+/// [TwoDimMultiReductionToReduction]
+/// Once in 2-D vector.multi_reduction form, with an **innermost** reduction
+/// dimension, unroll the outer dimension to obtain a sequence of extract +
+/// vector.reduction + insert. This can further lower to horizontal reduction
+/// ops.
+void populateVectorMultiReductionFlatteningPatterns(
+ RewritePatternSet &patterns, VectorMultiReductionLowering options,
+ PatternBenefit benefit = 1);
+
/// Collect a set of patterns to convert vector.multi_reduction op into
/// a sequence of vector.reduction ops. The patterns comprise:
///
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 7faa222a9e574..f651e74873a91 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -130,6 +130,8 @@ void transform::ApplyLowerMultiReductionPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
vector::VectorTransformsOptions vectorTransformOptions;
vectorTransformOptions.setVectorMultiReductionLowering(getLoweringStrategy());
+ vector::populateVectorMultiReductionInnerOuterDimPatterns(
+ patterns, vectorTransformOptions.vectorMultiReductionLowering);
vector::populateVectorMultiReductionLoweringPatterns(
patterns, vectorTransformOptions.vectorMultiReductionLowering);
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index e86e2a97038db..993b6a7f6e08f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -496,11 +496,18 @@ struct LowerVectorMultiReductionPass
Operation *op = getOperation();
MLIRContext *context = op->getContext();
- RewritePatternSet loweringPatterns(context);
- populateVectorMultiReductionLoweringPatterns(loweringPatterns,
- this->loweringStrategy);
+ RewritePatternSet innerOuterPatterns(context);
+ populateVectorMultiReductionInnerOuterDimPatterns(innerOuterPatterns,
+ this->loweringStrategy);
- if (failed(applyPatternsGreedily(op, std::move(loweringPatterns))))
+ if (failed(applyPatternsGreedily(op, std::move(innerOuterPatterns))))
+ signalPassFailure();
+
+ RewritePatternSet flatteningPatterns(context);
+ populateVectorMultiReductionFlatteningPatterns(flatteningPatterns,
+ this->loweringStrategy);
+
+ if (failed(applyPatternsGreedily(op, std::move(flatteningPatterns))))
signalPassFailure();
}
@@ -511,6 +518,27 @@ struct LowerVectorMultiReductionPass
} // namespace
+void mlir::vector::populateVectorMultiReductionInnerOuterDimPatterns(
+ RewritePatternSet &patterns, VectorMultiReductionLowering options,
+ PatternBenefit benefit) {
+ patterns.add<OneDimMultiReductionToTwoDim>(patterns.getContext(), benefit);
+ patterns.add<InnerOuterDimReductionConversion>(patterns.getContext(), options,
+ benefit);
+}
+
+void mlir::vector::populateVectorMultiReductionFlatteningPatterns(
+ RewritePatternSet &patterns, VectorMultiReductionLowering options,
+ PatternBenefit benefit) {
+ patterns.add<ReduceMultiDimReductionRank>(patterns.getContext(), options,
+ benefit);
+ if (options == VectorMultiReductionLowering ::InnerReduction)
+ patterns.add<TwoDimMultiReductionToReduction>(patterns.getContext(),
+ benefit);
+ else
+ patterns.add<TwoDimMultiReductionToElementWise>(patterns.getContext(),
+ benefit);
+}
+
void mlir::vector::populateVectorMultiReductionLoweringPatterns(
RewritePatternSet &patterns, VectorMultiReductionLowering options,
PatternBenefit benefit) {
>From 7a0e185022bd04ed2d878f4574fee560347a823f Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Thu, 29 Jan 2026 10:37:16 -0500
Subject: [PATCH 02/14] [mlir][vector] Finer pattern selection for
VectorMultiReduction (NFC)
* Removes the following patterns from
populateVectorMultiReductionFlatteningPatterns:
* TwoDimMultiReductionToElementWise
* TwoDimMultiReductionToReduction
* Adds populateVectorMultiReductionUnrollingPatterns with patterns
TwoDimMultiReductionToElementWise, TwoDimMultiReductionToReduction
---
.../Dialect/Vector/Transforms/LoweringPatterns.h | 8 +++++++-
.../Vector/Transforms/LowerVectorMultiReduction.cpp | 12 ++++++++++++
2 files changed, 19 insertions(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 6d39183d3e66d..1beb8e25cb683 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -85,6 +85,12 @@ void populateVectorMultiReductionInnerOuterDimPatterns(
/// form, rewrites n-D vector.multi_reduction into 2-D vector.multi_reduction,
/// by introducing vector.shape_cast ops to collapse + multi-reduce + expand
/// back.
+void populateVectorMultiReductionFlatteningPatterns(
+ RewritePatternSet &patterns, VectorMultiReductionLowering options,
+ PatternBenefit benefit = 1);
+
+/// Collect a set of patterns to convert vector.multi_reduction op into
+/// a sequence of vector.reduction ops. The patterns comprise:
///
/// [TwoDimMultiReductionToElementWise]
/// Once in 2-D vector.multi_reduction form, with an **outermost** reduction
@@ -96,7 +102,7 @@ void populateVectorMultiReductionInnerOuterDimPatterns(
/// dimension, unroll the outer dimension to obtain a sequence of extract +
/// vector.reduction + insert. This can further lower to horizontal reduction
/// ops.
-void populateVectorMultiReductionFlatteningPatterns(
+void populateVectorMultiReductionUnrollingPatterns(
RewritePatternSet &patterns, VectorMultiReductionLowering options,
PatternBenefit benefit = 1);
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 993b6a7f6e08f..2eb8048bb4a69 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -509,6 +509,13 @@ struct LowerVectorMultiReductionPass
if (failed(applyPatternsGreedily(op, std::move(flatteningPatterns))))
signalPassFailure();
+
+ RewritePatternSet unrollingPatterns(context);
+ populateVectorMultiReductionUnrollingPatterns(unrollingPatterns,
+ this->loweringStrategy);
+
+ if (failed(applyPatternsGreedily(op, std::move(unrollingPatterns))))
+ signalPassFailure();
}
void getDependentDialects(DialectRegistry ®istry) const override {
@@ -531,6 +538,11 @@ void mlir::vector::populateVectorMultiReductionFlatteningPatterns(
PatternBenefit benefit) {
patterns.add<ReduceMultiDimReductionRank>(patterns.getContext(), options,
benefit);
+}
+
+void mlir::vector::populateVectorMultiReductionUnrollingPatterns(
+ RewritePatternSet &patterns, VectorMultiReductionLowering options,
+ PatternBenefit benefit) {
if (options == VectorMultiReductionLowering ::InnerReduction)
patterns.add<TwoDimMultiReductionToReduction>(patterns.getContext(),
benefit);
>From 94dc8fd6cefd1ae166622950f6c6bfc8a6df18bf Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Fri, 30 Jan 2026 11:13:52 -0500
Subject: [PATCH 03/14] [mlir][vector] rank reduce unrolling for
vector.multi_reduction.
---
.../Transforms/LowerVectorMultiReduction.cpp | 199 +++++++++++++++++-
.../vector-multi-reduction-pass-lowering.mlir | 6 +-
2 files changed, 198 insertions(+), 7 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 2eb8048bb4a69..b2b01031587cf 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -485,6 +485,195 @@ struct OneDimMultiReductionToTwoDim
}
};
+/// Unrolls outermost dimension for vector.multi_reduction.
+/// This patterns matches operations which reduce the outermost dimension,
+/// it does not transform operations for which the outermost dimension is not
+/// a reduction dimension.
+///
+/// There are two cases to consider:
+/// 1. The base case is when the outermost dimension is the only reduction
+/// dimension.
+/// 2. The general case is when the outermost dimension is not the only
+/// reduction dimension.
+///
+/// The base case transformation:
+///
+/// ```mlir
+/// %res = vector.multi_reduction <add> %src, %acc [0] : vector<NxMx...xf32> to
+/// vector<Mx...xf32>
+/// ```
+///
+/// will extract N vectors from %src and then perform elementwise operations.
+///
+/// ```mlir
+/// %0 = vector.extract %src[0] : vector<Mx...xf32> from vector<NxMx...xf32>
+/// ...
+/// %Nminus1 = vector.extract %src[ [[N-1]] ] : vector<Mx...x.f32> from
+/// vector<NxMx...xf32>
+///
+/// %res0 = arith.addf %0, %acc : vector<Mx...xf32>
+/// ...
+/// %res = arith.addf %Nminus1, %resNminus2 : vector<Mx...xf32>
+/// ```
+///
+/// For the general case, we still extract N vectors, but produce N
+/// vector.multi_reduction instead of elementwise operations.
+///
+/// ```mlir
+/// %res = vector.multi_reduction <add> %src, %acc [0, [[REDUCTION_DIMS]] ] :
+/// vector<NxMx...xf32> to vector<Ix...xf32>
+///
+/// ```mlir
+/// %0 = vector.extract %src[0] : vector<Mx...xf32> from vector<NxMx...xf32>
+/// ...
+/// %Nminus1 = vector.extract %src[ [[N-1]] ] : vector<Mx...x.f32> from
+/// vector<NxMx...xf32>
+///
+/// %red0 = vector.multi_reduction %0, %acc [ [[REDUCTION_DIMS]] ] :
+/// vector<Mx...xf32> to vector<Ix...xf32>
+/// ...
+/// %res = vector.multi_reduction %Nminus1, %redNminus2 [ [[REDUCTION_DIMS]] ] :
+/// vector<Mx...xf32> to vector<Ix...xf32>
+/// ```
+struct UnrollMultiReductionOuterBaseCase
+ : public OpRewritePattern<vector::MultiDimReductionOp> {
+ using Base::Base;
+
+ LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
+ PatternRewriter &rewriter) const override {
+ if (!multiReductionOp.isReducedDim(0))
+ return rewriter.notifyMatchFailure(
+ multiReductionOp,
+ "expected outermost dimension to be reduced dimension.");
+
+ Type elementType = getElementTypeOrSelf(multiReductionOp.getDestType());
+ if (!elementType.isIntOrIndexOrFloat())
+ return rewriter.notifyMatchFailure(
+ multiReductionOp, "expected integer or float element type.");
+
+ ArrayRef<int64_t> reductionDims = multiReductionOp.getReductionDims();
+ if (reductionDims.size() > 1)
+ return rewriter.notifyMatchFailure(
+ multiReductionOp, "expected only one reduction dimension.");
+
+ Location loc = multiReductionOp.getLoc();
+ Value source = multiReductionOp.getSource();
+
+ ArrayRef<int64_t> srcShape =
+ multiReductionOp.getSourceVectorType().getShape();
+ int64_t numElementwiseOps = srcShape.front();
+
+ OpBuilder::InsertionGuard guard(rewriter);
+ auto maskableOp =
+ cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
+ bool isMasked = maskableOp.isMasked();
+ Operation *rootOp;
+ Value mask = nullptr;
+ if (isMasked) {
+ rewriter.setInsertionPoint(maskableOp.getMaskingOp());
+ rootOp = maskableOp.getMaskingOp();
+ mask = maskableOp.getMaskingOp().getMask();
+ } else {
+ rootOp = multiReductionOp;
+ }
+
+ SmallVector<Value> vectors;
+ for (int64_t i = 0; i < numElementwiseOps; ++i)
+ vectors.push_back(vector::ExtractOp::create(rewriter, loc, source, i));
+
+ SmallVector<Value> masks;
+ for (int64_t i = 0; i < numElementwiseOps; ++i)
+ if (isMasked)
+ masks.push_back(vector::ExtractOp::create(rewriter, loc, mask, i));
+ else
+ masks.push_back(nullptr);
+
+ Value result = multiReductionOp.getAcc();
+ for (auto [innerVector, innerMask] : llvm::zip(vectors, masks))
+ result = makeArithReduction(rewriter, loc, multiReductionOp.getKind(),
+ innerVector, result, /*fastmath=*/nullptr,
+ innerMask);
+
+ rewriter.replaceOp(rootOp, result);
+ return success();
+ }
+};
+
+struct UnrollMultiReductionOuterGeneralCase
+ : public OpRewritePattern<vector::MultiDimReductionOp> {
+ using Base::Base;
+
+ LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
+ PatternRewriter &rewriter) const override {
+ if (!multiReductionOp.isReducedDim(0))
+ return rewriter.notifyMatchFailure(
+ multiReductionOp,
+ "expected outermost dimension to be reduced dimension.");
+
+ Type elementType = getElementTypeOrSelf(multiReductionOp.getDestType());
+ if (!elementType.isIntOrIndexOrFloat())
+ return rewriter.notifyMatchFailure(
+ multiReductionOp, "expected integer or float element type.");
+
+ ArrayRef<int64_t> reductionDims = multiReductionOp.getReductionDims();
+ if (reductionDims.size() <= 1)
+ return rewriter.notifyMatchFailure(
+ multiReductionOp, "expected more than one reduction dimension.");
+
+ Location loc = multiReductionOp.getLoc();
+ Value source = multiReductionOp.getSource();
+
+ ArrayRef<int64_t> srcShape =
+ multiReductionOp.getSourceVectorType().getShape();
+ int64_t numElementwiseOps = srcShape.front();
+
+ OpBuilder::InsertionGuard guard(rewriter);
+ auto maskableOp =
+ cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
+ bool isMasked = maskableOp.isMasked();
+ Operation *rootOp;
+ Value mask = nullptr;
+ if (isMasked) {
+ rewriter.setInsertionPoint(maskableOp.getMaskingOp());
+ rootOp = maskableOp.getMaskingOp();
+ mask = maskableOp.getMaskingOp().getMask();
+ } else {
+ rootOp = multiReductionOp;
+ }
+
+ SmallVector<Value> vectors;
+ for (int64_t i = 0; i < numElementwiseOps; ++i)
+ vectors.push_back(vector::ExtractOp::create(rewriter, loc, source, i));
+
+ SmallVector<Value> masks;
+ for (int64_t i = 0; i < numElementwiseOps; ++i)
+ if (isMasked)
+ masks.push_back(vector::ExtractOp::create(rewriter, loc, mask, i));
+ else
+ masks.push_back(nullptr);
+
+ ArrayRef<bool> reductionMask =
+ ArrayRef<bool>(multiReductionOp.getReductionMask()).drop_front();
+ Value result = multiReductionOp.getAcc();
+ for (auto [innerVector, innerMask] : llvm::zip(vectors, masks)) {
+
+ auto reductionOp = vector::MultiDimReductionOp::create(
+ rewriter, loc, innerVector, result, reductionMask,
+ multiReductionOp.getKind());
+
+ if (isMasked) {
+ auto maskOp = vector::maskOperation(rewriter, reductionOp, innerMask);
+ result = maskOp->getResult(0);
+ } else {
+ result = reductionOp.getResult();
+ }
+ }
+
+ rewriter.replaceOp(rootOp, result);
+ return success();
+ }
+};
+
struct LowerVectorMultiReductionPass
: public vector::impl::LowerVectorMultiReductionBase<
LowerVectorMultiReductionPass> {
@@ -543,12 +732,14 @@ void mlir::vector::populateVectorMultiReductionFlatteningPatterns(
void mlir::vector::populateVectorMultiReductionUnrollingPatterns(
RewritePatternSet &patterns, VectorMultiReductionLowering options,
PatternBenefit benefit) {
- if (options == VectorMultiReductionLowering ::InnerReduction)
+ if (options == VectorMultiReductionLowering ::InnerReduction) {
patterns.add<TwoDimMultiReductionToReduction>(patterns.getContext(),
benefit);
- else
- patterns.add<TwoDimMultiReductionToElementWise>(patterns.getContext(),
- benefit);
+ } else {
+ patterns.add<UnrollMultiReductionOuterBaseCase,
+ UnrollMultiReductionOuterGeneralCase>(patterns.getContext(),
+ benefit);
+ }
}
void mlir::vector::populateVectorMultiReductionLoweringPatterns(
diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-pass-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-pass-lowering.mlir
index ddbc5c7bdb2c0..e01bf446eb83c 100644
--- a/mlir/test/Dialect/Vector/vector-multi-reduction-pass-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-pass-lowering.mlir
@@ -21,12 +21,12 @@ func.func @vector_multi_reduction(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -
// INNER-PARALLEL: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32>
// INNER-PARALLEL: %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<2xf32> from vector<4x2xf32>
-// INNER-PARALLEL: %[[RV0:.+]] = arith.mulf %[[V0]], %[[ACC]] : vector<2xf32>
// INNER-PARALLEL: %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<2xf32> from vector<4x2xf32>
-// INNER-PARALLEL: %[[RV01:.+]] = arith.mulf %[[V1]], %[[RV0]] : vector<2xf32>
// INNER-PARALLEL: %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<2xf32> from vector<4x2xf32>
-// INNER-PARALLEL: %[[RV012:.+]] = arith.mulf %[[V2]], %[[RV01]] : vector<2xf32>
// INNER-PARALLEL: %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<2xf32> from vector<4x2xf32>
+// INNER-PARALLEL: %[[RV0:.+]] = arith.mulf %[[V0]], %[[ACC]] : vector<2xf32>
+// INNER-PARALLEL: %[[RV01:.+]] = arith.mulf %[[V1]], %[[RV0]] : vector<2xf32>
+// INNER-PARALLEL: %[[RV012:.+]] = arith.mulf %[[V2]], %[[RV01]] : vector<2xf32>
// INNER-PARALLEL: %[[RESULT_VEC:.+]] = arith.mulf %[[V3]], %[[RV012]] : vector<2xf32>
// INNER-PARALLEL: return %[[RESULT_VEC]] : vector<2xf32>
>From 2ce43b870f8bc2c156a8f1436704c27ae8edb81d Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Fri, 30 Jan 2026 12:07:30 -0500
Subject: [PATCH 04/14] [mlir][vector] Add tests for multi_reduction unrolling.
---
.../Vector/TransformOps/VectorTransformOps.td | 20 ++++
.../Vector/Transforms/LoweringPatterns.h | 24 +++++
.../TransformOps/VectorTransformOps.cpp | 8 ++
.../Transforms/LowerVectorMultiReduction.cpp | 16 ++++
.../Vector/td/unroll-multi-reduction.mlir | 24 +++++
.../Vector/unroll-vector-multi-reduction.mlir | 92 +++++++++++++++++++
6 files changed, 184 insertions(+)
create mode 100644 mlir/test/Dialect/Vector/td/unroll-multi-reduction.mlir
create mode 100644 mlir/test/Dialect/Vector/unroll-vector-multi-reduction.mlir
diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index 03d25505dc65c..373f9a4210c41 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -243,6 +243,26 @@ def ApplyLowerMultiReductionPatternsOp : Op<Transform_Dialect,
}];
}
+def ApplyUnrollMultiReductionPatternsOp : Op<Transform_Dialect,
+ "apply_patterns.vector.unroll_multi_reduction",
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ let description = [{
+ Unrolls vector.multi_reduction operations by progressively reducing rank
+ along the outermost dimension.
+
+ This is an alternative to the flattening-based lowering that preserves
+ the n-D structure during progressive lowering.
+ }];
+
+ let arguments = (ins DefaultValuedAttr<VectorMultiReductionLoweringAttr,
+ "vector::VectorMultiReductionLowering::InnerParallel">:$lowering_strategy
+ );
+
+ let assemblyFormat = [{
+ (`lowering_strategy` `=` $lowering_strategy^)? attr-dict
+ }];
+}
+
def ApplyLowerOuterProductPatternsOp : Op<Transform_Dialect,
"apply_patterns.vector.lower_outerproduct",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 1beb8e25cb683..1e78983ce082d 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -140,6 +140,30 @@ void populateVectorMultiReductionLoweringPatterns(
RewritePatternSet &patterns, VectorMultiReductionLowering options,
PatternBenefit benefit = 1);
+/// Collect a set of patterns to unroll vector.multi_reduction ops by
+/// progressively reducing rank along the outermost dimension.
+///
+/// For OuterReduction (outermost dim is reduction):
+/// [UnrollMultiReductionOuterBaseCase]
+/// When the outermost dimension is the only reduction dimension, unroll to
+/// produce elementwise arithmetic operations.
+///
+/// [UnrollMultiReductionOuterGeneralCase]
+/// When the outermost dimension is one of multiple reduction dimensions,
+/// unroll to produce smaller multi_reduction operations.
+///
+/// For InnerReduction (innermost dim is reduction):
+/// [UnrollMultiReductionInnerBaseCase]
+/// When the innermost dimension is the only reduction dimension, unroll along
+/// the outermost parallel dimension.
+///
+/// [UnrollMultiReductionInnerGeneralCase]
+/// When the innermost dimension is one of multiple reduction dimensions,
+/// unroll along the outermost parallel dimension.
+void populateVectorUnrollMultiReduction(RewritePatternSet &patterns,
+ VectorMultiReductionLowering options,
+ PatternBenefit benefit = 1);
+
/// Populate the pattern set with the following patterns:
///
/// [TransferReadToVectorLoadLowering]
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index f651e74873a91..4562393559c94 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -136,6 +136,14 @@ void transform::ApplyLowerMultiReductionPatternsOp::populatePatterns(
patterns, vectorTransformOptions.vectorMultiReductionLowering);
}
+void transform::ApplyUnrollMultiReductionPatternsOp::populatePatterns(
+ RewritePatternSet &patterns) {
+ vector::VectorTransformsOptions vectorTransformOptions;
+ vectorTransformOptions.setVectorMultiReductionLowering(getLoweringStrategy());
+ vector::populateVectorUnrollMultiReduction(
+ patterns, vectorTransformOptions.vectorMultiReductionLowering);
+}
+
void transform::ApplyLowerOuterProductPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
populateVectorOuterProductLoweringPatterns(patterns);
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index b2b01031587cf..8548e8c2d311d 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -742,6 +742,22 @@ void mlir::vector::populateVectorMultiReductionUnrollingPatterns(
}
}
+void mlir::vector::populateVectorUnrollMultiReduction(
+ RewritePatternSet &patterns, VectorMultiReductionLowering options,
+ PatternBenefit benefit) {
+ if (options == VectorMultiReductionLowering::InnerReduction) {
+ // TODO: Add UnrollMultiReductionInnerBaseCase and
+ // UnrollMultiReductionInnerGeneralCase patterns here once implemented.
+ // For now, fall back to the existing 2-D based lowering.
+ patterns.add<TwoDimMultiReductionToReduction>(patterns.getContext(),
+ benefit);
+ } else {
+ patterns.add<UnrollMultiReductionOuterBaseCase,
+ UnrollMultiReductionOuterGeneralCase>(patterns.getContext(),
+ benefit);
+ }
+}
+
void mlir::vector::populateVectorMultiReductionLoweringPatterns(
RewritePatternSet &patterns, VectorMultiReductionLowering options,
PatternBenefit benefit) {
diff --git a/mlir/test/Dialect/Vector/td/unroll-multi-reduction.mlir b/mlir/test/Dialect/Vector/td/unroll-multi-reduction.mlir
new file mode 100644
index 0000000000000..96a68723266d3
--- /dev/null
+++ b/mlir/test/Dialect/Vector/td/unroll-multi-reduction.mlir
@@ -0,0 +1,24 @@
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @unroll_multi_reduction(%module_op: !transform.any_op {transform.readonly}) {
+
+ %func_op = transform.structured.match ops{["func.func"]} in %module_op
+ : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func_op {
+ // Test patterns
+ transform.apply_patterns.vector.unroll_multi_reduction
+ } : !transform.any_op
+
+ transform.yield
+ }
+ transform.named_sequence @unroll_multi_reduction_inner(%module_op: !transform.any_op {transform.readonly}) {
+
+ %func_op = transform.structured.match ops{["func.func"]} in %module_op
+ : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func_op {
+ // Test patterns
+ transform.apply_patterns.vector.unroll_multi_reduction lowering_strategy = "innerreduction"
+ } : !transform.any_op
+
+ transform.yield
+ }
+}
diff --git a/mlir/test/Dialect/Vector/unroll-vector-multi-reduction.mlir b/mlir/test/Dialect/Vector/unroll-vector-multi-reduction.mlir
new file mode 100644
index 0000000000000..79086e2b0b9ad
--- /dev/null
+++ b/mlir/test/Dialect/Vector/unroll-vector-multi-reduction.mlir
@@ -0,0 +1,92 @@
+// RUN: mlir-opt --split-input-file %s -transform-preload-library='transform-library-paths=%p/td/unroll-multi-reduction.mlir' \
+// RUN: -transform-interpreter=entry-point=unroll_multi_reduction | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// Test UnrollVectorMultiReduction
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func @unroll_vector_multi_reduction(
+// CHECK-SAME: %[[SOURCE:.+]]: vector<2x3x5xf32>,
+// CHECK-SAME: %[[ACC:.+]]: vector<3x5xf32>
+func.func @unroll_vector_multi_reduction(%source: vector<2x3x5xf32>, %acc: vector<3x5xf32>) -> (vector<3x5xf32>) {
+ // CHECK-DAG: %[[VEC_0:.+]] = vector.extract %[[SOURCE]][0] : vector<3x5xf32> from vector<2x3x5xf32>
+ // CHECK-DAG: %[[VEC_1:.+]] = vector.extract %[[SOURCE]][1] : vector<3x5xf32> from vector<2x3x5xf32>
+
+ // CHECK: %[[RES_0:.+]] = arith.addf %[[VEC_0]], %[[ACC]] : vector<3x5xf32>
+ // CHECK: %[[RES_1:.+]] = arith.addf %[[VEC_1]], %[[RES_0]] : vector<3x5xf32>
+ %1 = vector.multi_reduction <add>, %source, %acc [0] : vector<2x3x5xf32> to vector<3x5xf32>
+
+ // CHECK: return %[[RES_1]]
+ return %1 : vector<3x5xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @unroll_vector_multi_reduction_masked(
+// CHECK-SAME: %[[SOURCE:.+]]: vector<2x3x5xf32>,
+// CHECK-SAME: %[[MASK:.+]]: vector<2x3x5xi1>,
+// CHECK-SAME: %[[ACC:.+]]: vector<3x5xf32>
+func.func @unroll_vector_multi_reduction_masked(%source: vector<2x3x5xf32>, %mask: vector<2x3x5xi1>, %acc: vector<3x5xf32>) -> (vector<3x5xf32>) {
+ // CHECK-DAG: %[[VEC_0:.+]] = vector.extract %[[SOURCE]][0] : vector<3x5xf32> from vector<2x3x5xf32>
+ // CHECK-DAG: %[[VEC_1:.+]] = vector.extract %[[SOURCE]][1] : vector<3x5xf32> from vector<2x3x5xf32>
+
+ // CHECK-DAG: %[[MASK_0:.+]] = vector.extract %[[MASK]][0] : vector<3x5xi1> from vector<2x3x5xi1>
+ // CHECK-DAG: %[[MASK_1:.+]] = vector.extract %[[MASK]][1] : vector<3x5xi1> from vector<2x3x5xi1>
+
+ // CHECK: %[[RES_0:.+]] = arith.addf %[[VEC_0]], %[[ACC]] : vector<3x5xf32>
+ // CHECK: %[[RES_MASKED_0:.+]] = arith.select %[[MASK_0]], %[[RES_0]], %[[ACC]] : vector<3x5xi1>, vector<3x5xf32>
+
+ // CHECK: %[[RES_1:.+]] = arith.addf %[[VEC_1]], %[[RES_MASKED_0]] : vector<3x5xf32>
+ // CHECK: %[[RES_MASKED_1:.+]] = arith.select %[[MASK_1]], %[[RES_1]], %[[RES_MASKED_0]] : vector<3x5xi1>, vector<3x5xf32>
+
+ %0 = vector.mask %mask {
+ %1 = vector.multi_reduction <add>, %source, %acc [0] : vector<2x3x5xf32> to vector<3x5xf32>
+ } : vector<2x3x5xi1> -> vector<3x5xf32>
+
+ // CHECK: return %[[RES_MASKED_1]]
+ return %0 : vector<3x5xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @unroll_vector_multi_reduction_general(
+// CHECK-SAME: %[[SOURCE:.+]]: vector<2x3x5xf32>,
+// CHECK-SAME: %[[ACC:.+]]: vector<3xf32>
+func.func @unroll_vector_multi_reduction_general(%source: vector<2x3x5xf32>, %acc: vector<3xf32>) -> (vector<3xf32>) {
+
+ // CHECK-DAG: %[[VEC_0:.+]] = vector.extract %[[SOURCE]][0] : vector<3x5xf32> from vector<2x3x5xf32>
+ // CHECK-DAG: %[[VEC_1:.+]] = vector.extract %[[SOURCE]][1] : vector<3x5xf32> from vector<2x3x5xf32>
+
+ // CHECK: %[[RES_0:.+]] = vector.multi_reduction <add>, %[[VEC_0]], %[[ACC]] [1] : vector<3x5xf32> to vector<3xf32>
+ // CHECK: %[[RES_1:.+]] = vector.multi_reduction <add>, %[[VEC_1]], %[[RES_0]] [1] : vector<3x5xf32> to vector<3xf32>
+
+ %1 = vector.multi_reduction <add>, %source, %acc [0, 2] : vector<2x3x5xf32> to vector<3xf32>
+
+ // CHECK: return %[[RES_1]]
+ return %1 : vector<3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @unroll_vector_multi_reduction_general_masked(
+// CHECK-SAME: %[[SOURCE:.+]]: vector<2x3x5xf32>,
+// CHECK-SAME: %[[MASK:.+]]: vector<2x3x5xi1>,
+// CHECK-SAME: %[[ACC:.+]]: vector<3xf32>
+func.func @unroll_vector_multi_reduction_general_masked(%source: vector<2x3x5xf32>, %mask: vector<2x3x5xi1>, %acc: vector<3xf32>) -> (vector<3xf32>) {
+
+ // CHECK-DAG: %[[VEC_0:.+]] = vector.extract %[[SOURCE]][0] : vector<3x5xf32> from vector<2x3x5xf32>
+ // CHECK-DAG: %[[VEC_1:.+]] = vector.extract %[[SOURCE]][1] : vector<3x5xf32> from vector<2x3x5xf32>
+
+ // CHECK-DAG: %[[MASK_0:.+]] = vector.extract %[[MASK]][0] : vector<3x5xi1> from vector<2x3x5xi1>
+ // CHECK-DAG: %[[MASK_1:.+]] = vector.extract %[[MASK]][1] : vector<3x5xi1> from vector<2x3x5xi1>
+
+ // CHECK: %[[RES_0:.+]] = vector.mask %[[MASK_0]] { vector.multi_reduction <add>, %[[VEC_0]], %[[ACC]] [1] : vector<3x5xf32> to vector<3xf32> } : vector<3x5xi1> -> vector<3xf32>
+ // CHECK: %[[RES_1:.+]] = vector.mask %[[MASK_1]] { vector.multi_reduction <add>, %[[VEC_1]], %[[RES_0]] [1] : vector<3x5xf32> to vector<3xf32> } : vector<3x5xi1> -> vector<3xf32>
+
+ %0 = vector.mask %mask {
+ %1 = vector.multi_reduction <add>, %source, %acc [0, 2] : vector<2x3x5xf32> to vector<3xf32>
+ } : vector<2x3x5xi1> -> vector<3xf32>
+
+ // CHECK: return %[[RES_1]]
+ return %0 : vector<3xf32>
+}
>From 7561c5be684af765fd722addaaee230ee79b3166 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Fri, 30 Jan 2026 12:49:09 -0500
Subject: [PATCH 05/14] [mlir][vector] Add unrolling pattern for innermost
reduction base case
---
.../Transforms/LowerVectorMultiReduction.cpp | 138 +++++++++++++++++-
.../unroll-vector-multi-reduction-inner.mlir | 106 ++++++++++++++
2 files changed, 239 insertions(+), 5 deletions(-)
create mode 100644 mlir/test/Dialect/Vector/unroll-vector-multi-reduction-inner.mlir
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 8548e8c2d311d..5e0d6da82f834 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -674,6 +674,135 @@ struct UnrollMultiReductionOuterGeneralCase
}
};
+/// Unrolls innermost dimension for vector.multi_reduction when the innermost
+/// dimension is the only reduction dimension.
+///
+/// This pattern matches operations where:
+/// - The innermost dimension is a reduction dimension
+/// - The outermost dimension is a parallel dimension
+/// - There is exactly one reduction dimension
+///
+/// The transformation unrolls along the outermost parallel dimension:
+///
+/// ```mlir
+/// %res = vector.multi_reduction <add> %src, %acc [N-1]
+/// : vector<AxBx...xMxf32> to vector<AxBx...xf32>
+/// ```
+///
+/// becomes:
+///
+/// ```mlir
+/// %result = arith.constant dense<0.0> : vector<AxBx...xf32>
+/// %0 = vector.extract %src[0] : vector<Bx...xMxf32> from vector<AxBx...xMxf32>
+/// %acc0 = vector.extract %acc[0] : vector<Bx...xf32> from vector<AxBx...xf32>
+/// %red0 = vector.multi_reduction <add>, %0, %acc0 [N-2]
+/// : vector<Bx...xMxf32> to vector<Bx...xf32>
+/// %res0 = vector.insert %red0, %result[0]
+/// : vector<Bx...xf32> into vector<AxBx...xf32>
+/// // ... repeat for indices 1 to A-1
+/// ```
+struct UnrollMultiReductionInnerBaseCase
+ : public OpRewritePattern<vector::MultiDimReductionOp> {
+ using Base::Base;
+
+ LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
+ PatternRewriter &rewriter) const override {
+ auto srcRank = multiReductionOp.getSourceVectorType().getRank();
+
+ if (srcRank < 2)
+ return rewriter.notifyMatchFailure(multiReductionOp,
+ "expected source rank >= 2.");
+
+ if (!multiReductionOp.isReducedDim(srcRank - 1))
+ return rewriter.notifyMatchFailure(
+ multiReductionOp,
+ "expected innermost dimension to be a reduction dimension.");
+
+ if (multiReductionOp.isReducedDim(0))
+ return rewriter.notifyMatchFailure(
+ multiReductionOp,
+ "expected outermost dimension to be a parallel dimension.");
+
+ ArrayRef<int64_t> reductionDims = multiReductionOp.getReductionDims();
+ if (reductionDims.size() != 1)
+ return rewriter.notifyMatchFailure(
+ multiReductionOp, "expected exactly one reduction dimension.");
+
+ Type elementType = getElementTypeOrSelf(multiReductionOp.getDestType());
+ if (!elementType.isIntOrIndexOrFloat())
+ return rewriter.notifyMatchFailure(
+ multiReductionOp, "expected integer or float element type.");
+
+ Location loc = multiReductionOp.getLoc();
+ Value source = multiReductionOp.getSource();
+ Value acc = multiReductionOp.getAcc();
+
+ ArrayRef<int64_t> srcShape =
+ multiReductionOp.getSourceVectorType().getShape();
+ int64_t numSlices = srcShape.front();
+
+ OpBuilder::InsertionGuard guard(rewriter);
+ auto maskableOp =
+ cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
+ bool isMasked = maskableOp.isMasked();
+ Operation *rootOp;
+ Value mask = nullptr;
+ if (isMasked) {
+ rewriter.setInsertionPoint(maskableOp.getMaskingOp());
+ rootOp = maskableOp.getMaskingOp();
+ mask = maskableOp.getMaskingOp().getMask();
+ } else {
+ rootOp = multiReductionOp;
+ }
+
+ SmallVector<Value> srcSlices;
+ for (int64_t i = 0; i < numSlices; ++i)
+ srcSlices.push_back(vector::ExtractOp::create(rewriter, loc, source, i));
+
+ SmallVector<Value> accSlices;
+ for (int64_t i = 0; i < numSlices; ++i)
+ accSlices.push_back(vector::ExtractOp::create(rewriter, loc, acc, i));
+
+ SmallVector<Value> maskSlices;
+ for (int64_t i = 0; i < numSlices; ++i)
+ if (isMasked)
+ maskSlices.push_back(vector::ExtractOp::create(rewriter, loc, mask, i));
+ else
+ maskSlices.push_back(nullptr);
+
+ // Compute new reduction mask: for the extracted slice (rank srcRank-1),
+ // the innermost dimension is still the reduction dimension.
+ // New mask has srcRank-1 elements, with the last one being true.
+ SmallVector<bool> newReductionMask(srcRank - 1, false);
+ newReductionMask.back() = true;
+
+ SmallVector<Value> reductionResults;
+ for (auto [srcSlice, accSlice, maskSlice] :
+ llvm::zip(srcSlices, accSlices, maskSlices)) {
+ Operation *newReductionOp = vector::MultiDimReductionOp::create(
+ rewriter, loc, srcSlice, accSlice, newReductionMask,
+ multiReductionOp.getKind());
+
+ if (isMasked)
+ newReductionOp =
+ mlir::vector::maskOperation(rewriter, newReductionOp, maskSlice);
+
+ reductionResults.push_back(newReductionOp->getResult(0));
+ }
+
+ Value result = arith::ConstantOp::create(
+ rewriter, loc, multiReductionOp.getDestType(),
+ rewriter.getZeroAttr(multiReductionOp.getDestType()));
+
+ for (int64_t i = 0; i < numSlices; ++i)
+ result = vector::InsertOp::create(rewriter, loc, reductionResults[i],
+ result, i);
+
+ rewriter.replaceOp(rootOp, result);
+ return success();
+ }
+};
+
struct LowerVectorMultiReductionPass
: public vector::impl::LowerVectorMultiReductionBase<
LowerVectorMultiReductionPass> {
@@ -746,11 +875,10 @@ void mlir::vector::populateVectorUnrollMultiReduction(
RewritePatternSet &patterns, VectorMultiReductionLowering options,
PatternBenefit benefit) {
if (options == VectorMultiReductionLowering::InnerReduction) {
- // TODO: Add UnrollMultiReductionInnerBaseCase and
- // UnrollMultiReductionInnerGeneralCase patterns here once implemented.
- // For now, fall back to the existing 2-D based lowering.
- patterns.add<TwoDimMultiReductionToReduction>(patterns.getContext(),
- benefit);
+ // TODO: Add UnrollMultiReductionInnerGeneralCase pattern here once
+ // implemented.
+ patterns.add<UnrollMultiReductionInnerBaseCase>(patterns.getContext(),
+ benefit);
} else {
patterns.add<UnrollMultiReductionOuterBaseCase,
UnrollMultiReductionOuterGeneralCase>(patterns.getContext(),
diff --git a/mlir/test/Dialect/Vector/unroll-vector-multi-reduction-inner.mlir b/mlir/test/Dialect/Vector/unroll-vector-multi-reduction-inner.mlir
new file mode 100644
index 0000000000000..89af4580bf627
--- /dev/null
+++ b/mlir/test/Dialect/Vector/unroll-vector-multi-reduction-inner.mlir
@@ -0,0 +1,106 @@
+// RUN: mlir-opt --split-input-file %s -transform-preload-library='transform-library-paths=%p/td/unroll-multi-reduction.mlir' \
+// RUN: -transform-interpreter=entry-point=unroll_multi_reduction_inner | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// Test UnrollVectorMultiReduction for Inner Reduction (Base Case)
+//===----------------------------------------------------------------------===//
+
+// The pattern recursively reduces rank until we reach 1D multi_reductions.
+// For vector<2x3x5xf32> with reduction on dim 2:
+// - First pass: unrolls along dim 0 (size 2), creating vector<3x5xf32> multi_reductions
+// - Second pass: unrolls along dim 0 (size 3), creating vector<5xf32> multi_reductions
+//
+// The generated IR groups operations by phase:
+// extracts (source) → extracts (acc) → multi_reductions → inserts
+
+// CHECK-LABEL: func @unroll_vector_multi_reduction_inner(
+// CHECK-SAME: %[[SOURCE:.+]]: vector<2x3x5xf32>,
+// CHECK-SAME: %[[ACC:.+]]: vector<2x3xf32>
+func.func @unroll_vector_multi_reduction_inner(%source: vector<2x3x5xf32>, %acc: vector<2x3xf32>) -> (vector<2x3xf32>) {
+ // First slice [0, ...]: extracts → reductions → inserts
+ // CHECK: vector.extract %[[SOURCE]][0, 0]
+ // CHECK: vector.extract %[[SOURCE]][0, 1]
+ // CHECK: vector.extract %[[SOURCE]][0, 2]
+ // CHECK: vector.extract %[[ACC]][0, 0]
+ // CHECK: vector.extract %[[ACC]][0, 1]
+ // CHECK: vector.extract %[[ACC]][0, 2]
+ // CHECK: vector.multi_reduction <add>, {{.*}} [0] : vector<5xf32> to f32
+ // CHECK: vector.multi_reduction <add>, {{.*}} [0] : vector<5xf32> to f32
+ // CHECK: vector.multi_reduction <add>, {{.*}} [0] : vector<5xf32> to f32
+ // CHECK: vector.insert {{.*}} [0] : f32 into vector<3xf32>
+ // CHECK: vector.insert {{.*}} [1] : f32 into vector<3xf32>
+ // CHECK: vector.insert {{.*}} [2] : f32 into vector<3xf32>
+ // Second slice [1, ...]: extracts → reductions → inserts
+ // CHECK: vector.extract %[[SOURCE]][1, 0]
+ // CHECK: vector.extract %[[SOURCE]][1, 1]
+ // CHECK: vector.extract %[[SOURCE]][1, 2]
+ // CHECK: vector.extract %[[ACC]][1, 0]
+ // CHECK: vector.extract %[[ACC]][1, 1]
+ // CHECK: vector.extract %[[ACC]][1, 2]
+ // CHECK: vector.multi_reduction <add>, {{.*}} [0] : vector<5xf32> to f32
+ // CHECK: vector.multi_reduction <add>, {{.*}} [0] : vector<5xf32> to f32
+ // CHECK: vector.multi_reduction <add>, {{.*}} [0] : vector<5xf32> to f32
+ // CHECK: vector.insert {{.*}} [0] : f32 into vector<3xf32>
+ // CHECK: vector.insert {{.*}} [1] : f32 into vector<3xf32>
+ // CHECK: vector.insert {{.*}} [2] : f32 into vector<3xf32>
+ // Final inserts to assemble result
+ // CHECK: vector.insert {{.*}} [0] : vector<3xf32> into vector<2x3xf32>
+ // CHECK: vector.insert {{.*}} [1] : vector<3xf32> into vector<2x3xf32>
+ // No original multi_reduction with [2] remains
+ // CHECK-NOT: vector.multi_reduction <add>, {{.*}} [2]
+ %1 = vector.multi_reduction <add>, %source, %acc [2] : vector<2x3x5xf32> to vector<2x3xf32>
+
+ return %1 : vector<2x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @unroll_vector_multi_reduction_inner_masked(
+// CHECK-SAME: %[[SOURCE:.+]]: vector<2x3x5xf32>,
+// CHECK-SAME: %[[MASK:.+]]: vector<2x3x5xi1>,
+// CHECK-SAME: %[[ACC:.+]]: vector<2x3xf32>
+func.func @unroll_vector_multi_reduction_inner_masked(%source: vector<2x3x5xf32>, %mask: vector<2x3x5xi1>, %acc: vector<2x3xf32>) -> (vector<2x3xf32>) {
+ // First slice [0, ...]: extracts (source, acc, mask) → masked reductions → inserts
+ // CHECK: vector.extract %[[SOURCE]][0, 0]
+ // CHECK: vector.extract %[[SOURCE]][0, 1]
+ // CHECK: vector.extract %[[SOURCE]][0, 2]
+ // CHECK: vector.extract %[[ACC]][0, 0]
+ // CHECK: vector.extract %[[ACC]][0, 1]
+ // CHECK: vector.extract %[[ACC]][0, 2]
+ // CHECK: vector.extract %[[MASK]][0, 0]
+ // CHECK: vector.extract %[[MASK]][0, 1]
+ // CHECK: vector.extract %[[MASK]][0, 2]
+ // CHECK: vector.mask {{.*}} { vector.multi_reduction <add>, {{.*}} [0] : vector<5xf32> to f32 }
+ // CHECK: vector.mask {{.*}} { vector.multi_reduction <add>, {{.*}} [0] : vector<5xf32> to f32 }
+ // CHECK: vector.mask {{.*}} { vector.multi_reduction <add>, {{.*}} [0] : vector<5xf32> to f32 }
+ // CHECK: vector.insert {{.*}} [0] : f32 into vector<3xf32>
+ // CHECK: vector.insert {{.*}} [1] : f32 into vector<3xf32>
+ // CHECK: vector.insert {{.*}} [2] : f32 into vector<3xf32>
+ // Second slice [1, ...]: extracts → masked reductions → inserts
+ // CHECK: vector.extract %[[SOURCE]][1, 0]
+ // CHECK: vector.extract %[[SOURCE]][1, 1]
+ // CHECK: vector.extract %[[SOURCE]][1, 2]
+ // CHECK: vector.extract %[[ACC]][1, 0]
+ // CHECK: vector.extract %[[ACC]][1, 1]
+ // CHECK: vector.extract %[[ACC]][1, 2]
+ // CHECK: vector.extract %[[MASK]][1, 0]
+ // CHECK: vector.extract %[[MASK]][1, 1]
+ // CHECK: vector.extract %[[MASK]][1, 2]
+ // CHECK: vector.mask {{.*}} { vector.multi_reduction <add>, {{.*}} [0] : vector<5xf32> to f32 }
+ // CHECK: vector.mask {{.*}} { vector.multi_reduction <add>, {{.*}} [0] : vector<5xf32> to f32 }
+ // CHECK: vector.mask {{.*}} { vector.multi_reduction <add>, {{.*}} [0] : vector<5xf32> to f32 }
+ // CHECK: vector.insert {{.*}} [0] : f32 into vector<3xf32>
+ // CHECK: vector.insert {{.*}} [1] : f32 into vector<3xf32>
+ // CHECK: vector.insert {{.*}} [2] : f32 into vector<3xf32>
+ // Final inserts to assemble result
+ // CHECK: vector.insert {{.*}} [0] : vector<3xf32> into vector<2x3xf32>
+ // CHECK: vector.insert {{.*}} [1] : vector<3xf32> into vector<2x3xf32>
+ // No original multi_reduction with [2] remains
+ // CHECK-NOT: vector.multi_reduction <add>, {{.*}} [2]
+
+ %0 = vector.mask %mask {
+ %1 = vector.multi_reduction <add>, %source, %acc [2] : vector<2x3x5xf32> to vector<2x3xf32>
+ } : vector<2x3x5xi1> -> vector<2x3xf32>
+
+ return %0 : vector<2x3xf32>
+}
>From 966b4cb70efd50f309dd3e5592bba37ddcc06d12 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Fri, 30 Jan 2026 13:16:22 -0500
Subject: [PATCH 06/14] [mlir][vector] Implement general pattern
---
.../Transforms/LowerVectorMultiReduction.cpp | 46 ++++++--------
.../unroll-vector-multi-reduction-inner.mlir | 60 +++++++++++++++++++
2 files changed, 79 insertions(+), 27 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 5e0d6da82f834..99b766eb9ad7c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -674,34 +674,34 @@ struct UnrollMultiReductionOuterGeneralCase
}
};
-/// Unrolls innermost dimension for vector.multi_reduction when the innermost
-/// dimension is the only reduction dimension.
+/// Unrolls vector.multi_reduction along the outermost parallel dimension
+/// when the innermost dimension is a reduction dimension.
///
/// This pattern matches operations where:
/// - The innermost dimension is a reduction dimension
/// - The outermost dimension is a parallel dimension
-/// - There is exactly one reduction dimension
///
-/// The transformation unrolls along the outermost parallel dimension:
+/// The transformation extracts slices along the outermost parallel dimension,
+/// creates smaller multi_reductions, and assembles the results:
///
/// ```mlir
-/// %res = vector.multi_reduction <add> %src, %acc [N-1]
-/// : vector<AxBx...xMxf32> to vector<AxBx...xf32>
+/// %res = vector.multi_reduction <add> %src, %acc [1, 3]
+/// : vector<AxBxCxDxf32> to vector<AxCxf32>
/// ```
///
/// becomes:
///
/// ```mlir
-/// %result = arith.constant dense<0.0> : vector<AxBx...xf32>
-/// %0 = vector.extract %src[0] : vector<Bx...xMxf32> from vector<AxBx...xMxf32>
-/// %acc0 = vector.extract %acc[0] : vector<Bx...xf32> from vector<AxBx...xf32>
-/// %red0 = vector.multi_reduction <add>, %0, %acc0 [N-2]
-/// : vector<Bx...xMxf32> to vector<Bx...xf32>
+/// %result = arith.constant dense<0.0> : vector<AxCxf32>
+/// %0 = vector.extract %src[0] : vector<BxCxDxf32> from vector<AxBxCxDxf32>
+/// %acc0 = vector.extract %acc[0] : vector<Cxf32> from vector<AxCxf32>
+/// %red0 = vector.multi_reduction <add>, %0, %acc0 [0, 2]
+/// : vector<BxCxDxf32> to vector<Cxf32>
/// %res0 = vector.insert %red0, %result[0]
-/// : vector<Bx...xf32> into vector<AxBx...xf32>
+/// : vector<Cxf32> into vector<AxCxf32>
/// // ... repeat for indices 1 to A-1
/// ```
-struct UnrollMultiReductionInnerBaseCase
+struct UnrollMultiReductionInner
: public OpRewritePattern<vector::MultiDimReductionOp> {
using Base::Base;
@@ -723,11 +723,6 @@ struct UnrollMultiReductionInnerBaseCase
multiReductionOp,
"expected outermost dimension to be a parallel dimension.");
- ArrayRef<int64_t> reductionDims = multiReductionOp.getReductionDims();
- if (reductionDims.size() != 1)
- return rewriter.notifyMatchFailure(
- multiReductionOp, "expected exactly one reduction dimension.");
-
Type elementType = getElementTypeOrSelf(multiReductionOp.getDestType());
if (!elementType.isIntOrIndexOrFloat())
return rewriter.notifyMatchFailure(
@@ -770,11 +765,11 @@ struct UnrollMultiReductionInnerBaseCase
else
maskSlices.push_back(nullptr);
- // Compute new reduction mask: for the extracted slice (rank srcRank-1),
- // the innermost dimension is still the reduction dimension.
- // New mask has srcRank-1 elements, with the last one being true.
- SmallVector<bool> newReductionMask(srcRank - 1, false);
- newReductionMask.back() = true;
+ // Compute new reduction mask by dropping the first element (dimension 0).
+ // Since dimension 0 is parallel (not reduced), all reduction indices shift
+ // down by 1.
+ ArrayRef<bool> newReductionMask =
+ ArrayRef<bool>(multiReductionOp.getReductionMask()).drop_front();
SmallVector<Value> reductionResults;
for (auto [srcSlice, accSlice, maskSlice] :
@@ -875,10 +870,7 @@ void mlir::vector::populateVectorUnrollMultiReduction(
RewritePatternSet &patterns, VectorMultiReductionLowering options,
PatternBenefit benefit) {
if (options == VectorMultiReductionLowering::InnerReduction) {
- // TODO: Add UnrollMultiReductionInnerGeneralCase pattern here once
- // implemented.
- patterns.add<UnrollMultiReductionInnerBaseCase>(patterns.getContext(),
- benefit);
+ patterns.add<UnrollMultiReductionInner>(patterns.getContext(), benefit);
} else {
patterns.add<UnrollMultiReductionOuterBaseCase,
UnrollMultiReductionOuterGeneralCase>(patterns.getContext(),
diff --git a/mlir/test/Dialect/Vector/unroll-vector-multi-reduction-inner.mlir b/mlir/test/Dialect/Vector/unroll-vector-multi-reduction-inner.mlir
index 89af4580bf627..fcb97967f088e 100644
--- a/mlir/test/Dialect/Vector/unroll-vector-multi-reduction-inner.mlir
+++ b/mlir/test/Dialect/Vector/unroll-vector-multi-reduction-inner.mlir
@@ -104,3 +104,63 @@ func.func @unroll_vector_multi_reduction_inner_masked(%source: vector<2x3x5xf32>
return %0 : vector<2x3xf32>
}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// Test UnrollVectorMultiReduction for Inner Reduction (General Case)
+//===----------------------------------------------------------------------===//
+
+// The general case handles multiple reduction dimensions.
+// For vector<2x3x5xf32> with reduction on dims [1, 2]:
+// - Unrolls along dim 0 (size 2), creating vector<3x5xf32> multi_reductions
+// - Each new multi_reduction has reduction dims [0, 1] (shifted from [1, 2])
+
+// CHECK-LABEL: func @unroll_vector_multi_reduction_inner_general(
+// CHECK-SAME: %[[SOURCE:.+]]: vector<2x3x5xf32>,
+// CHECK-SAME: %[[ACC:.+]]: vector<2xf32>
+func.func @unroll_vector_multi_reduction_inner_general(%source: vector<2x3x5xf32>, %acc: vector<2xf32>) -> (vector<2xf32>) {
+ // CHECK-DAG: %[[VEC_0:.+]] = vector.extract %[[SOURCE]][0] : vector<3x5xf32> from vector<2x3x5xf32>
+ // CHECK-DAG: %[[VEC_1:.+]] = vector.extract %[[SOURCE]][1] : vector<3x5xf32> from vector<2x3x5xf32>
+
+ // CHECK-DAG: %[[ACC_0:.+]] = vector.extract %[[ACC]][0] : f32 from vector<2xf32>
+ // CHECK-DAG: %[[ACC_1:.+]] = vector.extract %[[ACC]][1] : f32 from vector<2xf32>
+
+ // CHECK: %[[RED_0:.+]] = vector.multi_reduction <add>, %[[VEC_0]], %[[ACC_0]] [0, 1] : vector<3x5xf32> to f32
+ // CHECK: %[[RED_1:.+]] = vector.multi_reduction <add>, %[[VEC_1]], %[[ACC_1]] [0, 1] : vector<3x5xf32> to f32
+ // CHECK: %[[INSERT_0:.+]] = vector.insert %[[RED_0]], %{{.*}} [0] : f32 into vector<2xf32>
+ // CHECK: %[[INSERT_1:.+]] = vector.insert %[[RED_1]], %[[INSERT_0]] [1] : f32 into vector<2xf32>
+ %1 = vector.multi_reduction <add>, %source, %acc [1, 2] : vector<2x3x5xf32> to vector<2xf32>
+
+ // CHECK: return %[[INSERT_1]]
+ return %1 : vector<2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @unroll_vector_multi_reduction_inner_general_masked(
+// CHECK-SAME: %[[SOURCE:.+]]: vector<2x3x5xf32>,
+// CHECK-SAME: %[[MASK:.+]]: vector<2x3x5xi1>,
+// CHECK-SAME: %[[ACC:.+]]: vector<2xf32>
+func.func @unroll_vector_multi_reduction_inner_general_masked(%source: vector<2x3x5xf32>, %mask: vector<2x3x5xi1>, %acc: vector<2xf32>) -> (vector<2xf32>) {
+ // CHECK-DAG: %[[VEC_0:.+]] = vector.extract %[[SOURCE]][0] : vector<3x5xf32> from vector<2x3x5xf32>
+ // CHECK-DAG: %[[VEC_1:.+]] = vector.extract %[[SOURCE]][1] : vector<3x5xf32> from vector<2x3x5xf32>
+
+ // CHECK-DAG: %[[ACC_0:.+]] = vector.extract %[[ACC]][0] : f32 from vector<2xf32>
+ // CHECK-DAG: %[[ACC_1:.+]] = vector.extract %[[ACC]][1] : f32 from vector<2xf32>
+
+ // CHECK-DAG: %[[MASK_0:.+]] = vector.extract %[[MASK]][0] : vector<3x5xi1> from vector<2x3x5xi1>
+ // CHECK-DAG: %[[MASK_1:.+]] = vector.extract %[[MASK]][1] : vector<3x5xi1> from vector<2x3x5xi1>
+
+ // CHECK: %[[RED_0:.+]] = vector.mask %[[MASK_0]] { vector.multi_reduction <add>, %[[VEC_0]], %[[ACC_0]] [0, 1] : vector<3x5xf32> to f32 } : vector<3x5xi1> -> f32
+ // CHECK: %[[RED_1:.+]] = vector.mask %[[MASK_1]] { vector.multi_reduction <add>, %[[VEC_1]], %[[ACC_1]] [0, 1] : vector<3x5xf32> to f32 } : vector<3x5xi1> -> f32
+ // CHECK: %[[INSERT_0:.+]] = vector.insert %[[RED_0]], %{{.*}} [0] : f32 into vector<2xf32>
+ // CHECK: %[[INSERT_1:.+]] = vector.insert %[[RED_1]], %[[INSERT_0]] [1] : f32 into vector<2xf32>
+
+ %0 = vector.mask %mask {
+ %1 = vector.multi_reduction <add>, %source, %acc [1, 2] : vector<2x3x5xf32> to vector<2xf32>
+ } : vector<2x3x5xi1> -> vector<2xf32>
+
+ // CHECK: return %[[INSERT_1]]
+ return %0 : vector<2xf32>
+}
>From cb0d3fb8bc3a2982b9223372ba2d434c8a42fb63 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Fri, 30 Jan 2026 14:05:20 -0500
Subject: [PATCH 07/14] [mlir][vector] Skip matching multi_reduction on rank 1
---
.../Transforms/LowerVectorMultiReduction.cpp | 5 +++++
.../unroll-vector-multi-reduction-inner.mlir | 19 +++++++++++++++++++
.../Vector/unroll-vector-multi-reduction.mlir | 19 +++++++++++++++++++
3 files changed, 43 insertions(+)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 99b766eb9ad7c..a79b6ea10d543 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -541,6 +541,11 @@ struct UnrollMultiReductionOuterBaseCase
LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
PatternRewriter &rewriter) const override {
+ auto srcRank = multiReductionOp.getSourceVectorType().getRank();
+ if (srcRank < 2)
+ return rewriter.notifyMatchFailure(multiReductionOp,
+ "expected source rank >= 2.");
+
if (!multiReductionOp.isReducedDim(0))
return rewriter.notifyMatchFailure(
multiReductionOp,
diff --git a/mlir/test/Dialect/Vector/unroll-vector-multi-reduction-inner.mlir b/mlir/test/Dialect/Vector/unroll-vector-multi-reduction-inner.mlir
index fcb97967f088e..d6b3091fb9463 100644
--- a/mlir/test/Dialect/Vector/unroll-vector-multi-reduction-inner.mlir
+++ b/mlir/test/Dialect/Vector/unroll-vector-multi-reduction-inner.mlir
@@ -164,3 +164,22 @@ func.func @unroll_vector_multi_reduction_inner_general_masked(%source: vector<2x
// CHECK: return %[[INSERT_1]]
return %0 : vector<2xf32>
}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// Negative Test: Rank-1 multi_reduction should NOT be matched by UnrollMultiReductionInner
+//===----------------------------------------------------------------------===//
+
+// UnrollMultiReductionInner requires srcRank >= 2, so rank-1 should not match.
+// The op should remain unchanged.
+
+// CHECK-LABEL: func @unroll_vector_multi_reduction_inner_rank1_negative(
+// CHECK-SAME: %[[SOURCE:.+]]: vector<8xf32>,
+// CHECK-SAME: %[[ACC:.+]]: f32
+func.func @unroll_vector_multi_reduction_inner_rank1_negative(%source: vector<8xf32>, %acc: f32) -> f32 {
+ // CHECK: %[[RESULT:.+]] = vector.multi_reduction <add>, %[[SOURCE]], %[[ACC]] [0] : vector<8xf32> to f32
+ %0 = vector.multi_reduction <add>, %source, %acc [0] : vector<8xf32> to f32
+ // CHECK: return %[[RESULT]]
+ return %0 : f32
+}
diff --git a/mlir/test/Dialect/Vector/unroll-vector-multi-reduction.mlir b/mlir/test/Dialect/Vector/unroll-vector-multi-reduction.mlir
index 79086e2b0b9ad..8c36430bb6eee 100644
--- a/mlir/test/Dialect/Vector/unroll-vector-multi-reduction.mlir
+++ b/mlir/test/Dialect/Vector/unroll-vector-multi-reduction.mlir
@@ -90,3 +90,22 @@ func.func @unroll_vector_multi_reduction_general_masked(%source: vector<2x3x5xf3
// CHECK: return %[[RES_1]]
return %0 : vector<3xf32>
}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// Negative Test: Rank-1 multi_reduction should NOT be matched by unroll patterns
+//===----------------------------------------------------------------------===//
+
+// UnrollMultiReductionOuterBaseCase and UnrollMultiReductionOuterGeneralCase
+// should not match rank-1 multi_reduction ops. The op should remain unchanged.
+
+// CHECK-LABEL: func @unroll_vector_multi_reduction_rank1_negative(
+// CHECK-SAME: %[[SOURCE:.+]]: vector<8xf32>,
+// CHECK-SAME: %[[ACC:.+]]: f32
+func.func @unroll_vector_multi_reduction_rank1_negative(%source: vector<8xf32>, %acc: f32) -> f32 {
+ // CHECK: %[[RESULT:.+]] = vector.multi_reduction <add>, %[[SOURCE]], %[[ACC]] [0] : vector<8xf32> to f32
+ %0 = vector.multi_reduction <add>, %source, %acc [0] : vector<8xf32> to f32
+ // CHECK: return %[[RESULT]]
+ return %0 : f32
+}
>From f55f2bfda5e89e7e4a2e5787f5f824dc7f7076e0 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Fri, 30 Jan 2026 15:17:17 -0500
Subject: [PATCH 08/14] [mlir][vector] Convert mult_reduction in 1-D to
reduction.
---
.../Transforms/LowerVectorMultiReduction.cpp | 58 +++++++++++++++++
.../unroll-vector-multi-reduction-inner.mlir | 65 ++++++++++++-------
.../Vector/unroll-vector-multi-reduction.mlir | 27 ++++++--
3 files changed, 120 insertions(+), 30 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index a79b6ea10d543..595c2fce55619 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -413,6 +413,60 @@ struct TwoDimMultiReductionToReduction
}
};
+/// Converts 1-D vector.multi_reduction directly to vector.reduction.
+/// This is the terminal case for unrolling - once we reach rank 1,
+/// we convert to vector.reduction which backends can optimize.
+///
+/// Example:
+/// ```mlir
+/// // Before
+/// %r = vector.multi_reduction <add>, %v, %acc [0] : vector<Nxf32> to f32
+///
+/// // After
+/// %r = vector.reduction <add>, %v, %acc : vector<Nxf32> into f32
+/// ```
+struct OneDimMultiReductionToReduction
+ : public OpRewritePattern<vector::MultiDimReductionOp> {
+ using Base::Base;
+
+ LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
+ PatternRewriter &rewriter) const override {
+ auto srcRank = multiReductionOp.getSourceVectorType().getRank();
+ if (srcRank != 1)
+ return rewriter.notifyMatchFailure(multiReductionOp,
+ "expected source rank == 1.");
+
+ if (!multiReductionOp.isReducedDim(0))
+ return rewriter.notifyMatchFailure(
+ multiReductionOp,
+ "expected dimension 0 to be a reduction dimension.");
+
+ OpBuilder::InsertionGuard guard(rewriter);
+ auto maskableOp =
+ cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
+ Operation *rootOp;
+ Value mask;
+ if (maskableOp.isMasked()) {
+ rewriter.setInsertionPoint(maskableOp.getMaskingOp());
+ rootOp = maskableOp.getMaskingOp();
+ mask = maskableOp.getMaskingOp().getMask();
+ } else {
+ rootOp = multiReductionOp;
+ }
+
+ auto loc = multiReductionOp.getLoc();
+ Operation *reductionOp = vector::ReductionOp::create(
+ rewriter, loc, multiReductionOp.getKind(), multiReductionOp.getSource(),
+ multiReductionOp.getAcc());
+
+ if (maskableOp.isMasked())
+ reductionOp = mlir::vector::maskOperation(rewriter, reductionOp, mask);
+
+ rewriter.replaceOp(rootOp, reductionOp->getResult(0));
+ return success();
+ }
+};
+
/// Converts 1d vector.multi_reduction with a single reduction dimension to a 2d
/// form with both a single parallel and reduction dimension.
/// This is achieved with a simple vector.shape_cast that inserts a leading 1.
@@ -861,6 +915,8 @@ void mlir::vector::populateVectorMultiReductionFlatteningPatterns(
void mlir::vector::populateVectorMultiReductionUnrollingPatterns(
RewritePatternSet &patterns, VectorMultiReductionLowering options,
PatternBenefit benefit) {
+ patterns.add<OneDimMultiReductionToReduction>(patterns.getContext(), benefit);
+
if (options == VectorMultiReductionLowering ::InnerReduction) {
patterns.add<TwoDimMultiReductionToReduction>(patterns.getContext(),
benefit);
@@ -874,6 +930,8 @@ void mlir::vector::populateVectorMultiReductionUnrollingPatterns(
void mlir::vector::populateVectorUnrollMultiReduction(
RewritePatternSet &patterns, VectorMultiReductionLowering options,
PatternBenefit benefit) {
+ patterns.add<OneDimMultiReductionToReduction>(patterns.getContext(), benefit);
+
if (options == VectorMultiReductionLowering::InnerReduction) {
patterns.add<UnrollMultiReductionInner>(patterns.getContext(), benefit);
} else {
diff --git a/mlir/test/Dialect/Vector/unroll-vector-multi-reduction-inner.mlir b/mlir/test/Dialect/Vector/unroll-vector-multi-reduction-inner.mlir
index d6b3091fb9463..f21535c1dfc21 100644
--- a/mlir/test/Dialect/Vector/unroll-vector-multi-reduction-inner.mlir
+++ b/mlir/test/Dialect/Vector/unroll-vector-multi-reduction-inner.mlir
@@ -5,13 +5,15 @@
// Test UnrollVectorMultiReduction for Inner Reduction (Base Case)
//===----------------------------------------------------------------------===//
-// The pattern recursively reduces rank until we reach 1D multi_reductions.
+// The pattern recursively reduces rank until we reach 1D, then converts to
+// vector.reduction via OneDimMultiReductionToReduction.
// For vector<2x3x5xf32> with reduction on dim 2:
// - First pass: unrolls along dim 0 (size 2), creating vector<3x5xf32> multi_reductions
// - Second pass: unrolls along dim 0 (size 3), creating vector<5xf32> multi_reductions
+// - Final: 1-D multi_reductions are converted to vector.reduction
//
// The generated IR groups operations by phase:
-// extracts (source) → extracts (acc) → multi_reductions → inserts
+// extracts (source) → extracts (acc) → reductions → inserts
// CHECK-LABEL: func @unroll_vector_multi_reduction_inner(
// CHECK-SAME: %[[SOURCE:.+]]: vector<2x3x5xf32>,
@@ -24,9 +26,9 @@ func.func @unroll_vector_multi_reduction_inner(%source: vector<2x3x5xf32>, %acc:
// CHECK: vector.extract %[[ACC]][0, 0]
// CHECK: vector.extract %[[ACC]][0, 1]
// CHECK: vector.extract %[[ACC]][0, 2]
- // CHECK: vector.multi_reduction <add>, {{.*}} [0] : vector<5xf32> to f32
- // CHECK: vector.multi_reduction <add>, {{.*}} [0] : vector<5xf32> to f32
- // CHECK: vector.multi_reduction <add>, {{.*}} [0] : vector<5xf32> to f32
+ // CHECK: vector.reduction <add>, {{.*}} : vector<5xf32> into f32
+ // CHECK: vector.reduction <add>, {{.*}} : vector<5xf32> into f32
+ // CHECK: vector.reduction <add>, {{.*}} : vector<5xf32> into f32
// CHECK: vector.insert {{.*}} [0] : f32 into vector<3xf32>
// CHECK: vector.insert {{.*}} [1] : f32 into vector<3xf32>
// CHECK: vector.insert {{.*}} [2] : f32 into vector<3xf32>
@@ -37,17 +39,17 @@ func.func @unroll_vector_multi_reduction_inner(%source: vector<2x3x5xf32>, %acc:
// CHECK: vector.extract %[[ACC]][1, 0]
// CHECK: vector.extract %[[ACC]][1, 1]
// CHECK: vector.extract %[[ACC]][1, 2]
- // CHECK: vector.multi_reduction <add>, {{.*}} [0] : vector<5xf32> to f32
- // CHECK: vector.multi_reduction <add>, {{.*}} [0] : vector<5xf32> to f32
- // CHECK: vector.multi_reduction <add>, {{.*}} [0] : vector<5xf32> to f32
+ // CHECK: vector.reduction <add>, {{.*}} : vector<5xf32> into f32
+ // CHECK: vector.reduction <add>, {{.*}} : vector<5xf32> into f32
+ // CHECK: vector.reduction <add>, {{.*}} : vector<5xf32> into f32
// CHECK: vector.insert {{.*}} [0] : f32 into vector<3xf32>
// CHECK: vector.insert {{.*}} [1] : f32 into vector<3xf32>
// CHECK: vector.insert {{.*}} [2] : f32 into vector<3xf32>
// Final inserts to assemble result
// CHECK: vector.insert {{.*}} [0] : vector<3xf32> into vector<2x3xf32>
// CHECK: vector.insert {{.*}} [1] : vector<3xf32> into vector<2x3xf32>
- // No original multi_reduction with [2] remains
- // CHECK-NOT: vector.multi_reduction <add>, {{.*}} [2]
+ // No original multi_reduction remains
+ // CHECK-NOT: vector.multi_reduction
%1 = vector.multi_reduction <add>, %source, %acc [2] : vector<2x3x5xf32> to vector<2x3xf32>
return %1 : vector<2x3xf32>
@@ -70,9 +72,9 @@ func.func @unroll_vector_multi_reduction_inner_masked(%source: vector<2x3x5xf32>
// CHECK: vector.extract %[[MASK]][0, 0]
// CHECK: vector.extract %[[MASK]][0, 1]
// CHECK: vector.extract %[[MASK]][0, 2]
- // CHECK: vector.mask {{.*}} { vector.multi_reduction <add>, {{.*}} [0] : vector<5xf32> to f32 }
- // CHECK: vector.mask {{.*}} { vector.multi_reduction <add>, {{.*}} [0] : vector<5xf32> to f32 }
- // CHECK: vector.mask {{.*}} { vector.multi_reduction <add>, {{.*}} [0] : vector<5xf32> to f32 }
+ // CHECK: vector.mask {{.*}} { vector.reduction <add>, {{.*}} : vector<5xf32> into f32 }
+ // CHECK: vector.mask {{.*}} { vector.reduction <add>, {{.*}} : vector<5xf32> into f32 }
+ // CHECK: vector.mask {{.*}} { vector.reduction <add>, {{.*}} : vector<5xf32> into f32 }
// CHECK: vector.insert {{.*}} [0] : f32 into vector<3xf32>
// CHECK: vector.insert {{.*}} [1] : f32 into vector<3xf32>
// CHECK: vector.insert {{.*}} [2] : f32 into vector<3xf32>
@@ -86,17 +88,17 @@ func.func @unroll_vector_multi_reduction_inner_masked(%source: vector<2x3x5xf32>
// CHECK: vector.extract %[[MASK]][1, 0]
// CHECK: vector.extract %[[MASK]][1, 1]
// CHECK: vector.extract %[[MASK]][1, 2]
- // CHECK: vector.mask {{.*}} { vector.multi_reduction <add>, {{.*}} [0] : vector<5xf32> to f32 }
- // CHECK: vector.mask {{.*}} { vector.multi_reduction <add>, {{.*}} [0] : vector<5xf32> to f32 }
- // CHECK: vector.mask {{.*}} { vector.multi_reduction <add>, {{.*}} [0] : vector<5xf32> to f32 }
+ // CHECK: vector.mask {{.*}} { vector.reduction <add>, {{.*}} : vector<5xf32> into f32 }
+ // CHECK: vector.mask {{.*}} { vector.reduction <add>, {{.*}} : vector<5xf32> into f32 }
+ // CHECK: vector.mask {{.*}} { vector.reduction <add>, {{.*}} : vector<5xf32> into f32 }
// CHECK: vector.insert {{.*}} [0] : f32 into vector<3xf32>
// CHECK: vector.insert {{.*}} [1] : f32 into vector<3xf32>
// CHECK: vector.insert {{.*}} [2] : f32 into vector<3xf32>
// Final inserts to assemble result
// CHECK: vector.insert {{.*}} [0] : vector<3xf32> into vector<2x3xf32>
// CHECK: vector.insert {{.*}} [1] : vector<3xf32> into vector<2x3xf32>
- // No original multi_reduction with [2] remains
- // CHECK-NOT: vector.multi_reduction <add>, {{.*}} [2]
+ // No original multi_reduction remains
+ // CHECK-NOT: vector.multi_reduction
%0 = vector.mask %mask {
%1 = vector.multi_reduction <add>, %source, %acc [2] : vector<2x3x5xf32> to vector<2x3xf32>
@@ -168,18 +170,33 @@ func.func @unroll_vector_multi_reduction_inner_general_masked(%source: vector<2x
// -----
//===----------------------------------------------------------------------===//
-// Negative Test: Rank-1 multi_reduction should NOT be matched by UnrollMultiReductionInner
+// Test 1-D multi_reduction to vector.reduction conversion
//===----------------------------------------------------------------------===//
-// UnrollMultiReductionInner requires srcRank >= 2, so rank-1 should not match.
-// The op should remain unchanged.
+// OneDimMultiReductionToReduction converts rank-1 multi_reduction directly
+// to vector.reduction, which preserves the reduction semantic for backends.
-// CHECK-LABEL: func @unroll_vector_multi_reduction_inner_rank1_negative(
+// CHECK-LABEL: func @unroll_vector_multi_reduction_inner_1d(
// CHECK-SAME: %[[SOURCE:.+]]: vector<8xf32>,
// CHECK-SAME: %[[ACC:.+]]: f32
-func.func @unroll_vector_multi_reduction_inner_rank1_negative(%source: vector<8xf32>, %acc: f32) -> f32 {
- // CHECK: %[[RESULT:.+]] = vector.multi_reduction <add>, %[[SOURCE]], %[[ACC]] [0] : vector<8xf32> to f32
+func.func @unroll_vector_multi_reduction_inner_1d(%source: vector<8xf32>, %acc: f32) -> f32 {
+ // CHECK: %[[RESULT:.+]] = vector.reduction <add>, %[[SOURCE]], %[[ACC]] : vector<8xf32> into f32
%0 = vector.multi_reduction <add>, %source, %acc [0] : vector<8xf32> to f32
// CHECK: return %[[RESULT]]
return %0 : f32
}
+
+// -----
+
+// CHECK-LABEL: func @unroll_vector_multi_reduction_inner_1d_masked(
+// CHECK-SAME: %[[SOURCE:.+]]: vector<8xf32>,
+// CHECK-SAME: %[[MASK:.+]]: vector<8xi1>,
+// CHECK-SAME: %[[ACC:.+]]: f32
+func.func @unroll_vector_multi_reduction_inner_1d_masked(%source: vector<8xf32>, %mask: vector<8xi1>, %acc: f32) -> f32 {
+ // CHECK: %[[RESULT:.+]] = vector.mask %[[MASK]] { vector.reduction <add>, %[[SOURCE]], %[[ACC]] : vector<8xf32> into f32 } : vector<8xi1> -> f32
+ %0 = vector.mask %mask {
+ vector.multi_reduction <add>, %source, %acc [0] : vector<8xf32> to f32
+ } : vector<8xi1> -> f32
+ // CHECK: return %[[RESULT]]
+ return %0 : f32
+}
diff --git a/mlir/test/Dialect/Vector/unroll-vector-multi-reduction.mlir b/mlir/test/Dialect/Vector/unroll-vector-multi-reduction.mlir
index 8c36430bb6eee..6db425fe4fce9 100644
--- a/mlir/test/Dialect/Vector/unroll-vector-multi-reduction.mlir
+++ b/mlir/test/Dialect/Vector/unroll-vector-multi-reduction.mlir
@@ -94,18 +94,33 @@ func.func @unroll_vector_multi_reduction_general_masked(%source: vector<2x3x5xf3
// -----
//===----------------------------------------------------------------------===//
-// Negative Test: Rank-1 multi_reduction should NOT be matched by unroll patterns
+// Test 1-D multi_reduction to vector.reduction conversion
//===----------------------------------------------------------------------===//
-// UnrollMultiReductionOuterBaseCase and UnrollMultiReductionOuterGeneralCase
-// should not match rank-1 multi_reduction ops. The op should remain unchanged.
+// OneDimMultiReductionToReduction converts rank-1 multi_reduction directly
+// to vector.reduction, which preserves the reduction semantic for backends.
-// CHECK-LABEL: func @unroll_vector_multi_reduction_rank1_negative(
+// CHECK-LABEL: func @unroll_vector_multi_reduction_1d(
// CHECK-SAME: %[[SOURCE:.+]]: vector<8xf32>,
// CHECK-SAME: %[[ACC:.+]]: f32
-func.func @unroll_vector_multi_reduction_rank1_negative(%source: vector<8xf32>, %acc: f32) -> f32 {
- // CHECK: %[[RESULT:.+]] = vector.multi_reduction <add>, %[[SOURCE]], %[[ACC]] [0] : vector<8xf32> to f32
+func.func @unroll_vector_multi_reduction_1d(%source: vector<8xf32>, %acc: f32) -> f32 {
+ // CHECK: %[[RESULT:.+]] = vector.reduction <add>, %[[SOURCE]], %[[ACC]] : vector<8xf32> into f32
%0 = vector.multi_reduction <add>, %source, %acc [0] : vector<8xf32> to f32
// CHECK: return %[[RESULT]]
return %0 : f32
}
+
+// -----
+
+// CHECK-LABEL: func @unroll_vector_multi_reduction_1d_masked(
+// CHECK-SAME: %[[SOURCE:.+]]: vector<8xf32>,
+// CHECK-SAME: %[[MASK:.+]]: vector<8xi1>,
+// CHECK-SAME: %[[ACC:.+]]: f32
+func.func @unroll_vector_multi_reduction_1d_masked(%source: vector<8xf32>, %mask: vector<8xi1>, %acc: f32) -> f32 {
+ // CHECK: %[[RESULT:.+]] = vector.mask %[[MASK]] { vector.reduction <add>, %[[SOURCE]], %[[ACC]] : vector<8xf32> into f32 } : vector<8xi1> -> f32
+ %0 = vector.mask %mask {
+ vector.multi_reduction <add>, %source, %acc [0] : vector<8xf32> to f32
+ } : vector<8xi1> -> f32
+ // CHECK: return %[[RESULT]]
+ return %0 : f32
+}
>From 88c6bfe77b71901a83cd9aafdba95652fa8d6933 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Fri, 30 Jan 2026 15:33:20 -0500
Subject: [PATCH 09/14] [mlir][vector] Remove old patterns population.
Replace populateVectorMultiReductionUnrollingPatterns
with populateVectorUnrollMultiReduction which is more general.
---
.../Vector/Transforms/LoweringPatterns.h | 17 ----------------
.../Transforms/LowerVectorMultiReduction.cpp | 19 ++----------------
.../vector-multi-reduction-pass-lowering.mlir | 20 +++++++++++--------
3 files changed, 14 insertions(+), 42 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 1e78983ce082d..e6e62f2250fb7 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -89,23 +89,6 @@ void populateVectorMultiReductionFlatteningPatterns(
RewritePatternSet &patterns, VectorMultiReductionLowering options,
PatternBenefit benefit = 1);
-/// Collect a set of patterns to convert vector.multi_reduction op into
-/// a sequence of vector.reduction ops. The patterns comprise:
-///
-/// [TwoDimMultiReductionToElementWise]
-/// Once in 2-D vector.multi_reduction form, with an **outermost** reduction
-/// dimension, unroll the outer dimension to obtain a sequence of 1-D vector
-/// ops. This also has an opportunity for tree-reduction (in the future).
-///
-/// [TwoDimMultiReductionToReduction]
-/// Once in 2-D vector.multi_reduction form, with an **innermost** reduction
-/// dimension, unroll the outer dimension to obtain a sequence of extract +
-/// vector.reduction + insert. This can further lower to horizontal reduction
-/// ops.
-void populateVectorMultiReductionUnrollingPatterns(
- RewritePatternSet &patterns, VectorMultiReductionLowering options,
- PatternBenefit benefit = 1);
-
/// Collect a set of patterns to convert vector.multi_reduction op into
/// a sequence of vector.reduction ops. The patterns comprise:
///
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 595c2fce55619..cd651ecf6c745 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -883,8 +883,8 @@ struct LowerVectorMultiReductionPass
signalPassFailure();
RewritePatternSet unrollingPatterns(context);
- populateVectorMultiReductionUnrollingPatterns(unrollingPatterns,
- this->loweringStrategy);
+ populateVectorUnrollMultiReduction(unrollingPatterns,
+ this->loweringStrategy);
if (failed(applyPatternsGreedily(op, std::move(unrollingPatterns))))
signalPassFailure();
@@ -912,21 +912,6 @@ void mlir::vector::populateVectorMultiReductionFlatteningPatterns(
benefit);
}
-void mlir::vector::populateVectorMultiReductionUnrollingPatterns(
- RewritePatternSet &patterns, VectorMultiReductionLowering options,
- PatternBenefit benefit) {
- patterns.add<OneDimMultiReductionToReduction>(patterns.getContext(), benefit);
-
- if (options == VectorMultiReductionLowering ::InnerReduction) {
- patterns.add<TwoDimMultiReductionToReduction>(patterns.getContext(),
- benefit);
- } else {
- patterns.add<UnrollMultiReductionOuterBaseCase,
- UnrollMultiReductionOuterGeneralCase>(patterns.getContext(),
- benefit);
- }
-}
-
void mlir::vector::populateVectorUnrollMultiReduction(
RewritePatternSet &patterns, VectorMultiReductionLowering options,
PatternBenefit benefit) {
diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-pass-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-pass-lowering.mlir
index e01bf446eb83c..ac833612681cf 100644
--- a/mlir/test/Dialect/Vector/vector-multi-reduction-pass-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-pass-lowering.mlir
@@ -10,13 +10,13 @@ func.func @vector_multi_reduction(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -
// ALL-SAME: %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: vector<2xf32>)
// INNER-REDUCTION-DAG: %[[RESULT_VEC_0:.+]] = arith.constant dense<{{.*}}> : vector<2xf32>
// INNER-REDUCTION: %[[V0:.+]] = vector.extract %[[INPUT]][0]
-// INNER-REDUCTION: %[[ACC0:.+]] = vector.extract %[[ACC]][0]
-// INNER-REDUCTION: %[[RV0:.+]] = vector.reduction <mul>, %[[V0]], %[[ACC0]] : vector<4xf32> into f32
-// INNER-REDUCTION: %[[RESULT_VEC_1:.+]] = vector.insert %[[RV0:.+]], %[[RESULT_VEC_0]] [0] : f32 into vector<2xf32>
// INNER-REDUCTION: %[[V1:.+]] = vector.extract %[[INPUT]][1]
+// INNER-REDUCTION: %[[ACC0:.+]] = vector.extract %[[ACC]][0]
// INNER-REDUCTION: %[[ACC1:.+]] = vector.extract %[[ACC]][1]
+// INNER-REDUCTION: %[[RV0:.+]] = vector.reduction <mul>, %[[V0]], %[[ACC0]] : vector<4xf32> into f32
// INNER-REDUCTION: %[[RV1:.+]] = vector.reduction <mul>, %[[V1]], %[[ACC1]] : vector<4xf32> into f32
-// INNER-REDUCTION: %[[RESULT_VEC:.+]] = vector.insert %[[RV1:.+]], %[[RESULT_VEC_1]] [1] : f32 into vector<2xf32>
+// INNER-REDUCTION: %[[RESULT_VEC_1:.+]] = vector.insert %[[RV0]], %[[RESULT_VEC_0]] [0] : f32 into vector<2xf32>
+// INNER-REDUCTION: %[[RESULT_VEC:.+]] = vector.insert %[[RV1]], %[[RESULT_VEC_1]] [1] : f32 into vector<2xf32>
// INNER-REDUCTION: return %[[RESULT_VEC]]
// INNER-PARALLEL: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32>
@@ -51,10 +51,14 @@ func.func @vector_multi_reduction_masked(%arg0: vector<2x4xf32>, %acc: vector<2x
// ALL-LABEL: func @vector_multi_reduction_masked
// ALL-SAME: %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.+]]: vector<2xf32>, %[[MASK:.+]]: vector<2x4xi1>
-// INNER-REDUCTION: %[[INNERVEC:.+]] = vector.extract %[[INPUT]][0] : vector<4xf32> from vector<2x4xf32>
-// INNER-REDUCTION: %[[INNERACC:.+]] = vector.extract %[[ACC]][0] : f32 from vector<2xf32>
-// INNER-REDUCTION: %[[INNERMASK:.+]] = vector.extract %[[MASK]][0] : vector<4xi1> from vector<2x4xi1>
-// INNER-REDUCTION: vector.mask %[[INNERMASK]] { vector.reduction <mul>, %[[INNERVEC]], %[[INNERACC]] : vector<4xf32> into f32 } : vector<4xi1> -> f32
+// INNER-REDUCTION: %[[V0:.+]] = vector.extract %[[INPUT]][0] : vector<4xf32> from vector<2x4xf32>
+// INNER-REDUCTION: %[[V1:.+]] = vector.extract %[[INPUT]][1] : vector<4xf32> from vector<2x4xf32>
+// INNER-REDUCTION: %[[ACC0:.+]] = vector.extract %[[ACC]][0] : f32 from vector<2xf32>
+// INNER-REDUCTION: %[[ACC1:.+]] = vector.extract %[[ACC]][1] : f32 from vector<2xf32>
+// INNER-REDUCTION: %[[MASK0:.+]] = vector.extract %[[MASK]][0] : vector<4xi1> from vector<2x4xi1>
+// INNER-REDUCTION: %[[MASK1:.+]] = vector.extract %[[MASK]][1] : vector<4xi1> from vector<2x4xi1>
+// INNER-REDUCTION: vector.mask %[[MASK0]] { vector.reduction <mul>, %[[V0]], %[[ACC0]] : vector<4xf32> into f32 } : vector<4xi1> -> f32
+// INNER-REDUCTION: vector.mask %[[MASK1]] { vector.reduction <mul>, %[[V1]], %[[ACC1]] : vector<4xf32> into f32 } : vector<4xi1> -> f32
// INNER-PARALLEL: %[[TPMASK:.+]] = vector.transpose %[[MASK]], [1, 0] : vector<2x4xi1> to vector<4x2xi1>
// INNER-PARALLEL: %[[TPINPUT:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32>
// INNER-PARALLEL: %[[INNERVEC:.+]] = vector.extract %[[TPINPUT]][0] : vector<2xf32> from vector<4x2xf32>
>From b8a14a1df67d09b6a3a53dcce91a3a17e982775e Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Fri, 30 Jan 2026 16:42:01 -0500
Subject: [PATCH 10/14] [mlir][vector] Add missing rank check
---
.../Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp | 5 +++++
1 file changed, 5 insertions(+)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index cd651ecf6c745..2e1da48a3f768 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -664,6 +664,11 @@ struct UnrollMultiReductionOuterGeneralCase
LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
PatternRewriter &rewriter) const override {
+ auto srcRank = multiReductionOp.getSourceVectorType().getRank();
+ if (srcRank < 2)
+ return rewriter.notifyMatchFailure(multiReductionOp,
+ "expected source rank >= 2.");
+
if (!multiReductionOp.isReducedDim(0))
return rewriter.notifyMatchFailure(
multiReductionOp,
>From b513b1af291cad4ae67064e73368bbf2b11ab836 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Fri, 30 Jan 2026 16:43:48 -0500
Subject: [PATCH 11/14] Split documentation
---
.../Transforms/LowerVectorMultiReduction.cpp | 67 +++++++++----------
1 file changed, 31 insertions(+), 36 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 2e1da48a3f768..08fcff00b2c19 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -539,56 +539,27 @@ struct OneDimMultiReductionToTwoDim
}
};
-/// Unrolls outermost dimension for vector.multi_reduction.
-/// This patterns matches operations which reduce the outermost dimension,
-/// it does not transform operations for which the outermost dimension is not
-/// a reduction dimension.
+/// Unrolls vector.multi_reduction when the outermost dimension is the only
+/// reduction dimension. Transforms to elementwise arithmetic operations.
///
-/// There are two cases to consider:
-/// 1. The base case is when the outermost dimension is the only reduction
-/// dimension.
-/// 2. The general case is when the outermost dimension is not the only
-/// reduction dimension.
-///
-/// The base case transformation:
+/// Example:
///
/// ```mlir
-/// %res = vector.multi_reduction <add> %src, %acc [0] : vector<NxMx...xf32> to
-/// vector<Mx...xf32>
+/// %res = vector.multi_reduction <add> %src, %acc [0]
+/// : vector<NxMx...xf32> to vector<Mx...xf32>
/// ```
///
-/// will extract N vectors from %src and then perform elementwise operations.
+/// becomes:
///
/// ```mlir
/// %0 = vector.extract %src[0] : vector<Mx...xf32> from vector<NxMx...xf32>
/// ...
-/// %Nminus1 = vector.extract %src[ [[N-1]] ] : vector<Mx...x.f32> from
-/// vector<NxMx...xf32>
+/// %Nminus1 = vector.extract %src[N-1] : vector<Mx...xf32> from vector<NxMx...xf32>
///
/// %res0 = arith.addf %0, %acc : vector<Mx...xf32>
/// ...
/// %res = arith.addf %Nminus1, %resNminus2 : vector<Mx...xf32>
/// ```
-///
-/// For the general case, we still extract N vectors, but produce N
-/// vector.multi_reduction instead of elementwise operations.
-///
-/// ```mlir
-/// %res = vector.multi_reduction <add> %src, %acc [0, [[REDUCTION_DIMS]] ] :
-/// vector<NxMx...xf32> to vector<Ix...xf32>
-///
-/// ```mlir
-/// %0 = vector.extract %src[0] : vector<Mx...xf32> from vector<NxMx...xf32>
-/// ...
-/// %Nminus1 = vector.extract %src[ [[N-1]] ] : vector<Mx...x.f32> from
-/// vector<NxMx...xf32>
-///
-/// %red0 = vector.multi_reduction %0, %acc [ [[REDUCTION_DIMS]] ] :
-/// vector<Mx...xf32> to vector<Ix...xf32>
-/// ...
-/// %res = vector.multi_reduction %Nminus1, %redNminus2 [ [[REDUCTION_DIMS]] ] :
-/// vector<Mx...xf32> to vector<Ix...xf32>
-/// ```
struct UnrollMultiReductionOuterBaseCase
: public OpRewritePattern<vector::MultiDimReductionOp> {
using Base::Base;
@@ -658,6 +629,30 @@ struct UnrollMultiReductionOuterBaseCase
}
};
+/// Unrolls vector.multi_reduction when the outermost dimension is one of
+/// multiple reduction dimensions. Extracts slices and chains smaller
+/// multi_reduction operations.
+///
+/// Example:
+///
+/// ```mlir
+/// %res = vector.multi_reduction <add> %src, %acc [0, 2]
+/// : vector<NxMx...xf32> to vector<Ix...xf32>
+/// ```
+///
+/// becomes:
+///
+/// ```mlir
+/// %0 = vector.extract %src[0] : vector<Mx...xf32> from vector<NxMx...xf32>
+/// ...
+/// %Nminus1 = vector.extract %src[N-1] : vector<Mx...xf32> from vector<NxMx...xf32>
+///
+/// %red0 = vector.multi_reduction <add>, %0, %acc [1]
+/// : vector<Mx...xf32> to vector<Ix...xf32>
+/// ...
+/// %res = vector.multi_reduction <add>, %Nminus1, %redNminus2 [1]
+/// : vector<Mx...xf32> to vector<Ix...xf32>
+/// ```
struct UnrollMultiReductionOuterGeneralCase
: public OpRewritePattern<vector::MultiDimReductionOp> {
using Base::Base;
>From 2240787542a6eb7a26ae6a0f575d3a1ad9ba347e Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Fri, 30 Jan 2026 16:48:55 -0500
Subject: [PATCH 12/14] Rename
---
.../Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp | 5 +++--
1 file changed, 3 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 08fcff00b2c19..2f8d0a2e2da05 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -760,7 +760,7 @@ struct UnrollMultiReductionOuterGeneralCase
/// : vector<Cxf32> into vector<AxCxf32>
/// // ... repeat for indices 1 to A-1
/// ```
-struct UnrollMultiReductionInner
+struct UnrollInnerReductionAlongOuterParallel
: public OpRewritePattern<vector::MultiDimReductionOp> {
using Base::Base;
@@ -918,7 +918,8 @@ void mlir::vector::populateVectorUnrollMultiReduction(
patterns.add<OneDimMultiReductionToReduction>(patterns.getContext(), benefit);
if (options == VectorMultiReductionLowering::InnerReduction) {
- patterns.add<UnrollMultiReductionInner>(patterns.getContext(), benefit);
+ patterns.add<UnrollInnerReductionAlongOuterParallel>(patterns.getContext(),
+ benefit);
} else {
patterns.add<UnrollMultiReductionOuterBaseCase,
UnrollMultiReductionOuterGeneralCase>(patterns.getContext(),
>From 36037c7db36b83dacd13ae77c0497231e2cceada Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Fri, 30 Jan 2026 17:01:32 -0500
Subject: [PATCH 13/14] Style
---
.../Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp | 6 ++++--
1 file changed, 4 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 2f8d0a2e2da05..3c9e7f05105e1 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -554,7 +554,8 @@ struct OneDimMultiReductionToTwoDim
/// ```mlir
/// %0 = vector.extract %src[0] : vector<Mx...xf32> from vector<NxMx...xf32>
/// ...
-/// %Nminus1 = vector.extract %src[N-1] : vector<Mx...xf32> from vector<NxMx...xf32>
+/// %Nminus1 = vector.extract %src[N-1] : vector<Mx...xf32> from
+/// vector<NxMx...xf32>
///
/// %res0 = arith.addf %0, %acc : vector<Mx...xf32>
/// ...
@@ -645,7 +646,8 @@ struct UnrollMultiReductionOuterBaseCase
/// ```mlir
/// %0 = vector.extract %src[0] : vector<Mx...xf32> from vector<NxMx...xf32>
/// ...
-/// %Nminus1 = vector.extract %src[N-1] : vector<Mx...xf32> from vector<NxMx...xf32>
+/// %Nminus1 = vector.extract %src[N-1] : vector<Mx...xf32> from
+/// vector<NxMx...xf32>
///
/// %red0 = vector.multi_reduction <add>, %0, %acc [1]
/// : vector<Mx...xf32> to vector<Ix...xf32>
>From 5bf640c6d59a879ccf9632e3b80698097b252a9f Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Fri, 30 Jan 2026 18:57:15 -0500
Subject: [PATCH 14/14] Fix asan
---
.../Vector/Transforms/LowerVectorMultiReduction.cpp | 7 ++++---
1 file changed, 4 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 3c9e7f05105e1..ea65b02fe71fa 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -713,8 +713,8 @@ struct UnrollMultiReductionOuterGeneralCase
else
masks.push_back(nullptr);
- ArrayRef<bool> reductionMask =
- ArrayRef<bool>(multiReductionOp.getReductionMask()).drop_front();
+ SmallVector<bool> fullReductionMask = multiReductionOp.getReductionMask();
+ ArrayRef<bool> reductionMask = ArrayRef<bool>(fullReductionMask).drop_front();
Value result = multiReductionOp.getAcc();
for (auto [innerVector, innerMask] : llvm::zip(vectors, masks)) {
@@ -829,8 +829,9 @@ struct UnrollInnerReductionAlongOuterParallel
// Compute new reduction mask by dropping the first element (dimension 0).
// Since dimension 0 is parallel (not reduced), all reduction indices shift
// down by 1.
+ SmallVector<bool> fullReductionMask = multiReductionOp.getReductionMask();
ArrayRef<bool> newReductionMask =
- ArrayRef<bool>(multiReductionOp.getReductionMask()).drop_front();
+ ArrayRef<bool>(fullReductionMask).drop_front();
SmallVector<Value> reductionResults;
for (auto [srcSlice, accSlice, maskSlice] :
More information about the Mlir-commits
mailing list