[Mlir-commits] [mlir] [vector][multi_reduction] Add unrolling for vector.multi_reduction (PR #185033)
Erick Ochoa Lopez
llvmlistbot at llvm.org
Fri Mar 6 08:18:15 PST 2026
https://github.com/amd-eochoalo created https://github.com/llvm/llvm-project/pull/185033
`vector.multi_reduction`'s unrolling patterns are currently a misnomer. They don't unroll, but instead lower multi_reduction to either arith operations or `vector.reduction`. This PR:
* renames `multi_reduction_unrolling` to `multi_reduction_lowering`
* adds `multi_reduction_unrolling` where actual unrolling happens.
The benefits of vector.multi_reduction expand, flattening, unrolling and lowering patterns are such that they preserve existing behaviour. The benefit of adding unrolling is that now we can unroll without going through expand or flattening. This allows different backends finer control over which patterns to apply. For example, to better handle differences between SPIR-V and LLVM lowerings.
This will break downstream projects. All you need to do is apply multi_reduction_lowering where multi_reduction_unrolling was applied before.
>From d07bef658ff8167892b8e56769bc4e52b5b9a712 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Fri, 6 Mar 2026 09:47:24 -0500
Subject: [PATCH 1/5] [mlir][vector] Rename multi_reduction's unrolling to
lowering.
Stated previously:
> Below are some notes on naming
> * RewriteXAsY or LowerXToY (if we're changing op kind)
>
> Concretely, TwoDimMultiReductionToReduction looks like a lowering
> (it rewrites to vector.reduction), not an unrolling,
https://github.com/llvm/llvm-project/pull/182301#issuecomment-3951827980
---
.../Dialect/Vector/TransformOps/VectorTransformOps.td | 8 ++++----
.../mlir/Dialect/Vector/Transforms/LoweringPatterns.h | 2 +-
.../Dialect/Vector/TransformOps/VectorTransformOps.cpp | 4 ++--
.../Vector/Transforms/LowerVectorMultiReduction.cpp | 10 +++++-----
mlir/test/Dialect/LLVM/transform-e2e.mlir | 2 +-
mlir/test/Dialect/Vector/transform-vector.mlir | 2 +-
.../Vector/vector-multi-reduction-unrolling.mlir | 4 ++--
.../Dialect/Linalg/CPU/ArmSVE/reduce_1d.mlir | 2 +-
.../Dialect/Linalg/CPU/ArmSVE/reduce_2d.mlir | 2 +-
.../Dialect/Linalg/CPU/test-matmul-masked-vec.mlir | 2 +-
mlir/test/python/dialects/transform_vector_ext.py | 6 +++---
11 files changed, 22 insertions(+), 22 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index dcd5f6ff3ad74..76e77b9f90a13 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -262,18 +262,18 @@ def ApplyMultiReductionFlatteningPatternsOp: Op<Transform_Dialect,
}];
}
-def ApplyMultiReductionUnrollingPatternsOp: Op<Transform_Dialect,
- "apply_patterns.vector.multi_reduction_unrolling",
+def ApplyMultiReductionLoweringPatternsOp: Op<Transform_Dialect,
+ "apply_patterns.vector.multi_reduction_lowering",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
let description = [{
- Indicates that vector multi_reduction operations should be unrolled.
+ Indicates that vector multi_reduction operations should be lowered.
1-D multi_reductions are converted directly to vector.reduction.
2-D multi_reductions are unrolled into either a sequence of
vector.reduction ops (innerreduction) or element-wise arith ops
(innerparallel).
This populates the patterns from
- `populateVectorMultiReductionUnrollingPatterns`, i.e.:
+ `populateVectorMultiReductionLoweringPatterns`, i.e.:
* `OneDimMultiReductionToReduction`
* `TwoDimMultiReductionToReduction` (innerreduction)
* `TwoDimMultiReductionToElementWise` (innerparallel)
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index aa75eff409ef9..2ea3367da414d 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -96,7 +96,7 @@ void populateVectorMultiReductionFlatteningPatterns(
/// dimension, unroll the outer dimension to obtain a sequence of extract +
/// vector.reduction + insert. This can further lower to horizontal reduction
/// ops.
-void populateVectorMultiReductionUnrollingPatterns(
+void populateVectorMultiReductionLoweringPatterns(
RewritePatternSet &patterns, VectorMultiReductionLowering options,
PatternBenefit benefit = 1);
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 312bd28ad48cf..d1799b3bd84d8 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -145,11 +145,11 @@ void transform::ApplyMultiReductionFlatteningPatternsOp::populatePatterns(
patterns, vectorTransformOptions.vectorMultiReductionLowering);
}
-void transform::ApplyMultiReductionUnrollingPatternsOp::populatePatterns(
+void transform::ApplyMultiReductionLoweringPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
vector::VectorTransformsOptions vectorTransformOptions;
vectorTransformOptions.setVectorMultiReductionLowering(getLoweringStrategy());
- vector::populateVectorMultiReductionUnrollingPatterns(
+ vector::populateVectorMultiReductionLoweringPatterns(
patterns, vectorTransformOptions.vectorMultiReductionLowering);
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 76599822fbfe4..88e669d1a34b0 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -505,10 +505,10 @@ struct LowerVectorMultiReductionPass
if (failed(applyPatternsGreedily(op, std::move(flatteningPatterns))))
signalPassFailure();
- RewritePatternSet unrollingPatterns(context);
- mlir::vector::populateVectorMultiReductionUnrollingPatterns(
- unrollingPatterns, this->loweringStrategy);
- if (failed(applyPatternsGreedily(op, std::move(unrollingPatterns))))
+ RewritePatternSet loweringPatterns(context);
+ mlir::vector::populateVectorMultiReductionLoweringPatterns(
+ loweringPatterns, this->loweringStrategy);
+ if (failed(applyPatternsGreedily(op, std::move(loweringPatterns))))
signalPassFailure();
}
@@ -532,7 +532,7 @@ void mlir::vector::populateVectorMultiReductionFlatteningPatterns(
patterns.add<FlattenMultiReduction>(patterns.getContext(), options, benefit);
}
-void mlir::vector::populateVectorMultiReductionUnrollingPatterns(
+void mlir::vector::populateVectorMultiReductionLoweringPatterns(
RewritePatternSet &patterns, VectorMultiReductionLowering options,
PatternBenefit benefit) {
patterns.add<OneDimMultiReductionToReduction>(patterns.getContext(), benefit);
diff --git a/mlir/test/Dialect/LLVM/transform-e2e.mlir b/mlir/test/Dialect/LLVM/transform-e2e.mlir
index bf7eba6e50174..7f4c78a7fc2c8 100644
--- a/mlir/test/Dialect/LLVM/transform-e2e.mlir
+++ b/mlir/test/Dialect/LLVM/transform-e2e.mlir
@@ -32,7 +32,7 @@ module attributes {transform.with_named_sequence} {
transform.apply_patterns.vector.transfer_permutation_patterns
transform.apply_patterns.vector.reorder_multi_reduction_dims lowering_strategy = "innerparallel"
transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerparallel"
- transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerparallel"
+ transform.apply_patterns.vector.multi_reduction_lowering lowering_strategy = "innerparallel"
transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "linalg-copy"
transform.apply_patterns.vector.transfer_to_scf max_transfer_rank = 1 full_unroll = true
transform.apply_patterns.vector.lower_transfer max_transfer_rank = 1
diff --git a/mlir/test/Dialect/Vector/transform-vector.mlir b/mlir/test/Dialect/Vector/transform-vector.mlir
index 4dc11c26e83f1..9027bef2eac5c 100644
--- a/mlir/test/Dialect/Vector/transform-vector.mlir
+++ b/mlir/test/Dialect/Vector/transform-vector.mlir
@@ -41,7 +41,7 @@ module attributes {transform.with_named_sequence} {
transform.apply_patterns to %f {
transform.apply_patterns.vector.reorder_multi_reduction_dims lowering_strategy = "innerparallel"
transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerparallel"
- transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerparallel"
+ transform.apply_patterns.vector.multi_reduction_lowering lowering_strategy = "innerparallel"
} : !transform.any_op
transform.apply_patterns to %f {
diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-unrolling.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-unrolling.mlir
index 447416ccba637..a7ec4b9206c33 100644
--- a/mlir/test/Dialect/Vector/vector-multi-reduction-unrolling.mlir
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-unrolling.mlir
@@ -117,7 +117,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @innerreduction(%root : !transform.any_op {transform.readonly}) {
%func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
transform.apply_patterns to %func_op {
- transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerreduction"
+ transform.apply_patterns.vector.multi_reduction_lowering lowering_strategy = "innerreduction"
} : !transform.op<"func.func">
transform.yield
}
@@ -125,7 +125,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @innerparallel(%root : !transform.any_op {transform.readonly}) {
%func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
transform.apply_patterns to %func_op {
- transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerparallel"
+ transform.apply_patterns.vector.multi_reduction_lowering lowering_strategy = "innerparallel"
} : !transform.op<"func.func">
transform.yield
}
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_1d.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_1d.mlir
index a7b0b27ca5fb9..15f5bfd3619eb 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_1d.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_1d.mlir
@@ -152,7 +152,7 @@ module attributes {transform.with_named_sequence} {
transform.apply_patterns.vector.lower_masked_transfers
transform.apply_patterns.vector.reorder_multi_reduction_dims lowering_strategy = "innerreduction"
transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerreduction"
- transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerreduction"
+ transform.apply_patterns.vector.multi_reduction_lowering lowering_strategy = "innerreduction"
} : !transform.op<"func.func">
transform.yield
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_2d.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_2d.mlir
index 4adc68966f17a..0cfba97654769 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_2d.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_2d.mlir
@@ -157,7 +157,7 @@ module attributes {transform.with_named_sequence} {
transform.apply_patterns.vector.lower_masked_transfers
transform.apply_patterns.vector.reorder_multi_reduction_dims lowering_strategy = "innerreduction"
transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerreduction"
- transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerreduction"
+ transform.apply_patterns.vector.multi_reduction_lowering lowering_strategy = "innerreduction"
} : !transform.op<"func.func">
transform.yield
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/test-matmul-masked-vec.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/test-matmul-masked-vec.mlir
index 0883e7b698f55..ce85720d665c2 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/test-matmul-masked-vec.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/test-matmul-masked-vec.mlir
@@ -55,7 +55,7 @@ module attributes {transform.with_named_sequence} {
transform.apply_patterns to %func_op {
transform.apply_patterns.vector.reorder_multi_reduction_dims lowering_strategy = "innerreduction"
transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerreduction"
- transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerreduction"
+ transform.apply_patterns.vector.multi_reduction_lowering lowering_strategy = "innerreduction"
} : !transform.op<"func.func">
transform.yield
}
diff --git a/mlir/test/python/dialects/transform_vector_ext.py b/mlir/test/python/dialects/transform_vector_ext.py
index a3c53a45048b2..cfaf91d24f471 100644
--- a/mlir/test/python/dialects/transform_vector_ext.py
+++ b/mlir/test/python/dialects/transform_vector_ext.py
@@ -103,9 +103,9 @@ def enum_configurable_patterns():
lowering_strategy=vector.VectorMultiReductionLowering.InnerReduction
)
- # CHECK: transform.apply_patterns.vector.multi_reduction_unrolling
- vector.ApplyMultiReductionUnrollingPatternsOp()
- # CHECK: transform.apply_patterns.vector.multi_reduction_unrolling
+ # CHECK: transform.apply_patterns.vector.multi_reduction_lowering
+ vector.ApplyMultiReductionLoweringPatternsOp()
+ # CHECK: transform.apply_patterns.vector.multi_reduction_lowering
# CHECK-SAME: lowering_strategy = innerreduction
vector.ApplyMultiReductionUnrollingPatternsOp(
lowering_strategy=vector.VectorMultiReductionLowering.InnerReduction
>From ba1fffa23f4b9cf6862897959db2cbecf2eea9d4 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Fri, 6 Mar 2026 10:01:00 -0500
Subject: [PATCH 2/5] [mlir][vector] Add multi_reduction unrolling
innerparallel.
When unrolling vector.multi_reduction with outermost
dimensions being reduced, exctract outermost dimension
vectors and chain them with vector.multi_reductions.
---
.../Vector/TransformOps/VectorTransformOps.td | 14 ++++
.../Vector/Transforms/LoweringPatterns.h | 5 ++
.../TransformOps/VectorTransformOps.cpp | 5 ++
.../Transforms/LowerVectorMultiReduction.cpp | 82 +++++++++++++++++++
4 files changed, 106 insertions(+)
diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index 76e77b9f90a13..dff2803f95ae6 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -262,6 +262,20 @@ def ApplyMultiReductionFlatteningPatternsOp: Op<Transform_Dialect,
}];
}
+def ApplyUnrollMultiReductionUnrollingPatternsOp: Op<Transform_Dialect,
+ "apply_patterns.vector.multi_reduction_unrolling",
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ let description = [{
+ Indicates that vector multi_reduction operations with more than one
+ reduction dimension should be unrolled in a rank-reducing way.
+
+ This populates the patterns:
+ }];
+
+ let assemblyFormat = "attr-dict";
+}
+
+
def ApplyMultiReductionLoweringPatternsOp: Op<Transform_Dialect,
"apply_patterns.vector.multi_reduction_lowering",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 2ea3367da414d..d9237bac63b78 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -81,6 +81,11 @@ void populateVectorMultiReductionFlatteningPatterns(
RewritePatternSet &patterns, VectorMultiReductionLowering options,
PatternBenefit benefit = 1);
+/// Populate the pattern set with the following patterns:
+///
+void populateVectorMultiReductionUnrollingPatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit = 1);
+
/// Populate the pattern set with the following patterns:
///
/// [OneDimMultiReductionToReduction]
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index d1799b3bd84d8..256384f6b7c18 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -145,6 +145,11 @@ void transform::ApplyMultiReductionFlatteningPatternsOp::populatePatterns(
patterns, vectorTransformOptions.vectorMultiReductionLowering);
}
+void transform::ApplyUnrollMultiReductionUnrollingPatternsOp::populatePatterns(
+ RewritePatternSet &patterns) {
+ vector::populateVectorMultiReductionUnrollingPatterns(patterns);
+}
+
void transform::ApplyMultiReductionLoweringPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
vector::VectorTransformsOptions vectorTransformOptions;
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 88e669d1a34b0..51b2f64c9a7da 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -443,6 +443,82 @@ struct TwoDimMultiReductionToReduction
}
};
+/// Unrolls outermost dimension for vector.multi_reduction.
+/// Matches when the outermost dimension is not the only
+/// reduction dimension.
+///
+/// ```mlir
+/// %res = vector.multi_reduction <add> %src, %acc [0, [[REDUCTION_DIMS]] ] :
+/// vector<NxMx...xf32> to vector<Ix...xf32>
+/// ```
+///
+/// ```mlir
+/// %0 = vector.extract %src[0] : vector<Mx...xf32> from vector<NxMx...xf32>
+/// ...
+/// %Nminus1 = vector.extract %src[ [[N-1]] ] : vector<Mx...x.f32> from
+/// vector<NxMx...xf32>
+///
+/// %red0 = vector.multi_reduction %0, %acc [ [[REDUCTION_DIMS]] ] :
+/// vector<Mx...xf32> to vector<Ix...xf32>
+/// ...
+/// %res = vector.multi_reduction %Nminus1, %redNminus2 [ [[REDUCTION_DIMS]] ] :
+/// vector<Mx...xf32> to vector<Ix...xf32>
+/// ```
+struct UnrollMultiReductionInnerParallel
+ : public vector::MaskableOpRewritePattern<vector::MultiDimReductionOp> {
+ using MaskableOpRewritePattern::MaskableOpRewritePattern;
+
+ FailureOr<Value>
+ matchAndRewriteMaskableOp(vector::MultiDimReductionOp multiReductionOp,
+ vector::MaskingOpInterface maskingOp,
+ PatternRewriter &rewriter) const override {
+ if (!multiReductionOp.isReducedDim(0))
+ return failure();
+
+ ArrayRef<int64_t> reductionDims = multiReductionOp.getReductionDims();
+ if (reductionDims.size() <= 1)
+ return failure();
+
+ Location loc = multiReductionOp.getLoc();
+ Value source = multiReductionOp.getSource();
+
+ ArrayRef<int64_t> srcShape =
+ multiReductionOp.getSourceVectorType().getShape();
+ int64_t outerDimSize = srcShape.front();
+
+ Value mask = maskingOp ? maskingOp.getMask() : nullptr;
+
+ SmallVector<Value> vectors(outerDimSize);
+ for (int64_t i = 0; i < outerDimSize; ++i)
+ vectors[i] = vector::ExtractOp::create(rewriter, loc, source, i);
+
+ SmallVector<Value> masks(outerDimSize);
+ if (mask)
+ for (int64_t i = 0; i < outerDimSize; ++i)
+ masks[i] = vector::ExtractOp::create(rewriter, loc, mask, i);
+
+ SmallVector<bool> fullReductionMask = multiReductionOp.getReductionMask();
+ ArrayRef<bool> reductionMask =
+ ArrayRef<bool>(fullReductionMask).drop_front();
+ Value result = multiReductionOp.getAcc();
+ for (auto [innerVector, innerMask] : llvm::zip_equal(vectors, masks)) {
+ auto reductionOp = vector::MultiDimReductionOp::create(
+ rewriter, loc, innerVector, result, reductionMask,
+ multiReductionOp.getKind());
+
+ if (innerMask) {
+ Operation *maskOp =
+ vector::maskOperation(rewriter, reductionOp, innerMask);
+ result = maskOp->getResult(0);
+ } else {
+ result = reductionOp.getResult();
+ }
+ }
+
+ return result;
+ }
+};
+
/// Converts 1D vector.multi_reduction directly to vector.reduction.
///
/// Example:
@@ -544,6 +620,12 @@ void mlir::vector::populateVectorMultiReductionLoweringPatterns(
benefit);
}
+void mlir::vector::populateVectorMultiReductionUnrollingPatterns(
+ RewritePatternSet &patterns, PatternBenefit benefit) {
+ patterns.add<UnrollMultiReductionInnerParallel>(patterns.getContext(),
+ benefit);
+}
+
std::unique_ptr<Pass> vector::createLowerVectorMultiReductionPass(
vector::VectorMultiReductionLowering option) {
return std::make_unique<LowerVectorMultiReductionPass>(option);
>From 2ea3b59224e103f515c1abebc802435147a4dbff Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Fri, 6 Mar 2026 10:37:47 -0500
Subject: [PATCH 3/5] [mlir][vector] Add multi_reduction unrolling
innerreduction.
Unrolls multi_reduction by series of extractions, multi_reduction
and insertions.
---
.../Vector/TransformOps/VectorTransformOps.td | 13 +-
.../Vector/Transforms/LoweringPatterns.h | 5 +-
.../TransformOps/VectorTransformOps.cpp | 5 +-
.../Transforms/LowerVectorMultiReduction.cpp | 125 +++++++++++++++++-
.../python/dialects/transform_vector_ext.py | 4 +-
5 files changed, 143 insertions(+), 9 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index dff2803f95ae6..c9226af30e674 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -269,10 +269,19 @@ def ApplyUnrollMultiReductionUnrollingPatternsOp: Op<Transform_Dialect,
Indicates that vector multi_reduction operations with more than one
reduction dimension should be unrolled in a rank-reducing way.
- This populates the patterns:
+ This populates the patterns from
+ `populateVectorMultiReductionUnrollingPatterns`, i.e.:
+ * `UnrollMultiReductionInnerReduction` (inner_reduction)
+ * `UnrollMultiReductionInnerParallel` (inner_parallel)
}];
- let assemblyFormat = "attr-dict";
+ let arguments = (ins DefaultValuedAttr<VectorMultiReductionLoweringAttr,
+ "vector::VectorMultiReductionLowering::InnerParallel">:$lowering_strategy
+ );
+
+ let assemblyFormat = [{
+ (`lowering_strategy` `=` $lowering_strategy^)? attr-dict
+ }];
}
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index d9237bac63b78..744a7d4b3b425 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -83,8 +83,9 @@ void populateVectorMultiReductionFlatteningPatterns(
/// Populate the pattern set with the following patterns:
///
-void populateVectorMultiReductionUnrollingPatterns(RewritePatternSet &patterns,
- PatternBenefit benefit = 1);
+void populateVectorMultiReductionUnrollingPatterns(
+ RewritePatternSet &patterns, VectorMultiReductionLowering options,
+ PatternBenefit benefit = 1);
/// Populate the pattern set with the following patterns:
///
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 256384f6b7c18..f63d43a7d165c 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -147,7 +147,10 @@ void transform::ApplyMultiReductionFlatteningPatternsOp::populatePatterns(
void transform::ApplyUnrollMultiReductionUnrollingPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
- vector::populateVectorMultiReductionUnrollingPatterns(patterns);
+ vector::VectorTransformsOptions vectorTransformOptions;
+ vectorTransformOptions.setVectorMultiReductionLowering(getLoweringStrategy());
+ vector::populateVectorMultiReductionUnrollingPatterns(
+ patterns, vectorTransformOptions.vectorMultiReductionLowering);
}
void transform::ApplyMultiReductionLoweringPatternsOp::populatePatterns(
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 51b2f64c9a7da..e347ba15537bd 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -519,6 +519,120 @@ struct UnrollMultiReductionInnerParallel
}
};
+/// Unrolls vector.multi_reduction along the outermost parallel dimension
+/// when the innermost dimension is a reduction dimension.
+///
+/// This pattern matches operations where:
+/// - The innermost dimension is a reduction dimension
+/// - The outermost dimension is a parallel dimension
+///
+/// The transformation extracts slices along the outermost parallel dimension,
+/// creates smaller multi_reductions, and assembles the results:
+///
+/// ```mlir
+/// %res = vector.multi_reduction <add> %src, %acc [1, 3]
+/// : vector<AxBxCxDxf32> to vector<AxCxf32>
+/// ```
+///
+/// becomes:
+///
+/// ```mlir
+/// %result = arith.constant dense<0.0> : vector<AxCxf32>
+/// %0 = vector.extract %src[0] : vector<BxCxDxf32> from vector<AxBxCxDxf32>
+/// %acc0 = vector.extract %acc[0] : vector<Cxf32> from vector<AxCxf32>
+/// %red0 = vector.multi_reduction <add>, %0, %acc0 [0, 2]
+/// : vector<BxCxDxf32> to vector<Cxf32>
+/// %res0 = vector.insert %red0, %result[0]
+/// : vector<Cxf32> into vector<AxCxf32>
+/// // ... repeat for indices 1 to A-1
+/// ```
+struct UnrollMultiReductionInnerReduction
+ : public vector::MaskableOpRewritePattern<vector::MultiDimReductionOp> {
+ using MaskableOpRewritePattern::MaskableOpRewritePattern;
+
+ FailureOr<Value>
+ matchAndRewriteMaskableOp(vector::MultiDimReductionOp multiReductionOp,
+ vector::MaskingOpInterface maskingOp,
+ PatternRewriter &rewriter) const override {
+ auto srcRank = multiReductionOp.getSourceVectorType().getRank();
+
+ if (srcRank < 2)
+ return rewriter.notifyMatchFailure(multiReductionOp,
+ "expected source rank >= 2.");
+
+ if (!multiReductionOp.isReducedDim(srcRank - 1))
+ return rewriter.notifyMatchFailure(
+ multiReductionOp,
+ "expected innermost dimension to be a reduction dimension.");
+
+ if (multiReductionOp.isReducedDim(0))
+ return rewriter.notifyMatchFailure(
+ multiReductionOp,
+ "expected outermost dimension to be a parallel dimension.");
+
+ Type elementType = getElementTypeOrSelf(multiReductionOp.getDestType());
+ if (!elementType.isIntOrIndexOrFloat())
+ return rewriter.notifyMatchFailure(
+ multiReductionOp, "expected integer or float element type.");
+
+ Location loc = multiReductionOp.getLoc();
+ Value source = multiReductionOp.getSource();
+ Value acc = multiReductionOp.getAcc();
+
+ ArrayRef<int64_t> srcShape =
+ multiReductionOp.getSourceVectorType().getShape();
+ int64_t numSlices = srcShape.front();
+
+ Value mask = maskingOp ? maskingOp.getMask() : Value();
+
+ SmallVector<Value> srcSlices;
+ for (int64_t i = 0; i < numSlices; ++i)
+ srcSlices.push_back(vector::ExtractOp::create(rewriter, loc, source, i));
+
+ SmallVector<Value> accSlices;
+ for (int64_t i = 0; i < numSlices; ++i)
+ accSlices.push_back(vector::ExtractOp::create(rewriter, loc, acc, i));
+
+ SmallVector<Value> maskSlices;
+ for (int64_t i = 0; i < numSlices; ++i)
+ if (mask)
+ maskSlices.push_back(vector::ExtractOp::create(rewriter, loc, mask, i));
+ else
+ maskSlices.push_back(nullptr);
+
+ // Compute new reduction mask by dropping the first element (dimension 0).
+ // Since dimension 0 is parallel (not reduced), all reduction indices shift
+ // down by 1.
+ SmallVector<bool> fullReductionMask = multiReductionOp.getReductionMask();
+ ArrayRef<bool> newReductionMask =
+ ArrayRef<bool>(fullReductionMask).drop_front();
+
+ SmallVector<Value> reductionResults;
+ for (auto [srcSlice, accSlice, maskSlice] :
+ llvm::zip(srcSlices, accSlices, maskSlices)) {
+ Operation *newReductionOp = vector::MultiDimReductionOp::create(
+ rewriter, loc, srcSlice, accSlice, newReductionMask,
+ multiReductionOp.getKind());
+
+ if (maskSlice)
+ newReductionOp =
+ mlir::vector::maskOperation(rewriter, newReductionOp, maskSlice);
+
+ reductionResults.push_back(newReductionOp->getResult(0));
+ }
+
+ Value result = arith::ConstantOp::create(
+ rewriter, loc, multiReductionOp.getDestType(),
+ rewriter.getZeroAttr(multiReductionOp.getDestType()));
+
+ for (int64_t i = 0; i < numSlices; ++i)
+ result = vector::InsertOp::create(rewriter, loc, reductionResults[i],
+ result, i);
+
+ return result;
+ }
+};
+
/// Converts 1D vector.multi_reduction directly to vector.reduction.
///
/// Example:
@@ -621,9 +735,14 @@ void mlir::vector::populateVectorMultiReductionLoweringPatterns(
}
void mlir::vector::populateVectorMultiReductionUnrollingPatterns(
- RewritePatternSet &patterns, PatternBenefit benefit) {
- patterns.add<UnrollMultiReductionInnerParallel>(patterns.getContext(),
- benefit);
+ RewritePatternSet &patterns, VectorMultiReductionLowering options,
+ PatternBenefit benefit) {
+ if (options == VectorMultiReductionLowering ::InnerReduction)
+ patterns.add<UnrollMultiReductionInnerReduction>(patterns.getContext(),
+ benefit);
+ else
+ patterns.add<UnrollMultiReductionInnerParallel>(patterns.getContext(),
+ benefit);
}
std::unique_ptr<Pass> vector::createLowerVectorMultiReductionPass(
diff --git a/mlir/test/python/dialects/transform_vector_ext.py b/mlir/test/python/dialects/transform_vector_ext.py
index cfaf91d24f471..84d70181b5288 100644
--- a/mlir/test/python/dialects/transform_vector_ext.py
+++ b/mlir/test/python/dialects/transform_vector_ext.py
@@ -105,7 +105,9 @@ def enum_configurable_patterns():
# CHECK: transform.apply_patterns.vector.multi_reduction_lowering
vector.ApplyMultiReductionLoweringPatternsOp()
- # CHECK: transform.apply_patterns.vector.multi_reduction_lowering
+ # CHECK: transform.apply_patterns.vector.multi_reduction_unrolling
+ vector.ApplyUnrollMultiReductionUnrollingPatternsOp()
+ # CHECK: transform.apply_patterns.vector.multi_reduction_unrolling
# CHECK-SAME: lowering_strategy = innerreduction
vector.ApplyMultiReductionUnrollingPatternsOp(
lowering_strategy=vector.VectorMultiReductionLowering.InnerReduction
>From 4328d792b9c104cb50346deb039440e08f071f1f Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Fri, 6 Mar 2026 10:39:07 -0500
Subject: [PATCH 4/5] Add tests
---
...unroll-multi-reduction-inner-parallel.mlir | 27 ++++++++
...nroll-multi-reduction-inner-reduction.mlir | 62 +++++++++++++++++++
2 files changed, 89 insertions(+)
create mode 100644 mlir/test/Dialect/Vector/vector-unroll-multi-reduction-inner-parallel.mlir
create mode 100644 mlir/test/Dialect/Vector/vector-unroll-multi-reduction-inner-reduction.mlir
diff --git a/mlir/test/Dialect/Vector/vector-unroll-multi-reduction-inner-parallel.mlir b/mlir/test/Dialect/Vector/vector-unroll-multi-reduction-inner-parallel.mlir
new file mode 100644
index 0000000000000..4f67d5678d6d3
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-unroll-multi-reduction-inner-parallel.mlir
@@ -0,0 +1,27 @@
+// RUN: mlir-opt %s --transform-interpreter | FileCheck %s
+
+// CHECK-LABEL: func @unroll_multi_reduction_inner_parallel
+// CHECK-SAME: %[[INPUT:.+]]: vector<4x2x3xf32>, %[[ACC:.+]]: vector<2xf32>
+func.func @unroll_multi_reduction_inner_parallel(%arg0: vector<4x2x3xf32>, %acc: vector<2xf32>) -> vector<2xf32> {
+ // CHECK: %[[V0:.+]] = vector.extract %[[INPUT]][0] : vector<2x3xf32> from vector<4x2x3xf32>
+ // CHECK: %[[V1:.+]] = vector.extract %[[INPUT]][1] : vector<2x3xf32> from vector<4x2x3xf32>
+ // CHECK: %[[V2:.+]] = vector.extract %[[INPUT]][2] : vector<2x3xf32> from vector<4x2x3xf32>
+ // CHECK: %[[V3:.+]] = vector.extract %[[INPUT]][3] : vector<2x3xf32> from vector<4x2x3xf32>
+ // CHECK: %[[RV0:.+]] = vector.multi_reduction <mul>, %[[V0]], %[[ACC]] [1] : vector<2x3xf32> to vector<2xf32>
+ // CHECK: %[[RV1:.+]] = vector.multi_reduction <mul>, %[[V1]], %[[RV0]] [1] : vector<2x3xf32> to vector<2xf32>
+ // CHECK: %[[RV2:.+]] = vector.multi_reduction <mul>, %[[V2]], %[[RV1]] [1] : vector<2x3xf32> to vector<2xf32>
+ // CHECK: %[[RESULT:.+]] = vector.multi_reduction <mul>, %[[V3]], %[[RV2]] [1] : vector<2x3xf32> to vector<2xf32>
+ %0 = vector.multi_reduction <mul>, %arg0, %acc [0, 2] : vector<4x2x3xf32> to vector<2xf32>
+ // CHECK: return %[[RESULT]]
+ return %0 : vector<2xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) {
+ %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
+ transform.apply_patterns to %func_op {
+ transform.apply_patterns.vector.multi_reduction_unrolling
+ } : !transform.op<"func.func">
+ transform.yield
+ }
+}
diff --git a/mlir/test/Dialect/Vector/vector-unroll-multi-reduction-inner-reduction.mlir b/mlir/test/Dialect/Vector/vector-unroll-multi-reduction-inner-reduction.mlir
new file mode 100644
index 0000000000000..04217479e377b
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-unroll-multi-reduction-inner-reduction.mlir
@@ -0,0 +1,62 @@
+// RUN: mlir-opt %s --transform-interpreter | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// Test UnrollVectorMultiReduction for Inner Reduction
+//===----------------------------------------------------------------------===//
+//
+// The general case handles multiple reduction dimensions.
+// For vector<2x3x5xf32> with reduction on dims [1, 2]:
+// UnrollMultiReductionInnerReduction unrolls along dim 0 (size 2), creating
+// two vector<3x5xf32> multi_reductions with dims [0, 1], then insert results.
+
+// CHECK-LABEL: func @unroll_vector_multi_reduction_inner_general(
+// CHECK-SAME: %[[SOURCE:.+]]: vector<2x3x5xf32>,
+// CHECK-SAME: %[[ACC:.+]]: vector<2xf32>
+func.func @unroll_vector_multi_reduction_inner_general(%source: vector<2x3x5xf32>, %acc: vector<2xf32>) -> (vector<2xf32>) {
+ // CHECK-DAG: %[[SRC_0:.+]] = vector.extract %[[SOURCE]][0] : vector<3x5xf32> from vector<2x3x5xf32>
+ // CHECK-DAG: %[[SRC_1:.+]] = vector.extract %[[SOURCE]][1] : vector<3x5xf32> from vector<2x3x5xf32>
+ // CHECK-DAG: %[[ACC_0:.+]] = vector.extract %[[ACC]][0] : f32 from vector<2xf32>
+ // CHECK-DAG: %[[ACC_1:.+]] = vector.extract %[[ACC]][1] : f32 from vector<2xf32>
+ // CHECK: %[[R0:.+]] = vector.multi_reduction <add>, %[[SRC_0]], %[[ACC_0]] [0, 1] : vector<3x5xf32> to f32
+ // CHECK: %[[R1:.+]] = vector.multi_reduction <add>, %[[SRC_1]], %[[ACC_1]] [0, 1] : vector<3x5xf32> to f32
+ // CHECK: %[[INSERT_0:.+]] = vector.insert %[[R0]], %{{.+}} [0] : f32 into vector<2xf32>
+ // CHECK: %[[INSERT_1:.+]] = vector.insert %[[R1]], %[[INSERT_0]] [1] : f32 into vector<2xf32>
+ %1 = vector.multi_reduction <add>, %source, %acc [1, 2] : vector<2x3x5xf32> to vector<2xf32>
+
+ // CHECK: return %[[INSERT_1]]
+ return %1 : vector<2xf32>
+}
+
+// CHECK-LABEL: func @unroll_vector_multi_reduction_inner_general_masked(
+// CHECK-SAME: %[[SOURCE:.+]]: vector<2x3x5xf32>,
+// CHECK-SAME: %[[MASK:.+]]: vector<2x3x5xi1>,
+// CHECK-SAME: %[[ACC:.+]]: vector<2xf32>
+func.func @unroll_vector_multi_reduction_inner_general_masked(%source: vector<2x3x5xf32>, %mask: vector<2x3x5xi1>, %acc: vector<2xf32>) -> (vector<2xf32>) {
+ // CHECK-DAG: %[[SRC_0:.+]] = vector.extract %[[SOURCE]][0] : vector<3x5xf32> from vector<2x3x5xf32>
+ // CHECK-DAG: %[[SRC_1:.+]] = vector.extract %[[SOURCE]][1] : vector<3x5xf32> from vector<2x3x5xf32>
+ // CHECK-DAG: %[[ACC_0:.+]] = vector.extract %[[ACC]][0] : f32 from vector<2xf32>
+ // CHECK-DAG: %[[ACC_1:.+]] = vector.extract %[[ACC]][1] : f32 from vector<2xf32>
+ // CHECK-DAG: %[[MASK_0:.+]] = vector.extract %[[MASK]][0] : vector<3x5xi1> from vector<2x3x5xi1>
+ // CHECK-DAG: %[[MASK_1:.+]] = vector.extract %[[MASK]][1] : vector<3x5xi1> from vector<2x3x5xi1>
+ // CHECK: %[[R0:.+]] = vector.mask %[[MASK_0]] {{.*}} vector.multi_reduction <add>, %[[SRC_0]], %[[ACC_0]] [0, 1] : vector<3x5xf32> to f32
+ // CHECK: %[[R1:.+]] = vector.mask %[[MASK_1]] {{.*}} vector.multi_reduction <add>, %[[SRC_1]], %[[ACC_1]] [0, 1] : vector<3x5xf32> to f32
+ // CHECK: %[[INSERT_0:.+]] = vector.insert %[[R0]], %{{.+}} [0] : f32 into vector<2xf32>
+ // CHECK: %[[INSERT_1:.+]] = vector.insert %[[R1]], %[[INSERT_0]] [1] : f32 into vector<2xf32>
+
+ %0 = vector.mask %mask {
+ %1 = vector.multi_reduction <add>, %source, %acc [1, 2] : vector<2x3x5xf32> to vector<2xf32>
+ } : vector<2x3x5xi1> -> vector<2xf32>
+
+ // CHECK: return %[[INSERT_1]]
+ return %0 : vector<2xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) {
+ %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
+ transform.apply_patterns to %func_op {
+ transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = innerreduction
+ } : !transform.op<"func.func">
+ transform.yield
+ }
+}
>From 9245142ef7cad0ebe090216ad5b179507f6b5edd Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Fri, 6 Mar 2026 11:07:58 -0500
Subject: [PATCH 5/5] Update benefits to ensure previous behaviour remains the
same
---
.../Vector/Transforms/LoweringPatterns.h | 22 ++++++++++++++++---
.../Transforms/LowerVectorMultiReduction.cpp | 16 +++++---------
mlir/test/Dialect/LLVM/transform-e2e.mlir | 1 +
.../test/Dialect/Vector/transform-vector.mlir | 1 +
.../Dialect/Linalg/CPU/ArmSVE/reduce_1d.mlir | 1 +
.../Dialect/Linalg/CPU/ArmSVE/reduce_2d.mlir | 1 +
.../Linalg/CPU/test-matmul-masked-vec.mlir | 1 +
7 files changed, 29 insertions(+), 14 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 744a7d4b3b425..f9eeaf4136370 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -66,9 +66,12 @@ void populateVectorOuterProductLoweringPatterns(RewritePatternSet &patterns,
/// Rewrites vector.multi_reduction such that all reduction dimensions are
/// either innermost or outermost, by adding the proper vector.transpose
/// operations.
+///
+/// The benefit is set to be higher than the unrolling patterns. Otherwise
+/// patterns here would match the same operations as those in unrolling.
void populateVectorMultiReductionReorderPatterns(
RewritePatternSet &patterns, VectorMultiReductionLowering options,
- PatternBenefit benefit = 1);
+ PatternBenefit benefit = 2);
/// Populate the pattern set with the following patterns:
///
@@ -77,12 +80,22 @@ void populateVectorMultiReductionReorderPatterns(
/// form, rewrites n-D vector.multi_reduction into 2-D vector.multi_reduction,
/// by introducing vector.shape_cast ops to collapse + multi-reduce + expand
/// back.
+///
+/// The benefit is set to be higher than the unrolling patterns. Otherwise
+/// patterns here would match the same operations as those in unrolling.
void populateVectorMultiReductionFlatteningPatterns(
RewritePatternSet &patterns, VectorMultiReductionLowering options,
- PatternBenefit benefit = 1);
+ PatternBenefit benefit = 2);
/// Populate the pattern set with the following patterns:
///
+/// [UnrollMultiReductionInnerReduction]
+/// Extracts vectors along outer dimension
+/// and chains multiple multi_reduction operations.
+///
+/// [UnrollMultiReductionInnerParallel]
+/// Extracts vectors along outer dimension, performs multi_reduction operations
+/// and inserts them back to a vector.multi_reduction with a lower rank.
void populateVectorMultiReductionUnrollingPatterns(
RewritePatternSet &patterns, VectorMultiReductionLowering options,
PatternBenefit benefit = 1);
@@ -102,9 +115,12 @@ void populateVectorMultiReductionUnrollingPatterns(
/// dimension, unroll the outer dimension to obtain a sequence of extract +
/// vector.reduction + insert. This can further lower to horizontal reduction
/// ops.
+///
+/// The benefit is set to be higher than the unrolling patterns. Otherwise
+/// patterns here would match the same operations as those in unrolling.
void populateVectorMultiReductionLoweringPatterns(
RewritePatternSet &patterns, VectorMultiReductionLowering options,
- PatternBenefit benefit = 1);
+ PatternBenefit benefit = 2);
/// Populate the pattern set with the following patterns:
///
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index e347ba15537bd..3f2b399e55343 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -686,19 +686,13 @@ struct LowerVectorMultiReductionPass
RewritePatternSet patterns(context);
mlir::vector::populateVectorMultiReductionReorderPatterns(
patterns, this->loweringStrategy);
- if (failed(applyPatternsGreedily(op, std::move(patterns))))
- signalPassFailure();
-
- RewritePatternSet flatteningPatterns(context);
mlir::vector::populateVectorMultiReductionFlatteningPatterns(
- flatteningPatterns, this->loweringStrategy);
- if (failed(applyPatternsGreedily(op, std::move(flatteningPatterns))))
- signalPassFailure();
-
- RewritePatternSet loweringPatterns(context);
+ patterns, this->loweringStrategy);
+ mlir::vector::populateVectorMultiReductionUnrollingPatterns(
+ patterns, this->loweringStrategy);
mlir::vector::populateVectorMultiReductionLoweringPatterns(
- loweringPatterns, this->loweringStrategy);
- if (failed(applyPatternsGreedily(op, std::move(loweringPatterns))))
+ patterns, this->loweringStrategy);
+ if (failed(applyPatternsGreedily(op, std::move(patterns))))
signalPassFailure();
}
diff --git a/mlir/test/Dialect/LLVM/transform-e2e.mlir b/mlir/test/Dialect/LLVM/transform-e2e.mlir
index 7f4c78a7fc2c8..b4cb4b75a5f4d 100644
--- a/mlir/test/Dialect/LLVM/transform-e2e.mlir
+++ b/mlir/test/Dialect/LLVM/transform-e2e.mlir
@@ -32,6 +32,7 @@ module attributes {transform.with_named_sequence} {
transform.apply_patterns.vector.transfer_permutation_patterns
transform.apply_patterns.vector.reorder_multi_reduction_dims lowering_strategy = "innerparallel"
transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerparallel"
+ transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerparallel"
transform.apply_patterns.vector.multi_reduction_lowering lowering_strategy = "innerparallel"
transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "linalg-copy"
transform.apply_patterns.vector.transfer_to_scf max_transfer_rank = 1 full_unroll = true
diff --git a/mlir/test/Dialect/Vector/transform-vector.mlir b/mlir/test/Dialect/Vector/transform-vector.mlir
index 9027bef2eac5c..b8729d06c9cfc 100644
--- a/mlir/test/Dialect/Vector/transform-vector.mlir
+++ b/mlir/test/Dialect/Vector/transform-vector.mlir
@@ -41,6 +41,7 @@ module attributes {transform.with_named_sequence} {
transform.apply_patterns to %f {
transform.apply_patterns.vector.reorder_multi_reduction_dims lowering_strategy = "innerparallel"
transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerparallel"
+ transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerparallel"
transform.apply_patterns.vector.multi_reduction_lowering lowering_strategy = "innerparallel"
} : !transform.any_op
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_1d.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_1d.mlir
index 15f5bfd3619eb..23623b8b3f21e 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_1d.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_1d.mlir
@@ -152,6 +152,7 @@ module attributes {transform.with_named_sequence} {
transform.apply_patterns.vector.lower_masked_transfers
transform.apply_patterns.vector.reorder_multi_reduction_dims lowering_strategy = "innerreduction"
transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerreduction"
+ transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerreduction"
transform.apply_patterns.vector.multi_reduction_lowering lowering_strategy = "innerreduction"
} : !transform.op<"func.func">
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_2d.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_2d.mlir
index 0cfba97654769..bf6ab21afc909 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_2d.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_2d.mlir
@@ -157,6 +157,7 @@ module attributes {transform.with_named_sequence} {
transform.apply_patterns.vector.lower_masked_transfers
transform.apply_patterns.vector.reorder_multi_reduction_dims lowering_strategy = "innerreduction"
transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerreduction"
+ transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerreduction"
transform.apply_patterns.vector.multi_reduction_lowering lowering_strategy = "innerreduction"
} : !transform.op<"func.func">
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/test-matmul-masked-vec.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/test-matmul-masked-vec.mlir
index ce85720d665c2..8d973527847c1 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/test-matmul-masked-vec.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/test-matmul-masked-vec.mlir
@@ -55,6 +55,7 @@ module attributes {transform.with_named_sequence} {
transform.apply_patterns to %func_op {
transform.apply_patterns.vector.reorder_multi_reduction_dims lowering_strategy = "innerreduction"
transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerreduction"
+ transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerreduction"
transform.apply_patterns.vector.multi_reduction_lowering lowering_strategy = "innerreduction"
} : !transform.op<"func.func">
transform.yield
More information about the Mlir-commits
mailing list