[Mlir-commits] [mlir] [mlir][vector] Separate multi_reduce lowering into transformations, flattening, and unrolling. (PR #178974)
Erick Ochoa Lopez
llvmlistbot at llvm.org
Mon Feb 2 07:23:51 PST 2026
https://github.com/amd-eochoalo updated https://github.com/llvm/llvm-project/pull/178974
>From 5c3d7cc4f98ede4b98177565b2e2dd972133b78d Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Thu, 29 Jan 2026 10:29:59 -0500
Subject: [PATCH 01/18] [mlir][vector] Finer pattern selection for
VectorMultiReduction (NFC).
* Adds populateVectorMultiReductionInnerOuterDimPatterns which populates
InnerOuterDimReductionConversion and OneDimMultiReductionToTwoDim.
These patterns set invariants for the patterns to follow.
* Adds populateVectorMultiReductionFlatteningPatterns which populates
ReduceMultiDimReductionRank, TwoDimMultiReductionToElementWise, and
TwoDimMultiReductionToReduction.
---
.../Vector/Transforms/LoweringPatterns.h | 40 +++++++++++++++++++
.../TransformOps/VectorTransformOps.cpp | 2 +
.../Transforms/LowerVectorMultiReduction.cpp | 36 +++++++++++++++--
3 files changed, 74 insertions(+), 4 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 7bd96c8a6d1a1..6d39183d3e66d 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -60,6 +60,46 @@ void populateVectorContractLoweringPatterns(
void populateVectorOuterProductLoweringPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);
+/// Collect a set of patterns to set invariants for vector.multi_reduction's
+/// conversion. The patterns comprise:
+///
+/// [InnerOuterDimReductionConversion]
+/// Rewrites vector.multi_reduction such that all reduction dimensions are
+/// either innermost or outermost, by adding the proper vector.transpose
+/// operations.
+///
+/// [OneDimMultiReductionToTwoDim]
+/// For cases that reduce to 1-D vector<k> reduction (and are thus missing
+/// either a parallel or a reduction), we lift them back up to 2-D with a simple
+/// vector.shape_cast to vector<1xk> so that the other patterns can kick in,
+/// thus fully exiting out of the vector.multi_reduction abstraction.
+void populateVectorMultiReductionInnerOuterDimPatterns(
+ RewritePatternSet &patterns, VectorMultiReductionLowering options,
+ PatternBenefit benefit = 1);
+
+/// Collect a set of patterns to convert vector.multi_reduction op into
+/// a sequence of vector.reduction ops. The patterns comprise:
+///
+/// [ReduceMultiDimReductionRank]
+/// Once in innermost or outermost reduction
+/// form, rewrites n-D vector.multi_reduction into 2-D vector.multi_reduction,
+/// by introducing vector.shape_cast ops to collapse + multi-reduce + expand
+/// back.
+///
+/// [TwoDimMultiReductionToElementWise]
+/// Once in 2-D vector.multi_reduction form, with an **outermost** reduction
+/// dimension, unroll the outer dimension to obtain a sequence of 1-D vector
+/// ops. This also has an opportunity for tree-reduction (in the future).
+///
+/// [TwoDimMultiReductionToReduction]
+/// Once in 2-D vector.multi_reduction form, with an **innermost** reduction
+/// dimension, unroll the outer dimension to obtain a sequence of extract +
+/// vector.reduction + insert. This can further lower to horizontal reduction
+/// ops.
+void populateVectorMultiReductionFlatteningPatterns(
+ RewritePatternSet &patterns, VectorMultiReductionLowering options,
+ PatternBenefit benefit = 1);
+
/// Collect a set of patterns to convert vector.multi_reduction op into
/// a sequence of vector.reduction ops. The patterns comprise:
///
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 7faa222a9e574..f651e74873a91 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -130,6 +130,8 @@ void transform::ApplyLowerMultiReductionPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
vector::VectorTransformsOptions vectorTransformOptions;
vectorTransformOptions.setVectorMultiReductionLowering(getLoweringStrategy());
+ vector::populateVectorMultiReductionInnerOuterDimPatterns(
+ patterns, vectorTransformOptions.vectorMultiReductionLowering);
vector::populateVectorMultiReductionLoweringPatterns(
patterns, vectorTransformOptions.vectorMultiReductionLowering);
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index e86e2a97038db..993b6a7f6e08f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -496,11 +496,18 @@ struct LowerVectorMultiReductionPass
Operation *op = getOperation();
MLIRContext *context = op->getContext();
- RewritePatternSet loweringPatterns(context);
- populateVectorMultiReductionLoweringPatterns(loweringPatterns,
- this->loweringStrategy);
+ RewritePatternSet innerOuterPatterns(context);
+ populateVectorMultiReductionInnerOuterDimPatterns(innerOuterPatterns,
+ this->loweringStrategy);
- if (failed(applyPatternsGreedily(op, std::move(loweringPatterns))))
+ if (failed(applyPatternsGreedily(op, std::move(innerOuterPatterns))))
+ signalPassFailure();
+
+ RewritePatternSet flatteningPatterns(context);
+ populateVectorMultiReductionFlatteningPatterns(flatteningPatterns,
+ this->loweringStrategy);
+
+ if (failed(applyPatternsGreedily(op, std::move(flatteningPatterns))))
signalPassFailure();
}
@@ -511,6 +518,27 @@ struct LowerVectorMultiReductionPass
} // namespace
+void mlir::vector::populateVectorMultiReductionInnerOuterDimPatterns(
+ RewritePatternSet &patterns, VectorMultiReductionLowering options,
+ PatternBenefit benefit) {
+ patterns.add<OneDimMultiReductionToTwoDim>(patterns.getContext(), benefit);
+ patterns.add<InnerOuterDimReductionConversion>(patterns.getContext(), options,
+ benefit);
+}
+
+void mlir::vector::populateVectorMultiReductionFlatteningPatterns(
+ RewritePatternSet &patterns, VectorMultiReductionLowering options,
+ PatternBenefit benefit) {
+ patterns.add<ReduceMultiDimReductionRank>(patterns.getContext(), options,
+ benefit);
+ if (options == VectorMultiReductionLowering ::InnerReduction)
+ patterns.add<TwoDimMultiReductionToReduction>(patterns.getContext(),
+ benefit);
+ else
+ patterns.add<TwoDimMultiReductionToElementWise>(patterns.getContext(),
+ benefit);
+}
+
void mlir::vector::populateVectorMultiReductionLoweringPatterns(
RewritePatternSet &patterns, VectorMultiReductionLowering options,
PatternBenefit benefit) {
>From c275cccda7659e564db2d615da5feabf2ae25e4d Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Thu, 29 Jan 2026 10:37:16 -0500
Subject: [PATCH 02/18] [mlir][vector] Finer pattern selection for
VectorMultiReduction (NFC)
* Removes the following patterns from
populateVectorMultiReductionFlatteningPatterns:
* TwoDimMultiReductionToElementWise
* TwoDimMultiReductionToReduction
* Adds populateVectorMultiReductionUnrollingPatterns with patterns
TwoDimMultiReductionToElementWise, TwoDimMultiReductionToReduction
---
.../Dialect/Vector/Transforms/LoweringPatterns.h | 8 +++++++-
.../Vector/Transforms/LowerVectorMultiReduction.cpp | 12 ++++++++++++
2 files changed, 19 insertions(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 6d39183d3e66d..1beb8e25cb683 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -85,6 +85,12 @@ void populateVectorMultiReductionInnerOuterDimPatterns(
/// form, rewrites n-D vector.multi_reduction into 2-D vector.multi_reduction,
/// by introducing vector.shape_cast ops to collapse + multi-reduce + expand
/// back.
+void populateVectorMultiReductionFlatteningPatterns(
+ RewritePatternSet &patterns, VectorMultiReductionLowering options,
+ PatternBenefit benefit = 1);
+
+/// Collect a set of patterns to convert vector.multi_reduction op into
+/// a sequence of vector.reduction ops. The patterns comprise:
///
/// [TwoDimMultiReductionToElementWise]
/// Once in 2-D vector.multi_reduction form, with an **outermost** reduction
@@ -96,7 +102,7 @@ void populateVectorMultiReductionInnerOuterDimPatterns(
/// dimension, unroll the outer dimension to obtain a sequence of extract +
/// vector.reduction + insert. This can further lower to horizontal reduction
/// ops.
-void populateVectorMultiReductionFlatteningPatterns(
+void populateVectorMultiReductionUnrollingPatterns(
RewritePatternSet &patterns, VectorMultiReductionLowering options,
PatternBenefit benefit = 1);
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 993b6a7f6e08f..2eb8048bb4a69 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -509,6 +509,13 @@ struct LowerVectorMultiReductionPass
if (failed(applyPatternsGreedily(op, std::move(flatteningPatterns))))
signalPassFailure();
+
+ RewritePatternSet unrollingPatterns(context);
+ populateVectorMultiReductionUnrollingPatterns(unrollingPatterns,
+ this->loweringStrategy);
+
+ if (failed(applyPatternsGreedily(op, std::move(unrollingPatterns))))
+ signalPassFailure();
}
void getDependentDialects(DialectRegistry ®istry) const override {
@@ -531,6 +538,11 @@ void mlir::vector::populateVectorMultiReductionFlatteningPatterns(
PatternBenefit benefit) {
patterns.add<ReduceMultiDimReductionRank>(patterns.getContext(), options,
benefit);
+}
+
+void mlir::vector::populateVectorMultiReductionUnrollingPatterns(
+ RewritePatternSet &patterns, VectorMultiReductionLowering options,
+ PatternBenefit benefit) {
if (options == VectorMultiReductionLowering ::InnerReduction)
patterns.add<TwoDimMultiReductionToReduction>(patterns.getContext(),
benefit);
>From 061615a77c44db310f67f32b9de53b02654f0b53 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Fri, 30 Jan 2026 11:13:52 -0500
Subject: [PATCH 03/18] [mlir][vector] rank reduce unrolling for
vector.multi_reduction.
---
.../Transforms/LowerVectorMultiReduction.cpp | 199 +++++++++++++++++-
.../vector-multi-reduction-pass-lowering.mlir | 6 +-
2 files changed, 198 insertions(+), 7 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 2eb8048bb4a69..b2b01031587cf 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -485,6 +485,195 @@ struct OneDimMultiReductionToTwoDim
}
};
+/// Unrolls outermost dimension for vector.multi_reduction.
+/// This patterns matches operations which reduce the outermost dimension,
+/// it does not transform operations for which the outermost dimension is not
+/// a reduction dimension.
+///
+/// There are two cases to consider:
+/// 1. The base case is when the outermost dimension is the only reduction
+/// dimension.
+/// 2. The general case is when the outermost dimension is not the only
+/// reduction dimension.
+///
+/// The base case transformation:
+///
+/// ```mlir
+/// %res = vector.multi_reduction <add> %src, %acc [0] : vector<NxMx...xf32> to
+/// vector<Mx...xf32>
+/// ```
+///
+/// will extract N vectors from %src and then perform elementwise operations.
+///
+/// ```mlir
+/// %0 = vector.extract %src[0] : vector<Mx...xf32> from vector<NxMx...xf32>
+/// ...
+/// %Nminus1 = vector.extract %src[ [[N-1]] ] : vector<Mx...x.f32> from
+/// vector<NxMx...xf32>
+///
+/// %res0 = arith.addf %0, %acc : vector<Mx...xf32>
+/// ...
+/// %res = arith.addf %Nminus1, %resNminus2 : vector<Mx...xf32>
+/// ```
+///
+/// For the general case, we still extract N vectors, but produce N
+/// vector.multi_reduction instead of elementwise operations.
+///
+/// ```mlir
+/// %res = vector.multi_reduction <add> %src, %acc [0, [[REDUCTION_DIMS]] ] :
+/// vector<NxMx...xf32> to vector<Ix...xf32>
+///
+/// ```mlir
+/// %0 = vector.extract %src[0] : vector<Mx...xf32> from vector<NxMx...xf32>
+/// ...
+/// %Nminus1 = vector.extract %src[ [[N-1]] ] : vector<Mx...x.f32> from
+/// vector<NxMx...xf32>
+///
+/// %red0 = vector.multi_reduction %0, %acc [ [[REDUCTION_DIMS]] ] :
+/// vector<Mx...xf32> to vector<Ix...xf32>
+/// ...
+/// %res = vector.multi_reduction %Nminus1, %redNminus2 [ [[REDUCTION_DIMS]] ] :
+/// vector<Mx...xf32> to vector<Ix...xf32>
+/// ```
+struct UnrollMultiReductionOuterBaseCase
+ : public OpRewritePattern<vector::MultiDimReductionOp> {
+ using Base::Base;
+
+ LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
+ PatternRewriter &rewriter) const override {
+ if (!multiReductionOp.isReducedDim(0))
+ return rewriter.notifyMatchFailure(
+ multiReductionOp,
+ "expected outermost dimension to be reduced dimension.");
+
+ Type elementType = getElementTypeOrSelf(multiReductionOp.getDestType());
+ if (!elementType.isIntOrIndexOrFloat())
+ return rewriter.notifyMatchFailure(
+ multiReductionOp, "expected integer or float element type.");
+
+ ArrayRef<int64_t> reductionDims = multiReductionOp.getReductionDims();
+ if (reductionDims.size() > 1)
+ return rewriter.notifyMatchFailure(
+ multiReductionOp, "expected only one reduction dimension.");
+
+ Location loc = multiReductionOp.getLoc();
+ Value source = multiReductionOp.getSource();
+
+ ArrayRef<int64_t> srcShape =
+ multiReductionOp.getSourceVectorType().getShape();
+ int64_t numElementwiseOps = srcShape.front();
+
+ OpBuilder::InsertionGuard guard(rewriter);
+ auto maskableOp =
+ cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
+ bool isMasked = maskableOp.isMasked();
+ Operation *rootOp;
+ Value mask = nullptr;
+ if (isMasked) {
+ rewriter.setInsertionPoint(maskableOp.getMaskingOp());
+ rootOp = maskableOp.getMaskingOp();
+ mask = maskableOp.getMaskingOp().getMask();
+ } else {
+ rootOp = multiReductionOp;
+ }
+
+ SmallVector<Value> vectors;
+ for (int64_t i = 0; i < numElementwiseOps; ++i)
+ vectors.push_back(vector::ExtractOp::create(rewriter, loc, source, i));
+
+ SmallVector<Value> masks;
+ for (int64_t i = 0; i < numElementwiseOps; ++i)
+ if (isMasked)
+ masks.push_back(vector::ExtractOp::create(rewriter, loc, mask, i));
+ else
+ masks.push_back(nullptr);
+
+ Value result = multiReductionOp.getAcc();
+ for (auto [innerVector, innerMask] : llvm::zip(vectors, masks))
+ result = makeArithReduction(rewriter, loc, multiReductionOp.getKind(),
+ innerVector, result, /*fastmath=*/nullptr,
+ innerMask);
+
+ rewriter.replaceOp(rootOp, result);
+ return success();
+ }
+};
+
+struct UnrollMultiReductionOuterGeneralCase
+ : public OpRewritePattern<vector::MultiDimReductionOp> {
+ using Base::Base;
+
+ LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
+ PatternRewriter &rewriter) const override {
+ if (!multiReductionOp.isReducedDim(0))
+ return rewriter.notifyMatchFailure(
+ multiReductionOp,
+ "expected outermost dimension to be reduced dimension.");
+
+ Type elementType = getElementTypeOrSelf(multiReductionOp.getDestType());
+ if (!elementType.isIntOrIndexOrFloat())
+ return rewriter.notifyMatchFailure(
+ multiReductionOp, "expected integer or float element type.");
+
+ ArrayRef<int64_t> reductionDims = multiReductionOp.getReductionDims();
+ if (reductionDims.size() <= 1)
+ return rewriter.notifyMatchFailure(
+ multiReductionOp, "expected more than one reduction dimension.");
+
+ Location loc = multiReductionOp.getLoc();
+ Value source = multiReductionOp.getSource();
+
+ ArrayRef<int64_t> srcShape =
+ multiReductionOp.getSourceVectorType().getShape();
+ int64_t numElementwiseOps = srcShape.front();
+
+ OpBuilder::InsertionGuard guard(rewriter);
+ auto maskableOp =
+ cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
+ bool isMasked = maskableOp.isMasked();
+ Operation *rootOp;
+ Value mask = nullptr;
+ if (isMasked) {
+ rewriter.setInsertionPoint(maskableOp.getMaskingOp());
+ rootOp = maskableOp.getMaskingOp();
+ mask = maskableOp.getMaskingOp().getMask();
+ } else {
+ rootOp = multiReductionOp;
+ }
+
+ SmallVector<Value> vectors;
+ for (int64_t i = 0; i < numElementwiseOps; ++i)
+ vectors.push_back(vector::ExtractOp::create(rewriter, loc, source, i));
+
+ SmallVector<Value> masks;
+ for (int64_t i = 0; i < numElementwiseOps; ++i)
+ if (isMasked)
+ masks.push_back(vector::ExtractOp::create(rewriter, loc, mask, i));
+ else
+ masks.push_back(nullptr);
+
+ ArrayRef<bool> reductionMask =
+ ArrayRef<bool>(multiReductionOp.getReductionMask()).drop_front();
+ Value result = multiReductionOp.getAcc();
+ for (auto [innerVector, innerMask] : llvm::zip(vectors, masks)) {
+
+ auto reductionOp = vector::MultiDimReductionOp::create(
+ rewriter, loc, innerVector, result, reductionMask,
+ multiReductionOp.getKind());
+
+ if (isMasked) {
+ auto maskOp = vector::maskOperation(rewriter, reductionOp, innerMask);
+ result = maskOp->getResult(0);
+ } else {
+ result = reductionOp.getResult();
+ }
+ }
+
+ rewriter.replaceOp(rootOp, result);
+ return success();
+ }
+};
+
struct LowerVectorMultiReductionPass
: public vector::impl::LowerVectorMultiReductionBase<
LowerVectorMultiReductionPass> {
@@ -543,12 +732,14 @@ void mlir::vector::populateVectorMultiReductionFlatteningPatterns(
void mlir::vector::populateVectorMultiReductionUnrollingPatterns(
RewritePatternSet &patterns, VectorMultiReductionLowering options,
PatternBenefit benefit) {
- if (options == VectorMultiReductionLowering ::InnerReduction)
+ if (options == VectorMultiReductionLowering ::InnerReduction) {
patterns.add<TwoDimMultiReductionToReduction>(patterns.getContext(),
benefit);
- else
- patterns.add<TwoDimMultiReductionToElementWise>(patterns.getContext(),
- benefit);
+ } else {
+ patterns.add<UnrollMultiReductionOuterBaseCase,
+ UnrollMultiReductionOuterGeneralCase>(patterns.getContext(),
+ benefit);
+ }
}
void mlir::vector::populateVectorMultiReductionLoweringPatterns(
diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-pass-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-pass-lowering.mlir
index ddbc5c7bdb2c0..e01bf446eb83c 100644
--- a/mlir/test/Dialect/Vector/vector-multi-reduction-pass-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-pass-lowering.mlir
@@ -21,12 +21,12 @@ func.func @vector_multi_reduction(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -
// INNER-PARALLEL: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32>
// INNER-PARALLEL: %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<2xf32> from vector<4x2xf32>
-// INNER-PARALLEL: %[[RV0:.+]] = arith.mulf %[[V0]], %[[ACC]] : vector<2xf32>
// INNER-PARALLEL: %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<2xf32> from vector<4x2xf32>
-// INNER-PARALLEL: %[[RV01:.+]] = arith.mulf %[[V1]], %[[RV0]] : vector<2xf32>
// INNER-PARALLEL: %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<2xf32> from vector<4x2xf32>
-// INNER-PARALLEL: %[[RV012:.+]] = arith.mulf %[[V2]], %[[RV01]] : vector<2xf32>
// INNER-PARALLEL: %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<2xf32> from vector<4x2xf32>
+// INNER-PARALLEL: %[[RV0:.+]] = arith.mulf %[[V0]], %[[ACC]] : vector<2xf32>
+// INNER-PARALLEL: %[[RV01:.+]] = arith.mulf %[[V1]], %[[RV0]] : vector<2xf32>
+// INNER-PARALLEL: %[[RV012:.+]] = arith.mulf %[[V2]], %[[RV01]] : vector<2xf32>
// INNER-PARALLEL: %[[RESULT_VEC:.+]] = arith.mulf %[[V3]], %[[RV012]] : vector<2xf32>
// INNER-PARALLEL: return %[[RESULT_VEC]] : vector<2xf32>
>From 01f8163355718f4cf8a50606e406691881eee7d4 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Fri, 30 Jan 2026 12:07:30 -0500
Subject: [PATCH 04/18] [mlir][vector] Add tests for multi_reduction unrolling.
---
.../Vector/TransformOps/VectorTransformOps.td | 20 ++++
.../Vector/Transforms/LoweringPatterns.h | 24 +++++
.../TransformOps/VectorTransformOps.cpp | 8 ++
.../Transforms/LowerVectorMultiReduction.cpp | 16 ++++
.../Vector/td/unroll-multi-reduction.mlir | 24 +++++
.../Vector/unroll-vector-multi-reduction.mlir | 92 +++++++++++++++++++
6 files changed, 184 insertions(+)
create mode 100644 mlir/test/Dialect/Vector/td/unroll-multi-reduction.mlir
create mode 100644 mlir/test/Dialect/Vector/unroll-vector-multi-reduction.mlir
diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index 03d25505dc65c..373f9a4210c41 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -243,6 +243,26 @@ def ApplyLowerMultiReductionPatternsOp : Op<Transform_Dialect,
}];
}
+def ApplyUnrollMultiReductionPatternsOp : Op<Transform_Dialect,
+ "apply_patterns.vector.unroll_multi_reduction",
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ let description = [{
+ Unrolls vector.multi_reduction operations by progressively reducing rank
+ along the outermost dimension.
+
+ This is an alternative to the flattening-based lowering that preserves
+ the n-D structure during progressive lowering.
+ }];
+
+ let arguments = (ins DefaultValuedAttr<VectorMultiReductionLoweringAttr,
+ "vector::VectorMultiReductionLowering::InnerParallel">:$lowering_strategy
+ );
+
+ let assemblyFormat = [{
+ (`lowering_strategy` `=` $lowering_strategy^)? attr-dict
+ }];
+}
+
def ApplyLowerOuterProductPatternsOp : Op<Transform_Dialect,
"apply_patterns.vector.lower_outerproduct",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 1beb8e25cb683..1e78983ce082d 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -140,6 +140,30 @@ void populateVectorMultiReductionLoweringPatterns(
RewritePatternSet &patterns, VectorMultiReductionLowering options,
PatternBenefit benefit = 1);
+/// Collect a set of patterns to unroll vector.multi_reduction ops by
+/// progressively reducing rank along the outermost dimension.
+///
+/// For OuterReduction (outermost dim is reduction):
+/// [UnrollMultiReductionOuterBaseCase]
+/// When the outermost dimension is the only reduction dimension, unroll to
+/// produce elementwise arithmetic operations.
+///
+/// [UnrollMultiReductionOuterGeneralCase]
+/// When the outermost dimension is one of multiple reduction dimensions,
+/// unroll to produce smaller multi_reduction operations.
+///
+/// For InnerReduction (innermost dim is reduction):
+/// [UnrollMultiReductionInnerBaseCase]
+/// When the innermost dimension is the only reduction dimension, unroll along
+/// the outermost parallel dimension.
+///
+/// [UnrollMultiReductionInnerGeneralCase]
+/// When the innermost dimension is one of multiple reduction dimensions,
+/// unroll along the outermost parallel dimension.
+void populateVectorUnrollMultiReduction(RewritePatternSet &patterns,
+ VectorMultiReductionLowering options,
+ PatternBenefit benefit = 1);
+
/// Populate the pattern set with the following patterns:
///
/// [TransferReadToVectorLoadLowering]
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index f651e74873a91..4562393559c94 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -136,6 +136,14 @@ void transform::ApplyLowerMultiReductionPatternsOp::populatePatterns(
patterns, vectorTransformOptions.vectorMultiReductionLowering);
}
+void transform::ApplyUnrollMultiReductionPatternsOp::populatePatterns(
+ RewritePatternSet &patterns) {
+ vector::VectorTransformsOptions vectorTransformOptions;
+ vectorTransformOptions.setVectorMultiReductionLowering(getLoweringStrategy());
+ vector::populateVectorUnrollMultiReduction(
+ patterns, vectorTransformOptions.vectorMultiReductionLowering);
+}
+
void transform::ApplyLowerOuterProductPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
populateVectorOuterProductLoweringPatterns(patterns);
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index b2b01031587cf..8548e8c2d311d 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -742,6 +742,22 @@ void mlir::vector::populateVectorMultiReductionUnrollingPatterns(
}
}
+void mlir::vector::populateVectorUnrollMultiReduction(
+ RewritePatternSet &patterns, VectorMultiReductionLowering options,
+ PatternBenefit benefit) {
+ if (options == VectorMultiReductionLowering::InnerReduction) {
+ // TODO: Add UnrollMultiReductionInnerBaseCase and
+ // UnrollMultiReductionInnerGeneralCase patterns here once implemented.
+ // For now, fall back to the existing 2-D based lowering.
+ patterns.add<TwoDimMultiReductionToReduction>(patterns.getContext(),
+ benefit);
+ } else {
+ patterns.add<UnrollMultiReductionOuterBaseCase,
+ UnrollMultiReductionOuterGeneralCase>(patterns.getContext(),
+ benefit);
+ }
+}
+
void mlir::vector::populateVectorMultiReductionLoweringPatterns(
RewritePatternSet &patterns, VectorMultiReductionLowering options,
PatternBenefit benefit) {
diff --git a/mlir/test/Dialect/Vector/td/unroll-multi-reduction.mlir b/mlir/test/Dialect/Vector/td/unroll-multi-reduction.mlir
new file mode 100644
index 0000000000000..96a68723266d3
--- /dev/null
+++ b/mlir/test/Dialect/Vector/td/unroll-multi-reduction.mlir
@@ -0,0 +1,24 @@
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @unroll_multi_reduction(%module_op: !transform.any_op {transform.readonly}) {
+
+ %func_op = transform.structured.match ops{["func.func"]} in %module_op
+ : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func_op {
+ // Test patterns
+ transform.apply_patterns.vector.unroll_multi_reduction
+ } : !transform.any_op
+
+ transform.yield
+ }
+ transform.named_sequence @unroll_multi_reduction_inner(%module_op: !transform.any_op {transform.readonly}) {
+
+ %func_op = transform.structured.match ops{["func.func"]} in %module_op
+ : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func_op {
+ // Test patterns
+ transform.apply_patterns.vector.unroll_multi_reduction lowering_strategy = "innerreduction"
+ } : !transform.any_op
+
+ transform.yield
+ }
+}
diff --git a/mlir/test/Dialect/Vector/unroll-vector-multi-reduction.mlir b/mlir/test/Dialect/Vector/unroll-vector-multi-reduction.mlir
new file mode 100644
index 0000000000000..79086e2b0b9ad
--- /dev/null
+++ b/mlir/test/Dialect/Vector/unroll-vector-multi-reduction.mlir
@@ -0,0 +1,92 @@
+// RUN: mlir-opt --split-input-file %s -transform-preload-library='transform-library-paths=%p/td/unroll-multi-reduction.mlir' \
+// RUN: -transform-interpreter=entry-point=unroll_multi_reduction | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// Test UnrollVectorMultiReduction
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func @unroll_vector_multi_reduction(
+// CHECK-SAME: %[[SOURCE:.+]]: vector<2x3x5xf32>,
+// CHECK-SAME: %[[ACC:.+]]: vector<3x5xf32>
+func.func @unroll_vector_multi_reduction(%source: vector<2x3x5xf32>, %acc: vector<3x5xf32>) -> (vector<3x5xf32>) {
+ // CHECK-DAG: %[[VEC_0:.+]] = vector.extract %[[SOURCE]][0] : vector<3x5xf32> from vector<2x3x5xf32>
+ // CHECK-DAG: %[[VEC_1:.+]] = vector.extract %[[SOURCE]][1] : vector<3x5xf32> from vector<2x3x5xf32>
+
+ // CHECK: %[[RES_0:.+]] = arith.addf %[[VEC_0]], %[[ACC]] : vector<3x5xf32>
+ // CHECK: %[[RES_1:.+]] = arith.addf %[[VEC_1]], %[[RES_0]] : vector<3x5xf32>
+ %1 = vector.multi_reduction <add>, %source, %acc [0] : vector<2x3x5xf32> to vector<3x5xf32>
+
+ // CHECK: return %[[RES_1]]
+ return %1 : vector<3x5xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @unroll_vector_multi_reduction_masked(
+// CHECK-SAME: %[[SOURCE:.+]]: vector<2x3x5xf32>,
+// CHECK-SAME: %[[MASK:.+]]: vector<2x3x5xi1>,
+// CHECK-SAME: %[[ACC:.+]]: vector<3x5xf32>
+func.func @unroll_vector_multi_reduction_masked(%source: vector<2x3x5xf32>, %mask: vector<2x3x5xi1>, %acc: vector<3x5xf32>) -> (vector<3x5xf32>) {
+ // CHECK-DAG: %[[VEC_0:.+]] = vector.extract %[[SOURCE]][0] : vector<3x5xf32> from vector<2x3x5xf32>
+ // CHECK-DAG: %[[VEC_1:.+]] = vector.extract %[[SOURCE]][1] : vector<3x5xf32> from vector<2x3x5xf32>
+
+ // CHECK-DAG: %[[MASK_0:.+]] = vector.extract %[[MASK]][0] : vector<3x5xi1> from vector<2x3x5xi1>
+ // CHECK-DAG: %[[MASK_1:.+]] = vector.extract %[[MASK]][1] : vector<3x5xi1> from vector<2x3x5xi1>
+
+ // CHECK: %[[RES_0:.+]] = arith.addf %[[VEC_0]], %[[ACC]] : vector<3x5xf32>
+ // CHECK: %[[RES_MASKED_0:.+]] = arith.select %[[MASK_0]], %[[RES_0]], %[[ACC]] : vector<3x5xi1>, vector<3x5xf32>
+
+ // CHECK: %[[RES_1:.+]] = arith.addf %[[VEC_1]], %[[RES_MASKED_0]] : vector<3x5xf32>
+ // CHECK: %[[RES_MASKED_1:.+]] = arith.select %[[MASK_1]], %[[RES_1]], %[[RES_MASKED_0]] : vector<3x5xi1>, vector<3x5xf32>
+
+ %0 = vector.mask %mask {
+ %1 = vector.multi_reduction <add>, %source, %acc [0] : vector<2x3x5xf32> to vector<3x5xf32>
+ } : vector<2x3x5xi1> -> vector<3x5xf32>
+
+ // CHECK: return %[[RES_MASKED_1]]
+ return %0 : vector<3x5xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @unroll_vector_multi_reduction_general(
+// CHECK-SAME: %[[SOURCE:.+]]: vector<2x3x5xf32>,
+// CHECK-SAME: %[[ACC:.+]]: vector<3xf32>
+func.func @unroll_vector_multi_reduction_general(%source: vector<2x3x5xf32>, %acc: vector<3xf32>) -> (vector<3xf32>) {
+
+ // CHECK-DAG: %[[VEC_0:.+]] = vector.extract %[[SOURCE]][0] : vector<3x5xf32> from vector<2x3x5xf32>
+ // CHECK-DAG: %[[VEC_1:.+]] = vector.extract %[[SOURCE]][1] : vector<3x5xf32> from vector<2x3x5xf32>
+
+ // CHECK: %[[RES_0:.+]] = vector.multi_reduction <add>, %[[VEC_0]], %[[ACC]] [1] : vector<3x5xf32> to vector<3xf32>
+ // CHECK: %[[RES_1:.+]] = vector.multi_reduction <add>, %[[VEC_1]], %[[RES_0]] [1] : vector<3x5xf32> to vector<3xf32>
+
+ %1 = vector.multi_reduction <add>, %source, %acc [0, 2] : vector<2x3x5xf32> to vector<3xf32>
+
+ // CHECK: return %[[RES_1]]
+ return %1 : vector<3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @unroll_vector_multi_reduction_general_masked(
+// CHECK-SAME: %[[SOURCE:.+]]: vector<2x3x5xf32>,
+// CHECK-SAME: %[[MASK:.+]]: vector<2x3x5xi1>,
+// CHECK-SAME: %[[ACC:.+]]: vector<3xf32>
+func.func @unroll_vector_multi_reduction_general_masked(%source: vector<2x3x5xf32>, %mask: vector<2x3x5xi1>, %acc: vector<3xf32>) -> (vector<3xf32>) {
+
+ // CHECK-DAG: %[[VEC_0:.+]] = vector.extract %[[SOURCE]][0] : vector<3x5xf32> from vector<2x3x5xf32>
+ // CHECK-DAG: %[[VEC_1:.+]] = vector.extract %[[SOURCE]][1] : vector<3x5xf32> from vector<2x3x5xf32>
+
+ // CHECK-DAG: %[[MASK_0:.+]] = vector.extract %[[MASK]][0] : vector<3x5xi1> from vector<2x3x5xi1>
+ // CHECK-DAG: %[[MASK_1:.+]] = vector.extract %[[MASK]][1] : vector<3x5xi1> from vector<2x3x5xi1>
+
+ // CHECK: %[[RES_0:.+]] = vector.mask %[[MASK_0]] { vector.multi_reduction <add>, %[[VEC_0]], %[[ACC]] [1] : vector<3x5xf32> to vector<3xf32> } : vector<3x5xi1> -> vector<3xf32>
+ // CHECK: %[[RES_1:.+]] = vector.mask %[[MASK_1]] { vector.multi_reduction <add>, %[[VEC_1]], %[[RES_0]] [1] : vector<3x5xf32> to vector<3xf32> } : vector<3x5xi1> -> vector<3xf32>
+
+ %0 = vector.mask %mask {
+ %1 = vector.multi_reduction <add>, %source, %acc [0, 2] : vector<2x3x5xf32> to vector<3xf32>
+ } : vector<2x3x5xi1> -> vector<3xf32>
+
+ // CHECK: return %[[RES_1]]
+ return %0 : vector<3xf32>
+}
>From 29c0e50abc926126b9e1b8cd9ee9c83afff90cca Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Fri, 30 Jan 2026 12:49:09 -0500
Subject: [PATCH 05/18] [mlir][vector] Add unrolling pattern for innermost
reduction base case
---
.../Transforms/LowerVectorMultiReduction.cpp | 138 +++++++++++++++++-
.../unroll-vector-multi-reduction-inner.mlir | 106 ++++++++++++++
2 files changed, 239 insertions(+), 5 deletions(-)
create mode 100644 mlir/test/Dialect/Vector/unroll-vector-multi-reduction-inner.mlir
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 8548e8c2d311d..5e0d6da82f834 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -674,6 +674,135 @@ struct UnrollMultiReductionOuterGeneralCase
}
};
+/// Unrolls innermost dimension for vector.multi_reduction when the innermost
+/// dimension is the only reduction dimension.
+///
+/// This pattern matches operations where:
+/// - The innermost dimension is a reduction dimension
+/// - The outermost dimension is a parallel dimension
+/// - There is exactly one reduction dimension
+///
+/// The transformation unrolls along the outermost parallel dimension:
+///
+/// ```mlir
+/// %res = vector.multi_reduction <add> %src, %acc [N-1]
+/// : vector<AxBx...xMxf32> to vector<AxBx...xf32>
+/// ```
+///
+/// becomes:
+///
+/// ```mlir
+/// %result = arith.constant dense<0.0> : vector<AxBx...xf32>
+/// %0 = vector.extract %src[0] : vector<Bx...xMxf32> from vector<AxBx...xMxf32>
+/// %acc0 = vector.extract %acc[0] : vector<Bx...xf32> from vector<AxBx...xf32>
+/// %red0 = vector.multi_reduction <add>, %0, %acc0 [N-2]
+/// : vector<Bx...xMxf32> to vector<Bx...xf32>
+/// %res0 = vector.insert %red0, %result[0]
+/// : vector<Bx...xf32> into vector<AxBx...xf32>
+/// // ... repeat for indices 1 to A-1
+/// ```
+struct UnrollMultiReductionInnerBaseCase
+ : public OpRewritePattern<vector::MultiDimReductionOp> {
+ using Base::Base;
+
+ LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
+ PatternRewriter &rewriter) const override {
+ auto srcRank = multiReductionOp.getSourceVectorType().getRank();
+
+ if (srcRank < 2)
+ return rewriter.notifyMatchFailure(multiReductionOp,
+ "expected source rank >= 2.");
+
+ if (!multiReductionOp.isReducedDim(srcRank - 1))
+ return rewriter.notifyMatchFailure(
+ multiReductionOp,
+ "expected innermost dimension to be a reduction dimension.");
+
+ if (multiReductionOp.isReducedDim(0))
+ return rewriter.notifyMatchFailure(
+ multiReductionOp,
+ "expected outermost dimension to be a parallel dimension.");
+
+ ArrayRef<int64_t> reductionDims = multiReductionOp.getReductionDims();
+ if (reductionDims.size() != 1)
+ return rewriter.notifyMatchFailure(
+ multiReductionOp, "expected exactly one reduction dimension.");
+
+ Type elementType = getElementTypeOrSelf(multiReductionOp.getDestType());
+ if (!elementType.isIntOrIndexOrFloat())
+ return rewriter.notifyMatchFailure(
+ multiReductionOp, "expected integer or float element type.");
+
+ Location loc = multiReductionOp.getLoc();
+ Value source = multiReductionOp.getSource();
+ Value acc = multiReductionOp.getAcc();
+
+ ArrayRef<int64_t> srcShape =
+ multiReductionOp.getSourceVectorType().getShape();
+ int64_t numSlices = srcShape.front();
+
+ OpBuilder::InsertionGuard guard(rewriter);
+ auto maskableOp =
+ cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
+ bool isMasked = maskableOp.isMasked();
+ Operation *rootOp;
+ Value mask = nullptr;
+ if (isMasked) {
+ rewriter.setInsertionPoint(maskableOp.getMaskingOp());
+ rootOp = maskableOp.getMaskingOp();
+ mask = maskableOp.getMaskingOp().getMask();
+ } else {
+ rootOp = multiReductionOp;
+ }
+
+ SmallVector<Value> srcSlices;
+ for (int64_t i = 0; i < numSlices; ++i)
+ srcSlices.push_back(vector::ExtractOp::create(rewriter, loc, source, i));
+
+ SmallVector<Value> accSlices;
+ for (int64_t i = 0; i < numSlices; ++i)
+ accSlices.push_back(vector::ExtractOp::create(rewriter, loc, acc, i));
+
+ SmallVector<Value> maskSlices;
+ for (int64_t i = 0; i < numSlices; ++i)
+ if (isMasked)
+ maskSlices.push_back(vector::ExtractOp::create(rewriter, loc, mask, i));
+ else
+ maskSlices.push_back(nullptr);
+
+ // Compute new reduction mask: for the extracted slice (rank srcRank-1),
+ // the innermost dimension is still the reduction dimension.
+ // New mask has srcRank-1 elements, with the last one being true.
+ SmallVector<bool> newReductionMask(srcRank - 1, false);
+ newReductionMask.back() = true;
+
+ SmallVector<Value> reductionResults;
+ for (auto [srcSlice, accSlice, maskSlice] :
+ llvm::zip(srcSlices, accSlices, maskSlices)) {
+ Operation *newReductionOp = vector::MultiDimReductionOp::create(
+ rewriter, loc, srcSlice, accSlice, newReductionMask,
+ multiReductionOp.getKind());
+
+ if (isMasked)
+ newReductionOp =
+ mlir::vector::maskOperation(rewriter, newReductionOp, maskSlice);
+
+ reductionResults.push_back(newReductionOp->getResult(0));
+ }
+
+ Value result = arith::ConstantOp::create(
+ rewriter, loc, multiReductionOp.getDestType(),
+ rewriter.getZeroAttr(multiReductionOp.getDestType()));
+
+ for (int64_t i = 0; i < numSlices; ++i)
+ result = vector::InsertOp::create(rewriter, loc, reductionResults[i],
+ result, i);
+
+ rewriter.replaceOp(rootOp, result);
+ return success();
+ }
+};
+
struct LowerVectorMultiReductionPass
: public vector::impl::LowerVectorMultiReductionBase<
LowerVectorMultiReductionPass> {
@@ -746,11 +875,10 @@ void mlir::vector::populateVectorUnrollMultiReduction(
RewritePatternSet &patterns, VectorMultiReductionLowering options,
PatternBenefit benefit) {
if (options == VectorMultiReductionLowering::InnerReduction) {
- // TODO: Add UnrollMultiReductionInnerBaseCase and
- // UnrollMultiReductionInnerGeneralCase patterns here once implemented.
- // For now, fall back to the existing 2-D based lowering.
- patterns.add<TwoDimMultiReductionToReduction>(patterns.getContext(),
- benefit);
+ // TODO: Add UnrollMultiReductionInnerGeneralCase pattern here once
+ // implemented.
+ patterns.add<UnrollMultiReductionInnerBaseCase>(patterns.getContext(),
+ benefit);
} else {
patterns.add<UnrollMultiReductionOuterBaseCase,
UnrollMultiReductionOuterGeneralCase>(patterns.getContext(),
diff --git a/mlir/test/Dialect/Vector/unroll-vector-multi-reduction-inner.mlir b/mlir/test/Dialect/Vector/unroll-vector-multi-reduction-inner.mlir
new file mode 100644
index 0000000000000..89af4580bf627
--- /dev/null
+++ b/mlir/test/Dialect/Vector/unroll-vector-multi-reduction-inner.mlir
@@ -0,0 +1,106 @@
+// RUN: mlir-opt --split-input-file %s -transform-preload-library='transform-library-paths=%p/td/unroll-multi-reduction.mlir' \
+// RUN: -transform-interpreter=entry-point=unroll_multi_reduction_inner | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// Test UnrollVectorMultiReduction for Inner Reduction (Base Case)
+//===----------------------------------------------------------------------===//
+
+// The pattern recursively reduces rank until we reach 1D multi_reductions.
+// For vector<2x3x5xf32> with reduction on dim 2:
+// - First pass: unrolls along dim 0 (size 2), creating vector<3x5xf32> multi_reductions
+// - Second pass: unrolls along dim 0 (size 3), creating vector<5xf32> multi_reductions
+//
+// The generated IR groups operations by phase:
+// extracts (source) → extracts (acc) → multi_reductions → inserts
+
+// CHECK-LABEL: func @unroll_vector_multi_reduction_inner(
+// CHECK-SAME: %[[SOURCE:.+]]: vector<2x3x5xf32>,
+// CHECK-SAME: %[[ACC:.+]]: vector<2x3xf32>
+func.func @unroll_vector_multi_reduction_inner(%source: vector<2x3x5xf32>, %acc: vector<2x3xf32>) -> (vector<2x3xf32>) {
+ // First slice [0, ...]: extracts → reductions → inserts
+ // CHECK: vector.extract %[[SOURCE]][0, 0]
+ // CHECK: vector.extract %[[SOURCE]][0, 1]
+ // CHECK: vector.extract %[[SOURCE]][0, 2]
+ // CHECK: vector.extract %[[ACC]][0, 0]
+ // CHECK: vector.extract %[[ACC]][0, 1]
+ // CHECK: vector.extract %[[ACC]][0, 2]
+ // CHECK: vector.multi_reduction <add>, {{.*}} [0] : vector<5xf32> to f32
+ // CHECK: vector.multi_reduction <add>, {{.*}} [0] : vector<5xf32> to f32
+ // CHECK: vector.multi_reduction <add>, {{.*}} [0] : vector<5xf32> to f32
+ // CHECK: vector.insert {{.*}} [0] : f32 into vector<3xf32>
+ // CHECK: vector.insert {{.*}} [1] : f32 into vector<3xf32>
+ // CHECK: vector.insert {{.*}} [2] : f32 into vector<3xf32>
+ // Second slice [1, ...]: extracts → reductions → inserts
+ // CHECK: vector.extract %[[SOURCE]][1, 0]
+ // CHECK: vector.extract %[[SOURCE]][1, 1]
+ // CHECK: vector.extract %[[SOURCE]][1, 2]
+ // CHECK: vector.extract %[[ACC]][1, 0]
+ // CHECK: vector.extract %[[ACC]][1, 1]
+ // CHECK: vector.extract %[[ACC]][1, 2]
+ // CHECK: vector.multi_reduction <add>, {{.*}} [0] : vector<5xf32> to f32
+ // CHECK: vector.multi_reduction <add>, {{.*}} [0] : vector<5xf32> to f32
+ // CHECK: vector.multi_reduction <add>, {{.*}} [0] : vector<5xf32> to f32
+ // CHECK: vector.insert {{.*}} [0] : f32 into vector<3xf32>
+ // CHECK: vector.insert {{.*}} [1] : f32 into vector<3xf32>
+ // CHECK: vector.insert {{.*}} [2] : f32 into vector<3xf32>
+ // Final inserts to assemble result
+ // CHECK: vector.insert {{.*}} [0] : vector<3xf32> into vector<2x3xf32>
+ // CHECK: vector.insert {{.*}} [1] : vector<3xf32> into vector<2x3xf32>
+ // No original multi_reduction with [2] remains
+ // CHECK-NOT: vector.multi_reduction <add>, {{.*}} [2]
+ %1 = vector.multi_reduction <add>, %source, %acc [2] : vector<2x3x5xf32> to vector<2x3xf32>
+
+ return %1 : vector<2x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @unroll_vector_multi_reduction_inner_masked(
+// CHECK-SAME: %[[SOURCE:.+]]: vector<2x3x5xf32>,
+// CHECK-SAME: %[[MASK:.+]]: vector<2x3x5xi1>,
+// CHECK-SAME: %[[ACC:.+]]: vector<2x3xf32>
+func.func @unroll_vector_multi_reduction_inner_masked(%source: vector<2x3x5xf32>, %mask: vector<2x3x5xi1>, %acc: vector<2x3xf32>) -> (vector<2x3xf32>) {
+ // First slice [0, ...]: extracts (source, acc, mask) → masked reductions → inserts
+ // CHECK: vector.extract %[[SOURCE]][0, 0]
+ // CHECK: vector.extract %[[SOURCE]][0, 1]
+ // CHECK: vector.extract %[[SOURCE]][0, 2]
+ // CHECK: vector.extract %[[ACC]][0, 0]
+ // CHECK: vector.extract %[[ACC]][0, 1]
+ // CHECK: vector.extract %[[ACC]][0, 2]
+ // CHECK: vector.extract %[[MASK]][0, 0]
+ // CHECK: vector.extract %[[MASK]][0, 1]
+ // CHECK: vector.extract %[[MASK]][0, 2]
+ // CHECK: vector.mask {{.*}} { vector.multi_reduction <add>, {{.*}} [0] : vector<5xf32> to f32 }
+ // CHECK: vector.mask {{.*}} { vector.multi_reduction <add>, {{.*}} [0] : vector<5xf32> to f32 }
+ // CHECK: vector.mask {{.*}} { vector.multi_reduction <add>, {{.*}} [0] : vector<5xf32> to f32 }
+ // CHECK: vector.insert {{.*}} [0] : f32 into vector<3xf32>
+ // CHECK: vector.insert {{.*}} [1] : f32 into vector<3xf32>
+ // CHECK: vector.insert {{.*}} [2] : f32 into vector<3xf32>
+ // Second slice [1, ...]: extracts → masked reductions → inserts
+ // CHECK: vector.extract %[[SOURCE]][1, 0]
+ // CHECK: vector.extract %[[SOURCE]][1, 1]
+ // CHECK: vector.extract %[[SOURCE]][1, 2]
+ // CHECK: vector.extract %[[ACC]][1, 0]
+ // CHECK: vector.extract %[[ACC]][1, 1]
+ // CHECK: vector.extract %[[ACC]][1, 2]
+ // CHECK: vector.extract %[[MASK]][1, 0]
+ // CHECK: vector.extract %[[MASK]][1, 1]
+ // CHECK: vector.extract %[[MASK]][1, 2]
+ // CHECK: vector.mask {{.*}} { vector.multi_reduction <add>, {{.*}} [0] : vector<5xf32> to f32 }
+ // CHECK: vector.mask {{.*}} { vector.multi_reduction <add>, {{.*}} [0] : vector<5xf32> to f32 }
+ // CHECK: vector.mask {{.*}} { vector.multi_reduction <add>, {{.*}} [0] : vector<5xf32> to f32 }
+ // CHECK: vector.insert {{.*}} [0] : f32 into vector<3xf32>
+ // CHECK: vector.insert {{.*}} [1] : f32 into vector<3xf32>
+ // CHECK: vector.insert {{.*}} [2] : f32 into vector<3xf32>
+ // Final inserts to assemble result
+ // CHECK: vector.insert {{.*}} [0] : vector<3xf32> into vector<2x3xf32>
+ // CHECK: vector.insert {{.*}} [1] : vector<3xf32> into vector<2x3xf32>
+ // No original multi_reduction with [2] remains
+ // CHECK-NOT: vector.multi_reduction <add>, {{.*}} [2]
+
+ %0 = vector.mask %mask {
+ %1 = vector.multi_reduction <add>, %source, %acc [2] : vector<2x3x5xf32> to vector<2x3xf32>
+ } : vector<2x3x5xi1> -> vector<2x3xf32>
+
+ return %0 : vector<2x3xf32>
+}
>From 91532b1522da5128a75b8d9599d26b376d4417c7 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Fri, 30 Jan 2026 13:16:22 -0500
Subject: [PATCH 06/18] [mlir][vector] Implement general pattern
---
.../Transforms/LowerVectorMultiReduction.cpp | 46 ++++++--------
.../unroll-vector-multi-reduction-inner.mlir | 60 +++++++++++++++++++
2 files changed, 79 insertions(+), 27 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 5e0d6da82f834..99b766eb9ad7c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -674,34 +674,34 @@ struct UnrollMultiReductionOuterGeneralCase
}
};
-/// Unrolls innermost dimension for vector.multi_reduction when the innermost
-/// dimension is the only reduction dimension.
+/// Unrolls vector.multi_reduction along the outermost parallel dimension
+/// when the innermost dimension is a reduction dimension.
///
/// This pattern matches operations where:
/// - The innermost dimension is a reduction dimension
/// - The outermost dimension is a parallel dimension
-/// - There is exactly one reduction dimension
///
-/// The transformation unrolls along the outermost parallel dimension:
+/// The transformation extracts slices along the outermost parallel dimension,
+/// creates smaller multi_reductions, and assembles the results:
///
/// ```mlir
-/// %res = vector.multi_reduction <add> %src, %acc [N-1]
-/// : vector<AxBx...xMxf32> to vector<AxBx...xf32>
+/// %res = vector.multi_reduction <add> %src, %acc [1, 3]
+/// : vector<AxBxCxDxf32> to vector<AxCxf32>
/// ```
///
/// becomes:
///
/// ```mlir
-/// %result = arith.constant dense<0.0> : vector<AxBx...xf32>
-/// %0 = vector.extract %src[0] : vector<Bx...xMxf32> from vector<AxBx...xMxf32>
-/// %acc0 = vector.extract %acc[0] : vector<Bx...xf32> from vector<AxBx...xf32>
-/// %red0 = vector.multi_reduction <add>, %0, %acc0 [N-2]
-/// : vector<Bx...xMxf32> to vector<Bx...xf32>
+/// %result = arith.constant dense<0.0> : vector<AxCxf32>
+/// %0 = vector.extract %src[0] : vector<BxCxDxf32> from vector<AxBxCxDxf32>
+/// %acc0 = vector.extract %acc[0] : vector<Cxf32> from vector<AxCxf32>
+/// %red0 = vector.multi_reduction <add>, %0, %acc0 [0, 2]
+/// : vector<BxCxDxf32> to vector<Cxf32>
/// %res0 = vector.insert %red0, %result[0]
-/// : vector<Bx...xf32> into vector<AxBx...xf32>
+/// : vector<Cxf32> into vector<AxCxf32>
/// // ... repeat for indices 1 to A-1
/// ```
-struct UnrollMultiReductionInnerBaseCase
+struct UnrollMultiReductionInner
: public OpRewritePattern<vector::MultiDimReductionOp> {
using Base::Base;
@@ -723,11 +723,6 @@ struct UnrollMultiReductionInnerBaseCase
multiReductionOp,
"expected outermost dimension to be a parallel dimension.");
- ArrayRef<int64_t> reductionDims = multiReductionOp.getReductionDims();
- if (reductionDims.size() != 1)
- return rewriter.notifyMatchFailure(
- multiReductionOp, "expected exactly one reduction dimension.");
-
Type elementType = getElementTypeOrSelf(multiReductionOp.getDestType());
if (!elementType.isIntOrIndexOrFloat())
return rewriter.notifyMatchFailure(
@@ -770,11 +765,11 @@ struct UnrollMultiReductionInnerBaseCase
else
maskSlices.push_back(nullptr);
- // Compute new reduction mask: for the extracted slice (rank srcRank-1),
- // the innermost dimension is still the reduction dimension.
- // New mask has srcRank-1 elements, with the last one being true.
- SmallVector<bool> newReductionMask(srcRank - 1, false);
- newReductionMask.back() = true;
+ // Compute new reduction mask by dropping the first element (dimension 0).
+ // Since dimension 0 is parallel (not reduced), all reduction indices shift
+ // down by 1.
+ ArrayRef<bool> newReductionMask =
+ ArrayRef<bool>(multiReductionOp.getReductionMask()).drop_front();
SmallVector<Value> reductionResults;
for (auto [srcSlice, accSlice, maskSlice] :
@@ -875,10 +870,7 @@ void mlir::vector::populateVectorUnrollMultiReduction(
RewritePatternSet &patterns, VectorMultiReductionLowering options,
PatternBenefit benefit) {
if (options == VectorMultiReductionLowering::InnerReduction) {
- // TODO: Add UnrollMultiReductionInnerGeneralCase pattern here once
- // implemented.
- patterns.add<UnrollMultiReductionInnerBaseCase>(patterns.getContext(),
- benefit);
+ patterns.add<UnrollMultiReductionInner>(patterns.getContext(), benefit);
} else {
patterns.add<UnrollMultiReductionOuterBaseCase,
UnrollMultiReductionOuterGeneralCase>(patterns.getContext(),
diff --git a/mlir/test/Dialect/Vector/unroll-vector-multi-reduction-inner.mlir b/mlir/test/Dialect/Vector/unroll-vector-multi-reduction-inner.mlir
index 89af4580bf627..fcb97967f088e 100644
--- a/mlir/test/Dialect/Vector/unroll-vector-multi-reduction-inner.mlir
+++ b/mlir/test/Dialect/Vector/unroll-vector-multi-reduction-inner.mlir
@@ -104,3 +104,63 @@ func.func @unroll_vector_multi_reduction_inner_masked(%source: vector<2x3x5xf32>
return %0 : vector<2x3xf32>
}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// Test UnrollVectorMultiReduction for Inner Reduction (General Case)
+//===----------------------------------------------------------------------===//
+
+// The general case handles multiple reduction dimensions.
+// For vector<2x3x5xf32> with reduction on dims [1, 2]:
+// - Unrolls along dim 0 (size 2), creating vector<3x5xf32> multi_reductions
+// - Each new multi_reduction has reduction dims [0, 1] (shifted from [1, 2])
+
+// CHECK-LABEL: func @unroll_vector_multi_reduction_inner_general(
+// CHECK-SAME: %[[SOURCE:.+]]: vector<2x3x5xf32>,
+// CHECK-SAME: %[[ACC:.+]]: vector<2xf32>
+func.func @unroll_vector_multi_reduction_inner_general(%source: vector<2x3x5xf32>, %acc: vector<2xf32>) -> (vector<2xf32>) {
+ // CHECK-DAG: %[[VEC_0:.+]] = vector.extract %[[SOURCE]][0] : vector<3x5xf32> from vector<2x3x5xf32>
+ // CHECK-DAG: %[[VEC_1:.+]] = vector.extract %[[SOURCE]][1] : vector<3x5xf32> from vector<2x3x5xf32>
+
+ // CHECK-DAG: %[[ACC_0:.+]] = vector.extract %[[ACC]][0] : f32 from vector<2xf32>
+ // CHECK-DAG: %[[ACC_1:.+]] = vector.extract %[[ACC]][1] : f32 from vector<2xf32>
+
+ // CHECK: %[[RED_0:.+]] = vector.multi_reduction <add>, %[[VEC_0]], %[[ACC_0]] [0, 1] : vector<3x5xf32> to f32
+ // CHECK: %[[RED_1:.+]] = vector.multi_reduction <add>, %[[VEC_1]], %[[ACC_1]] [0, 1] : vector<3x5xf32> to f32
+ // CHECK: %[[INSERT_0:.+]] = vector.insert %[[RED_0]], %{{.*}} [0] : f32 into vector<2xf32>
+ // CHECK: %[[INSERT_1:.+]] = vector.insert %[[RED_1]], %[[INSERT_0]] [1] : f32 into vector<2xf32>
+ %1 = vector.multi_reduction <add>, %source, %acc [1, 2] : vector<2x3x5xf32> to vector<2xf32>
+
+ // CHECK: return %[[INSERT_1]]
+ return %1 : vector<2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @unroll_vector_multi_reduction_inner_general_masked(
+// CHECK-SAME: %[[SOURCE:.+]]: vector<2x3x5xf32>,
+// CHECK-SAME: %[[MASK:.+]]: vector<2x3x5xi1>,
+// CHECK-SAME: %[[ACC:.+]]: vector<2xf32>
+func.func @unroll_vector_multi_reduction_inner_general_masked(%source: vector<2x3x5xf32>, %mask: vector<2x3x5xi1>, %acc: vector<2xf32>) -> (vector<2xf32>) {
+ // CHECK-DAG: %[[VEC_0:.+]] = vector.extract %[[SOURCE]][0] : vector<3x5xf32> from vector<2x3x5xf32>
+ // CHECK-DAG: %[[VEC_1:.+]] = vector.extract %[[SOURCE]][1] : vector<3x5xf32> from vector<2x3x5xf32>
+
+ // CHECK-DAG: %[[ACC_0:.+]] = vector.extract %[[ACC]][0] : f32 from vector<2xf32>
+ // CHECK-DAG: %[[ACC_1:.+]] = vector.extract %[[ACC]][1] : f32 from vector<2xf32>
+
+ // CHECK-DAG: %[[MASK_0:.+]] = vector.extract %[[MASK]][0] : vector<3x5xi1> from vector<2x3x5xi1>
+ // CHECK-DAG: %[[MASK_1:.+]] = vector.extract %[[MASK]][1] : vector<3x5xi1> from vector<2x3x5xi1>
+
+ // CHECK: %[[RED_0:.+]] = vector.mask %[[MASK_0]] { vector.multi_reduction <add>, %[[VEC_0]], %[[ACC_0]] [0, 1] : vector<3x5xf32> to f32 } : vector<3x5xi1> -> f32
+ // CHECK: %[[RED_1:.+]] = vector.mask %[[MASK_1]] { vector.multi_reduction <add>, %[[VEC_1]], %[[ACC_1]] [0, 1] : vector<3x5xf32> to f32 } : vector<3x5xi1> -> f32
+ // CHECK: %[[INSERT_0:.+]] = vector.insert %[[RED_0]], %{{.*}} [0] : f32 into vector<2xf32>
+ // CHECK: %[[INSERT_1:.+]] = vector.insert %[[RED_1]], %[[INSERT_0]] [1] : f32 into vector<2xf32>
+
+ %0 = vector.mask %mask {
+ %1 = vector.multi_reduction <add>, %source, %acc [1, 2] : vector<2x3x5xf32> to vector<2xf32>
+ } : vector<2x3x5xi1> -> vector<2xf32>
+
+ // CHECK: return %[[INSERT_1]]
+ return %0 : vector<2xf32>
+}
>From f3fe16c9d259e1ff910e6ecaa606f843f2b723e4 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Fri, 30 Jan 2026 14:05:20 -0500
Subject: [PATCH 07/18] [mlir][vector] Skip matching multi_reduction on rank 1
---
.../Transforms/LowerVectorMultiReduction.cpp | 5 +++++
.../unroll-vector-multi-reduction-inner.mlir | 19 +++++++++++++++++++
.../Vector/unroll-vector-multi-reduction.mlir | 19 +++++++++++++++++++
3 files changed, 43 insertions(+)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 99b766eb9ad7c..a79b6ea10d543 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -541,6 +541,11 @@ struct UnrollMultiReductionOuterBaseCase
LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
PatternRewriter &rewriter) const override {
+ auto srcRank = multiReductionOp.getSourceVectorType().getRank();
+ if (srcRank < 2)
+ return rewriter.notifyMatchFailure(multiReductionOp,
+ "expected source rank >= 2.");
+
if (!multiReductionOp.isReducedDim(0))
return rewriter.notifyMatchFailure(
multiReductionOp,
diff --git a/mlir/test/Dialect/Vector/unroll-vector-multi-reduction-inner.mlir b/mlir/test/Dialect/Vector/unroll-vector-multi-reduction-inner.mlir
index fcb97967f088e..d6b3091fb9463 100644
--- a/mlir/test/Dialect/Vector/unroll-vector-multi-reduction-inner.mlir
+++ b/mlir/test/Dialect/Vector/unroll-vector-multi-reduction-inner.mlir
@@ -164,3 +164,22 @@ func.func @unroll_vector_multi_reduction_inner_general_masked(%source: vector<2x
// CHECK: return %[[INSERT_1]]
return %0 : vector<2xf32>
}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// Negative Test: Rank-1 multi_reduction should NOT be matched by UnrollMultiReductionInner
+//===----------------------------------------------------------------------===//
+
+// UnrollMultiReductionInner requires srcRank >= 2, so rank-1 should not match.
+// The op should remain unchanged.
+
+// CHECK-LABEL: func @unroll_vector_multi_reduction_inner_rank1_negative(
+// CHECK-SAME: %[[SOURCE:.+]]: vector<8xf32>,
+// CHECK-SAME: %[[ACC:.+]]: f32
+func.func @unroll_vector_multi_reduction_inner_rank1_negative(%source: vector<8xf32>, %acc: f32) -> f32 {
+ // CHECK: %[[RESULT:.+]] = vector.multi_reduction <add>, %[[SOURCE]], %[[ACC]] [0] : vector<8xf32> to f32
+ %0 = vector.multi_reduction <add>, %source, %acc [0] : vector<8xf32> to f32
+ // CHECK: return %[[RESULT]]
+ return %0 : f32
+}
diff --git a/mlir/test/Dialect/Vector/unroll-vector-multi-reduction.mlir b/mlir/test/Dialect/Vector/unroll-vector-multi-reduction.mlir
index 79086e2b0b9ad..8c36430bb6eee 100644
--- a/mlir/test/Dialect/Vector/unroll-vector-multi-reduction.mlir
+++ b/mlir/test/Dialect/Vector/unroll-vector-multi-reduction.mlir
@@ -90,3 +90,22 @@ func.func @unroll_vector_multi_reduction_general_masked(%source: vector<2x3x5xf3
// CHECK: return %[[RES_1]]
return %0 : vector<3xf32>
}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// Negative Test: Rank-1 multi_reduction should NOT be matched by unroll patterns
+//===----------------------------------------------------------------------===//
+
+// UnrollMultiReductionOuterBaseCase and UnrollMultiReductionOuterGeneralCase
+// should not match rank-1 multi_reduction ops. The op should remain unchanged.
+
+// CHECK-LABEL: func @unroll_vector_multi_reduction_rank1_negative(
+// CHECK-SAME: %[[SOURCE:.+]]: vector<8xf32>,
+// CHECK-SAME: %[[ACC:.+]]: f32
+func.func @unroll_vector_multi_reduction_rank1_negative(%source: vector<8xf32>, %acc: f32) -> f32 {
+ // CHECK: %[[RESULT:.+]] = vector.multi_reduction <add>, %[[SOURCE]], %[[ACC]] [0] : vector<8xf32> to f32
+ %0 = vector.multi_reduction <add>, %source, %acc [0] : vector<8xf32> to f32
+ // CHECK: return %[[RESULT]]
+ return %0 : f32
+}
>From edc537eedf5caa554dd8e8d025b14d0692a26275 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Fri, 30 Jan 2026 15:17:17 -0500
Subject: [PATCH 08/18] [mlir][vector] Convert mult_reduction in 1-D to
reduction.
---
.../Transforms/LowerVectorMultiReduction.cpp | 58 +++++++++++++++++
.../unroll-vector-multi-reduction-inner.mlir | 65 ++++++++++++-------
.../Vector/unroll-vector-multi-reduction.mlir | 27 ++++++--
3 files changed, 120 insertions(+), 30 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index a79b6ea10d543..595c2fce55619 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -413,6 +413,60 @@ struct TwoDimMultiReductionToReduction
}
};
+/// Converts 1-D vector.multi_reduction directly to vector.reduction.
+/// This is the terminal case for unrolling - once we reach rank 1,
+/// we convert to vector.reduction which backends can optimize.
+///
+/// Example:
+/// ```mlir
+/// // Before
+/// %r = vector.multi_reduction <add>, %v, %acc [0] : vector<Nxf32> to f32
+///
+/// // After
+/// %r = vector.reduction <add>, %v, %acc : vector<Nxf32> into f32
+/// ```
+struct OneDimMultiReductionToReduction
+ : public OpRewritePattern<vector::MultiDimReductionOp> {
+ using Base::Base;
+
+ LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
+ PatternRewriter &rewriter) const override {
+ auto srcRank = multiReductionOp.getSourceVectorType().getRank();
+ if (srcRank != 1)
+ return rewriter.notifyMatchFailure(multiReductionOp,
+ "expected source rank == 1.");
+
+ if (!multiReductionOp.isReducedDim(0))
+ return rewriter.notifyMatchFailure(
+ multiReductionOp,
+ "expected dimension 0 to be a reduction dimension.");
+
+ OpBuilder::InsertionGuard guard(rewriter);
+ auto maskableOp =
+ cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
+ Operation *rootOp;
+ Value mask;
+ if (maskableOp.isMasked()) {
+ rewriter.setInsertionPoint(maskableOp.getMaskingOp());
+ rootOp = maskableOp.getMaskingOp();
+ mask = maskableOp.getMaskingOp().getMask();
+ } else {
+ rootOp = multiReductionOp;
+ }
+
+ auto loc = multiReductionOp.getLoc();
+ Operation *reductionOp = vector::ReductionOp::create(
+ rewriter, loc, multiReductionOp.getKind(), multiReductionOp.getSource(),
+ multiReductionOp.getAcc());
+
+ if (maskableOp.isMasked())
+ reductionOp = mlir::vector::maskOperation(rewriter, reductionOp, mask);
+
+ rewriter.replaceOp(rootOp, reductionOp->getResult(0));
+ return success();
+ }
+};
+
/// Converts 1d vector.multi_reduction with a single reduction dimension to a 2d
/// form with both a single parallel and reduction dimension.
/// This is achieved with a simple vector.shape_cast that inserts a leading 1.
@@ -861,6 +915,8 @@ void mlir::vector::populateVectorMultiReductionFlatteningPatterns(
void mlir::vector::populateVectorMultiReductionUnrollingPatterns(
RewritePatternSet &patterns, VectorMultiReductionLowering options,
PatternBenefit benefit) {
+ patterns.add<OneDimMultiReductionToReduction>(patterns.getContext(), benefit);
+
if (options == VectorMultiReductionLowering ::InnerReduction) {
patterns.add<TwoDimMultiReductionToReduction>(patterns.getContext(),
benefit);
@@ -874,6 +930,8 @@ void mlir::vector::populateVectorMultiReductionUnrollingPatterns(
void mlir::vector::populateVectorUnrollMultiReduction(
RewritePatternSet &patterns, VectorMultiReductionLowering options,
PatternBenefit benefit) {
+ patterns.add<OneDimMultiReductionToReduction>(patterns.getContext(), benefit);
+
if (options == VectorMultiReductionLowering::InnerReduction) {
patterns.add<UnrollMultiReductionInner>(patterns.getContext(), benefit);
} else {
diff --git a/mlir/test/Dialect/Vector/unroll-vector-multi-reduction-inner.mlir b/mlir/test/Dialect/Vector/unroll-vector-multi-reduction-inner.mlir
index d6b3091fb9463..f21535c1dfc21 100644
--- a/mlir/test/Dialect/Vector/unroll-vector-multi-reduction-inner.mlir
+++ b/mlir/test/Dialect/Vector/unroll-vector-multi-reduction-inner.mlir
@@ -5,13 +5,15 @@
// Test UnrollVectorMultiReduction for Inner Reduction (Base Case)
//===----------------------------------------------------------------------===//
-// The pattern recursively reduces rank until we reach 1D multi_reductions.
+// The pattern recursively reduces rank until we reach 1D, then converts to
+// vector.reduction via OneDimMultiReductionToReduction.
// For vector<2x3x5xf32> with reduction on dim 2:
// - First pass: unrolls along dim 0 (size 2), creating vector<3x5xf32> multi_reductions
// - Second pass: unrolls along dim 0 (size 3), creating vector<5xf32> multi_reductions
+// - Final: 1-D multi_reductions are converted to vector.reduction
//
// The generated IR groups operations by phase:
-// extracts (source) → extracts (acc) → multi_reductions → inserts
+// extracts (source) → extracts (acc) → reductions → inserts
// CHECK-LABEL: func @unroll_vector_multi_reduction_inner(
// CHECK-SAME: %[[SOURCE:.+]]: vector<2x3x5xf32>,
@@ -24,9 +26,9 @@ func.func @unroll_vector_multi_reduction_inner(%source: vector<2x3x5xf32>, %acc:
// CHECK: vector.extract %[[ACC]][0, 0]
// CHECK: vector.extract %[[ACC]][0, 1]
// CHECK: vector.extract %[[ACC]][0, 2]
- // CHECK: vector.multi_reduction <add>, {{.*}} [0] : vector<5xf32> to f32
- // CHECK: vector.multi_reduction <add>, {{.*}} [0] : vector<5xf32> to f32
- // CHECK: vector.multi_reduction <add>, {{.*}} [0] : vector<5xf32> to f32
+ // CHECK: vector.reduction <add>, {{.*}} : vector<5xf32> into f32
+ // CHECK: vector.reduction <add>, {{.*}} : vector<5xf32> into f32
+ // CHECK: vector.reduction <add>, {{.*}} : vector<5xf32> into f32
// CHECK: vector.insert {{.*}} [0] : f32 into vector<3xf32>
// CHECK: vector.insert {{.*}} [1] : f32 into vector<3xf32>
// CHECK: vector.insert {{.*}} [2] : f32 into vector<3xf32>
@@ -37,17 +39,17 @@ func.func @unroll_vector_multi_reduction_inner(%source: vector<2x3x5xf32>, %acc:
// CHECK: vector.extract %[[ACC]][1, 0]
// CHECK: vector.extract %[[ACC]][1, 1]
// CHECK: vector.extract %[[ACC]][1, 2]
- // CHECK: vector.multi_reduction <add>, {{.*}} [0] : vector<5xf32> to f32
- // CHECK: vector.multi_reduction <add>, {{.*}} [0] : vector<5xf32> to f32
- // CHECK: vector.multi_reduction <add>, {{.*}} [0] : vector<5xf32> to f32
+ // CHECK: vector.reduction <add>, {{.*}} : vector<5xf32> into f32
+ // CHECK: vector.reduction <add>, {{.*}} : vector<5xf32> into f32
+ // CHECK: vector.reduction <add>, {{.*}} : vector<5xf32> into f32
// CHECK: vector.insert {{.*}} [0] : f32 into vector<3xf32>
// CHECK: vector.insert {{.*}} [1] : f32 into vector<3xf32>
// CHECK: vector.insert {{.*}} [2] : f32 into vector<3xf32>
// Final inserts to assemble result
// CHECK: vector.insert {{.*}} [0] : vector<3xf32> into vector<2x3xf32>
// CHECK: vector.insert {{.*}} [1] : vector<3xf32> into vector<2x3xf32>
- // No original multi_reduction with [2] remains
- // CHECK-NOT: vector.multi_reduction <add>, {{.*}} [2]
+ // No original multi_reduction remains
+ // CHECK-NOT: vector.multi_reduction
%1 = vector.multi_reduction <add>, %source, %acc [2] : vector<2x3x5xf32> to vector<2x3xf32>
return %1 : vector<2x3xf32>
@@ -70,9 +72,9 @@ func.func @unroll_vector_multi_reduction_inner_masked(%source: vector<2x3x5xf32>
// CHECK: vector.extract %[[MASK]][0, 0]
// CHECK: vector.extract %[[MASK]][0, 1]
// CHECK: vector.extract %[[MASK]][0, 2]
- // CHECK: vector.mask {{.*}} { vector.multi_reduction <add>, {{.*}} [0] : vector<5xf32> to f32 }
- // CHECK: vector.mask {{.*}} { vector.multi_reduction <add>, {{.*}} [0] : vector<5xf32> to f32 }
- // CHECK: vector.mask {{.*}} { vector.multi_reduction <add>, {{.*}} [0] : vector<5xf32> to f32 }
+ // CHECK: vector.mask {{.*}} { vector.reduction <add>, {{.*}} : vector<5xf32> into f32 }
+ // CHECK: vector.mask {{.*}} { vector.reduction <add>, {{.*}} : vector<5xf32> into f32 }
+ // CHECK: vector.mask {{.*}} { vector.reduction <add>, {{.*}} : vector<5xf32> into f32 }
// CHECK: vector.insert {{.*}} [0] : f32 into vector<3xf32>
// CHECK: vector.insert {{.*}} [1] : f32 into vector<3xf32>
// CHECK: vector.insert {{.*}} [2] : f32 into vector<3xf32>
@@ -86,17 +88,17 @@ func.func @unroll_vector_multi_reduction_inner_masked(%source: vector<2x3x5xf32>
// CHECK: vector.extract %[[MASK]][1, 0]
// CHECK: vector.extract %[[MASK]][1, 1]
// CHECK: vector.extract %[[MASK]][1, 2]
- // CHECK: vector.mask {{.*}} { vector.multi_reduction <add>, {{.*}} [0] : vector<5xf32> to f32 }
- // CHECK: vector.mask {{.*}} { vector.multi_reduction <add>, {{.*}} [0] : vector<5xf32> to f32 }
- // CHECK: vector.mask {{.*}} { vector.multi_reduction <add>, {{.*}} [0] : vector<5xf32> to f32 }
+ // CHECK: vector.mask {{.*}} { vector.reduction <add>, {{.*}} : vector<5xf32> into f32 }
+ // CHECK: vector.mask {{.*}} { vector.reduction <add>, {{.*}} : vector<5xf32> into f32 }
+ // CHECK: vector.mask {{.*}} { vector.reduction <add>, {{.*}} : vector<5xf32> into f32 }
// CHECK: vector.insert {{.*}} [0] : f32 into vector<3xf32>
// CHECK: vector.insert {{.*}} [1] : f32 into vector<3xf32>
// CHECK: vector.insert {{.*}} [2] : f32 into vector<3xf32>
// Final inserts to assemble result
// CHECK: vector.insert {{.*}} [0] : vector<3xf32> into vector<2x3xf32>
// CHECK: vector.insert {{.*}} [1] : vector<3xf32> into vector<2x3xf32>
- // No original multi_reduction with [2] remains
- // CHECK-NOT: vector.multi_reduction <add>, {{.*}} [2]
+ // No original multi_reduction remains
+ // CHECK-NOT: vector.multi_reduction
%0 = vector.mask %mask {
%1 = vector.multi_reduction <add>, %source, %acc [2] : vector<2x3x5xf32> to vector<2x3xf32>
@@ -168,18 +170,33 @@ func.func @unroll_vector_multi_reduction_inner_general_masked(%source: vector<2x
// -----
//===----------------------------------------------------------------------===//
-// Negative Test: Rank-1 multi_reduction should NOT be matched by UnrollMultiReductionInner
+// Test 1-D multi_reduction to vector.reduction conversion
//===----------------------------------------------------------------------===//
-// UnrollMultiReductionInner requires srcRank >= 2, so rank-1 should not match.
-// The op should remain unchanged.
+// OneDimMultiReductionToReduction converts rank-1 multi_reduction directly
+// to vector.reduction, which preserves the reduction semantic for backends.
-// CHECK-LABEL: func @unroll_vector_multi_reduction_inner_rank1_negative(
+// CHECK-LABEL: func @unroll_vector_multi_reduction_inner_1d(
// CHECK-SAME: %[[SOURCE:.+]]: vector<8xf32>,
// CHECK-SAME: %[[ACC:.+]]: f32
-func.func @unroll_vector_multi_reduction_inner_rank1_negative(%source: vector<8xf32>, %acc: f32) -> f32 {
- // CHECK: %[[RESULT:.+]] = vector.multi_reduction <add>, %[[SOURCE]], %[[ACC]] [0] : vector<8xf32> to f32
+func.func @unroll_vector_multi_reduction_inner_1d(%source: vector<8xf32>, %acc: f32) -> f32 {
+ // CHECK: %[[RESULT:.+]] = vector.reduction <add>, %[[SOURCE]], %[[ACC]] : vector<8xf32> into f32
%0 = vector.multi_reduction <add>, %source, %acc [0] : vector<8xf32> to f32
// CHECK: return %[[RESULT]]
return %0 : f32
}
+
+// -----
+
+// CHECK-LABEL: func @unroll_vector_multi_reduction_inner_1d_masked(
+// CHECK-SAME: %[[SOURCE:.+]]: vector<8xf32>,
+// CHECK-SAME: %[[MASK:.+]]: vector<8xi1>,
+// CHECK-SAME: %[[ACC:.+]]: f32
+func.func @unroll_vector_multi_reduction_inner_1d_masked(%source: vector<8xf32>, %mask: vector<8xi1>, %acc: f32) -> f32 {
+ // CHECK: %[[RESULT:.+]] = vector.mask %[[MASK]] { vector.reduction <add>, %[[SOURCE]], %[[ACC]] : vector<8xf32> into f32 } : vector<8xi1> -> f32
+ %0 = vector.mask %mask {
+ vector.multi_reduction <add>, %source, %acc [0] : vector<8xf32> to f32
+ } : vector<8xi1> -> f32
+ // CHECK: return %[[RESULT]]
+ return %0 : f32
+}
diff --git a/mlir/test/Dialect/Vector/unroll-vector-multi-reduction.mlir b/mlir/test/Dialect/Vector/unroll-vector-multi-reduction.mlir
index 8c36430bb6eee..6db425fe4fce9 100644
--- a/mlir/test/Dialect/Vector/unroll-vector-multi-reduction.mlir
+++ b/mlir/test/Dialect/Vector/unroll-vector-multi-reduction.mlir
@@ -94,18 +94,33 @@ func.func @unroll_vector_multi_reduction_general_masked(%source: vector<2x3x5xf3
// -----
//===----------------------------------------------------------------------===//
-// Negative Test: Rank-1 multi_reduction should NOT be matched by unroll patterns
+// Test 1-D multi_reduction to vector.reduction conversion
//===----------------------------------------------------------------------===//
-// UnrollMultiReductionOuterBaseCase and UnrollMultiReductionOuterGeneralCase
-// should not match rank-1 multi_reduction ops. The op should remain unchanged.
+// OneDimMultiReductionToReduction converts rank-1 multi_reduction directly
+// to vector.reduction, which preserves the reduction semantic for backends.
-// CHECK-LABEL: func @unroll_vector_multi_reduction_rank1_negative(
+// CHECK-LABEL: func @unroll_vector_multi_reduction_1d(
// CHECK-SAME: %[[SOURCE:.+]]: vector<8xf32>,
// CHECK-SAME: %[[ACC:.+]]: f32
-func.func @unroll_vector_multi_reduction_rank1_negative(%source: vector<8xf32>, %acc: f32) -> f32 {
- // CHECK: %[[RESULT:.+]] = vector.multi_reduction <add>, %[[SOURCE]], %[[ACC]] [0] : vector<8xf32> to f32
+func.func @unroll_vector_multi_reduction_1d(%source: vector<8xf32>, %acc: f32) -> f32 {
+ // CHECK: %[[RESULT:.+]] = vector.reduction <add>, %[[SOURCE]], %[[ACC]] : vector<8xf32> into f32
%0 = vector.multi_reduction <add>, %source, %acc [0] : vector<8xf32> to f32
// CHECK: return %[[RESULT]]
return %0 : f32
}
+
+// -----
+
+// CHECK-LABEL: func @unroll_vector_multi_reduction_1d_masked(
+// CHECK-SAME: %[[SOURCE:.+]]: vector<8xf32>,
+// CHECK-SAME: %[[MASK:.+]]: vector<8xi1>,
+// CHECK-SAME: %[[ACC:.+]]: f32
+func.func @unroll_vector_multi_reduction_1d_masked(%source: vector<8xf32>, %mask: vector<8xi1>, %acc: f32) -> f32 {
+ // CHECK: %[[RESULT:.+]] = vector.mask %[[MASK]] { vector.reduction <add>, %[[SOURCE]], %[[ACC]] : vector<8xf32> into f32 } : vector<8xi1> -> f32
+ %0 = vector.mask %mask {
+ vector.multi_reduction <add>, %source, %acc [0] : vector<8xf32> to f32
+ } : vector<8xi1> -> f32
+ // CHECK: return %[[RESULT]]
+ return %0 : f32
+}
>From 764bbeab848872f0416335b64c84d38ffe2fe974 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Fri, 30 Jan 2026 15:33:20 -0500
Subject: [PATCH 09/18] [mlir][vector] Remove old patterns population.
Replace populateVectorMultiReductionUnrollingPatterns
with populateVectorUnrollMultiReduction which is more general.
---
.../Vector/Transforms/LoweringPatterns.h | 17 ----------------
.../Transforms/LowerVectorMultiReduction.cpp | 19 ++----------------
.../vector-multi-reduction-pass-lowering.mlir | 20 +++++++++++--------
3 files changed, 14 insertions(+), 42 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 1e78983ce082d..e6e62f2250fb7 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -89,23 +89,6 @@ void populateVectorMultiReductionFlatteningPatterns(
RewritePatternSet &patterns, VectorMultiReductionLowering options,
PatternBenefit benefit = 1);
-/// Collect a set of patterns to convert vector.multi_reduction op into
-/// a sequence of vector.reduction ops. The patterns comprise:
-///
-/// [TwoDimMultiReductionToElementWise]
-/// Once in 2-D vector.multi_reduction form, with an **outermost** reduction
-/// dimension, unroll the outer dimension to obtain a sequence of 1-D vector
-/// ops. This also has an opportunity for tree-reduction (in the future).
-///
-/// [TwoDimMultiReductionToReduction]
-/// Once in 2-D vector.multi_reduction form, with an **innermost** reduction
-/// dimension, unroll the outer dimension to obtain a sequence of extract +
-/// vector.reduction + insert. This can further lower to horizontal reduction
-/// ops.
-void populateVectorMultiReductionUnrollingPatterns(
- RewritePatternSet &patterns, VectorMultiReductionLowering options,
- PatternBenefit benefit = 1);
-
/// Collect a set of patterns to convert vector.multi_reduction op into
/// a sequence of vector.reduction ops. The patterns comprise:
///
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 595c2fce55619..cd651ecf6c745 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -883,8 +883,8 @@ struct LowerVectorMultiReductionPass
signalPassFailure();
RewritePatternSet unrollingPatterns(context);
- populateVectorMultiReductionUnrollingPatterns(unrollingPatterns,
- this->loweringStrategy);
+ populateVectorUnrollMultiReduction(unrollingPatterns,
+ this->loweringStrategy);
if (failed(applyPatternsGreedily(op, std::move(unrollingPatterns))))
signalPassFailure();
@@ -912,21 +912,6 @@ void mlir::vector::populateVectorMultiReductionFlatteningPatterns(
benefit);
}
-void mlir::vector::populateVectorMultiReductionUnrollingPatterns(
- RewritePatternSet &patterns, VectorMultiReductionLowering options,
- PatternBenefit benefit) {
- patterns.add<OneDimMultiReductionToReduction>(patterns.getContext(), benefit);
-
- if (options == VectorMultiReductionLowering ::InnerReduction) {
- patterns.add<TwoDimMultiReductionToReduction>(patterns.getContext(),
- benefit);
- } else {
- patterns.add<UnrollMultiReductionOuterBaseCase,
- UnrollMultiReductionOuterGeneralCase>(patterns.getContext(),
- benefit);
- }
-}
-
void mlir::vector::populateVectorUnrollMultiReduction(
RewritePatternSet &patterns, VectorMultiReductionLowering options,
PatternBenefit benefit) {
diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-pass-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-pass-lowering.mlir
index e01bf446eb83c..ac833612681cf 100644
--- a/mlir/test/Dialect/Vector/vector-multi-reduction-pass-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-pass-lowering.mlir
@@ -10,13 +10,13 @@ func.func @vector_multi_reduction(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -
// ALL-SAME: %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: vector<2xf32>)
// INNER-REDUCTION-DAG: %[[RESULT_VEC_0:.+]] = arith.constant dense<{{.*}}> : vector<2xf32>
// INNER-REDUCTION: %[[V0:.+]] = vector.extract %[[INPUT]][0]
-// INNER-REDUCTION: %[[ACC0:.+]] = vector.extract %[[ACC]][0]
-// INNER-REDUCTION: %[[RV0:.+]] = vector.reduction <mul>, %[[V0]], %[[ACC0]] : vector<4xf32> into f32
-// INNER-REDUCTION: %[[RESULT_VEC_1:.+]] = vector.insert %[[RV0:.+]], %[[RESULT_VEC_0]] [0] : f32 into vector<2xf32>
// INNER-REDUCTION: %[[V1:.+]] = vector.extract %[[INPUT]][1]
+// INNER-REDUCTION: %[[ACC0:.+]] = vector.extract %[[ACC]][0]
// INNER-REDUCTION: %[[ACC1:.+]] = vector.extract %[[ACC]][1]
+// INNER-REDUCTION: %[[RV0:.+]] = vector.reduction <mul>, %[[V0]], %[[ACC0]] : vector<4xf32> into f32
// INNER-REDUCTION: %[[RV1:.+]] = vector.reduction <mul>, %[[V1]], %[[ACC1]] : vector<4xf32> into f32
-// INNER-REDUCTION: %[[RESULT_VEC:.+]] = vector.insert %[[RV1:.+]], %[[RESULT_VEC_1]] [1] : f32 into vector<2xf32>
+// INNER-REDUCTION: %[[RESULT_VEC_1:.+]] = vector.insert %[[RV0]], %[[RESULT_VEC_0]] [0] : f32 into vector<2xf32>
+// INNER-REDUCTION: %[[RESULT_VEC:.+]] = vector.insert %[[RV1]], %[[RESULT_VEC_1]] [1] : f32 into vector<2xf32>
// INNER-REDUCTION: return %[[RESULT_VEC]]
// INNER-PARALLEL: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32>
@@ -51,10 +51,14 @@ func.func @vector_multi_reduction_masked(%arg0: vector<2x4xf32>, %acc: vector<2x
// ALL-LABEL: func @vector_multi_reduction_masked
// ALL-SAME: %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.+]]: vector<2xf32>, %[[MASK:.+]]: vector<2x4xi1>
-// INNER-REDUCTION: %[[INNERVEC:.+]] = vector.extract %[[INPUT]][0] : vector<4xf32> from vector<2x4xf32>
-// INNER-REDUCTION: %[[INNERACC:.+]] = vector.extract %[[ACC]][0] : f32 from vector<2xf32>
-// INNER-REDUCTION: %[[INNERMASK:.+]] = vector.extract %[[MASK]][0] : vector<4xi1> from vector<2x4xi1>
-// INNER-REDUCTION: vector.mask %[[INNERMASK]] { vector.reduction <mul>, %[[INNERVEC]], %[[INNERACC]] : vector<4xf32> into f32 } : vector<4xi1> -> f32
+// INNER-REDUCTION: %[[V0:.+]] = vector.extract %[[INPUT]][0] : vector<4xf32> from vector<2x4xf32>
+// INNER-REDUCTION: %[[V1:.+]] = vector.extract %[[INPUT]][1] : vector<4xf32> from vector<2x4xf32>
+// INNER-REDUCTION: %[[ACC0:.+]] = vector.extract %[[ACC]][0] : f32 from vector<2xf32>
+// INNER-REDUCTION: %[[ACC1:.+]] = vector.extract %[[ACC]][1] : f32 from vector<2xf32>
+// INNER-REDUCTION: %[[MASK0:.+]] = vector.extract %[[MASK]][0] : vector<4xi1> from vector<2x4xi1>
+// INNER-REDUCTION: %[[MASK1:.+]] = vector.extract %[[MASK]][1] : vector<4xi1> from vector<2x4xi1>
+// INNER-REDUCTION: vector.mask %[[MASK0]] { vector.reduction <mul>, %[[V0]], %[[ACC0]] : vector<4xf32> into f32 } : vector<4xi1> -> f32
+// INNER-REDUCTION: vector.mask %[[MASK1]] { vector.reduction <mul>, %[[V1]], %[[ACC1]] : vector<4xf32> into f32 } : vector<4xi1> -> f32
// INNER-PARALLEL: %[[TPMASK:.+]] = vector.transpose %[[MASK]], [1, 0] : vector<2x4xi1> to vector<4x2xi1>
// INNER-PARALLEL: %[[TPINPUT:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32>
// INNER-PARALLEL: %[[INNERVEC:.+]] = vector.extract %[[TPINPUT]][0] : vector<2xf32> from vector<4x2xf32>
>From 1df3fa0d2d4713f15f70e5ef33a3aa4488a89904 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Fri, 30 Jan 2026 16:42:01 -0500
Subject: [PATCH 10/18] [mlir][vector] Add missing rank check
---
.../Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp | 5 +++++
1 file changed, 5 insertions(+)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index cd651ecf6c745..2e1da48a3f768 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -664,6 +664,11 @@ struct UnrollMultiReductionOuterGeneralCase
LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
PatternRewriter &rewriter) const override {
+ auto srcRank = multiReductionOp.getSourceVectorType().getRank();
+ if (srcRank < 2)
+ return rewriter.notifyMatchFailure(multiReductionOp,
+ "expected source rank >= 2.");
+
if (!multiReductionOp.isReducedDim(0))
return rewriter.notifyMatchFailure(
multiReductionOp,
>From c458d03f842cf962b73c677f85f1a753d9c440ec Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Fri, 30 Jan 2026 16:43:48 -0500
Subject: [PATCH 11/18] Split documentation
---
.../Transforms/LowerVectorMultiReduction.cpp | 67 +++++++++----------
1 file changed, 31 insertions(+), 36 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 2e1da48a3f768..08fcff00b2c19 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -539,56 +539,27 @@ struct OneDimMultiReductionToTwoDim
}
};
-/// Unrolls outermost dimension for vector.multi_reduction.
-/// This patterns matches operations which reduce the outermost dimension,
-/// it does not transform operations for which the outermost dimension is not
-/// a reduction dimension.
+/// Unrolls vector.multi_reduction when the outermost dimension is the only
+/// reduction dimension. Transforms to elementwise arithmetic operations.
///
-/// There are two cases to consider:
-/// 1. The base case is when the outermost dimension is the only reduction
-/// dimension.
-/// 2. The general case is when the outermost dimension is not the only
-/// reduction dimension.
-///
-/// The base case transformation:
+/// Example:
///
/// ```mlir
-/// %res = vector.multi_reduction <add> %src, %acc [0] : vector<NxMx...xf32> to
-/// vector<Mx...xf32>
+/// %res = vector.multi_reduction <add> %src, %acc [0]
+/// : vector<NxMx...xf32> to vector<Mx...xf32>
/// ```
///
-/// will extract N vectors from %src and then perform elementwise operations.
+/// becomes:
///
/// ```mlir
/// %0 = vector.extract %src[0] : vector<Mx...xf32> from vector<NxMx...xf32>
/// ...
-/// %Nminus1 = vector.extract %src[ [[N-1]] ] : vector<Mx...x.f32> from
-/// vector<NxMx...xf32>
+/// %Nminus1 = vector.extract %src[N-1] : vector<Mx...xf32> from vector<NxMx...xf32>
///
/// %res0 = arith.addf %0, %acc : vector<Mx...xf32>
/// ...
/// %res = arith.addf %Nminus1, %resNminus2 : vector<Mx...xf32>
/// ```
-///
-/// For the general case, we still extract N vectors, but produce N
-/// vector.multi_reduction instead of elementwise operations.
-///
-/// ```mlir
-/// %res = vector.multi_reduction <add> %src, %acc [0, [[REDUCTION_DIMS]] ] :
-/// vector<NxMx...xf32> to vector<Ix...xf32>
-///
-/// ```mlir
-/// %0 = vector.extract %src[0] : vector<Mx...xf32> from vector<NxMx...xf32>
-/// ...
-/// %Nminus1 = vector.extract %src[ [[N-1]] ] : vector<Mx...x.f32> from
-/// vector<NxMx...xf32>
-///
-/// %red0 = vector.multi_reduction %0, %acc [ [[REDUCTION_DIMS]] ] :
-/// vector<Mx...xf32> to vector<Ix...xf32>
-/// ...
-/// %res = vector.multi_reduction %Nminus1, %redNminus2 [ [[REDUCTION_DIMS]] ] :
-/// vector<Mx...xf32> to vector<Ix...xf32>
-/// ```
struct UnrollMultiReductionOuterBaseCase
: public OpRewritePattern<vector::MultiDimReductionOp> {
using Base::Base;
@@ -658,6 +629,30 @@ struct UnrollMultiReductionOuterBaseCase
}
};
+/// Unrolls vector.multi_reduction when the outermost dimension is one of
+/// multiple reduction dimensions. Extracts slices and chains smaller
+/// multi_reduction operations.
+///
+/// Example:
+///
+/// ```mlir
+/// %res = vector.multi_reduction <add> %src, %acc [0, 2]
+/// : vector<NxMx...xf32> to vector<Ix...xf32>
+/// ```
+///
+/// becomes:
+///
+/// ```mlir
+/// %0 = vector.extract %src[0] : vector<Mx...xf32> from vector<NxMx...xf32>
+/// ...
+/// %Nminus1 = vector.extract %src[N-1] : vector<Mx...xf32> from vector<NxMx...xf32>
+///
+/// %red0 = vector.multi_reduction <add>, %0, %acc [1]
+/// : vector<Mx...xf32> to vector<Ix...xf32>
+/// ...
+/// %res = vector.multi_reduction <add>, %Nminus1, %redNminus2 [1]
+/// : vector<Mx...xf32> to vector<Ix...xf32>
+/// ```
struct UnrollMultiReductionOuterGeneralCase
: public OpRewritePattern<vector::MultiDimReductionOp> {
using Base::Base;
>From 10dd7655b481bc336c2b4ae7846b10763b24bb0b Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Fri, 30 Jan 2026 16:48:55 -0500
Subject: [PATCH 12/18] Rename
---
.../Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp | 5 +++--
1 file changed, 3 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 08fcff00b2c19..2f8d0a2e2da05 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -760,7 +760,7 @@ struct UnrollMultiReductionOuterGeneralCase
/// : vector<Cxf32> into vector<AxCxf32>
/// // ... repeat for indices 1 to A-1
/// ```
-struct UnrollMultiReductionInner
+struct UnrollInnerReductionAlongOuterParallel
: public OpRewritePattern<vector::MultiDimReductionOp> {
using Base::Base;
@@ -918,7 +918,8 @@ void mlir::vector::populateVectorUnrollMultiReduction(
patterns.add<OneDimMultiReductionToReduction>(patterns.getContext(), benefit);
if (options == VectorMultiReductionLowering::InnerReduction) {
- patterns.add<UnrollMultiReductionInner>(patterns.getContext(), benefit);
+ patterns.add<UnrollInnerReductionAlongOuterParallel>(patterns.getContext(),
+ benefit);
} else {
patterns.add<UnrollMultiReductionOuterBaseCase,
UnrollMultiReductionOuterGeneralCase>(patterns.getContext(),
>From 32018932b45d9714f8048840e3be4d4ab31570dd Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Fri, 30 Jan 2026 17:01:32 -0500
Subject: [PATCH 13/18] Style
---
.../Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp | 6 ++++--
1 file changed, 4 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 2f8d0a2e2da05..3c9e7f05105e1 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -554,7 +554,8 @@ struct OneDimMultiReductionToTwoDim
/// ```mlir
/// %0 = vector.extract %src[0] : vector<Mx...xf32> from vector<NxMx...xf32>
/// ...
-/// %Nminus1 = vector.extract %src[N-1] : vector<Mx...xf32> from vector<NxMx...xf32>
+/// %Nminus1 = vector.extract %src[N-1] : vector<Mx...xf32> from
+/// vector<NxMx...xf32>
///
/// %res0 = arith.addf %0, %acc : vector<Mx...xf32>
/// ...
@@ -645,7 +646,8 @@ struct UnrollMultiReductionOuterBaseCase
/// ```mlir
/// %0 = vector.extract %src[0] : vector<Mx...xf32> from vector<NxMx...xf32>
/// ...
-/// %Nminus1 = vector.extract %src[N-1] : vector<Mx...xf32> from vector<NxMx...xf32>
+/// %Nminus1 = vector.extract %src[N-1] : vector<Mx...xf32> from
+/// vector<NxMx...xf32>
///
/// %red0 = vector.multi_reduction <add>, %0, %acc [1]
/// : vector<Mx...xf32> to vector<Ix...xf32>
>From 9e02ca79276fbff311996ffd2a7fc19439c2963c Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Fri, 30 Jan 2026 18:57:15 -0500
Subject: [PATCH 14/18] Fix asan
---
.../Vector/Transforms/LowerVectorMultiReduction.cpp | 7 ++++---
1 file changed, 4 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 3c9e7f05105e1..ea65b02fe71fa 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -713,8 +713,8 @@ struct UnrollMultiReductionOuterGeneralCase
else
masks.push_back(nullptr);
- ArrayRef<bool> reductionMask =
- ArrayRef<bool>(multiReductionOp.getReductionMask()).drop_front();
+ SmallVector<bool> fullReductionMask = multiReductionOp.getReductionMask();
+ ArrayRef<bool> reductionMask = ArrayRef<bool>(fullReductionMask).drop_front();
Value result = multiReductionOp.getAcc();
for (auto [innerVector, innerMask] : llvm::zip(vectors, masks)) {
@@ -829,8 +829,9 @@ struct UnrollInnerReductionAlongOuterParallel
// Compute new reduction mask by dropping the first element (dimension 0).
// Since dimension 0 is parallel (not reduced), all reduction indices shift
// down by 1.
+ SmallVector<bool> fullReductionMask = multiReductionOp.getReductionMask();
ArrayRef<bool> newReductionMask =
- ArrayRef<bool>(multiReductionOp.getReductionMask()).drop_front();
+ ArrayRef<bool>(fullReductionMask).drop_front();
SmallVector<Value> reductionResults;
for (auto [srcSlice, accSlice, maskSlice] :
>From 444ffd68e2189dc04e14f5f9065dc19f80149350 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Fri, 30 Jan 2026 19:00:09 -0500
Subject: [PATCH 15/18] Style
---
.../Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index ea65b02fe71fa..3b3aaa2b22eca 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -714,7 +714,8 @@ struct UnrollMultiReductionOuterGeneralCase
masks.push_back(nullptr);
SmallVector<bool> fullReductionMask = multiReductionOp.getReductionMask();
- ArrayRef<bool> reductionMask = ArrayRef<bool>(fullReductionMask).drop_front();
+ ArrayRef<bool> reductionMask =
+ ArrayRef<bool>(fullReductionMask).drop_front();
Value result = multiReductionOp.getAcc();
for (auto [innerVector, innerMask] : llvm::zip(vectors, masks)) {
>From fbeb0d9503cf2e4daad9d362fb4407438fdb205a Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Mon, 2 Feb 2026 10:10:15 -0500
Subject: [PATCH 16/18] [mlir][vector] Unroll until 1-D.
If innermost and outermost reduction patterns are applied in a mutually
exclusive manner, it is possible that operations do not unroll fully to
1-D. For example, using innermost reduction unrolling would stop
unrolling the moment the reduction dimensions are hit.
To address this we populate all patterns unconditionally. This allows
innermost reductions to unroll and then be matched by the outermost
reduction unrolling patterns, therefore unrolling until 1-D vectors are
generated in the IR.
---
.../Transforms/LowerVectorMultiReduction.cpp | 11 +--
.../unroll-vector-multi-reduction-inner.mlir | 72 ++++++++++++----
.../Vector/unroll-vector-multi-reduction.mlir | 86 +++++++++++++++----
3 files changed, 125 insertions(+), 44 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 3b3aaa2b22eca..0e16b260cc604 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -920,15 +920,10 @@ void mlir::vector::populateVectorUnrollMultiReduction(
RewritePatternSet &patterns, VectorMultiReductionLowering options,
PatternBenefit benefit) {
patterns.add<OneDimMultiReductionToReduction>(patterns.getContext(), benefit);
-
- if (options == VectorMultiReductionLowering::InnerReduction) {
- patterns.add<UnrollInnerReductionAlongOuterParallel>(patterns.getContext(),
- benefit);
- } else {
- patterns.add<UnrollMultiReductionOuterBaseCase,
- UnrollMultiReductionOuterGeneralCase>(patterns.getContext(),
+ patterns.add<UnrollMultiReductionOuterBaseCase,
+ UnrollMultiReductionOuterGeneralCase,
+ UnrollInnerReductionAlongOuterParallel>(patterns.getContext(),
benefit);
- }
}
void mlir::vector::populateVectorMultiReductionLoweringPatterns(
diff --git a/mlir/test/Dialect/Vector/unroll-vector-multi-reduction-inner.mlir b/mlir/test/Dialect/Vector/unroll-vector-multi-reduction-inner.mlir
index f21535c1dfc21..ea53e3668f8a2 100644
--- a/mlir/test/Dialect/Vector/unroll-vector-multi-reduction-inner.mlir
+++ b/mlir/test/Dialect/Vector/unroll-vector-multi-reduction-inner.mlir
@@ -115,23 +115,41 @@ func.func @unroll_vector_multi_reduction_inner_masked(%source: vector<2x3x5xf32>
// The general case handles multiple reduction dimensions.
// For vector<2x3x5xf32> with reduction on dims [1, 2]:
-// - Unrolls along dim 0 (size 2), creating vector<3x5xf32> multi_reductions
-// - Each new multi_reduction has reduction dims [0, 1] (shifted from [1, 2])
+// - First: UnrollInnerReductionAlongOuterParallel unrolls along dim 0 (size 2),
+// creating vector<3x5xf32> multi_reductions with dims [0, 1]
+// - Then: UnrollMultiReductionOuterGeneralCase handles these (outermost is now
+// reduction), extracting along dim 0 and chaining 1-D multi_reductions
+// - Finally: OneDimMultiReductionToReduction converts to vector.reduction
// CHECK-LABEL: func @unroll_vector_multi_reduction_inner_general(
// CHECK-SAME: %[[SOURCE:.+]]: vector<2x3x5xf32>,
// CHECK-SAME: %[[ACC:.+]]: vector<2xf32>
func.func @unroll_vector_multi_reduction_inner_general(%source: vector<2x3x5xf32>, %acc: vector<2xf32>) -> (vector<2xf32>) {
- // CHECK-DAG: %[[VEC_0:.+]] = vector.extract %[[SOURCE]][0] : vector<3x5xf32> from vector<2x3x5xf32>
- // CHECK-DAG: %[[VEC_1:.+]] = vector.extract %[[SOURCE]][1] : vector<3x5xf32> from vector<2x3x5xf32>
-
// CHECK-DAG: %[[ACC_0:.+]] = vector.extract %[[ACC]][0] : f32 from vector<2xf32>
// CHECK-DAG: %[[ACC_1:.+]] = vector.extract %[[ACC]][1] : f32 from vector<2xf32>
- // CHECK: %[[RED_0:.+]] = vector.multi_reduction <add>, %[[VEC_0]], %[[ACC_0]] [0, 1] : vector<3x5xf32> to f32
- // CHECK: %[[RED_1:.+]] = vector.multi_reduction <add>, %[[VEC_1]], %[[ACC_1]] [0, 1] : vector<3x5xf32> to f32
- // CHECK: %[[INSERT_0:.+]] = vector.insert %[[RED_0]], %{{.*}} [0] : f32 into vector<2xf32>
- // CHECK: %[[INSERT_1:.+]] = vector.insert %[[RED_1]], %[[INSERT_0]] [1] : f32 into vector<2xf32>
+ // First slice [0, ...]: extracts → chained reductions
+ // CHECK: vector.extract %[[SOURCE]][0, 0] : vector<5xf32>
+ // CHECK: vector.extract %[[SOURCE]][0, 1] : vector<5xf32>
+ // CHECK: vector.extract %[[SOURCE]][0, 2] : vector<5xf32>
+ // CHECK: %[[R0_0:.+]] = vector.reduction <add>, {{.*}}, %[[ACC_0]] : vector<5xf32> into f32
+ // CHECK: %[[R0_1:.+]] = vector.reduction <add>, {{.*}}, %[[R0_0]] : vector<5xf32> into f32
+ // CHECK: %[[R0_2:.+]] = vector.reduction <add>, {{.*}}, %[[R0_1]] : vector<5xf32> into f32
+
+ // Second slice [1, ...]: extracts → chained reductions
+ // CHECK: vector.extract %[[SOURCE]][1, 0] : vector<5xf32>
+ // CHECK: vector.extract %[[SOURCE]][1, 1] : vector<5xf32>
+ // CHECK: vector.extract %[[SOURCE]][1, 2] : vector<5xf32>
+ // CHECK: %[[R1_0:.+]] = vector.reduction <add>, {{.*}}, %[[ACC_1]] : vector<5xf32> into f32
+ // CHECK: %[[R1_1:.+]] = vector.reduction <add>, {{.*}}, %[[R1_0]] : vector<5xf32> into f32
+ // CHECK: %[[R1_2:.+]] = vector.reduction <add>, {{.*}}, %[[R1_1]] : vector<5xf32> into f32
+
+ // Final inserts
+ // CHECK: %[[INSERT_0:.+]] = vector.insert %[[R0_2]], %{{.*}} [0] : f32 into vector<2xf32>
+ // CHECK: %[[INSERT_1:.+]] = vector.insert %[[R1_2]], %[[INSERT_0]] [1] : f32 into vector<2xf32>
+
+ // No original multi_reduction remains
+ // CHECK-NOT: vector.multi_reduction
%1 = vector.multi_reduction <add>, %source, %acc [1, 2] : vector<2x3x5xf32> to vector<2xf32>
// CHECK: return %[[INSERT_1]]
@@ -145,19 +163,37 @@ func.func @unroll_vector_multi_reduction_inner_general(%source: vector<2x3x5xf32
// CHECK-SAME: %[[MASK:.+]]: vector<2x3x5xi1>,
// CHECK-SAME: %[[ACC:.+]]: vector<2xf32>
func.func @unroll_vector_multi_reduction_inner_general_masked(%source: vector<2x3x5xf32>, %mask: vector<2x3x5xi1>, %acc: vector<2xf32>) -> (vector<2xf32>) {
- // CHECK-DAG: %[[VEC_0:.+]] = vector.extract %[[SOURCE]][0] : vector<3x5xf32> from vector<2x3x5xf32>
- // CHECK-DAG: %[[VEC_1:.+]] = vector.extract %[[SOURCE]][1] : vector<3x5xf32> from vector<2x3x5xf32>
-
// CHECK-DAG: %[[ACC_0:.+]] = vector.extract %[[ACC]][0] : f32 from vector<2xf32>
// CHECK-DAG: %[[ACC_1:.+]] = vector.extract %[[ACC]][1] : f32 from vector<2xf32>
- // CHECK-DAG: %[[MASK_0:.+]] = vector.extract %[[MASK]][0] : vector<3x5xi1> from vector<2x3x5xi1>
- // CHECK-DAG: %[[MASK_1:.+]] = vector.extract %[[MASK]][1] : vector<3x5xi1> from vector<2x3x5xi1>
+ // First slice [0, ...]: extracts (source, mask) → chained masked reductions
+ // CHECK: vector.extract %[[SOURCE]][0, 0] : vector<5xf32>
+ // CHECK: vector.extract %[[SOURCE]][0, 1] : vector<5xf32>
+ // CHECK: vector.extract %[[SOURCE]][0, 2] : vector<5xf32>
+ // CHECK: vector.extract %[[MASK]][0, 0] : vector<5xi1>
+ // CHECK: vector.extract %[[MASK]][0, 1] : vector<5xi1>
+ // CHECK: vector.extract %[[MASK]][0, 2] : vector<5xi1>
+ // CHECK: %[[R0_0:.+]] = vector.mask {{.*}} { vector.reduction <add>, {{.*}}, %[[ACC_0]] : vector<5xf32> into f32 }
+ // CHECK: %[[R0_1:.+]] = vector.mask {{.*}} { vector.reduction <add>, {{.*}}, %[[R0_0]] : vector<5xf32> into f32 }
+ // CHECK: %[[R0_2:.+]] = vector.mask {{.*}} { vector.reduction <add>, {{.*}}, %[[R0_1]] : vector<5xf32> into f32 }
+
+ // Second slice [1, ...]: extracts (source, mask) → chained masked reductions
+ // CHECK: vector.extract %[[SOURCE]][1, 0] : vector<5xf32>
+ // CHECK: vector.extract %[[SOURCE]][1, 1] : vector<5xf32>
+ // CHECK: vector.extract %[[SOURCE]][1, 2] : vector<5xf32>
+ // CHECK: vector.extract %[[MASK]][1, 0] : vector<5xi1>
+ // CHECK: vector.extract %[[MASK]][1, 1] : vector<5xi1>
+ // CHECK: vector.extract %[[MASK]][1, 2] : vector<5xi1>
+ // CHECK: %[[R1_0:.+]] = vector.mask {{.*}} { vector.reduction <add>, {{.*}}, %[[ACC_1]] : vector<5xf32> into f32 }
+ // CHECK: %[[R1_1:.+]] = vector.mask {{.*}} { vector.reduction <add>, {{.*}}, %[[R1_0]] : vector<5xf32> into f32 }
+ // CHECK: %[[R1_2:.+]] = vector.mask {{.*}} { vector.reduction <add>, {{.*}}, %[[R1_1]] : vector<5xf32> into f32 }
+
+ // Final inserts
+ // CHECK: %[[INSERT_0:.+]] = vector.insert %[[R0_2]], %{{.*}} [0] : f32 into vector<2xf32>
+ // CHECK: %[[INSERT_1:.+]] = vector.insert %[[R1_2]], %[[INSERT_0]] [1] : f32 into vector<2xf32>
- // CHECK: %[[RED_0:.+]] = vector.mask %[[MASK_0]] { vector.multi_reduction <add>, %[[VEC_0]], %[[ACC_0]] [0, 1] : vector<3x5xf32> to f32 } : vector<3x5xi1> -> f32
- // CHECK: %[[RED_1:.+]] = vector.mask %[[MASK_1]] { vector.multi_reduction <add>, %[[VEC_1]], %[[ACC_1]] [0, 1] : vector<3x5xf32> to f32 } : vector<3x5xi1> -> f32
- // CHECK: %[[INSERT_0:.+]] = vector.insert %[[RED_0]], %{{.*}} [0] : f32 into vector<2xf32>
- // CHECK: %[[INSERT_1:.+]] = vector.insert %[[RED_1]], %[[INSERT_0]] [1] : f32 into vector<2xf32>
+ // No original multi_reduction remains
+ // CHECK-NOT: vector.multi_reduction
%0 = vector.mask %mask {
%1 = vector.multi_reduction <add>, %source, %acc [1, 2] : vector<2x3x5xf32> to vector<2xf32>
diff --git a/mlir/test/Dialect/Vector/unroll-vector-multi-reduction.mlir b/mlir/test/Dialect/Vector/unroll-vector-multi-reduction.mlir
index 6db425fe4fce9..3d1bfc258792c 100644
--- a/mlir/test/Dialect/Vector/unroll-vector-multi-reduction.mlir
+++ b/mlir/test/Dialect/Vector/unroll-vector-multi-reduction.mlir
@@ -49,20 +49,47 @@ func.func @unroll_vector_multi_reduction_masked(%source: vector<2x3x5xf32>, %mas
// -----
+// The general case (multiple reduction dims where outermost is reduction) now
+// fully lowers through:
+// - UnrollMultiReductionOuterGeneralCase: extracts along outermost reduction,
+// chains smaller multi_reductions with remaining reduction dims
+// - UnrollInnerReductionAlongOuterParallel: handles resulting operations where
+// outermost becomes parallel
+// - OneDimMultiReductionToReduction: converts 1-D multi_reductions to vector.reduction
+
// CHECK-LABEL: func @unroll_vector_multi_reduction_general(
// CHECK-SAME: %[[SOURCE:.+]]: vector<2x3x5xf32>,
// CHECK-SAME: %[[ACC:.+]]: vector<3xf32>
func.func @unroll_vector_multi_reduction_general(%source: vector<2x3x5xf32>, %acc: vector<3xf32>) -> (vector<3xf32>) {
-
- // CHECK-DAG: %[[VEC_0:.+]] = vector.extract %[[SOURCE]][0] : vector<3x5xf32> from vector<2x3x5xf32>
- // CHECK-DAG: %[[VEC_1:.+]] = vector.extract %[[SOURCE]][1] : vector<3x5xf32> from vector<2x3x5xf32>
-
- // CHECK: %[[RES_0:.+]] = vector.multi_reduction <add>, %[[VEC_0]], %[[ACC]] [1] : vector<3x5xf32> to vector<3xf32>
- // CHECK: %[[RES_1:.+]] = vector.multi_reduction <add>, %[[VEC_1]], %[[RES_0]] [1] : vector<3x5xf32> to vector<3xf32>
-
+ // First row of reductions (source[0, ...])
+ // CHECK: vector.extract %[[SOURCE]][0, 0] : vector<5xf32>
+ // CHECK: vector.extract %[[SOURCE]][0, 1] : vector<5xf32>
+ // CHECK: vector.extract %[[SOURCE]][0, 2] : vector<5xf32>
+ // CHECK-DAG: %[[ACC_0:.+]] = vector.extract %[[ACC]][0] : f32
+ // CHECK-DAG: %[[ACC_1:.+]] = vector.extract %[[ACC]][1] : f32
+ // CHECK-DAG: %[[ACC_2:.+]] = vector.extract %[[ACC]][2] : f32
+ // CHECK: %[[R0_0:.+]] = vector.reduction <add>, {{.*}}, %[[ACC_0]] : vector<5xf32> into f32
+ // CHECK: %[[R0_1:.+]] = vector.reduction <add>, {{.*}}, %[[ACC_1]] : vector<5xf32> into f32
+ // CHECK: %[[R0_2:.+]] = vector.reduction <add>, {{.*}}, %[[ACC_2]] : vector<5xf32> into f32
+
+ // Second row of reductions (source[1, ...]), chaining from first row results
+ // CHECK: vector.extract %[[SOURCE]][1, 0] : vector<5xf32>
+ // CHECK: vector.extract %[[SOURCE]][1, 1] : vector<5xf32>
+ // CHECK: vector.extract %[[SOURCE]][1, 2] : vector<5xf32>
+ // CHECK: %[[R1_0:.+]] = vector.reduction <add>, {{.*}}, %[[R0_0]] : vector<5xf32> into f32
+ // CHECK: %[[R1_1:.+]] = vector.reduction <add>, {{.*}}, %[[R0_1]] : vector<5xf32> into f32
+ // CHECK: %[[R1_2:.+]] = vector.reduction <add>, {{.*}}, %[[R0_2]] : vector<5xf32> into f32
+
+ // Final inserts to assemble result
+ // CHECK: %[[INSERT_0:.+]] = vector.insert %[[R1_0]], %{{.*}} [0] : f32 into vector<3xf32>
+ // CHECK: %[[INSERT_1:.+]] = vector.insert %[[R1_1]], %[[INSERT_0]] [1] : f32 into vector<3xf32>
+ // CHECK: %[[INSERT_2:.+]] = vector.insert %[[R1_2]], %[[INSERT_1]] [2] : f32 into vector<3xf32>
+
+ // No original multi_reduction remains
+ // CHECK-NOT: vector.multi_reduction
%1 = vector.multi_reduction <add>, %source, %acc [0, 2] : vector<2x3x5xf32> to vector<3xf32>
- // CHECK: return %[[RES_1]]
+ // CHECK: return %[[INSERT_2]]
return %1 : vector<3xf32>
}
@@ -73,21 +100,44 @@ func.func @unroll_vector_multi_reduction_general(%source: vector<2x3x5xf32>, %ac
// CHECK-SAME: %[[MASK:.+]]: vector<2x3x5xi1>,
// CHECK-SAME: %[[ACC:.+]]: vector<3xf32>
func.func @unroll_vector_multi_reduction_general_masked(%source: vector<2x3x5xf32>, %mask: vector<2x3x5xi1>, %acc: vector<3xf32>) -> (vector<3xf32>) {
-
- // CHECK-DAG: %[[VEC_0:.+]] = vector.extract %[[SOURCE]][0] : vector<3x5xf32> from vector<2x3x5xf32>
- // CHECK-DAG: %[[VEC_1:.+]] = vector.extract %[[SOURCE]][1] : vector<3x5xf32> from vector<2x3x5xf32>
-
- // CHECK-DAG: %[[MASK_0:.+]] = vector.extract %[[MASK]][0] : vector<3x5xi1> from vector<2x3x5xi1>
- // CHECK-DAG: %[[MASK_1:.+]] = vector.extract %[[MASK]][1] : vector<3x5xi1> from vector<2x3x5xi1>
-
- // CHECK: %[[RES_0:.+]] = vector.mask %[[MASK_0]] { vector.multi_reduction <add>, %[[VEC_0]], %[[ACC]] [1] : vector<3x5xf32> to vector<3xf32> } : vector<3x5xi1> -> vector<3xf32>
- // CHECK: %[[RES_1:.+]] = vector.mask %[[MASK_1]] { vector.multi_reduction <add>, %[[VEC_1]], %[[RES_0]] [1] : vector<3x5xf32> to vector<3xf32> } : vector<3x5xi1> -> vector<3xf32>
+ // First row of reductions (source[0, ...])
+ // CHECK: vector.extract %[[SOURCE]][0, 0] : vector<5xf32>
+ // CHECK: vector.extract %[[SOURCE]][0, 1] : vector<5xf32>
+ // CHECK: vector.extract %[[SOURCE]][0, 2] : vector<5xf32>
+ // CHECK-DAG: %[[ACC_0:.+]] = vector.extract %[[ACC]][0] : f32
+ // CHECK-DAG: %[[ACC_1:.+]] = vector.extract %[[ACC]][1] : f32
+ // CHECK-DAG: %[[ACC_2:.+]] = vector.extract %[[ACC]][2] : f32
+ // CHECK: vector.extract %[[MASK]][0, 0] : vector<5xi1>
+ // CHECK: vector.extract %[[MASK]][0, 1] : vector<5xi1>
+ // CHECK: vector.extract %[[MASK]][0, 2] : vector<5xi1>
+ // CHECK: %[[R0_0:.+]] = vector.mask {{.*}} { vector.reduction <add>, {{.*}}, %[[ACC_0]] : vector<5xf32> into f32 }
+ // CHECK: %[[R0_1:.+]] = vector.mask {{.*}} { vector.reduction <add>, {{.*}}, %[[ACC_1]] : vector<5xf32> into f32 }
+ // CHECK: %[[R0_2:.+]] = vector.mask {{.*}} { vector.reduction <add>, {{.*}}, %[[ACC_2]] : vector<5xf32> into f32 }
+
+ // Second row of reductions (source[1, ...]), chaining from first row results
+ // CHECK: vector.extract %[[SOURCE]][1, 0] : vector<5xf32>
+ // CHECK: vector.extract %[[SOURCE]][1, 1] : vector<5xf32>
+ // CHECK: vector.extract %[[SOURCE]][1, 2] : vector<5xf32>
+ // CHECK: vector.extract %[[MASK]][1, 0] : vector<5xi1>
+ // CHECK: vector.extract %[[MASK]][1, 1] : vector<5xi1>
+ // CHECK: vector.extract %[[MASK]][1, 2] : vector<5xi1>
+ // CHECK: %[[R1_0:.+]] = vector.mask {{.*}} { vector.reduction <add>, {{.*}}, %[[R0_0]] : vector<5xf32> into f32 }
+ // CHECK: %[[R1_1:.+]] = vector.mask {{.*}} { vector.reduction <add>, {{.*}}, %[[R0_1]] : vector<5xf32> into f32 }
+ // CHECK: %[[R1_2:.+]] = vector.mask {{.*}} { vector.reduction <add>, {{.*}}, %[[R0_2]] : vector<5xf32> into f32 }
+
+ // Final inserts to assemble result
+ // CHECK: %[[INSERT_0:.+]] = vector.insert %[[R1_0]], %{{.*}} [0] : f32 into vector<3xf32>
+ // CHECK: %[[INSERT_1:.+]] = vector.insert %[[R1_1]], %[[INSERT_0]] [1] : f32 into vector<3xf32>
+ // CHECK: %[[INSERT_2:.+]] = vector.insert %[[R1_2]], %[[INSERT_1]] [2] : f32 into vector<3xf32>
+
+ // No original multi_reduction remains
+ // CHECK-NOT: vector.multi_reduction
%0 = vector.mask %mask {
%1 = vector.multi_reduction <add>, %source, %acc [0, 2] : vector<2x3x5xf32> to vector<3xf32>
} : vector<2x3x5xi1> -> vector<3xf32>
- // CHECK: return %[[RES_1]]
+ // CHECK: return %[[INSERT_2]]
return %0 : vector<3xf32>
}
>From b3a88c2133a918fb3fa52c7a1fbc92a2a553a429 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Mon, 2 Feb 2026 10:17:34 -0500
Subject: [PATCH 17/18] Update comments
---
.../unroll-vector-multi-reduction-inner.mlir | 24 -------------------
.../Vector/unroll-vector-multi-reduction.mlir | 23 ++----------------
2 files changed, 2 insertions(+), 45 deletions(-)
diff --git a/mlir/test/Dialect/Vector/unroll-vector-multi-reduction-inner.mlir b/mlir/test/Dialect/Vector/unroll-vector-multi-reduction-inner.mlir
index ea53e3668f8a2..033d81ed14b41 100644
--- a/mlir/test/Dialect/Vector/unroll-vector-multi-reduction-inner.mlir
+++ b/mlir/test/Dialect/Vector/unroll-vector-multi-reduction-inner.mlir
@@ -19,7 +19,6 @@
// CHECK-SAME: %[[SOURCE:.+]]: vector<2x3x5xf32>,
// CHECK-SAME: %[[ACC:.+]]: vector<2x3xf32>
func.func @unroll_vector_multi_reduction_inner(%source: vector<2x3x5xf32>, %acc: vector<2x3xf32>) -> (vector<2x3xf32>) {
- // First slice [0, ...]: extracts → reductions → inserts
// CHECK: vector.extract %[[SOURCE]][0, 0]
// CHECK: vector.extract %[[SOURCE]][0, 1]
// CHECK: vector.extract %[[SOURCE]][0, 2]
@@ -32,7 +31,6 @@ func.func @unroll_vector_multi_reduction_inner(%source: vector<2x3x5xf32>, %acc:
// CHECK: vector.insert {{.*}} [0] : f32 into vector<3xf32>
// CHECK: vector.insert {{.*}} [1] : f32 into vector<3xf32>
// CHECK: vector.insert {{.*}} [2] : f32 into vector<3xf32>
- // Second slice [1, ...]: extracts → reductions → inserts
// CHECK: vector.extract %[[SOURCE]][1, 0]
// CHECK: vector.extract %[[SOURCE]][1, 1]
// CHECK: vector.extract %[[SOURCE]][1, 2]
@@ -45,10 +43,8 @@ func.func @unroll_vector_multi_reduction_inner(%source: vector<2x3x5xf32>, %acc:
// CHECK: vector.insert {{.*}} [0] : f32 into vector<3xf32>
// CHECK: vector.insert {{.*}} [1] : f32 into vector<3xf32>
// CHECK: vector.insert {{.*}} [2] : f32 into vector<3xf32>
- // Final inserts to assemble result
// CHECK: vector.insert {{.*}} [0] : vector<3xf32> into vector<2x3xf32>
// CHECK: vector.insert {{.*}} [1] : vector<3xf32> into vector<2x3xf32>
- // No original multi_reduction remains
// CHECK-NOT: vector.multi_reduction
%1 = vector.multi_reduction <add>, %source, %acc [2] : vector<2x3x5xf32> to vector<2x3xf32>
@@ -62,7 +58,6 @@ func.func @unroll_vector_multi_reduction_inner(%source: vector<2x3x5xf32>, %acc:
// CHECK-SAME: %[[MASK:.+]]: vector<2x3x5xi1>,
// CHECK-SAME: %[[ACC:.+]]: vector<2x3xf32>
func.func @unroll_vector_multi_reduction_inner_masked(%source: vector<2x3x5xf32>, %mask: vector<2x3x5xi1>, %acc: vector<2x3xf32>) -> (vector<2x3xf32>) {
- // First slice [0, ...]: extracts (source, acc, mask) → masked reductions → inserts
// CHECK: vector.extract %[[SOURCE]][0, 0]
// CHECK: vector.extract %[[SOURCE]][0, 1]
// CHECK: vector.extract %[[SOURCE]][0, 2]
@@ -78,7 +73,6 @@ func.func @unroll_vector_multi_reduction_inner_masked(%source: vector<2x3x5xf32>
// CHECK: vector.insert {{.*}} [0] : f32 into vector<3xf32>
// CHECK: vector.insert {{.*}} [1] : f32 into vector<3xf32>
// CHECK: vector.insert {{.*}} [2] : f32 into vector<3xf32>
- // Second slice [1, ...]: extracts → masked reductions → inserts
// CHECK: vector.extract %[[SOURCE]][1, 0]
// CHECK: vector.extract %[[SOURCE]][1, 1]
// CHECK: vector.extract %[[SOURCE]][1, 2]
@@ -94,10 +88,8 @@ func.func @unroll_vector_multi_reduction_inner_masked(%source: vector<2x3x5xf32>
// CHECK: vector.insert {{.*}} [0] : f32 into vector<3xf32>
// CHECK: vector.insert {{.*}} [1] : f32 into vector<3xf32>
// CHECK: vector.insert {{.*}} [2] : f32 into vector<3xf32>
- // Final inserts to assemble result
// CHECK: vector.insert {{.*}} [0] : vector<3xf32> into vector<2x3xf32>
// CHECK: vector.insert {{.*}} [1] : vector<3xf32> into vector<2x3xf32>
- // No original multi_reduction remains
// CHECK-NOT: vector.multi_reduction
%0 = vector.mask %mask {
@@ -127,28 +119,20 @@ func.func @unroll_vector_multi_reduction_inner_masked(%source: vector<2x3x5xf32>
func.func @unroll_vector_multi_reduction_inner_general(%source: vector<2x3x5xf32>, %acc: vector<2xf32>) -> (vector<2xf32>) {
// CHECK-DAG: %[[ACC_0:.+]] = vector.extract %[[ACC]][0] : f32 from vector<2xf32>
// CHECK-DAG: %[[ACC_1:.+]] = vector.extract %[[ACC]][1] : f32 from vector<2xf32>
-
- // First slice [0, ...]: extracts → chained reductions
// CHECK: vector.extract %[[SOURCE]][0, 0] : vector<5xf32>
// CHECK: vector.extract %[[SOURCE]][0, 1] : vector<5xf32>
// CHECK: vector.extract %[[SOURCE]][0, 2] : vector<5xf32>
// CHECK: %[[R0_0:.+]] = vector.reduction <add>, {{.*}}, %[[ACC_0]] : vector<5xf32> into f32
// CHECK: %[[R0_1:.+]] = vector.reduction <add>, {{.*}}, %[[R0_0]] : vector<5xf32> into f32
// CHECK: %[[R0_2:.+]] = vector.reduction <add>, {{.*}}, %[[R0_1]] : vector<5xf32> into f32
-
- // Second slice [1, ...]: extracts → chained reductions
// CHECK: vector.extract %[[SOURCE]][1, 0] : vector<5xf32>
// CHECK: vector.extract %[[SOURCE]][1, 1] : vector<5xf32>
// CHECK: vector.extract %[[SOURCE]][1, 2] : vector<5xf32>
// CHECK: %[[R1_0:.+]] = vector.reduction <add>, {{.*}}, %[[ACC_1]] : vector<5xf32> into f32
// CHECK: %[[R1_1:.+]] = vector.reduction <add>, {{.*}}, %[[R1_0]] : vector<5xf32> into f32
// CHECK: %[[R1_2:.+]] = vector.reduction <add>, {{.*}}, %[[R1_1]] : vector<5xf32> into f32
-
- // Final inserts
// CHECK: %[[INSERT_0:.+]] = vector.insert %[[R0_2]], %{{.*}} [0] : f32 into vector<2xf32>
// CHECK: %[[INSERT_1:.+]] = vector.insert %[[R1_2]], %[[INSERT_0]] [1] : f32 into vector<2xf32>
-
- // No original multi_reduction remains
// CHECK-NOT: vector.multi_reduction
%1 = vector.multi_reduction <add>, %source, %acc [1, 2] : vector<2x3x5xf32> to vector<2xf32>
@@ -165,8 +149,6 @@ func.func @unroll_vector_multi_reduction_inner_general(%source: vector<2x3x5xf32
func.func @unroll_vector_multi_reduction_inner_general_masked(%source: vector<2x3x5xf32>, %mask: vector<2x3x5xi1>, %acc: vector<2xf32>) -> (vector<2xf32>) {
// CHECK-DAG: %[[ACC_0:.+]] = vector.extract %[[ACC]][0] : f32 from vector<2xf32>
// CHECK-DAG: %[[ACC_1:.+]] = vector.extract %[[ACC]][1] : f32 from vector<2xf32>
-
- // First slice [0, ...]: extracts (source, mask) → chained masked reductions
// CHECK: vector.extract %[[SOURCE]][0, 0] : vector<5xf32>
// CHECK: vector.extract %[[SOURCE]][0, 1] : vector<5xf32>
// CHECK: vector.extract %[[SOURCE]][0, 2] : vector<5xf32>
@@ -176,8 +158,6 @@ func.func @unroll_vector_multi_reduction_inner_general_masked(%source: vector<2x
// CHECK: %[[R0_0:.+]] = vector.mask {{.*}} { vector.reduction <add>, {{.*}}, %[[ACC_0]] : vector<5xf32> into f32 }
// CHECK: %[[R0_1:.+]] = vector.mask {{.*}} { vector.reduction <add>, {{.*}}, %[[R0_0]] : vector<5xf32> into f32 }
// CHECK: %[[R0_2:.+]] = vector.mask {{.*}} { vector.reduction <add>, {{.*}}, %[[R0_1]] : vector<5xf32> into f32 }
-
- // Second slice [1, ...]: extracts (source, mask) → chained masked reductions
// CHECK: vector.extract %[[SOURCE]][1, 0] : vector<5xf32>
// CHECK: vector.extract %[[SOURCE]][1, 1] : vector<5xf32>
// CHECK: vector.extract %[[SOURCE]][1, 2] : vector<5xf32>
@@ -187,12 +167,8 @@ func.func @unroll_vector_multi_reduction_inner_general_masked(%source: vector<2x
// CHECK: %[[R1_0:.+]] = vector.mask {{.*}} { vector.reduction <add>, {{.*}}, %[[ACC_1]] : vector<5xf32> into f32 }
// CHECK: %[[R1_1:.+]] = vector.mask {{.*}} { vector.reduction <add>, {{.*}}, %[[R1_0]] : vector<5xf32> into f32 }
// CHECK: %[[R1_2:.+]] = vector.mask {{.*}} { vector.reduction <add>, {{.*}}, %[[R1_1]] : vector<5xf32> into f32 }
-
- // Final inserts
// CHECK: %[[INSERT_0:.+]] = vector.insert %[[R0_2]], %{{.*}} [0] : f32 into vector<2xf32>
// CHECK: %[[INSERT_1:.+]] = vector.insert %[[R1_2]], %[[INSERT_0]] [1] : f32 into vector<2xf32>
-
- // No original multi_reduction remains
// CHECK-NOT: vector.multi_reduction
%0 = vector.mask %mask {
diff --git a/mlir/test/Dialect/Vector/unroll-vector-multi-reduction.mlir b/mlir/test/Dialect/Vector/unroll-vector-multi-reduction.mlir
index 3d1bfc258792c..f4048fa80e5e8 100644
--- a/mlir/test/Dialect/Vector/unroll-vector-multi-reduction.mlir
+++ b/mlir/test/Dialect/Vector/unroll-vector-multi-reduction.mlir
@@ -49,19 +49,13 @@ func.func @unroll_vector_multi_reduction_masked(%source: vector<2x3x5xf32>, %mas
// -----
-// The general case (multiple reduction dims where outermost is reduction) now
-// fully lowers through:
-// - UnrollMultiReductionOuterGeneralCase: extracts along outermost reduction,
-// chains smaller multi_reductions with remaining reduction dims
-// - UnrollInnerReductionAlongOuterParallel: handles resulting operations where
-// outermost becomes parallel
-// - OneDimMultiReductionToReduction: converts 1-D multi_reductions to vector.reduction
+// Multiple reduction dims with outermost as reduction. Fully lowers to
+// vector.reduction by combining outer and inner unrolling patterns.
// CHECK-LABEL: func @unroll_vector_multi_reduction_general(
// CHECK-SAME: %[[SOURCE:.+]]: vector<2x3x5xf32>,
// CHECK-SAME: %[[ACC:.+]]: vector<3xf32>
func.func @unroll_vector_multi_reduction_general(%source: vector<2x3x5xf32>, %acc: vector<3xf32>) -> (vector<3xf32>) {
- // First row of reductions (source[0, ...])
// CHECK: vector.extract %[[SOURCE]][0, 0] : vector<5xf32>
// CHECK: vector.extract %[[SOURCE]][0, 1] : vector<5xf32>
// CHECK: vector.extract %[[SOURCE]][0, 2] : vector<5xf32>
@@ -71,21 +65,15 @@ func.func @unroll_vector_multi_reduction_general(%source: vector<2x3x5xf32>, %ac
// CHECK: %[[R0_0:.+]] = vector.reduction <add>, {{.*}}, %[[ACC_0]] : vector<5xf32> into f32
// CHECK: %[[R0_1:.+]] = vector.reduction <add>, {{.*}}, %[[ACC_1]] : vector<5xf32> into f32
// CHECK: %[[R0_2:.+]] = vector.reduction <add>, {{.*}}, %[[ACC_2]] : vector<5xf32> into f32
-
- // Second row of reductions (source[1, ...]), chaining from first row results
// CHECK: vector.extract %[[SOURCE]][1, 0] : vector<5xf32>
// CHECK: vector.extract %[[SOURCE]][1, 1] : vector<5xf32>
// CHECK: vector.extract %[[SOURCE]][1, 2] : vector<5xf32>
// CHECK: %[[R1_0:.+]] = vector.reduction <add>, {{.*}}, %[[R0_0]] : vector<5xf32> into f32
// CHECK: %[[R1_1:.+]] = vector.reduction <add>, {{.*}}, %[[R0_1]] : vector<5xf32> into f32
// CHECK: %[[R1_2:.+]] = vector.reduction <add>, {{.*}}, %[[R0_2]] : vector<5xf32> into f32
-
- // Final inserts to assemble result
// CHECK: %[[INSERT_0:.+]] = vector.insert %[[R1_0]], %{{.*}} [0] : f32 into vector<3xf32>
// CHECK: %[[INSERT_1:.+]] = vector.insert %[[R1_1]], %[[INSERT_0]] [1] : f32 into vector<3xf32>
// CHECK: %[[INSERT_2:.+]] = vector.insert %[[R1_2]], %[[INSERT_1]] [2] : f32 into vector<3xf32>
-
- // No original multi_reduction remains
// CHECK-NOT: vector.multi_reduction
%1 = vector.multi_reduction <add>, %source, %acc [0, 2] : vector<2x3x5xf32> to vector<3xf32>
@@ -100,7 +88,6 @@ func.func @unroll_vector_multi_reduction_general(%source: vector<2x3x5xf32>, %ac
// CHECK-SAME: %[[MASK:.+]]: vector<2x3x5xi1>,
// CHECK-SAME: %[[ACC:.+]]: vector<3xf32>
func.func @unroll_vector_multi_reduction_general_masked(%source: vector<2x3x5xf32>, %mask: vector<2x3x5xi1>, %acc: vector<3xf32>) -> (vector<3xf32>) {
- // First row of reductions (source[0, ...])
// CHECK: vector.extract %[[SOURCE]][0, 0] : vector<5xf32>
// CHECK: vector.extract %[[SOURCE]][0, 1] : vector<5xf32>
// CHECK: vector.extract %[[SOURCE]][0, 2] : vector<5xf32>
@@ -113,8 +100,6 @@ func.func @unroll_vector_multi_reduction_general_masked(%source: vector<2x3x5xf3
// CHECK: %[[R0_0:.+]] = vector.mask {{.*}} { vector.reduction <add>, {{.*}}, %[[ACC_0]] : vector<5xf32> into f32 }
// CHECK: %[[R0_1:.+]] = vector.mask {{.*}} { vector.reduction <add>, {{.*}}, %[[ACC_1]] : vector<5xf32> into f32 }
// CHECK: %[[R0_2:.+]] = vector.mask {{.*}} { vector.reduction <add>, {{.*}}, %[[ACC_2]] : vector<5xf32> into f32 }
-
- // Second row of reductions (source[1, ...]), chaining from first row results
// CHECK: vector.extract %[[SOURCE]][1, 0] : vector<5xf32>
// CHECK: vector.extract %[[SOURCE]][1, 1] : vector<5xf32>
// CHECK: vector.extract %[[SOURCE]][1, 2] : vector<5xf32>
@@ -124,13 +109,9 @@ func.func @unroll_vector_multi_reduction_general_masked(%source: vector<2x3x5xf3
// CHECK: %[[R1_0:.+]] = vector.mask {{.*}} { vector.reduction <add>, {{.*}}, %[[R0_0]] : vector<5xf32> into f32 }
// CHECK: %[[R1_1:.+]] = vector.mask {{.*}} { vector.reduction <add>, {{.*}}, %[[R0_1]] : vector<5xf32> into f32 }
// CHECK: %[[R1_2:.+]] = vector.mask {{.*}} { vector.reduction <add>, {{.*}}, %[[R0_2]] : vector<5xf32> into f32 }
-
- // Final inserts to assemble result
// CHECK: %[[INSERT_0:.+]] = vector.insert %[[R1_0]], %{{.*}} [0] : f32 into vector<3xf32>
// CHECK: %[[INSERT_1:.+]] = vector.insert %[[R1_1]], %[[INSERT_0]] [1] : f32 into vector<3xf32>
// CHECK: %[[INSERT_2:.+]] = vector.insert %[[R1_2]], %[[INSERT_1]] [2] : f32 into vector<3xf32>
-
- // No original multi_reduction remains
// CHECK-NOT: vector.multi_reduction
%0 = vector.mask %mask {
>From 3cb942082ff3b238f76bebb039baf5f40d21b37a Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Mon, 2 Feb 2026 10:23:23 -0500
Subject: [PATCH 18/18] Update description for pass
---
.../Dialect/Vector/TransformOps/VectorTransformOps.td | 9 ++++-----
1 file changed, 4 insertions(+), 5 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index 373f9a4210c41..4f112ba88a564 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -247,11 +247,10 @@ def ApplyUnrollMultiReductionPatternsOp : Op<Transform_Dialect,
"apply_patterns.vector.unroll_multi_reduction",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
let description = [{
- Unrolls vector.multi_reduction operations by progressively reducing rank
- along the outermost dimension.
-
- This is an alternative to the flattening-based lowering that preserves
- the n-D structure during progressive lowering.
+ Populates patterns that unroll vector.multi_reduction operations by
+ progressively reducing rank. Terminal cases lower to either:
+ - vector.reduction (for 1-D reductions), or
+ - elementwise arith ops (when outermost dim is the only reduction dim).
}];
let arguments = (ins DefaultValuedAttr<VectorMultiReductionLoweringAttr,
More information about the Mlir-commits
mailing list