[llvm] [LoopVectorizer] Add support for partial reductions (PR #92418)

Sam Tebbs via llvm-commits llvm-commits at lists.llvm.org
Thu Sep 19 07:01:28 PDT 2024


================
@@ -1556,7 +1556,130 @@ class LoopVectorizationCostModel {
   getReductionPatternCost(Instruction *I, ElementCount VF, Type *VectorTy,
                           TTI::TargetCostKind CostKind) const;
 
+  /// A chain of instructions that form a partial reduction.
+  /// Designed to match: reduction_bin_op (bin_op (extend (A), (extend (B))),
+  /// accumulator)
+  struct PartialReductionChain {
+    /// The top-level binary operation that forms the reduction to a scalar
+    /// after the loop body
+    Instruction *Reduction;
+    /// The inner binary operation that forms the reduction to a vector value
+    /// within the loop body
+    Instruction *BinOp;
+    /// The extension of each of the inner binary operation's operands
+    Instruction *ExtendA;
+    Instruction *ExtendB;
+
+    /// The accumulator that is reduced to a scalar after the loop body
+    Value *Accumulator;
+
+    /// The scaling factor between the size of the reduction type and the
+    /// (possibly extended) inputs
+    unsigned ScaleFactor;
+  };
+
+  using PartialReductionList = DenseMap<Instruction *, PartialReductionChain>;
+
+  PartialReductionList getPartialReductionChains() {
+    return PartialReductionChains;
+  }
+
+  std::optional<PartialReductionChain> getInstructionsPartialReduction(Instruction *I
+ ) const {
+    auto PairIt = PartialReductionChains.find(I);
+    if (PairIt == PartialReductionChains.end())
+      return std::nullopt;
+    return PairIt->second;
+  }
+
+  void addPartialReductionIfSupported(Instruction *Instr, ElementCount VF) {
+    Value *ExpectedPhi;
+    Value *A, *B;
+
+    using namespace llvm::PatternMatch;
+
+    unsigned BinOpIdx = 0;
+
+    // The binary operator can be commutative
+    if (match(Instr, m_BinOp(m_OneUse(m_BinOp(
+                                 m_ZExtOrSExt(m_Value(A)),
+                                 m_ZExtOrSExt(m_Value(B)))),
+                             m_Value(ExpectedPhi))))
+      BinOpIdx = 0;
+    else if (match(Instr,
+                   m_BinOp(m_Value(ExpectedPhi),
+                           m_OneUse(m_BinOp(
+                               m_ZExtOrSExt(m_Value(A)),
+                               m_ZExtOrSExt(m_Value(B)))))))
+      BinOpIdx = 1;
+    else
+      return;
+
+    // Check that the extends extend from the same type
+    if (A->getType() != B->getType()) {
+      LLVM_DEBUG(dbgs() << "Extends don't extend from the same type, cannot "
+                           "create a partial reduction.\n");
+      return;
+    }
+
+    // A and B are one-use, so the first user of each should be the respective
----------------
SamTebbs33 wrote:

Yeah I think it is. I've re-organised it to match manually.

https://github.com/llvm/llvm-project/pull/92418


More information about the llvm-commits mailing list