[Mlir-commits] [mlir] [mlir][vector] Replace OneDimMultiReductionToTwoDim with OneDimMultiReductionToReduction (PR #184241)

Erick Ochoa Lopez llvmlistbot at llvm.org
Wed Mar 4 07:51:47 PST 2026


https://github.com/amd-eochoalo updated https://github.com/llvm/llvm-project/pull/184241

>From f7f51269cb283f7da1ebe5bd99b7a1e93ab41a53 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Mon, 2 Mar 2026 16:02:32 -0500
Subject: [PATCH 01/10] [mlir][vector] Add OneDimMultiReductionToReduction

---
 .../Vector/TransformOps/VectorTransformOps.td |  9 ++--
 .../Vector/Transforms/LoweringPatterns.h      |  6 ++-
 .../Transforms/LowerVectorMultiReduction.cpp  | 44 ++++++++++++++++++-
 .../vector-multi-reduction-unrolling.mlir     | 33 ++++++++++----
 4 files changed, 79 insertions(+), 13 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index 9fec5804d0b3b..bbe20a55eb10a 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -267,12 +267,15 @@ def ApplyMultiReductionUnrollingPatternsOp: Op<Transform_Dialect,
     "apply_patterns.vector.multi_reduction_unrolling",
     [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
   let description = [{
-    Indicates that 2-D vector multi_reduction operations should be unrolled
-    into either a sequence of vector.reduction ops (innerreduction) or
-    element-wise arith ops (innerparallel).
+    Indicates that vector multi_reduction operations should be unrolled.
+    1-D multi_reductions are converted directly to vector.reduction.
+    2-D multi_reductions are unrolled into either a sequence of
+    vector.reduction ops (innerreduction) or element-wise arith ops
+    (innerparallel).
 
     This populates the patterns from
     `populateVectorMultiReductionUnrollingPatterns`, i.e.:
+    * `OneDimMultiReductionToReduction`
     * `TwoDimMultiReductionToReduction` (innerreduction)
     * `TwoDimMultiReductionToElementWise` (innerparallel)
   }];
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index a933f68732a4d..7cb8e1df65ef8 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -74,7 +74,7 @@ void populateVectorOuterProductLoweringPatterns(RewritePatternSet &patterns,
 /// thus fully exiting out of the vector.multi_reduction abstraction.
 void populateVectorMultiReductionReorderAndExpandPatterns(
     RewritePatternSet &patterns, VectorMultiReductionLowering options,
-    PatternBenefit benefit = 1);
+    PatternBenefit benefit = 2);
 
 /// Populate the pattern set with the following patterns:
 ///
@@ -89,6 +89,10 @@ void populateVectorMultiReductionFlatteningPatterns(
 
 /// Populate the pattern set with the following patterns:
 ///
+/// [OneDimMultiReductionToReduction]
+/// Converts 1-D vector.multi_reduction directly to vector.reduction.
+/// This is the terminal case for unrolling.
+///
 /// [TwoDimMultiReductionToElementWise]
 /// Once in 2-D vector.multi_reduction form, with an **outermost** reduction
 /// dimension, unroll the outer dimension to obtain a sequence of 1-D vector
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 0d9ff95e1279c..4a4e007d1381b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -375,7 +375,7 @@ struct TwoDimMultiReductionToElementWise
   }
 };
 
-/// Lowers 2D vector.multi_reduction to a squence of vector.reduction Ops
+/// Lowers 2D vector.multi_reduction to a sequence of vector.reduction ops.
 ///
 /// The reduction dimension must be the inner-most dimension.
 ///
@@ -443,6 +443,47 @@ struct TwoDimMultiReductionToReduction
   }
 };
 
+/// Converts 1-D vector.multi_reduction directly to vector.reduction.
+/// This is the terminal case for unrolling - once we reach rank 1,
+/// we convert to vector.reduction which backends can optimize.
+///
+/// Example:
+/// ```mlir
+/// // Before
+/// %r = vector.multi_reduction <add>, %v, %acc [0] : vector<Nxf32> to f32
+///
+/// // After
+/// %r = vector.reduction <add>, %v, %acc : vector<Nxf32> into f32
+/// ```
+struct OneDimMultiReductionToReduction
+    : public vector::MaskableOpRewritePattern<vector::MultiDimReductionOp> {
+  using MaskableOpRewritePattern::MaskableOpRewritePattern;
+
+  FailureOr<Value>
+  matchAndRewriteMaskableOp(vector::MultiDimReductionOp multiReductionOp,
+                            vector::MaskingOpInterface maskingOp,
+                            PatternRewriter &rewriter) const override {
+    auto srcRank = multiReductionOp.getSourceVectorType().getRank();
+    if (srcRank != 1)
+      return failure();
+
+    if (!multiReductionOp.isReducedDim(0))
+      return failure();
+
+    auto loc = multiReductionOp.getLoc();
+    Value mask = maskingOp ? maskingOp.getMask() : Value();
+
+    Operation *reductionOp = vector::ReductionOp::create(
+        rewriter, loc, multiReductionOp.getKind(),
+        multiReductionOp.getSource(), multiReductionOp.getAcc());
+
+    if (mask)
+      reductionOp = mlir::vector::maskOperation(rewriter, reductionOp, mask);
+
+    return reductionOp->getResult(0);
+  }
+};
+
 /// Converts 1d vector.multi_reduction with a single reduction dimension to a 2d
 /// form with both a single parallel and reduction dimension.
 /// This is achieved with a simple vector.shape_cast that inserts a leading 1.
@@ -569,6 +610,7 @@ void mlir::vector::populateVectorMultiReductionFlatteningPatterns(
 void mlir::vector::populateVectorMultiReductionUnrollingPatterns(
     RewritePatternSet &patterns, VectorMultiReductionLowering options,
     PatternBenefit benefit) {
+  patterns.add<OneDimMultiReductionToReduction>(patterns.getContext(), benefit);
   if (options == VectorMultiReductionLowering ::InnerReduction)
     patterns.add<TwoDimMultiReductionToReduction>(patterns.getContext(),
                                                   benefit);
diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-unrolling.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-unrolling.mlir
index bc0d192e012ee..447416ccba637 100644
--- a/mlir/test/Dialect/Vector/vector-multi-reduction-unrolling.mlir
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-unrolling.mlir
@@ -1,15 +1,32 @@
 // RUN: mlir-opt %s --transform-interpreter='entry-point=innerreduction' | FileCheck %s --check-prefixes=INNER_REDUCTION,ALL
 // RUN: mlir-opt %s --transform-interpreter='entry-point=innerparallel' | FileCheck %s --check-prefixes=INNER_PARALLEL,ALL
 
-// ALL-LABEL: func @negative_rank1_and_rank3
-func.func @negative_rank1_and_rank3(
-    %rank1: vector<8xf32>, %rank1_acc: f32,
-    %rank3: vector<2x3x4xf32>, %rank3_acc: vector<2x3xf32>) -> (f32, vector<2x3xf32>) {
-  // ALL: vector.multi_reduction <add>, {{.+}} [0] : vector<8xf32> to f32
-  %0 = vector.multi_reduction <add>, %rank1, %rank1_acc [0] : vector<8xf32> to f32
+// ALL-LABEL: func @one_dim_reduction
+// ALL-SAME:    %[[INPUT:.+]]: vector<8xf32>, %[[ACC:.+]]: f32
+func.func @one_dim_reduction(%arg0: vector<8xf32>, %acc: f32) -> f32 {
+  // ALL: %[[RESULT:.+]] = vector.reduction <add>, %[[INPUT]], %[[ACC]] : vector<8xf32> into f32
+  %0 = vector.multi_reduction <add>, %arg0, %acc [0] : vector<8xf32> to f32
+  // ALL: return %[[RESULT]]
+  return %0 : f32
+}
+
+// ALL-LABEL: func @one_dim_reduction_masked
+// ALL-SAME:    %[[INPUT:.+]]: vector<8xf32>, %[[ACC:.+]]: f32, %[[MASK:.+]]: vector<8xi1>
+func.func @one_dim_reduction_masked(%arg0: vector<8xf32>, %acc: f32, %mask: vector<8xi1>) -> f32 {
+  // ALL: %[[RESULT:.+]] = vector.mask %[[MASK]] { vector.reduction <add>, %[[INPUT]], %[[ACC]] : vector<8xf32> into f32 } : vector<8xi1> -> f32
+  %0 = vector.mask %mask {
+    vector.multi_reduction <add>, %arg0, %acc [0] : vector<8xf32> to f32
+  } : vector<8xi1> -> f32
+  // ALL: return %[[RESULT]]
+  return %0 : f32
+}
+
+// ALL-LABEL: func @negative_rank3
+func.func @negative_rank3(
+    %rank3: vector<2x3x4xf32>, %rank3_acc: vector<2x3xf32>) -> vector<2x3xf32> {
   // ALL: vector.multi_reduction <add>, {{.+}} [2] : vector<2x3x4xf32> to vector<2x3xf32>
-  %1 = vector.multi_reduction <add>, %rank3, %rank3_acc [2] : vector<2x3x4xf32> to vector<2x3xf32>
-  return %0, %1 : f32, vector<2x3xf32>
+  %0 = vector.multi_reduction <add>, %rank3, %rank3_acc [2] : vector<2x3x4xf32> to vector<2x3xf32>
+  return %0 : vector<2x3xf32>
 }
 
 // ALL-LABEL: func @inner_reduction_2d

>From eab5c1e79dc1bcb73ad93909514363cd0c08527c Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Mon, 2 Mar 2026 16:24:09 -0500
Subject: [PATCH 02/10] [mlir][vector] Remove OneDimMultiReductionToTwoDim

---
 .../Vector/TransformOps/VectorTransformOps.td |  1 -
 .../Vector/Transforms/LoweringPatterns.h      |  6 --
 .../Transforms/LowerVectorMultiReduction.cpp  | 73 -------------------
 ...or-multi-reduction-reorder-and-expand.mlir | 46 +-----------
 4 files changed, 3 insertions(+), 123 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index bbe20a55eb10a..b5c5a3be872df 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -234,7 +234,6 @@ def ApplyReorderAndExpandMultiReductionPatternsOp: Op<Transform_Dialect,
     This populates the patterns from
     `populateVectorMultiReductionReorderAndExpandPatterns`, i.e.:
     * `InnerOuterDimReductionConversion`
-    * `OneDimMultiReductionToTwoDim`
   }];
 
   let arguments = (ins DefaultValuedAttr<VectorMultiReductionLoweringAttr,
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 7cb8e1df65ef8..d064e116a9c77 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -66,12 +66,6 @@ void populateVectorOuterProductLoweringPatterns(RewritePatternSet &patterns,
 /// Rewrites vector.multi_reduction such that all reduction dimensions are
 /// either innermost or outermost, by adding the proper vector.transpose
 /// operations.
-///
-/// [OneDimMultiReductionToTwoDim]
-/// For cases that reduce to 1-D vector<k> reduction (and are thus missing
-/// either a parallel or a reduction), we lift them back up to 2-D with a simple
-/// vector.shape_cast to vector<1xk> so that the other patterns can kick in,
-/// thus fully exiting out of the vector.multi_reduction abstraction.
 void populateVectorMultiReductionReorderAndExpandPatterns(
     RewritePatternSet &patterns, VectorMultiReductionLowering options,
     PatternBenefit benefit = 2);
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 4a4e007d1381b..a0c6709d8a532 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -484,78 +484,6 @@ struct OneDimMultiReductionToReduction
   }
 };
 
-/// Converts 1d vector.multi_reduction with a single reduction dimension to a 2d
-/// form with both a single parallel and reduction dimension.
-/// This is achieved with a simple vector.shape_cast that inserts a leading 1.
-/// The case with a single parallel dimension is a noop and folds away
-/// separately.
-struct OneDimMultiReductionToTwoDim
-    : public OpRewritePattern<vector::MultiDimReductionOp> {
-  using Base::Base;
-
-  LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
-                                PatternRewriter &rewriter) const override {
-    auto srcRank = multiReductionOp.getSourceVectorType().getRank();
-    // Rank-1 or bail.
-    if (srcRank != 1)
-      return failure();
-
-    // Vector mask setup.
-    OpBuilder::InsertionGuard guard(rewriter);
-    auto maskableOp =
-        cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
-    Operation *rootOp;
-    Value mask;
-    if (maskableOp.isMasked()) {
-      rewriter.setInsertionPoint(maskableOp.getMaskingOp());
-      rootOp = maskableOp.getMaskingOp();
-      mask = maskableOp.getMaskingOp().getMask();
-    } else {
-      rootOp = multiReductionOp;
-    }
-
-    auto loc = multiReductionOp.getLoc();
-    auto srcVectorType = multiReductionOp.getSourceVectorType();
-    auto srcShape = srcVectorType.getShape();
-    auto castedType = VectorType::get(
-        ArrayRef<int64_t>{1, srcShape.back()}, srcVectorType.getElementType(),
-        ArrayRef<bool>{false, srcVectorType.getScalableDims().back()});
-
-    auto accType =
-        VectorType::get(ArrayRef<int64_t>{1}, srcVectorType.getElementType());
-    assert(!llvm::isa<VectorType>(multiReductionOp.getDestType()) &&
-           "multi_reduction with a single dimension expects a scalar result");
-
-    // If the unique dim is reduced and we insert a parallel in front, we need a
-    // {false, true} mask.
-    SmallVector<bool, 2> reductionMask{false, true};
-
-    /// vector.extract(vector.multi_reduce(vector.shape_cast(v, 1xk)), 0)
-    Value cast = vector::ShapeCastOp::create(rewriter, loc, castedType,
-                                             multiReductionOp.getSource());
-    Value castAcc = vector::BroadcastOp::create(rewriter, loc, accType,
-                                                multiReductionOp.getAcc());
-    Value castMask;
-    if (maskableOp.isMasked()) {
-      auto maskType = llvm::cast<VectorType>(mask.getType());
-      auto castMaskType = VectorType::get(
-          ArrayRef<int64_t>{1, maskType.getShape().back()},
-          maskType.getElementType(),
-          ArrayRef<bool>{false, maskType.getScalableDims().back()});
-      castMask = vector::BroadcastOp::create(rewriter, loc, castMaskType, mask);
-    }
-
-    Operation *newOp = vector::MultiDimReductionOp::create(
-        rewriter, loc, cast, castAcc, reductionMask,
-        multiReductionOp.getKind());
-    newOp = vector::maskOperation(rewriter, newOp, castMask);
-
-    rewriter.replaceOpWithNewOp<vector::ExtractOp>(rootOp, newOp->getResult(0),
-                                                   ArrayRef<int64_t>{0});
-    return success();
-  }
-};
-
 struct LowerVectorMultiReductionPass
     : public vector::impl::LowerVectorMultiReductionBase<
           LowerVectorMultiReductionPass> {
@@ -596,7 +524,6 @@ struct LowerVectorMultiReductionPass
 void mlir::vector::populateVectorMultiReductionReorderAndExpandPatterns(
     RewritePatternSet &patterns, VectorMultiReductionLowering options,
     PatternBenefit benefit) {
-  patterns.add<OneDimMultiReductionToTwoDim>(patterns.getContext(), benefit);
   patterns.add<InnerOuterDimReductionConversion>(patterns.getContext(), options,
                                                  benefit);
 }
diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-reorder-and-expand.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-reorder-and-expand.mlir
index 7f41f7e9e1ddc..fff075dcb47f2 100644
--- a/mlir/test/Dialect/Vector/vector-multi-reduction-reorder-and-expand.mlir
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-reorder-and-expand.mlir
@@ -36,50 +36,10 @@ func.func @transpose_parallel_middle(%arg0: vector<3x4x5xf32>, %acc: vector<4xf3
     return %0 : vector<4xf32>
 }
 
-// ALL-LABEL: func @one_dim_to_two_dim
-// ALL-SAME:    %[[INPUT:.+]]: vector<8xf32>
-// ALL-SAME:    %[[ACC:.+]]: f32
-func.func @one_dim_to_two_dim(%arg0: vector<8xf32>, %acc: f32) -> f32 {
-    // ALL: %[[CAST:.+]] = vector.shape_cast %[[INPUT]] : vector<8xf32> to vector<1x8xf32>
-    // ALL: %[[BROADCAST:.+]] = vector.broadcast %[[ACC]] : f32 to vector<1xf32>
-    // INNER_REDUCTION: %[[RESULT:.+]] = vector.multi_reduction <add>, %[[CAST]], %[[BROADCAST]] [1]
-    // INNER_REDUCTION: %[[SCALAR:.+]] = vector.extract %[[RESULT]][0]
-    // INNER_PARALLEL: %[[TRANSPOSED:.+]] = vector.transpose %[[CAST]], [1, 0]
-    // INNER_PARALLEL: %[[RESULT:.+]] = vector.multi_reduction <add>, %[[TRANSPOSED]], %[[BROADCAST]] [0]
-    // INNER_PARALLEL: %[[SCALAR:.+]] = vector.extract %[[RESULT]][0]
+// ALL-LABEL: func @negative_one_dim
+func.func @negative_one_dim(%arg0: vector<8xf32>, %acc: f32) -> f32 {
+    // ALL: vector.multi_reduction <add>, {{.+}} [0] : vector<8xf32> to f32
     %0 = vector.multi_reduction <add>, %arg0, %acc [0] : vector<8xf32> to f32
-    // ALL: return %[[SCALAR]]
-    return %0 : f32
-}
-
-// INNER_REDUCTION-LABEL: func @one_dim_to_two_dim_scalable
-// INNER_REDUCTION-SAME:    %[[INPUT:.+]]: vector<[4]xf32>
-// INNER_REDUCTION-SAME:    %[[ACC:.+]]: f32
-func.func @one_dim_to_two_dim_scalable(%arg0: vector<[4]xf32>, %acc: f32) -> f32 {
-    // INNER_REDUCTION: %[[CAST:.+]] = vector.shape_cast %[[INPUT]] : vector<[4]xf32> to vector<1x[4]xf32>
-    // INNER_REDUCTION: %[[BROADCAST:.+]] = vector.broadcast %[[ACC]] : f32 to vector<1xf32>
-    // INNER_REDUCTION: %[[RESULT:.+]] = vector.multi_reduction <add>, %[[CAST]], %[[BROADCAST]] [1]
-    %0 = vector.multi_reduction <add>, %arg0, %acc [0] : vector<[4]xf32> to f32
-    // INNER_REDUCTION: %[[EXTRACT:.+]] = vector.extract %[[RESULT]][0]
-    // INNER_REDUCTION: return %[[EXTRACT]]
-    return %0 : f32
-}
-
-// INNER_REDUCTION-LABEL: func @one_dim_to_two_dim_masked
-// INNER_REDUCTION-SAME:    %[[INPUT:.+]]: vector<8xf32>
-// INNER_REDUCTION-SAME:    %[[ACC:.+]]: f32
-// INNER_REDUCTION-SAME:    %[[MASK:.+]]: vector<8xi1>
-func.func @one_dim_to_two_dim_masked(%arg0: vector<8xf32>, %acc: f32, %mask: vector<8xi1>) -> f32 {
-    // INNER_REDUCTION: %[[CAST:.+]] = vector.shape_cast %[[INPUT]] : vector<8xf32> to vector<1x8xf32>
-    // INNER_REDUCTION: %[[BROADCAST_ACC:.+]] = vector.broadcast %[[ACC]] : f32 to vector<1xf32>
-    // INNER_REDUCTION: %[[BROADCAST_MASK:.+]] = vector.broadcast %[[MASK]] : vector<8xi1> to vector<1x8xi1>
-    // INNER_REDUCTION: %[[RESULT:.+]] = vector.mask %[[BROADCAST_MASK]] {
-    // INNER_REDUCTION:   vector.multi_reduction <add>, %[[CAST]], %[[BROADCAST_ACC]] [1]
-    %0 = vector.mask %mask {
-      vector.multi_reduction <add>, %arg0, %acc [0] : vector<8xf32> to f32
-    } : vector<8xi1> -> f32
-    // INNER_REDUCTION: %[[EXTRACT:.+]] = vector.extract %[[RESULT]][0]
-    // INNER_REDUCTION: return %[[EXTRACT]]
     return %0 : f32
 }
 

>From eb3ff3fbc01ba386e8ea68e1e64f631bf8aee6b5 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Mon, 2 Mar 2026 16:33:22 -0500
Subject: [PATCH 03/10] [mlir][vector] Rename reorder_and_expand to reorder

---
 .../Dialect/Vector/TransformOps/VectorTransformOps.td     | 6 +++---
 .../mlir/Dialect/Vector/Transforms/LoweringPatterns.h     | 2 +-
 .../Dialect/Vector/TransformOps/VectorTransformOps.cpp    | 4 ++--
 .../Vector/Transforms/LowerVectorMultiReduction.cpp       | 4 ++--
 mlir/test/Dialect/LLVM/transform-e2e.mlir                 | 2 +-
 mlir/test/Dialect/Vector/transform-vector.mlir            | 2 +-
 ...nd-expand.mlir => vector-multi-reduction-reorder.mlir} | 4 ++--
 .../Integration/Dialect/Linalg/CPU/ArmSVE/reduce_1d.mlir  | 2 +-
 .../Integration/Dialect/Linalg/CPU/ArmSVE/reduce_2d.mlir  | 2 +-
 .../Dialect/Linalg/CPU/test-matmul-masked-vec.mlir        | 2 +-
 mlir/test/python/dialects/transform_vector_ext.py         | 8 ++++----
 11 files changed, 19 insertions(+), 19 deletions(-)
 rename mlir/test/Dialect/Vector/{vector-multi-reduction-reorder-and-expand.mlir => vector-multi-reduction-reorder.mlir} (93%)

diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index b5c5a3be872df..b370d17ea6ea9 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -223,8 +223,8 @@ def ApplyMaterializeMasksPatternsOp : Op<Transform_Dialect,
   let assemblyFormat = "attr-dict";
 }
 
-def ApplyReorderAndExpandMultiReductionPatternsOp: Op<Transform_Dialect,
-    "apply_patterns.vector.reorder_and_expand_multi_reduction_dims",
+def ApplyReorderMultiReductionPatternsOp: Op<Transform_Dialect,
+    "apply_patterns.vector.reorder_multi_reduction_dims",
     [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
   let description = [{
     Indicates that vector multi_reduction-like operations should be
@@ -232,7 +232,7 @@ def ApplyReorderAndExpandMultiReductionPatternsOp: Op<Transform_Dialect,
     outermost, and 1-D reductions are lifted to 2-D.
 
     This populates the patterns from
-    `populateVectorMultiReductionReorderAndExpandPatterns`, i.e.:
+    `populateVectorMultiReductionReorderPatterns`, i.e.:
     * `InnerOuterDimReductionConversion`
   }];
 
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index d064e116a9c77..a843cc7f49d35 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -66,7 +66,7 @@ void populateVectorOuterProductLoweringPatterns(RewritePatternSet &patterns,
 /// Rewrites vector.multi_reduction such that all reduction dimensions are
 /// either innermost or outermost, by adding the proper vector.transpose
 /// operations.
-void populateVectorMultiReductionReorderAndExpandPatterns(
+void populateVectorMultiReductionReorderPatterns(
     RewritePatternSet &patterns, VectorMultiReductionLowering options,
     PatternBenefit benefit = 2);
 
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 9da4be88586f4..312bd28ad48cf 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -129,11 +129,11 @@ void transform::ApplyMaterializeMasksPatternsOp::populatePatterns(
 //===----------------------------------------------------------------------===//
 // Multi-reduction patterns
 //===----------------------------------------------------------------------===//
-void transform::ApplyReorderAndExpandMultiReductionPatternsOp::populatePatterns(
+void transform::ApplyReorderMultiReductionPatternsOp::populatePatterns(
     RewritePatternSet &patterns) {
   vector::VectorTransformsOptions vectorTransformOptions;
   vectorTransformOptions.setVectorMultiReductionLowering(getLoweringStrategy());
-  vector::populateVectorMultiReductionReorderAndExpandPatterns(
+  vector::populateVectorMultiReductionReorderPatterns(
       patterns, vectorTransformOptions.vectorMultiReductionLowering);
 }
 
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index a0c6709d8a532..1ba371005ab8a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -496,7 +496,7 @@ struct LowerVectorMultiReductionPass
     MLIRContext *context = op->getContext();
 
     RewritePatternSet patterns(context);
-    mlir::vector::populateVectorMultiReductionReorderAndExpandPatterns(
+    mlir::vector::populateVectorMultiReductionReorderPatterns(
         patterns, this->loweringStrategy);
     if (failed(applyPatternsGreedily(op, std::move(patterns))))
       signalPassFailure();
@@ -521,7 +521,7 @@ struct LowerVectorMultiReductionPass
 
 } // namespace
 
-void mlir::vector::populateVectorMultiReductionReorderAndExpandPatterns(
+void mlir::vector::populateVectorMultiReductionReorderPatterns(
     RewritePatternSet &patterns, VectorMultiReductionLowering options,
     PatternBenefit benefit) {
   patterns.add<InnerOuterDimReductionConversion>(patterns.getContext(), options,
diff --git a/mlir/test/Dialect/LLVM/transform-e2e.mlir b/mlir/test/Dialect/LLVM/transform-e2e.mlir
index ab58dda91a914..bf7eba6e50174 100644
--- a/mlir/test/Dialect/LLVM/transform-e2e.mlir
+++ b/mlir/test/Dialect/LLVM/transform-e2e.mlir
@@ -30,7 +30,7 @@ module attributes {transform.with_named_sequence} {
     transform.apply_patterns to %f {
       transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
       transform.apply_patterns.vector.transfer_permutation_patterns
-      transform.apply_patterns.vector.reorder_and_expand_multi_reduction_dims lowering_strategy = "innerparallel"
+      transform.apply_patterns.vector.reorder_multi_reduction_dims lowering_strategy = "innerparallel"
       transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerparallel"
       transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerparallel"
       transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "linalg-copy"
diff --git a/mlir/test/Dialect/Vector/transform-vector.mlir b/mlir/test/Dialect/Vector/transform-vector.mlir
index a37105d573219..4dc11c26e83f1 100644
--- a/mlir/test/Dialect/Vector/transform-vector.mlir
+++ b/mlir/test/Dialect/Vector/transform-vector.mlir
@@ -39,7 +39,7 @@ module attributes {transform.with_named_sequence} {
     } : !transform.any_op
 
     transform.apply_patterns to %f {
-      transform.apply_patterns.vector.reorder_and_expand_multi_reduction_dims lowering_strategy = "innerparallel"
+      transform.apply_patterns.vector.reorder_multi_reduction_dims lowering_strategy = "innerparallel"
       transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerparallel"
       transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerparallel"
     } : !transform.any_op
diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-reorder-and-expand.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-reorder.mlir
similarity index 93%
rename from mlir/test/Dialect/Vector/vector-multi-reduction-reorder-and-expand.mlir
rename to mlir/test/Dialect/Vector/vector-multi-reduction-reorder.mlir
index fff075dcb47f2..0a22205f61f90 100644
--- a/mlir/test/Dialect/Vector/vector-multi-reduction-reorder-and-expand.mlir
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-reorder.mlir
@@ -47,7 +47,7 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @innerreduction(%root : !transform.any_op {transform.readonly}) {
     %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
     transform.apply_patterns to %func_op {
-      transform.apply_patterns.vector.reorder_and_expand_multi_reduction_dims lowering_strategy = "innerreduction"
+      transform.apply_patterns.vector.reorder_multi_reduction_dims lowering_strategy = "innerreduction"
     } : !transform.op<"func.func">
     transform.yield
   }
@@ -55,7 +55,7 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @innerparallel(%root : !transform.any_op {transform.readonly}) {
     %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
     transform.apply_patterns to %func_op {
-      transform.apply_patterns.vector.reorder_and_expand_multi_reduction_dims lowering_strategy = "innerparallel"
+      transform.apply_patterns.vector.reorder_multi_reduction_dims lowering_strategy = "innerparallel"
     } : !transform.op<"func.func">
     transform.yield
   }
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_1d.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_1d.mlir
index 25b65080339d5..a7b0b27ca5fb9 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_1d.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_1d.mlir
@@ -150,7 +150,7 @@ module attributes {transform.with_named_sequence} {
     // Step 3: Lower vector.multi_reduction
     transform.apply_patterns to %func {
       transform.apply_patterns.vector.lower_masked_transfers
-      transform.apply_patterns.vector.reorder_and_expand_multi_reduction_dims lowering_strategy = "innerreduction"
+      transform.apply_patterns.vector.reorder_multi_reduction_dims lowering_strategy = "innerreduction"
       transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerreduction"
       transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerreduction"
     } : !transform.op<"func.func">
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_2d.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_2d.mlir
index 6072b44adf4fa..4adc68966f17a 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_2d.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_2d.mlir
@@ -155,7 +155,7 @@ module attributes {transform.with_named_sequence} {
     // Step 3: Lower vector.multi_reduction
     transform.apply_patterns to %func {
       transform.apply_patterns.vector.lower_masked_transfers
-      transform.apply_patterns.vector.reorder_and_expand_multi_reduction_dims lowering_strategy = "innerreduction"
+      transform.apply_patterns.vector.reorder_multi_reduction_dims lowering_strategy = "innerreduction"
       transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerreduction"
       transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerreduction"
     } : !transform.op<"func.func">
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/test-matmul-masked-vec.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/test-matmul-masked-vec.mlir
index 3c4f10316d0f3..0883e7b698f55 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/test-matmul-masked-vec.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/test-matmul-masked-vec.mlir
@@ -53,7 +53,7 @@ module attributes {transform.with_named_sequence} {
     %func_op = transform.get_parent_op %0 : (!transform.any_op) -> !transform.op<"func.func">
     transform.structured.vectorize %0 vector_sizes [4, 4, 2] : !transform.any_op
     transform.apply_patterns to %func_op {
-      transform.apply_patterns.vector.reorder_and_expand_multi_reduction_dims lowering_strategy = "innerreduction"
+      transform.apply_patterns.vector.reorder_multi_reduction_dims lowering_strategy = "innerreduction"
       transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerreduction"
       transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerreduction"
     } : !transform.op<"func.func">
diff --git a/mlir/test/python/dialects/transform_vector_ext.py b/mlir/test/python/dialects/transform_vector_ext.py
index 8a3091d0b1b02..a3c53a45048b2 100644
--- a/mlir/test/python/dialects/transform_vector_ext.py
+++ b/mlir/test/python/dialects/transform_vector_ext.py
@@ -87,11 +87,11 @@ def enum_configurable_patterns():
         lowering_strategy=vector.VectorContractLowering.ParallelArith
     )
 
-    # CHECK: transform.apply_patterns.vector.reorder_and_expand_multi_reduction_dims
-    vector.ApplyReorderAndExpandMultiReductionPatternsOp()
-    # CHECK: transform.apply_patterns.vector.reorder_and_expand_multi_reduction_dims
+    # CHECK: transform.apply_patterns.vector.reorder_multi_reduction_dims
+    vector.ApplyReorderMultiReductionPatternsOp()
+    # CHECK: transform.apply_patterns.vector.reorder_multi_reduction_dims
     # CHECK-SAME: lowering_strategy = innerreduction
-    vector.ApplyReorderAndExpandMultiReductionPatternsOp(
+    vector.ApplyReorderMultiReductionPatternsOp(
         lowering_strategy=vector.VectorMultiReductionLowering.InnerReduction
     )
 

>From 4bfe9e403c7ea1ad32f369375f3ba2bbb729fa8c Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Mon, 2 Mar 2026 16:53:32 -0500
Subject: [PATCH 04/10] Style

---
 .../Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp   | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 1ba371005ab8a..d1f7e0f16fc2a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -474,8 +474,8 @@ struct OneDimMultiReductionToReduction
     Value mask = maskingOp ? maskingOp.getMask() : Value();
 
     Operation *reductionOp = vector::ReductionOp::create(
-        rewriter, loc, multiReductionOp.getKind(),
-        multiReductionOp.getSource(), multiReductionOp.getAcc());
+        rewriter, loc, multiReductionOp.getKind(), multiReductionOp.getSource(),
+        multiReductionOp.getAcc());
 
     if (mask)
       reductionOp = mlir::vector::maskOperation(rewriter, reductionOp, mask);

>From 61aa1fd74d1ebf9eff0905fc21896aafc81d1f5b Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Wed, 4 Mar 2026 10:29:33 -0500
Subject: [PATCH 05/10] Address review comment

---
 .../mlir/Dialect/Vector/TransformOps/VectorTransformOps.td      | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index b370d17ea6ea9..dcd5f6ff3ad74 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -229,7 +229,7 @@ def ApplyReorderMultiReductionPatternsOp: Op<Transform_Dialect,
   let description = [{
     Indicates that vector multi_reduction-like operations should be
     transformed such that all reduction dimensions become innermost or
-    outermost, and 1-D reductions are lifted to 2-D.
+    outermost, depending on `lowering_strategy`.
 
     This populates the patterns from
     `populateVectorMultiReductionReorderPatterns`, i.e.:

>From bf84cdcdf6df7f6b24ea9af78842a4ee38997163 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Wed, 4 Mar 2026 10:40:51 -0500
Subject: [PATCH 06/10] Revert benefit to be 1

---
 mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index a843cc7f49d35..5181f94b8ca49 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -68,7 +68,7 @@ void populateVectorOuterProductLoweringPatterns(RewritePatternSet &patterns,
 /// operations.
 void populateVectorMultiReductionReorderPatterns(
     RewritePatternSet &patterns, VectorMultiReductionLowering options,
-    PatternBenefit benefit = 2);
+    PatternBenefit benefit = 1);
 
 /// Populate the pattern set with the following patterns:
 ///

>From eaca0557a3027f5ce22becf6b22695b032ae6388 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Wed, 4 Mar 2026 10:47:00 -0500
Subject: [PATCH 07/10] Address comment

---
 .../lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp | 2 --
 1 file changed, 2 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index d1f7e0f16fc2a..40de61fc9d402 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -444,8 +444,6 @@ struct TwoDimMultiReductionToReduction
 };
 
 /// Converts 1-D vector.multi_reduction directly to vector.reduction.
-/// This is the terminal case for unrolling - once we reach rank 1,
-/// we convert to vector.reduction which backends can optimize.
 ///
 /// Example:
 /// ```mlir

>From 0e646894c3b0df4c31d189a48000b7ac7372a1ab Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Wed, 4 Mar 2026 10:49:51 -0500
Subject: [PATCH 08/10] Address comments

---
 mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 5181f94b8ca49..aa75eff409ef9 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -84,8 +84,7 @@ void populateVectorMultiReductionFlatteningPatterns(
 /// Populate the pattern set with the following patterns:
 ///
 /// [OneDimMultiReductionToReduction]
-/// Converts 1-D vector.multi_reduction directly to vector.reduction.
-/// This is the terminal case for unrolling.
+/// Converts 1-D vector.multi_reduction to vector.reduction.
 ///
 /// [TwoDimMultiReductionToElementWise]
 /// Once in 2-D vector.multi_reduction form, with an **outermost** reduction

>From b59fc99ae3027cc1b410fcd5b945a604bccfaaa0 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Wed, 4 Mar 2026 10:50:44 -0500
Subject: [PATCH 09/10] address comment

---
 .../lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 40de61fc9d402..dccc52bb7d55a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -375,7 +375,7 @@ struct TwoDimMultiReductionToElementWise
   }
 };
 
-/// Lowers 2D vector.multi_reduction to a sequence of vector.reduction ops.
+/// Lowers 2D vector.multi_reduction to a sequence of vector.reduction Ops.
 ///
 /// The reduction dimension must be the inner-most dimension.
 ///

>From c5894e452192fcca184f8dd13d3552a033d3d124 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Wed, 4 Mar 2026 10:51:18 -0500
Subject: [PATCH 10/10] Consistency

---
 .../lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index dccc52bb7d55a..76599822fbfe4 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -443,7 +443,7 @@ struct TwoDimMultiReductionToReduction
   }
 };
 
-/// Converts 1-D vector.multi_reduction directly to vector.reduction.
+/// Converts 1D vector.multi_reduction directly to vector.reduction.
 ///
 /// Example:
 /// ```mlir



More information about the Mlir-commits mailing list