[llvm] [VPlan] Use ResumePhi to create reduction resume phis. (PR #110004)
via llvm-commits
llvm-commits at lists.llvm.org
Sun Oct 27 15:25:27 PDT 2024
================
@@ -7562,67 +7562,62 @@ static void addRuntimeUnrollDisableMetaData(Loop *L) {
}
}
-// Check if \p RedResult is a ComputeReductionResult instruction, and if it is
-// create a merge phi node for it.
-static void createAndCollectMergePhiForReduction(
- VPInstruction *RedResult,
- VPTransformState &State, Loop *OrigLoop, BasicBlock *LoopMiddleBlock,
- bool VectorizingEpilogue) {
- if (!RedResult ||
- RedResult->getOpcode() != VPInstruction::ComputeReductionResult)
+// If \p R is a ComputeReductionResult when vectorizing the epilog loop,
+// fix the reduction's scalar PHI node by adding the incoming value from the
+// main vector loop.
+static void fixReductionScalarResumeWhenVectorizingEpilog(
+ VPRecipeBase *R, VPTransformState &State, Loop *OrigLoop,
+ BasicBlock *LoopMiddleBlock) {
+ auto *EpiRedResult = dyn_cast<VPInstruction>(R);
+ if (!EpiRedResult ||
+ EpiRedResult->getOpcode() != VPInstruction::ComputeReductionResult)
return;
- auto *PhiR = cast<VPReductionPHIRecipe>(RedResult->getOperand(0));
- const RecurrenceDescriptor &RdxDesc = PhiR->getRecurrenceDescriptor();
-
- Value *FinalValue = State.get(RedResult, VPLane(VPLane::getFirstLane()));
- auto *ResumePhi =
- dyn_cast<PHINode>(PhiR->getStartValue()->getUnderlyingValue());
- if (VectorizingEpilogue && RecurrenceDescriptor::isAnyOfRecurrenceKind(
- RdxDesc.getRecurrenceKind())) {
- auto *Cmp = cast<ICmpInst>(PhiR->getStartValue()->getUnderlyingValue());
- assert(Cmp->getPredicate() == CmpInst::ICMP_NE);
- assert(Cmp->getOperand(1) == RdxDesc.getRecurrenceStartValue());
- ResumePhi = cast<PHINode>(Cmp->getOperand(0));
- }
- assert((!VectorizingEpilogue || ResumePhi) &&
- "when vectorizing the epilogue loop, we need a resume phi from main "
- "vector loop");
-
- // TODO: bc.merge.rdx should not be created here, instead it should be
- // modeled in VPlan.
+ auto *EpiRedHeaderPhi =
+ cast<VPReductionPHIRecipe>(EpiRedResult->getOperand(0));
+ const RecurrenceDescriptor &RdxDesc =
+ EpiRedHeaderPhi->getRecurrenceDescriptor();
+ Value *MainResumeValue =
+ EpiRedHeaderPhi->getStartValue()->getUnderlyingValue();
+ if (RecurrenceDescriptor::isAnyOfRecurrenceKind(
+ RdxDesc.getRecurrenceKind())) {
+ auto *Cmp = cast<ICmpInst>(MainResumeValue);
+ assert(Cmp->getPredicate() == CmpInst::ICMP_NE &&
+ "AnyOf expected to start with ICMP_NE");
+ assert(Cmp->getOperand(1) == RdxDesc.getRecurrenceStartValue() &&
+ "AnyOf expected to start by comparing main resume value to original "
+ "start value");
+ MainResumeValue = Cmp->getOperand(0);
+ }
+ PHINode *MainResumePhi = cast<PHINode>(MainResumeValue);
+
+ // When fixing reductions in the epilogue loop we should already have
+ // created a bc.merge.rdx Phi after the main vector body. Ensure that we carry
+ // over the incoming values correctly.
+ using namespace VPlanPatternMatch;
+ auto IsResumePhi = [](VPUser *U) {
+ return match(
+ U, m_VPInstruction<VPInstruction::ResumePhi>(m_VPValue(), m_VPValue()));
+ };
+ assert(count_if(EpiRedResult->users(), IsResumePhi) == 1 &&
+ "ResumePhi must have a single user");
+ auto *EpiResumePhiVPI =
+ cast<VPInstruction>(*find_if(EpiRedResult->users(), IsResumePhi));
+ auto *EpiResumePhi = cast<PHINode>(State.get(EpiResumePhiVPI, true));
BasicBlock *LoopScalarPreHeader = OrigLoop->getLoopPreheader();
- // Create a phi node that merges control-flow from the backedge-taken check
- // block and the middle block.
- auto *BCBlockPhi =
- PHINode::Create(FinalValue->getType(), 2, "bc.merge.rdx",
- LoopScalarPreHeader->getTerminator()->getIterator());
-
- // If we are fixing reductions in the epilogue loop then we should already
- // have created a bc.merge.rdx Phi after the main vector body. Ensure that
- // we carry over the incoming values correctly.
+ unsigned UpdateCnt = 0;
for (auto *Incoming : predecessors(LoopScalarPreHeader)) {
- if (Incoming == LoopMiddleBlock)
- BCBlockPhi->addIncoming(FinalValue, Incoming);
- else if (ResumePhi && is_contained(ResumePhi->blocks(), Incoming))
- BCBlockPhi->addIncoming(ResumePhi->getIncomingValueForBlock(Incoming),
- Incoming);
- else
- BCBlockPhi->addIncoming(RdxDesc.getRecurrenceStartValue(), Incoming);
+ if (is_contained(MainResumePhi->blocks(), Incoming)) {
+ assert(EpiResumePhi->getIncomingValueForBlock(Incoming) ==
+ RdxDesc.getRecurrenceStartValue() &&
+ "Trying to reset unexpected value");
+ EpiResumePhi->setIncomingValueForBlock(
+ Incoming, MainResumePhi->getIncomingValueForBlock(Incoming));
+ UpdateCnt++;
----------------
ayalz wrote:
```suggestion
assert(!Updated && "Should update at most 1 incoming value");
EpiResumePhi->setIncomingValueForBlock(
Incoming, MainResumePhi->getIncomingValueForBlock(Incoming));
Updated = true;
```
https://github.com/llvm/llvm-project/pull/110004
More information about the llvm-commits
mailing list