[Mlir-commits] [mlir] [mlir][vector] Generalize multi_reduction innerparallel unrolling to N dimensions (PR #182301)

Erick Ochoa Lopez llvmlistbot at llvm.org
Mon Feb 23 07:44:33 PST 2026


https://github.com/amd-eochoalo updated https://github.com/llvm/llvm-project/pull/182301

>From 9a7a2774b1d904841fc04d07b4f91cfa60ba74d5 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Fri, 30 Jan 2026 11:13:52 -0500
Subject: [PATCH 01/23] [mlir][vector] rank reduce unrolling for
 vector.multi_reduction.

---
 .../Transforms/LowerVectorMultiReduction.cpp  | 199 +++++++++++++++++-
 .../vector-multi-reduction-pass-lowering.mlir |   6 +-
 2 files changed, 198 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index fec04c967c9e1..90fed006327e2 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -486,6 +486,195 @@ struct OneDimMultiReductionToTwoDim
   }
 };
 
+/// Unrolls outermost dimension for vector.multi_reduction.
+/// This patterns matches operations which reduce the outermost dimension,
+/// it does not transform operations for which the outermost dimension is not
+/// a reduction dimension.
+///
+/// There are two cases to consider:
+/// 1. The base case is when the outermost dimension is the only reduction
+/// dimension.
+/// 2. The general case is when the outermost dimension is not the only
+/// reduction dimension.
+///
+/// The base case transformation:
+///
+/// ```mlir
+/// %res = vector.multi_reduction <add> %src, %acc [0] : vector<NxMx...xf32> to
+/// vector<Mx...xf32>
+/// ```
+///
+/// will extract N vectors from %src and then perform elementwise operations.
+///
+/// ```mlir
+/// %0 = vector.extract %src[0] : vector<Mx...xf32> from vector<NxMx...xf32>
+/// ...
+/// %Nminus1 = vector.extract %src[ [[N-1]] ] : vector<Mx...x.f32> from
+/// vector<NxMx...xf32>
+///
+/// %res0 = arith.addf %0, %acc : vector<Mx...xf32>
+/// ...
+/// %res = arith.addf %Nminus1, %resNminus2 : vector<Mx...xf32>
+/// ```
+///
+/// For the general case, we still extract N vectors, but produce N
+/// vector.multi_reduction instead of elementwise operations.
+///
+/// ```mlir
+/// %res = vector.multi_reduction <add> %src, %acc [0, [[REDUCTION_DIMS]] ] :
+/// vector<NxMx...xf32> to vector<Ix...xf32>
+///
+/// ```mlir
+/// %0 = vector.extract %src[0] : vector<Mx...xf32> from vector<NxMx...xf32>
+/// ...
+/// %Nminus1 = vector.extract %src[ [[N-1]] ] : vector<Mx...x.f32> from
+/// vector<NxMx...xf32>
+///
+/// %red0 = vector.multi_reduction %0, %acc [ [[REDUCTION_DIMS]] ] :
+/// vector<Mx...xf32> to vector<Ix...xf32>
+/// ...
+/// %res = vector.multi_reduction %Nminus1, %redNminus2 [ [[REDUCTION_DIMS]] ] :
+/// vector<Mx...xf32> to vector<Ix...xf32>
+/// ```
+struct UnrollMultiReductionOuterBaseCase
+    : public OpRewritePattern<vector::MultiDimReductionOp> {
+  using Base::Base;
+
+  LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
+                                PatternRewriter &rewriter) const override {
+    if (!multiReductionOp.isReducedDim(0))
+      return rewriter.notifyMatchFailure(
+          multiReductionOp,
+          "expected outermost dimension to be reduced dimension.");
+
+    Type elementType = getElementTypeOrSelf(multiReductionOp.getDestType());
+    if (!elementType.isIntOrIndexOrFloat())
+      return rewriter.notifyMatchFailure(
+          multiReductionOp, "expected integer or float element type.");
+
+    ArrayRef<int64_t> reductionDims = multiReductionOp.getReductionDims();
+    if (reductionDims.size() > 1)
+      return rewriter.notifyMatchFailure(
+          multiReductionOp, "expected only one reduction dimension.");
+
+    Location loc = multiReductionOp.getLoc();
+    Value source = multiReductionOp.getSource();
+
+    ArrayRef<int64_t> srcShape =
+        multiReductionOp.getSourceVectorType().getShape();
+    int64_t numElementwiseOps = srcShape.front();
+
+    OpBuilder::InsertionGuard guard(rewriter);
+    auto maskableOp =
+        cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
+    bool isMasked = maskableOp.isMasked();
+    Operation *rootOp;
+    Value mask = nullptr;
+    if (isMasked) {
+      rewriter.setInsertionPoint(maskableOp.getMaskingOp());
+      rootOp = maskableOp.getMaskingOp();
+      mask = maskableOp.getMaskingOp().getMask();
+    } else {
+      rootOp = multiReductionOp;
+    }
+
+    SmallVector<Value> vectors;
+    for (int64_t i = 0; i < numElementwiseOps; ++i)
+      vectors.push_back(vector::ExtractOp::create(rewriter, loc, source, i));
+
+    SmallVector<Value> masks;
+    for (int64_t i = 0; i < numElementwiseOps; ++i)
+      if (isMasked)
+        masks.push_back(vector::ExtractOp::create(rewriter, loc, mask, i));
+      else
+        masks.push_back(nullptr);
+
+    Value result = multiReductionOp.getAcc();
+    for (auto [innerVector, innerMask] : llvm::zip(vectors, masks))
+      result = makeArithReduction(rewriter, loc, multiReductionOp.getKind(),
+                                  innerVector, result, /*fastmath=*/nullptr,
+                                  innerMask);
+
+    rewriter.replaceOp(rootOp, result);
+    return success();
+  }
+};
+
+struct UnrollMultiReductionOuterGeneralCase
+    : public OpRewritePattern<vector::MultiDimReductionOp> {
+  using Base::Base;
+
+  LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
+                                PatternRewriter &rewriter) const override {
+    if (!multiReductionOp.isReducedDim(0))
+      return rewriter.notifyMatchFailure(
+          multiReductionOp,
+          "expected outermost dimension to be reduced dimension.");
+
+    Type elementType = getElementTypeOrSelf(multiReductionOp.getDestType());
+    if (!elementType.isIntOrIndexOrFloat())
+      return rewriter.notifyMatchFailure(
+          multiReductionOp, "expected integer or float element type.");
+
+    ArrayRef<int64_t> reductionDims = multiReductionOp.getReductionDims();
+    if (reductionDims.size() <= 1)
+      return rewriter.notifyMatchFailure(
+          multiReductionOp, "expected more than one reduction dimension.");
+
+    Location loc = multiReductionOp.getLoc();
+    Value source = multiReductionOp.getSource();
+
+    ArrayRef<int64_t> srcShape =
+        multiReductionOp.getSourceVectorType().getShape();
+    int64_t numElementwiseOps = srcShape.front();
+
+    OpBuilder::InsertionGuard guard(rewriter);
+    auto maskableOp =
+        cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
+    bool isMasked = maskableOp.isMasked();
+    Operation *rootOp;
+    Value mask = nullptr;
+    if (isMasked) {
+      rewriter.setInsertionPoint(maskableOp.getMaskingOp());
+      rootOp = maskableOp.getMaskingOp();
+      mask = maskableOp.getMaskingOp().getMask();
+    } else {
+      rootOp = multiReductionOp;
+    }
+
+    SmallVector<Value> vectors;
+    for (int64_t i = 0; i < numElementwiseOps; ++i)
+      vectors.push_back(vector::ExtractOp::create(rewriter, loc, source, i));
+
+    SmallVector<Value> masks;
+    for (int64_t i = 0; i < numElementwiseOps; ++i)
+      if (isMasked)
+        masks.push_back(vector::ExtractOp::create(rewriter, loc, mask, i));
+      else
+        masks.push_back(nullptr);
+
+    ArrayRef<bool> reductionMask =
+        ArrayRef<bool>(multiReductionOp.getReductionMask()).drop_front();
+    Value result = multiReductionOp.getAcc();
+    for (auto [innerVector, innerMask] : llvm::zip(vectors, masks)) {
+
+      auto reductionOp = vector::MultiDimReductionOp::create(
+          rewriter, loc, innerVector, result, reductionMask,
+          multiReductionOp.getKind());
+
+      if (isMasked) {
+        auto maskOp = vector::maskOperation(rewriter, reductionOp, innerMask);
+        result = maskOp->getResult(0);
+      } else {
+        result = reductionOp.getResult();
+      }
+    }
+
+    rewriter.replaceOp(rootOp, result);
+    return success();
+  }
+};
+
 struct LowerVectorMultiReductionPass
     : public vector::impl::LowerVectorMultiReductionBase<
           LowerVectorMultiReductionPass> {
@@ -541,12 +730,14 @@ void mlir::vector::populateVectorMultiReductionFlatteningPatterns(
 void mlir::vector::populateVectorMultiReductionUnrollingPatterns(
     RewritePatternSet &patterns, VectorMultiReductionLowering options,
     PatternBenefit benefit) {
-  if (options == VectorMultiReductionLowering ::InnerReduction)
+  if (options == VectorMultiReductionLowering ::InnerReduction) {
     patterns.add<TwoDimMultiReductionToReduction>(patterns.getContext(),
                                                   benefit);
-  else
-    patterns.add<TwoDimMultiReductionToElementWise>(patterns.getContext(),
-                                                    benefit);
+  } else {
+    patterns.add<UnrollMultiReductionOuterBaseCase,
+                 UnrollMultiReductionOuterGeneralCase>(patterns.getContext(),
+                                                       benefit);
+  }
 }
 
 void mlir::vector::populateVectorMultiReductionLoweringPatterns(
diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-pass-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-pass-lowering.mlir
index ddbc5c7bdb2c0..e01bf446eb83c 100644
--- a/mlir/test/Dialect/Vector/vector-multi-reduction-pass-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-pass-lowering.mlir
@@ -21,12 +21,12 @@ func.func @vector_multi_reduction(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -
 
 //      INNER-PARALLEL: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32>
 //      INNER-PARALLEL: %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<2xf32> from vector<4x2xf32>
-//      INNER-PARALLEL: %[[RV0:.+]] = arith.mulf %[[V0]], %[[ACC]] : vector<2xf32>
 //      INNER-PARALLEL: %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<2xf32> from vector<4x2xf32>
-//      INNER-PARALLEL: %[[RV01:.+]] = arith.mulf %[[V1]], %[[RV0]] : vector<2xf32>
 //      INNER-PARALLEL: %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<2xf32> from vector<4x2xf32>
-//      INNER-PARALLEL: %[[RV012:.+]] = arith.mulf %[[V2]], %[[RV01]] : vector<2xf32>
 //      INNER-PARALLEL: %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<2xf32> from vector<4x2xf32>
+//      INNER-PARALLEL: %[[RV0:.+]] = arith.mulf %[[V0]], %[[ACC]] : vector<2xf32>
+//      INNER-PARALLEL: %[[RV01:.+]] = arith.mulf %[[V1]], %[[RV0]] : vector<2xf32>
+//      INNER-PARALLEL: %[[RV012:.+]] = arith.mulf %[[V2]], %[[RV01]] : vector<2xf32>
 //      INNER-PARALLEL: %[[RESULT_VEC:.+]] = arith.mulf %[[V3]], %[[RV012]] : vector<2xf32>
 //      INNER-PARALLEL: return %[[RESULT_VEC]] : vector<2xf32>
 

>From 12f3f0df22bb8b0520676d1a79abbfd3f24616fa Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Fri, 30 Jan 2026 12:07:30 -0500
Subject: [PATCH 02/23] [mlir][vector] Add tests for multi_reduction unrolling.

---
 .../Vector/TransformOps/VectorTransformOps.td | 20 ++++
 .../Vector/Transforms/LoweringPatterns.h      | 24 +++++
 .../TransformOps/VectorTransformOps.cpp       |  8 ++
 .../Transforms/LowerVectorMultiReduction.cpp  | 16 ++++
 .../Vector/td/unroll-multi-reduction.mlir     | 24 +++++
 .../Vector/unroll-vector-multi-reduction.mlir | 92 +++++++++++++++++++
 6 files changed, 184 insertions(+)
 create mode 100644 mlir/test/Dialect/Vector/td/unroll-multi-reduction.mlir
 create mode 100644 mlir/test/Dialect/Vector/unroll-vector-multi-reduction.mlir

diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index 685c88c17e556..01fb33274828a 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -306,6 +306,26 @@ def ApplyMultiReductionUnrollingPatternsOp: Op<Transform_Dialect,
   }];
 }
 
+def ApplyUnrollMultiReductionPatternsOp : Op<Transform_Dialect,
+    "apply_patterns.vector.unroll_multi_reduction",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Unrolls vector.multi_reduction operations by progressively reducing rank
+    along the outermost dimension.
+
+    This is an alternative to the flattening-based lowering that preserves
+    the n-D structure during progressive lowering.
+  }];
+
+  let arguments = (ins DefaultValuedAttr<VectorMultiReductionLoweringAttr,
+      "vector::VectorMultiReductionLowering::InnerParallel">:$lowering_strategy
+  );
+
+  let assemblyFormat = [{
+    (`lowering_strategy` `=` $lowering_strategy^)? attr-dict
+  }];
+}
+
 def ApplyLowerOuterProductPatternsOp : Op<Transform_Dialect,
     "apply_patterns.vector.lower_outerproduct",
     [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 33487a9d8d6e0..a0334afa05dad 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -118,6 +118,30 @@ void populateVectorMultiReductionLoweringPatterns(
     RewritePatternSet &patterns, VectorMultiReductionLowering options,
     PatternBenefit benefit = 1);
 
+/// Collect a set of patterns to unroll vector.multi_reduction ops by
+/// progressively reducing rank along the outermost dimension.
+///
+/// For OuterReduction (outermost dim is reduction):
+/// [UnrollMultiReductionOuterBaseCase]
+/// When the outermost dimension is the only reduction dimension, unroll to
+/// produce elementwise arithmetic operations.
+///
+/// [UnrollMultiReductionOuterGeneralCase]
+/// When the outermost dimension is one of multiple reduction dimensions,
+/// unroll to produce smaller multi_reduction operations.
+///
+/// For InnerReduction (innermost dim is reduction):
+/// [UnrollMultiReductionInnerBaseCase]
+/// When the innermost dimension is the only reduction dimension, unroll along
+/// the outermost parallel dimension.
+///
+/// [UnrollMultiReductionInnerGeneralCase]
+/// When the innermost dimension is one of multiple reduction dimensions,
+/// unroll along the outermost parallel dimension.
+void populateVectorUnrollMultiReduction(RewritePatternSet &patterns,
+                                        VectorMultiReductionLowering options,
+                                        PatternBenefit benefit = 1);
+
 /// Populate the pattern set with the following patterns:
 ///
 /// [TransferReadToVectorLoadLowering]
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 6c97c6501a23e..fca1ca6bade92 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -162,6 +162,14 @@ void transform::ApplyMultiReductionUnrollingPatternsOp::populatePatterns(
       patterns, vectorTransformOptions.vectorMultiReductionLowering);
 }
 
+void transform::ApplyUnrollMultiReductionPatternsOp::populatePatterns(
+    RewritePatternSet &patterns) {
+  vector::VectorTransformsOptions vectorTransformOptions;
+  vectorTransformOptions.setVectorMultiReductionLowering(getLoweringStrategy());
+  vector::populateVectorUnrollMultiReduction(
+      patterns, vectorTransformOptions.vectorMultiReductionLowering);
+}
+
 void transform::ApplyLowerOuterProductPatternsOp::populatePatterns(
     RewritePatternSet &patterns) {
   populateVectorOuterProductLoweringPatterns(patterns);
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 90fed006327e2..5ac13dd9fe4d5 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -740,6 +740,22 @@ void mlir::vector::populateVectorMultiReductionUnrollingPatterns(
   }
 }
 
+void mlir::vector::populateVectorUnrollMultiReduction(
+    RewritePatternSet &patterns, VectorMultiReductionLowering options,
+    PatternBenefit benefit) {
+  if (options == VectorMultiReductionLowering::InnerReduction) {
+    // TODO: Add UnrollMultiReductionInnerBaseCase and
+    // UnrollMultiReductionInnerGeneralCase patterns here once implemented.
+    // For now, fall back to the existing 2-D based lowering.
+    patterns.add<TwoDimMultiReductionToReduction>(patterns.getContext(),
+                                                  benefit);
+  } else {
+    patterns.add<UnrollMultiReductionOuterBaseCase,
+                 UnrollMultiReductionOuterGeneralCase>(patterns.getContext(),
+                                                       benefit);
+  }
+}
+
 void mlir::vector::populateVectorMultiReductionLoweringPatterns(
     RewritePatternSet &patterns, VectorMultiReductionLowering options,
     PatternBenefit benefit) {
diff --git a/mlir/test/Dialect/Vector/td/unroll-multi-reduction.mlir b/mlir/test/Dialect/Vector/td/unroll-multi-reduction.mlir
new file mode 100644
index 0000000000000..96a68723266d3
--- /dev/null
+++ b/mlir/test/Dialect/Vector/td/unroll-multi-reduction.mlir
@@ -0,0 +1,24 @@
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @unroll_multi_reduction(%module_op: !transform.any_op {transform.readonly}) {
+
+    %func_op = transform.structured.match ops{["func.func"]} in %module_op
+      : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func_op {
+      // Test patterns
+      transform.apply_patterns.vector.unroll_multi_reduction
+    } : !transform.any_op
+
+    transform.yield
+  }
+  transform.named_sequence @unroll_multi_reduction_inner(%module_op: !transform.any_op {transform.readonly}) {
+
+    %func_op = transform.structured.match ops{["func.func"]} in %module_op
+      : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func_op {
+      // Test patterns
+      transform.apply_patterns.vector.unroll_multi_reduction lowering_strategy = "innerreduction"
+    } : !transform.any_op
+
+    transform.yield
+  }
+}
diff --git a/mlir/test/Dialect/Vector/unroll-vector-multi-reduction.mlir b/mlir/test/Dialect/Vector/unroll-vector-multi-reduction.mlir
new file mode 100644
index 0000000000000..79086e2b0b9ad
--- /dev/null
+++ b/mlir/test/Dialect/Vector/unroll-vector-multi-reduction.mlir
@@ -0,0 +1,92 @@
+// RUN: mlir-opt --split-input-file %s -transform-preload-library='transform-library-paths=%p/td/unroll-multi-reduction.mlir' \
+// RUN: -transform-interpreter=entry-point=unroll_multi_reduction | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// Test UnrollVectorMultiReduction
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func @unroll_vector_multi_reduction(
+// CHECK-SAME: %[[SOURCE:.+]]: vector<2x3x5xf32>,
+// CHECK-SAME: %[[ACC:.+]]: vector<3x5xf32>
+func.func @unroll_vector_multi_reduction(%source: vector<2x3x5xf32>, %acc: vector<3x5xf32>) -> (vector<3x5xf32>) {
+  // CHECK-DAG: %[[VEC_0:.+]] = vector.extract %[[SOURCE]][0] : vector<3x5xf32> from vector<2x3x5xf32>
+  // CHECK-DAG: %[[VEC_1:.+]] = vector.extract %[[SOURCE]][1] : vector<3x5xf32> from vector<2x3x5xf32>
+
+  // CHECK: %[[RES_0:.+]] = arith.addf %[[VEC_0]], %[[ACC]] : vector<3x5xf32>
+  // CHECK: %[[RES_1:.+]] = arith.addf %[[VEC_1]], %[[RES_0]] : vector<3x5xf32>
+  %1 = vector.multi_reduction <add>, %source, %acc [0] : vector<2x3x5xf32> to vector<3x5xf32>
+
+  // CHECK: return %[[RES_1]]
+  return %1 : vector<3x5xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @unroll_vector_multi_reduction_masked(
+// CHECK-SAME: %[[SOURCE:.+]]: vector<2x3x5xf32>,
+// CHECK-SAME: %[[MASK:.+]]: vector<2x3x5xi1>,
+// CHECK-SAME: %[[ACC:.+]]: vector<3x5xf32>
+func.func @unroll_vector_multi_reduction_masked(%source: vector<2x3x5xf32>, %mask: vector<2x3x5xi1>, %acc: vector<3x5xf32>) -> (vector<3x5xf32>) {
+  // CHECK-DAG: %[[VEC_0:.+]] = vector.extract %[[SOURCE]][0] : vector<3x5xf32> from vector<2x3x5xf32>
+  // CHECK-DAG: %[[VEC_1:.+]] = vector.extract %[[SOURCE]][1] : vector<3x5xf32> from vector<2x3x5xf32>
+
+  // CHECK-DAG: %[[MASK_0:.+]] = vector.extract %[[MASK]][0] : vector<3x5xi1> from vector<2x3x5xi1>
+  // CHECK-DAG: %[[MASK_1:.+]] = vector.extract %[[MASK]][1] : vector<3x5xi1> from vector<2x3x5xi1>
+
+  // CHECK: %[[RES_0:.+]] = arith.addf %[[VEC_0]], %[[ACC]] : vector<3x5xf32>
+  // CHECK: %[[RES_MASKED_0:.+]] = arith.select %[[MASK_0]], %[[RES_0]], %[[ACC]] : vector<3x5xi1>, vector<3x5xf32>
+
+  // CHECK: %[[RES_1:.+]] = arith.addf %[[VEC_1]], %[[RES_MASKED_0]] : vector<3x5xf32>
+  // CHECK: %[[RES_MASKED_1:.+]] = arith.select %[[MASK_1]], %[[RES_1]], %[[RES_MASKED_0]] : vector<3x5xi1>, vector<3x5xf32>
+
+  %0 = vector.mask %mask {
+    %1 = vector.multi_reduction <add>, %source, %acc [0] : vector<2x3x5xf32> to vector<3x5xf32>
+  } : vector<2x3x5xi1> -> vector<3x5xf32>
+
+  // CHECK: return %[[RES_MASKED_1]]
+  return %0 : vector<3x5xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @unroll_vector_multi_reduction_general(
+// CHECK-SAME: %[[SOURCE:.+]]: vector<2x3x5xf32>,
+// CHECK-SAME: %[[ACC:.+]]: vector<3xf32>
+func.func @unroll_vector_multi_reduction_general(%source: vector<2x3x5xf32>, %acc: vector<3xf32>) -> (vector<3xf32>) {
+
+  // CHECK-DAG: %[[VEC_0:.+]] = vector.extract %[[SOURCE]][0] : vector<3x5xf32> from vector<2x3x5xf32>
+  // CHECK-DAG: %[[VEC_1:.+]] = vector.extract %[[SOURCE]][1] : vector<3x5xf32> from vector<2x3x5xf32>
+
+  // CHECK: %[[RES_0:.+]] = vector.multi_reduction <add>, %[[VEC_0]], %[[ACC]] [1] : vector<3x5xf32> to vector<3xf32>
+  // CHECK: %[[RES_1:.+]] = vector.multi_reduction <add>, %[[VEC_1]], %[[RES_0]] [1] : vector<3x5xf32> to vector<3xf32>
+
+  %1 = vector.multi_reduction <add>, %source, %acc [0, 2] : vector<2x3x5xf32> to vector<3xf32>
+
+  // CHECK: return %[[RES_1]]
+  return %1 : vector<3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @unroll_vector_multi_reduction_general_masked(
+// CHECK-SAME: %[[SOURCE:.+]]: vector<2x3x5xf32>,
+// CHECK-SAME: %[[MASK:.+]]: vector<2x3x5xi1>,
+// CHECK-SAME: %[[ACC:.+]]: vector<3xf32>
+func.func @unroll_vector_multi_reduction_general_masked(%source: vector<2x3x5xf32>, %mask: vector<2x3x5xi1>, %acc: vector<3xf32>) -> (vector<3xf32>) {
+
+  // CHECK-DAG: %[[VEC_0:.+]] = vector.extract %[[SOURCE]][0] : vector<3x5xf32> from vector<2x3x5xf32>
+  // CHECK-DAG: %[[VEC_1:.+]] = vector.extract %[[SOURCE]][1] : vector<3x5xf32> from vector<2x3x5xf32>
+
+  // CHECK-DAG: %[[MASK_0:.+]] = vector.extract %[[MASK]][0] : vector<3x5xi1> from vector<2x3x5xi1>
+  // CHECK-DAG: %[[MASK_1:.+]] = vector.extract %[[MASK]][1] : vector<3x5xi1> from vector<2x3x5xi1>
+
+  // CHECK: %[[RES_0:.+]] = vector.mask %[[MASK_0]] { vector.multi_reduction <add>, %[[VEC_0]], %[[ACC]] [1] : vector<3x5xf32> to vector<3xf32> } : vector<3x5xi1> -> vector<3xf32>
+  // CHECK: %[[RES_1:.+]] = vector.mask %[[MASK_1]] { vector.multi_reduction <add>, %[[VEC_1]], %[[RES_0]] [1] : vector<3x5xf32> to vector<3xf32> } : vector<3x5xi1> -> vector<3xf32>
+
+  %0 = vector.mask %mask {
+    %1 = vector.multi_reduction <add>, %source, %acc [0, 2] : vector<2x3x5xf32> to vector<3xf32>
+  } : vector<2x3x5xi1> -> vector<3xf32>
+
+  // CHECK: return %[[RES_1]]
+  return %0 : vector<3xf32>
+}

>From 5638833b7314f0ba2f3b55f7dbc1684947de03ee Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Wed, 18 Feb 2026 16:20:55 -0500
Subject: [PATCH 03/23] Use already existing infrastructure

---
 .../Vector/TransformOps/VectorTransformOps.td | 20 ----
 .../Vector/Transforms/LoweringPatterns.h      | 24 -----
 .../TransformOps/VectorTransformOps.cpp       |  8 --
 .../Transforms/LowerVectorMultiReduction.cpp  | 16 ----
 .../Vector/unroll-vector-multi-reduction.mlir | 92 -------------------
 5 files changed, 160 deletions(-)
 delete mode 100644 mlir/test/Dialect/Vector/unroll-vector-multi-reduction.mlir

diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index 01fb33274828a..685c88c17e556 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -306,26 +306,6 @@ def ApplyMultiReductionUnrollingPatternsOp: Op<Transform_Dialect,
   }];
 }
 
-def ApplyUnrollMultiReductionPatternsOp : Op<Transform_Dialect,
-    "apply_patterns.vector.unroll_multi_reduction",
-    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
-  let description = [{
-    Unrolls vector.multi_reduction operations by progressively reducing rank
-    along the outermost dimension.
-
-    This is an alternative to the flattening-based lowering that preserves
-    the n-D structure during progressive lowering.
-  }];
-
-  let arguments = (ins DefaultValuedAttr<VectorMultiReductionLoweringAttr,
-      "vector::VectorMultiReductionLowering::InnerParallel">:$lowering_strategy
-  );
-
-  let assemblyFormat = [{
-    (`lowering_strategy` `=` $lowering_strategy^)? attr-dict
-  }];
-}
-
 def ApplyLowerOuterProductPatternsOp : Op<Transform_Dialect,
     "apply_patterns.vector.lower_outerproduct",
     [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index a0334afa05dad..33487a9d8d6e0 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -118,30 +118,6 @@ void populateVectorMultiReductionLoweringPatterns(
     RewritePatternSet &patterns, VectorMultiReductionLowering options,
     PatternBenefit benefit = 1);
 
-/// Collect a set of patterns to unroll vector.multi_reduction ops by
-/// progressively reducing rank along the outermost dimension.
-///
-/// For OuterReduction (outermost dim is reduction):
-/// [UnrollMultiReductionOuterBaseCase]
-/// When the outermost dimension is the only reduction dimension, unroll to
-/// produce elementwise arithmetic operations.
-///
-/// [UnrollMultiReductionOuterGeneralCase]
-/// When the outermost dimension is one of multiple reduction dimensions,
-/// unroll to produce smaller multi_reduction operations.
-///
-/// For InnerReduction (innermost dim is reduction):
-/// [UnrollMultiReductionInnerBaseCase]
-/// When the innermost dimension is the only reduction dimension, unroll along
-/// the outermost parallel dimension.
-///
-/// [UnrollMultiReductionInnerGeneralCase]
-/// When the innermost dimension is one of multiple reduction dimensions,
-/// unroll along the outermost parallel dimension.
-void populateVectorUnrollMultiReduction(RewritePatternSet &patterns,
-                                        VectorMultiReductionLowering options,
-                                        PatternBenefit benefit = 1);
-
 /// Populate the pattern set with the following patterns:
 ///
 /// [TransferReadToVectorLoadLowering]
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index fca1ca6bade92..6c97c6501a23e 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -162,14 +162,6 @@ void transform::ApplyMultiReductionUnrollingPatternsOp::populatePatterns(
       patterns, vectorTransformOptions.vectorMultiReductionLowering);
 }
 
-void transform::ApplyUnrollMultiReductionPatternsOp::populatePatterns(
-    RewritePatternSet &patterns) {
-  vector::VectorTransformsOptions vectorTransformOptions;
-  vectorTransformOptions.setVectorMultiReductionLowering(getLoweringStrategy());
-  vector::populateVectorUnrollMultiReduction(
-      patterns, vectorTransformOptions.vectorMultiReductionLowering);
-}
-
 void transform::ApplyLowerOuterProductPatternsOp::populatePatterns(
     RewritePatternSet &patterns) {
   populateVectorOuterProductLoweringPatterns(patterns);
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 5ac13dd9fe4d5..90fed006327e2 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -740,22 +740,6 @@ void mlir::vector::populateVectorMultiReductionUnrollingPatterns(
   }
 }
 
-void mlir::vector::populateVectorUnrollMultiReduction(
-    RewritePatternSet &patterns, VectorMultiReductionLowering options,
-    PatternBenefit benefit) {
-  if (options == VectorMultiReductionLowering::InnerReduction) {
-    // TODO: Add UnrollMultiReductionInnerBaseCase and
-    // UnrollMultiReductionInnerGeneralCase patterns here once implemented.
-    // For now, fall back to the existing 2-D based lowering.
-    patterns.add<TwoDimMultiReductionToReduction>(patterns.getContext(),
-                                                  benefit);
-  } else {
-    patterns.add<UnrollMultiReductionOuterBaseCase,
-                 UnrollMultiReductionOuterGeneralCase>(patterns.getContext(),
-                                                       benefit);
-  }
-}
-
 void mlir::vector::populateVectorMultiReductionLoweringPatterns(
     RewritePatternSet &patterns, VectorMultiReductionLowering options,
     PatternBenefit benefit) {
diff --git a/mlir/test/Dialect/Vector/unroll-vector-multi-reduction.mlir b/mlir/test/Dialect/Vector/unroll-vector-multi-reduction.mlir
deleted file mode 100644
index 79086e2b0b9ad..0000000000000
--- a/mlir/test/Dialect/Vector/unroll-vector-multi-reduction.mlir
+++ /dev/null
@@ -1,92 +0,0 @@
-// RUN: mlir-opt --split-input-file %s -transform-preload-library='transform-library-paths=%p/td/unroll-multi-reduction.mlir' \
-// RUN: -transform-interpreter=entry-point=unroll_multi_reduction | FileCheck %s
-
-//===----------------------------------------------------------------------===//
-// Test UnrollVectorMultiReduction
-//===----------------------------------------------------------------------===//
-
-// CHECK-LABEL: func @unroll_vector_multi_reduction(
-// CHECK-SAME: %[[SOURCE:.+]]: vector<2x3x5xf32>,
-// CHECK-SAME: %[[ACC:.+]]: vector<3x5xf32>
-func.func @unroll_vector_multi_reduction(%source: vector<2x3x5xf32>, %acc: vector<3x5xf32>) -> (vector<3x5xf32>) {
-  // CHECK-DAG: %[[VEC_0:.+]] = vector.extract %[[SOURCE]][0] : vector<3x5xf32> from vector<2x3x5xf32>
-  // CHECK-DAG: %[[VEC_1:.+]] = vector.extract %[[SOURCE]][1] : vector<3x5xf32> from vector<2x3x5xf32>
-
-  // CHECK: %[[RES_0:.+]] = arith.addf %[[VEC_0]], %[[ACC]] : vector<3x5xf32>
-  // CHECK: %[[RES_1:.+]] = arith.addf %[[VEC_1]], %[[RES_0]] : vector<3x5xf32>
-  %1 = vector.multi_reduction <add>, %source, %acc [0] : vector<2x3x5xf32> to vector<3x5xf32>
-
-  // CHECK: return %[[RES_1]]
-  return %1 : vector<3x5xf32>
-}
-
-// -----
-
-// CHECK-LABEL: func @unroll_vector_multi_reduction_masked(
-// CHECK-SAME: %[[SOURCE:.+]]: vector<2x3x5xf32>,
-// CHECK-SAME: %[[MASK:.+]]: vector<2x3x5xi1>,
-// CHECK-SAME: %[[ACC:.+]]: vector<3x5xf32>
-func.func @unroll_vector_multi_reduction_masked(%source: vector<2x3x5xf32>, %mask: vector<2x3x5xi1>, %acc: vector<3x5xf32>) -> (vector<3x5xf32>) {
-  // CHECK-DAG: %[[VEC_0:.+]] = vector.extract %[[SOURCE]][0] : vector<3x5xf32> from vector<2x3x5xf32>
-  // CHECK-DAG: %[[VEC_1:.+]] = vector.extract %[[SOURCE]][1] : vector<3x5xf32> from vector<2x3x5xf32>
-
-  // CHECK-DAG: %[[MASK_0:.+]] = vector.extract %[[MASK]][0] : vector<3x5xi1> from vector<2x3x5xi1>
-  // CHECK-DAG: %[[MASK_1:.+]] = vector.extract %[[MASK]][1] : vector<3x5xi1> from vector<2x3x5xi1>
-
-  // CHECK: %[[RES_0:.+]] = arith.addf %[[VEC_0]], %[[ACC]] : vector<3x5xf32>
-  // CHECK: %[[RES_MASKED_0:.+]] = arith.select %[[MASK_0]], %[[RES_0]], %[[ACC]] : vector<3x5xi1>, vector<3x5xf32>
-
-  // CHECK: %[[RES_1:.+]] = arith.addf %[[VEC_1]], %[[RES_MASKED_0]] : vector<3x5xf32>
-  // CHECK: %[[RES_MASKED_1:.+]] = arith.select %[[MASK_1]], %[[RES_1]], %[[RES_MASKED_0]] : vector<3x5xi1>, vector<3x5xf32>
-
-  %0 = vector.mask %mask {
-    %1 = vector.multi_reduction <add>, %source, %acc [0] : vector<2x3x5xf32> to vector<3x5xf32>
-  } : vector<2x3x5xi1> -> vector<3x5xf32>
-
-  // CHECK: return %[[RES_MASKED_1]]
-  return %0 : vector<3x5xf32>
-}
-
-// -----
-
-// CHECK-LABEL: func @unroll_vector_multi_reduction_general(
-// CHECK-SAME: %[[SOURCE:.+]]: vector<2x3x5xf32>,
-// CHECK-SAME: %[[ACC:.+]]: vector<3xf32>
-func.func @unroll_vector_multi_reduction_general(%source: vector<2x3x5xf32>, %acc: vector<3xf32>) -> (vector<3xf32>) {
-
-  // CHECK-DAG: %[[VEC_0:.+]] = vector.extract %[[SOURCE]][0] : vector<3x5xf32> from vector<2x3x5xf32>
-  // CHECK-DAG: %[[VEC_1:.+]] = vector.extract %[[SOURCE]][1] : vector<3x5xf32> from vector<2x3x5xf32>
-
-  // CHECK: %[[RES_0:.+]] = vector.multi_reduction <add>, %[[VEC_0]], %[[ACC]] [1] : vector<3x5xf32> to vector<3xf32>
-  // CHECK: %[[RES_1:.+]] = vector.multi_reduction <add>, %[[VEC_1]], %[[RES_0]] [1] : vector<3x5xf32> to vector<3xf32>
-
-  %1 = vector.multi_reduction <add>, %source, %acc [0, 2] : vector<2x3x5xf32> to vector<3xf32>
-
-  // CHECK: return %[[RES_1]]
-  return %1 : vector<3xf32>
-}
-
-// -----
-
-// CHECK-LABEL: func @unroll_vector_multi_reduction_general_masked(
-// CHECK-SAME: %[[SOURCE:.+]]: vector<2x3x5xf32>,
-// CHECK-SAME: %[[MASK:.+]]: vector<2x3x5xi1>,
-// CHECK-SAME: %[[ACC:.+]]: vector<3xf32>
-func.func @unroll_vector_multi_reduction_general_masked(%source: vector<2x3x5xf32>, %mask: vector<2x3x5xi1>, %acc: vector<3xf32>) -> (vector<3xf32>) {
-
-  // CHECK-DAG: %[[VEC_0:.+]] = vector.extract %[[SOURCE]][0] : vector<3x5xf32> from vector<2x3x5xf32>
-  // CHECK-DAG: %[[VEC_1:.+]] = vector.extract %[[SOURCE]][1] : vector<3x5xf32> from vector<2x3x5xf32>
-
-  // CHECK-DAG: %[[MASK_0:.+]] = vector.extract %[[MASK]][0] : vector<3x5xi1> from vector<2x3x5xi1>
-  // CHECK-DAG: %[[MASK_1:.+]] = vector.extract %[[MASK]][1] : vector<3x5xi1> from vector<2x3x5xi1>
-
-  // CHECK: %[[RES_0:.+]] = vector.mask %[[MASK_0]] { vector.multi_reduction <add>, %[[VEC_0]], %[[ACC]] [1] : vector<3x5xf32> to vector<3xf32> } : vector<3x5xi1> -> vector<3xf32>
-  // CHECK: %[[RES_1:.+]] = vector.mask %[[MASK_1]] { vector.multi_reduction <add>, %[[VEC_1]], %[[RES_0]] [1] : vector<3x5xf32> to vector<3xf32> } : vector<3x5xi1> -> vector<3xf32>
-
-  %0 = vector.mask %mask {
-    %1 = vector.multi_reduction <add>, %source, %acc [0, 2] : vector<2x3x5xf32> to vector<3xf32>
-  } : vector<2x3x5xi1> -> vector<3xf32>
-
-  // CHECK: return %[[RES_1]]
-  return %0 : vector<3xf32>
-}

>From 952c3e41ccc578cf220862adfde43df4bfd6156f Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Wed, 18 Feb 2026 16:34:54 -0500
Subject: [PATCH 04/23] fix tests

---
 .../Transforms/LowerVectorMultiReduction.cpp  |  5 +++++
 .../vector-multi-reduction-unrolling.mlir     | 20 ++++++++++---------
 2 files changed, 16 insertions(+), 9 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 90fed006327e2..6df55232c605d 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -542,6 +542,11 @@ struct UnrollMultiReductionOuterBaseCase
 
   LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
                                 PatternRewriter &rewriter) const override {
+    auto srcRank = multiReductionOp.getSourceVectorType().getRank();
+    if (srcRank < 2)
+      return rewriter.notifyMatchFailure(multiReductionOp,
+                                         "expected source rank >= 2.");
+
     if (!multiReductionOp.isReducedDim(0))
       return rewriter.notifyMatchFailure(
           multiReductionOp,
diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-unrolling.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-unrolling.mlir
index bc0d192e012ee..dc6d1a94c4bdb 100644
--- a/mlir/test/Dialect/Vector/vector-multi-reduction-unrolling.mlir
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-unrolling.mlir
@@ -58,12 +58,12 @@ func.func @inner_reduction_2d_scalable(%input: vector<2x[4]xf32>, %acc: vector<2
 // ALL-SAME:    %[[INPUT:.+]]: vector<4x2xf32>, %[[ACC:.+]]: vector<2xf32>
 func.func @inner_parallel_2d(%arg0: vector<4x2xf32>, %acc: vector<2xf32>) -> vector<2xf32> {
     // INNER_PARALLEL: %[[V0:.+]] = vector.extract %[[INPUT]][0] : vector<2xf32> from vector<4x2xf32>
-    // INNER_PARALLEL: %[[RV0:.+]] = arith.mulf %[[V0]], %[[ACC]] : vector<2xf32>
     // INNER_PARALLEL: %[[V1:.+]] = vector.extract %[[INPUT]][1] : vector<2xf32> from vector<4x2xf32>
-    // INNER_PARALLEL: %[[RV1:.+]] = arith.mulf %[[V1]], %[[RV0]] : vector<2xf32>
     // INNER_PARALLEL: %[[V2:.+]] = vector.extract %[[INPUT]][2] : vector<2xf32> from vector<4x2xf32>
-    // INNER_PARALLEL: %[[RV2:.+]] = arith.mulf %[[V2]], %[[RV1]] : vector<2xf32>
     // INNER_PARALLEL: %[[V3:.+]] = vector.extract %[[INPUT]][3] : vector<2xf32> from vector<4x2xf32>
+    // INNER_PARALLEL: %[[RV0:.+]] = arith.mulf %[[V0]], %[[ACC]] : vector<2xf32>
+    // INNER_PARALLEL: %[[RV1:.+]] = arith.mulf %[[V1]], %[[RV0]] : vector<2xf32>
+    // INNER_PARALLEL: %[[RV2:.+]] = arith.mulf %[[V2]], %[[RV1]] : vector<2xf32>
     // INNER_PARALLEL: %[[RESULT:.+]] = arith.mulf %[[V3]], %[[RV2]] : vector<2xf32>
     // INNER_REDUCTION: %[[RESULT:.+]] = vector.multi_reduction <mul>, %[[INPUT]], %[[ACC]] [0]
     // ALL:             return %[[RESULT]] : vector<2xf32>
@@ -75,19 +75,21 @@ func.func @inner_parallel_2d(%arg0: vector<4x2xf32>, %acc: vector<2xf32>) -> vec
 // ALL-SAME:    %[[INPUT:.+]]: vector<4x2xf32>, %[[ACC:.+]]: vector<2xf32>, %[[MASK:.+]]: vector<4x2xi1>
 func.func @inner_parallel_2d_masked(%arg0: vector<4x2xf32>, %acc: vector<2xf32>, %mask: vector<4x2xi1>) -> vector<2xf32> {
     // INNER_PARALLEL: %[[V0:.+]] = vector.extract %[[INPUT]][0] : vector<2xf32> from vector<4x2xf32>
+    // INNER_PARALLEL: %[[V1:.+]] = vector.extract %[[INPUT]][1] : vector<2xf32> from vector<4x2xf32>
+    // INNER_PARALLEL: %[[V2:.+]] = vector.extract %[[INPUT]][2] : vector<2xf32> from vector<4x2xf32>
+    // INNER_PARALLEL: %[[V3:.+]] = vector.extract %[[INPUT]][3] : vector<2xf32> from vector<4x2xf32>
+
     // INNER_PARALLEL: %[[M0:.+]] = vector.extract %[[MASK]][0] : vector<2xi1> from vector<4x2xi1>
+    // INNER_PARALLEL: %[[M1:.+]] = vector.extract %[[MASK]][1] : vector<2xi1> from vector<4x2xi1>
+    // INNER_PARALLEL: %[[M2:.+]] = vector.extract %[[MASK]][2] : vector<2xi1> from vector<4x2xi1>
+    // INNER_PARALLEL: %[[M3:.+]] = vector.extract %[[MASK]][3] : vector<2xi1> from vector<4x2xi1>
+
     // INNER_PARALLEL: %[[RED0:.+]] = arith.mulf %[[V0]], %[[ACC]] : vector<2xf32>
     // INNER_PARALLEL: %[[RV0:.+]] = arith.select %[[M0]], %[[RED0]], %[[ACC]] : vector<2xi1>, vector<2xf32>
-    // INNER_PARALLEL: %[[V1:.+]] = vector.extract %[[INPUT]][1] : vector<2xf32> from vector<4x2xf32>
-    // INNER_PARALLEL: %[[M1:.+]] = vector.extract %[[MASK]][1] : vector<2xi1> from vector<4x2xi1>
     // INNER_PARALLEL: %[[RED1:.+]] = arith.mulf %[[V1]], %[[RV0]] : vector<2xf32>
     // INNER_PARALLEL: %[[RV1:.+]] = arith.select %[[M1]], %[[RED1]], %[[RV0]] : vector<2xi1>, vector<2xf32>
-    // INNER_PARALLEL: %[[V2:.+]] = vector.extract %[[INPUT]][2] : vector<2xf32> from vector<4x2xf32>
-    // INNER_PARALLEL: %[[M2:.+]] = vector.extract %[[MASK]][2] : vector<2xi1> from vector<4x2xi1>
     // INNER_PARALLEL: %[[RED2:.+]] = arith.mulf %[[V2]], %[[RV1]] : vector<2xf32>
     // INNER_PARALLEL: %[[RV2:.+]] = arith.select %[[M2]], %[[RED2]], %[[RV1]] : vector<2xi1>, vector<2xf32>
-    // INNER_PARALLEL: %[[V3:.+]] = vector.extract %[[INPUT]][3] : vector<2xf32> from vector<4x2xf32>
-    // INNER_PARALLEL: %[[M3:.+]] = vector.extract %[[MASK]][3] : vector<2xi1> from vector<4x2xi1>
     // INNER_PARALLEL: %[[RED3:.+]] = arith.mulf %[[V3]], %[[RV2]] : vector<2xf32>
     // INNER_PARALLEL: %[[RESULT:.+]] = arith.select %[[M3]], %[[RED3]], %[[RV2]] : vector<2xi1>, vector<2xf32>
     // INNER_REDUCTION: %[[RESULT:.+]] = vector.mask %[[MASK]] { vector.multi_reduction <mul>, %[[INPUT]], %[[ACC]] [0] {{.+}} } : vector<4x2xi1> -> vector<2xf32>

>From 63e02987b39df6c438e4dfc540be14b2c053fbda Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Thu, 19 Feb 2026 08:45:44 -0500
Subject: [PATCH 05/23] Delete's old implementation and renames new
 implementation

---
 .../Vector/TransformOps/VectorTransformOps.td |  3 +-
 .../Vector/Transforms/LoweringPatterns.h      | 11 ++--
 .../Transforms/LowerVectorMultiReduction.cpp  | 65 ++-----------------
 3 files changed, 14 insertions(+), 65 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index 685c88c17e556..a7de823de7705 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -294,7 +294,8 @@ def ApplyMultiReductionUnrollingPatternsOp: Op<Transform_Dialect,
     This populates the patterns from
     `populateVectorMultiReductionUnrollingPatterns`, i.e.:
     * `TwoDimMultiReductionToReduction` (innerreduction)
-    * `TwoDimMultiReductionToElementWise` (innerparallel)
+    * `UnrollMultiReductionInnerParallelBaseCase`
+    * `UnrollMultiReductionInnerParallelGeneralCase`
   }];
 
   let arguments = (ins DefaultValuedAttr<VectorMultiReductionLoweringAttr,
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 33487a9d8d6e0..ecc6c420e3a82 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -89,10 +89,13 @@ void populateVectorMultiReductionFlatteningPatterns(
 
 /// Populate the pattern set with the following patterns:
 ///
-/// [TwoDimMultiReductionToElementWise]
-/// Once in 2-D vector.multi_reduction form, with an **outermost** reduction
-/// dimension, unroll the outer dimension to obtain a sequence of 1-D vector
-/// ops. This also has an opportunity for tree-reduction (in the future).
+/// [UnrollMultiReductionInnerParallelBaseCase]
+/// Rank reducing unrolling for inner-parallel case, when there is only one
+/// reduction dimension and it is the outermost one.
+///
+/// [UnrollMultiReductionInnerParallelGeneralCase]
+/// Rank reducing unrolling for inner-parallel general case, when there is
+/// more than one reduction and it is the outermost one.
 ///
 /// [TwoDimMultiReductionToReduction]
 /// Once in 2-D vector.multi_reduction form, with an **innermost** reduction
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 6df55232c605d..a68300c8b9088 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -301,61 +301,6 @@ class ReduceMultiDimReductionRank
   const bool useInnerDimsForReduction;
 };
 
-/// Unrolls vector.multi_reduction with outermost reductions
-/// and combines results
-struct TwoDimMultiReductionToElementWise
-    : public OpRewritePattern<vector::MultiDimReductionOp> {
-  using Base::Base;
-
-  LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
-                                PatternRewriter &rewriter) const override {
-    auto srcRank = multiReductionOp.getSourceVectorType().getRank();
-    // Rank-2 ["parallel", "reduce"] or bail.
-    if (srcRank != 2)
-      return failure();
-
-    if (multiReductionOp.isReducedDim(1) || !multiReductionOp.isReducedDim(0))
-      return failure();
-
-    auto loc = multiReductionOp.getLoc();
-    ArrayRef<int64_t> srcShape =
-        multiReductionOp.getSourceVectorType().getShape();
-
-    Type elementType = getElementTypeOrSelf(multiReductionOp.getDestType());
-    if (!elementType.isIntOrIndexOrFloat())
-      return failure();
-
-    OpBuilder::InsertionGuard guard(rewriter);
-    auto maskableOp =
-        cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
-    Operation *rootOp;
-    Value mask = nullptr;
-    if (maskableOp.isMasked()) {
-      rewriter.setInsertionPoint(maskableOp.getMaskingOp());
-      rootOp = maskableOp.getMaskingOp();
-      mask = maskableOp.getMaskingOp().getMask();
-    } else {
-      rootOp = multiReductionOp;
-    }
-
-    Value result = multiReductionOp.getAcc();
-    for (int64_t i = 0; i < srcShape[0]; i++) {
-      auto operand = vector::ExtractOp::create(rewriter, loc,
-                                               multiReductionOp.getSource(), i);
-      Value extractMask = nullptr;
-      if (mask) {
-        extractMask = vector::ExtractOp::create(rewriter, loc, mask, i);
-      }
-      result =
-          makeArithReduction(rewriter, loc, multiReductionOp.getKind(), operand,
-                             result, /*fastmath=*/nullptr, extractMask);
-    }
-
-    rewriter.replaceOp(rootOp, result);
-    return success();
-  }
-};
-
 /// Converts 2d vector.multi_reduction with inner most reduction dimension into
 /// a sequence of vector.reduction ops.
 struct TwoDimMultiReductionToReduction
@@ -536,7 +481,7 @@ struct OneDimMultiReductionToTwoDim
 /// %res = vector.multi_reduction %Nminus1, %redNminus2 [ [[REDUCTION_DIMS]] ] :
 /// vector<Mx...xf32> to vector<Ix...xf32>
 /// ```
-struct UnrollMultiReductionOuterBaseCase
+struct UnrollMultiReductionInnerParallelBaseCase
     : public OpRewritePattern<vector::MultiDimReductionOp> {
   using Base::Base;
 
@@ -605,7 +550,7 @@ struct UnrollMultiReductionOuterBaseCase
   }
 };
 
-struct UnrollMultiReductionOuterGeneralCase
+struct UnrollMultiReductionInnerParallelGeneralCase
     : public OpRewritePattern<vector::MultiDimReductionOp> {
   using Base::Base;
 
@@ -739,9 +684,9 @@ void mlir::vector::populateVectorMultiReductionUnrollingPatterns(
     patterns.add<TwoDimMultiReductionToReduction>(patterns.getContext(),
                                                   benefit);
   } else {
-    patterns.add<UnrollMultiReductionOuterBaseCase,
-                 UnrollMultiReductionOuterGeneralCase>(patterns.getContext(),
-                                                       benefit);
+    patterns.add<UnrollMultiReductionInnerParallelBaseCase,
+                 UnrollMultiReductionInnerParallelGeneralCase>(
+        patterns.getContext(), benefit);
   }
 }
 

>From b62cb5388e63d58eea0893c10488c9a172a361e0 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Thu, 19 Feb 2026 08:47:55 -0500
Subject: [PATCH 06/23] Remove unnecessary file after re-structuring

---
 .../Vector/td/unroll-multi-reduction.mlir     | 24 -------------------
 1 file changed, 24 deletions(-)
 delete mode 100644 mlir/test/Dialect/Vector/td/unroll-multi-reduction.mlir

diff --git a/mlir/test/Dialect/Vector/td/unroll-multi-reduction.mlir b/mlir/test/Dialect/Vector/td/unroll-multi-reduction.mlir
deleted file mode 100644
index 96a68723266d3..0000000000000
--- a/mlir/test/Dialect/Vector/td/unroll-multi-reduction.mlir
+++ /dev/null
@@ -1,24 +0,0 @@
-module attributes {transform.with_named_sequence} {
-  transform.named_sequence @unroll_multi_reduction(%module_op: !transform.any_op {transform.readonly}) {
-
-    %func_op = transform.structured.match ops{["func.func"]} in %module_op
-      : (!transform.any_op) -> !transform.any_op
-    transform.apply_patterns to %func_op {
-      // Test patterns
-      transform.apply_patterns.vector.unroll_multi_reduction
-    } : !transform.any_op
-
-    transform.yield
-  }
-  transform.named_sequence @unroll_multi_reduction_inner(%module_op: !transform.any_op {transform.readonly}) {
-
-    %func_op = transform.structured.match ops{["func.func"]} in %module_op
-      : (!transform.any_op) -> !transform.any_op
-    transform.apply_patterns to %func_op {
-      // Test patterns
-      transform.apply_patterns.vector.unroll_multi_reduction lowering_strategy = "innerreduction"
-    } : !transform.any_op
-
-    transform.yield
-  }
-}

>From b8b22fe02faa213ca875fac5c5beed702cb797b5 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Thu, 19 Feb 2026 09:05:08 -0500
Subject: [PATCH 07/23] Add new case for general case

---
 .../vector-multi-reduction-unrolling.mlir     | 25 ++++++++++++++++---
 1 file changed, 21 insertions(+), 4 deletions(-)

diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-unrolling.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-unrolling.mlir
index dc6d1a94c4bdb..da375965ff837 100644
--- a/mlir/test/Dialect/Vector/vector-multi-reduction-unrolling.mlir
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-unrolling.mlir
@@ -54,9 +54,9 @@ func.func @inner_reduction_2d_scalable(%input: vector<2x[4]xf32>, %acc: vector<2
     return %0 : vector<2xf32>
 }
 
-// ALL-LABEL: func @inner_parallel_2d
+// ALL-LABEL: func @inner_parallel_base
 // ALL-SAME:    %[[INPUT:.+]]: vector<4x2xf32>, %[[ACC:.+]]: vector<2xf32>
-func.func @inner_parallel_2d(%arg0: vector<4x2xf32>, %acc: vector<2xf32>) -> vector<2xf32> {
+func.func @inner_parallel_base(%arg0: vector<4x2xf32>, %acc: vector<2xf32>) -> vector<2xf32> {
     // INNER_PARALLEL: %[[V0:.+]] = vector.extract %[[INPUT]][0] : vector<2xf32> from vector<4x2xf32>
     // INNER_PARALLEL: %[[V1:.+]] = vector.extract %[[INPUT]][1] : vector<2xf32> from vector<4x2xf32>
     // INNER_PARALLEL: %[[V2:.+]] = vector.extract %[[INPUT]][2] : vector<2xf32> from vector<4x2xf32>
@@ -71,9 +71,26 @@ func.func @inner_parallel_2d(%arg0: vector<4x2xf32>, %acc: vector<2xf32>) -> vec
     return %0 : vector<2xf32>
 }
 
-// ALL-LABEL: func @inner_parallel_2d_masked
+// ALL-LABEL: func @inner_parallel_general
+// ALL-SAME:    %[[INPUT:.+]]: vector<4x2x3xf32>, %[[ACC:.+]]: vector<2xf32>
+func.func @inner_parallel_general(%arg0: vector<4x2x3xf32>, %acc: vector<2xf32>) -> vector<2xf32> {
+    // INNER_PARALLEL: %[[V0:.+]] = vector.extract %[[INPUT]][0] : vector<2x3xf32> from vector<4x2x3xf32>
+    // INNER_PARALLEL: %[[V1:.+]] = vector.extract %[[INPUT]][1] : vector<2x3xf32> from vector<4x2x3xf32>
+    // INNER_PARALLEL: %[[V2:.+]] = vector.extract %[[INPUT]][2] : vector<2x3xf32> from vector<4x2x3xf32>
+    // INNER_PARALLEL: %[[V3:.+]] = vector.extract %[[INPUT]][3] : vector<2x3xf32> from vector<4x2x3xf32>
+    // INNER_PARALLEL: %[[RV0:.+]] = vector.multi_reduction <mul>, %[[V0]], %[[ACC]] [1] : vector<2x3xf32>
+    // INNER_PARALLEL: %[[RV1:.+]] = vector.multi_reduction <mul>, %[[V1]], %[[RV0]] [1] : vector<2x3xf32>
+    // INNER_PARALLEL: %[[RV2:.+]] = vector.multi_reduction <mul>, %[[V2]], %[[RV1]] [1] : vector<2x3xf32>
+    // INNER_PARALLEL: %[[RESULT:.+]] = vector.multi_reduction <mul>, %[[V3]], %[[RV2]] [1] : vector<2x3xf32>
+    // INNER_REDUCTION: %[[RESULT:.+]] = vector.multi_reduction <mul>, %[[INPUT]], %[[ACC]] [0, 2]
+    %0 = vector.multi_reduction <mul>, %arg0, %acc [0, 2] : vector<4x2x3xf32> to vector<2xf32>
+    // ALL:             return %[[RESULT]]
+    return %0 : vector<2xf32>
+}
+
+// ALL-LABEL: func @inner_parallel_base_masked
 // ALL-SAME:    %[[INPUT:.+]]: vector<4x2xf32>, %[[ACC:.+]]: vector<2xf32>, %[[MASK:.+]]: vector<4x2xi1>
-func.func @inner_parallel_2d_masked(%arg0: vector<4x2xf32>, %acc: vector<2xf32>, %mask: vector<4x2xi1>) -> vector<2xf32> {
+func.func @inner_parallel_base_masked(%arg0: vector<4x2xf32>, %acc: vector<2xf32>, %mask: vector<4x2xi1>) -> vector<2xf32> {
     // INNER_PARALLEL: %[[V0:.+]] = vector.extract %[[INPUT]][0] : vector<2xf32> from vector<4x2xf32>
     // INNER_PARALLEL: %[[V1:.+]] = vector.extract %[[INPUT]][1] : vector<2xf32> from vector<4x2xf32>
     // INNER_PARALLEL: %[[V2:.+]] = vector.extract %[[INPUT]][2] : vector<2xf32> from vector<4x2xf32>

>From f4e4ddfbb81640ca9a4834fc4ba8b6e7b5aa6dc6 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Thu, 19 Feb 2026 09:19:11 -0500
Subject: [PATCH 08/23] Use MaskableOpRewritePattern

---
 .../Transforms/LowerVectorMultiReduction.cpp  | 61 ++++++-------------
 1 file changed, 19 insertions(+), 42 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index a68300c8b9088..fcbbd06a4716a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -482,11 +482,13 @@ struct OneDimMultiReductionToTwoDim
 /// vector<Mx...xf32> to vector<Ix...xf32>
 /// ```
 struct UnrollMultiReductionInnerParallelBaseCase
-    : public OpRewritePattern<vector::MultiDimReductionOp> {
-  using Base::Base;
+    : public vector::MaskableOpRewritePattern<vector::MultiDimReductionOp> {
+  using MaskableOpRewritePattern::MaskableOpRewritePattern;
 
-  LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
-                                PatternRewriter &rewriter) const override {
+  FailureOr<Value>
+  matchAndRewriteMaskableOp(vector::MultiDimReductionOp multiReductionOp,
+                            vector::MaskingOpInterface maskingOp,
+                            PatternRewriter &rewriter) const override {
     auto srcRank = multiReductionOp.getSourceVectorType().getRank();
     if (srcRank < 2)
       return rewriter.notifyMatchFailure(multiReductionOp,
@@ -514,19 +516,7 @@ struct UnrollMultiReductionInnerParallelBaseCase
         multiReductionOp.getSourceVectorType().getShape();
     int64_t numElementwiseOps = srcShape.front();
 
-    OpBuilder::InsertionGuard guard(rewriter);
-    auto maskableOp =
-        cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
-    bool isMasked = maskableOp.isMasked();
-    Operation *rootOp;
-    Value mask = nullptr;
-    if (isMasked) {
-      rewriter.setInsertionPoint(maskableOp.getMaskingOp());
-      rootOp = maskableOp.getMaskingOp();
-      mask = maskableOp.getMaskingOp().getMask();
-    } else {
-      rootOp = multiReductionOp;
-    }
+    Value mask = maskingOp ? maskingOp.getMask() : nullptr;
 
     SmallVector<Value> vectors;
     for (int64_t i = 0; i < numElementwiseOps; ++i)
@@ -534,7 +524,7 @@ struct UnrollMultiReductionInnerParallelBaseCase
 
     SmallVector<Value> masks;
     for (int64_t i = 0; i < numElementwiseOps; ++i)
-      if (isMasked)
+      if (mask)
         masks.push_back(vector::ExtractOp::create(rewriter, loc, mask, i));
       else
         masks.push_back(nullptr);
@@ -545,17 +535,18 @@ struct UnrollMultiReductionInnerParallelBaseCase
                                   innerVector, result, /*fastmath=*/nullptr,
                                   innerMask);
 
-    rewriter.replaceOp(rootOp, result);
-    return success();
+    return result;
   }
 };
 
 struct UnrollMultiReductionInnerParallelGeneralCase
-    : public OpRewritePattern<vector::MultiDimReductionOp> {
-  using Base::Base;
+    : public vector::MaskableOpRewritePattern<vector::MultiDimReductionOp> {
+  using MaskableOpRewritePattern::MaskableOpRewritePattern;
 
-  LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
-                                PatternRewriter &rewriter) const override {
+  FailureOr<Value>
+  matchAndRewriteMaskableOp(vector::MultiDimReductionOp multiReductionOp,
+                            vector::MaskingOpInterface maskingOp,
+                            PatternRewriter &rewriter) const override {
     if (!multiReductionOp.isReducedDim(0))
       return rewriter.notifyMatchFailure(
           multiReductionOp,
@@ -578,19 +569,7 @@ struct UnrollMultiReductionInnerParallelGeneralCase
         multiReductionOp.getSourceVectorType().getShape();
     int64_t numElementwiseOps = srcShape.front();
 
-    OpBuilder::InsertionGuard guard(rewriter);
-    auto maskableOp =
-        cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
-    bool isMasked = maskableOp.isMasked();
-    Operation *rootOp;
-    Value mask = nullptr;
-    if (isMasked) {
-      rewriter.setInsertionPoint(maskableOp.getMaskingOp());
-      rootOp = maskableOp.getMaskingOp();
-      mask = maskableOp.getMaskingOp().getMask();
-    } else {
-      rootOp = multiReductionOp;
-    }
+    Value mask = maskingOp ? maskingOp.getMask() : nullptr;
 
     SmallVector<Value> vectors;
     for (int64_t i = 0; i < numElementwiseOps; ++i)
@@ -598,7 +577,7 @@ struct UnrollMultiReductionInnerParallelGeneralCase
 
     SmallVector<Value> masks;
     for (int64_t i = 0; i < numElementwiseOps; ++i)
-      if (isMasked)
+      if (mask)
         masks.push_back(vector::ExtractOp::create(rewriter, loc, mask, i));
       else
         masks.push_back(nullptr);
@@ -607,12 +586,11 @@ struct UnrollMultiReductionInnerParallelGeneralCase
         ArrayRef<bool>(multiReductionOp.getReductionMask()).drop_front();
     Value result = multiReductionOp.getAcc();
     for (auto [innerVector, innerMask] : llvm::zip(vectors, masks)) {
-
       auto reductionOp = vector::MultiDimReductionOp::create(
           rewriter, loc, innerVector, result, reductionMask,
           multiReductionOp.getKind());
 
-      if (isMasked) {
+      if (innerMask) {
         auto maskOp = vector::maskOperation(rewriter, reductionOp, innerMask);
         result = maskOp->getResult(0);
       } else {
@@ -620,8 +598,7 @@ struct UnrollMultiReductionInnerParallelGeneralCase
       }
     }
 
-    rewriter.replaceOp(rootOp, result);
-    return success();
+    return result;
   }
 };
 

>From 5be1bce88737e7b776d2072176ae863b399bcd2f Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Thu, 19 Feb 2026 09:27:21 -0500
Subject: [PATCH 09/23] Remove notifyMatchFailure in complementary patterns

---
 .../Vector/Transforms/LowerVectorMultiReduction.cpp    | 10 +++-------
 1 file changed, 3 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index fcbbd06a4716a..d33370fe2d2d0 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -548,19 +548,15 @@ struct UnrollMultiReductionInnerParallelGeneralCase
                             vector::MaskingOpInterface maskingOp,
                             PatternRewriter &rewriter) const override {
     if (!multiReductionOp.isReducedDim(0))
-      return rewriter.notifyMatchFailure(
-          multiReductionOp,
-          "expected outermost dimension to be reduced dimension.");
+      return failure();
 
     Type elementType = getElementTypeOrSelf(multiReductionOp.getDestType());
     if (!elementType.isIntOrIndexOrFloat())
-      return rewriter.notifyMatchFailure(
-          multiReductionOp, "expected integer or float element type.");
+      return failure();
 
     ArrayRef<int64_t> reductionDims = multiReductionOp.getReductionDims();
     if (reductionDims.size() <= 1)
-      return rewriter.notifyMatchFailure(
-          multiReductionOp, "expected more than one reduction dimension.");
+      return failure();
 
     Location loc = multiReductionOp.getLoc();
     Value source = multiReductionOp.getSource();

>From 0a9e3ffcb5b0aad1cb960d59b57e0c41525644e0 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Thu, 19 Feb 2026 09:32:04 -0500
Subject: [PATCH 10/23] Remove unnecessary check

---
 .../Vector/Transforms/LowerVectorMultiReduction.cpp      | 9 ---------
 1 file changed, 9 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index d33370fe2d2d0..4821d40172c70 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -499,11 +499,6 @@ struct UnrollMultiReductionInnerParallelBaseCase
           multiReductionOp,
           "expected outermost dimension to be reduced dimension.");
 
-    Type elementType = getElementTypeOrSelf(multiReductionOp.getDestType());
-    if (!elementType.isIntOrIndexOrFloat())
-      return rewriter.notifyMatchFailure(
-          multiReductionOp, "expected integer or float element type.");
-
     ArrayRef<int64_t> reductionDims = multiReductionOp.getReductionDims();
     if (reductionDims.size() > 1)
       return rewriter.notifyMatchFailure(
@@ -550,10 +545,6 @@ struct UnrollMultiReductionInnerParallelGeneralCase
     if (!multiReductionOp.isReducedDim(0))
       return failure();
 
-    Type elementType = getElementTypeOrSelf(multiReductionOp.getDestType());
-    if (!elementType.isIntOrIndexOrFloat())
-      return failure();
-
     ArrayRef<int64_t> reductionDims = multiReductionOp.getReductionDims();
     if (reductionDims.size() <= 1)
       return failure();

>From 4ae7defcc9b39105cf901c0c91c779f939006b76 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Thu, 19 Feb 2026 09:49:44 -0500
Subject: [PATCH 11/23] Improve loops

---
 .../Transforms/LowerVectorMultiReduction.cpp  | 20 ++++++++-----------
 1 file changed, 8 insertions(+), 12 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 4821d40172c70..7e6f90a128114 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -513,16 +513,14 @@ struct UnrollMultiReductionInnerParallelBaseCase
 
     Value mask = maskingOp ? maskingOp.getMask() : nullptr;
 
-    SmallVector<Value> vectors;
+    SmallVector<Value> vectors(numElementwiseOps);
     for (int64_t i = 0; i < numElementwiseOps; ++i)
-      vectors.push_back(vector::ExtractOp::create(rewriter, loc, source, i));
+      vectors[i] = vector::ExtractOp::create(rewriter, loc, source, i);
 
-    SmallVector<Value> masks;
+    SmallVector<Value> masks(numElementwiseOps);
     for (int64_t i = 0; i < numElementwiseOps; ++i)
       if (mask)
-        masks.push_back(vector::ExtractOp::create(rewriter, loc, mask, i));
-      else
-        masks.push_back(nullptr);
+        masks[i] = vector::ExtractOp::create(rewriter, loc, mask, i);
 
     Value result = multiReductionOp.getAcc();
     for (auto [innerVector, innerMask] : llvm::zip(vectors, masks))
@@ -558,16 +556,14 @@ struct UnrollMultiReductionInnerParallelGeneralCase
 
     Value mask = maskingOp ? maskingOp.getMask() : nullptr;
 
-    SmallVector<Value> vectors;
+    SmallVector<Value> vectors(numElementwiseOps);
     for (int64_t i = 0; i < numElementwiseOps; ++i)
-      vectors.push_back(vector::ExtractOp::create(rewriter, loc, source, i));
+      vectors[i] = vector::ExtractOp::create(rewriter, loc, source, i);
 
-    SmallVector<Value> masks;
+    SmallVector<Value> masks(numElementwiseOps);
     for (int64_t i = 0; i < numElementwiseOps; ++i)
       if (mask)
-        masks.push_back(vector::ExtractOp::create(rewriter, loc, mask, i));
-      else
-        masks.push_back(nullptr);
+        masks[i] = vector::ExtractOp::create(rewriter, loc, mask, i);
 
     ArrayRef<bool> reductionMask =
         ArrayRef<bool>(multiReductionOp.getReductionMask()).drop_front();

>From 655de6313b990a56772271003163e17ed44adc06 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Thu, 19 Feb 2026 10:11:34 -0500
Subject: [PATCH 12/23] sink loop and fix dangling array ref

---
 .../Vector/Transforms/LowerVectorMultiReduction.cpp   | 11 ++++++-----
 1 file changed, 6 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 7e6f90a128114..6c9b89bb62ae7 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -518,8 +518,8 @@ struct UnrollMultiReductionInnerParallelBaseCase
       vectors[i] = vector::ExtractOp::create(rewriter, loc, source, i);
 
     SmallVector<Value> masks(numElementwiseOps);
-    for (int64_t i = 0; i < numElementwiseOps; ++i)
-      if (mask)
+    if (mask)
+      for (int64_t i = 0; i < numElementwiseOps; ++i)
         masks[i] = vector::ExtractOp::create(rewriter, loc, mask, i);
 
     Value result = multiReductionOp.getAcc();
@@ -561,12 +561,13 @@ struct UnrollMultiReductionInnerParallelGeneralCase
       vectors[i] = vector::ExtractOp::create(rewriter, loc, source, i);
 
     SmallVector<Value> masks(numElementwiseOps);
-    for (int64_t i = 0; i < numElementwiseOps; ++i)
-      if (mask)
+    if (mask)
+      for (int64_t i = 0; i < numElementwiseOps; ++i)
         masks[i] = vector::ExtractOp::create(rewriter, loc, mask, i);
 
+    SmallVector<bool> fullReductionMask = multiReductionOp.getReductionMask();
     ArrayRef<bool> reductionMask =
-        ArrayRef<bool>(multiReductionOp.getReductionMask()).drop_front();
+        ArrayRef<bool>(fullReductionMask).drop_front();
     Value result = multiReductionOp.getAcc();
     for (auto [innerVector, innerMask] : llvm::zip(vectors, masks)) {
       auto reductionOp = vector::MultiDimReductionOp::create(

>From ca57413669bc2cce23c95a5035bc97c5cd7b7154 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Thu, 19 Feb 2026 10:14:23 -0500
Subject: [PATCH 13/23] Fix documentation

---
 .../mlir/Dialect/Vector/TransformOps/VectorTransformOps.td      | 2 +-
 mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h  | 2 +-
 .../lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp | 1 +
 3 files changed, 3 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index a7de823de7705..10950e701faa5 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -287,7 +287,7 @@ def ApplyMultiReductionUnrollingPatternsOp: Op<Transform_Dialect,
     "apply_patterns.vector.multi_reduction_unrolling",
     [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
   let description = [{
-    Indicates that 2-D vector multi_reduction operations should be unrolled
+    Indicates that vector multi_reduction operations should be unrolled
     into either a sequence of vector.reduction ops (innerreduction) or
     element-wise arith ops (innerparallel).
 
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index ecc6c420e3a82..7a1823c26a46e 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -95,7 +95,7 @@ void populateVectorMultiReductionFlatteningPatterns(
 ///
 /// [UnrollMultiReductionInnerParallelGeneralCase]
 /// Rank reducing unrolling for inner-parallel general case, when there is
-/// more than one reduction and it is the outermost one.
+/// more than one reduction dimension and it is the outermost one.
 ///
 /// [TwoDimMultiReductionToReduction]
 /// Once in 2-D vector.multi_reduction form, with an **innermost** reduction
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 6c9b89bb62ae7..b7a2c21f356e6 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -468,6 +468,7 @@ struct OneDimMultiReductionToTwoDim
 /// ```mlir
 /// %res = vector.multi_reduction <add> %src, %acc [0, [[REDUCTION_DIMS]] ] :
 /// vector<NxMx...xf32> to vector<Ix...xf32>
+/// ```
 ///
 /// ```mlir
 /// %0 = vector.extract %src[0] : vector<Mx...xf32> from vector<NxMx...xf32>

>From dc69e171b1e1af0b3613834cc4b60d07471e0512 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Thu, 19 Feb 2026 10:19:37 -0500
Subject: [PATCH 14/23] Rename variable

---
 .../Vector/Transforms/LowerVectorMultiReduction.cpp    | 10 +++++-----
 1 file changed, 5 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index b7a2c21f356e6..a6b3e8bba06d7 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -553,17 +553,17 @@ struct UnrollMultiReductionInnerParallelGeneralCase
 
     ArrayRef<int64_t> srcShape =
         multiReductionOp.getSourceVectorType().getShape();
-    int64_t numElementwiseOps = srcShape.front();
+    int64_t outerDimSize = srcShape.front();
 
     Value mask = maskingOp ? maskingOp.getMask() : nullptr;
 
-    SmallVector<Value> vectors(numElementwiseOps);
-    for (int64_t i = 0; i < numElementwiseOps; ++i)
+    SmallVector<Value> vectors(outerDimSize);
+    for (int64_t i = 0; i < outerDimSize; ++i)
       vectors[i] = vector::ExtractOp::create(rewriter, loc, source, i);
 
-    SmallVector<Value> masks(numElementwiseOps);
+    SmallVector<Value> masks(outerDimSize);
     if (mask)
-      for (int64_t i = 0; i < numElementwiseOps; ++i)
+      for (int64_t i = 0; i < outerDimSize; ++i)
         masks[i] = vector::ExtractOp::create(rewriter, loc, mask, i);
 
     SmallVector<bool> fullReductionMask = multiReductionOp.getReductionMask();

>From 6201ee875f79cab971e890c5572fcec700d093e3 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Thu, 19 Feb 2026 10:22:05 -0500
Subject: [PATCH 15/23] Splits documentation

---
 .../Transforms/LowerVectorMultiReduction.cpp  | 53 ++++++++-----------
 1 file changed, 22 insertions(+), 31 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index a6b3e8bba06d7..38ae3caa89d07 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -432,17 +432,8 @@ struct OneDimMultiReductionToTwoDim
 };
 
 /// Unrolls outermost dimension for vector.multi_reduction.
-/// This patterns matches operations which reduce the outermost dimension,
-/// it does not transform operations for which the outermost dimension is not
-/// a reduction dimension.
-///
-/// There are two cases to consider:
-/// 1. The base case is when the outermost dimension is the only reduction
+/// Matches when the outermost dimension is the only reduction
 /// dimension.
-/// 2. The general case is when the outermost dimension is not the only
-/// reduction dimension.
-///
-/// The base case transformation:
 ///
 /// ```mlir
 /// %res = vector.multi_reduction <add> %src, %acc [0] : vector<NxMx...xf32> to
@@ -461,27 +452,6 @@ struct OneDimMultiReductionToTwoDim
 /// ...
 /// %res = arith.addf %Nminus1, %resNminus2 : vector<Mx...xf32>
 /// ```
-///
-/// For the general case, we still extract N vectors, but produce N
-/// vector.multi_reduction instead of elementwise operations.
-///
-/// ```mlir
-/// %res = vector.multi_reduction <add> %src, %acc [0, [[REDUCTION_DIMS]] ] :
-/// vector<NxMx...xf32> to vector<Ix...xf32>
-/// ```
-///
-/// ```mlir
-/// %0 = vector.extract %src[0] : vector<Mx...xf32> from vector<NxMx...xf32>
-/// ...
-/// %Nminus1 = vector.extract %src[ [[N-1]] ] : vector<Mx...x.f32> from
-/// vector<NxMx...xf32>
-///
-/// %red0 = vector.multi_reduction %0, %acc [ [[REDUCTION_DIMS]] ] :
-/// vector<Mx...xf32> to vector<Ix...xf32>
-/// ...
-/// %res = vector.multi_reduction %Nminus1, %redNminus2 [ [[REDUCTION_DIMS]] ] :
-/// vector<Mx...xf32> to vector<Ix...xf32>
-/// ```
 struct UnrollMultiReductionInnerParallelBaseCase
     : public vector::MaskableOpRewritePattern<vector::MultiDimReductionOp> {
   using MaskableOpRewritePattern::MaskableOpRewritePattern;
@@ -533,6 +503,27 @@ struct UnrollMultiReductionInnerParallelBaseCase
   }
 };
 
+/// Unrolls outermost dimension for vector.multi_reduction.
+/// Matches when the outermost dimension is not the only
+/// reduction dimension.
+///
+/// ```mlir
+/// %res = vector.multi_reduction <add> %src, %acc [0, [[REDUCTION_DIMS]] ] :
+/// vector<NxMx...xf32> to vector<Ix...xf32>
+/// ```
+///
+/// ```mlir
+/// %0 = vector.extract %src[0] : vector<Mx...xf32> from vector<NxMx...xf32>
+/// ...
+/// %Nminus1 = vector.extract %src[ [[N-1]] ] : vector<Mx...x.f32> from
+/// vector<NxMx...xf32>
+///
+/// %red0 = vector.multi_reduction %0, %acc [ [[REDUCTION_DIMS]] ] :
+/// vector<Mx...xf32> to vector<Ix...xf32>
+/// ...
+/// %res = vector.multi_reduction %Nminus1, %redNminus2 [ [[REDUCTION_DIMS]] ] :
+/// vector<Mx...xf32> to vector<Ix...xf32>
+/// ```
 struct UnrollMultiReductionInnerParallelGeneralCase
     : public vector::MaskableOpRewritePattern<vector::MultiDimReductionOp> {
   using MaskableOpRewritePattern::MaskableOpRewritePattern;

>From fa1a713ec5953f95beaac5f4a01d0265de7ee131 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Thu, 19 Feb 2026 10:29:14 -0500
Subject: [PATCH 16/23] Make type explicit

---
 .../Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp    | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 38ae3caa89d07..3756d6d471ed7 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -567,7 +567,8 @@ struct UnrollMultiReductionInnerParallelGeneralCase
           multiReductionOp.getKind());
 
       if (innerMask) {
-        auto maskOp = vector::maskOperation(rewriter, reductionOp, innerMask);
+        Operation *maskOp =
+            vector::maskOperation(rewriter, reductionOp, innerMask);
         result = maskOp->getResult(0);
       } else {
         result = reductionOp.getResult();

>From 9265a47b5bac81a6afc4b021f5452aaffefe3c66 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Thu, 19 Feb 2026 10:46:00 -0500
Subject: [PATCH 17/23] Use zip_equal

---
 .../Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp   | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 3756d6d471ed7..8b817c74f72a5 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -494,7 +494,7 @@ struct UnrollMultiReductionInnerParallelBaseCase
         masks[i] = vector::ExtractOp::create(rewriter, loc, mask, i);
 
     Value result = multiReductionOp.getAcc();
-    for (auto [innerVector, innerMask] : llvm::zip(vectors, masks))
+    for (auto [innerVector, innerMask] : llvm::zip_equal(vectors, masks))
       result = makeArithReduction(rewriter, loc, multiReductionOp.getKind(),
                                   innerVector, result, /*fastmath=*/nullptr,
                                   innerMask);
@@ -561,7 +561,7 @@ struct UnrollMultiReductionInnerParallelGeneralCase
     ArrayRef<bool> reductionMask =
         ArrayRef<bool>(fullReductionMask).drop_front();
     Value result = multiReductionOp.getAcc();
-    for (auto [innerVector, innerMask] : llvm::zip(vectors, masks)) {
+    for (auto [innerVector, innerMask] : llvm::zip_equal(vectors, masks)) {
       auto reductionOp = vector::MultiDimReductionOp::create(
           rewriter, loc, innerVector, result, reductionMask,
           multiReductionOp.getKind());

>From fd6e54834f4bed2fdc8a84841c774dca967b473a Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Thu, 19 Feb 2026 10:49:21 -0500
Subject: [PATCH 18/23] clarify comment

---
 .../Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp    | 3 +++
 1 file changed, 3 insertions(+)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 8b817c74f72a5..28257c2ec39a5 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -435,6 +435,9 @@ struct OneDimMultiReductionToTwoDim
 /// Matches when the outermost dimension is the only reduction
 /// dimension.
 ///
+/// In this case [0] refers to rank at position N, so it is the outermost
+/// dimension.
+///
 /// ```mlir
 /// %res = vector.multi_reduction <add> %src, %acc [0] : vector<NxMx...xf32> to
 /// vector<Mx...xf32>

>From 9ff005fc6c3fbc51a29d93987478cd9a3c54ca10 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Mon, 23 Feb 2026 10:14:19 -0500
Subject: [PATCH 19/23] Update schedules to serialize lowering of
 multi_reduction.

This generalization is not compatible with the previous approach of
applying these patterns at the same time until a fixed point is reach.
This can be seen easily by noting that the generalization of unrolling
can match and rewrite vector.multi_reduction operations which are of
rank higher than two and would compete with other patterns found in
flattening or reorder_and_expand.
---
 mlir/test/Dialect/LLVM/transform-e2e.mlir     | 18 +++++++++++++---
 .../test/Dialect/Vector/transform-vector.mlir | 19 ++++++++++++-----
 .../Dialect/Linalg/CPU/ArmSVE/reduce_1d.mlir  | 21 +++++++++++++------
 .../Dialect/Linalg/CPU/ArmSVE/reduce_2d.mlir  | 20 ++++++++++++------
 .../Linalg/CPU/test-matmul-masked-vec.mlir    | 21 ++++++++++++++-----
 5 files changed, 74 insertions(+), 25 deletions(-)

diff --git a/mlir/test/Dialect/LLVM/transform-e2e.mlir b/mlir/test/Dialect/LLVM/transform-e2e.mlir
index ab58dda91a914..c739b92760244 100644
--- a/mlir/test/Dialect/LLVM/transform-e2e.mlir
+++ b/mlir/test/Dialect/LLVM/transform-e2e.mlir
@@ -13,6 +13,17 @@ func.func @matmul_tensors(
 }
 
 module attributes {transform.with_named_sequence} {
+  transform.named_sequence @multi_reduction_lowering(%func_op: !transform.any_op {transform.readonly}) {
+    transform.apply_patterns to %func_op {
+      transform.apply_patterns.vector.reorder_and_expand_multi_reduction_dims lowering_strategy = "innerparallel"
+      transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerparallel"
+    } : !transform.any_op
+    transform.apply_patterns to %func_op {
+      transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerparallel"
+    } : !transform.any_op
+    transform.yield
+  }
+
   transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.consumed}) {
     %0 = transform.structured.match ops{["linalg.matmul"]} in %module_op : (!transform.any_op) -> !transform.any_op
     %1, %loops:3 = transform.structured.tile_using_for %0 tile_sizes [2, 2, 2] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
@@ -30,10 +41,11 @@ module attributes {transform.with_named_sequence} {
     transform.apply_patterns to %f {
       transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
       transform.apply_patterns.vector.transfer_permutation_patterns
-      transform.apply_patterns.vector.reorder_and_expand_multi_reduction_dims lowering_strategy = "innerparallel"
-      transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerparallel"
-      transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerparallel"
       transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "linalg-copy"
+    } : !transform.any_op
+    transform.include @multi_reduction_lowering failures(propagate) (%f)
+      : (!transform.any_op) -> ()
+    transform.apply_patterns to %f {
       transform.apply_patterns.vector.transfer_to_scf max_transfer_rank = 1 full_unroll = true
       transform.apply_patterns.vector.lower_transfer max_transfer_rank = 1
       transform.apply_patterns.vector.lower_shape_cast
diff --git a/mlir/test/Dialect/Vector/transform-vector.mlir b/mlir/test/Dialect/Vector/transform-vector.mlir
index a37105d573219..6732f0c681d52 100644
--- a/mlir/test/Dialect/Vector/transform-vector.mlir
+++ b/mlir/test/Dialect/Vector/transform-vector.mlir
@@ -14,6 +14,18 @@ func.func @matmul_tensors(
 }
 
 module attributes {transform.with_named_sequence} {
+
+  transform.named_sequence @multi_reduction_lowering(%func_op: !transform.any_op {transform.readonly}) {
+    transform.apply_patterns to %func_op {
+      transform.apply_patterns.vector.reorder_and_expand_multi_reduction_dims lowering_strategy = "innerparallel"
+      transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerparallel"
+    } : !transform.any_op
+    transform.apply_patterns to %func_op {
+      transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerparallel"
+    } : !transform.any_op
+    transform.yield
+  }
+
   transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.consumed}) {
     %0 = transform.structured.match ops{["linalg.matmul"]} in %module_op : (!transform.any_op) -> !transform.any_op
     %1, %loops:3 = transform.structured.tile_using_for %0 tile_sizes [8, 4, 2]
@@ -38,11 +50,8 @@ module attributes {transform.with_named_sequence} {
       transform.apply_patterns.vector.transfer_permutation_patterns
     } : !transform.any_op
 
-    transform.apply_patterns to %f {
-      transform.apply_patterns.vector.reorder_and_expand_multi_reduction_dims lowering_strategy = "innerparallel"
-      transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerparallel"
-      transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerparallel"
-    } : !transform.any_op
+    transform.include @multi_reduction_lowering failures(propagate) (%f)
+      : (!transform.any_op) -> ()
 
     transform.apply_patterns to %f {
       transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "linalg-copy"
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_1d.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_1d.mlir
index 25b65080339d5..ed91a7a307631 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_1d.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_1d.mlir
@@ -132,6 +132,19 @@ func.func @generic_reduce_1d_f32() {
 }
 
 module attributes {transform.with_named_sequence} {
+
+  transform.named_sequence @multi_reduction_lowering(%func_op: !transform.any_op {transform.readonly}) {
+    transform.apply_patterns to %func_op {
+      transform.apply_patterns.vector.lower_masked_transfers
+      transform.apply_patterns.vector.reorder_and_expand_multi_reduction_dims lowering_strategy = "innerparallel"
+      transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerparallel"
+    } : !transform.any_op
+    transform.apply_patterns to %func_op {
+      transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerparallel"
+    } : !transform.any_op
+    transform.yield
+  }
+
   // A sequence that will tile and vectorise a Reduce Op
   transform.named_sequence @tile_and_vectorize_reduce(%func
     : !transform.op<"func.func"> {transform.readonly}) {
@@ -148,12 +161,8 @@ module attributes {transform.with_named_sequence} {
     transform.structured.vectorize %tiled_reduce vector_sizes [[4]] : !transform.any_op
 
     // Step 3: Lower vector.multi_reduction
-    transform.apply_patterns to %func {
-      transform.apply_patterns.vector.lower_masked_transfers
-      transform.apply_patterns.vector.reorder_and_expand_multi_reduction_dims lowering_strategy = "innerreduction"
-      transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerreduction"
-      transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerreduction"
-    } : !transform.op<"func.func">
+    transform.include @multi_reduction_lowering failures(propagate) (%func)
+      : (!transform.any_op) -> ()
 
     transform.yield
   }
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_2d.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_2d.mlir
index 6072b44adf4fa..4844de6df66fc 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_2d.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_2d.mlir
@@ -137,6 +137,18 @@ func.func @generic_reduce_2d_i32() {
 
 
 module attributes {transform.with_named_sequence} {
+  transform.named_sequence @multi_reduction_lowering(%func_op: !transform.any_op {transform.readonly}) {
+    transform.apply_patterns to %func_op {
+      transform.apply_patterns.vector.lower_masked_transfers
+      transform.apply_patterns.vector.reorder_and_expand_multi_reduction_dims lowering_strategy = "innerparallel"
+      transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerparallel"
+    } : !transform.any_op
+    transform.apply_patterns to %func_op {
+      transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerparallel"
+    } : !transform.any_op
+    transform.yield
+  }
+
   // A sequence that will tile and vectorise a Reduce Op
   transform.named_sequence @tile_and_vectorize_reduce(%func
     : !transform.op<"func.func"> {transform.readonly}) {
@@ -153,12 +165,8 @@ module attributes {transform.with_named_sequence} {
     transform.structured.vectorize %tiled_reduce vector_sizes [1, [4]] : !transform.any_op
 
     // Step 3: Lower vector.multi_reduction
-    transform.apply_patterns to %func {
-      transform.apply_patterns.vector.lower_masked_transfers
-      transform.apply_patterns.vector.reorder_and_expand_multi_reduction_dims lowering_strategy = "innerreduction"
-      transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerreduction"
-      transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerreduction"
-    } : !transform.op<"func.func">
+    transform.include @multi_reduction_lowering failures(propagate) (%func)
+      : (!transform.any_op) -> ()
 
     transform.yield
   }
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/test-matmul-masked-vec.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/test-matmul-masked-vec.mlir
index 3c4f10316d0f3..4d1d0539ce2af 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/test-matmul-masked-vec.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/test-matmul-masked-vec.mlir
@@ -48,15 +48,26 @@ func.func @main() {
 }
 
 module attributes {transform.with_named_sequence} {
+
+  transform.named_sequence @multi_reduction_lowering(%func_op: !transform.op<"func.func"> {transform.readonly}) {
+    transform.apply_patterns to %func_op {
+      transform.apply_patterns.vector.reorder_and_expand_multi_reduction_dims lowering_strategy = "innerparallel"
+      transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerparallel"
+    } : !transform.op<"func.func">
+    transform.apply_patterns to %func_op {
+      transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerparallel"
+    } : !transform.op<"func.func">
+    transform.yield
+  }
+
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     %func_op = transform.get_parent_op %0 : (!transform.any_op) -> !transform.op<"func.func">
     transform.structured.vectorize %0 vector_sizes [4, 4, 2] : !transform.any_op
-    transform.apply_patterns to %func_op {
-      transform.apply_patterns.vector.reorder_and_expand_multi_reduction_dims lowering_strategy = "innerreduction"
-      transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerreduction"
-      transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerreduction"
-    } : !transform.op<"func.func">
+
+    transform.include @multi_reduction_lowering failures(propagate) (%func_op)
+      : (!transform.op<"func.func">) -> ()
+
     transform.yield
   }
 }

>From a7c870b37f5e0fc12c00bf0b0112cd1bdea6fc2a Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Mon, 23 Feb 2026 10:27:38 -0500
Subject: [PATCH 20/23] Rename patterns

---
 .../Vector/Transforms/LowerVectorMultiReduction.cpp       | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index f3f3bd74837e4..5f609a9d7aaff 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -455,7 +455,7 @@ struct OneDimMultiReductionToTwoDim
 /// ...
 /// %res = arith.addf %Nminus1, %resNminus2 : vector<Mx...xf32>
 /// ```
-struct UnrollMultiReductionInnerParallelBaseCase
+struct MultiReductionToArithOps
     : public vector::MaskableOpRewritePattern<vector::MultiDimReductionOp> {
   using MaskableOpRewritePattern::MaskableOpRewritePattern;
 
@@ -527,7 +527,7 @@ struct UnrollMultiReductionInnerParallelBaseCase
 /// %res = vector.multi_reduction %Nminus1, %redNminus2 [ [[REDUCTION_DIMS]] ] :
 /// vector<Mx...xf32> to vector<Ix...xf32>
 /// ```
-struct UnrollMultiReductionInnerParallelGeneralCase
+struct UnrollMultiReductionInnerParallel
     : public vector::MaskableOpRewritePattern<vector::MultiDimReductionOp> {
   using MaskableOpRewritePattern::MaskableOpRewritePattern;
 
@@ -641,8 +641,8 @@ void mlir::vector::populateVectorMultiReductionUnrollingPatterns(
     patterns.add<TwoDimMultiReductionToReduction>(patterns.getContext(),
                                                   benefit);
   } else {
-    patterns.add<UnrollMultiReductionInnerParallelBaseCase,
-                 UnrollMultiReductionInnerParallelGeneralCase>(
+    patterns.add<MultiReductionToArithOps,
+                 UnrollMultiReductionInnerParallel>(
         patterns.getContext(), benefit);
   }
 }

>From 211c6ae8b6f00964b0ccf4431869c60025edf2ab Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Mon, 23 Feb 2026 10:33:20 -0500
Subject: [PATCH 21/23] Update tests functions' names

---
 .../Vector/vector-multi-reduction-unrolling.mlir     | 12 ++++++------
 1 file changed, 6 insertions(+), 6 deletions(-)

diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-unrolling.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-unrolling.mlir
index da375965ff837..bfbbc24ebe495 100644
--- a/mlir/test/Dialect/Vector/vector-multi-reduction-unrolling.mlir
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-unrolling.mlir
@@ -54,9 +54,9 @@ func.func @inner_reduction_2d_scalable(%input: vector<2x[4]xf32>, %acc: vector<2
     return %0 : vector<2xf32>
 }
 
-// ALL-LABEL: func @inner_parallel_base
+// ALL-LABEL: func @multi_reduction_to_arith_ops
 // ALL-SAME:    %[[INPUT:.+]]: vector<4x2xf32>, %[[ACC:.+]]: vector<2xf32>
-func.func @inner_parallel_base(%arg0: vector<4x2xf32>, %acc: vector<2xf32>) -> vector<2xf32> {
+func.func @multi_reduction_to_arith_ops(%arg0: vector<4x2xf32>, %acc: vector<2xf32>) -> vector<2xf32> {
     // INNER_PARALLEL: %[[V0:.+]] = vector.extract %[[INPUT]][0] : vector<2xf32> from vector<4x2xf32>
     // INNER_PARALLEL: %[[V1:.+]] = vector.extract %[[INPUT]][1] : vector<2xf32> from vector<4x2xf32>
     // INNER_PARALLEL: %[[V2:.+]] = vector.extract %[[INPUT]][2] : vector<2xf32> from vector<4x2xf32>
@@ -71,9 +71,9 @@ func.func @inner_parallel_base(%arg0: vector<4x2xf32>, %acc: vector<2xf32>) -> v
     return %0 : vector<2xf32>
 }
 
-// ALL-LABEL: func @inner_parallel_general
+// ALL-LABEL: func @unroll_multi_reduction_inner_parallel
 // ALL-SAME:    %[[INPUT:.+]]: vector<4x2x3xf32>, %[[ACC:.+]]: vector<2xf32>
-func.func @inner_parallel_general(%arg0: vector<4x2x3xf32>, %acc: vector<2xf32>) -> vector<2xf32> {
+func.func @unroll_multi_reduction_inner_parallel(%arg0: vector<4x2x3xf32>, %acc: vector<2xf32>) -> vector<2xf32> {
     // INNER_PARALLEL: %[[V0:.+]] = vector.extract %[[INPUT]][0] : vector<2x3xf32> from vector<4x2x3xf32>
     // INNER_PARALLEL: %[[V1:.+]] = vector.extract %[[INPUT]][1] : vector<2x3xf32> from vector<4x2x3xf32>
     // INNER_PARALLEL: %[[V2:.+]] = vector.extract %[[INPUT]][2] : vector<2x3xf32> from vector<4x2x3xf32>
@@ -88,9 +88,9 @@ func.func @inner_parallel_general(%arg0: vector<4x2x3xf32>, %acc: vector<2xf32>)
     return %0 : vector<2xf32>
 }
 
-// ALL-LABEL: func @inner_parallel_base_masked
+// ALL-LABEL: func @multi_reduction_to_arith_ops_masked
 // ALL-SAME:    %[[INPUT:.+]]: vector<4x2xf32>, %[[ACC:.+]]: vector<2xf32>, %[[MASK:.+]]: vector<4x2xi1>
-func.func @inner_parallel_base_masked(%arg0: vector<4x2xf32>, %acc: vector<2xf32>, %mask: vector<4x2xi1>) -> vector<2xf32> {
+func.func @multi_reduction_to_arith_ops_masked(%arg0: vector<4x2xf32>, %acc: vector<2xf32>, %mask: vector<4x2xi1>) -> vector<2xf32> {
     // INNER_PARALLEL: %[[V0:.+]] = vector.extract %[[INPUT]][0] : vector<2xf32> from vector<4x2xf32>
     // INNER_PARALLEL: %[[V1:.+]] = vector.extract %[[INPUT]][1] : vector<2xf32> from vector<4x2xf32>
     // INNER_PARALLEL: %[[V2:.+]] = vector.extract %[[INPUT]][2] : vector<2xf32> from vector<4x2xf32>

>From e5687cc5a32df2ed24fea846b07042d1e758695c Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Mon, 23 Feb 2026 10:35:23 -0500
Subject: [PATCH 22/23] Remove punctuation from notify match failures

---
 .../Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 5f609a9d7aaff..c0c7bc91b93a2 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -466,17 +466,17 @@ struct MultiReductionToArithOps
     auto srcRank = multiReductionOp.getSourceVectorType().getRank();
     if (srcRank < 2)
       return rewriter.notifyMatchFailure(multiReductionOp,
-                                         "expected source rank >= 2.");
+                                         "expected source rank >= 2");
 
     if (!multiReductionOp.isReducedDim(0))
       return rewriter.notifyMatchFailure(
           multiReductionOp,
-          "expected outermost dimension to be reduced dimension.");
+          "expected outermost dimension to be reduced dimension");
 
     ArrayRef<int64_t> reductionDims = multiReductionOp.getReductionDims();
     if (reductionDims.size() > 1)
       return rewriter.notifyMatchFailure(
-          multiReductionOp, "expected only one reduction dimension.");
+          multiReductionOp, "expected only one reduction dimension");
 
     Location loc = multiReductionOp.getLoc();
     Value source = multiReductionOp.getSource();

>From eeae40b8c375f356d98462a60dcf71ff15627df9 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Mon, 23 Feb 2026 10:41:53 -0500
Subject: [PATCH 23/23] Style

---
 .../Vector/Transforms/LowerVectorMultiReduction.cpp       | 8 +++-----
 1 file changed, 3 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index c0c7bc91b93a2..80d4023c219aa 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -637,14 +637,12 @@ void mlir::vector::populateVectorMultiReductionFlatteningPatterns(
 void mlir::vector::populateVectorMultiReductionUnrollingPatterns(
     RewritePatternSet &patterns, VectorMultiReductionLowering options,
     PatternBenefit benefit) {
-  if (options == VectorMultiReductionLowering ::InnerReduction) {
+  if (options == VectorMultiReductionLowering ::InnerReduction)
     patterns.add<TwoDimMultiReductionToReduction>(patterns.getContext(),
                                                   benefit);
-  } else {
-    patterns.add<MultiReductionToArithOps,
-                 UnrollMultiReductionInnerParallel>(
+  else
+    patterns.add<MultiReductionToArithOps, UnrollMultiReductionInnerParallel>(
         patterns.getContext(), benefit);
-  }
 }
 
 std::unique_ptr<Pass> vector::createLowerVectorMultiReductionPass(



More information about the Mlir-commits mailing list