[Mlir-commits] [mlir] [mlir][vector] Separate multi_reduce lowering into invariants, flattening, and unrolling. (PR #178974)

Erick Ochoa Lopez llvmlistbot at llvm.org
Fri Jan 30 16:00:29 PST 2026


https://github.com/amd-eochoalo updated https://github.com/llvm/llvm-project/pull/178974

>From c7e0fa9bb52ffecb21cd9bec4b51a16b02b73000 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Thu, 29 Jan 2026 10:29:59 -0500
Subject: [PATCH 01/15] [mlir][vector] Finer pattern selection for
 VectorMultiReduction (NFC).

* Adds populateVectorMultiReductionInnerOuterDimPatterns which populates
  InnerOuterDimReductionConversion and OneDimMultiReductionToTwoDim.
  These patterns set invariants for the patterns to follow.
* Adds populateVectorMultiReductionFlatteningPatterns which populates
  ReduceMultiDimReductionRank, TwoDimMultiReductionToElementWise, and
  TwoDimMultiReductionToReduction.
---
 .../Vector/Transforms/LoweringPatterns.h      | 40 +++++++++++++++++++
 .../TransformOps/VectorTransformOps.cpp       |  2 +
 .../Transforms/LowerVectorMultiReduction.cpp  | 36 +++++++++++++++--
 3 files changed, 74 insertions(+), 4 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 7bd96c8a6d1a1..6d39183d3e66d 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -60,6 +60,46 @@ void populateVectorContractLoweringPatterns(
 void populateVectorOuterProductLoweringPatterns(RewritePatternSet &patterns,
                                                 PatternBenefit benefit = 1);
 
+/// Collect a set of patterns to set invariants for vector.multi_reduction's
+/// conversion. The patterns comprise:
+///
+/// [InnerOuterDimReductionConversion]
+/// Rewrites vector.multi_reduction such that all reduction dimensions are
+/// either innermost or outermost, by adding the proper vector.transpose
+/// operations.
+///
+/// [OneDimMultiReductionToTwoDim]
+/// For cases that reduce to 1-D vector<k> reduction (and are thus missing
+/// either a parallel or a reduction), we lift them back up to 2-D with a simple
+/// vector.shape_cast to vector<1xk> so that the other patterns can kick in,
+/// thus fully exiting out of the vector.multi_reduction abstraction.
+void populateVectorMultiReductionInnerOuterDimPatterns(
+    RewritePatternSet &patterns, VectorMultiReductionLowering options,
+    PatternBenefit benefit = 1);
+
+/// Collect a set of patterns to convert vector.multi_reduction op into
+/// a sequence of vector.reduction ops. The patterns comprise:
+///
+/// [ReduceMultiDimReductionRank]
+/// Once in innermost or outermost reduction
+/// form, rewrites n-D vector.multi_reduction into 2-D vector.multi_reduction,
+/// by introducing vector.shape_cast ops to collapse + multi-reduce + expand
+/// back.
+///
+/// [TwoDimMultiReductionToElementWise]
+/// Once in 2-D vector.multi_reduction form, with an **outermost** reduction
+/// dimension, unroll the outer dimension to obtain a sequence of 1-D vector
+/// ops. This also has an opportunity for tree-reduction (in the future).
+///
+/// [TwoDimMultiReductionToReduction]
+/// Once in 2-D vector.multi_reduction form, with an **innermost** reduction
+/// dimension, unroll the outer dimension to obtain a sequence of extract +
+/// vector.reduction + insert. This can further lower to horizontal reduction
+/// ops.
+void populateVectorMultiReductionFlatteningPatterns(
+    RewritePatternSet &patterns, VectorMultiReductionLowering options,
+    PatternBenefit benefit = 1);
+
 /// Collect a set of patterns to convert vector.multi_reduction op into
 /// a sequence of vector.reduction ops. The patterns comprise:
 ///
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 7faa222a9e574..f651e74873a91 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -130,6 +130,8 @@ void transform::ApplyLowerMultiReductionPatternsOp::populatePatterns(
     RewritePatternSet &patterns) {
   vector::VectorTransformsOptions vectorTransformOptions;
   vectorTransformOptions.setVectorMultiReductionLowering(getLoweringStrategy());
+  vector::populateVectorMultiReductionInnerOuterDimPatterns(
+      patterns, vectorTransformOptions.vectorMultiReductionLowering);
   vector::populateVectorMultiReductionLoweringPatterns(
       patterns, vectorTransformOptions.vectorMultiReductionLowering);
 }
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index e86e2a97038db..993b6a7f6e08f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -496,11 +496,18 @@ struct LowerVectorMultiReductionPass
     Operation *op = getOperation();
     MLIRContext *context = op->getContext();
 
-    RewritePatternSet loweringPatterns(context);
-    populateVectorMultiReductionLoweringPatterns(loweringPatterns,
-                                                 this->loweringStrategy);
+    RewritePatternSet innerOuterPatterns(context);
+    populateVectorMultiReductionInnerOuterDimPatterns(innerOuterPatterns,
+                                                      this->loweringStrategy);
 
-    if (failed(applyPatternsGreedily(op, std::move(loweringPatterns))))
+    if (failed(applyPatternsGreedily(op, std::move(innerOuterPatterns))))
+      signalPassFailure();
+
+    RewritePatternSet flatteningPatterns(context);
+    populateVectorMultiReductionFlatteningPatterns(flatteningPatterns,
+                                                   this->loweringStrategy);
+
+    if (failed(applyPatternsGreedily(op, std::move(flatteningPatterns))))
       signalPassFailure();
   }
 
@@ -511,6 +518,27 @@ struct LowerVectorMultiReductionPass
 
 } // namespace
 
+void mlir::vector::populateVectorMultiReductionInnerOuterDimPatterns(
+    RewritePatternSet &patterns, VectorMultiReductionLowering options,
+    PatternBenefit benefit) {
+  patterns.add<OneDimMultiReductionToTwoDim>(patterns.getContext(), benefit);
+  patterns.add<InnerOuterDimReductionConversion>(patterns.getContext(), options,
+                                                 benefit);
+}
+
+void mlir::vector::populateVectorMultiReductionFlatteningPatterns(
+    RewritePatternSet &patterns, VectorMultiReductionLowering options,
+    PatternBenefit benefit) {
+  patterns.add<ReduceMultiDimReductionRank>(patterns.getContext(), options,
+                                            benefit);
+  if (options == VectorMultiReductionLowering ::InnerReduction)
+    patterns.add<TwoDimMultiReductionToReduction>(patterns.getContext(),
+                                                  benefit);
+  else
+    patterns.add<TwoDimMultiReductionToElementWise>(patterns.getContext(),
+                                                    benefit);
+}
+
 void mlir::vector::populateVectorMultiReductionLoweringPatterns(
     RewritePatternSet &patterns, VectorMultiReductionLowering options,
     PatternBenefit benefit) {

>From 7a0e185022bd04ed2d878f4574fee560347a823f Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Thu, 29 Jan 2026 10:37:16 -0500
Subject: [PATCH 02/15] [mlir][vector] Finer pattern selection for
 VectorMultiReduction (NFC)

* Removes the following patterns from
  populateVectorMultiReductionFlatteningPatterns:
  * TwoDimMultiReductionToElementWise
  * TwoDimMultiReductionToReduction
* Adds populateVectorMultiReductionUnrollingPatterns with patterns
  TwoDimMultiReductionToElementWise, TwoDimMultiReductionToReduction
---
 .../Dialect/Vector/Transforms/LoweringPatterns.h     |  8 +++++++-
 .../Vector/Transforms/LowerVectorMultiReduction.cpp  | 12 ++++++++++++
 2 files changed, 19 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 6d39183d3e66d..1beb8e25cb683 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -85,6 +85,12 @@ void populateVectorMultiReductionInnerOuterDimPatterns(
 /// form, rewrites n-D vector.multi_reduction into 2-D vector.multi_reduction,
 /// by introducing vector.shape_cast ops to collapse + multi-reduce + expand
 /// back.
+void populateVectorMultiReductionFlatteningPatterns(
+    RewritePatternSet &patterns, VectorMultiReductionLowering options,
+    PatternBenefit benefit = 1);
+
+/// Collect a set of patterns to convert vector.multi_reduction op into
+/// a sequence of vector.reduction ops. The patterns comprise:
 ///
 /// [TwoDimMultiReductionToElementWise]
 /// Once in 2-D vector.multi_reduction form, with an **outermost** reduction
@@ -96,7 +102,7 @@ void populateVectorMultiReductionInnerOuterDimPatterns(
 /// dimension, unroll the outer dimension to obtain a sequence of extract +
 /// vector.reduction + insert. This can further lower to horizontal reduction
 /// ops.
-void populateVectorMultiReductionFlatteningPatterns(
+void populateVectorMultiReductionUnrollingPatterns(
     RewritePatternSet &patterns, VectorMultiReductionLowering options,
     PatternBenefit benefit = 1);
 
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 993b6a7f6e08f..2eb8048bb4a69 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -509,6 +509,13 @@ struct LowerVectorMultiReductionPass
 
     if (failed(applyPatternsGreedily(op, std::move(flatteningPatterns))))
       signalPassFailure();
+
+    RewritePatternSet unrollingPatterns(context);
+    populateVectorMultiReductionUnrollingPatterns(unrollingPatterns,
+                                                  this->loweringStrategy);
+
+    if (failed(applyPatternsGreedily(op, std::move(unrollingPatterns))))
+      signalPassFailure();
   }
 
   void getDependentDialects(DialectRegistry &registry) const override {
@@ -531,6 +538,11 @@ void mlir::vector::populateVectorMultiReductionFlatteningPatterns(
     PatternBenefit benefit) {
   patterns.add<ReduceMultiDimReductionRank>(patterns.getContext(), options,
                                             benefit);
+}
+
+void mlir::vector::populateVectorMultiReductionUnrollingPatterns(
+    RewritePatternSet &patterns, VectorMultiReductionLowering options,
+    PatternBenefit benefit) {
   if (options == VectorMultiReductionLowering ::InnerReduction)
     patterns.add<TwoDimMultiReductionToReduction>(patterns.getContext(),
                                                   benefit);

>From 94dc8fd6cefd1ae166622950f6c6bfc8a6df18bf Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Fri, 30 Jan 2026 11:13:52 -0500
Subject: [PATCH 03/15] [mlir][vector] rank reduce unrolling for
 vector.multi_reduction.

---
 .../Transforms/LowerVectorMultiReduction.cpp  | 199 +++++++++++++++++-
 .../vector-multi-reduction-pass-lowering.mlir |   6 +-
 2 files changed, 198 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 2eb8048bb4a69..b2b01031587cf 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -485,6 +485,195 @@ struct OneDimMultiReductionToTwoDim
   }
 };
 
+/// Unrolls outermost dimension for vector.multi_reduction.
+/// This patterns matches operations which reduce the outermost dimension,
+/// it does not transform operations for which the outermost dimension is not
+/// a reduction dimension.
+///
+/// There are two cases to consider:
+/// 1. The base case is when the outermost dimension is the only reduction
+/// dimension.
+/// 2. The general case is when the outermost dimension is not the only
+/// reduction dimension.
+///
+/// The base case transformation:
+///
+/// ```mlir
+/// %res = vector.multi_reduction <add> %src, %acc [0] : vector<NxMx...xf32> to
+/// vector<Mx...xf32>
+/// ```
+///
+/// will extract N vectors from %src and then perform elementwise operations.
+///
+/// ```mlir
+/// %0 = vector.extract %src[0] : vector<Mx...xf32> from vector<NxMx...xf32>
+/// ...
+/// %Nminus1 = vector.extract %src[ [[N-1]] ] : vector<Mx...x.f32> from
+/// vector<NxMx...xf32>
+///
+/// %res0 = arith.addf %0, %acc : vector<Mx...xf32>
+/// ...
+/// %res = arith.addf %Nminus1, %resNminus2 : vector<Mx...xf32>
+/// ```
+///
+/// For the general case, we still extract N vectors, but produce N
+/// vector.multi_reduction instead of elementwise operations.
+///
+/// ```mlir
+/// %res = vector.multi_reduction <add> %src, %acc [0, [[REDUCTION_DIMS]] ] :
+/// vector<NxMx...xf32> to vector<Ix...xf32>
+///
+/// ```mlir
+/// %0 = vector.extract %src[0] : vector<Mx...xf32> from vector<NxMx...xf32>
+/// ...
+/// %Nminus1 = vector.extract %src[ [[N-1]] ] : vector<Mx...x.f32> from
+/// vector<NxMx...xf32>
+///
+/// %red0 = vector.multi_reduction %0, %acc [ [[REDUCTION_DIMS]] ] :
+/// vector<Mx...xf32> to vector<Ix...xf32>
+/// ...
+/// %res = vector.multi_reduction %Nminus1, %redNminus2 [ [[REDUCTION_DIMS]] ] :
+/// vector<Mx...xf32> to vector<Ix...xf32>
+/// ```
+struct UnrollMultiReductionOuterBaseCase
+    : public OpRewritePattern<vector::MultiDimReductionOp> {
+  using Base::Base;
+
+  LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
+                                PatternRewriter &rewriter) const override {
+    if (!multiReductionOp.isReducedDim(0))
+      return rewriter.notifyMatchFailure(
+          multiReductionOp,
+          "expected outermost dimension to be reduced dimension.");
+
+    Type elementType = getElementTypeOrSelf(multiReductionOp.getDestType());
+    if (!elementType.isIntOrIndexOrFloat())
+      return rewriter.notifyMatchFailure(
+          multiReductionOp, "expected integer or float element type.");
+
+    ArrayRef<int64_t> reductionDims = multiReductionOp.getReductionDims();
+    if (reductionDims.size() > 1)
+      return rewriter.notifyMatchFailure(
+          multiReductionOp, "expected only one reduction dimension.");
+
+    Location loc = multiReductionOp.getLoc();
+    Value source = multiReductionOp.getSource();
+
+    ArrayRef<int64_t> srcShape =
+        multiReductionOp.getSourceVectorType().getShape();
+    int64_t numElementwiseOps = srcShape.front();
+
+    OpBuilder::InsertionGuard guard(rewriter);
+    auto maskableOp =
+        cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
+    bool isMasked = maskableOp.isMasked();
+    Operation *rootOp;
+    Value mask = nullptr;
+    if (isMasked) {
+      rewriter.setInsertionPoint(maskableOp.getMaskingOp());
+      rootOp = maskableOp.getMaskingOp();
+      mask = maskableOp.getMaskingOp().getMask();
+    } else {
+      rootOp = multiReductionOp;
+    }
+
+    SmallVector<Value> vectors;
+    for (int64_t i = 0; i < numElementwiseOps; ++i)
+      vectors.push_back(vector::ExtractOp::create(rewriter, loc, source, i));
+
+    SmallVector<Value> masks;
+    for (int64_t i = 0; i < numElementwiseOps; ++i)
+      if (isMasked)
+        masks.push_back(vector::ExtractOp::create(rewriter, loc, mask, i));
+      else
+        masks.push_back(nullptr);
+
+    Value result = multiReductionOp.getAcc();
+    for (auto [innerVector, innerMask] : llvm::zip(vectors, masks))
+      result = makeArithReduction(rewriter, loc, multiReductionOp.getKind(),
+                                  innerVector, result, /*fastmath=*/nullptr,
+                                  innerMask);
+
+    rewriter.replaceOp(rootOp, result);
+    return success();
+  }
+};
+
+struct UnrollMultiReductionOuterGeneralCase
+    : public OpRewritePattern<vector::MultiDimReductionOp> {
+  using Base::Base;
+
+  LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
+                                PatternRewriter &rewriter) const override {
+    if (!multiReductionOp.isReducedDim(0))
+      return rewriter.notifyMatchFailure(
+          multiReductionOp,
+          "expected outermost dimension to be reduced dimension.");
+
+    Type elementType = getElementTypeOrSelf(multiReductionOp.getDestType());
+    if (!elementType.isIntOrIndexOrFloat())
+      return rewriter.notifyMatchFailure(
+          multiReductionOp, "expected integer or float element type.");
+
+    ArrayRef<int64_t> reductionDims = multiReductionOp.getReductionDims();
+    if (reductionDims.size() <= 1)
+      return rewriter.notifyMatchFailure(
+          multiReductionOp, "expected more than one reduction dimension.");
+
+    Location loc = multiReductionOp.getLoc();
+    Value source = multiReductionOp.getSource();
+
+    ArrayRef<int64_t> srcShape =
+        multiReductionOp.getSourceVectorType().getShape();
+    int64_t numElementwiseOps = srcShape.front();
+
+    OpBuilder::InsertionGuard guard(rewriter);
+    auto maskableOp =
+        cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
+    bool isMasked = maskableOp.isMasked();
+    Operation *rootOp;
+    Value mask = nullptr;
+    if (isMasked) {
+      rewriter.setInsertionPoint(maskableOp.getMaskingOp());
+      rootOp = maskableOp.getMaskingOp();
+      mask = maskableOp.getMaskingOp().getMask();
+    } else {
+      rootOp = multiReductionOp;
+    }
+
+    SmallVector<Value> vectors;
+    for (int64_t i = 0; i < numElementwiseOps; ++i)
+      vectors.push_back(vector::ExtractOp::create(rewriter, loc, source, i));
+
+    SmallVector<Value> masks;
+    for (int64_t i = 0; i < numElementwiseOps; ++i)
+      if (isMasked)
+        masks.push_back(vector::ExtractOp::create(rewriter, loc, mask, i));
+      else
+        masks.push_back(nullptr);
+
+    ArrayRef<bool> reductionMask =
+        ArrayRef<bool>(multiReductionOp.getReductionMask()).drop_front();
+    Value result = multiReductionOp.getAcc();
+    for (auto [innerVector, innerMask] : llvm::zip(vectors, masks)) {
+
+      auto reductionOp = vector::MultiDimReductionOp::create(
+          rewriter, loc, innerVector, result, reductionMask,
+          multiReductionOp.getKind());
+
+      if (isMasked) {
+        auto maskOp = vector::maskOperation(rewriter, reductionOp, innerMask);
+        result = maskOp->getResult(0);
+      } else {
+        result = reductionOp.getResult();
+      }
+    }
+
+    rewriter.replaceOp(rootOp, result);
+    return success();
+  }
+};
+
 struct LowerVectorMultiReductionPass
     : public vector::impl::LowerVectorMultiReductionBase<
           LowerVectorMultiReductionPass> {
@@ -543,12 +732,14 @@ void mlir::vector::populateVectorMultiReductionFlatteningPatterns(
 void mlir::vector::populateVectorMultiReductionUnrollingPatterns(
     RewritePatternSet &patterns, VectorMultiReductionLowering options,
     PatternBenefit benefit) {
-  if (options == VectorMultiReductionLowering ::InnerReduction)
+  if (options == VectorMultiReductionLowering ::InnerReduction) {
     patterns.add<TwoDimMultiReductionToReduction>(patterns.getContext(),
                                                   benefit);
-  else
-    patterns.add<TwoDimMultiReductionToElementWise>(patterns.getContext(),
-                                                    benefit);
+  } else {
+    patterns.add<UnrollMultiReductionOuterBaseCase,
+                 UnrollMultiReductionOuterGeneralCase>(patterns.getContext(),
+                                                       benefit);
+  }
 }
 
 void mlir::vector::populateVectorMultiReductionLoweringPatterns(
diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-pass-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-pass-lowering.mlir
index ddbc5c7bdb2c0..e01bf446eb83c 100644
--- a/mlir/test/Dialect/Vector/vector-multi-reduction-pass-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-pass-lowering.mlir
@@ -21,12 +21,12 @@ func.func @vector_multi_reduction(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -
 
 //      INNER-PARALLEL: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32>
 //      INNER-PARALLEL: %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<2xf32> from vector<4x2xf32>
-//      INNER-PARALLEL: %[[RV0:.+]] = arith.mulf %[[V0]], %[[ACC]] : vector<2xf32>
 //      INNER-PARALLEL: %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<2xf32> from vector<4x2xf32>
-//      INNER-PARALLEL: %[[RV01:.+]] = arith.mulf %[[V1]], %[[RV0]] : vector<2xf32>
 //      INNER-PARALLEL: %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<2xf32> from vector<4x2xf32>
-//      INNER-PARALLEL: %[[RV012:.+]] = arith.mulf %[[V2]], %[[RV01]] : vector<2xf32>
 //      INNER-PARALLEL: %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<2xf32> from vector<4x2xf32>
+//      INNER-PARALLEL: %[[RV0:.+]] = arith.mulf %[[V0]], %[[ACC]] : vector<2xf32>
+//      INNER-PARALLEL: %[[RV01:.+]] = arith.mulf %[[V1]], %[[RV0]] : vector<2xf32>
+//      INNER-PARALLEL: %[[RV012:.+]] = arith.mulf %[[V2]], %[[RV01]] : vector<2xf32>
 //      INNER-PARALLEL: %[[RESULT_VEC:.+]] = arith.mulf %[[V3]], %[[RV012]] : vector<2xf32>
 //      INNER-PARALLEL: return %[[RESULT_VEC]] : vector<2xf32>
 

>From 2ce43b870f8bc2c156a8f1436704c27ae8edb81d Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Fri, 30 Jan 2026 12:07:30 -0500
Subject: [PATCH 04/15] [mlir][vector] Add tests for multi_reduction unrolling.

---
 .../Vector/TransformOps/VectorTransformOps.td | 20 ++++
 .../Vector/Transforms/LoweringPatterns.h      | 24 +++++
 .../TransformOps/VectorTransformOps.cpp       |  8 ++
 .../Transforms/LowerVectorMultiReduction.cpp  | 16 ++++
 .../Vector/td/unroll-multi-reduction.mlir     | 24 +++++
 .../Vector/unroll-vector-multi-reduction.mlir | 92 +++++++++++++++++++
 6 files changed, 184 insertions(+)
 create mode 100644 mlir/test/Dialect/Vector/td/unroll-multi-reduction.mlir
 create mode 100644 mlir/test/Dialect/Vector/unroll-vector-multi-reduction.mlir

diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index 03d25505dc65c..373f9a4210c41 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -243,6 +243,26 @@ def ApplyLowerMultiReductionPatternsOp : Op<Transform_Dialect,
   }];
 }
 
+def ApplyUnrollMultiReductionPatternsOp : Op<Transform_Dialect,
+    "apply_patterns.vector.unroll_multi_reduction",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Unrolls vector.multi_reduction operations by progressively reducing rank
+    along the outermost dimension.
+
+    This is an alternative to the flattening-based lowering that preserves
+    the n-D structure during progressive lowering.
+  }];
+
+  let arguments = (ins DefaultValuedAttr<VectorMultiReductionLoweringAttr,
+      "vector::VectorMultiReductionLowering::InnerParallel">:$lowering_strategy
+  );
+
+  let assemblyFormat = [{
+    (`lowering_strategy` `=` $lowering_strategy^)? attr-dict
+  }];
+}
+
 def ApplyLowerOuterProductPatternsOp : Op<Transform_Dialect,
     "apply_patterns.vector.lower_outerproduct",
     [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 1beb8e25cb683..1e78983ce082d 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -140,6 +140,30 @@ void populateVectorMultiReductionLoweringPatterns(
     RewritePatternSet &patterns, VectorMultiReductionLowering options,
     PatternBenefit benefit = 1);
 
+/// Collect a set of patterns to unroll vector.multi_reduction ops by
+/// progressively reducing rank along the outermost dimension.
+///
+/// For OuterReduction (outermost dim is reduction):
+/// [UnrollMultiReductionOuterBaseCase]
+/// When the outermost dimension is the only reduction dimension, unroll to
+/// produce elementwise arithmetic operations.
+///
+/// [UnrollMultiReductionOuterGeneralCase]
+/// When the outermost dimension is one of multiple reduction dimensions,
+/// unroll to produce smaller multi_reduction operations.
+///
+/// For InnerReduction (innermost dim is reduction):
+/// [UnrollMultiReductionInnerBaseCase]
+/// When the innermost dimension is the only reduction dimension, unroll along
+/// the outermost parallel dimension.
+///
+/// [UnrollMultiReductionInnerGeneralCase]
+/// When the innermost dimension is one of multiple reduction dimensions,
+/// unroll along the outermost parallel dimension.
+void populateVectorUnrollMultiReduction(RewritePatternSet &patterns,
+                                        VectorMultiReductionLowering options,
+                                        PatternBenefit benefit = 1);
+
 /// Populate the pattern set with the following patterns:
 ///
 /// [TransferReadToVectorLoadLowering]
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index f651e74873a91..4562393559c94 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -136,6 +136,14 @@ void transform::ApplyLowerMultiReductionPatternsOp::populatePatterns(
       patterns, vectorTransformOptions.vectorMultiReductionLowering);
 }
 
+void transform::ApplyUnrollMultiReductionPatternsOp::populatePatterns(
+    RewritePatternSet &patterns) {
+  vector::VectorTransformsOptions vectorTransformOptions;
+  vectorTransformOptions.setVectorMultiReductionLowering(getLoweringStrategy());
+  vector::populateVectorUnrollMultiReduction(
+      patterns, vectorTransformOptions.vectorMultiReductionLowering);
+}
+
 void transform::ApplyLowerOuterProductPatternsOp::populatePatterns(
     RewritePatternSet &patterns) {
   populateVectorOuterProductLoweringPatterns(patterns);
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index b2b01031587cf..8548e8c2d311d 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -742,6 +742,22 @@ void mlir::vector::populateVectorMultiReductionUnrollingPatterns(
   }
 }
 
+void mlir::vector::populateVectorUnrollMultiReduction(
+    RewritePatternSet &patterns, VectorMultiReductionLowering options,
+    PatternBenefit benefit) {
+  if (options == VectorMultiReductionLowering::InnerReduction) {
+    // TODO: Add UnrollMultiReductionInnerBaseCase and
+    // UnrollMultiReductionInnerGeneralCase patterns here once implemented.
+    // For now, fall back to the existing 2-D based lowering.
+    patterns.add<TwoDimMultiReductionToReduction>(patterns.getContext(),
+                                                  benefit);
+  } else {
+    patterns.add<UnrollMultiReductionOuterBaseCase,
+                 UnrollMultiReductionOuterGeneralCase>(patterns.getContext(),
+                                                       benefit);
+  }
+}
+
 void mlir::vector::populateVectorMultiReductionLoweringPatterns(
     RewritePatternSet &patterns, VectorMultiReductionLowering options,
     PatternBenefit benefit) {
diff --git a/mlir/test/Dialect/Vector/td/unroll-multi-reduction.mlir b/mlir/test/Dialect/Vector/td/unroll-multi-reduction.mlir
new file mode 100644
index 0000000000000..96a68723266d3
--- /dev/null
+++ b/mlir/test/Dialect/Vector/td/unroll-multi-reduction.mlir
@@ -0,0 +1,24 @@
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @unroll_multi_reduction(%module_op: !transform.any_op {transform.readonly}) {
+
+    %func_op = transform.structured.match ops{["func.func"]} in %module_op
+      : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func_op {
+      // Test patterns
+      transform.apply_patterns.vector.unroll_multi_reduction
+    } : !transform.any_op
+
+    transform.yield
+  }
+  transform.named_sequence @unroll_multi_reduction_inner(%module_op: !transform.any_op {transform.readonly}) {
+
+    %func_op = transform.structured.match ops{["func.func"]} in %module_op
+      : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func_op {
+      // Test patterns
+      transform.apply_patterns.vector.unroll_multi_reduction lowering_strategy = "innerreduction"
+    } : !transform.any_op
+
+    transform.yield
+  }
+}
diff --git a/mlir/test/Dialect/Vector/unroll-vector-multi-reduction.mlir b/mlir/test/Dialect/Vector/unroll-vector-multi-reduction.mlir
new file mode 100644
index 0000000000000..79086e2b0b9ad
--- /dev/null
+++ b/mlir/test/Dialect/Vector/unroll-vector-multi-reduction.mlir
@@ -0,0 +1,92 @@
+// RUN: mlir-opt --split-input-file %s -transform-preload-library='transform-library-paths=%p/td/unroll-multi-reduction.mlir' \
+// RUN: -transform-interpreter=entry-point=unroll_multi_reduction | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// Test UnrollVectorMultiReduction
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func @unroll_vector_multi_reduction(
+// CHECK-SAME: %[[SOURCE:.+]]: vector<2x3x5xf32>,
+// CHECK-SAME: %[[ACC:.+]]: vector<3x5xf32>
+func.func @unroll_vector_multi_reduction(%source: vector<2x3x5xf32>, %acc: vector<3x5xf32>) -> (vector<3x5xf32>) {
+  // CHECK-DAG: %[[VEC_0:.+]] = vector.extract %[[SOURCE]][0] : vector<3x5xf32> from vector<2x3x5xf32>
+  // CHECK-DAG: %[[VEC_1:.+]] = vector.extract %[[SOURCE]][1] : vector<3x5xf32> from vector<2x3x5xf32>
+
+  // CHECK: %[[RES_0:.+]] = arith.addf %[[VEC_0]], %[[ACC]] : vector<3x5xf32>
+  // CHECK: %[[RES_1:.+]] = arith.addf %[[VEC_1]], %[[RES_0]] : vector<3x5xf32>
+  %1 = vector.multi_reduction <add>, %source, %acc [0] : vector<2x3x5xf32> to vector<3x5xf32>
+
+  // CHECK: return %[[RES_1]]
+  return %1 : vector<3x5xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @unroll_vector_multi_reduction_masked(
+// CHECK-SAME: %[[SOURCE:.+]]: vector<2x3x5xf32>,
+// CHECK-SAME: %[[MASK:.+]]: vector<2x3x5xi1>,
+// CHECK-SAME: %[[ACC:.+]]: vector<3x5xf32>
+func.func @unroll_vector_multi_reduction_masked(%source: vector<2x3x5xf32>, %mask: vector<2x3x5xi1>, %acc: vector<3x5xf32>) -> (vector<3x5xf32>) {
+  // CHECK-DAG: %[[VEC_0:.+]] = vector.extract %[[SOURCE]][0] : vector<3x5xf32> from vector<2x3x5xf32>
+  // CHECK-DAG: %[[VEC_1:.+]] = vector.extract %[[SOURCE]][1] : vector<3x5xf32> from vector<2x3x5xf32>
+
+  // CHECK-DAG: %[[MASK_0:.+]] = vector.extract %[[MASK]][0] : vector<3x5xi1> from vector<2x3x5xi1>
+  // CHECK-DAG: %[[MASK_1:.+]] = vector.extract %[[MASK]][1] : vector<3x5xi1> from vector<2x3x5xi1>
+
+  // CHECK: %[[RES_0:.+]] = arith.addf %[[VEC_0]], %[[ACC]] : vector<3x5xf32>
+  // CHECK: %[[RES_MASKED_0:.+]] = arith.select %[[MASK_0]], %[[RES_0]], %[[ACC]] : vector<3x5xi1>, vector<3x5xf32>
+
+  // CHECK: %[[RES_1:.+]] = arith.addf %[[VEC_1]], %[[RES_MASKED_0]] : vector<3x5xf32>
+  // CHECK: %[[RES_MASKED_1:.+]] = arith.select %[[MASK_1]], %[[RES_1]], %[[RES_MASKED_0]] : vector<3x5xi1>, vector<3x5xf32>
+
+  %0 = vector.mask %mask {
+    %1 = vector.multi_reduction <add>, %source, %acc [0] : vector<2x3x5xf32> to vector<3x5xf32>
+  } : vector<2x3x5xi1> -> vector<3x5xf32>
+
+  // CHECK: return %[[RES_MASKED_1]]
+  return %0 : vector<3x5xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @unroll_vector_multi_reduction_general(
+// CHECK-SAME: %[[SOURCE:.+]]: vector<2x3x5xf32>,
+// CHECK-SAME: %[[ACC:.+]]: vector<3xf32>
+func.func @unroll_vector_multi_reduction_general(%source: vector<2x3x5xf32>, %acc: vector<3xf32>) -> (vector<3xf32>) {
+
+  // CHECK-DAG: %[[VEC_0:.+]] = vector.extract %[[SOURCE]][0] : vector<3x5xf32> from vector<2x3x5xf32>
+  // CHECK-DAG: %[[VEC_1:.+]] = vector.extract %[[SOURCE]][1] : vector<3x5xf32> from vector<2x3x5xf32>
+
+  // CHECK: %[[RES_0:.+]] = vector.multi_reduction <add>, %[[VEC_0]], %[[ACC]] [1] : vector<3x5xf32> to vector<3xf32>
+  // CHECK: %[[RES_1:.+]] = vector.multi_reduction <add>, %[[VEC_1]], %[[RES_0]] [1] : vector<3x5xf32> to vector<3xf32>
+
+  %1 = vector.multi_reduction <add>, %source, %acc [0, 2] : vector<2x3x5xf32> to vector<3xf32>
+
+  // CHECK: return %[[RES_1]]
+  return %1 : vector<3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @unroll_vector_multi_reduction_general_masked(
+// CHECK-SAME: %[[SOURCE:.+]]: vector<2x3x5xf32>,
+// CHECK-SAME: %[[MASK:.+]]: vector<2x3x5xi1>,
+// CHECK-SAME: %[[ACC:.+]]: vector<3xf32>
+func.func @unroll_vector_multi_reduction_general_masked(%source: vector<2x3x5xf32>, %mask: vector<2x3x5xi1>, %acc: vector<3xf32>) -> (vector<3xf32>) {
+
+  // CHECK-DAG: %[[VEC_0:.+]] = vector.extract %[[SOURCE]][0] : vector<3x5xf32> from vector<2x3x5xf32>
+  // CHECK-DAG: %[[VEC_1:.+]] = vector.extract %[[SOURCE]][1] : vector<3x5xf32> from vector<2x3x5xf32>
+
+  // CHECK-DAG: %[[MASK_0:.+]] = vector.extract %[[MASK]][0] : vector<3x5xi1> from vector<2x3x5xi1>
+  // CHECK-DAG: %[[MASK_1:.+]] = vector.extract %[[MASK]][1] : vector<3x5xi1> from vector<2x3x5xi1>
+
+  // CHECK: %[[RES_0:.+]] = vector.mask %[[MASK_0]] { vector.multi_reduction <add>, %[[VEC_0]], %[[ACC]] [1] : vector<3x5xf32> to vector<3xf32> } : vector<3x5xi1> -> vector<3xf32>
+  // CHECK: %[[RES_1:.+]] = vector.mask %[[MASK_1]] { vector.multi_reduction <add>, %[[VEC_1]], %[[RES_0]] [1] : vector<3x5xf32> to vector<3xf32> } : vector<3x5xi1> -> vector<3xf32>
+
+  %0 = vector.mask %mask {
+    %1 = vector.multi_reduction <add>, %source, %acc [0, 2] : vector<2x3x5xf32> to vector<3xf32>
+  } : vector<2x3x5xi1> -> vector<3xf32>
+
+  // CHECK: return %[[RES_1]]
+  return %0 : vector<3xf32>
+}

>From 7561c5be684af765fd722addaaee230ee79b3166 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Fri, 30 Jan 2026 12:49:09 -0500
Subject: [PATCH 05/15] [mlir][vector] Add unrolling pattern for innermost
 reduction base case

---
 .../Transforms/LowerVectorMultiReduction.cpp  | 138 +++++++++++++++++-
 .../unroll-vector-multi-reduction-inner.mlir  | 106 ++++++++++++++
 2 files changed, 239 insertions(+), 5 deletions(-)
 create mode 100644 mlir/test/Dialect/Vector/unroll-vector-multi-reduction-inner.mlir

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 8548e8c2d311d..5e0d6da82f834 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -674,6 +674,135 @@ struct UnrollMultiReductionOuterGeneralCase
   }
 };
 
+/// Unrolls innermost dimension for vector.multi_reduction when the innermost
+/// dimension is the only reduction dimension.
+///
+/// This pattern matches operations where:
+/// - The innermost dimension is a reduction dimension
+/// - The outermost dimension is a parallel dimension
+/// - There is exactly one reduction dimension
+///
+/// The transformation unrolls along the outermost parallel dimension:
+///
+/// ```mlir
+/// %res = vector.multi_reduction <add> %src, %acc [N-1]
+///     : vector<AxBx...xMxf32> to vector<AxBx...xf32>
+/// ```
+///
+/// becomes:
+///
+/// ```mlir
+/// %result = arith.constant dense<0.0> : vector<AxBx...xf32>
+/// %0 = vector.extract %src[0] : vector<Bx...xMxf32> from vector<AxBx...xMxf32>
+/// %acc0 = vector.extract %acc[0] : vector<Bx...xf32> from vector<AxBx...xf32>
+/// %red0 = vector.multi_reduction <add>, %0, %acc0 [N-2]
+///     : vector<Bx...xMxf32> to vector<Bx...xf32>
+/// %res0 = vector.insert %red0, %result[0]
+///     : vector<Bx...xf32> into vector<AxBx...xf32>
+/// // ... repeat for indices 1 to A-1
+/// ```
+struct UnrollMultiReductionInnerBaseCase
+    : public OpRewritePattern<vector::MultiDimReductionOp> {
+  using Base::Base;
+
+  LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
+                                PatternRewriter &rewriter) const override {
+    auto srcRank = multiReductionOp.getSourceVectorType().getRank();
+
+    if (srcRank < 2)
+      return rewriter.notifyMatchFailure(multiReductionOp,
+                                         "expected source rank >= 2.");
+
+    if (!multiReductionOp.isReducedDim(srcRank - 1))
+      return rewriter.notifyMatchFailure(
+          multiReductionOp,
+          "expected innermost dimension to be a reduction dimension.");
+
+    if (multiReductionOp.isReducedDim(0))
+      return rewriter.notifyMatchFailure(
+          multiReductionOp,
+          "expected outermost dimension to be a parallel dimension.");
+
+    ArrayRef<int64_t> reductionDims = multiReductionOp.getReductionDims();
+    if (reductionDims.size() != 1)
+      return rewriter.notifyMatchFailure(
+          multiReductionOp, "expected exactly one reduction dimension.");
+
+    Type elementType = getElementTypeOrSelf(multiReductionOp.getDestType());
+    if (!elementType.isIntOrIndexOrFloat())
+      return rewriter.notifyMatchFailure(
+          multiReductionOp, "expected integer or float element type.");
+
+    Location loc = multiReductionOp.getLoc();
+    Value source = multiReductionOp.getSource();
+    Value acc = multiReductionOp.getAcc();
+
+    ArrayRef<int64_t> srcShape =
+        multiReductionOp.getSourceVectorType().getShape();
+    int64_t numSlices = srcShape.front();
+
+    OpBuilder::InsertionGuard guard(rewriter);
+    auto maskableOp =
+        cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
+    bool isMasked = maskableOp.isMasked();
+    Operation *rootOp;
+    Value mask = nullptr;
+    if (isMasked) {
+      rewriter.setInsertionPoint(maskableOp.getMaskingOp());
+      rootOp = maskableOp.getMaskingOp();
+      mask = maskableOp.getMaskingOp().getMask();
+    } else {
+      rootOp = multiReductionOp;
+    }
+
+    SmallVector<Value> srcSlices;
+    for (int64_t i = 0; i < numSlices; ++i)
+      srcSlices.push_back(vector::ExtractOp::create(rewriter, loc, source, i));
+
+    SmallVector<Value> accSlices;
+    for (int64_t i = 0; i < numSlices; ++i)
+      accSlices.push_back(vector::ExtractOp::create(rewriter, loc, acc, i));
+
+    SmallVector<Value> maskSlices;
+    for (int64_t i = 0; i < numSlices; ++i)
+      if (isMasked)
+        maskSlices.push_back(vector::ExtractOp::create(rewriter, loc, mask, i));
+      else
+        maskSlices.push_back(nullptr);
+
+    // Compute new reduction mask: for the extracted slice (rank srcRank-1),
+    // the innermost dimension is still the reduction dimension.
+    // New mask has srcRank-1 elements, with the last one being true.
+    SmallVector<bool> newReductionMask(srcRank - 1, false);
+    newReductionMask.back() = true;
+
+    SmallVector<Value> reductionResults;
+    for (auto [srcSlice, accSlice, maskSlice] :
+         llvm::zip(srcSlices, accSlices, maskSlices)) {
+      Operation *newReductionOp = vector::MultiDimReductionOp::create(
+          rewriter, loc, srcSlice, accSlice, newReductionMask,
+          multiReductionOp.getKind());
+
+      if (isMasked)
+        newReductionOp =
+            mlir::vector::maskOperation(rewriter, newReductionOp, maskSlice);
+
+      reductionResults.push_back(newReductionOp->getResult(0));
+    }
+
+    Value result = arith::ConstantOp::create(
+        rewriter, loc, multiReductionOp.getDestType(),
+        rewriter.getZeroAttr(multiReductionOp.getDestType()));
+
+    for (int64_t i = 0; i < numSlices; ++i)
+      result = vector::InsertOp::create(rewriter, loc, reductionResults[i],
+                                        result, i);
+
+    rewriter.replaceOp(rootOp, result);
+    return success();
+  }
+};
+
 struct LowerVectorMultiReductionPass
     : public vector::impl::LowerVectorMultiReductionBase<
           LowerVectorMultiReductionPass> {
@@ -746,11 +875,10 @@ void mlir::vector::populateVectorUnrollMultiReduction(
     RewritePatternSet &patterns, VectorMultiReductionLowering options,
     PatternBenefit benefit) {
   if (options == VectorMultiReductionLowering::InnerReduction) {
-    // TODO: Add UnrollMultiReductionInnerBaseCase and
-    // UnrollMultiReductionInnerGeneralCase patterns here once implemented.
-    // For now, fall back to the existing 2-D based lowering.
-    patterns.add<TwoDimMultiReductionToReduction>(patterns.getContext(),
-                                                  benefit);
+    // TODO: Add UnrollMultiReductionInnerGeneralCase pattern here once
+    // implemented.
+    patterns.add<UnrollMultiReductionInnerBaseCase>(patterns.getContext(),
+                                                    benefit);
   } else {
     patterns.add<UnrollMultiReductionOuterBaseCase,
                  UnrollMultiReductionOuterGeneralCase>(patterns.getContext(),
diff --git a/mlir/test/Dialect/Vector/unroll-vector-multi-reduction-inner.mlir b/mlir/test/Dialect/Vector/unroll-vector-multi-reduction-inner.mlir
new file mode 100644
index 0000000000000..89af4580bf627
--- /dev/null
+++ b/mlir/test/Dialect/Vector/unroll-vector-multi-reduction-inner.mlir
@@ -0,0 +1,106 @@
+// RUN: mlir-opt --split-input-file %s -transform-preload-library='transform-library-paths=%p/td/unroll-multi-reduction.mlir' \
+// RUN: -transform-interpreter=entry-point=unroll_multi_reduction_inner | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// Test UnrollVectorMultiReduction for Inner Reduction (Base Case)
+//===----------------------------------------------------------------------===//
+
+// The pattern recursively reduces rank until we reach 1D multi_reductions.
+// For vector<2x3x5xf32> with reduction on dim 2:
+// - First pass: unrolls along dim 0 (size 2), creating vector<3x5xf32> multi_reductions
+// - Second pass: unrolls along dim 0 (size 3), creating vector<5xf32> multi_reductions
+//
+// The generated IR groups operations by phase:
+// extracts (source) → extracts (acc) → multi_reductions → inserts
+
+// CHECK-LABEL: func @unroll_vector_multi_reduction_inner(
+// CHECK-SAME: %[[SOURCE:.+]]: vector<2x3x5xf32>,
+// CHECK-SAME: %[[ACC:.+]]: vector<2x3xf32>
+func.func @unroll_vector_multi_reduction_inner(%source: vector<2x3x5xf32>, %acc: vector<2x3xf32>) -> (vector<2x3xf32>) {
+  // First slice [0, ...]: extracts → reductions → inserts
+  // CHECK: vector.extract %[[SOURCE]][0, 0]
+  // CHECK: vector.extract %[[SOURCE]][0, 1]
+  // CHECK: vector.extract %[[SOURCE]][0, 2]
+  // CHECK: vector.extract %[[ACC]][0, 0]
+  // CHECK: vector.extract %[[ACC]][0, 1]
+  // CHECK: vector.extract %[[ACC]][0, 2]
+  // CHECK: vector.multi_reduction <add>, {{.*}} [0] : vector<5xf32> to f32
+  // CHECK: vector.multi_reduction <add>, {{.*}} [0] : vector<5xf32> to f32
+  // CHECK: vector.multi_reduction <add>, {{.*}} [0] : vector<5xf32> to f32
+  // CHECK: vector.insert {{.*}} [0] : f32 into vector<3xf32>
+  // CHECK: vector.insert {{.*}} [1] : f32 into vector<3xf32>
+  // CHECK: vector.insert {{.*}} [2] : f32 into vector<3xf32>
+  // Second slice [1, ...]: extracts → reductions → inserts
+  // CHECK: vector.extract %[[SOURCE]][1, 0]
+  // CHECK: vector.extract %[[SOURCE]][1, 1]
+  // CHECK: vector.extract %[[SOURCE]][1, 2]
+  // CHECK: vector.extract %[[ACC]][1, 0]
+  // CHECK: vector.extract %[[ACC]][1, 1]
+  // CHECK: vector.extract %[[ACC]][1, 2]
+  // CHECK: vector.multi_reduction <add>, {{.*}} [0] : vector<5xf32> to f32
+  // CHECK: vector.multi_reduction <add>, {{.*}} [0] : vector<5xf32> to f32
+  // CHECK: vector.multi_reduction <add>, {{.*}} [0] : vector<5xf32> to f32
+  // CHECK: vector.insert {{.*}} [0] : f32 into vector<3xf32>
+  // CHECK: vector.insert {{.*}} [1] : f32 into vector<3xf32>
+  // CHECK: vector.insert {{.*}} [2] : f32 into vector<3xf32>
+  // Final inserts to assemble result
+  // CHECK: vector.insert {{.*}} [0] : vector<3xf32> into vector<2x3xf32>
+  // CHECK: vector.insert {{.*}} [1] : vector<3xf32> into vector<2x3xf32>
+  // No original multi_reduction with [2] remains
+  // CHECK-NOT: vector.multi_reduction <add>, {{.*}} [2]
+  %1 = vector.multi_reduction <add>, %source, %acc [2] : vector<2x3x5xf32> to vector<2x3xf32>
+
+  return %1 : vector<2x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @unroll_vector_multi_reduction_inner_masked(
+// CHECK-SAME: %[[SOURCE:.+]]: vector<2x3x5xf32>,
+// CHECK-SAME: %[[MASK:.+]]: vector<2x3x5xi1>,
+// CHECK-SAME: %[[ACC:.+]]: vector<2x3xf32>
+func.func @unroll_vector_multi_reduction_inner_masked(%source: vector<2x3x5xf32>, %mask: vector<2x3x5xi1>, %acc: vector<2x3xf32>) -> (vector<2x3xf32>) {
+  // First slice [0, ...]: extracts (source, acc, mask) → masked reductions → inserts
+  // CHECK: vector.extract %[[SOURCE]][0, 0]
+  // CHECK: vector.extract %[[SOURCE]][0, 1]
+  // CHECK: vector.extract %[[SOURCE]][0, 2]
+  // CHECK: vector.extract %[[ACC]][0, 0]
+  // CHECK: vector.extract %[[ACC]][0, 1]
+  // CHECK: vector.extract %[[ACC]][0, 2]
+  // CHECK: vector.extract %[[MASK]][0, 0]
+  // CHECK: vector.extract %[[MASK]][0, 1]
+  // CHECK: vector.extract %[[MASK]][0, 2]
+  // CHECK: vector.mask {{.*}} { vector.multi_reduction <add>, {{.*}} [0] : vector<5xf32> to f32 }
+  // CHECK: vector.mask {{.*}} { vector.multi_reduction <add>, {{.*}} [0] : vector<5xf32> to f32 }
+  // CHECK: vector.mask {{.*}} { vector.multi_reduction <add>, {{.*}} [0] : vector<5xf32> to f32 }
+  // CHECK: vector.insert {{.*}} [0] : f32 into vector<3xf32>
+  // CHECK: vector.insert {{.*}} [1] : f32 into vector<3xf32>
+  // CHECK: vector.insert {{.*}} [2] : f32 into vector<3xf32>
+  // Second slice [1, ...]: extracts → masked reductions → inserts
+  // CHECK: vector.extract %[[SOURCE]][1, 0]
+  // CHECK: vector.extract %[[SOURCE]][1, 1]
+  // CHECK: vector.extract %[[SOURCE]][1, 2]
+  // CHECK: vector.extract %[[ACC]][1, 0]
+  // CHECK: vector.extract %[[ACC]][1, 1]
+  // CHECK: vector.extract %[[ACC]][1, 2]
+  // CHECK: vector.extract %[[MASK]][1, 0]
+  // CHECK: vector.extract %[[MASK]][1, 1]
+  // CHECK: vector.extract %[[MASK]][1, 2]
+  // CHECK: vector.mask {{.*}} { vector.multi_reduction <add>, {{.*}} [0] : vector<5xf32> to f32 }
+  // CHECK: vector.mask {{.*}} { vector.multi_reduction <add>, {{.*}} [0] : vector<5xf32> to f32 }
+  // CHECK: vector.mask {{.*}} { vector.multi_reduction <add>, {{.*}} [0] : vector<5xf32> to f32 }
+  // CHECK: vector.insert {{.*}} [0] : f32 into vector<3xf32>
+  // CHECK: vector.insert {{.*}} [1] : f32 into vector<3xf32>
+  // CHECK: vector.insert {{.*}} [2] : f32 into vector<3xf32>
+  // Final inserts to assemble result
+  // CHECK: vector.insert {{.*}} [0] : vector<3xf32> into vector<2x3xf32>
+  // CHECK: vector.insert {{.*}} [1] : vector<3xf32> into vector<2x3xf32>
+  // No original multi_reduction with [2] remains
+  // CHECK-NOT: vector.multi_reduction <add>, {{.*}} [2]
+
+  %0 = vector.mask %mask {
+    %1 = vector.multi_reduction <add>, %source, %acc [2] : vector<2x3x5xf32> to vector<2x3xf32>
+  } : vector<2x3x5xi1> -> vector<2x3xf32>
+
+  return %0 : vector<2x3xf32>
+}

>From 966b4cb70efd50f309dd3e5592bba37ddcc06d12 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Fri, 30 Jan 2026 13:16:22 -0500
Subject: [PATCH 06/15] [mlir][vector] Implement general pattern

---
 .../Transforms/LowerVectorMultiReduction.cpp  | 46 ++++++--------
 .../unroll-vector-multi-reduction-inner.mlir  | 60 +++++++++++++++++++
 2 files changed, 79 insertions(+), 27 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 5e0d6da82f834..99b766eb9ad7c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -674,34 +674,34 @@ struct UnrollMultiReductionOuterGeneralCase
   }
 };
 
-/// Unrolls innermost dimension for vector.multi_reduction when the innermost
-/// dimension is the only reduction dimension.
+/// Unrolls vector.multi_reduction along the outermost parallel dimension
+/// when the innermost dimension is a reduction dimension.
 ///
 /// This pattern matches operations where:
 /// - The innermost dimension is a reduction dimension
 /// - The outermost dimension is a parallel dimension
-/// - There is exactly one reduction dimension
 ///
-/// The transformation unrolls along the outermost parallel dimension:
+/// The transformation extracts slices along the outermost parallel dimension,
+/// creates smaller multi_reductions, and assembles the results:
 ///
 /// ```mlir
-/// %res = vector.multi_reduction <add> %src, %acc [N-1]
-///     : vector<AxBx...xMxf32> to vector<AxBx...xf32>
+/// %res = vector.multi_reduction <add> %src, %acc [1, 3]
+///     : vector<AxBxCxDxf32> to vector<AxCxf32>
 /// ```
 ///
 /// becomes:
 ///
 /// ```mlir
-/// %result = arith.constant dense<0.0> : vector<AxBx...xf32>
-/// %0 = vector.extract %src[0] : vector<Bx...xMxf32> from vector<AxBx...xMxf32>
-/// %acc0 = vector.extract %acc[0] : vector<Bx...xf32> from vector<AxBx...xf32>
-/// %red0 = vector.multi_reduction <add>, %0, %acc0 [N-2]
-///     : vector<Bx...xMxf32> to vector<Bx...xf32>
+/// %result = arith.constant dense<0.0> : vector<AxCxf32>
+/// %0 = vector.extract %src[0] : vector<BxCxDxf32> from vector<AxBxCxDxf32>
+/// %acc0 = vector.extract %acc[0] : vector<Cxf32> from vector<AxCxf32>
+/// %red0 = vector.multi_reduction <add>, %0, %acc0 [0, 2]
+///     : vector<BxCxDxf32> to vector<Cxf32>
 /// %res0 = vector.insert %red0, %result[0]
-///     : vector<Bx...xf32> into vector<AxBx...xf32>
+///     : vector<Cxf32> into vector<AxCxf32>
 /// // ... repeat for indices 1 to A-1
 /// ```
-struct UnrollMultiReductionInnerBaseCase
+struct UnrollMultiReductionInner
     : public OpRewritePattern<vector::MultiDimReductionOp> {
   using Base::Base;
 
@@ -723,11 +723,6 @@ struct UnrollMultiReductionInnerBaseCase
           multiReductionOp,
           "expected outermost dimension to be a parallel dimension.");
 
-    ArrayRef<int64_t> reductionDims = multiReductionOp.getReductionDims();
-    if (reductionDims.size() != 1)
-      return rewriter.notifyMatchFailure(
-          multiReductionOp, "expected exactly one reduction dimension.");
-
     Type elementType = getElementTypeOrSelf(multiReductionOp.getDestType());
     if (!elementType.isIntOrIndexOrFloat())
       return rewriter.notifyMatchFailure(
@@ -770,11 +765,11 @@ struct UnrollMultiReductionInnerBaseCase
       else
         maskSlices.push_back(nullptr);
 
-    // Compute new reduction mask: for the extracted slice (rank srcRank-1),
-    // the innermost dimension is still the reduction dimension.
-    // New mask has srcRank-1 elements, with the last one being true.
-    SmallVector<bool> newReductionMask(srcRank - 1, false);
-    newReductionMask.back() = true;
+    // Compute new reduction mask by dropping the first element (dimension 0).
+    // Since dimension 0 is parallel (not reduced), all reduction indices shift
+    // down by 1.
+    ArrayRef<bool> newReductionMask =
+        ArrayRef<bool>(multiReductionOp.getReductionMask()).drop_front();
 
     SmallVector<Value> reductionResults;
     for (auto [srcSlice, accSlice, maskSlice] :
@@ -875,10 +870,7 @@ void mlir::vector::populateVectorUnrollMultiReduction(
     RewritePatternSet &patterns, VectorMultiReductionLowering options,
     PatternBenefit benefit) {
   if (options == VectorMultiReductionLowering::InnerReduction) {
-    // TODO: Add UnrollMultiReductionInnerGeneralCase pattern here once
-    // implemented.
-    patterns.add<UnrollMultiReductionInnerBaseCase>(patterns.getContext(),
-                                                    benefit);
+    patterns.add<UnrollMultiReductionInner>(patterns.getContext(), benefit);
   } else {
     patterns.add<UnrollMultiReductionOuterBaseCase,
                  UnrollMultiReductionOuterGeneralCase>(patterns.getContext(),
diff --git a/mlir/test/Dialect/Vector/unroll-vector-multi-reduction-inner.mlir b/mlir/test/Dialect/Vector/unroll-vector-multi-reduction-inner.mlir
index 89af4580bf627..fcb97967f088e 100644
--- a/mlir/test/Dialect/Vector/unroll-vector-multi-reduction-inner.mlir
+++ b/mlir/test/Dialect/Vector/unroll-vector-multi-reduction-inner.mlir
@@ -104,3 +104,63 @@ func.func @unroll_vector_multi_reduction_inner_masked(%source: vector<2x3x5xf32>
 
   return %0 : vector<2x3xf32>
 }
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// Test UnrollVectorMultiReduction for Inner Reduction (General Case)
+//===----------------------------------------------------------------------===//
+
+// The general case handles multiple reduction dimensions.
+// For vector<2x3x5xf32> with reduction on dims [1, 2]:
+// - Unrolls along dim 0 (size 2), creating vector<3x5xf32> multi_reductions
+// - Each new multi_reduction has reduction dims [0, 1] (shifted from [1, 2])
+
+// CHECK-LABEL: func @unroll_vector_multi_reduction_inner_general(
+// CHECK-SAME: %[[SOURCE:.+]]: vector<2x3x5xf32>,
+// CHECK-SAME: %[[ACC:.+]]: vector<2xf32>
+func.func @unroll_vector_multi_reduction_inner_general(%source: vector<2x3x5xf32>, %acc: vector<2xf32>) -> (vector<2xf32>) {
+  // CHECK-DAG: %[[VEC_0:.+]] = vector.extract %[[SOURCE]][0] : vector<3x5xf32> from vector<2x3x5xf32>
+  // CHECK-DAG: %[[VEC_1:.+]] = vector.extract %[[SOURCE]][1] : vector<3x5xf32> from vector<2x3x5xf32>
+
+  // CHECK-DAG: %[[ACC_0:.+]] = vector.extract %[[ACC]][0] : f32 from vector<2xf32>
+  // CHECK-DAG: %[[ACC_1:.+]] = vector.extract %[[ACC]][1] : f32 from vector<2xf32>
+
+  // CHECK: %[[RED_0:.+]] = vector.multi_reduction <add>, %[[VEC_0]], %[[ACC_0]] [0, 1] : vector<3x5xf32> to f32
+  // CHECK: %[[RED_1:.+]] = vector.multi_reduction <add>, %[[VEC_1]], %[[ACC_1]] [0, 1] : vector<3x5xf32> to f32
+  // CHECK: %[[INSERT_0:.+]] = vector.insert %[[RED_0]], %{{.*}} [0] : f32 into vector<2xf32>
+  // CHECK: %[[INSERT_1:.+]] = vector.insert %[[RED_1]], %[[INSERT_0]] [1] : f32 into vector<2xf32>
+  %1 = vector.multi_reduction <add>, %source, %acc [1, 2] : vector<2x3x5xf32> to vector<2xf32>
+
+  // CHECK: return %[[INSERT_1]]
+  return %1 : vector<2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @unroll_vector_multi_reduction_inner_general_masked(
+// CHECK-SAME: %[[SOURCE:.+]]: vector<2x3x5xf32>,
+// CHECK-SAME: %[[MASK:.+]]: vector<2x3x5xi1>,
+// CHECK-SAME: %[[ACC:.+]]: vector<2xf32>
+func.func @unroll_vector_multi_reduction_inner_general_masked(%source: vector<2x3x5xf32>, %mask: vector<2x3x5xi1>, %acc: vector<2xf32>) -> (vector<2xf32>) {
+  // CHECK-DAG: %[[VEC_0:.+]] = vector.extract %[[SOURCE]][0] : vector<3x5xf32> from vector<2x3x5xf32>
+  // CHECK-DAG: %[[VEC_1:.+]] = vector.extract %[[SOURCE]][1] : vector<3x5xf32> from vector<2x3x5xf32>
+
+  // CHECK-DAG: %[[ACC_0:.+]] = vector.extract %[[ACC]][0] : f32 from vector<2xf32>
+  // CHECK-DAG: %[[ACC_1:.+]] = vector.extract %[[ACC]][1] : f32 from vector<2xf32>
+
+  // CHECK-DAG: %[[MASK_0:.+]] = vector.extract %[[MASK]][0] : vector<3x5xi1> from vector<2x3x5xi1>
+  // CHECK-DAG: %[[MASK_1:.+]] = vector.extract %[[MASK]][1] : vector<3x5xi1> from vector<2x3x5xi1>
+
+  // CHECK: %[[RED_0:.+]] = vector.mask %[[MASK_0]] { vector.multi_reduction <add>, %[[VEC_0]], %[[ACC_0]] [0, 1] : vector<3x5xf32> to f32 } : vector<3x5xi1> -> f32
+  // CHECK: %[[RED_1:.+]] = vector.mask %[[MASK_1]] { vector.multi_reduction <add>, %[[VEC_1]], %[[ACC_1]] [0, 1] : vector<3x5xf32> to f32 } : vector<3x5xi1> -> f32
+  // CHECK: %[[INSERT_0:.+]] = vector.insert %[[RED_0]], %{{.*}} [0] : f32 into vector<2xf32>
+  // CHECK: %[[INSERT_1:.+]] = vector.insert %[[RED_1]], %[[INSERT_0]] [1] : f32 into vector<2xf32>
+
+  %0 = vector.mask %mask {
+    %1 = vector.multi_reduction <add>, %source, %acc [1, 2] : vector<2x3x5xf32> to vector<2xf32>
+  } : vector<2x3x5xi1> -> vector<2xf32>
+
+  // CHECK: return %[[INSERT_1]]
+  return %0 : vector<2xf32>
+}

>From cb0d3fb8bc3a2982b9223372ba2d434c8a42fb63 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Fri, 30 Jan 2026 14:05:20 -0500
Subject: [PATCH 07/15] [mlir][vector] Skip matching multi_reduction on rank 1

---
 .../Transforms/LowerVectorMultiReduction.cpp  |  5 +++++
 .../unroll-vector-multi-reduction-inner.mlir  | 19 +++++++++++++++++++
 .../Vector/unroll-vector-multi-reduction.mlir | 19 +++++++++++++++++++
 3 files changed, 43 insertions(+)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 99b766eb9ad7c..a79b6ea10d543 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -541,6 +541,11 @@ struct UnrollMultiReductionOuterBaseCase
 
   LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
                                 PatternRewriter &rewriter) const override {
+    auto srcRank = multiReductionOp.getSourceVectorType().getRank();
+    if (srcRank < 2)
+      return rewriter.notifyMatchFailure(multiReductionOp,
+                                         "expected source rank >= 2.");
+
     if (!multiReductionOp.isReducedDim(0))
       return rewriter.notifyMatchFailure(
           multiReductionOp,
diff --git a/mlir/test/Dialect/Vector/unroll-vector-multi-reduction-inner.mlir b/mlir/test/Dialect/Vector/unroll-vector-multi-reduction-inner.mlir
index fcb97967f088e..d6b3091fb9463 100644
--- a/mlir/test/Dialect/Vector/unroll-vector-multi-reduction-inner.mlir
+++ b/mlir/test/Dialect/Vector/unroll-vector-multi-reduction-inner.mlir
@@ -164,3 +164,22 @@ func.func @unroll_vector_multi_reduction_inner_general_masked(%source: vector<2x
   // CHECK: return %[[INSERT_1]]
   return %0 : vector<2xf32>
 }
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// Negative Test: Rank-1 multi_reduction should NOT be matched by UnrollMultiReductionInner
+//===----------------------------------------------------------------------===//
+
+// UnrollMultiReductionInner requires srcRank >= 2, so rank-1 should not match.
+// The op should remain unchanged.
+
+// CHECK-LABEL: func @unroll_vector_multi_reduction_inner_rank1_negative(
+// CHECK-SAME: %[[SOURCE:.+]]: vector<8xf32>,
+// CHECK-SAME: %[[ACC:.+]]: f32
+func.func @unroll_vector_multi_reduction_inner_rank1_negative(%source: vector<8xf32>, %acc: f32) -> f32 {
+  // CHECK: %[[RESULT:.+]] = vector.multi_reduction <add>, %[[SOURCE]], %[[ACC]] [0] : vector<8xf32> to f32
+  %0 = vector.multi_reduction <add>, %source, %acc [0] : vector<8xf32> to f32
+  // CHECK: return %[[RESULT]]
+  return %0 : f32
+}
diff --git a/mlir/test/Dialect/Vector/unroll-vector-multi-reduction.mlir b/mlir/test/Dialect/Vector/unroll-vector-multi-reduction.mlir
index 79086e2b0b9ad..8c36430bb6eee 100644
--- a/mlir/test/Dialect/Vector/unroll-vector-multi-reduction.mlir
+++ b/mlir/test/Dialect/Vector/unroll-vector-multi-reduction.mlir
@@ -90,3 +90,22 @@ func.func @unroll_vector_multi_reduction_general_masked(%source: vector<2x3x5xf3
   // CHECK: return %[[RES_1]]
   return %0 : vector<3xf32>
 }
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// Negative Test: Rank-1 multi_reduction should NOT be matched by unroll patterns
+//===----------------------------------------------------------------------===//
+
+// UnrollMultiReductionOuterBaseCase and UnrollMultiReductionOuterGeneralCase
+// should not match rank-1 multi_reduction ops. The op should remain unchanged.
+
+// CHECK-LABEL: func @unroll_vector_multi_reduction_rank1_negative(
+// CHECK-SAME: %[[SOURCE:.+]]: vector<8xf32>,
+// CHECK-SAME: %[[ACC:.+]]: f32
+func.func @unroll_vector_multi_reduction_rank1_negative(%source: vector<8xf32>, %acc: f32) -> f32 {
+  // CHECK: %[[RESULT:.+]] = vector.multi_reduction <add>, %[[SOURCE]], %[[ACC]] [0] : vector<8xf32> to f32
+  %0 = vector.multi_reduction <add>, %source, %acc [0] : vector<8xf32> to f32
+  // CHECK: return %[[RESULT]]
+  return %0 : f32
+}

>From f55f2bfda5e89e7e4a2e5787f5f824dc7f7076e0 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Fri, 30 Jan 2026 15:17:17 -0500
Subject: [PATCH 08/15] [mlir][vector] Convert mult_reduction in 1-D to
 reduction.

---
 .../Transforms/LowerVectorMultiReduction.cpp  | 58 +++++++++++++++++
 .../unroll-vector-multi-reduction-inner.mlir  | 65 ++++++++++++-------
 .../Vector/unroll-vector-multi-reduction.mlir | 27 ++++++--
 3 files changed, 120 insertions(+), 30 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index a79b6ea10d543..595c2fce55619 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -413,6 +413,60 @@ struct TwoDimMultiReductionToReduction
   }
 };
 
+/// Converts 1-D vector.multi_reduction directly to vector.reduction.
+/// This is the terminal case for unrolling - once we reach rank 1,
+/// we convert to vector.reduction which backends can optimize.
+///
+/// Example:
+/// ```mlir
+/// // Before
+/// %r = vector.multi_reduction <add>, %v, %acc [0] : vector<Nxf32> to f32
+///
+/// // After
+/// %r = vector.reduction <add>, %v, %acc : vector<Nxf32> into f32
+/// ```
+struct OneDimMultiReductionToReduction
+    : public OpRewritePattern<vector::MultiDimReductionOp> {
+  using Base::Base;
+
+  LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
+                                PatternRewriter &rewriter) const override {
+    auto srcRank = multiReductionOp.getSourceVectorType().getRank();
+    if (srcRank != 1)
+      return rewriter.notifyMatchFailure(multiReductionOp,
+                                         "expected source rank == 1.");
+
+    if (!multiReductionOp.isReducedDim(0))
+      return rewriter.notifyMatchFailure(
+          multiReductionOp,
+          "expected dimension 0 to be a reduction dimension.");
+
+    OpBuilder::InsertionGuard guard(rewriter);
+    auto maskableOp =
+        cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
+    Operation *rootOp;
+    Value mask;
+    if (maskableOp.isMasked()) {
+      rewriter.setInsertionPoint(maskableOp.getMaskingOp());
+      rootOp = maskableOp.getMaskingOp();
+      mask = maskableOp.getMaskingOp().getMask();
+    } else {
+      rootOp = multiReductionOp;
+    }
+
+    auto loc = multiReductionOp.getLoc();
+    Operation *reductionOp = vector::ReductionOp::create(
+        rewriter, loc, multiReductionOp.getKind(), multiReductionOp.getSource(),
+        multiReductionOp.getAcc());
+
+    if (maskableOp.isMasked())
+      reductionOp = mlir::vector::maskOperation(rewriter, reductionOp, mask);
+
+    rewriter.replaceOp(rootOp, reductionOp->getResult(0));
+    return success();
+  }
+};
+
 /// Converts 1d vector.multi_reduction with a single reduction dimension to a 2d
 /// form with both a single parallel and reduction dimension.
 /// This is achieved with a simple vector.shape_cast that inserts a leading 1.
@@ -861,6 +915,8 @@ void mlir::vector::populateVectorMultiReductionFlatteningPatterns(
 void mlir::vector::populateVectorMultiReductionUnrollingPatterns(
     RewritePatternSet &patterns, VectorMultiReductionLowering options,
     PatternBenefit benefit) {
+  patterns.add<OneDimMultiReductionToReduction>(patterns.getContext(), benefit);
+
   if (options == VectorMultiReductionLowering ::InnerReduction) {
     patterns.add<TwoDimMultiReductionToReduction>(patterns.getContext(),
                                                   benefit);
@@ -874,6 +930,8 @@ void mlir::vector::populateVectorMultiReductionUnrollingPatterns(
 void mlir::vector::populateVectorUnrollMultiReduction(
     RewritePatternSet &patterns, VectorMultiReductionLowering options,
     PatternBenefit benefit) {
+  patterns.add<OneDimMultiReductionToReduction>(patterns.getContext(), benefit);
+
   if (options == VectorMultiReductionLowering::InnerReduction) {
     patterns.add<UnrollMultiReductionInner>(patterns.getContext(), benefit);
   } else {
diff --git a/mlir/test/Dialect/Vector/unroll-vector-multi-reduction-inner.mlir b/mlir/test/Dialect/Vector/unroll-vector-multi-reduction-inner.mlir
index d6b3091fb9463..f21535c1dfc21 100644
--- a/mlir/test/Dialect/Vector/unroll-vector-multi-reduction-inner.mlir
+++ b/mlir/test/Dialect/Vector/unroll-vector-multi-reduction-inner.mlir
@@ -5,13 +5,15 @@
 // Test UnrollVectorMultiReduction for Inner Reduction (Base Case)
 //===----------------------------------------------------------------------===//
 
-// The pattern recursively reduces rank until we reach 1D multi_reductions.
+// The pattern recursively reduces rank until we reach 1D, then converts to
+// vector.reduction via OneDimMultiReductionToReduction.
 // For vector<2x3x5xf32> with reduction on dim 2:
 // - First pass: unrolls along dim 0 (size 2), creating vector<3x5xf32> multi_reductions
 // - Second pass: unrolls along dim 0 (size 3), creating vector<5xf32> multi_reductions
+// - Final: 1-D multi_reductions are converted to vector.reduction
 //
 // The generated IR groups operations by phase:
-// extracts (source) → extracts (acc) → multi_reductions → inserts
+// extracts (source) → extracts (acc) → reductions → inserts
 
 // CHECK-LABEL: func @unroll_vector_multi_reduction_inner(
 // CHECK-SAME: %[[SOURCE:.+]]: vector<2x3x5xf32>,
@@ -24,9 +26,9 @@ func.func @unroll_vector_multi_reduction_inner(%source: vector<2x3x5xf32>, %acc:
   // CHECK: vector.extract %[[ACC]][0, 0]
   // CHECK: vector.extract %[[ACC]][0, 1]
   // CHECK: vector.extract %[[ACC]][0, 2]
-  // CHECK: vector.multi_reduction <add>, {{.*}} [0] : vector<5xf32> to f32
-  // CHECK: vector.multi_reduction <add>, {{.*}} [0] : vector<5xf32> to f32
-  // CHECK: vector.multi_reduction <add>, {{.*}} [0] : vector<5xf32> to f32
+  // CHECK: vector.reduction <add>, {{.*}} : vector<5xf32> into f32
+  // CHECK: vector.reduction <add>, {{.*}} : vector<5xf32> into f32
+  // CHECK: vector.reduction <add>, {{.*}} : vector<5xf32> into f32
   // CHECK: vector.insert {{.*}} [0] : f32 into vector<3xf32>
   // CHECK: vector.insert {{.*}} [1] : f32 into vector<3xf32>
   // CHECK: vector.insert {{.*}} [2] : f32 into vector<3xf32>
@@ -37,17 +39,17 @@ func.func @unroll_vector_multi_reduction_inner(%source: vector<2x3x5xf32>, %acc:
   // CHECK: vector.extract %[[ACC]][1, 0]
   // CHECK: vector.extract %[[ACC]][1, 1]
   // CHECK: vector.extract %[[ACC]][1, 2]
-  // CHECK: vector.multi_reduction <add>, {{.*}} [0] : vector<5xf32> to f32
-  // CHECK: vector.multi_reduction <add>, {{.*}} [0] : vector<5xf32> to f32
-  // CHECK: vector.multi_reduction <add>, {{.*}} [0] : vector<5xf32> to f32
+  // CHECK: vector.reduction <add>, {{.*}} : vector<5xf32> into f32
+  // CHECK: vector.reduction <add>, {{.*}} : vector<5xf32> into f32
+  // CHECK: vector.reduction <add>, {{.*}} : vector<5xf32> into f32
   // CHECK: vector.insert {{.*}} [0] : f32 into vector<3xf32>
   // CHECK: vector.insert {{.*}} [1] : f32 into vector<3xf32>
   // CHECK: vector.insert {{.*}} [2] : f32 into vector<3xf32>
   // Final inserts to assemble result
   // CHECK: vector.insert {{.*}} [0] : vector<3xf32> into vector<2x3xf32>
   // CHECK: vector.insert {{.*}} [1] : vector<3xf32> into vector<2x3xf32>
-  // No original multi_reduction with [2] remains
-  // CHECK-NOT: vector.multi_reduction <add>, {{.*}} [2]
+  // No original multi_reduction remains
+  // CHECK-NOT: vector.multi_reduction
   %1 = vector.multi_reduction <add>, %source, %acc [2] : vector<2x3x5xf32> to vector<2x3xf32>
 
   return %1 : vector<2x3xf32>
@@ -70,9 +72,9 @@ func.func @unroll_vector_multi_reduction_inner_masked(%source: vector<2x3x5xf32>
   // CHECK: vector.extract %[[MASK]][0, 0]
   // CHECK: vector.extract %[[MASK]][0, 1]
   // CHECK: vector.extract %[[MASK]][0, 2]
-  // CHECK: vector.mask {{.*}} { vector.multi_reduction <add>, {{.*}} [0] : vector<5xf32> to f32 }
-  // CHECK: vector.mask {{.*}} { vector.multi_reduction <add>, {{.*}} [0] : vector<5xf32> to f32 }
-  // CHECK: vector.mask {{.*}} { vector.multi_reduction <add>, {{.*}} [0] : vector<5xf32> to f32 }
+  // CHECK: vector.mask {{.*}} { vector.reduction <add>, {{.*}} : vector<5xf32> into f32 }
+  // CHECK: vector.mask {{.*}} { vector.reduction <add>, {{.*}} : vector<5xf32> into f32 }
+  // CHECK: vector.mask {{.*}} { vector.reduction <add>, {{.*}} : vector<5xf32> into f32 }
   // CHECK: vector.insert {{.*}} [0] : f32 into vector<3xf32>
   // CHECK: vector.insert {{.*}} [1] : f32 into vector<3xf32>
   // CHECK: vector.insert {{.*}} [2] : f32 into vector<3xf32>
@@ -86,17 +88,17 @@ func.func @unroll_vector_multi_reduction_inner_masked(%source: vector<2x3x5xf32>
   // CHECK: vector.extract %[[MASK]][1, 0]
   // CHECK: vector.extract %[[MASK]][1, 1]
   // CHECK: vector.extract %[[MASK]][1, 2]
-  // CHECK: vector.mask {{.*}} { vector.multi_reduction <add>, {{.*}} [0] : vector<5xf32> to f32 }
-  // CHECK: vector.mask {{.*}} { vector.multi_reduction <add>, {{.*}} [0] : vector<5xf32> to f32 }
-  // CHECK: vector.mask {{.*}} { vector.multi_reduction <add>, {{.*}} [0] : vector<5xf32> to f32 }
+  // CHECK: vector.mask {{.*}} { vector.reduction <add>, {{.*}} : vector<5xf32> into f32 }
+  // CHECK: vector.mask {{.*}} { vector.reduction <add>, {{.*}} : vector<5xf32> into f32 }
+  // CHECK: vector.mask {{.*}} { vector.reduction <add>, {{.*}} : vector<5xf32> into f32 }
   // CHECK: vector.insert {{.*}} [0] : f32 into vector<3xf32>
   // CHECK: vector.insert {{.*}} [1] : f32 into vector<3xf32>
   // CHECK: vector.insert {{.*}} [2] : f32 into vector<3xf32>
   // Final inserts to assemble result
   // CHECK: vector.insert {{.*}} [0] : vector<3xf32> into vector<2x3xf32>
   // CHECK: vector.insert {{.*}} [1] : vector<3xf32> into vector<2x3xf32>
-  // No original multi_reduction with [2] remains
-  // CHECK-NOT: vector.multi_reduction <add>, {{.*}} [2]
+  // No original multi_reduction remains
+  // CHECK-NOT: vector.multi_reduction
 
   %0 = vector.mask %mask {
     %1 = vector.multi_reduction <add>, %source, %acc [2] : vector<2x3x5xf32> to vector<2x3xf32>
@@ -168,18 +170,33 @@ func.func @unroll_vector_multi_reduction_inner_general_masked(%source: vector<2x
 // -----
 
 //===----------------------------------------------------------------------===//
-// Negative Test: Rank-1 multi_reduction should NOT be matched by UnrollMultiReductionInner
+// Test 1-D multi_reduction to vector.reduction conversion
 //===----------------------------------------------------------------------===//
 
-// UnrollMultiReductionInner requires srcRank >= 2, so rank-1 should not match.
-// The op should remain unchanged.
+// OneDimMultiReductionToReduction converts rank-1 multi_reduction directly
+// to vector.reduction, which preserves the reduction semantic for backends.
 
-// CHECK-LABEL: func @unroll_vector_multi_reduction_inner_rank1_negative(
+// CHECK-LABEL: func @unroll_vector_multi_reduction_inner_1d(
 // CHECK-SAME: %[[SOURCE:.+]]: vector<8xf32>,
 // CHECK-SAME: %[[ACC:.+]]: f32
-func.func @unroll_vector_multi_reduction_inner_rank1_negative(%source: vector<8xf32>, %acc: f32) -> f32 {
-  // CHECK: %[[RESULT:.+]] = vector.multi_reduction <add>, %[[SOURCE]], %[[ACC]] [0] : vector<8xf32> to f32
+func.func @unroll_vector_multi_reduction_inner_1d(%source: vector<8xf32>, %acc: f32) -> f32 {
+  // CHECK: %[[RESULT:.+]] = vector.reduction <add>, %[[SOURCE]], %[[ACC]] : vector<8xf32> into f32
   %0 = vector.multi_reduction <add>, %source, %acc [0] : vector<8xf32> to f32
   // CHECK: return %[[RESULT]]
   return %0 : f32
 }
+
+// -----
+
+// CHECK-LABEL: func @unroll_vector_multi_reduction_inner_1d_masked(
+// CHECK-SAME: %[[SOURCE:.+]]: vector<8xf32>,
+// CHECK-SAME: %[[MASK:.+]]: vector<8xi1>,
+// CHECK-SAME: %[[ACC:.+]]: f32
+func.func @unroll_vector_multi_reduction_inner_1d_masked(%source: vector<8xf32>, %mask: vector<8xi1>, %acc: f32) -> f32 {
+  // CHECK: %[[RESULT:.+]] = vector.mask %[[MASK]] { vector.reduction <add>, %[[SOURCE]], %[[ACC]] : vector<8xf32> into f32 } : vector<8xi1> -> f32
+  %0 = vector.mask %mask {
+    vector.multi_reduction <add>, %source, %acc [0] : vector<8xf32> to f32
+  } : vector<8xi1> -> f32
+  // CHECK: return %[[RESULT]]
+  return %0 : f32
+}
diff --git a/mlir/test/Dialect/Vector/unroll-vector-multi-reduction.mlir b/mlir/test/Dialect/Vector/unroll-vector-multi-reduction.mlir
index 8c36430bb6eee..6db425fe4fce9 100644
--- a/mlir/test/Dialect/Vector/unroll-vector-multi-reduction.mlir
+++ b/mlir/test/Dialect/Vector/unroll-vector-multi-reduction.mlir
@@ -94,18 +94,33 @@ func.func @unroll_vector_multi_reduction_general_masked(%source: vector<2x3x5xf3
 // -----
 
 //===----------------------------------------------------------------------===//
-// Negative Test: Rank-1 multi_reduction should NOT be matched by unroll patterns
+// Test 1-D multi_reduction to vector.reduction conversion
 //===----------------------------------------------------------------------===//
 
-// UnrollMultiReductionOuterBaseCase and UnrollMultiReductionOuterGeneralCase
-// should not match rank-1 multi_reduction ops. The op should remain unchanged.
+// OneDimMultiReductionToReduction converts rank-1 multi_reduction directly
+// to vector.reduction, which preserves the reduction semantic for backends.
 
-// CHECK-LABEL: func @unroll_vector_multi_reduction_rank1_negative(
+// CHECK-LABEL: func @unroll_vector_multi_reduction_1d(
 // CHECK-SAME: %[[SOURCE:.+]]: vector<8xf32>,
 // CHECK-SAME: %[[ACC:.+]]: f32
-func.func @unroll_vector_multi_reduction_rank1_negative(%source: vector<8xf32>, %acc: f32) -> f32 {
-  // CHECK: %[[RESULT:.+]] = vector.multi_reduction <add>, %[[SOURCE]], %[[ACC]] [0] : vector<8xf32> to f32
+func.func @unroll_vector_multi_reduction_1d(%source: vector<8xf32>, %acc: f32) -> f32 {
+  // CHECK: %[[RESULT:.+]] = vector.reduction <add>, %[[SOURCE]], %[[ACC]] : vector<8xf32> into f32
   %0 = vector.multi_reduction <add>, %source, %acc [0] : vector<8xf32> to f32
   // CHECK: return %[[RESULT]]
   return %0 : f32
 }
+
+// -----
+
+// CHECK-LABEL: func @unroll_vector_multi_reduction_1d_masked(
+// CHECK-SAME: %[[SOURCE:.+]]: vector<8xf32>,
+// CHECK-SAME: %[[MASK:.+]]: vector<8xi1>,
+// CHECK-SAME: %[[ACC:.+]]: f32
+func.func @unroll_vector_multi_reduction_1d_masked(%source: vector<8xf32>, %mask: vector<8xi1>, %acc: f32) -> f32 {
+  // CHECK: %[[RESULT:.+]] = vector.mask %[[MASK]] { vector.reduction <add>, %[[SOURCE]], %[[ACC]] : vector<8xf32> into f32 } : vector<8xi1> -> f32
+  %0 = vector.mask %mask {
+    vector.multi_reduction <add>, %source, %acc [0] : vector<8xf32> to f32
+  } : vector<8xi1> -> f32
+  // CHECK: return %[[RESULT]]
+  return %0 : f32
+}

>From 88c6bfe77b71901a83cd9aafdba95652fa8d6933 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Fri, 30 Jan 2026 15:33:20 -0500
Subject: [PATCH 09/15] [mlir][vector] Remove old patterns population.

Replace populateVectorMultiReductionUnrollingPatterns
with populateVectorUnrollMultiReduction which is more general.
---
 .../Vector/Transforms/LoweringPatterns.h      | 17 ----------------
 .../Transforms/LowerVectorMultiReduction.cpp  | 19 ++----------------
 .../vector-multi-reduction-pass-lowering.mlir | 20 +++++++++++--------
 3 files changed, 14 insertions(+), 42 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 1e78983ce082d..e6e62f2250fb7 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -89,23 +89,6 @@ void populateVectorMultiReductionFlatteningPatterns(
     RewritePatternSet &patterns, VectorMultiReductionLowering options,
     PatternBenefit benefit = 1);
 
-/// Collect a set of patterns to convert vector.multi_reduction op into
-/// a sequence of vector.reduction ops. The patterns comprise:
-///
-/// [TwoDimMultiReductionToElementWise]
-/// Once in 2-D vector.multi_reduction form, with an **outermost** reduction
-/// dimension, unroll the outer dimension to obtain a sequence of 1-D vector
-/// ops. This also has an opportunity for tree-reduction (in the future).
-///
-/// [TwoDimMultiReductionToReduction]
-/// Once in 2-D vector.multi_reduction form, with an **innermost** reduction
-/// dimension, unroll the outer dimension to obtain a sequence of extract +
-/// vector.reduction + insert. This can further lower to horizontal reduction
-/// ops.
-void populateVectorMultiReductionUnrollingPatterns(
-    RewritePatternSet &patterns, VectorMultiReductionLowering options,
-    PatternBenefit benefit = 1);
-
 /// Collect a set of patterns to convert vector.multi_reduction op into
 /// a sequence of vector.reduction ops. The patterns comprise:
 ///
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 595c2fce55619..cd651ecf6c745 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -883,8 +883,8 @@ struct LowerVectorMultiReductionPass
       signalPassFailure();
 
     RewritePatternSet unrollingPatterns(context);
-    populateVectorMultiReductionUnrollingPatterns(unrollingPatterns,
-                                                  this->loweringStrategy);
+    populateVectorUnrollMultiReduction(unrollingPatterns,
+                                       this->loweringStrategy);
 
     if (failed(applyPatternsGreedily(op, std::move(unrollingPatterns))))
       signalPassFailure();
@@ -912,21 +912,6 @@ void mlir::vector::populateVectorMultiReductionFlatteningPatterns(
                                             benefit);
 }
 
-void mlir::vector::populateVectorMultiReductionUnrollingPatterns(
-    RewritePatternSet &patterns, VectorMultiReductionLowering options,
-    PatternBenefit benefit) {
-  patterns.add<OneDimMultiReductionToReduction>(patterns.getContext(), benefit);
-
-  if (options == VectorMultiReductionLowering ::InnerReduction) {
-    patterns.add<TwoDimMultiReductionToReduction>(patterns.getContext(),
-                                                  benefit);
-  } else {
-    patterns.add<UnrollMultiReductionOuterBaseCase,
-                 UnrollMultiReductionOuterGeneralCase>(patterns.getContext(),
-                                                       benefit);
-  }
-}
-
 void mlir::vector::populateVectorUnrollMultiReduction(
     RewritePatternSet &patterns, VectorMultiReductionLowering options,
     PatternBenefit benefit) {
diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-pass-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-pass-lowering.mlir
index e01bf446eb83c..ac833612681cf 100644
--- a/mlir/test/Dialect/Vector/vector-multi-reduction-pass-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-pass-lowering.mlir
@@ -10,13 +10,13 @@ func.func @vector_multi_reduction(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -
 //            ALL-SAME: %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: vector<2xf32>)
 // INNER-REDUCTION-DAG: %[[RESULT_VEC_0:.+]] = arith.constant dense<{{.*}}> : vector<2xf32>
 //     INNER-REDUCTION: %[[V0:.+]] = vector.extract %[[INPUT]][0]
-//     INNER-REDUCTION: %[[ACC0:.+]] = vector.extract %[[ACC]][0]
-//     INNER-REDUCTION: %[[RV0:.+]] = vector.reduction <mul>, %[[V0]], %[[ACC0]] : vector<4xf32> into f32
-//     INNER-REDUCTION: %[[RESULT_VEC_1:.+]] = vector.insert %[[RV0:.+]], %[[RESULT_VEC_0]] [0] : f32 into vector<2xf32>
 //     INNER-REDUCTION: %[[V1:.+]] = vector.extract %[[INPUT]][1]
+//     INNER-REDUCTION: %[[ACC0:.+]] = vector.extract %[[ACC]][0]
 //     INNER-REDUCTION: %[[ACC1:.+]] = vector.extract %[[ACC]][1]
+//     INNER-REDUCTION: %[[RV0:.+]] = vector.reduction <mul>, %[[V0]], %[[ACC0]] : vector<4xf32> into f32
 //     INNER-REDUCTION: %[[RV1:.+]] = vector.reduction <mul>, %[[V1]], %[[ACC1]] : vector<4xf32> into f32
-//     INNER-REDUCTION: %[[RESULT_VEC:.+]] = vector.insert %[[RV1:.+]], %[[RESULT_VEC_1]] [1] : f32 into vector<2xf32>
+//     INNER-REDUCTION: %[[RESULT_VEC_1:.+]] = vector.insert %[[RV0]], %[[RESULT_VEC_0]] [0] : f32 into vector<2xf32>
+//     INNER-REDUCTION: %[[RESULT_VEC:.+]] = vector.insert %[[RV1]], %[[RESULT_VEC_1]] [1] : f32 into vector<2xf32>
 //     INNER-REDUCTION: return %[[RESULT_VEC]]
 
 //      INNER-PARALLEL: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32>
@@ -51,10 +51,14 @@ func.func @vector_multi_reduction_masked(%arg0: vector<2x4xf32>, %acc: vector<2x
 
 //       ALL-LABEL: func @vector_multi_reduction_masked
 //        ALL-SAME: %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.+]]: vector<2xf32>, %[[MASK:.+]]: vector<2x4xi1>
-// INNER-REDUCTION: %[[INNERVEC:.+]] = vector.extract %[[INPUT]][0] : vector<4xf32> from vector<2x4xf32>
-// INNER-REDUCTION: %[[INNERACC:.+]] = vector.extract %[[ACC]][0] : f32 from vector<2xf32>
-// INNER-REDUCTION: %[[INNERMASK:.+]] = vector.extract %[[MASK]][0] : vector<4xi1> from vector<2x4xi1>
-// INNER-REDUCTION: vector.mask %[[INNERMASK]] { vector.reduction <mul>, %[[INNERVEC]], %[[INNERACC]] : vector<4xf32> into f32 } : vector<4xi1> -> f32
+// INNER-REDUCTION: %[[V0:.+]] = vector.extract %[[INPUT]][0] : vector<4xf32> from vector<2x4xf32>
+// INNER-REDUCTION: %[[V1:.+]] = vector.extract %[[INPUT]][1] : vector<4xf32> from vector<2x4xf32>
+// INNER-REDUCTION: %[[ACC0:.+]] = vector.extract %[[ACC]][0] : f32 from vector<2xf32>
+// INNER-REDUCTION: %[[ACC1:.+]] = vector.extract %[[ACC]][1] : f32 from vector<2xf32>
+// INNER-REDUCTION: %[[MASK0:.+]] = vector.extract %[[MASK]][0] : vector<4xi1> from vector<2x4xi1>
+// INNER-REDUCTION: %[[MASK1:.+]] = vector.extract %[[MASK]][1] : vector<4xi1> from vector<2x4xi1>
+// INNER-REDUCTION: vector.mask %[[MASK0]] { vector.reduction <mul>, %[[V0]], %[[ACC0]] : vector<4xf32> into f32 } : vector<4xi1> -> f32
+// INNER-REDUCTION: vector.mask %[[MASK1]] { vector.reduction <mul>, %[[V1]], %[[ACC1]] : vector<4xf32> into f32 } : vector<4xi1> -> f32
 //  INNER-PARALLEL: %[[TPMASK:.+]] = vector.transpose %[[MASK]], [1, 0] : vector<2x4xi1> to vector<4x2xi1>
 //  INNER-PARALLEL: %[[TPINPUT:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32>
 //  INNER-PARALLEL: %[[INNERVEC:.+]] = vector.extract %[[TPINPUT]][0] : vector<2xf32> from vector<4x2xf32>

>From b8a14a1df67d09b6a3a53dcce91a3a17e982775e Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Fri, 30 Jan 2026 16:42:01 -0500
Subject: [PATCH 10/15] [mlir][vector] Add missing rank check

---
 .../Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp  | 5 +++++
 1 file changed, 5 insertions(+)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index cd651ecf6c745..2e1da48a3f768 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -664,6 +664,11 @@ struct UnrollMultiReductionOuterGeneralCase
 
   LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
                                 PatternRewriter &rewriter) const override {
+    auto srcRank = multiReductionOp.getSourceVectorType().getRank();
+    if (srcRank < 2)
+      return rewriter.notifyMatchFailure(multiReductionOp,
+                                         "expected source rank >= 2.");
+
     if (!multiReductionOp.isReducedDim(0))
       return rewriter.notifyMatchFailure(
           multiReductionOp,

>From b513b1af291cad4ae67064e73368bbf2b11ab836 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Fri, 30 Jan 2026 16:43:48 -0500
Subject: [PATCH 11/15] Split documentation

---
 .../Transforms/LowerVectorMultiReduction.cpp  | 67 +++++++++----------
 1 file changed, 31 insertions(+), 36 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 2e1da48a3f768..08fcff00b2c19 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -539,56 +539,27 @@ struct OneDimMultiReductionToTwoDim
   }
 };
 
-/// Unrolls outermost dimension for vector.multi_reduction.
-/// This patterns matches operations which reduce the outermost dimension,
-/// it does not transform operations for which the outermost dimension is not
-/// a reduction dimension.
+/// Unrolls vector.multi_reduction when the outermost dimension is the only
+/// reduction dimension. Transforms to elementwise arithmetic operations.
 ///
-/// There are two cases to consider:
-/// 1. The base case is when the outermost dimension is the only reduction
-/// dimension.
-/// 2. The general case is when the outermost dimension is not the only
-/// reduction dimension.
-///
-/// The base case transformation:
+/// Example:
 ///
 /// ```mlir
-/// %res = vector.multi_reduction <add> %src, %acc [0] : vector<NxMx...xf32> to
-/// vector<Mx...xf32>
+/// %res = vector.multi_reduction <add> %src, %acc [0]
+///     : vector<NxMx...xf32> to vector<Mx...xf32>
 /// ```
 ///
-/// will extract N vectors from %src and then perform elementwise operations.
+/// becomes:
 ///
 /// ```mlir
 /// %0 = vector.extract %src[0] : vector<Mx...xf32> from vector<NxMx...xf32>
 /// ...
-/// %Nminus1 = vector.extract %src[ [[N-1]] ] : vector<Mx...x.f32> from
-/// vector<NxMx...xf32>
+/// %Nminus1 = vector.extract %src[N-1] : vector<Mx...xf32> from vector<NxMx...xf32>
 ///
 /// %res0 = arith.addf %0, %acc : vector<Mx...xf32>
 /// ...
 /// %res = arith.addf %Nminus1, %resNminus2 : vector<Mx...xf32>
 /// ```
-///
-/// For the general case, we still extract N vectors, but produce N
-/// vector.multi_reduction instead of elementwise operations.
-///
-/// ```mlir
-/// %res = vector.multi_reduction <add> %src, %acc [0, [[REDUCTION_DIMS]] ] :
-/// vector<NxMx...xf32> to vector<Ix...xf32>
-///
-/// ```mlir
-/// %0 = vector.extract %src[0] : vector<Mx...xf32> from vector<NxMx...xf32>
-/// ...
-/// %Nminus1 = vector.extract %src[ [[N-1]] ] : vector<Mx...x.f32> from
-/// vector<NxMx...xf32>
-///
-/// %red0 = vector.multi_reduction %0, %acc [ [[REDUCTION_DIMS]] ] :
-/// vector<Mx...xf32> to vector<Ix...xf32>
-/// ...
-/// %res = vector.multi_reduction %Nminus1, %redNminus2 [ [[REDUCTION_DIMS]] ] :
-/// vector<Mx...xf32> to vector<Ix...xf32>
-/// ```
 struct UnrollMultiReductionOuterBaseCase
     : public OpRewritePattern<vector::MultiDimReductionOp> {
   using Base::Base;
@@ -658,6 +629,30 @@ struct UnrollMultiReductionOuterBaseCase
   }
 };
 
+/// Unrolls vector.multi_reduction when the outermost dimension is one of
+/// multiple reduction dimensions. Extracts slices and chains smaller
+/// multi_reduction operations.
+///
+/// Example:
+///
+/// ```mlir
+/// %res = vector.multi_reduction <add> %src, %acc [0, 2]
+///     : vector<NxMx...xf32> to vector<Ix...xf32>
+/// ```
+///
+/// becomes:
+///
+/// ```mlir
+/// %0 = vector.extract %src[0] : vector<Mx...xf32> from vector<NxMx...xf32>
+/// ...
+/// %Nminus1 = vector.extract %src[N-1] : vector<Mx...xf32> from vector<NxMx...xf32>
+///
+/// %red0 = vector.multi_reduction <add>, %0, %acc [1]
+///     : vector<Mx...xf32> to vector<Ix...xf32>
+/// ...
+/// %res = vector.multi_reduction <add>, %Nminus1, %redNminus2 [1]
+///     : vector<Mx...xf32> to vector<Ix...xf32>
+/// ```
 struct UnrollMultiReductionOuterGeneralCase
     : public OpRewritePattern<vector::MultiDimReductionOp> {
   using Base::Base;

>From 2240787542a6eb7a26ae6a0f575d3a1ad9ba347e Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Fri, 30 Jan 2026 16:48:55 -0500
Subject: [PATCH 12/15] Rename

---
 .../Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp  | 5 +++--
 1 file changed, 3 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 08fcff00b2c19..2f8d0a2e2da05 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -760,7 +760,7 @@ struct UnrollMultiReductionOuterGeneralCase
 ///     : vector<Cxf32> into vector<AxCxf32>
 /// // ... repeat for indices 1 to A-1
 /// ```
-struct UnrollMultiReductionInner
+struct UnrollInnerReductionAlongOuterParallel
     : public OpRewritePattern<vector::MultiDimReductionOp> {
   using Base::Base;
 
@@ -918,7 +918,8 @@ void mlir::vector::populateVectorUnrollMultiReduction(
   patterns.add<OneDimMultiReductionToReduction>(patterns.getContext(), benefit);
 
   if (options == VectorMultiReductionLowering::InnerReduction) {
-    patterns.add<UnrollMultiReductionInner>(patterns.getContext(), benefit);
+    patterns.add<UnrollInnerReductionAlongOuterParallel>(patterns.getContext(),
+                                                         benefit);
   } else {
     patterns.add<UnrollMultiReductionOuterBaseCase,
                  UnrollMultiReductionOuterGeneralCase>(patterns.getContext(),

>From 36037c7db36b83dacd13ae77c0497231e2cceada Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Fri, 30 Jan 2026 17:01:32 -0500
Subject: [PATCH 13/15] Style

---
 .../Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp | 6 ++++--
 1 file changed, 4 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 2f8d0a2e2da05..3c9e7f05105e1 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -554,7 +554,8 @@ struct OneDimMultiReductionToTwoDim
 /// ```mlir
 /// %0 = vector.extract %src[0] : vector<Mx...xf32> from vector<NxMx...xf32>
 /// ...
-/// %Nminus1 = vector.extract %src[N-1] : vector<Mx...xf32> from vector<NxMx...xf32>
+/// %Nminus1 = vector.extract %src[N-1] : vector<Mx...xf32> from
+/// vector<NxMx...xf32>
 ///
 /// %res0 = arith.addf %0, %acc : vector<Mx...xf32>
 /// ...
@@ -645,7 +646,8 @@ struct UnrollMultiReductionOuterBaseCase
 /// ```mlir
 /// %0 = vector.extract %src[0] : vector<Mx...xf32> from vector<NxMx...xf32>
 /// ...
-/// %Nminus1 = vector.extract %src[N-1] : vector<Mx...xf32> from vector<NxMx...xf32>
+/// %Nminus1 = vector.extract %src[N-1] : vector<Mx...xf32> from
+/// vector<NxMx...xf32>
 ///
 /// %red0 = vector.multi_reduction <add>, %0, %acc [1]
 ///     : vector<Mx...xf32> to vector<Ix...xf32>

>From 5bf640c6d59a879ccf9632e3b80698097b252a9f Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Fri, 30 Jan 2026 18:57:15 -0500
Subject: [PATCH 14/15] Fix asan

---
 .../Vector/Transforms/LowerVectorMultiReduction.cpp        | 7 ++++---
 1 file changed, 4 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 3c9e7f05105e1..ea65b02fe71fa 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -713,8 +713,8 @@ struct UnrollMultiReductionOuterGeneralCase
       else
         masks.push_back(nullptr);
 
-    ArrayRef<bool> reductionMask =
-        ArrayRef<bool>(multiReductionOp.getReductionMask()).drop_front();
+    SmallVector<bool> fullReductionMask = multiReductionOp.getReductionMask();
+    ArrayRef<bool> reductionMask = ArrayRef<bool>(fullReductionMask).drop_front();
     Value result = multiReductionOp.getAcc();
     for (auto [innerVector, innerMask] : llvm::zip(vectors, masks)) {
 
@@ -829,8 +829,9 @@ struct UnrollInnerReductionAlongOuterParallel
     // Compute new reduction mask by dropping the first element (dimension 0).
     // Since dimension 0 is parallel (not reduced), all reduction indices shift
     // down by 1.
+    SmallVector<bool> fullReductionMask = multiReductionOp.getReductionMask();
     ArrayRef<bool> newReductionMask =
-        ArrayRef<bool>(multiReductionOp.getReductionMask()).drop_front();
+        ArrayRef<bool>(fullReductionMask).drop_front();
 
     SmallVector<Value> reductionResults;
     for (auto [srcSlice, accSlice, maskSlice] :

>From 0c9e013db584bb4c84648065400a3eadfb6b4622 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Fri, 30 Jan 2026 19:00:09 -0500
Subject: [PATCH 15/15] Style

---
 .../Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp    | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index ea65b02fe71fa..3b3aaa2b22eca 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -714,7 +714,8 @@ struct UnrollMultiReductionOuterGeneralCase
         masks.push_back(nullptr);
 
     SmallVector<bool> fullReductionMask = multiReductionOp.getReductionMask();
-    ArrayRef<bool> reductionMask = ArrayRef<bool>(fullReductionMask).drop_front();
+    ArrayRef<bool> reductionMask =
+        ArrayRef<bool>(fullReductionMask).drop_front();
     Value result = multiReductionOp.getAcc();
     for (auto [innerVector, innerMask] : llvm::zip(vectors, masks)) {
 



More information about the Mlir-commits mailing list