[Mlir-commits] [mlir] [mlir][vector] Rename `ReduceMultiDimReductionRank` as `FlattenMultiReduction` (PR #183721)

Andrzej WarzyƄski llvmlistbot at llvm.org
Fri Feb 27 02:19:33 PST 2026


https://github.com/banach-space created https://github.com/llvm/llvm-project/pull/183721

The updated name better captures what the pattern does and matches the
coresponding `populat*` hook,
`populateVectorMultiReductionFlatteningPatterns`, that only contains
this pattern.

Also adds a dedicated TD Op for
`populateVectorMultiReductionFlatteningPatterns` so that we can test
this pattern in isolaiton.


>From 2b5df405bb2b70a50b78377e5f621ecfdab9beab Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Fri, 27 Feb 2026 10:15:52 +0000
Subject: [PATCH] [mlir][vector] Rename `ReduceMultiDimReductionRank` as
 `FlattenMultiReduction`

The updated name better captures what the pattern does and matches the
coresponding `populat*` hook,
`populateVectorMultiReductionFlatteningPatterns`, that only contains
this pattern.

Also adds a dedicated TD Op for
`populateVectorMultiReductionFlatteningPatterns` so that we can test
this pattern in isolaiton.
---
 .../Vector/TransformOps/VectorTransformOps.td | 17 ++++++
 .../TransformOps/VectorTransformOps.cpp       | 11 ++++
 .../Transforms/LowerVectorMultiReduction.cpp  | 26 ++++++---
 .../vector-multi-reduction-flatten.mlir       | 57 +++++++++++++++++++
 4 files changed, 103 insertions(+), 8 deletions(-)
 create mode 100644 mlir/test/Dialect/Vector/vector-multi-reduction-flatten.mlir

diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index 462c61df72108..baec2d1024de0 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -266,6 +266,23 @@ def ApplyReorderAndExpandMultiReductionPatternsOp: Op<Transform_Dialect,
   }];
 }
 
+def ApplyMultiReductionFlattengPatternsOp: Op<Transform_Dialect,
+    "apply_patterns.vector.multi_reduction_flatten",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Populates the patterns from
+    `populateVectorMultiReductionFlatteningPatterns`.
+  }];
+
+  let arguments = (ins DefaultValuedAttr<VectorMultiReductionLoweringAttr,
+      "vector::VectorMultiReductionLowering::InnerParallel">:$lowering_strategy
+  );
+
+  let assemblyFormat = [{
+    (`lowering_strategy` `=` $lowering_strategy^)? attr-dict
+  }];
+}
+
 def ApplyLowerOuterProductPatternsOp : Op<Transform_Dialect,
     "apply_patterns.vector.lower_outerproduct",
     [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 4e2b97aa07084..146eeac7d953d 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -126,6 +126,9 @@ void transform::ApplyMaterializeMasksPatternsOp::populatePatterns(
                                             /*force32BitVectorIndices=*/false);
 }
 
+//===----------------------------------------------------------------------===//
+// Multi-reduction patterns
+//===----------------------------------------------------------------------===//
 void transform::ApplyLowerMultiReductionPatternsOp::populatePatterns(
     RewritePatternSet &patterns) {
   vector::VectorTransformsOptions vectorTransformOptions;
@@ -146,6 +149,14 @@ void transform::ApplyReorderAndExpandMultiReductionPatternsOp::populatePatterns(
       patterns, vectorTransformOptions.vectorMultiReductionLowering);
 }
 
+void transform::ApplyMultiReductionFlattengPatternsOp::populatePatterns(
+    RewritePatternSet &patterns) {
+  vector::VectorTransformsOptions vectorTransformOptions;
+  vectorTransformOptions.setVectorMultiReductionLowering(getLoweringStrategy());
+  vector::populateVectorMultiReductionFlatteningPatterns(
+      patterns, vectorTransformOptions.vectorMultiReductionLowering);
+}
+
 void transform::ApplyLowerOuterProductPatternsOp::populatePatterns(
     RewritePatternSet &patterns) {
   populateVectorOuterProductLoweringPatterns(patterns);
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 2d6a49bad27bc..3db9273b32431 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -131,16 +131,27 @@ class InnerOuterDimReductionConversion
   const bool useInnerDimsForReduction;
 };
 
-/// Reduces the rank of vector.multi_reduction nd -> 2d given all reduction
-/// dimensions are either inner most or outer most.
-class ReduceMultiDimReductionRank
+/// Flattens vector.multi_reduction to 2D
+///
+/// Given all reduction dimensions are either inner most or outer most,
+/// flattens all reduction and parallel dimensions so that there are only 2Ds.
+///
+/// BEFORE
+///     vector.multi_reduction <add>, %vec, %acc [2, 3] : vector<2x3x4x5xi32> to
+///     vector<2x3xi32>
+/// AFTER
+///     %vec_sc = vector.shape_cast %vec
+///     %acc_sc = vector.shape_cast %acc
+///     %res = vector.multi_reduction <add>, %vec_sc, %acc_cs [1] :
+///     vector<6x20xi32> to vector<6xi32> %res_sc = vector.shape_cast %res
+class FlattenMultiReduction
     : public OpRewritePattern<vector::MultiDimReductionOp> {
 public:
   using Base::Base;
 
-  explicit ReduceMultiDimReductionRank(
-      MLIRContext *context, vector::VectorMultiReductionLowering options,
-      PatternBenefit benefit = 1)
+  explicit FlattenMultiReduction(MLIRContext *context,
+                                 vector::VectorMultiReductionLowering options,
+                                 PatternBenefit benefit = 1)
       : mlir::OpRewritePattern<vector::MultiDimReductionOp>(context, benefit),
         useInnerDimsForReduction(
             options == vector::VectorMultiReductionLowering::InnerReduction) {}
@@ -533,8 +544,7 @@ void mlir::vector::populateVectorMultiReductionReorderAndExpandPatterns(
 void mlir::vector::populateVectorMultiReductionFlatteningPatterns(
     RewritePatternSet &patterns, VectorMultiReductionLowering options,
     PatternBenefit benefit) {
-  patterns.add<ReduceMultiDimReductionRank>(patterns.getContext(), options,
-                                            benefit);
+  patterns.add<FlattenMultiReduction>(patterns.getContext(), options, benefit);
 }
 
 void mlir::vector::populateVectorMultiReductionUnrollingPatterns(
diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-flatten.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-flatten.mlir
new file mode 100644
index 0000000000000..509655d587ceb
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-flatten.mlir
@@ -0,0 +1,57 @@
+// RUN: mlir-opt %s --transform-interpreter='entry-point=innerreduction' | FileCheck %s --check-prefix=INNER_REDUCTION,ALL
+// RUN: mlir-opt %s --transform-interpreter='entry-point=innerparallel' | FileCheck %s --check-prefix=INNER_PARALLEL,ALL
+
+//=============================================================================
+// Tests for the FlattenMultiReduction pattern
+//=============================================================================
+
+// ALL-LABEL:   func.func @reduction_inner(
+// ALL-SAME:                 %[[VEC:.*]]: vector<2x3x4x5xi32>,
+// ALL-SAME:                  %[[ACC:.*]]: vector<2x3xi32>) -> vector<2x3xi32> {
+
+// INNER_REDUCTION:           %[[VEC_SC:.*]] = vector.shape_cast %[[VEC]] : vector<2x3x4x5xi32> to vector<6x20xi32>
+// INNER_REDUCTION:           %[[ACC_SC:.*]] = vector.shape_cast %[[ACC]] : vector<2x3xi32> to vector<6xi32>
+// INNER_REDUCTION:           %[[MULTI_REDUCTION_0:.*]] = vector.multi_reduction <add>, %[[VEC_SC]], %[[ACC_SC]] [1] : vector<6x20xi32> to vector<6xi32>
+// INNER_REDUCTION:           %[[RES:.*]] = vector.shape_cast %[[MULTI_REDUCTION_0]] : vector<6xi32> to vector<2x3xi32>
+// INNER_REDUCTION:           return %[[RES]]
+
+// INNER_PARALLEL:           %[[RES:.*]] = vector.multi_reduction <add>, %[[VEC]], %[[ACC]] [2, 3] : vector<2x3x4x5xi32> to vector<2x3xi32>
+// INNER_PARALLEL:           return %[[RES]]
+func.func @reduction_inner(%vec: vector<2x3x4x5xi32>, %acc: vector<2x3xi32>) -> vector<2x3xi32> {
+    %0 = vector.multi_reduction <add>, %vec, %acc [2, 3] : vector<2x3x4x5xi32> to vector<2x3xi32>
+    return %0 : vector<2x3xi32>
+}
+
+// ALL-LABEL:   func.func @reduction_outer(
+// ALL-SAME:                  %[[VEC:.*]]: vector<2x3x4x5xi32>,
+// ALL-SAME:                  %[[ACC:.*]]: vector<4x5xi32>) -> vector<4x5xi32> {
+// INNER_REDUCTION:           %[[RES:.*]] = vector.multi_reduction <add>, %[[VEC]], %[[ACC]] [0, 1] : vector<2x3x4x5xi32> to vector<4x5xi32>
+// INNER_REDUCTION:           return %[[RES]] : vector<4x5xi32>
+
+// INNER_PARALLEL:            %[[VEC_SC:.*]] = vector.shape_cast %[[VEC]] : vector<2x3x4x5xi32> to vector<6x20xi32>
+// INNER_PARALLEL:            %[[ACC_SC:.*]] = vector.shape_cast %[[ACC]] : vector<4x5xi32> to vector<20xi32>
+// INNER_PARALLEL:            %[[MULTI_REDUCTION_0:.*]] = vector.multi_reduction <add>, %[[VEC_SC]], %[[ACC_SC]] [0] : vector<6x20xi32> to vector<20xi32>
+// INNER_PARALLEL:            %[[RES:.*]] = vector.shape_cast %[[MULTI_REDUCTION_0]] : vector<20xi32> to vector<4x5xi32>
+// INNER_PARALLEL:            return %[[RES]] 
+func.func @reduction_outer(%arg0: vector<2x3x4x5xi32>, %acc: vector<4x5xi32>) -> vector<4x5xi32> {
+    %0 = vector.multi_reduction <add>, %arg0, %acc [0, 1] : vector<2x3x4x5xi32> to vector<4x5xi32>
+    return %0 : vector<4x5xi32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @innerreduction(%root : !transform.any_op {transform.readonly}) {
+    %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
+    transform.apply_patterns to %func_op {
+      transform.apply_patterns.vector.multi_reduction_flatten lowering_strategy = "innerreduction"
+    } : !transform.op<"func.func">
+    transform.yield
+  }
+
+  transform.named_sequence @innerparallel(%root : !transform.any_op {transform.readonly}) {
+    %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
+    transform.apply_patterns to %func_op {
+      transform.apply_patterns.vector.multi_reduction_flatten lowering_strategy = "innerparallel"
+    } : !transform.op<"func.func">
+    transform.yield
+  }
+}



More information about the Mlir-commits mailing list