[Mlir-commits] [mlir] [vector][multi_reduction] Add unrolling for vector.multi_reduction (PR #185033)

Erick Ochoa Lopez llvmlistbot at llvm.org
Fri Mar 6 10:25:38 PST 2026


https://github.com/amd-eochoalo updated https://github.com/llvm/llvm-project/pull/185033

>From d07bef658ff8167892b8e56769bc4e52b5b9a712 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Fri, 6 Mar 2026 09:47:24 -0500
Subject: [PATCH 1/6] [mlir][vector] Rename multi_reduction's unrolling to
 lowering.

Stated previously:

> Below are some notes on naming
> * RewriteXAsY or LowerXToY (if we're changing op kind)
>
> Concretely, TwoDimMultiReductionToReduction looks like a lowering
> (it rewrites to vector.reduction), not an unrolling,

https://github.com/llvm/llvm-project/pull/182301#issuecomment-3951827980
---
 .../Dialect/Vector/TransformOps/VectorTransformOps.td  |  8 ++++----
 .../mlir/Dialect/Vector/Transforms/LoweringPatterns.h  |  2 +-
 .../Dialect/Vector/TransformOps/VectorTransformOps.cpp |  4 ++--
 .../Vector/Transforms/LowerVectorMultiReduction.cpp    | 10 +++++-----
 mlir/test/Dialect/LLVM/transform-e2e.mlir              |  2 +-
 mlir/test/Dialect/Vector/transform-vector.mlir         |  2 +-
 .../Vector/vector-multi-reduction-unrolling.mlir       |  4 ++--
 .../Dialect/Linalg/CPU/ArmSVE/reduce_1d.mlir           |  2 +-
 .../Dialect/Linalg/CPU/ArmSVE/reduce_2d.mlir           |  2 +-
 .../Dialect/Linalg/CPU/test-matmul-masked-vec.mlir     |  2 +-
 mlir/test/python/dialects/transform_vector_ext.py      |  6 +++---
 11 files changed, 22 insertions(+), 22 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index dcd5f6ff3ad74..76e77b9f90a13 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -262,18 +262,18 @@ def ApplyMultiReductionFlatteningPatternsOp: Op<Transform_Dialect,
   }];
 }
 
-def ApplyMultiReductionUnrollingPatternsOp: Op<Transform_Dialect,
-    "apply_patterns.vector.multi_reduction_unrolling",
+def ApplyMultiReductionLoweringPatternsOp: Op<Transform_Dialect,
+    "apply_patterns.vector.multi_reduction_lowering",
     [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
   let description = [{
-    Indicates that vector multi_reduction operations should be unrolled.
+    Indicates that vector multi_reduction operations should be lowered.
     1-D multi_reductions are converted directly to vector.reduction.
     2-D multi_reductions are unrolled into either a sequence of
     vector.reduction ops (innerreduction) or element-wise arith ops
     (innerparallel).
 
     This populates the patterns from
-    `populateVectorMultiReductionUnrollingPatterns`, i.e.:
+    `populateVectorMultiReductionLoweringPatterns`, i.e.:
     * `OneDimMultiReductionToReduction`
     * `TwoDimMultiReductionToReduction` (innerreduction)
     * `TwoDimMultiReductionToElementWise` (innerparallel)
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index aa75eff409ef9..2ea3367da414d 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -96,7 +96,7 @@ void populateVectorMultiReductionFlatteningPatterns(
 /// dimension, unroll the outer dimension to obtain a sequence of extract +
 /// vector.reduction + insert. This can further lower to horizontal reduction
 /// ops.
-void populateVectorMultiReductionUnrollingPatterns(
+void populateVectorMultiReductionLoweringPatterns(
     RewritePatternSet &patterns, VectorMultiReductionLowering options,
     PatternBenefit benefit = 1);
 
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 312bd28ad48cf..d1799b3bd84d8 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -145,11 +145,11 @@ void transform::ApplyMultiReductionFlatteningPatternsOp::populatePatterns(
       patterns, vectorTransformOptions.vectorMultiReductionLowering);
 }
 
-void transform::ApplyMultiReductionUnrollingPatternsOp::populatePatterns(
+void transform::ApplyMultiReductionLoweringPatternsOp::populatePatterns(
     RewritePatternSet &patterns) {
   vector::VectorTransformsOptions vectorTransformOptions;
   vectorTransformOptions.setVectorMultiReductionLowering(getLoweringStrategy());
-  vector::populateVectorMultiReductionUnrollingPatterns(
+  vector::populateVectorMultiReductionLoweringPatterns(
       patterns, vectorTransformOptions.vectorMultiReductionLowering);
 }
 
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 76599822fbfe4..88e669d1a34b0 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -505,10 +505,10 @@ struct LowerVectorMultiReductionPass
     if (failed(applyPatternsGreedily(op, std::move(flatteningPatterns))))
       signalPassFailure();
 
-    RewritePatternSet unrollingPatterns(context);
-    mlir::vector::populateVectorMultiReductionUnrollingPatterns(
-        unrollingPatterns, this->loweringStrategy);
-    if (failed(applyPatternsGreedily(op, std::move(unrollingPatterns))))
+    RewritePatternSet loweringPatterns(context);
+    mlir::vector::populateVectorMultiReductionLoweringPatterns(
+        loweringPatterns, this->loweringStrategy);
+    if (failed(applyPatternsGreedily(op, std::move(loweringPatterns))))
       signalPassFailure();
   }
 
@@ -532,7 +532,7 @@ void mlir::vector::populateVectorMultiReductionFlatteningPatterns(
   patterns.add<FlattenMultiReduction>(patterns.getContext(), options, benefit);
 }
 
-void mlir::vector::populateVectorMultiReductionUnrollingPatterns(
+void mlir::vector::populateVectorMultiReductionLoweringPatterns(
     RewritePatternSet &patterns, VectorMultiReductionLowering options,
     PatternBenefit benefit) {
   patterns.add<OneDimMultiReductionToReduction>(patterns.getContext(), benefit);
diff --git a/mlir/test/Dialect/LLVM/transform-e2e.mlir b/mlir/test/Dialect/LLVM/transform-e2e.mlir
index bf7eba6e50174..7f4c78a7fc2c8 100644
--- a/mlir/test/Dialect/LLVM/transform-e2e.mlir
+++ b/mlir/test/Dialect/LLVM/transform-e2e.mlir
@@ -32,7 +32,7 @@ module attributes {transform.with_named_sequence} {
       transform.apply_patterns.vector.transfer_permutation_patterns
       transform.apply_patterns.vector.reorder_multi_reduction_dims lowering_strategy = "innerparallel"
       transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerparallel"
-      transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerparallel"
+      transform.apply_patterns.vector.multi_reduction_lowering lowering_strategy = "innerparallel"
       transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "linalg-copy"
       transform.apply_patterns.vector.transfer_to_scf max_transfer_rank = 1 full_unroll = true
       transform.apply_patterns.vector.lower_transfer max_transfer_rank = 1
diff --git a/mlir/test/Dialect/Vector/transform-vector.mlir b/mlir/test/Dialect/Vector/transform-vector.mlir
index 4dc11c26e83f1..9027bef2eac5c 100644
--- a/mlir/test/Dialect/Vector/transform-vector.mlir
+++ b/mlir/test/Dialect/Vector/transform-vector.mlir
@@ -41,7 +41,7 @@ module attributes {transform.with_named_sequence} {
     transform.apply_patterns to %f {
       transform.apply_patterns.vector.reorder_multi_reduction_dims lowering_strategy = "innerparallel"
       transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerparallel"
-      transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerparallel"
+      transform.apply_patterns.vector.multi_reduction_lowering lowering_strategy = "innerparallel"
     } : !transform.any_op
 
     transform.apply_patterns to %f {
diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-unrolling.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-unrolling.mlir
index 447416ccba637..a7ec4b9206c33 100644
--- a/mlir/test/Dialect/Vector/vector-multi-reduction-unrolling.mlir
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-unrolling.mlir
@@ -117,7 +117,7 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @innerreduction(%root : !transform.any_op {transform.readonly}) {
     %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
     transform.apply_patterns to %func_op {
-      transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerreduction"
+      transform.apply_patterns.vector.multi_reduction_lowering lowering_strategy = "innerreduction"
     } : !transform.op<"func.func">
     transform.yield
   }
@@ -125,7 +125,7 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @innerparallel(%root : !transform.any_op {transform.readonly}) {
     %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
     transform.apply_patterns to %func_op {
-      transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerparallel"
+      transform.apply_patterns.vector.multi_reduction_lowering lowering_strategy = "innerparallel"
     } : !transform.op<"func.func">
     transform.yield
   }
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_1d.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_1d.mlir
index a7b0b27ca5fb9..15f5bfd3619eb 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_1d.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_1d.mlir
@@ -152,7 +152,7 @@ module attributes {transform.with_named_sequence} {
       transform.apply_patterns.vector.lower_masked_transfers
       transform.apply_patterns.vector.reorder_multi_reduction_dims lowering_strategy = "innerreduction"
       transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerreduction"
-      transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerreduction"
+      transform.apply_patterns.vector.multi_reduction_lowering lowering_strategy = "innerreduction"
     } : !transform.op<"func.func">
 
     transform.yield
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_2d.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_2d.mlir
index 4adc68966f17a..0cfba97654769 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_2d.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_2d.mlir
@@ -157,7 +157,7 @@ module attributes {transform.with_named_sequence} {
       transform.apply_patterns.vector.lower_masked_transfers
       transform.apply_patterns.vector.reorder_multi_reduction_dims lowering_strategy = "innerreduction"
       transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerreduction"
-      transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerreduction"
+      transform.apply_patterns.vector.multi_reduction_lowering lowering_strategy = "innerreduction"
     } : !transform.op<"func.func">
 
     transform.yield
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/test-matmul-masked-vec.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/test-matmul-masked-vec.mlir
index 0883e7b698f55..ce85720d665c2 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/test-matmul-masked-vec.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/test-matmul-masked-vec.mlir
@@ -55,7 +55,7 @@ module attributes {transform.with_named_sequence} {
     transform.apply_patterns to %func_op {
       transform.apply_patterns.vector.reorder_multi_reduction_dims lowering_strategy = "innerreduction"
       transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerreduction"
-      transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerreduction"
+      transform.apply_patterns.vector.multi_reduction_lowering lowering_strategy = "innerreduction"
     } : !transform.op<"func.func">
     transform.yield
   }
diff --git a/mlir/test/python/dialects/transform_vector_ext.py b/mlir/test/python/dialects/transform_vector_ext.py
index a3c53a45048b2..cfaf91d24f471 100644
--- a/mlir/test/python/dialects/transform_vector_ext.py
+++ b/mlir/test/python/dialects/transform_vector_ext.py
@@ -103,9 +103,9 @@ def enum_configurable_patterns():
         lowering_strategy=vector.VectorMultiReductionLowering.InnerReduction
     )
 
-    # CHECK: transform.apply_patterns.vector.multi_reduction_unrolling
-    vector.ApplyMultiReductionUnrollingPatternsOp()
-    # CHECK: transform.apply_patterns.vector.multi_reduction_unrolling
+    # CHECK: transform.apply_patterns.vector.multi_reduction_lowering
+    vector.ApplyMultiReductionLoweringPatternsOp()
+    # CHECK: transform.apply_patterns.vector.multi_reduction_lowering
     # CHECK-SAME: lowering_strategy = innerreduction
     vector.ApplyMultiReductionUnrollingPatternsOp(
         lowering_strategy=vector.VectorMultiReductionLowering.InnerReduction

>From ba1fffa23f4b9cf6862897959db2cbecf2eea9d4 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Fri, 6 Mar 2026 10:01:00 -0500
Subject: [PATCH 2/6] [mlir][vector] Add multi_reduction unrolling
 innerparallel.

When unrolling vector.multi_reduction with outermost
dimensions being reduced, exctract outermost dimension
vectors and chain them with vector.multi_reductions.
---
 .../Vector/TransformOps/VectorTransformOps.td | 14 ++++
 .../Vector/Transforms/LoweringPatterns.h      |  5 ++
 .../TransformOps/VectorTransformOps.cpp       |  5 ++
 .../Transforms/LowerVectorMultiReduction.cpp  | 82 +++++++++++++++++++
 4 files changed, 106 insertions(+)

diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index 76e77b9f90a13..dff2803f95ae6 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -262,6 +262,20 @@ def ApplyMultiReductionFlatteningPatternsOp: Op<Transform_Dialect,
   }];
 }
 
+def ApplyUnrollMultiReductionUnrollingPatternsOp: Op<Transform_Dialect,
+    "apply_patterns.vector.multi_reduction_unrolling",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Indicates that vector multi_reduction operations with more than one
+    reduction dimension should be unrolled in a rank-reducing way.
+
+    This populates the patterns:
+  }];
+
+  let assemblyFormat = "attr-dict";
+}
+
+
 def ApplyMultiReductionLoweringPatternsOp: Op<Transform_Dialect,
     "apply_patterns.vector.multi_reduction_lowering",
     [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 2ea3367da414d..d9237bac63b78 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -81,6 +81,11 @@ void populateVectorMultiReductionFlatteningPatterns(
     RewritePatternSet &patterns, VectorMultiReductionLowering options,
     PatternBenefit benefit = 1);
 
+/// Populate the pattern set with the following patterns:
+///
+void populateVectorMultiReductionUnrollingPatterns(RewritePatternSet &patterns,
+                                                   PatternBenefit benefit = 1);
+
 /// Populate the pattern set with the following patterns:
 ///
 /// [OneDimMultiReductionToReduction]
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index d1799b3bd84d8..256384f6b7c18 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -145,6 +145,11 @@ void transform::ApplyMultiReductionFlatteningPatternsOp::populatePatterns(
       patterns, vectorTransformOptions.vectorMultiReductionLowering);
 }
 
+void transform::ApplyUnrollMultiReductionUnrollingPatternsOp::populatePatterns(
+    RewritePatternSet &patterns) {
+  vector::populateVectorMultiReductionUnrollingPatterns(patterns);
+}
+
 void transform::ApplyMultiReductionLoweringPatternsOp::populatePatterns(
     RewritePatternSet &patterns) {
   vector::VectorTransformsOptions vectorTransformOptions;
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 88e669d1a34b0..51b2f64c9a7da 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -443,6 +443,82 @@ struct TwoDimMultiReductionToReduction
   }
 };
 
+/// Unrolls outermost dimension for vector.multi_reduction.
+/// Matches when the outermost dimension is not the only
+/// reduction dimension.
+///
+/// ```mlir
+/// %res = vector.multi_reduction <add> %src, %acc [0, [[REDUCTION_DIMS]] ] :
+/// vector<NxMx...xf32> to vector<Ix...xf32>
+/// ```
+///
+/// ```mlir
+/// %0 = vector.extract %src[0] : vector<Mx...xf32> from vector<NxMx...xf32>
+/// ...
+/// %Nminus1 = vector.extract %src[ [[N-1]] ] : vector<Mx...x.f32> from
+/// vector<NxMx...xf32>
+///
+/// %red0 = vector.multi_reduction %0, %acc [ [[REDUCTION_DIMS]] ] :
+/// vector<Mx...xf32> to vector<Ix...xf32>
+/// ...
+/// %res = vector.multi_reduction %Nminus1, %redNminus2 [ [[REDUCTION_DIMS]] ] :
+/// vector<Mx...xf32> to vector<Ix...xf32>
+/// ```
+struct UnrollMultiReductionInnerParallel
+    : public vector::MaskableOpRewritePattern<vector::MultiDimReductionOp> {
+  using MaskableOpRewritePattern::MaskableOpRewritePattern;
+
+  FailureOr<Value>
+  matchAndRewriteMaskableOp(vector::MultiDimReductionOp multiReductionOp,
+                            vector::MaskingOpInterface maskingOp,
+                            PatternRewriter &rewriter) const override {
+    if (!multiReductionOp.isReducedDim(0))
+      return failure();
+
+    ArrayRef<int64_t> reductionDims = multiReductionOp.getReductionDims();
+    if (reductionDims.size() <= 1)
+      return failure();
+
+    Location loc = multiReductionOp.getLoc();
+    Value source = multiReductionOp.getSource();
+
+    ArrayRef<int64_t> srcShape =
+        multiReductionOp.getSourceVectorType().getShape();
+    int64_t outerDimSize = srcShape.front();
+
+    Value mask = maskingOp ? maskingOp.getMask() : nullptr;
+
+    SmallVector<Value> vectors(outerDimSize);
+    for (int64_t i = 0; i < outerDimSize; ++i)
+      vectors[i] = vector::ExtractOp::create(rewriter, loc, source, i);
+
+    SmallVector<Value> masks(outerDimSize);
+    if (mask)
+      for (int64_t i = 0; i < outerDimSize; ++i)
+        masks[i] = vector::ExtractOp::create(rewriter, loc, mask, i);
+
+    SmallVector<bool> fullReductionMask = multiReductionOp.getReductionMask();
+    ArrayRef<bool> reductionMask =
+        ArrayRef<bool>(fullReductionMask).drop_front();
+    Value result = multiReductionOp.getAcc();
+    for (auto [innerVector, innerMask] : llvm::zip_equal(vectors, masks)) {
+      auto reductionOp = vector::MultiDimReductionOp::create(
+          rewriter, loc, innerVector, result, reductionMask,
+          multiReductionOp.getKind());
+
+      if (innerMask) {
+        Operation *maskOp =
+            vector::maskOperation(rewriter, reductionOp, innerMask);
+        result = maskOp->getResult(0);
+      } else {
+        result = reductionOp.getResult();
+      }
+    }
+
+    return result;
+  }
+};
+
 /// Converts 1D vector.multi_reduction directly to vector.reduction.
 ///
 /// Example:
@@ -544,6 +620,12 @@ void mlir::vector::populateVectorMultiReductionLoweringPatterns(
                                                     benefit);
 }
 
+void mlir::vector::populateVectorMultiReductionUnrollingPatterns(
+    RewritePatternSet &patterns, PatternBenefit benefit) {
+  patterns.add<UnrollMultiReductionInnerParallel>(patterns.getContext(),
+                                                  benefit);
+}
+
 std::unique_ptr<Pass> vector::createLowerVectorMultiReductionPass(
     vector::VectorMultiReductionLowering option) {
   return std::make_unique<LowerVectorMultiReductionPass>(option);

>From 2ea3b59224e103f515c1abebc802435147a4dbff Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Fri, 6 Mar 2026 10:37:47 -0500
Subject: [PATCH 3/6] [mlir][vector] Add multi_reduction unrolling
 innerreduction.

Unrolls multi_reduction by series of extractions, multi_reduction
and insertions.
---
 .../Vector/TransformOps/VectorTransformOps.td |  13 +-
 .../Vector/Transforms/LoweringPatterns.h      |   5 +-
 .../TransformOps/VectorTransformOps.cpp       |   5 +-
 .../Transforms/LowerVectorMultiReduction.cpp  | 125 +++++++++++++++++-
 .../python/dialects/transform_vector_ext.py   |   4 +-
 5 files changed, 143 insertions(+), 9 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index dff2803f95ae6..c9226af30e674 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -269,10 +269,19 @@ def ApplyUnrollMultiReductionUnrollingPatternsOp: Op<Transform_Dialect,
     Indicates that vector multi_reduction operations with more than one
     reduction dimension should be unrolled in a rank-reducing way.
 
-    This populates the patterns:
+    This populates the patterns from
+    `populateVectorMultiReductionUnrollingPatterns`, i.e.:
+    * `UnrollMultiReductionInnerReduction` (inner_reduction)
+    * `UnrollMultiReductionInnerParallel` (inner_parallel)
   }];
 
-  let assemblyFormat = "attr-dict";
+  let arguments = (ins DefaultValuedAttr<VectorMultiReductionLoweringAttr,
+      "vector::VectorMultiReductionLowering::InnerParallel">:$lowering_strategy
+  );
+
+  let assemblyFormat = [{
+    (`lowering_strategy` `=` $lowering_strategy^)? attr-dict
+  }];
 }
 
 
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index d9237bac63b78..744a7d4b3b425 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -83,8 +83,9 @@ void populateVectorMultiReductionFlatteningPatterns(
 
 /// Populate the pattern set with the following patterns:
 ///
-void populateVectorMultiReductionUnrollingPatterns(RewritePatternSet &patterns,
-                                                   PatternBenefit benefit = 1);
+void populateVectorMultiReductionUnrollingPatterns(
+    RewritePatternSet &patterns, VectorMultiReductionLowering options,
+    PatternBenefit benefit = 1);
 
 /// Populate the pattern set with the following patterns:
 ///
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 256384f6b7c18..f63d43a7d165c 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -147,7 +147,10 @@ void transform::ApplyMultiReductionFlatteningPatternsOp::populatePatterns(
 
 void transform::ApplyUnrollMultiReductionUnrollingPatternsOp::populatePatterns(
     RewritePatternSet &patterns) {
-  vector::populateVectorMultiReductionUnrollingPatterns(patterns);
+  vector::VectorTransformsOptions vectorTransformOptions;
+  vectorTransformOptions.setVectorMultiReductionLowering(getLoweringStrategy());
+  vector::populateVectorMultiReductionUnrollingPatterns(
+      patterns, vectorTransformOptions.vectorMultiReductionLowering);
 }
 
 void transform::ApplyMultiReductionLoweringPatternsOp::populatePatterns(
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 51b2f64c9a7da..e347ba15537bd 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -519,6 +519,120 @@ struct UnrollMultiReductionInnerParallel
   }
 };
 
+/// Unrolls vector.multi_reduction along the outermost parallel dimension
+/// when the innermost dimension is a reduction dimension.
+///
+/// This pattern matches operations where:
+/// - The innermost dimension is a reduction dimension
+/// - The outermost dimension is a parallel dimension
+///
+/// The transformation extracts slices along the outermost parallel dimension,
+/// creates smaller multi_reductions, and assembles the results:
+///
+/// ```mlir
+/// %res = vector.multi_reduction <add> %src, %acc [1, 3]
+///     : vector<AxBxCxDxf32> to vector<AxCxf32>
+/// ```
+///
+/// becomes:
+///
+/// ```mlir
+/// %result = arith.constant dense<0.0> : vector<AxCxf32>
+/// %0 = vector.extract %src[0] : vector<BxCxDxf32> from vector<AxBxCxDxf32>
+/// %acc0 = vector.extract %acc[0] : vector<Cxf32> from vector<AxCxf32>
+/// %red0 = vector.multi_reduction <add>, %0, %acc0 [0, 2]
+///     : vector<BxCxDxf32> to vector<Cxf32>
+/// %res0 = vector.insert %red0, %result[0]
+///     : vector<Cxf32> into vector<AxCxf32>
+/// // ... repeat for indices 1 to A-1
+/// ```
+struct UnrollMultiReductionInnerReduction
+    : public vector::MaskableOpRewritePattern<vector::MultiDimReductionOp> {
+  using MaskableOpRewritePattern::MaskableOpRewritePattern;
+
+  FailureOr<Value>
+  matchAndRewriteMaskableOp(vector::MultiDimReductionOp multiReductionOp,
+                            vector::MaskingOpInterface maskingOp,
+                            PatternRewriter &rewriter) const override {
+    auto srcRank = multiReductionOp.getSourceVectorType().getRank();
+
+    if (srcRank < 2)
+      return rewriter.notifyMatchFailure(multiReductionOp,
+                                         "expected source rank >= 2.");
+
+    if (!multiReductionOp.isReducedDim(srcRank - 1))
+      return rewriter.notifyMatchFailure(
+          multiReductionOp,
+          "expected innermost dimension to be a reduction dimension.");
+
+    if (multiReductionOp.isReducedDim(0))
+      return rewriter.notifyMatchFailure(
+          multiReductionOp,
+          "expected outermost dimension to be a parallel dimension.");
+
+    Type elementType = getElementTypeOrSelf(multiReductionOp.getDestType());
+    if (!elementType.isIntOrIndexOrFloat())
+      return rewriter.notifyMatchFailure(
+          multiReductionOp, "expected integer or float element type.");
+
+    Location loc = multiReductionOp.getLoc();
+    Value source = multiReductionOp.getSource();
+    Value acc = multiReductionOp.getAcc();
+
+    ArrayRef<int64_t> srcShape =
+        multiReductionOp.getSourceVectorType().getShape();
+    int64_t numSlices = srcShape.front();
+
+    Value mask = maskingOp ? maskingOp.getMask() : Value();
+
+    SmallVector<Value> srcSlices;
+    for (int64_t i = 0; i < numSlices; ++i)
+      srcSlices.push_back(vector::ExtractOp::create(rewriter, loc, source, i));
+
+    SmallVector<Value> accSlices;
+    for (int64_t i = 0; i < numSlices; ++i)
+      accSlices.push_back(vector::ExtractOp::create(rewriter, loc, acc, i));
+
+    SmallVector<Value> maskSlices;
+    for (int64_t i = 0; i < numSlices; ++i)
+      if (mask)
+        maskSlices.push_back(vector::ExtractOp::create(rewriter, loc, mask, i));
+      else
+        maskSlices.push_back(nullptr);
+
+    // Compute new reduction mask by dropping the first element (dimension 0).
+    // Since dimension 0 is parallel (not reduced), all reduction indices shift
+    // down by 1.
+    SmallVector<bool> fullReductionMask = multiReductionOp.getReductionMask();
+    ArrayRef<bool> newReductionMask =
+        ArrayRef<bool>(fullReductionMask).drop_front();
+
+    SmallVector<Value> reductionResults;
+    for (auto [srcSlice, accSlice, maskSlice] :
+         llvm::zip(srcSlices, accSlices, maskSlices)) {
+      Operation *newReductionOp = vector::MultiDimReductionOp::create(
+          rewriter, loc, srcSlice, accSlice, newReductionMask,
+          multiReductionOp.getKind());
+
+      if (maskSlice)
+        newReductionOp =
+            mlir::vector::maskOperation(rewriter, newReductionOp, maskSlice);
+
+      reductionResults.push_back(newReductionOp->getResult(0));
+    }
+
+    Value result = arith::ConstantOp::create(
+        rewriter, loc, multiReductionOp.getDestType(),
+        rewriter.getZeroAttr(multiReductionOp.getDestType()));
+
+    for (int64_t i = 0; i < numSlices; ++i)
+      result = vector::InsertOp::create(rewriter, loc, reductionResults[i],
+                                        result, i);
+
+    return result;
+  }
+};
+
 /// Converts 1D vector.multi_reduction directly to vector.reduction.
 ///
 /// Example:
@@ -621,9 +735,14 @@ void mlir::vector::populateVectorMultiReductionLoweringPatterns(
 }
 
 void mlir::vector::populateVectorMultiReductionUnrollingPatterns(
-    RewritePatternSet &patterns, PatternBenefit benefit) {
-  patterns.add<UnrollMultiReductionInnerParallel>(patterns.getContext(),
-                                                  benefit);
+    RewritePatternSet &patterns, VectorMultiReductionLowering options,
+    PatternBenefit benefit) {
+  if (options == VectorMultiReductionLowering ::InnerReduction)
+    patterns.add<UnrollMultiReductionInnerReduction>(patterns.getContext(),
+                                                     benefit);
+  else
+    patterns.add<UnrollMultiReductionInnerParallel>(patterns.getContext(),
+                                                    benefit);
 }
 
 std::unique_ptr<Pass> vector::createLowerVectorMultiReductionPass(
diff --git a/mlir/test/python/dialects/transform_vector_ext.py b/mlir/test/python/dialects/transform_vector_ext.py
index cfaf91d24f471..84d70181b5288 100644
--- a/mlir/test/python/dialects/transform_vector_ext.py
+++ b/mlir/test/python/dialects/transform_vector_ext.py
@@ -105,7 +105,9 @@ def enum_configurable_patterns():
 
     # CHECK: transform.apply_patterns.vector.multi_reduction_lowering
     vector.ApplyMultiReductionLoweringPatternsOp()
-    # CHECK: transform.apply_patterns.vector.multi_reduction_lowering
+    # CHECK: transform.apply_patterns.vector.multi_reduction_unrolling
+    vector.ApplyUnrollMultiReductionUnrollingPatternsOp()
+    # CHECK: transform.apply_patterns.vector.multi_reduction_unrolling
     # CHECK-SAME: lowering_strategy = innerreduction
     vector.ApplyMultiReductionUnrollingPatternsOp(
         lowering_strategy=vector.VectorMultiReductionLowering.InnerReduction

>From 4328d792b9c104cb50346deb039440e08f071f1f Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Fri, 6 Mar 2026 10:39:07 -0500
Subject: [PATCH 4/6] Add tests

---
 ...unroll-multi-reduction-inner-parallel.mlir | 27 ++++++++
 ...nroll-multi-reduction-inner-reduction.mlir | 62 +++++++++++++++++++
 2 files changed, 89 insertions(+)
 create mode 100644 mlir/test/Dialect/Vector/vector-unroll-multi-reduction-inner-parallel.mlir
 create mode 100644 mlir/test/Dialect/Vector/vector-unroll-multi-reduction-inner-reduction.mlir

diff --git a/mlir/test/Dialect/Vector/vector-unroll-multi-reduction-inner-parallel.mlir b/mlir/test/Dialect/Vector/vector-unroll-multi-reduction-inner-parallel.mlir
new file mode 100644
index 0000000000000..4f67d5678d6d3
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-unroll-multi-reduction-inner-parallel.mlir
@@ -0,0 +1,27 @@
+// RUN: mlir-opt %s --transform-interpreter | FileCheck %s
+
+// CHECK-LABEL: func @unroll_multi_reduction_inner_parallel
+// CHECK-SAME:    %[[INPUT:.+]]: vector<4x2x3xf32>, %[[ACC:.+]]: vector<2xf32>
+func.func @unroll_multi_reduction_inner_parallel(%arg0: vector<4x2x3xf32>, %acc: vector<2xf32>) -> vector<2xf32> {
+    // CHECK: %[[V0:.+]] = vector.extract %[[INPUT]][0] : vector<2x3xf32> from vector<4x2x3xf32>
+    // CHECK: %[[V1:.+]] = vector.extract %[[INPUT]][1] : vector<2x3xf32> from vector<4x2x3xf32>
+    // CHECK: %[[V2:.+]] = vector.extract %[[INPUT]][2] : vector<2x3xf32> from vector<4x2x3xf32>
+    // CHECK: %[[V3:.+]] = vector.extract %[[INPUT]][3] : vector<2x3xf32> from vector<4x2x3xf32>
+    // CHECK: %[[RV0:.+]] = vector.multi_reduction <mul>, %[[V0]], %[[ACC]] [1] : vector<2x3xf32> to vector<2xf32>
+    // CHECK: %[[RV1:.+]] = vector.multi_reduction <mul>, %[[V1]], %[[RV0]] [1] : vector<2x3xf32> to vector<2xf32>
+    // CHECK: %[[RV2:.+]] = vector.multi_reduction <mul>, %[[V2]], %[[RV1]] [1] : vector<2x3xf32> to vector<2xf32>
+    // CHECK: %[[RESULT:.+]] = vector.multi_reduction <mul>, %[[V3]], %[[RV2]] [1] : vector<2x3xf32> to vector<2xf32>
+    %0 = vector.multi_reduction <mul>, %arg0, %acc [0, 2] : vector<4x2x3xf32> to vector<2xf32>
+    // CHECK:             return %[[RESULT]]
+    return %0 : vector<2xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) {
+    %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
+    transform.apply_patterns to %func_op {
+      transform.apply_patterns.vector.multi_reduction_unrolling
+    } : !transform.op<"func.func">
+    transform.yield
+  }
+}
diff --git a/mlir/test/Dialect/Vector/vector-unroll-multi-reduction-inner-reduction.mlir b/mlir/test/Dialect/Vector/vector-unroll-multi-reduction-inner-reduction.mlir
new file mode 100644
index 0000000000000..04217479e377b
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-unroll-multi-reduction-inner-reduction.mlir
@@ -0,0 +1,62 @@
+// RUN: mlir-opt %s --transform-interpreter | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// Test UnrollVectorMultiReduction for Inner Reduction
+//===----------------------------------------------------------------------===//
+//
+// The general case handles multiple reduction dimensions.
+// For vector<2x3x5xf32> with reduction on dims [1, 2]:
+// UnrollMultiReductionInnerReduction unrolls along dim 0 (size 2), creating
+// two vector<3x5xf32> multi_reductions with dims [0, 1], then insert results.
+
+// CHECK-LABEL: func @unroll_vector_multi_reduction_inner_general(
+// CHECK-SAME: %[[SOURCE:.+]]: vector<2x3x5xf32>,
+// CHECK-SAME: %[[ACC:.+]]: vector<2xf32>
+func.func @unroll_vector_multi_reduction_inner_general(%source: vector<2x3x5xf32>, %acc: vector<2xf32>) -> (vector<2xf32>) {
+  // CHECK-DAG: %[[SRC_0:.+]] = vector.extract %[[SOURCE]][0] : vector<3x5xf32> from vector<2x3x5xf32>
+  // CHECK-DAG: %[[SRC_1:.+]] = vector.extract %[[SOURCE]][1] : vector<3x5xf32> from vector<2x3x5xf32>
+  // CHECK-DAG: %[[ACC_0:.+]] = vector.extract %[[ACC]][0] : f32 from vector<2xf32>
+  // CHECK-DAG: %[[ACC_1:.+]] = vector.extract %[[ACC]][1] : f32 from vector<2xf32>
+  // CHECK: %[[R0:.+]] = vector.multi_reduction <add>, %[[SRC_0]], %[[ACC_0]] [0, 1] : vector<3x5xf32> to f32
+  // CHECK: %[[R1:.+]] = vector.multi_reduction <add>, %[[SRC_1]], %[[ACC_1]] [0, 1] : vector<3x5xf32> to f32
+  // CHECK: %[[INSERT_0:.+]] = vector.insert %[[R0]], %{{.+}} [0] : f32 into vector<2xf32>
+  // CHECK: %[[INSERT_1:.+]] = vector.insert %[[R1]], %[[INSERT_0]] [1] : f32 into vector<2xf32>
+  %1 = vector.multi_reduction <add>, %source, %acc [1, 2] : vector<2x3x5xf32> to vector<2xf32>
+
+  // CHECK: return %[[INSERT_1]]
+  return %1 : vector<2xf32>
+}
+
+// CHECK-LABEL: func @unroll_vector_multi_reduction_inner_general_masked(
+// CHECK-SAME: %[[SOURCE:.+]]: vector<2x3x5xf32>,
+// CHECK-SAME: %[[MASK:.+]]: vector<2x3x5xi1>,
+// CHECK-SAME: %[[ACC:.+]]: vector<2xf32>
+func.func @unroll_vector_multi_reduction_inner_general_masked(%source: vector<2x3x5xf32>, %mask: vector<2x3x5xi1>, %acc: vector<2xf32>) -> (vector<2xf32>) {
+  // CHECK-DAG: %[[SRC_0:.+]] = vector.extract %[[SOURCE]][0] : vector<3x5xf32> from vector<2x3x5xf32>
+  // CHECK-DAG: %[[SRC_1:.+]] = vector.extract %[[SOURCE]][1] : vector<3x5xf32> from vector<2x3x5xf32>
+  // CHECK-DAG: %[[ACC_0:.+]] = vector.extract %[[ACC]][0] : f32 from vector<2xf32>
+  // CHECK-DAG: %[[ACC_1:.+]] = vector.extract %[[ACC]][1] : f32 from vector<2xf32>
+  // CHECK-DAG: %[[MASK_0:.+]] = vector.extract %[[MASK]][0] : vector<3x5xi1> from vector<2x3x5xi1>
+  // CHECK-DAG: %[[MASK_1:.+]] = vector.extract %[[MASK]][1] : vector<3x5xi1> from vector<2x3x5xi1>
+  // CHECK: %[[R0:.+]] = vector.mask %[[MASK_0]] {{.*}} vector.multi_reduction <add>, %[[SRC_0]], %[[ACC_0]] [0, 1] : vector<3x5xf32> to f32
+  // CHECK: %[[R1:.+]] = vector.mask %[[MASK_1]] {{.*}} vector.multi_reduction <add>, %[[SRC_1]], %[[ACC_1]] [0, 1] : vector<3x5xf32> to f32
+  // CHECK: %[[INSERT_0:.+]] = vector.insert %[[R0]], %{{.+}} [0] : f32 into vector<2xf32>
+  // CHECK: %[[INSERT_1:.+]] = vector.insert %[[R1]], %[[INSERT_0]] [1] : f32 into vector<2xf32>
+
+  %0 = vector.mask %mask {
+    %1 = vector.multi_reduction <add>, %source, %acc [1, 2] : vector<2x3x5xf32> to vector<2xf32>
+  } : vector<2x3x5xi1> -> vector<2xf32>
+
+  // CHECK: return %[[INSERT_1]]
+  return %0 : vector<2xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) {
+    %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
+    transform.apply_patterns to %func_op {
+      transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = innerreduction
+    } : !transform.op<"func.func">
+    transform.yield
+  }
+}

>From 9245142ef7cad0ebe090216ad5b179507f6b5edd Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Fri, 6 Mar 2026 11:07:58 -0500
Subject: [PATCH 5/6] Update benefits to ensure previous behaviour remains the
 same

---
 .../Vector/Transforms/LoweringPatterns.h      | 22 ++++++++++++++++---
 .../Transforms/LowerVectorMultiReduction.cpp  | 16 +++++---------
 mlir/test/Dialect/LLVM/transform-e2e.mlir     |  1 +
 .../test/Dialect/Vector/transform-vector.mlir |  1 +
 .../Dialect/Linalg/CPU/ArmSVE/reduce_1d.mlir  |  1 +
 .../Dialect/Linalg/CPU/ArmSVE/reduce_2d.mlir  |  1 +
 .../Linalg/CPU/test-matmul-masked-vec.mlir    |  1 +
 7 files changed, 29 insertions(+), 14 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 744a7d4b3b425..f9eeaf4136370 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -66,9 +66,12 @@ void populateVectorOuterProductLoweringPatterns(RewritePatternSet &patterns,
 /// Rewrites vector.multi_reduction such that all reduction dimensions are
 /// either innermost or outermost, by adding the proper vector.transpose
 /// operations.
+///
+/// The benefit is set to be higher than the unrolling patterns. Otherwise
+/// patterns here would match the same operations as those in unrolling.
 void populateVectorMultiReductionReorderPatterns(
     RewritePatternSet &patterns, VectorMultiReductionLowering options,
-    PatternBenefit benefit = 1);
+    PatternBenefit benefit = 2);
 
 /// Populate the pattern set with the following patterns:
 ///
@@ -77,12 +80,22 @@ void populateVectorMultiReductionReorderPatterns(
 /// form, rewrites n-D vector.multi_reduction into 2-D vector.multi_reduction,
 /// by introducing vector.shape_cast ops to collapse + multi-reduce + expand
 /// back.
+///
+/// The benefit is set to be higher than the unrolling patterns. Otherwise
+/// patterns here would match the same operations as those in unrolling.
 void populateVectorMultiReductionFlatteningPatterns(
     RewritePatternSet &patterns, VectorMultiReductionLowering options,
-    PatternBenefit benefit = 1);
+    PatternBenefit benefit = 2);
 
 /// Populate the pattern set with the following patterns:
 ///
+/// [UnrollMultiReductionInnerReduction]
+/// Extracts vectors along outer dimension
+/// and chains multiple multi_reduction operations.
+///
+/// [UnrollMultiReductionInnerParallel]
+/// Extracts vectors along outer dimension, performs multi_reduction operations
+/// and inserts them back to a vector.multi_reduction with a lower rank.
 void populateVectorMultiReductionUnrollingPatterns(
     RewritePatternSet &patterns, VectorMultiReductionLowering options,
     PatternBenefit benefit = 1);
@@ -102,9 +115,12 @@ void populateVectorMultiReductionUnrollingPatterns(
 /// dimension, unroll the outer dimension to obtain a sequence of extract +
 /// vector.reduction + insert. This can further lower to horizontal reduction
 /// ops.
+///
+/// The benefit is set to be higher than the unrolling patterns. Otherwise
+/// patterns here would match the same operations as those in unrolling.
 void populateVectorMultiReductionLoweringPatterns(
     RewritePatternSet &patterns, VectorMultiReductionLowering options,
-    PatternBenefit benefit = 1);
+    PatternBenefit benefit = 2);
 
 /// Populate the pattern set with the following patterns:
 ///
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index e347ba15537bd..3f2b399e55343 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -686,19 +686,13 @@ struct LowerVectorMultiReductionPass
     RewritePatternSet patterns(context);
     mlir::vector::populateVectorMultiReductionReorderPatterns(
         patterns, this->loweringStrategy);
-    if (failed(applyPatternsGreedily(op, std::move(patterns))))
-      signalPassFailure();
-
-    RewritePatternSet flatteningPatterns(context);
     mlir::vector::populateVectorMultiReductionFlatteningPatterns(
-        flatteningPatterns, this->loweringStrategy);
-    if (failed(applyPatternsGreedily(op, std::move(flatteningPatterns))))
-      signalPassFailure();
-
-    RewritePatternSet loweringPatterns(context);
+        patterns, this->loweringStrategy);
+    mlir::vector::populateVectorMultiReductionUnrollingPatterns(
+        patterns, this->loweringStrategy);
     mlir::vector::populateVectorMultiReductionLoweringPatterns(
-        loweringPatterns, this->loweringStrategy);
-    if (failed(applyPatternsGreedily(op, std::move(loweringPatterns))))
+        patterns, this->loweringStrategy);
+    if (failed(applyPatternsGreedily(op, std::move(patterns))))
       signalPassFailure();
   }
 
diff --git a/mlir/test/Dialect/LLVM/transform-e2e.mlir b/mlir/test/Dialect/LLVM/transform-e2e.mlir
index 7f4c78a7fc2c8..b4cb4b75a5f4d 100644
--- a/mlir/test/Dialect/LLVM/transform-e2e.mlir
+++ b/mlir/test/Dialect/LLVM/transform-e2e.mlir
@@ -32,6 +32,7 @@ module attributes {transform.with_named_sequence} {
       transform.apply_patterns.vector.transfer_permutation_patterns
       transform.apply_patterns.vector.reorder_multi_reduction_dims lowering_strategy = "innerparallel"
       transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerparallel"
+      transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerparallel"
       transform.apply_patterns.vector.multi_reduction_lowering lowering_strategy = "innerparallel"
       transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "linalg-copy"
       transform.apply_patterns.vector.transfer_to_scf max_transfer_rank = 1 full_unroll = true
diff --git a/mlir/test/Dialect/Vector/transform-vector.mlir b/mlir/test/Dialect/Vector/transform-vector.mlir
index 9027bef2eac5c..b8729d06c9cfc 100644
--- a/mlir/test/Dialect/Vector/transform-vector.mlir
+++ b/mlir/test/Dialect/Vector/transform-vector.mlir
@@ -41,6 +41,7 @@ module attributes {transform.with_named_sequence} {
     transform.apply_patterns to %f {
       transform.apply_patterns.vector.reorder_multi_reduction_dims lowering_strategy = "innerparallel"
       transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerparallel"
+      transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerparallel"
       transform.apply_patterns.vector.multi_reduction_lowering lowering_strategy = "innerparallel"
     } : !transform.any_op
 
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_1d.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_1d.mlir
index 15f5bfd3619eb..23623b8b3f21e 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_1d.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_1d.mlir
@@ -152,6 +152,7 @@ module attributes {transform.with_named_sequence} {
       transform.apply_patterns.vector.lower_masked_transfers
       transform.apply_patterns.vector.reorder_multi_reduction_dims lowering_strategy = "innerreduction"
       transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerreduction"
+      transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerreduction"
       transform.apply_patterns.vector.multi_reduction_lowering lowering_strategy = "innerreduction"
     } : !transform.op<"func.func">
 
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_2d.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_2d.mlir
index 0cfba97654769..bf6ab21afc909 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_2d.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_2d.mlir
@@ -157,6 +157,7 @@ module attributes {transform.with_named_sequence} {
       transform.apply_patterns.vector.lower_masked_transfers
       transform.apply_patterns.vector.reorder_multi_reduction_dims lowering_strategy = "innerreduction"
       transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerreduction"
+      transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerreduction"
       transform.apply_patterns.vector.multi_reduction_lowering lowering_strategy = "innerreduction"
     } : !transform.op<"func.func">
 
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/test-matmul-masked-vec.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/test-matmul-masked-vec.mlir
index ce85720d665c2..8d973527847c1 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/test-matmul-masked-vec.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/test-matmul-masked-vec.mlir
@@ -55,6 +55,7 @@ module attributes {transform.with_named_sequence} {
     transform.apply_patterns to %func_op {
       transform.apply_patterns.vector.reorder_multi_reduction_dims lowering_strategy = "innerreduction"
       transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerreduction"
+      transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerreduction"
       transform.apply_patterns.vector.multi_reduction_lowering lowering_strategy = "innerreduction"
     } : !transform.op<"func.func">
     transform.yield

>From 365480502c15eada5308f9d14af3d09132d45b2c Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Fri, 6 Mar 2026 13:24:38 -0500
Subject: [PATCH 6/6] Fix identifier name

---
 .../Vector/TransformOps/VectorTransformOps.td        |  2 +-
 .../Vector/TransformOps/VectorTransformOps.cpp       |  2 +-
 mlir/test/python/dialects/transform_vector_ext.py    | 12 +++++++++---
 3 files changed, 11 insertions(+), 5 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index c9226af30e674..b91bd09053c7e 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -262,7 +262,7 @@ def ApplyMultiReductionFlatteningPatternsOp: Op<Transform_Dialect,
   }];
 }
 
-def ApplyUnrollMultiReductionUnrollingPatternsOp: Op<Transform_Dialect,
+def ApplyMultiReductionUnrollingPatternsOp: Op<Transform_Dialect,
     "apply_patterns.vector.multi_reduction_unrolling",
     [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
   let description = [{
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index f63d43a7d165c..55953aa22dd88 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -145,7 +145,7 @@ void transform::ApplyMultiReductionFlatteningPatternsOp::populatePatterns(
       patterns, vectorTransformOptions.vectorMultiReductionLowering);
 }
 
-void transform::ApplyUnrollMultiReductionUnrollingPatternsOp::populatePatterns(
+void transform::ApplyMultiReductionUnrollingPatternsOp::populatePatterns(
     RewritePatternSet &patterns) {
   vector::VectorTransformsOptions vectorTransformOptions;
   vectorTransformOptions.setVectorMultiReductionLowering(getLoweringStrategy());
diff --git a/mlir/test/python/dialects/transform_vector_ext.py b/mlir/test/python/dialects/transform_vector_ext.py
index 84d70181b5288..b1dae910052b7 100644
--- a/mlir/test/python/dialects/transform_vector_ext.py
+++ b/mlir/test/python/dialects/transform_vector_ext.py
@@ -103,16 +103,22 @@ def enum_configurable_patterns():
         lowering_strategy=vector.VectorMultiReductionLowering.InnerReduction
     )
 
-    # CHECK: transform.apply_patterns.vector.multi_reduction_lowering
-    vector.ApplyMultiReductionLoweringPatternsOp()
     # CHECK: transform.apply_patterns.vector.multi_reduction_unrolling
-    vector.ApplyUnrollMultiReductionUnrollingPatternsOp()
+    vector.ApplyMultiReductionUnrollingPatternsOp()
     # CHECK: transform.apply_patterns.vector.multi_reduction_unrolling
     # CHECK-SAME: lowering_strategy = innerreduction
     vector.ApplyMultiReductionUnrollingPatternsOp(
         lowering_strategy=vector.VectorMultiReductionLowering.InnerReduction
     )
 
+    # CHECK: transform.apply_patterns.vector.multi_reduction_lowering
+    vector.ApplyMultiReductionLoweringPatternsOp()
+    # CHECK: transform.apply_patterns.vector.multi_reduction_lowering
+    # CHECK-SAME: lowering_strategy = innerreduction
+    vector.ApplyMultiReductionLoweringPatternsOp(
+        lowering_strategy=vector.VectorMultiReductionLowering.InnerReduction
+    )
+
     # CHECK: transform.apply_patterns.vector.lower_transpose
     vector.ApplyLowerTransposePatternsOp()
     # CHECK: transform.apply_patterns.vector.lower_transpose



More information about the Mlir-commits mailing list