[Mlir-commits] [mlir] ed05f70 - [mlir][vector] Rename `ReduceMultiDimReductionRank` -> `FlattenMultiReduction` (NFC) (#183721)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Feb 27 08:23:10 PST 2026


Author: Andrzej WarzyƄski
Date: 2026-02-27T16:23:06Z
New Revision: ed05f7012fe93af8b2b230f7059d0162d8aed126

URL: https://github.com/llvm/llvm-project/commit/ed05f7012fe93af8b2b230f7059d0162d8aed126
DIFF: https://github.com/llvm/llvm-project/commit/ed05f7012fe93af8b2b230f7059d0162d8aed126.diff

LOG: [mlir][vector] Rename `ReduceMultiDimReductionRank` -> `FlattenMultiReduction` (NFC) (#183721)

The updated name better captures what the pattern does and matches the
coresponding `populat*` hook,
`populateVectorMultiReductionFlatteningPatterns`, that only contains
this pattern.

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
    mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
    mlir/test/Dialect/Vector/vector-multi-reduction-flattening.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 8a82e413e7dfc..9da4be88586f4 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -126,6 +126,9 @@ void transform::ApplyMaterializeMasksPatternsOp::populatePatterns(
                                             /*force32BitVectorIndices=*/false);
 }
 
+//===----------------------------------------------------------------------===//
+// Multi-reduction patterns
+//===----------------------------------------------------------------------===//
 void transform::ApplyReorderAndExpandMultiReductionPatternsOp::populatePatterns(
     RewritePatternSet &patterns) {
   vector::VectorTransformsOptions vectorTransformOptions;

diff  --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 70c5f33dd05cc..0d9ff95e1279c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -131,16 +131,27 @@ class InnerOuterDimReductionConversion
   const bool useInnerDimsForReduction;
 };
 
-/// Reduces the rank of vector.multi_reduction nd -> 2d given all reduction
-/// dimensions are either inner most or outer most.
-class ReduceMultiDimReductionRank
+/// Flattens vector.multi_reduction to 2D
+///
+/// Given all reduction dimensions are either inner most or outer most,
+/// flattens all reduction and parallel dimensions so that there are only 2Ds.
+///
+/// BEFORE
+///     vector.multi_reduction <add>, %vec, %acc [2, 3] : vector<2x3x4x5xi32> to
+///     vector<2x3xi32>
+/// AFTER
+///     %vec_sc = vector.shape_cast %vec
+///     %acc_sc = vector.shape_cast %acc
+///     %res = vector.multi_reduction <add>, %vec_sc, %acc_cs [1] :
+///     vector<6x20xi32> to vector<6xi32> %res_sc = vector.shape_cast %res
+class FlattenMultiReduction
     : public OpRewritePattern<vector::MultiDimReductionOp> {
 public:
   using Base::Base;
 
-  explicit ReduceMultiDimReductionRank(
-      MLIRContext *context, vector::VectorMultiReductionLowering options,
-      PatternBenefit benefit = 1)
+  explicit FlattenMultiReduction(MLIRContext *context,
+                                 vector::VectorMultiReductionLowering options,
+                                 PatternBenefit benefit = 1)
       : mlir::OpRewritePattern<vector::MultiDimReductionOp>(context, benefit),
         useInnerDimsForReduction(
             options == vector::VectorMultiReductionLowering::InnerReduction) {}
@@ -552,8 +563,7 @@ void mlir::vector::populateVectorMultiReductionReorderAndExpandPatterns(
 void mlir::vector::populateVectorMultiReductionFlatteningPatterns(
     RewritePatternSet &patterns, VectorMultiReductionLowering options,
     PatternBenefit benefit) {
-  patterns.add<ReduceMultiDimReductionRank>(patterns.getContext(), options,
-                                            benefit);
+  patterns.add<FlattenMultiReduction>(patterns.getContext(), options, benefit);
 }
 
 void mlir::vector::populateVectorMultiReductionUnrollingPatterns(

diff  --git a/mlir/test/Dialect/Vector/vector-multi-reduction-flattening.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-flattening.mlir
index b8f970912909b..4ebcd60ccb13c 100644
--- a/mlir/test/Dialect/Vector/vector-multi-reduction-flattening.mlir
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-flattening.mlir
@@ -1,6 +1,10 @@
 // RUN: mlir-opt %s --transform-interpreter='entry-point=innerreduction' | FileCheck %s --check-prefix=INNER_REDUCTION,ALL
 // RUN: mlir-opt %s --transform-interpreter='entry-point=innerparallel' | FileCheck %s --check-prefix=INNER_PARALLEL,ALL
 
+//=============================================================================
+// Tests for the FlattenMultiReduction pattern
+//=============================================================================
+
 // ALL-LABEL: func @negative_flattening_cases
 func.func @negative_flattening_cases(
     %v1d: vector<8xf32>,


        


More information about the Mlir-commits mailing list