[llvm] [LoopVectorizer] Add support for partial reductions (PR #92418)

Sam Tebbs via llvm-commits llvm-commits at lists.llvm.org
Mon Sep 16 08:18:34 PDT 2024


================
@@ -1556,7 +1556,105 @@ class LoopVectorizationCostModel {
   getReductionPatternCost(Instruction *I, ElementCount VF, Type *VectorTy,
                           TTI::TargetCostKind CostKind) const;
 
+  using PartialReductionList = DenseMap<Instruction *, PartialReductionChain>;
+
+  PartialReductionList getPartialReductionChains() {
+    return PartialReductionChains;
+  }
+
+  bool getInstructionsPartialReduction(Instruction *I,
+                                       PartialReductionChain &Chain) const {
+    auto PairIt = PartialReductionChains.find(I);
+    if (PairIt == PartialReductionChains.end())
+      return false;
+    Chain = PairIt->second;
+    return true;
+  }
+
+  void addPartialReductionIfSupported(Instruction *Instr, ElementCount VF) {
+    Value *ExpectedPhi;
+    Value *A, *B;
+
+    using namespace llvm::PatternMatch;
+    auto Pattern =
+        m_BinOp(m_OneUse(m_BinOp(m_OneUse(m_ZExtOrSExt(m_OneUse(m_Value(A)))),
+                                 m_OneUse(m_ZExtOrSExt(m_OneUse(m_Value(B)))))),
+                m_Value(ExpectedPhi));
+
+    bool Matches = match(Instr, Pattern);
+
+    if (!Matches)
+      return;
+
+    // Check that the extends extend from the same type
+    if (A->getType() != B->getType()) {
+      LLVM_DEBUG(dbgs() << "Extends don't extend from the same type, cannot "
+                           "create a partial reduction.\n");
+      return;
+    }
+
+    // A and B are one-use, so the first user of each should be the respective
+    // extend
+    Instruction *Ext0 = cast<CastInst>(*A->user_begin());
+    Instruction *Ext1 = cast<CastInst>(*B->user_begin());
+
+    // Check that the extends extend to the same type
+    if (Ext0->getType() != Ext1->getType()) {
+      LLVM_DEBUG(
+          dbgs() << "Extends don't extend to the same type, cannot create "
+                    "a partial reduction.\n");
+      return;
+    }
+
+    // Check that the add feeds into ExpectedPhi
+    PHINode *PhiNode = dyn_cast<PHINode>(ExpectedPhi);
+    if (!PhiNode) {
+      LLVM_DEBUG(dbgs() << "Expected Phi node was not a phi, cannot create a "
+                           "partial reduction.\n");
+      return;
+    }
+
+    // Check that the second phi value is the instruction we're looking at
+    Instruction *MaybeAdd = dyn_cast<Instruction>(PhiNode->getIncomingValue(1));
----------------
SamTebbs33 wrote:

I'm hoping that using `PhiNode->getIncomingValueForBlock(Instr->getParent())` will be more accurate. Thank you.

https://github.com/llvm/llvm-project/pull/92418


More information about the llvm-commits mailing list