[Mlir-commits] [mlir] [mlir][vector] Add multi_reduction_rank_reducing_unrolling (PR #182301)

Erick Ochoa Lopez llvmlistbot at llvm.org
Mon Feb 23 12:49:49 PST 2026


================
@@ -486,6 +431,157 @@ struct OneDimMultiReductionToTwoDim
   }
 };
 
+/// Unrolls outermost dimension for vector.multi_reduction.
+/// Matches when the outermost dimension is the only reduction
+/// dimension.
+///
+/// In this case [0] refers to rank at position N, so it is the outermost
+/// dimension.
+///
+/// ```mlir
+/// %res = vector.multi_reduction <add> %src, %acc [0] : vector<NxMx...xf32> to
+/// vector<Mx...xf32>
+/// ```
+///
+/// will extract N vectors from %src and then perform elementwise operations.
+///
+/// ```mlir
+/// %0 = vector.extract %src[0] : vector<Mx...xf32> from vector<NxMx...xf32>
+/// ...
+/// %Nminus1 = vector.extract %src[ [[N-1]] ] : vector<Mx...x.f32> from
+/// vector<NxMx...xf32>
+///
+/// %res0 = arith.addf %0, %acc : vector<Mx...xf32>
+/// ...
+/// %res = arith.addf %Nminus1, %resNminus2 : vector<Mx...xf32>
+/// ```
+struct UnrollMultiReductionInnerParallelBaseCase
+    : public vector::MaskableOpRewritePattern<vector::MultiDimReductionOp> {
+  using MaskableOpRewritePattern::MaskableOpRewritePattern;
+
+  FailureOr<Value>
+  matchAndRewriteMaskableOp(vector::MultiDimReductionOp multiReductionOp,
+                            vector::MaskingOpInterface maskingOp,
+                            PatternRewriter &rewriter) const override {
+    auto srcRank = multiReductionOp.getSourceVectorType().getRank();
+    if (srcRank < 2)
----------------
amd-eochoalo wrote:

Yes, but nothing was exercising this path. Every test that would exercise this path was specialized to innerreduction not innerparallel.

https://github.com/llvm/llvm-project/pull/182301


More information about the Mlir-commits mailing list