[Mlir-commits] [mlir] [mlir][vector] Add finer grained populate methods for multi_reduction (NFC). (PR #180750)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Feb 10 07:05:02 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-vector

Author: Erick Ochoa Lopez (amd-eochoalo)

<details>
<summary>Changes</summary>

Thiese commits add three more populate methods for
`vector.multi_reduction`'s lowering patterns:

* populateVectorMultiReductionTransformationPatterns
* populateVectorMultiReductionFlatteningPatterns
* populateVectorMultiReductionUnrollingPatterns

These methods have a
finer level of granularity and allow users to select between unrolling,
flattening, and applying transformations that would set up operations
for unrolling and flattening.

The previous populateVectorMultiReductionLoweringPatterns method
is rewritten in terms of these new methods.

---
Full diff: https://github.com/llvm/llvm-project/pull/180750.diff


2 Files Affected:

- (modified) mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h (+42) 
- (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp (+34-5) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index efa3c6d8ac238..c45b3b15b760b 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -60,6 +60,48 @@ void populateVectorContractLoweringPatterns(
 void populateVectorOuterProductLoweringPatterns(RewritePatternSet &patterns,
                                                 PatternBenefit benefit = 1);
 
+/// Populate the pattern set with the following patterns:
+///
+/// [InnerOuterDimReductionConversion]
+/// Rewrites vector.multi_reduction such that all reduction dimensions are
+/// either innermost or outermost, by adding the proper vector.transpose
+/// operations.
+///
+/// [OneDimMultiReductionToTwoDim]
+/// For cases that reduce to 1-D vector<k> reduction (and are thus missing
+/// either a parallel or a reduction), we lift them back up to 2-D with a simple
+/// vector.shape_cast to vector<1xk> so that the other patterns can kick in,
+/// thus fully exiting out of the vector.multi_reduction abstraction.
+void populateVectorMultiReductionTransformationPatterns(
+    RewritePatternSet &patterns, VectorMultiReductionLowering options,
+    PatternBenefit benefit = 1);
+
+/// Populate the pattern set with the following patterns:
+///
+/// [ReduceMultiDimReductionRank]
+/// Once in innermost or outermost reduction
+/// form, rewrites n-D vector.multi_reduction into 2-D vector.multi_reduction,
+/// by introducing vector.shape_cast ops to collapse + multi-reduce + expand
+/// back.
+void populateVectorMultiReductionFlatteningPatterns(RewritePatternSet &patterns,
+                                                    PatternBenefit benefit = 1);
+
+/// Populate the pattern set with the following patterns:
+///
+/// [TwoDimMultiReductionToElementWise]
+/// Once in 2-D vector.multi_reduction form, with an **outermost** reduction
+/// dimension, unroll the outer dimension to obtain a sequence of 1-D vector
+/// ops. This also has an opportunity for tree-reduction (in the future).
+///
+/// [TwoDimMultiReductionToReduction]
+/// Once in 2-D vector.multi_reduction form, with an **innermost** reduction
+/// dimension, unroll the outer dimension to obtain a sequence of extract +
+/// vector.reduction + insert. This can further lower to horizontal reduction
+/// ops.
+void populateVectorMultiReductionUnrollingPatterns(
+    RewritePatternSet &patterns, VectorMultiReductionLowering options,
+    PatternBenefit benefit = 1);
+
 /// Collect a set of patterns to convert vector.multi_reduction op into
 /// a sequence of vector.reduction ops. The patterns comprise:
 ///
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index e86e2a97038db..dd6972ab07a04 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -496,11 +496,22 @@ struct LowerVectorMultiReductionPass
     Operation *op = getOperation();
     MLIRContext *context = op->getContext();
 
-    RewritePatternSet loweringPatterns(context);
-    populateVectorMultiReductionLoweringPatterns(loweringPatterns,
-                                                 this->loweringStrategy);
+    RewritePatternSet patterns(context);
+    mlir::vector::populateVectorMultiReductionTransformationPatterns(
+        patterns, this->loweringStrategy);
+    if (failed(applyPatternsGreedily(op, std::move(patterns))))
+      signalPassFailure();
+
+    RewritePatternSet flatteningPatterns(context);
+    mlir::vector::populateVectorMultiReductionFlatteningPatterns(
+        flatteningPatterns);
+    if (failed(applyPatternsGreedily(op, std::move(flatteningPatterns))))
+      signalPassFailure();
 
-    if (failed(applyPatternsGreedily(op, std::move(loweringPatterns))))
+    RewritePatternSet unrollingPatterns(context);
+    mlir::vector::populateVectorMultiReductionUnrollingPatterns(
+        unrollingPatterns, this->loweringStrategy);
+    if (failed(applyPatternsGreedily(op, std::move(unrollingPatterns))))
       signalPassFailure();
   }
 
@@ -511,12 +522,21 @@ struct LowerVectorMultiReductionPass
 
 } // namespace
 
-void mlir::vector::populateVectorMultiReductionLoweringPatterns(
+void mlir::vector::populateVectorMultiReductionTransformationPatterns(
     RewritePatternSet &patterns, VectorMultiReductionLowering options,
     PatternBenefit benefit) {
   patterns.add<InnerOuterDimReductionConversion, ReduceMultiDimReductionRank>(
       patterns.getContext(), options, benefit);
+}
+
+void mlir::vector::populateVectorMultiReductionFlatteningPatterns(
+    RewritePatternSet &patterns, PatternBenefit benefit) {
   patterns.add<OneDimMultiReductionToTwoDim>(patterns.getContext(), benefit);
+}
+
+void mlir::vector::populateVectorMultiReductionUnrollingPatterns(
+    RewritePatternSet &patterns, VectorMultiReductionLowering options,
+    PatternBenefit benefit) {
   if (options == VectorMultiReductionLowering ::InnerReduction)
     patterns.add<TwoDimMultiReductionToReduction>(patterns.getContext(),
                                                   benefit);
@@ -525,6 +545,15 @@ void mlir::vector::populateVectorMultiReductionLoweringPatterns(
                                                     benefit);
 }
 
+void mlir::vector::populateVectorMultiReductionLoweringPatterns(
+    RewritePatternSet &patterns, VectorMultiReductionLowering options,
+    PatternBenefit benefit) {
+  populateVectorMultiReductionTransformationPatterns(patterns, options,
+                                                     benefit);
+  populateVectorMultiReductionFlatteningPatterns(patterns, benefit);
+  populateVectorMultiReductionUnrollingPatterns(patterns, options, benefit);
+}
+
 std::unique_ptr<Pass> vector::createLowerVectorMultiReductionPass(
     vector::VectorMultiReductionLowering option) {
   return std::make_unique<LowerVectorMultiReductionPass>(option);

``````````

</details>


https://github.com/llvm/llvm-project/pull/180750


More information about the Mlir-commits mailing list