[Mlir-commits] [mlir] [vector][multi_reduction] Add unrolling for vector.multi_reduction (PR #185033)
Erick Ochoa Lopez
llvmlistbot at llvm.org
Fri Mar 6 10:25:38 PST 2026
https://github.com/amd-eochoalo updated https://github.com/llvm/llvm-project/pull/185033
>From d07bef658ff8167892b8e56769bc4e52b5b9a712 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Fri, 6 Mar 2026 09:47:24 -0500
Subject: [PATCH 1/6] [mlir][vector] Rename multi_reduction's unrolling to
lowering.
Stated previously:
> Below are some notes on naming
> * RewriteXAsY or LowerXToY (if we're changing op kind)
>
> Concretely, TwoDimMultiReductionToReduction looks like a lowering
> (it rewrites to vector.reduction), not an unrolling,
https://github.com/llvm/llvm-project/pull/182301#issuecomment-3951827980
---
.../Dialect/Vector/TransformOps/VectorTransformOps.td | 8 ++++----
.../mlir/Dialect/Vector/Transforms/LoweringPatterns.h | 2 +-
.../Dialect/Vector/TransformOps/VectorTransformOps.cpp | 4 ++--
.../Vector/Transforms/LowerVectorMultiReduction.cpp | 10 +++++-----
mlir/test/Dialect/LLVM/transform-e2e.mlir | 2 +-
mlir/test/Dialect/Vector/transform-vector.mlir | 2 +-
.../Vector/vector-multi-reduction-unrolling.mlir | 4 ++--
.../Dialect/Linalg/CPU/ArmSVE/reduce_1d.mlir | 2 +-
.../Dialect/Linalg/CPU/ArmSVE/reduce_2d.mlir | 2 +-
.../Dialect/Linalg/CPU/test-matmul-masked-vec.mlir | 2 +-
mlir/test/python/dialects/transform_vector_ext.py | 6 +++---
11 files changed, 22 insertions(+), 22 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index dcd5f6ff3ad74..76e77b9f90a13 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -262,18 +262,18 @@ def ApplyMultiReductionFlatteningPatternsOp: Op<Transform_Dialect,
}];
}
-def ApplyMultiReductionUnrollingPatternsOp: Op<Transform_Dialect,
- "apply_patterns.vector.multi_reduction_unrolling",
+def ApplyMultiReductionLoweringPatternsOp: Op<Transform_Dialect,
+ "apply_patterns.vector.multi_reduction_lowering",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
let description = [{
- Indicates that vector multi_reduction operations should be unrolled.
+ Indicates that vector multi_reduction operations should be lowered.
1-D multi_reductions are converted directly to vector.reduction.
2-D multi_reductions are unrolled into either a sequence of
vector.reduction ops (innerreduction) or element-wise arith ops
(innerparallel).
This populates the patterns from
- `populateVectorMultiReductionUnrollingPatterns`, i.e.:
+ `populateVectorMultiReductionLoweringPatterns`, i.e.:
* `OneDimMultiReductionToReduction`
* `TwoDimMultiReductionToReduction` (innerreduction)
* `TwoDimMultiReductionToElementWise` (innerparallel)
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index aa75eff409ef9..2ea3367da414d 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -96,7 +96,7 @@ void populateVectorMultiReductionFlatteningPatterns(
/// dimension, unroll the outer dimension to obtain a sequence of extract +
/// vector.reduction + insert. This can further lower to horizontal reduction
/// ops.
-void populateVectorMultiReductionUnrollingPatterns(
+void populateVectorMultiReductionLoweringPatterns(
RewritePatternSet &patterns, VectorMultiReductionLowering options,
PatternBenefit benefit = 1);
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 312bd28ad48cf..d1799b3bd84d8 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -145,11 +145,11 @@ void transform::ApplyMultiReductionFlatteningPatternsOp::populatePatterns(
patterns, vectorTransformOptions.vectorMultiReductionLowering);
}
-void transform::ApplyMultiReductionUnrollingPatternsOp::populatePatterns(
+void transform::ApplyMultiReductionLoweringPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
vector::VectorTransformsOptions vectorTransformOptions;
vectorTransformOptions.setVectorMultiReductionLowering(getLoweringStrategy());
- vector::populateVectorMultiReductionUnrollingPatterns(
+ vector::populateVectorMultiReductionLoweringPatterns(
patterns, vectorTransformOptions.vectorMultiReductionLowering);
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 76599822fbfe4..88e669d1a34b0 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -505,10 +505,10 @@ struct LowerVectorMultiReductionPass
if (failed(applyPatternsGreedily(op, std::move(flatteningPatterns))))
signalPassFailure();
- RewritePatternSet unrollingPatterns(context);
- mlir::vector::populateVectorMultiReductionUnrollingPatterns(
- unrollingPatterns, this->loweringStrategy);
- if (failed(applyPatternsGreedily(op, std::move(unrollingPatterns))))
+ RewritePatternSet loweringPatterns(context);
+ mlir::vector::populateVectorMultiReductionLoweringPatterns(
+ loweringPatterns, this->loweringStrategy);
+ if (failed(applyPatternsGreedily(op, std::move(loweringPatterns))))
signalPassFailure();
}
@@ -532,7 +532,7 @@ void mlir::vector::populateVectorMultiReductionFlatteningPatterns(
patterns.add<FlattenMultiReduction>(patterns.getContext(), options, benefit);
}
-void mlir::vector::populateVectorMultiReductionUnrollingPatterns(
+void mlir::vector::populateVectorMultiReductionLoweringPatterns(
RewritePatternSet &patterns, VectorMultiReductionLowering options,
PatternBenefit benefit) {
patterns.add<OneDimMultiReductionToReduction>(patterns.getContext(), benefit);
diff --git a/mlir/test/Dialect/LLVM/transform-e2e.mlir b/mlir/test/Dialect/LLVM/transform-e2e.mlir
index bf7eba6e50174..7f4c78a7fc2c8 100644
--- a/mlir/test/Dialect/LLVM/transform-e2e.mlir
+++ b/mlir/test/Dialect/LLVM/transform-e2e.mlir
@@ -32,7 +32,7 @@ module attributes {transform.with_named_sequence} {
transform.apply_patterns.vector.transfer_permutation_patterns
transform.apply_patterns.vector.reorder_multi_reduction_dims lowering_strategy = "innerparallel"
transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerparallel"
- transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerparallel"
+ transform.apply_patterns.vector.multi_reduction_lowering lowering_strategy = "innerparallel"
transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "linalg-copy"
transform.apply_patterns.vector.transfer_to_scf max_transfer_rank = 1 full_unroll = true
transform.apply_patterns.vector.lower_transfer max_transfer_rank = 1
diff --git a/mlir/test/Dialect/Vector/transform-vector.mlir b/mlir/test/Dialect/Vector/transform-vector.mlir
index 4dc11c26e83f1..9027bef2eac5c 100644
--- a/mlir/test/Dialect/Vector/transform-vector.mlir
+++ b/mlir/test/Dialect/Vector/transform-vector.mlir
@@ -41,7 +41,7 @@ module attributes {transform.with_named_sequence} {
transform.apply_patterns to %f {
transform.apply_patterns.vector.reorder_multi_reduction_dims lowering_strategy = "innerparallel"
transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerparallel"
- transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerparallel"
+ transform.apply_patterns.vector.multi_reduction_lowering lowering_strategy = "innerparallel"
} : !transform.any_op
transform.apply_patterns to %f {
diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-unrolling.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-unrolling.mlir
index 447416ccba637..a7ec4b9206c33 100644
--- a/mlir/test/Dialect/Vector/vector-multi-reduction-unrolling.mlir
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-unrolling.mlir
@@ -117,7 +117,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @innerreduction(%root : !transform.any_op {transform.readonly}) {
%func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
transform.apply_patterns to %func_op {
- transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerreduction"
+ transform.apply_patterns.vector.multi_reduction_lowering lowering_strategy = "innerreduction"
} : !transform.op<"func.func">
transform.yield
}
@@ -125,7 +125,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @innerparallel(%root : !transform.any_op {transform.readonly}) {
%func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
transform.apply_patterns to %func_op {
- transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerparallel"
+ transform.apply_patterns.vector.multi_reduction_lowering lowering_strategy = "innerparallel"
} : !transform.op<"func.func">
transform.yield
}
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_1d.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_1d.mlir
index a7b0b27ca5fb9..15f5bfd3619eb 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_1d.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_1d.mlir
@@ -152,7 +152,7 @@ module attributes {transform.with_named_sequence} {
transform.apply_patterns.vector.lower_masked_transfers
transform.apply_patterns.vector.reorder_multi_reduction_dims lowering_strategy = "innerreduction"
transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerreduction"
- transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerreduction"
+ transform.apply_patterns.vector.multi_reduction_lowering lowering_strategy = "innerreduction"
} : !transform.op<"func.func">
transform.yield
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_2d.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_2d.mlir
index 4adc68966f17a..0cfba97654769 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_2d.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_2d.mlir
@@ -157,7 +157,7 @@ module attributes {transform.with_named_sequence} {
transform.apply_patterns.vector.lower_masked_transfers
transform.apply_patterns.vector.reorder_multi_reduction_dims lowering_strategy = "innerreduction"
transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerreduction"
- transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerreduction"
+ transform.apply_patterns.vector.multi_reduction_lowering lowering_strategy = "innerreduction"
} : !transform.op<"func.func">
transform.yield
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/test-matmul-masked-vec.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/test-matmul-masked-vec.mlir
index 0883e7b698f55..ce85720d665c2 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/test-matmul-masked-vec.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/test-matmul-masked-vec.mlir
@@ -55,7 +55,7 @@ module attributes {transform.with_named_sequence} {
transform.apply_patterns to %func_op {
transform.apply_patterns.vector.reorder_multi_reduction_dims lowering_strategy = "innerreduction"
transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerreduction"
- transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerreduction"
+ transform.apply_patterns.vector.multi_reduction_lowering lowering_strategy = "innerreduction"
} : !transform.op<"func.func">
transform.yield
}
diff --git a/mlir/test/python/dialects/transform_vector_ext.py b/mlir/test/python/dialects/transform_vector_ext.py
index a3c53a45048b2..cfaf91d24f471 100644
--- a/mlir/test/python/dialects/transform_vector_ext.py
+++ b/mlir/test/python/dialects/transform_vector_ext.py
@@ -103,9 +103,9 @@ def enum_configurable_patterns():
lowering_strategy=vector.VectorMultiReductionLowering.InnerReduction
)
- # CHECK: transform.apply_patterns.vector.multi_reduction_unrolling
- vector.ApplyMultiReductionUnrollingPatternsOp()
- # CHECK: transform.apply_patterns.vector.multi_reduction_unrolling
+ # CHECK: transform.apply_patterns.vector.multi_reduction_lowering
+ vector.ApplyMultiReductionLoweringPatternsOp()
+ # CHECK: transform.apply_patterns.vector.multi_reduction_lowering
# CHECK-SAME: lowering_strategy = innerreduction
vector.ApplyMultiReductionUnrollingPatternsOp(
lowering_strategy=vector.VectorMultiReductionLowering.InnerReduction
>From ba1fffa23f4b9cf6862897959db2cbecf2eea9d4 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Fri, 6 Mar 2026 10:01:00 -0500
Subject: [PATCH 2/6] [mlir][vector] Add multi_reduction unrolling
innerparallel.
When unrolling vector.multi_reduction with outermost
dimensions being reduced, exctract outermost dimension
vectors and chain them with vector.multi_reductions.
---
.../Vector/TransformOps/VectorTransformOps.td | 14 ++++
.../Vector/Transforms/LoweringPatterns.h | 5 ++
.../TransformOps/VectorTransformOps.cpp | 5 ++
.../Transforms/LowerVectorMultiReduction.cpp | 82 +++++++++++++++++++
4 files changed, 106 insertions(+)
diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index 76e77b9f90a13..dff2803f95ae6 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -262,6 +262,20 @@ def ApplyMultiReductionFlatteningPatternsOp: Op<Transform_Dialect,
}];
}
+def ApplyUnrollMultiReductionUnrollingPatternsOp: Op<Transform_Dialect,
+ "apply_patterns.vector.multi_reduction_unrolling",
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ let description = [{
+ Indicates that vector multi_reduction operations with more than one
+ reduction dimension should be unrolled in a rank-reducing way.
+
+ This populates the patterns:
+ }];
+
+ let assemblyFormat = "attr-dict";
+}
+
+
def ApplyMultiReductionLoweringPatternsOp: Op<Transform_Dialect,
"apply_patterns.vector.multi_reduction_lowering",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 2ea3367da414d..d9237bac63b78 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -81,6 +81,11 @@ void populateVectorMultiReductionFlatteningPatterns(
RewritePatternSet &patterns, VectorMultiReductionLowering options,
PatternBenefit benefit = 1);
+/// Populate the pattern set with the following patterns:
+///
+void populateVectorMultiReductionUnrollingPatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit = 1);
+
/// Populate the pattern set with the following patterns:
///
/// [OneDimMultiReductionToReduction]
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index d1799b3bd84d8..256384f6b7c18 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -145,6 +145,11 @@ void transform::ApplyMultiReductionFlatteningPatternsOp::populatePatterns(
patterns, vectorTransformOptions.vectorMultiReductionLowering);
}
+void transform::ApplyUnrollMultiReductionUnrollingPatternsOp::populatePatterns(
+ RewritePatternSet &patterns) {
+ vector::populateVectorMultiReductionUnrollingPatterns(patterns);
+}
+
void transform::ApplyMultiReductionLoweringPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
vector::VectorTransformsOptions vectorTransformOptions;
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 88e669d1a34b0..51b2f64c9a7da 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -443,6 +443,82 @@ struct TwoDimMultiReductionToReduction
}
};
+/// Unrolls outermost dimension for vector.multi_reduction.
+/// Matches when the outermost dimension is not the only
+/// reduction dimension.
+///
+/// ```mlir
+/// %res = vector.multi_reduction <add> %src, %acc [0, [[REDUCTION_DIMS]] ] :
+/// vector<NxMx...xf32> to vector<Ix...xf32>
+/// ```
+///
+/// ```mlir
+/// %0 = vector.extract %src[0] : vector<Mx...xf32> from vector<NxMx...xf32>
+/// ...
+/// %Nminus1 = vector.extract %src[ [[N-1]] ] : vector<Mx...x.f32> from
+/// vector<NxMx...xf32>
+///
+/// %red0 = vector.multi_reduction %0, %acc [ [[REDUCTION_DIMS]] ] :
+/// vector<Mx...xf32> to vector<Ix...xf32>
+/// ...
+/// %res = vector.multi_reduction %Nminus1, %redNminus2 [ [[REDUCTION_DIMS]] ] :
+/// vector<Mx...xf32> to vector<Ix...xf32>
+/// ```
+struct UnrollMultiReductionInnerParallel
+ : public vector::MaskableOpRewritePattern<vector::MultiDimReductionOp> {
+ using MaskableOpRewritePattern::MaskableOpRewritePattern;
+
+ FailureOr<Value>
+ matchAndRewriteMaskableOp(vector::MultiDimReductionOp multiReductionOp,
+ vector::MaskingOpInterface maskingOp,
+ PatternRewriter &rewriter) const override {
+ if (!multiReductionOp.isReducedDim(0))
+ return failure();
+
+ ArrayRef<int64_t> reductionDims = multiReductionOp.getReductionDims();
+ if (reductionDims.size() <= 1)
+ return failure();
+
+ Location loc = multiReductionOp.getLoc();
+ Value source = multiReductionOp.getSource();
+
+ ArrayRef<int64_t> srcShape =
+ multiReductionOp.getSourceVectorType().getShape();
+ int64_t outerDimSize = srcShape.front();
+
+ Value mask = maskingOp ? maskingOp.getMask() : nullptr;
+
+ SmallVector<Value> vectors(outerDimSize);
+ for (int64_t i = 0; i < outerDimSize; ++i)
+ vectors[i] = vector::ExtractOp::create(rewriter, loc, source, i);
+
+ SmallVector<Value> masks(outerDimSize);
+ if (mask)
+ for (int64_t i = 0; i < outerDimSize; ++i)
+ masks[i] = vector::ExtractOp::create(rewriter, loc, mask, i);
+
+ SmallVector<bool> fullReductionMask = multiReductionOp.getReductionMask();
+ ArrayRef<bool> reductionMask =
+ ArrayRef<bool>(fullReductionMask).drop_front();
+ Value result = multiReductionOp.getAcc();
+ for (auto [innerVector, innerMask] : llvm::zip_equal(vectors, masks)) {
+ auto reductionOp = vector::MultiDimReductionOp::create(
+ rewriter, loc, innerVector, result, reductionMask,
+ multiReductionOp.getKind());
+
+ if (innerMask) {
+ Operation *maskOp =
+ vector::maskOperation(rewriter, reductionOp, innerMask);
+ result = maskOp->getResult(0);
+ } else {
+ result = reductionOp.getResult();
+ }
+ }
+
+ return result;
+ }
+};
+
/// Converts 1D vector.multi_reduction directly to vector.reduction.
///
/// Example:
@@ -544,6 +620,12 @@ void mlir::vector::populateVectorMultiReductionLoweringPatterns(
benefit);
}
+void mlir::vector::populateVectorMultiReductionUnrollingPatterns(
+ RewritePatternSet &patterns, PatternBenefit benefit) {
+ patterns.add<UnrollMultiReductionInnerParallel>(patterns.getContext(),
+ benefit);
+}
+
std::unique_ptr<Pass> vector::createLowerVectorMultiReductionPass(
vector::VectorMultiReductionLowering option) {
return std::make_unique<LowerVectorMultiReductionPass>(option);
>From 2ea3b59224e103f515c1abebc802435147a4dbff Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Fri, 6 Mar 2026 10:37:47 -0500
Subject: [PATCH 3/6] [mlir][vector] Add multi_reduction unrolling
innerreduction.
Unrolls multi_reduction by series of extractions, multi_reduction
and insertions.
---
.../Vector/TransformOps/VectorTransformOps.td | 13 +-
.../Vector/Transforms/LoweringPatterns.h | 5 +-
.../TransformOps/VectorTransformOps.cpp | 5 +-
.../Transforms/LowerVectorMultiReduction.cpp | 125 +++++++++++++++++-
.../python/dialects/transform_vector_ext.py | 4 +-
5 files changed, 143 insertions(+), 9 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index dff2803f95ae6..c9226af30e674 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -269,10 +269,19 @@ def ApplyUnrollMultiReductionUnrollingPatternsOp: Op<Transform_Dialect,
Indicates that vector multi_reduction operations with more than one
reduction dimension should be unrolled in a rank-reducing way.
- This populates the patterns:
+ This populates the patterns from
+ `populateVectorMultiReductionUnrollingPatterns`, i.e.:
+ * `UnrollMultiReductionInnerReduction` (inner_reduction)
+ * `UnrollMultiReductionInnerParallel` (inner_parallel)
}];
- let assemblyFormat = "attr-dict";
+ let arguments = (ins DefaultValuedAttr<VectorMultiReductionLoweringAttr,
+ "vector::VectorMultiReductionLowering::InnerParallel">:$lowering_strategy
+ );
+
+ let assemblyFormat = [{
+ (`lowering_strategy` `=` $lowering_strategy^)? attr-dict
+ }];
}
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index d9237bac63b78..744a7d4b3b425 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -83,8 +83,9 @@ void populateVectorMultiReductionFlatteningPatterns(
/// Populate the pattern set with the following patterns:
///
-void populateVectorMultiReductionUnrollingPatterns(RewritePatternSet &patterns,
- PatternBenefit benefit = 1);
+void populateVectorMultiReductionUnrollingPatterns(
+ RewritePatternSet &patterns, VectorMultiReductionLowering options,
+ PatternBenefit benefit = 1);
/// Populate the pattern set with the following patterns:
///
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 256384f6b7c18..f63d43a7d165c 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -147,7 +147,10 @@ void transform::ApplyMultiReductionFlatteningPatternsOp::populatePatterns(
void transform::ApplyUnrollMultiReductionUnrollingPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
- vector::populateVectorMultiReductionUnrollingPatterns(patterns);
+ vector::VectorTransformsOptions vectorTransformOptions;
+ vectorTransformOptions.setVectorMultiReductionLowering(getLoweringStrategy());
+ vector::populateVectorMultiReductionUnrollingPatterns(
+ patterns, vectorTransformOptions.vectorMultiReductionLowering);
}
void transform::ApplyMultiReductionLoweringPatternsOp::populatePatterns(
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 51b2f64c9a7da..e347ba15537bd 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -519,6 +519,120 @@ struct UnrollMultiReductionInnerParallel
}
};
+/// Unrolls vector.multi_reduction along the outermost parallel dimension
+/// when the innermost dimension is a reduction dimension.
+///
+/// This pattern matches operations where:
+/// - The innermost dimension is a reduction dimension
+/// - The outermost dimension is a parallel dimension
+///
+/// The transformation extracts slices along the outermost parallel dimension,
+/// creates smaller multi_reductions, and assembles the results:
+///
+/// ```mlir
+/// %res = vector.multi_reduction <add> %src, %acc [1, 3]
+/// : vector<AxBxCxDxf32> to vector<AxCxf32>
+/// ```
+///
+/// becomes:
+///
+/// ```mlir
+/// %result = arith.constant dense<0.0> : vector<AxCxf32>
+/// %0 = vector.extract %src[0] : vector<BxCxDxf32> from vector<AxBxCxDxf32>
+/// %acc0 = vector.extract %acc[0] : vector<Cxf32> from vector<AxCxf32>
+/// %red0 = vector.multi_reduction <add>, %0, %acc0 [0, 2]
+/// : vector<BxCxDxf32> to vector<Cxf32>
+/// %res0 = vector.insert %red0, %result[0]
+/// : vector<Cxf32> into vector<AxCxf32>
+/// // ... repeat for indices 1 to A-1
+/// ```
+struct UnrollMultiReductionInnerReduction
+ : public vector::MaskableOpRewritePattern<vector::MultiDimReductionOp> {
+ using MaskableOpRewritePattern::MaskableOpRewritePattern;
+
+ FailureOr<Value>
+ matchAndRewriteMaskableOp(vector::MultiDimReductionOp multiReductionOp,
+ vector::MaskingOpInterface maskingOp,
+ PatternRewriter &rewriter) const override {
+ auto srcRank = multiReductionOp.getSourceVectorType().getRank();
+
+ if (srcRank < 2)
+ return rewriter.notifyMatchFailure(multiReductionOp,
+ "expected source rank >= 2.");
+
+ if (!multiReductionOp.isReducedDim(srcRank - 1))
+ return rewriter.notifyMatchFailure(
+ multiReductionOp,
+ "expected innermost dimension to be a reduction dimension.");
+
+ if (multiReductionOp.isReducedDim(0))
+ return rewriter.notifyMatchFailure(
+ multiReductionOp,
+ "expected outermost dimension to be a parallel dimension.");
+
+ Type elementType = getElementTypeOrSelf(multiReductionOp.getDestType());
+ if (!elementType.isIntOrIndexOrFloat())
+ return rewriter.notifyMatchFailure(
+ multiReductionOp, "expected integer or float element type.");
+
+ Location loc = multiReductionOp.getLoc();
+ Value source = multiReductionOp.getSource();
+ Value acc = multiReductionOp.getAcc();
+
+ ArrayRef<int64_t> srcShape =
+ multiReductionOp.getSourceVectorType().getShape();
+ int64_t numSlices = srcShape.front();
+
+ Value mask = maskingOp ? maskingOp.getMask() : Value();
+
+ SmallVector<Value> srcSlices;
+ for (int64_t i = 0; i < numSlices; ++i)
+ srcSlices.push_back(vector::ExtractOp::create(rewriter, loc, source, i));
+
+ SmallVector<Value> accSlices;
+ for (int64_t i = 0; i < numSlices; ++i)
+ accSlices.push_back(vector::ExtractOp::create(rewriter, loc, acc, i));
+
+ SmallVector<Value> maskSlices;
+ for (int64_t i = 0; i < numSlices; ++i)
+ if (mask)
+ maskSlices.push_back(vector::ExtractOp::create(rewriter, loc, mask, i));
+ else
+ maskSlices.push_back(nullptr);
+
+ // Compute new reduction mask by dropping the first element (dimension 0).
+ // Since dimension 0 is parallel (not reduced), all reduction indices shift
+ // down by 1.
+ SmallVector<bool> fullReductionMask = multiReductionOp.getReductionMask();
+ ArrayRef<bool> newReductionMask =
+ ArrayRef<bool>(fullReductionMask).drop_front();
+
+ SmallVector<Value> reductionResults;
+ for (auto [srcSlice, accSlice, maskSlice] :
+ llvm::zip(srcSlices, accSlices, maskSlices)) {
+ Operation *newReductionOp = vector::MultiDimReductionOp::create(
+ rewriter, loc, srcSlice, accSlice, newReductionMask,
+ multiReductionOp.getKind());
+
+ if (maskSlice)
+ newReductionOp =
+ mlir::vector::maskOperation(rewriter, newReductionOp, maskSlice);
+
+ reductionResults.push_back(newReductionOp->getResult(0));
+ }
+
+ Value result = arith::ConstantOp::create(
+ rewriter, loc, multiReductionOp.getDestType(),
+ rewriter.getZeroAttr(multiReductionOp.getDestType()));
+
+ for (int64_t i = 0; i < numSlices; ++i)
+ result = vector::InsertOp::create(rewriter, loc, reductionResults[i],
+ result, i);
+
+ return result;
+ }
+};
+
/// Converts 1D vector.multi_reduction directly to vector.reduction.
///
/// Example:
@@ -621,9 +735,14 @@ void mlir::vector::populateVectorMultiReductionLoweringPatterns(
}
void mlir::vector::populateVectorMultiReductionUnrollingPatterns(
- RewritePatternSet &patterns, PatternBenefit benefit) {
- patterns.add<UnrollMultiReductionInnerParallel>(patterns.getContext(),
- benefit);
+ RewritePatternSet &patterns, VectorMultiReductionLowering options,
+ PatternBenefit benefit) {
+ if (options == VectorMultiReductionLowering ::InnerReduction)
+ patterns.add<UnrollMultiReductionInnerReduction>(patterns.getContext(),
+ benefit);
+ else
+ patterns.add<UnrollMultiReductionInnerParallel>(patterns.getContext(),
+ benefit);
}
std::unique_ptr<Pass> vector::createLowerVectorMultiReductionPass(
diff --git a/mlir/test/python/dialects/transform_vector_ext.py b/mlir/test/python/dialects/transform_vector_ext.py
index cfaf91d24f471..84d70181b5288 100644
--- a/mlir/test/python/dialects/transform_vector_ext.py
+++ b/mlir/test/python/dialects/transform_vector_ext.py
@@ -105,7 +105,9 @@ def enum_configurable_patterns():
# CHECK: transform.apply_patterns.vector.multi_reduction_lowering
vector.ApplyMultiReductionLoweringPatternsOp()
- # CHECK: transform.apply_patterns.vector.multi_reduction_lowering
+ # CHECK: transform.apply_patterns.vector.multi_reduction_unrolling
+ vector.ApplyUnrollMultiReductionUnrollingPatternsOp()
+ # CHECK: transform.apply_patterns.vector.multi_reduction_unrolling
# CHECK-SAME: lowering_strategy = innerreduction
vector.ApplyMultiReductionUnrollingPatternsOp(
lowering_strategy=vector.VectorMultiReductionLowering.InnerReduction
>From 4328d792b9c104cb50346deb039440e08f071f1f Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Fri, 6 Mar 2026 10:39:07 -0500
Subject: [PATCH 4/6] Add tests
---
...unroll-multi-reduction-inner-parallel.mlir | 27 ++++++++
...nroll-multi-reduction-inner-reduction.mlir | 62 +++++++++++++++++++
2 files changed, 89 insertions(+)
create mode 100644 mlir/test/Dialect/Vector/vector-unroll-multi-reduction-inner-parallel.mlir
create mode 100644 mlir/test/Dialect/Vector/vector-unroll-multi-reduction-inner-reduction.mlir
diff --git a/mlir/test/Dialect/Vector/vector-unroll-multi-reduction-inner-parallel.mlir b/mlir/test/Dialect/Vector/vector-unroll-multi-reduction-inner-parallel.mlir
new file mode 100644
index 0000000000000..4f67d5678d6d3
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-unroll-multi-reduction-inner-parallel.mlir
@@ -0,0 +1,27 @@
+// RUN: mlir-opt %s --transform-interpreter | FileCheck %s
+
+// CHECK-LABEL: func @unroll_multi_reduction_inner_parallel
+// CHECK-SAME: %[[INPUT:.+]]: vector<4x2x3xf32>, %[[ACC:.+]]: vector<2xf32>
+func.func @unroll_multi_reduction_inner_parallel(%arg0: vector<4x2x3xf32>, %acc: vector<2xf32>) -> vector<2xf32> {
+ // CHECK: %[[V0:.+]] = vector.extract %[[INPUT]][0] : vector<2x3xf32> from vector<4x2x3xf32>
+ // CHECK: %[[V1:.+]] = vector.extract %[[INPUT]][1] : vector<2x3xf32> from vector<4x2x3xf32>
+ // CHECK: %[[V2:.+]] = vector.extract %[[INPUT]][2] : vector<2x3xf32> from vector<4x2x3xf32>
+ // CHECK: %[[V3:.+]] = vector.extract %[[INPUT]][3] : vector<2x3xf32> from vector<4x2x3xf32>
+ // CHECK: %[[RV0:.+]] = vector.multi_reduction <mul>, %[[V0]], %[[ACC]] [1] : vector<2x3xf32> to vector<2xf32>
+ // CHECK: %[[RV1:.+]] = vector.multi_reduction <mul>, %[[V1]], %[[RV0]] [1] : vector<2x3xf32> to vector<2xf32>
+ // CHECK: %[[RV2:.+]] = vector.multi_reduction <mul>, %[[V2]], %[[RV1]] [1] : vector<2x3xf32> to vector<2xf32>
+ // CHECK: %[[RESULT:.+]] = vector.multi_reduction <mul>, %[[V3]], %[[RV2]] [1] : vector<2x3xf32> to vector<2xf32>
+ %0 = vector.multi_reduction <mul>, %arg0, %acc [0, 2] : vector<4x2x3xf32> to vector<2xf32>
+ // CHECK: return %[[RESULT]]
+ return %0 : vector<2xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) {
+ %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
+ transform.apply_patterns to %func_op {
+ transform.apply_patterns.vector.multi_reduction_unrolling
+ } : !transform.op<"func.func">
+ transform.yield
+ }
+}
diff --git a/mlir/test/Dialect/Vector/vector-unroll-multi-reduction-inner-reduction.mlir b/mlir/test/Dialect/Vector/vector-unroll-multi-reduction-inner-reduction.mlir
new file mode 100644
index 0000000000000..04217479e377b
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-unroll-multi-reduction-inner-reduction.mlir
@@ -0,0 +1,62 @@
+// RUN: mlir-opt %s --transform-interpreter | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// Test UnrollVectorMultiReduction for Inner Reduction
+//===----------------------------------------------------------------------===//
+//
+// The general case handles multiple reduction dimensions.
+// For vector<2x3x5xf32> with reduction on dims [1, 2]:
+// UnrollMultiReductionInnerReduction unrolls along dim 0 (size 2), creating
+// two vector<3x5xf32> multi_reductions with dims [0, 1], then insert results.
+
+// CHECK-LABEL: func @unroll_vector_multi_reduction_inner_general(
+// CHECK-SAME: %[[SOURCE:.+]]: vector<2x3x5xf32>,
+// CHECK-SAME: %[[ACC:.+]]: vector<2xf32>
+func.func @unroll_vector_multi_reduction_inner_general(%source: vector<2x3x5xf32>, %acc: vector<2xf32>) -> (vector<2xf32>) {
+ // CHECK-DAG: %[[SRC_0:.+]] = vector.extract %[[SOURCE]][0] : vector<3x5xf32> from vector<2x3x5xf32>
+ // CHECK-DAG: %[[SRC_1:.+]] = vector.extract %[[SOURCE]][1] : vector<3x5xf32> from vector<2x3x5xf32>
+ // CHECK-DAG: %[[ACC_0:.+]] = vector.extract %[[ACC]][0] : f32 from vector<2xf32>
+ // CHECK-DAG: %[[ACC_1:.+]] = vector.extract %[[ACC]][1] : f32 from vector<2xf32>
+ // CHECK: %[[R0:.+]] = vector.multi_reduction <add>, %[[SRC_0]], %[[ACC_0]] [0, 1] : vector<3x5xf32> to f32
+ // CHECK: %[[R1:.+]] = vector.multi_reduction <add>, %[[SRC_1]], %[[ACC_1]] [0, 1] : vector<3x5xf32> to f32
+ // CHECK: %[[INSERT_0:.+]] = vector.insert %[[R0]], %{{.+}} [0] : f32 into vector<2xf32>
+ // CHECK: %[[INSERT_1:.+]] = vector.insert %[[R1]], %[[INSERT_0]] [1] : f32 into vector<2xf32>
+ %1 = vector.multi_reduction <add>, %source, %acc [1, 2] : vector<2x3x5xf32> to vector<2xf32>
+
+ // CHECK: return %[[INSERT_1]]
+ return %1 : vector<2xf32>
+}
+
+// CHECK-LABEL: func @unroll_vector_multi_reduction_inner_general_masked(
+// CHECK-SAME: %[[SOURCE:.+]]: vector<2x3x5xf32>,
+// CHECK-SAME: %[[MASK:.+]]: vector<2x3x5xi1>,
+// CHECK-SAME: %[[ACC:.+]]: vector<2xf32>
+func.func @unroll_vector_multi_reduction_inner_general_masked(%source: vector<2x3x5xf32>, %mask: vector<2x3x5xi1>, %acc: vector<2xf32>) -> (vector<2xf32>) {
+ // CHECK-DAG: %[[SRC_0:.+]] = vector.extract %[[SOURCE]][0] : vector<3x5xf32> from vector<2x3x5xf32>
+ // CHECK-DAG: %[[SRC_1:.+]] = vector.extract %[[SOURCE]][1] : vector<3x5xf32> from vector<2x3x5xf32>
+ // CHECK-DAG: %[[ACC_0:.+]] = vector.extract %[[ACC]][0] : f32 from vector<2xf32>
+ // CHECK-DAG: %[[ACC_1:.+]] = vector.extract %[[ACC]][1] : f32 from vector<2xf32>
+ // CHECK-DAG: %[[MASK_0:.+]] = vector.extract %[[MASK]][0] : vector<3x5xi1> from vector<2x3x5xi1>
+ // CHECK-DAG: %[[MASK_1:.+]] = vector.extract %[[MASK]][1] : vector<3x5xi1> from vector<2x3x5xi1>
+ // CHECK: %[[R0:.+]] = vector.mask %[[MASK_0]] {{.*}} vector.multi_reduction <add>, %[[SRC_0]], %[[ACC_0]] [0, 1] : vector<3x5xf32> to f32
+ // CHECK: %[[R1:.+]] = vector.mask %[[MASK_1]] {{.*}} vector.multi_reduction <add>, %[[SRC_1]], %[[ACC_1]] [0, 1] : vector<3x5xf32> to f32
+ // CHECK: %[[INSERT_0:.+]] = vector.insert %[[R0]], %{{.+}} [0] : f32 into vector<2xf32>
+ // CHECK: %[[INSERT_1:.+]] = vector.insert %[[R1]], %[[INSERT_0]] [1] : f32 into vector<2xf32>
+
+ %0 = vector.mask %mask {
+ %1 = vector.multi_reduction <add>, %source, %acc [1, 2] : vector<2x3x5xf32> to vector<2xf32>
+ } : vector<2x3x5xi1> -> vector<2xf32>
+
+ // CHECK: return %[[INSERT_1]]
+ return %0 : vector<2xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) {
+ %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
+ transform.apply_patterns to %func_op {
+ transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = innerreduction
+ } : !transform.op<"func.func">
+ transform.yield
+ }
+}
>From 9245142ef7cad0ebe090216ad5b179507f6b5edd Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Fri, 6 Mar 2026 11:07:58 -0500
Subject: [PATCH 5/6] Update benefits to ensure previous behaviour remains the
same
---
.../Vector/Transforms/LoweringPatterns.h | 22 ++++++++++++++++---
.../Transforms/LowerVectorMultiReduction.cpp | 16 +++++---------
mlir/test/Dialect/LLVM/transform-e2e.mlir | 1 +
.../test/Dialect/Vector/transform-vector.mlir | 1 +
.../Dialect/Linalg/CPU/ArmSVE/reduce_1d.mlir | 1 +
.../Dialect/Linalg/CPU/ArmSVE/reduce_2d.mlir | 1 +
.../Linalg/CPU/test-matmul-masked-vec.mlir | 1 +
7 files changed, 29 insertions(+), 14 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 744a7d4b3b425..f9eeaf4136370 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -66,9 +66,12 @@ void populateVectorOuterProductLoweringPatterns(RewritePatternSet &patterns,
/// Rewrites vector.multi_reduction such that all reduction dimensions are
/// either innermost or outermost, by adding the proper vector.transpose
/// operations.
+///
+/// The benefit is set to be higher than the unrolling patterns. Otherwise
+/// patterns here would match the same operations as those in unrolling.
void populateVectorMultiReductionReorderPatterns(
RewritePatternSet &patterns, VectorMultiReductionLowering options,
- PatternBenefit benefit = 1);
+ PatternBenefit benefit = 2);
/// Populate the pattern set with the following patterns:
///
@@ -77,12 +80,22 @@ void populateVectorMultiReductionReorderPatterns(
/// form, rewrites n-D vector.multi_reduction into 2-D vector.multi_reduction,
/// by introducing vector.shape_cast ops to collapse + multi-reduce + expand
/// back.
+///
+/// The benefit is set to be higher than the unrolling patterns. Otherwise
+/// patterns here would match the same operations as those in unrolling.
void populateVectorMultiReductionFlatteningPatterns(
RewritePatternSet &patterns, VectorMultiReductionLowering options,
- PatternBenefit benefit = 1);
+ PatternBenefit benefit = 2);
/// Populate the pattern set with the following patterns:
///
+/// [UnrollMultiReductionInnerReduction]
+/// Extracts vectors along outer dimension
+/// and chains multiple multi_reduction operations.
+///
+/// [UnrollMultiReductionInnerParallel]
+/// Extracts vectors along outer dimension, performs multi_reduction operations
+/// and inserts them back to a vector.multi_reduction with a lower rank.
void populateVectorMultiReductionUnrollingPatterns(
RewritePatternSet &patterns, VectorMultiReductionLowering options,
PatternBenefit benefit = 1);
@@ -102,9 +115,12 @@ void populateVectorMultiReductionUnrollingPatterns(
/// dimension, unroll the outer dimension to obtain a sequence of extract +
/// vector.reduction + insert. This can further lower to horizontal reduction
/// ops.
+///
+/// The benefit is set to be higher than the unrolling patterns. Otherwise
+/// patterns here would match the same operations as those in unrolling.
void populateVectorMultiReductionLoweringPatterns(
RewritePatternSet &patterns, VectorMultiReductionLowering options,
- PatternBenefit benefit = 1);
+ PatternBenefit benefit = 2);
/// Populate the pattern set with the following patterns:
///
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index e347ba15537bd..3f2b399e55343 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -686,19 +686,13 @@ struct LowerVectorMultiReductionPass
RewritePatternSet patterns(context);
mlir::vector::populateVectorMultiReductionReorderPatterns(
patterns, this->loweringStrategy);
- if (failed(applyPatternsGreedily(op, std::move(patterns))))
- signalPassFailure();
-
- RewritePatternSet flatteningPatterns(context);
mlir::vector::populateVectorMultiReductionFlatteningPatterns(
- flatteningPatterns, this->loweringStrategy);
- if (failed(applyPatternsGreedily(op, std::move(flatteningPatterns))))
- signalPassFailure();
-
- RewritePatternSet loweringPatterns(context);
+ patterns, this->loweringStrategy);
+ mlir::vector::populateVectorMultiReductionUnrollingPatterns(
+ patterns, this->loweringStrategy);
mlir::vector::populateVectorMultiReductionLoweringPatterns(
- loweringPatterns, this->loweringStrategy);
- if (failed(applyPatternsGreedily(op, std::move(loweringPatterns))))
+ patterns, this->loweringStrategy);
+ if (failed(applyPatternsGreedily(op, std::move(patterns))))
signalPassFailure();
}
diff --git a/mlir/test/Dialect/LLVM/transform-e2e.mlir b/mlir/test/Dialect/LLVM/transform-e2e.mlir
index 7f4c78a7fc2c8..b4cb4b75a5f4d 100644
--- a/mlir/test/Dialect/LLVM/transform-e2e.mlir
+++ b/mlir/test/Dialect/LLVM/transform-e2e.mlir
@@ -32,6 +32,7 @@ module attributes {transform.with_named_sequence} {
transform.apply_patterns.vector.transfer_permutation_patterns
transform.apply_patterns.vector.reorder_multi_reduction_dims lowering_strategy = "innerparallel"
transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerparallel"
+ transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerparallel"
transform.apply_patterns.vector.multi_reduction_lowering lowering_strategy = "innerparallel"
transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "linalg-copy"
transform.apply_patterns.vector.transfer_to_scf max_transfer_rank = 1 full_unroll = true
diff --git a/mlir/test/Dialect/Vector/transform-vector.mlir b/mlir/test/Dialect/Vector/transform-vector.mlir
index 9027bef2eac5c..b8729d06c9cfc 100644
--- a/mlir/test/Dialect/Vector/transform-vector.mlir
+++ b/mlir/test/Dialect/Vector/transform-vector.mlir
@@ -41,6 +41,7 @@ module attributes {transform.with_named_sequence} {
transform.apply_patterns to %f {
transform.apply_patterns.vector.reorder_multi_reduction_dims lowering_strategy = "innerparallel"
transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerparallel"
+ transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerparallel"
transform.apply_patterns.vector.multi_reduction_lowering lowering_strategy = "innerparallel"
} : !transform.any_op
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_1d.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_1d.mlir
index 15f5bfd3619eb..23623b8b3f21e 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_1d.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_1d.mlir
@@ -152,6 +152,7 @@ module attributes {transform.with_named_sequence} {
transform.apply_patterns.vector.lower_masked_transfers
transform.apply_patterns.vector.reorder_multi_reduction_dims lowering_strategy = "innerreduction"
transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerreduction"
+ transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerreduction"
transform.apply_patterns.vector.multi_reduction_lowering lowering_strategy = "innerreduction"
} : !transform.op<"func.func">
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_2d.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_2d.mlir
index 0cfba97654769..bf6ab21afc909 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_2d.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_2d.mlir
@@ -157,6 +157,7 @@ module attributes {transform.with_named_sequence} {
transform.apply_patterns.vector.lower_masked_transfers
transform.apply_patterns.vector.reorder_multi_reduction_dims lowering_strategy = "innerreduction"
transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerreduction"
+ transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerreduction"
transform.apply_patterns.vector.multi_reduction_lowering lowering_strategy = "innerreduction"
} : !transform.op<"func.func">
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/test-matmul-masked-vec.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/test-matmul-masked-vec.mlir
index ce85720d665c2..8d973527847c1 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/test-matmul-masked-vec.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/test-matmul-masked-vec.mlir
@@ -55,6 +55,7 @@ module attributes {transform.with_named_sequence} {
transform.apply_patterns to %func_op {
transform.apply_patterns.vector.reorder_multi_reduction_dims lowering_strategy = "innerreduction"
transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerreduction"
+ transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerreduction"
transform.apply_patterns.vector.multi_reduction_lowering lowering_strategy = "innerreduction"
} : !transform.op<"func.func">
transform.yield
>From 365480502c15eada5308f9d14af3d09132d45b2c Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Fri, 6 Mar 2026 13:24:38 -0500
Subject: [PATCH 6/6] Fix identifier name
---
.../Vector/TransformOps/VectorTransformOps.td | 2 +-
.../Vector/TransformOps/VectorTransformOps.cpp | 2 +-
mlir/test/python/dialects/transform_vector_ext.py | 12 +++++++++---
3 files changed, 11 insertions(+), 5 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index c9226af30e674..b91bd09053c7e 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -262,7 +262,7 @@ def ApplyMultiReductionFlatteningPatternsOp: Op<Transform_Dialect,
}];
}
-def ApplyUnrollMultiReductionUnrollingPatternsOp: Op<Transform_Dialect,
+def ApplyMultiReductionUnrollingPatternsOp: Op<Transform_Dialect,
"apply_patterns.vector.multi_reduction_unrolling",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
let description = [{
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index f63d43a7d165c..55953aa22dd88 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -145,7 +145,7 @@ void transform::ApplyMultiReductionFlatteningPatternsOp::populatePatterns(
patterns, vectorTransformOptions.vectorMultiReductionLowering);
}
-void transform::ApplyUnrollMultiReductionUnrollingPatternsOp::populatePatterns(
+void transform::ApplyMultiReductionUnrollingPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
vector::VectorTransformsOptions vectorTransformOptions;
vectorTransformOptions.setVectorMultiReductionLowering(getLoweringStrategy());
diff --git a/mlir/test/python/dialects/transform_vector_ext.py b/mlir/test/python/dialects/transform_vector_ext.py
index 84d70181b5288..b1dae910052b7 100644
--- a/mlir/test/python/dialects/transform_vector_ext.py
+++ b/mlir/test/python/dialects/transform_vector_ext.py
@@ -103,16 +103,22 @@ def enum_configurable_patterns():
lowering_strategy=vector.VectorMultiReductionLowering.InnerReduction
)
- # CHECK: transform.apply_patterns.vector.multi_reduction_lowering
- vector.ApplyMultiReductionLoweringPatternsOp()
# CHECK: transform.apply_patterns.vector.multi_reduction_unrolling
- vector.ApplyUnrollMultiReductionUnrollingPatternsOp()
+ vector.ApplyMultiReductionUnrollingPatternsOp()
# CHECK: transform.apply_patterns.vector.multi_reduction_unrolling
# CHECK-SAME: lowering_strategy = innerreduction
vector.ApplyMultiReductionUnrollingPatternsOp(
lowering_strategy=vector.VectorMultiReductionLowering.InnerReduction
)
+ # CHECK: transform.apply_patterns.vector.multi_reduction_lowering
+ vector.ApplyMultiReductionLoweringPatternsOp()
+ # CHECK: transform.apply_patterns.vector.multi_reduction_lowering
+ # CHECK-SAME: lowering_strategy = innerreduction
+ vector.ApplyMultiReductionLoweringPatternsOp(
+ lowering_strategy=vector.VectorMultiReductionLowering.InnerReduction
+ )
+
# CHECK: transform.apply_patterns.vector.lower_transpose
vector.ApplyLowerTransposePatternsOp()
# CHECK: transform.apply_patterns.vector.lower_transpose
More information about the Mlir-commits
mailing list