[Mlir-commits] [mlir] [mlir][vector] Rename `ReduceMultiDimReductionRank` as `FlattenMultiReduction` (PR #183721)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Fri Feb 27 02:19:33 PST 2026
https://github.com/banach-space created https://github.com/llvm/llvm-project/pull/183721
The updated name better captures what the pattern does and matches the
coresponding `populat*` hook,
`populateVectorMultiReductionFlatteningPatterns`, that only contains
this pattern.
Also adds a dedicated TD Op for
`populateVectorMultiReductionFlatteningPatterns` so that we can test
this pattern in isolaiton.
>From 2b5df405bb2b70a50b78377e5f621ecfdab9beab Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Fri, 27 Feb 2026 10:15:52 +0000
Subject: [PATCH] [mlir][vector] Rename `ReduceMultiDimReductionRank` as
`FlattenMultiReduction`
The updated name better captures what the pattern does and matches the
coresponding `populat*` hook,
`populateVectorMultiReductionFlatteningPatterns`, that only contains
this pattern.
Also adds a dedicated TD Op for
`populateVectorMultiReductionFlatteningPatterns` so that we can test
this pattern in isolaiton.
---
.../Vector/TransformOps/VectorTransformOps.td | 17 ++++++
.../TransformOps/VectorTransformOps.cpp | 11 ++++
.../Transforms/LowerVectorMultiReduction.cpp | 26 ++++++---
.../vector-multi-reduction-flatten.mlir | 57 +++++++++++++++++++
4 files changed, 103 insertions(+), 8 deletions(-)
create mode 100644 mlir/test/Dialect/Vector/vector-multi-reduction-flatten.mlir
diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index 462c61df72108..baec2d1024de0 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -266,6 +266,23 @@ def ApplyReorderAndExpandMultiReductionPatternsOp: Op<Transform_Dialect,
}];
}
+def ApplyMultiReductionFlattengPatternsOp: Op<Transform_Dialect,
+ "apply_patterns.vector.multi_reduction_flatten",
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ let description = [{
+ Populates the patterns from
+ `populateVectorMultiReductionFlatteningPatterns`.
+ }];
+
+ let arguments = (ins DefaultValuedAttr<VectorMultiReductionLoweringAttr,
+ "vector::VectorMultiReductionLowering::InnerParallel">:$lowering_strategy
+ );
+
+ let assemblyFormat = [{
+ (`lowering_strategy` `=` $lowering_strategy^)? attr-dict
+ }];
+}
+
def ApplyLowerOuterProductPatternsOp : Op<Transform_Dialect,
"apply_patterns.vector.lower_outerproduct",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 4e2b97aa07084..146eeac7d953d 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -126,6 +126,9 @@ void transform::ApplyMaterializeMasksPatternsOp::populatePatterns(
/*force32BitVectorIndices=*/false);
}
+//===----------------------------------------------------------------------===//
+// Multi-reduction patterns
+//===----------------------------------------------------------------------===//
void transform::ApplyLowerMultiReductionPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
vector::VectorTransformsOptions vectorTransformOptions;
@@ -146,6 +149,14 @@ void transform::ApplyReorderAndExpandMultiReductionPatternsOp::populatePatterns(
patterns, vectorTransformOptions.vectorMultiReductionLowering);
}
+void transform::ApplyMultiReductionFlattengPatternsOp::populatePatterns(
+ RewritePatternSet &patterns) {
+ vector::VectorTransformsOptions vectorTransformOptions;
+ vectorTransformOptions.setVectorMultiReductionLowering(getLoweringStrategy());
+ vector::populateVectorMultiReductionFlatteningPatterns(
+ patterns, vectorTransformOptions.vectorMultiReductionLowering);
+}
+
void transform::ApplyLowerOuterProductPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
populateVectorOuterProductLoweringPatterns(patterns);
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 2d6a49bad27bc..3db9273b32431 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -131,16 +131,27 @@ class InnerOuterDimReductionConversion
const bool useInnerDimsForReduction;
};
-/// Reduces the rank of vector.multi_reduction nd -> 2d given all reduction
-/// dimensions are either inner most or outer most.
-class ReduceMultiDimReductionRank
+/// Flattens vector.multi_reduction to 2D
+///
+/// Given all reduction dimensions are either inner most or outer most,
+/// flattens all reduction and parallel dimensions so that there are only 2Ds.
+///
+/// BEFORE
+/// vector.multi_reduction <add>, %vec, %acc [2, 3] : vector<2x3x4x5xi32> to
+/// vector<2x3xi32>
+/// AFTER
+/// %vec_sc = vector.shape_cast %vec
+/// %acc_sc = vector.shape_cast %acc
+/// %res = vector.multi_reduction <add>, %vec_sc, %acc_cs [1] :
+/// vector<6x20xi32> to vector<6xi32> %res_sc = vector.shape_cast %res
+class FlattenMultiReduction
: public OpRewritePattern<vector::MultiDimReductionOp> {
public:
using Base::Base;
- explicit ReduceMultiDimReductionRank(
- MLIRContext *context, vector::VectorMultiReductionLowering options,
- PatternBenefit benefit = 1)
+ explicit FlattenMultiReduction(MLIRContext *context,
+ vector::VectorMultiReductionLowering options,
+ PatternBenefit benefit = 1)
: mlir::OpRewritePattern<vector::MultiDimReductionOp>(context, benefit),
useInnerDimsForReduction(
options == vector::VectorMultiReductionLowering::InnerReduction) {}
@@ -533,8 +544,7 @@ void mlir::vector::populateVectorMultiReductionReorderAndExpandPatterns(
void mlir::vector::populateVectorMultiReductionFlatteningPatterns(
RewritePatternSet &patterns, VectorMultiReductionLowering options,
PatternBenefit benefit) {
- patterns.add<ReduceMultiDimReductionRank>(patterns.getContext(), options,
- benefit);
+ patterns.add<FlattenMultiReduction>(patterns.getContext(), options, benefit);
}
void mlir::vector::populateVectorMultiReductionUnrollingPatterns(
diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-flatten.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-flatten.mlir
new file mode 100644
index 0000000000000..509655d587ceb
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-flatten.mlir
@@ -0,0 +1,57 @@
+// RUN: mlir-opt %s --transform-interpreter='entry-point=innerreduction' | FileCheck %s --check-prefix=INNER_REDUCTION,ALL
+// RUN: mlir-opt %s --transform-interpreter='entry-point=innerparallel' | FileCheck %s --check-prefix=INNER_PARALLEL,ALL
+
+//=============================================================================
+// Tests for the FlattenMultiReduction pattern
+//=============================================================================
+
+// ALL-LABEL: func.func @reduction_inner(
+// ALL-SAME: %[[VEC:.*]]: vector<2x3x4x5xi32>,
+// ALL-SAME: %[[ACC:.*]]: vector<2x3xi32>) -> vector<2x3xi32> {
+
+// INNER_REDUCTION: %[[VEC_SC:.*]] = vector.shape_cast %[[VEC]] : vector<2x3x4x5xi32> to vector<6x20xi32>
+// INNER_REDUCTION: %[[ACC_SC:.*]] = vector.shape_cast %[[ACC]] : vector<2x3xi32> to vector<6xi32>
+// INNER_REDUCTION: %[[MULTI_REDUCTION_0:.*]] = vector.multi_reduction <add>, %[[VEC_SC]], %[[ACC_SC]] [1] : vector<6x20xi32> to vector<6xi32>
+// INNER_REDUCTION: %[[RES:.*]] = vector.shape_cast %[[MULTI_REDUCTION_0]] : vector<6xi32> to vector<2x3xi32>
+// INNER_REDUCTION: return %[[RES]]
+
+// INNER_PARALLEL: %[[RES:.*]] = vector.multi_reduction <add>, %[[VEC]], %[[ACC]] [2, 3] : vector<2x3x4x5xi32> to vector<2x3xi32>
+// INNER_PARALLEL: return %[[RES]]
+func.func @reduction_inner(%vec: vector<2x3x4x5xi32>, %acc: vector<2x3xi32>) -> vector<2x3xi32> {
+ %0 = vector.multi_reduction <add>, %vec, %acc [2, 3] : vector<2x3x4x5xi32> to vector<2x3xi32>
+ return %0 : vector<2x3xi32>
+}
+
+// ALL-LABEL: func.func @reduction_outer(
+// ALL-SAME: %[[VEC:.*]]: vector<2x3x4x5xi32>,
+// ALL-SAME: %[[ACC:.*]]: vector<4x5xi32>) -> vector<4x5xi32> {
+// INNER_REDUCTION: %[[RES:.*]] = vector.multi_reduction <add>, %[[VEC]], %[[ACC]] [0, 1] : vector<2x3x4x5xi32> to vector<4x5xi32>
+// INNER_REDUCTION: return %[[RES]] : vector<4x5xi32>
+
+// INNER_PARALLEL: %[[VEC_SC:.*]] = vector.shape_cast %[[VEC]] : vector<2x3x4x5xi32> to vector<6x20xi32>
+// INNER_PARALLEL: %[[ACC_SC:.*]] = vector.shape_cast %[[ACC]] : vector<4x5xi32> to vector<20xi32>
+// INNER_PARALLEL: %[[MULTI_REDUCTION_0:.*]] = vector.multi_reduction <add>, %[[VEC_SC]], %[[ACC_SC]] [0] : vector<6x20xi32> to vector<20xi32>
+// INNER_PARALLEL: %[[RES:.*]] = vector.shape_cast %[[MULTI_REDUCTION_0]] : vector<20xi32> to vector<4x5xi32>
+// INNER_PARALLEL: return %[[RES]]
+func.func @reduction_outer(%arg0: vector<2x3x4x5xi32>, %acc: vector<4x5xi32>) -> vector<4x5xi32> {
+ %0 = vector.multi_reduction <add>, %arg0, %acc [0, 1] : vector<2x3x4x5xi32> to vector<4x5xi32>
+ return %0 : vector<4x5xi32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @innerreduction(%root : !transform.any_op {transform.readonly}) {
+ %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
+ transform.apply_patterns to %func_op {
+ transform.apply_patterns.vector.multi_reduction_flatten lowering_strategy = "innerreduction"
+ } : !transform.op<"func.func">
+ transform.yield
+ }
+
+ transform.named_sequence @innerparallel(%root : !transform.any_op {transform.readonly}) {
+ %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
+ transform.apply_patterns to %func_op {
+ transform.apply_patterns.vector.multi_reduction_flatten lowering_strategy = "innerparallel"
+ } : !transform.op<"func.func">
+ transform.yield
+ }
+}
More information about the Mlir-commits
mailing list