[Mlir-commits] [mlir] ed05f70 - [mlir][vector] Rename `ReduceMultiDimReductionRank` -> `FlattenMultiReduction` (NFC) (#183721)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Feb 27 08:23:10 PST 2026
Author: Andrzej WarzyĆski
Date: 2026-02-27T16:23:06Z
New Revision: ed05f7012fe93af8b2b230f7059d0162d8aed126
URL: https://github.com/llvm/llvm-project/commit/ed05f7012fe93af8b2b230f7059d0162d8aed126
DIFF: https://github.com/llvm/llvm-project/commit/ed05f7012fe93af8b2b230f7059d0162d8aed126.diff
LOG: [mlir][vector] Rename `ReduceMultiDimReductionRank` -> `FlattenMultiReduction` (NFC) (#183721)
The updated name better captures what the pattern does and matches the
coresponding `populat*` hook,
`populateVectorMultiReductionFlatteningPatterns`, that only contains
this pattern.
Added:
Modified:
mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
mlir/test/Dialect/Vector/vector-multi-reduction-flattening.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 8a82e413e7dfc..9da4be88586f4 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -126,6 +126,9 @@ void transform::ApplyMaterializeMasksPatternsOp::populatePatterns(
/*force32BitVectorIndices=*/false);
}
+//===----------------------------------------------------------------------===//
+// Multi-reduction patterns
+//===----------------------------------------------------------------------===//
void transform::ApplyReorderAndExpandMultiReductionPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
vector::VectorTransformsOptions vectorTransformOptions;
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 70c5f33dd05cc..0d9ff95e1279c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -131,16 +131,27 @@ class InnerOuterDimReductionConversion
const bool useInnerDimsForReduction;
};
-/// Reduces the rank of vector.multi_reduction nd -> 2d given all reduction
-/// dimensions are either inner most or outer most.
-class ReduceMultiDimReductionRank
+/// Flattens vector.multi_reduction to 2D
+///
+/// Given all reduction dimensions are either inner most or outer most,
+/// flattens all reduction and parallel dimensions so that there are only 2Ds.
+///
+/// BEFORE
+/// vector.multi_reduction <add>, %vec, %acc [2, 3] : vector<2x3x4x5xi32> to
+/// vector<2x3xi32>
+/// AFTER
+/// %vec_sc = vector.shape_cast %vec
+/// %acc_sc = vector.shape_cast %acc
+/// %res = vector.multi_reduction <add>, %vec_sc, %acc_cs [1] :
+/// vector<6x20xi32> to vector<6xi32> %res_sc = vector.shape_cast %res
+class FlattenMultiReduction
: public OpRewritePattern<vector::MultiDimReductionOp> {
public:
using Base::Base;
- explicit ReduceMultiDimReductionRank(
- MLIRContext *context, vector::VectorMultiReductionLowering options,
- PatternBenefit benefit = 1)
+ explicit FlattenMultiReduction(MLIRContext *context,
+ vector::VectorMultiReductionLowering options,
+ PatternBenefit benefit = 1)
: mlir::OpRewritePattern<vector::MultiDimReductionOp>(context, benefit),
useInnerDimsForReduction(
options == vector::VectorMultiReductionLowering::InnerReduction) {}
@@ -552,8 +563,7 @@ void mlir::vector::populateVectorMultiReductionReorderAndExpandPatterns(
void mlir::vector::populateVectorMultiReductionFlatteningPatterns(
RewritePatternSet &patterns, VectorMultiReductionLowering options,
PatternBenefit benefit) {
- patterns.add<ReduceMultiDimReductionRank>(patterns.getContext(), options,
- benefit);
+ patterns.add<FlattenMultiReduction>(patterns.getContext(), options, benefit);
}
void mlir::vector::populateVectorMultiReductionUnrollingPatterns(
diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-flattening.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-flattening.mlir
index b8f970912909b..4ebcd60ccb13c 100644
--- a/mlir/test/Dialect/Vector/vector-multi-reduction-flattening.mlir
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-flattening.mlir
@@ -1,6 +1,10 @@
// RUN: mlir-opt %s --transform-interpreter='entry-point=innerreduction' | FileCheck %s --check-prefix=INNER_REDUCTION,ALL
// RUN: mlir-opt %s --transform-interpreter='entry-point=innerparallel' | FileCheck %s --check-prefix=INNER_PARALLEL,ALL
+//=============================================================================
+// Tests for the FlattenMultiReduction pattern
+//=============================================================================
+
// ALL-LABEL: func @negative_flattening_cases
func.func @negative_flattening_cases(
%v1d: vector<8xf32>,
More information about the Mlir-commits
mailing list