[Mlir-commits] [mlir] [mlir][vector] Add finer grained populate methods for multi_reduction (NFC). (PR #180750)

Erick Ochoa Lopez llvmlistbot at llvm.org
Tue Feb 10 07:04:14 PST 2026


https://github.com/amd-eochoalo created https://github.com/llvm/llvm-project/pull/180750

Thiese commits add three more populate methods for
`vector.multi_reduction`'s lowering patterns:

* populateVectorMultiReductionTransformationPatterns
* populateVectorMultiReductionFlatteningPatterns
* populateVectorMultiReductionUnrollingPatterns

These methods have a
finer level of granularity and allow users to select between unrolling,
flattening, and applying transformations that would set up operations
for unrolling and flattening.

The previous populateVectorMultiReductionLoweringPatterns method
is rewritten in terms of these new methods.

>From aad6145c09f0831ff4a1f2bc2995322f4250dbea Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 10 Feb 2026 09:39:41 -0500
Subject: [PATCH 1/2] [mlir][vector] Add populate methods for multi_reduction.
 (NFC)

This commit adds three more populate methods for
`vector.multi_reduction`'s lowering patterns. These methods have a
finer level of granularity and allow users to select between unrolling,
flattening, and applying transformations that would set up operations
for unrolling and flattening.
---
 .../Vector/Transforms/LoweringPatterns.h      | 42 +++++++++++++++++++
 .../Transforms/LowerVectorMultiReduction.cpp  | 20 ++++++++-
 2 files changed, 61 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index efa3c6d8ac238..c45b3b15b760b 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -60,6 +60,48 @@ void populateVectorContractLoweringPatterns(
 void populateVectorOuterProductLoweringPatterns(RewritePatternSet &patterns,
                                                 PatternBenefit benefit = 1);
 
+/// Populate the pattern set with the following patterns:
+///
+/// [InnerOuterDimReductionConversion]
+/// Rewrites vector.multi_reduction such that all reduction dimensions are
+/// either innermost or outermost, by adding the proper vector.transpose
+/// operations.
+///
+/// [OneDimMultiReductionToTwoDim]
+/// For cases that reduce to 1-D vector<k> reduction (and are thus missing
+/// either a parallel or a reduction), we lift them back up to 2-D with a simple
+/// vector.shape_cast to vector<1xk> so that the other patterns can kick in,
+/// thus fully exiting out of the vector.multi_reduction abstraction.
+void populateVectorMultiReductionTransformationPatterns(
+    RewritePatternSet &patterns, VectorMultiReductionLowering options,
+    PatternBenefit benefit = 1);
+
+/// Populate the pattern set with the following patterns:
+///
+/// [ReduceMultiDimReductionRank]
+/// Once in innermost or outermost reduction
+/// form, rewrites n-D vector.multi_reduction into 2-D vector.multi_reduction,
+/// by introducing vector.shape_cast ops to collapse + multi-reduce + expand
+/// back.
+void populateVectorMultiReductionFlatteningPatterns(RewritePatternSet &patterns,
+                                                    PatternBenefit benefit = 1);
+
+/// Populate the pattern set with the following patterns:
+///
+/// [TwoDimMultiReductionToElementWise]
+/// Once in 2-D vector.multi_reduction form, with an **outermost** reduction
+/// dimension, unroll the outer dimension to obtain a sequence of 1-D vector
+/// ops. This also has an opportunity for tree-reduction (in the future).
+///
+/// [TwoDimMultiReductionToReduction]
+/// Once in 2-D vector.multi_reduction form, with an **innermost** reduction
+/// dimension, unroll the outer dimension to obtain a sequence of extract +
+/// vector.reduction + insert. This can further lower to horizontal reduction
+/// ops.
+void populateVectorMultiReductionUnrollingPatterns(
+    RewritePatternSet &patterns, VectorMultiReductionLowering options,
+    PatternBenefit benefit = 1);
+
 /// Collect a set of patterns to convert vector.multi_reduction op into
 /// a sequence of vector.reduction ops. The patterns comprise:
 ///
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index e86e2a97038db..9003e49b30986 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -511,12 +511,21 @@ struct LowerVectorMultiReductionPass
 
 } // namespace
 
-void mlir::vector::populateVectorMultiReductionLoweringPatterns(
+void mlir::vector::populateVectorMultiReductionTransformationPatterns(
     RewritePatternSet &patterns, VectorMultiReductionLowering options,
     PatternBenefit benefit) {
   patterns.add<InnerOuterDimReductionConversion, ReduceMultiDimReductionRank>(
       patterns.getContext(), options, benefit);
+}
+
+void mlir::vector::populateVectorMultiReductionFlatteningPatterns(
+    RewritePatternSet &patterns, PatternBenefit benefit) {
   patterns.add<OneDimMultiReductionToTwoDim>(patterns.getContext(), benefit);
+}
+
+void mlir::vector::populateVectorMultiReductionUnrollingPatterns(
+    RewritePatternSet &patterns, VectorMultiReductionLowering options,
+    PatternBenefit benefit) {
   if (options == VectorMultiReductionLowering ::InnerReduction)
     patterns.add<TwoDimMultiReductionToReduction>(patterns.getContext(),
                                                   benefit);
@@ -525,6 +534,15 @@ void mlir::vector::populateVectorMultiReductionLoweringPatterns(
                                                     benefit);
 }
 
+void mlir::vector::populateVectorMultiReductionLoweringPatterns(
+    RewritePatternSet &patterns, VectorMultiReductionLowering options,
+    PatternBenefit benefit) {
+  populateVectorMultiReductionTransformationPatterns(patterns, options,
+                                                     benefit);
+  populateVectorMultiReductionFlatteningPatterns(patterns, benefit);
+  populateVectorMultiReductionUnrollingPatterns(patterns, options, benefit);
+}
+
 std::unique_ptr<Pass> vector::createLowerVectorMultiReductionPass(
     vector::VectorMultiReductionLowering option) {
   return std::make_unique<LowerVectorMultiReductionPass>(option);

>From 1650ccea1b27eaa23ffaa4d17c84da8a77b1a796 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 10 Feb 2026 09:57:36 -0500
Subject: [PATCH 2/2] [mlir][vector] Split lower-vector-multi-reduction. (NFC)

This commits splits the lower-vector-multi-reduction pass into three
different stages, each occuring sequentially one after the other. This
is intended to demonstrate a lowering approach which uses the separate
unrolling and flattening patterns at different stages and to show that
the changes are functionally equivalent.
---
 .../Transforms/LowerVectorMultiReduction.cpp  | 19 +++++++++++++++----
 1 file changed, 15 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 9003e49b30986..dd6972ab07a04 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -496,11 +496,22 @@ struct LowerVectorMultiReductionPass
     Operation *op = getOperation();
     MLIRContext *context = op->getContext();
 
-    RewritePatternSet loweringPatterns(context);
-    populateVectorMultiReductionLoweringPatterns(loweringPatterns,
-                                                 this->loweringStrategy);
+    RewritePatternSet patterns(context);
+    mlir::vector::populateVectorMultiReductionTransformationPatterns(
+        patterns, this->loweringStrategy);
+    if (failed(applyPatternsGreedily(op, std::move(patterns))))
+      signalPassFailure();
+
+    RewritePatternSet flatteningPatterns(context);
+    mlir::vector::populateVectorMultiReductionFlatteningPatterns(
+        flatteningPatterns);
+    if (failed(applyPatternsGreedily(op, std::move(flatteningPatterns))))
+      signalPassFailure();
 
-    if (failed(applyPatternsGreedily(op, std::move(loweringPatterns))))
+    RewritePatternSet unrollingPatterns(context);
+    mlir::vector::populateVectorMultiReductionUnrollingPatterns(
+        unrollingPatterns, this->loweringStrategy);
+    if (failed(applyPatternsGreedily(op, std::move(unrollingPatterns))))
       signalPassFailure();
   }
 



More information about the Mlir-commits mailing list