[llvm] [LoopVectorizer] Add support for partial reductions (PR #92418)
Sam Tebbs via llvm-commits
llvm-commits at lists.llvm.org
Thu Sep 19 07:01:28 PDT 2024
================
@@ -1556,7 +1556,130 @@ class LoopVectorizationCostModel {
getReductionPatternCost(Instruction *I, ElementCount VF, Type *VectorTy,
TTI::TargetCostKind CostKind) const;
+ /// A chain of instructions that form a partial reduction.
+ /// Designed to match: reduction_bin_op (bin_op (extend (A), (extend (B))),
+ /// accumulator)
+ struct PartialReductionChain {
+ /// The top-level binary operation that forms the reduction to a scalar
+ /// after the loop body
+ Instruction *Reduction;
+ /// The inner binary operation that forms the reduction to a vector value
+ /// within the loop body
+ Instruction *BinOp;
+ /// The extension of each of the inner binary operation's operands
+ Instruction *ExtendA;
+ Instruction *ExtendB;
+
+ /// The accumulator that is reduced to a scalar after the loop body
+ Value *Accumulator;
+
+ /// The scaling factor between the size of the reduction type and the
+ /// (possibly extended) inputs
+ unsigned ScaleFactor;
+ };
+
+ using PartialReductionList = DenseMap<Instruction *, PartialReductionChain>;
+
+ PartialReductionList getPartialReductionChains() {
+ return PartialReductionChains;
+ }
+
+ std::optional<PartialReductionChain> getInstructionsPartialReduction(Instruction *I
+ ) const {
+ auto PairIt = PartialReductionChains.find(I);
+ if (PairIt == PartialReductionChains.end())
+ return std::nullopt;
+ return PairIt->second;
+ }
+
+ void addPartialReductionIfSupported(Instruction *Instr, ElementCount VF) {
+ Value *ExpectedPhi;
+ Value *A, *B;
+
+ using namespace llvm::PatternMatch;
+
+ unsigned BinOpIdx = 0;
+
+ // The binary operator can be commutative
+ if (match(Instr, m_BinOp(m_OneUse(m_BinOp(
+ m_ZExtOrSExt(m_Value(A)),
+ m_ZExtOrSExt(m_Value(B)))),
+ m_Value(ExpectedPhi))))
+ BinOpIdx = 0;
+ else if (match(Instr,
+ m_BinOp(m_Value(ExpectedPhi),
+ m_OneUse(m_BinOp(
+ m_ZExtOrSExt(m_Value(A)),
+ m_ZExtOrSExt(m_Value(B)))))))
+ BinOpIdx = 1;
+ else
+ return;
+
+ // Check that the extends extend from the same type
+ if (A->getType() != B->getType()) {
+ LLVM_DEBUG(dbgs() << "Extends don't extend from the same type, cannot "
+ "create a partial reduction.\n");
+ return;
+ }
+
+ // A and B are one-use, so the first user of each should be the respective
----------------
SamTebbs33 wrote:
Yeah I think it is. I've re-organised it to match manually.
https://github.com/llvm/llvm-project/pull/92418
More information about the llvm-commits
mailing list