[llvm] [VPlan] Introduce multi-branch recipe, use for multi-exit loops (WIP). (PR #109193)

via llvm-commits llvm-commits at lists.llvm.org
Wed Sep 18 13:38:25 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: Florian Hahn (fhahn)

<details>
<summary>Changes</summary>

This patch introduces a new BranchMultipleConds VPInstruction that takes multiple conditions and branches to the first successor if the first operand is true, to the second successor if the second condition is true and to the region header if neither is true. At the moment it only supports 2 conditions, but it can be extended in the future.

This may serve as an alternative to changing VPRegionBlock to allow multiple exiting blocks and keep it single-entry-single-exit. With BranchMultipleConds, we still leave a region via a single exiting block, but can have more than 2 destinations (similar idea to switch in LLVM IR). The new recipe allows to precisely model edges and conditions leaving the vector loop region.

BranchMultipleConds also allows predicating instructions in blocks after any early exit, i.e. also allows later stores.

See llvm/test/Transforms/LoopVectorize/X86/multi-exit-vplan.ll for an example VPlan and llvm/test/Transforms/LoopVectorize/X86/multi-exit-codegen.ll for example predicated codegen.

The patch also contains logic to construct VPlans using BranchMultipleConds for simple loops with 2 exit blocks instead of requiring a scalar tail. To logic to detect such cases is a bit rough around the edges and mainly to test the new recipes end-to-end.

This may serve as an alternative to https://github.com/llvm/llvm-project/pull/108563 that would allow us to keep the single-entry-single-exit property and support predication between early exits and latches.

---

Patch is 41.08 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/109193.diff


9 Files Affected:

- (modified) llvm/include/llvm/Transforms/Vectorize/LoopVectorizationLegality.h (+2) 
- (modified) llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp (+36) 
- (modified) llvm/lib/Transforms/Vectorize/LoopVectorize.cpp (+130-41) 
- (modified) llvm/lib/Transforms/Vectorize/VPlan.cpp (+43-14) 
- (modified) llvm/lib/Transforms/Vectorize/VPlan.h (+6-1) 
- (modified) llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp (+57-4) 
- (modified) llvm/lib/Transforms/Vectorize/VPlanVerifier.cpp (+7-13) 
- (added) llvm/test/Transforms/LoopVectorize/X86/multi-exit-codegen.ll (+117) 
- (added) llvm/test/Transforms/LoopVectorize/X86/multi-exit-vplan.ll (+69) 


``````````diff
diff --git a/llvm/include/llvm/Transforms/Vectorize/LoopVectorizationLegality.h b/llvm/include/llvm/Transforms/Vectorize/LoopVectorizationLegality.h
index 0f4d1355dd2bfe..d2c754a106cf7f 100644
--- a/llvm/include/llvm/Transforms/Vectorize/LoopVectorizationLegality.h
+++ b/llvm/include/llvm/Transforms/Vectorize/LoopVectorizationLegality.h
@@ -275,6 +275,8 @@ class LoopVectorizationLegality {
   /// we can use in-order reductions.
   bool canVectorizeFPMath(bool EnableStrictReductions);
 
+  bool canVectorizeMultiCond() const;
+
   /// Return true if we can vectorize this loop while folding its tail by
   /// masking.
   bool canFoldTailByMasking() const;
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp
index 7062e21383a5fc..4d3dfc0838f466 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp
@@ -43,6 +43,9 @@ AllowStridedPointerIVs("lv-strided-pointer-ivs", cl::init(false), cl::Hidden,
                        cl::desc("Enable recognition of non-constant strided "
                                 "pointer induction variables."));
 
+static cl::opt<bool> EnableMultiCond("enable-multi-cond-vectorization",
+                                     cl::init(false), cl::Hidden, cl::desc(""));
+
 namespace llvm {
 cl::opt<bool>
     HintsAllowReordering("hints-allow-reordering", cl::init(true), cl::Hidden,
@@ -1247,6 +1250,8 @@ bool LoopVectorizationLegality::isFixedOrderRecurrence(
 }
 
 bool LoopVectorizationLegality::blockNeedsPredication(BasicBlock *BB) const {
+  if (canVectorizeMultiCond() && BB != TheLoop->getHeader())
+    return true;
   return LoopAccessInfo::blockNeedsPredication(BB, TheLoop, DT);
 }
 
@@ -1377,6 +1382,37 @@ bool LoopVectorizationLegality::canVectorizeWithIfConvert() {
   return true;
 }
 
+bool LoopVectorizationLegality::canVectorizeMultiCond() const {
+  if (!EnableMultiCond)
+    return false;
+  if (TheLoop->getUniqueExitBlock())
+    return false;
+  SmallVector<BasicBlock *> Exiting;
+  TheLoop->getExitingBlocks(Exiting);
+  if (Exiting.size() != 2 || Exiting[0] != TheLoop->getHeader() ||
+      Exiting[1] != TheLoop->getLoopLatch() ||
+      any_of(*TheLoop->getHeader(), [](Instruction &I) {
+        return I.mayReadFromMemory() || I.mayHaveSideEffects();
+      }))
+    return false;
+  CmpInst::Predicate Pred;
+  Value *A, *B;
+  if (!match(
+          TheLoop->getHeader()->getTerminator(),
+          m_Br(m_ICmp(Pred, m_Value(A), m_Value(B)), m_Value(), m_Value())) ||
+      Pred == CmpInst::ICMP_EQ || Pred == CmpInst::ICMP_NE)
+    return false;
+  if (any_of(TheLoop->getBlocks(), [this](BasicBlock *BB) {
+        return any_of(*BB, [this](Instruction &I) {
+          return any_of(I.users(), [this](User *U) {
+            return !TheLoop->contains(cast<Instruction>(U)->getParent());
+          });
+        });
+      }))
+    return false;
+  return true;
+}
+
 // Helper function to canVectorizeLoopNestCFG.
 bool LoopVectorizationLegality::canVectorizeLoopCFG(Loop *Lp,
                                                     bool UseVPlanNativePath) {
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 9fb684427cfe9d..b2188ad8b2e4b9 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -1387,9 +1387,11 @@ class LoopVectorizationCostModel {
     // If we might exit from anywhere but the latch, must run the exiting
     // iteration in scalar form.
     if (TheLoop->getExitingBlock() != TheLoop->getLoopLatch()) {
-      LLVM_DEBUG(
-          dbgs() << "LV: Loop requires scalar epilogue: multiple exits\n");
-      return true;
+      if (!Legal->canVectorizeMultiCond()) {
+        LLVM_DEBUG(
+            dbgs() << "LV: Loop requires scalar epilogue: multiple exits\n");
+        return true;
+      }
     }
     if (IsVectorizing && InterleaveInfo.requiresScalarEpilogue()) {
       LLVM_DEBUG(dbgs() << "LV: Loop requires scalar epilogue: "
@@ -2571,8 +2573,17 @@ void InnerLoopVectorizer::createVectorLoopSkeleton(StringRef Prefix) {
   LoopVectorPreHeader = OrigLoop->getLoopPreheader();
   assert(LoopVectorPreHeader && "Invalid loop structure");
   LoopExitBlock = OrigLoop->getUniqueExitBlock(); // may be nullptr
-  assert((LoopExitBlock || Cost->requiresScalarEpilogue(VF.isVector())) &&
-         "multiple exit loop without required epilogue?");
+  if (Legal->canVectorizeMultiCond()) {
+    BasicBlock *Latch = OrigLoop->getLoopLatch();
+    BasicBlock *TrueSucc =
+        cast<BranchInst>(Latch->getTerminator())->getSuccessor(0);
+    BasicBlock *FalseSucc =
+        cast<BranchInst>(Latch->getTerminator())->getSuccessor(1);
+    LoopExitBlock = OrigLoop->contains(TrueSucc) ? FalseSucc : TrueSucc;
+  } else {
+    assert((LoopExitBlock || Cost->requiresScalarEpilogue(VF.isVector())) &&
+           "multiple exit loop without required epilogue?");
+  }
 
   LoopMiddleBlock =
       SplitBlock(LoopVectorPreHeader, LoopVectorPreHeader->getTerminator(), DT,
@@ -2943,24 +2954,26 @@ void InnerLoopVectorizer::fixVectorizedLoop(VPTransformState &State,
   VPRegionBlock *VectorRegion = State.Plan->getVectorLoopRegion();
   VPBasicBlock *LatchVPBB = VectorRegion->getExitingBasicBlock();
   Loop *VectorLoop = LI->getLoopFor(State.CFG.VPBB2IRBB[LatchVPBB]);
-  if (Cost->requiresScalarEpilogue(VF.isVector())) {
-    // No edge from the middle block to the unique exit block has been inserted
-    // and there is nothing to fix from vector loop; phis should have incoming
-    // from scalar loop only.
-  } else {
-    // TODO: Check VPLiveOuts to see if IV users need fixing instead of checking
-    // the cost model.
-
-    // If we inserted an edge from the middle block to the unique exit block,
-    // update uses outside the loop (phis) to account for the newly inserted
-    // edge.
-
-    // Fix-up external users of the induction variables.
-    for (const auto &Entry : Legal->getInductionVars())
-      fixupIVUsers(Entry.first, Entry.second,
-                   getOrCreateVectorTripCount(VectorLoop->getLoopPreheader()),
-                   IVEndValues[Entry.first], LoopMiddleBlock,
-                   VectorLoop->getHeader(), Plan, State);
+  if (OrigLoop->getUniqueExitBlock()) {
+    if (Cost->requiresScalarEpilogue(VF.isVector())) {
+      // No edge from the middle block to the unique exit block has been
+      // inserted and there is nothing to fix from vector loop; phis should have
+      // incoming from scalar loop only.
+    } else {
+      // TODO: Check VPLiveOuts to see if IV users need fixing instead of
+      // checking the cost model.
+
+      // If we inserted an edge from the middle block to the unique exit block,
+      // update uses outside the loop (phis) to account for the newly inserted
+      // edge.
+
+      // Fix-up external users of the induction variables.
+      for (const auto &Entry : Legal->getInductionVars())
+        fixupIVUsers(Entry.first, Entry.second,
+                     getOrCreateVectorTripCount(VectorLoop->getLoopPreheader()),
+                     IVEndValues[Entry.first], LoopMiddleBlock,
+                     VectorLoop->getHeader(), Plan, State);
+    }
   }
 
   // Fix live-out phis not already fixed earlier.
@@ -3584,7 +3597,8 @@ void LoopVectorizationCostModel::collectLoopUniforms(ElementCount VF) {
   TheLoop->getExitingBlocks(Exiting);
   for (BasicBlock *E : Exiting) {
     auto *Cmp = dyn_cast<Instruction>(E->getTerminator()->getOperand(0));
-    if (Cmp && TheLoop->contains(Cmp) && Cmp->hasOneUse())
+    if (Cmp && TheLoop->contains(Cmp) && Cmp->hasOneUse() &&
+        (TheLoop->getLoopLatch() == E || !Legal->canVectorizeMultiCond()))
       AddToWorklistIfAllowed(Cmp);
   }
 
@@ -7515,7 +7529,8 @@ LoopVectorizationPlanner::executePlan(
   LLVM_DEBUG(BestVPlan.dump());
 
   // Perform the actual loop transformation.
-  VPTransformState State(BestVF, BestUF, LI, DT, ILV.Builder, &ILV, &BestVPlan);
+  VPTransformState State(BestVF, BestUF, LI, DT, ILV.Builder, &ILV, &BestVPlan,
+                         OrigLoop);
 
   // 0. Generate SCEV-dependent code into the preheader, including TripCount,
   // before making any changes to the CFG.
@@ -7577,12 +7592,15 @@ LoopVectorizationPlanner::executePlan(
 
   // 2.5 Collect reduction resume values.
   DenseMap<const RecurrenceDescriptor *, Value *> ReductionResumeValues;
-  auto *ExitVPBB =
-      cast<VPBasicBlock>(BestVPlan.getVectorLoopRegion()->getSingleSuccessor());
-  for (VPRecipeBase &R : *ExitVPBB) {
-    createAndCollectMergePhiForReduction(
-        dyn_cast<VPInstruction>(&R), ReductionResumeValues, State, OrigLoop,
-        State.CFG.VPBB2IRBB[ExitVPBB], ExpandedSCEVs);
+  VPBasicBlock *ExitVPBB = nullptr;
+  if (BestVPlan.getVectorLoopRegion()->getSingleSuccessor()) {
+    ExitVPBB = cast<VPBasicBlock>(
+        BestVPlan.getVectorLoopRegion()->getSingleSuccessor());
+    for (VPRecipeBase &R : *ExitVPBB) {
+      createAndCollectMergePhiForReduction(
+          dyn_cast<VPInstruction>(&R), ReductionResumeValues, State, OrigLoop,
+          State.CFG.VPBB2IRBB[ExitVPBB], ExpandedSCEVs);
+    }
   }
 
   // 2.6. Maintain Loop Hints
@@ -7608,6 +7626,7 @@ LoopVectorizationPlanner::executePlan(
     LoopVectorizeHints Hints(L, true, *ORE);
     Hints.setAlreadyVectorized();
   }
+
   TargetTransformInfo::UnrollingPreferences UP;
   TTI.getUnrollingPreferences(L, *PSE.getSE(), UP, ORE);
   if (!UP.UnrollVectorizedLoop || CanonicalIVStartValue)
@@ -7620,15 +7639,17 @@ LoopVectorizationPlanner::executePlan(
   ILV.printDebugTracesAtEnd();
 
   // 4. Adjust branch weight of the branch in the middle block.
-  auto *MiddleTerm =
-      cast<BranchInst>(State.CFG.VPBB2IRBB[ExitVPBB]->getTerminator());
-  if (MiddleTerm->isConditional() &&
-      hasBranchWeightMD(*OrigLoop->getLoopLatch()->getTerminator())) {
-    // Assume that `Count % VectorTripCount` is equally distributed.
-    unsigned TripCount = State.UF * State.VF.getKnownMinValue();
-    assert(TripCount > 0 && "trip count should not be zero");
-    const uint32_t Weights[] = {1, TripCount - 1};
-    setBranchWeights(*MiddleTerm, Weights, /*IsExpected=*/false);
+  if (ExitVPBB) {
+    auto *MiddleTerm =
+        cast<BranchInst>(State.CFG.VPBB2IRBB[ExitVPBB]->getTerminator());
+    if (MiddleTerm->isConditional() &&
+        hasBranchWeightMD(*OrigLoop->getLoopLatch()->getTerminator())) {
+      // Assume that `Count % VectorTripCount` is equally distributed.
+      unsigned TripCount = State.UF * State.VF.getKnownMinValue();
+      assert(TripCount > 0 && "trip count should not be zero");
+      const uint32_t Weights[] = {1, TripCount - 1};
+      setBranchWeights(*MiddleTerm, Weights, /*IsExpected=*/false);
+    }
   }
 
   return {State.ExpandedSCEVs, ReductionResumeValues};
@@ -8013,7 +8034,7 @@ VPValue *VPRecipeBuilder::createEdgeMask(BasicBlock *Src, BasicBlock *Dst) {
   // If source is an exiting block, we know the exit edge is dynamically dead
   // in the vector loop, and thus we don't need to restrict the mask.  Avoid
   // adding uses of an otherwise potentially dead instruction.
-  if (OrigLoop->isLoopExiting(Src))
+  if (!Legal->canVectorizeMultiCond() && OrigLoop->isLoopExiting(Src))
     return EdgeMaskCache[Edge] = SrcMask;
 
   VPValue *EdgeMask = getVPValueOrAddLiveIn(BI->getCondition());
@@ -8630,6 +8651,8 @@ static void addCanonicalIVRecipes(VPlan &Plan, Type *IdxTy, bool HasNUW,
 static SetVector<VPIRInstruction *> collectUsersInExitBlock(
     Loop *OrigLoop, VPRecipeBuilder &Builder, VPlan &Plan,
     const MapVector<PHINode *, InductionDescriptor> &Inductions) {
+  if (!Plan.getVectorLoopRegion()->getSingleSuccessor())
+    return {};
   auto *MiddleVPBB =
       cast<VPBasicBlock>(Plan.getVectorLoopRegion()->getSingleSuccessor());
   // No edge from the middle block to the unique exit block has been inserted
@@ -8717,6 +8740,8 @@ static void addLiveOutsForFirstOrderRecurrences(
   // TODO: Should be replaced by
   // Plan->getScalarLoopRegion()->getSinglePredecessor() in the future once the
   // scalar region is modeled as well.
+  if (!VectorRegion->getSingleSuccessor())
+    return;
   auto *MiddleVPBB = cast<VPBasicBlock>(VectorRegion->getSingleSuccessor());
   VPBasicBlock *ScalarPHVPBB = nullptr;
   if (MiddleVPBB->getNumSuccessors() == 2) {
@@ -8991,6 +9016,67 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
          "VPBasicBlock");
   RecipeBuilder.fixHeaderPhis();
 
+  SmallVector<BasicBlock *> Exiting;
+  OrigLoop->getExitingBlocks(Exiting);
+
+  if (Legal->canVectorizeMultiCond()) {
+    auto *LatchVPBB =
+        cast<VPBasicBlock>(Plan->getVectorLoopRegion()->getExiting());
+    VPBuilder::InsertPointGuard Guard(Builder);
+    Builder.setInsertPoint(LatchVPBB->getTerminator());
+    auto *MiddleVPBB =
+        cast<VPBasicBlock>(Plan->getVectorLoopRegion()->getSingleSuccessor());
+
+    VPValue *EarlyExitTaken = nullptr;
+    SmallVector<VPValue *> ExitTaken;
+    SmallVector<PHINode *> ExitPhis;
+    SmallVector<Value *> ExitValues;
+    BasicBlock *ExitBlock;
+    for (BasicBlock *E : Exiting) {
+      if (E == OrigLoop->getLoopLatch()) {
+        BasicBlock *TrueSucc =
+            cast<BranchInst>(E->getTerminator())->getSuccessor(0);
+        BasicBlock *FalseSucc =
+            cast<BranchInst>(E->getTerminator())->getSuccessor(1);
+        auto EB = !OrigLoop->contains(TrueSucc) ? TrueSucc : FalseSucc;
+
+        auto *VPExitBlock = new VPIRBasicBlock(EB);
+        VPBasicBlock *ScalarPH = new VPBasicBlock("scalar.ph");
+        VPBlockUtils::connectBlocks(MiddleVPBB, VPExitBlock);
+        VPBlockUtils::connectBlocks(MiddleVPBB, ScalarPH);
+        continue;
+      }
+      BasicBlock *TrueSucc =
+          cast<BranchInst>(E->getTerminator())->getSuccessor(0);
+      BasicBlock *FalseSucc =
+          cast<BranchInst>(E->getTerminator())->getSuccessor(1);
+      VPValue *M = RecipeBuilder.getBlockInMask(
+          OrigLoop->contains(TrueSucc) ? TrueSucc : FalseSucc);
+
+      auto *N = Builder.createNot(M);
+      auto *EC = Builder.createNaryOp(VPInstruction::AnyOf, {N});
+      ExitTaken.push_back(EC);
+      if (EarlyExitTaken)
+        EarlyExitTaken = Builder.createOr(EarlyExitTaken, EC);
+      else
+        EarlyExitTaken = EC;
+      ExitBlock = !OrigLoop->contains(TrueSucc) ? TrueSucc : FalseSucc;
+    }
+
+    auto *Term = dyn_cast<VPInstruction>(LatchVPBB->getTerminator());
+    auto *IsLatchExiting = Builder.createICmp(
+        CmpInst::ICMP_EQ, Term->getOperand(0), Term->getOperand(1));
+    Builder.createNaryOp(VPInstruction::BranchMultipleConds,
+                         {EarlyExitTaken, IsLatchExiting});
+    Term->eraseFromParent();
+
+    auto *EA = new VPIRBasicBlock(ExitBlock);
+    VPRegionBlock *LoopRegion = Plan->getVectorLoopRegion();
+    VPBlockUtils::disconnectBlocks(LoopRegion, MiddleVPBB);
+    VPBlockUtils::connectBlocks(LoopRegion, EA);
+    VPBlockUtils::connectBlocks(LoopRegion, MiddleVPBB);
+  }
+
   SetVector<VPIRInstruction *> ExitUsersToFix = collectUsersInExitBlock(
       OrigLoop, RecipeBuilder, *Plan, Legal->getInductionVars());
   addLiveOutsForFirstOrderRecurrences(*Plan, ExitUsersToFix);
@@ -9062,6 +9148,7 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
     VPlanTransforms::addActiveLaneMask(*Plan, ForControlFlow,
                                        WithoutRuntimeCheck);
   }
+
   return Plan;
 }
 
@@ -9286,6 +9373,8 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
   }
   VPBasicBlock *LatchVPBB = VectorLoopRegion->getExitingBasicBlock();
   Builder.setInsertPoint(&*LatchVPBB->begin());
+  if (!VectorLoopRegion->getSingleSuccessor())
+    return;
   VPBasicBlock *MiddleVPBB =
       cast<VPBasicBlock>(VectorLoopRegion->getSingleSuccessor());
   VPBasicBlock::iterator IP = MiddleVPBB->getFirstNonPhi();
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.cpp b/llvm/lib/Transforms/Vectorize/VPlan.cpp
index 2169d78542cbaf..c608c2c1cd3f69 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlan.cpp
@@ -224,9 +224,11 @@ VPBasicBlock::iterator VPBasicBlock::getFirstNonPhi() {
 
 VPTransformState::VPTransformState(ElementCount VF, unsigned UF, LoopInfo *LI,
                                    DominatorTree *DT, IRBuilderBase &Builder,
-                                   InnerLoopVectorizer *ILV, VPlan *Plan)
+                                   InnerLoopVectorizer *ILV, VPlan *Plan,
+                                   Loop *OrigLoop)
     : VF(VF), UF(UF), CFG(DT), LI(LI), Builder(Builder), ILV(ILV), Plan(Plan),
-      LVer(nullptr), TypeAnalysis(Plan->getCanonicalIV()->getScalarType()) {}
+      LVer(nullptr), TypeAnalysis(Plan->getCanonicalIV()->getScalarType()),
+      OrigLoop(OrigLoop) {}
 
 Value *VPTransformState::get(VPValue *Def, const VPIteration &Instance) {
   if (Def->isLiveIn())
@@ -477,6 +479,14 @@ void VPIRBasicBlock::execute(VPTransformState *State) {
     // backedges. A backward successor is set when the branch is created.
     const auto &PredVPSuccessors = PredVPBB->getHierarchicalSuccessors();
     unsigned idx = PredVPSuccessors.front() == this ? 0 : 1;
+    if (TermBr->getSuccessor(idx) &&
+        PredVPBlock == getPlan()->getVectorLoopRegion() &&
+        PredVPBlock->getNumSuccessors()) {
+      // Update PRedBB and TermBr for BranchOnMultiCond in predecessor.
+      PredBB = TermBr->getSuccessor(1);
+      TermBr = cast<BranchInst>(PredBB->getTerminator());
+      idx = 0;
+    }
     assert(!TermBr->getSuccessor(idx) &&
            "Trying to reset an existing successor block.");
     TermBr->setSuccessor(idx, IRBB);
@@ -595,9 +605,11 @@ static bool hasConditionalTerminator(const VPBasicBlock *VPBB) {
   }
 
   const VPRecipeBase *R = &VPBB->back();
-  bool IsCondBranch = isa<VPBranchOnMaskRecipe>(R) ||
-                      match(R, m_BranchOnCond(m_VPValue())) ||
-                      match(R, m_BranchOnCount(m_VPValue(), m_VPValue()));
+  bool IsCondBranch =
+      isa<VPBranchOnMaskRecipe>(R) || match(R, m_BranchOnCond(m_VPValue())) ||
+      match(R, m_BranchOnCount(m_VPValue(), m_VPValue())) ||
+      (isa<VPInstruction>(R) && cast<VPInstruction>(R)->getOpcode() ==
+                                    VPInstruction::BranchMultipleConds);
   (void)IsCondBranch;
 
   if (VPBB->getNumSuccessors() >= 2 ||
@@ -878,7 +890,10 @@ VPlanPtr VPlan::createInitialVPlan(Type *InductionTy,
   auto Plan = std::make_unique<VPlan>(Entry, VecPreheader);
 
   // Create SCEV and VPValue for the trip count.
-  const SCEV *BackedgeTakenCount = PSE.getBackedgeTakenCount();
+  BasicBlock *IRExitBlock = TheLoop->getUniqueExitBlock();
+  const SCEV *BackedgeTakenCount =
+      IRExitBlock ? PSE.getBackedgeTakenCount()
+                  : PSE.getSE()->getExitCount(TheLoop, TheLoop->getLoopLatch());
   assert(!isa<SCEVCouldNotCompute>(BackedgeTakenCount) && "Invalid loop count");
   ScalarEvolution &SE = *PSE.getSE();
   const SCEV *TripCount =
@@ -898,8 +913,8 @@ VPlanPtr VPlan::createInitialVPlan(Type *InductionTy,
   VPBasicBlock *MiddleVPBB = new VPBasicBlock("middle.block");
   VPBlockUtils::insertBlockAfter(MiddleVPBB, TopRegion);
 
-  VPBasicBlock *ScalarPH = new VPBasicBlock("scalar.ph");
   if (!RequiresScalarEpilogueCheck) {
+    VPBasicBlock *ScalarPH = new VPBasicBlock("scalar.ph");
     VPBlockUtils::connectBlocks(MiddleVPBB, ScalarPH);
     return Plan;
   }
@@ -912,11 +927,14 @@ VPlanPtr VPlan::createInitialVPlan(Type *InductionTy,
   // 2) If we require a scalar epilogue, there is no conditional branch as
   //    we unconditionally branch to the scalar preheader.  Do nothing.
   // 3) Otherwise, construct a runtime check.
-  BasicBlock *IRExitBlock = TheLoop->getUniqueExitBlock();
-  auto *VPExitBlock = createVPIRBasicBlockFor(IRExitBlock);
-  // The connection order corresponds to the operands of the conditional branch.
-  VPBlockUtils::insertBlockAfter(VPExitBlock, MiddleVPBB);
-  VPBlockUtils::connectBlocks(MiddleVPBB, ScalarPH);
+  if (IRExitBlock) {
+    auto *VPExitBlock = createVPIRBasicBlockFor(IRExitBlock);
+    // The connection order corresponds to the operands of the conditional
+    // branch.
+    VPBlockUtils::insertBlockAfter(VPExitBlock, MiddleVPBB);
+    VPBasicBlock *ScalarPH =...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/109193


More information about the llvm-commits mailing list