[llvm] [VPlan] Use ResumePhi to create reduction resume phis. (PR #110004)

via llvm-commits llvm-commits at lists.llvm.org
Sun Oct 27 15:25:27 PDT 2024


================
@@ -7562,67 +7562,62 @@ static void addRuntimeUnrollDisableMetaData(Loop *L) {
   }
 }
 
-// Check if \p RedResult is a ComputeReductionResult instruction, and if it is
-// create a merge phi node for it.
-static void createAndCollectMergePhiForReduction(
-    VPInstruction *RedResult,
-    VPTransformState &State, Loop *OrigLoop, BasicBlock *LoopMiddleBlock,
-    bool VectorizingEpilogue) {
-  if (!RedResult ||
-      RedResult->getOpcode() != VPInstruction::ComputeReductionResult)
+// If \p R is a ComputeReductionResult when vectorizing the epilog loop,
+// fix the reduction's scalar PHI node by adding the incoming value from the
+// main vector loop.
+static void fixReductionScalarResumeWhenVectorizingEpilog(
+    VPRecipeBase *R, VPTransformState &State, Loop *OrigLoop,
+    BasicBlock *LoopMiddleBlock) {
+  auto *EpiRedResult = dyn_cast<VPInstruction>(R);
+  if (!EpiRedResult ||
+      EpiRedResult->getOpcode() != VPInstruction::ComputeReductionResult)
     return;
 
-  auto *PhiR = cast<VPReductionPHIRecipe>(RedResult->getOperand(0));
-  const RecurrenceDescriptor &RdxDesc = PhiR->getRecurrenceDescriptor();
-
-  Value *FinalValue = State.get(RedResult, VPLane(VPLane::getFirstLane()));
-  auto *ResumePhi =
-      dyn_cast<PHINode>(PhiR->getStartValue()->getUnderlyingValue());
-  if (VectorizingEpilogue && RecurrenceDescriptor::isAnyOfRecurrenceKind(
-                                 RdxDesc.getRecurrenceKind())) {
-    auto *Cmp = cast<ICmpInst>(PhiR->getStartValue()->getUnderlyingValue());
-    assert(Cmp->getPredicate() == CmpInst::ICMP_NE);
-    assert(Cmp->getOperand(1) == RdxDesc.getRecurrenceStartValue());
-    ResumePhi = cast<PHINode>(Cmp->getOperand(0));
-  }
-  assert((!VectorizingEpilogue || ResumePhi) &&
-         "when vectorizing the epilogue loop, we need a resume phi from main "
-         "vector loop");
-
-  // TODO: bc.merge.rdx should not be created here, instead it should be
-  // modeled in VPlan.
+  auto *EpiRedHeaderPhi =
+      cast<VPReductionPHIRecipe>(EpiRedResult->getOperand(0));
+  const RecurrenceDescriptor &RdxDesc =
+      EpiRedHeaderPhi->getRecurrenceDescriptor();
+  Value *MainResumeValue =
+      EpiRedHeaderPhi->getStartValue()->getUnderlyingValue();
+  if (RecurrenceDescriptor::isAnyOfRecurrenceKind(
+          RdxDesc.getRecurrenceKind())) {
+    auto *Cmp = cast<ICmpInst>(MainResumeValue);
+    assert(Cmp->getPredicate() == CmpInst::ICMP_NE &&
+           "AnyOf expected to start with ICMP_NE");
+    assert(Cmp->getOperand(1) == RdxDesc.getRecurrenceStartValue() &&
+           "AnyOf expected to start by comparing main resume value to original "
+           "start value");
+    MainResumeValue = Cmp->getOperand(0);
+  }
+  PHINode *MainResumePhi = cast<PHINode>(MainResumeValue);
+
+  // When fixing reductions in the epilogue loop we should already have
+  // created a bc.merge.rdx Phi after the main vector body. Ensure that we carry
+  // over the incoming values correctly.
+  using namespace VPlanPatternMatch;
+  auto IsResumePhi = [](VPUser *U) {
+    return match(
+        U, m_VPInstruction<VPInstruction::ResumePhi>(m_VPValue(), m_VPValue()));
+  };
+  assert(count_if(EpiRedResult->users(), IsResumePhi) == 1 &&
+         "ResumePhi must have a single user");
+  auto *EpiResumePhiVPI =
+      cast<VPInstruction>(*find_if(EpiRedResult->users(), IsResumePhi));
+  auto *EpiResumePhi = cast<PHINode>(State.get(EpiResumePhiVPI, true));
   BasicBlock *LoopScalarPreHeader = OrigLoop->getLoopPreheader();
-  // Create a phi node that merges control-flow from the backedge-taken check
-  // block and the middle block.
-  auto *BCBlockPhi =
-      PHINode::Create(FinalValue->getType(), 2, "bc.merge.rdx",
-                      LoopScalarPreHeader->getTerminator()->getIterator());
-
-  // If we are fixing reductions in the epilogue loop then we should already
-  // have created a bc.merge.rdx Phi after the main vector body. Ensure that
-  // we carry over the incoming values correctly.
+  unsigned UpdateCnt = 0;
   for (auto *Incoming : predecessors(LoopScalarPreHeader)) {
-    if (Incoming == LoopMiddleBlock)
-      BCBlockPhi->addIncoming(FinalValue, Incoming);
-    else if (ResumePhi && is_contained(ResumePhi->blocks(), Incoming))
-      BCBlockPhi->addIncoming(ResumePhi->getIncomingValueForBlock(Incoming),
-                              Incoming);
-    else
-      BCBlockPhi->addIncoming(RdxDesc.getRecurrenceStartValue(), Incoming);
+    if (is_contained(MainResumePhi->blocks(), Incoming)) {
+      assert(EpiResumePhi->getIncomingValueForBlock(Incoming) ==
+                 RdxDesc.getRecurrenceStartValue() &&
+             "Trying to reset unexpected value");
+      EpiResumePhi->setIncomingValueForBlock(
+          Incoming, MainResumePhi->getIncomingValueForBlock(Incoming));
+      UpdateCnt++;
----------------
ayalz wrote:

```suggestion
      assert(!Updated && "Should update at most 1 incoming value");
      EpiResumePhi->setIncomingValueForBlock(
          Incoming, MainResumePhi->getIncomingValueForBlock(Incoming));
      Updated = true;
```

https://github.com/llvm/llvm-project/pull/110004


More information about the llvm-commits mailing list