[Mlir-commits] [mlir] [mlir][vector] Generalize multi_reduction innerparallel unrolling to N dimensions (PR #182301)

Erick Ochoa Lopez llvmlistbot at llvm.org
Mon Feb 23 07:40:03 PST 2026


================
@@ -486,6 +431,157 @@ struct OneDimMultiReductionToTwoDim
   }
 };
 
+/// Unrolls outermost dimension for vector.multi_reduction.
+/// Matches when the outermost dimension is the only reduction
+/// dimension.
+///
+/// In this case [0] refers to rank at position N, so it is the outermost
+/// dimension.
+///
+/// ```mlir
+/// %res = vector.multi_reduction <add> %src, %acc [0] : vector<NxMx...xf32> to
+/// vector<Mx...xf32>
+/// ```
+///
+/// will extract N vectors from %src and then perform elementwise operations.
+///
+/// ```mlir
+/// %0 = vector.extract %src[0] : vector<Mx...xf32> from vector<NxMx...xf32>
+/// ...
+/// %Nminus1 = vector.extract %src[ [[N-1]] ] : vector<Mx...x.f32> from
+/// vector<NxMx...xf32>
+///
+/// %res0 = arith.addf %0, %acc : vector<Mx...xf32>
+/// ...
+/// %res = arith.addf %Nminus1, %resNminus2 : vector<Mx...xf32>
+/// ```
+struct UnrollMultiReductionInnerParallelBaseCase
+    : public vector::MaskableOpRewritePattern<vector::MultiDimReductionOp> {
+  using MaskableOpRewritePattern::MaskableOpRewritePattern;
+
+  FailureOr<Value>
+  matchAndRewriteMaskableOp(vector::MultiDimReductionOp multiReductionOp,
+                            vector::MaskingOpInterface maskingOp,
+                            PatternRewriter &rewriter) const override {
+    auto srcRank = multiReductionOp.getSourceVectorType().getRank();
+    if (srcRank < 2)
+      return rewriter.notifyMatchFailure(multiReductionOp,
+                                         "expected source rank >= 2.");
+
+    if (!multiReductionOp.isReducedDim(0))
+      return rewriter.notifyMatchFailure(
+          multiReductionOp,
+          "expected outermost dimension to be reduced dimension.");
+
+    ArrayRef<int64_t> reductionDims = multiReductionOp.getReductionDims();
+    if (reductionDims.size() > 1)
+      return rewriter.notifyMatchFailure(
+          multiReductionOp, "expected only one reduction dimension.");
+
+    Location loc = multiReductionOp.getLoc();
+    Value source = multiReductionOp.getSource();
+
+    ArrayRef<int64_t> srcShape =
+        multiReductionOp.getSourceVectorType().getShape();
+    int64_t numElementwiseOps = srcShape.front();
+
+    Value mask = maskingOp ? maskingOp.getMask() : nullptr;
+
+    SmallVector<Value> vectors(numElementwiseOps);
+    for (int64_t i = 0; i < numElementwiseOps; ++i)
+      vectors[i] = vector::ExtractOp::create(rewriter, loc, source, i);
+
+    SmallVector<Value> masks(numElementwiseOps);
+    if (mask)
+      for (int64_t i = 0; i < numElementwiseOps; ++i)
+        masks[i] = vector::ExtractOp::create(rewriter, loc, mask, i);
+
+    Value result = multiReductionOp.getAcc();
+    for (auto [innerVector, innerMask] : llvm::zip_equal(vectors, masks))
+      result = makeArithReduction(rewriter, loc, multiReductionOp.getKind(),
+                                  innerVector, result, /*fastmath=*/nullptr,
+                                  innerMask);
+
+    return result;
+  }
+};
+
+/// Unrolls outermost dimension for vector.multi_reduction.
+/// Matches when the outermost dimension is not the only
+/// reduction dimension.
+///
+/// ```mlir
+/// %res = vector.multi_reduction <add> %src, %acc [0, [[REDUCTION_DIMS]] ] :
+/// vector<NxMx...xf32> to vector<Ix...xf32>
+/// ```
+///
+/// ```mlir
+/// %0 = vector.extract %src[0] : vector<Mx...xf32> from vector<NxMx...xf32>
+/// ...
+/// %Nminus1 = vector.extract %src[ [[N-1]] ] : vector<Mx...x.f32> from
+/// vector<NxMx...xf32>
+///
+/// %red0 = vector.multi_reduction %0, %acc [ [[REDUCTION_DIMS]] ] :
+/// vector<Mx...xf32> to vector<Ix...xf32>
+/// ...
+/// %res = vector.multi_reduction %Nminus1, %redNminus2 [ [[REDUCTION_DIMS]] ] :
+/// vector<Mx...xf32> to vector<Ix...xf32>
+/// ```
+struct UnrollMultiReductionInnerParallelGeneralCase
----------------
amd-eochoalo wrote:

https://github.com/llvm/llvm-project/pull/182301/commits/a7c870b37f5e0fc12c00bf0b0112cd1bdea6fc2a

https://github.com/llvm/llvm-project/pull/182301


More information about the Mlir-commits mailing list