[Mlir-commits] [mlir] [mlir][vector] Generalize multi_reduction innerparallel unrolling to N dimensions (PR #182301)
Erick Ochoa Lopez
llvmlistbot at llvm.org
Mon Feb 23 07:40:03 PST 2026
================
@@ -486,6 +431,157 @@ struct OneDimMultiReductionToTwoDim
}
};
+/// Unrolls outermost dimension for vector.multi_reduction.
+/// Matches when the outermost dimension is the only reduction
+/// dimension.
+///
+/// In this case [0] refers to rank at position N, so it is the outermost
+/// dimension.
+///
+/// ```mlir
+/// %res = vector.multi_reduction <add> %src, %acc [0] : vector<NxMx...xf32> to
+/// vector<Mx...xf32>
+/// ```
+///
+/// will extract N vectors from %src and then perform elementwise operations.
+///
+/// ```mlir
+/// %0 = vector.extract %src[0] : vector<Mx...xf32> from vector<NxMx...xf32>
+/// ...
+/// %Nminus1 = vector.extract %src[ [[N-1]] ] : vector<Mx...x.f32> from
+/// vector<NxMx...xf32>
+///
+/// %res0 = arith.addf %0, %acc : vector<Mx...xf32>
+/// ...
+/// %res = arith.addf %Nminus1, %resNminus2 : vector<Mx...xf32>
+/// ```
+struct UnrollMultiReductionInnerParallelBaseCase
+ : public vector::MaskableOpRewritePattern<vector::MultiDimReductionOp> {
+ using MaskableOpRewritePattern::MaskableOpRewritePattern;
+
+ FailureOr<Value>
+ matchAndRewriteMaskableOp(vector::MultiDimReductionOp multiReductionOp,
+ vector::MaskingOpInterface maskingOp,
+ PatternRewriter &rewriter) const override {
+ auto srcRank = multiReductionOp.getSourceVectorType().getRank();
+ if (srcRank < 2)
+ return rewriter.notifyMatchFailure(multiReductionOp,
+ "expected source rank >= 2.");
+
+ if (!multiReductionOp.isReducedDim(0))
+ return rewriter.notifyMatchFailure(
+ multiReductionOp,
+ "expected outermost dimension to be reduced dimension.");
+
+ ArrayRef<int64_t> reductionDims = multiReductionOp.getReductionDims();
+ if (reductionDims.size() > 1)
+ return rewriter.notifyMatchFailure(
+ multiReductionOp, "expected only one reduction dimension.");
+
+ Location loc = multiReductionOp.getLoc();
+ Value source = multiReductionOp.getSource();
+
+ ArrayRef<int64_t> srcShape =
+ multiReductionOp.getSourceVectorType().getShape();
+ int64_t numElementwiseOps = srcShape.front();
+
+ Value mask = maskingOp ? maskingOp.getMask() : nullptr;
+
+ SmallVector<Value> vectors(numElementwiseOps);
+ for (int64_t i = 0; i < numElementwiseOps; ++i)
+ vectors[i] = vector::ExtractOp::create(rewriter, loc, source, i);
+
+ SmallVector<Value> masks(numElementwiseOps);
+ if (mask)
+ for (int64_t i = 0; i < numElementwiseOps; ++i)
+ masks[i] = vector::ExtractOp::create(rewriter, loc, mask, i);
+
+ Value result = multiReductionOp.getAcc();
+ for (auto [innerVector, innerMask] : llvm::zip_equal(vectors, masks))
+ result = makeArithReduction(rewriter, loc, multiReductionOp.getKind(),
+ innerVector, result, /*fastmath=*/nullptr,
+ innerMask);
+
+ return result;
+ }
+};
+
+/// Unrolls outermost dimension for vector.multi_reduction.
+/// Matches when the outermost dimension is not the only
+/// reduction dimension.
+///
+/// ```mlir
+/// %res = vector.multi_reduction <add> %src, %acc [0, [[REDUCTION_DIMS]] ] :
+/// vector<NxMx...xf32> to vector<Ix...xf32>
+/// ```
+///
+/// ```mlir
+/// %0 = vector.extract %src[0] : vector<Mx...xf32> from vector<NxMx...xf32>
+/// ...
+/// %Nminus1 = vector.extract %src[ [[N-1]] ] : vector<Mx...x.f32> from
+/// vector<NxMx...xf32>
+///
+/// %red0 = vector.multi_reduction %0, %acc [ [[REDUCTION_DIMS]] ] :
+/// vector<Mx...xf32> to vector<Ix...xf32>
+/// ...
+/// %res = vector.multi_reduction %Nminus1, %redNminus2 [ [[REDUCTION_DIMS]] ] :
+/// vector<Mx...xf32> to vector<Ix...xf32>
+/// ```
+struct UnrollMultiReductionInnerParallelGeneralCase
----------------
amd-eochoalo wrote:
https://github.com/llvm/llvm-project/pull/182301/commits/a7c870b37f5e0fc12c00bf0b0112cd1bdea6fc2a
https://github.com/llvm/llvm-project/pull/182301
More information about the Mlir-commits
mailing list