[llvm] [LV] Transform to handle exits in the scalar loop (PR #148626)

Graham Hunter via llvm-commits llvm-commits at lists.llvm.org
Wed Oct 29 09:38:04 PDT 2025


https://github.com/huntergr-arm updated https://github.com/llvm/llvm-project/pull/148626

>From 3e5b9d9aa18afaf6aa9f71f374a0eafc9e4e10ca Mon Sep 17 00:00:00 2001
From: Graham Hunter <graham.hunter at arm.com>
Date: Tue, 1 Jul 2025 13:08:48 +0000
Subject: [PATCH] Transform code

---
 .../Transforms/Vectorize/LoopVectorize.cpp    |  20 +-
 llvm/lib/Transforms/Vectorize/VPlan.cpp       |   2 +
 llvm/lib/Transforms/Vectorize/VPlan.h         |  24 ++
 .../Vectorize/VPlanConstruction.cpp           |   3 +
 .../Transforms/Vectorize/VPlanTransforms.cpp  | 236 +++++++++++++++---
 .../Transforms/Vectorize/VPlanTransforms.h    |  10 +
 llvm/lib/Transforms/Vectorize/VPlanUtils.cpp  |   9 +-
 .../early-exit-handle-exits-in-scalar-loop.ll | 155 ++++++++++++
 8 files changed, 414 insertions(+), 45 deletions(-)
 create mode 100644 llvm/test/Transforms/LoopVectorize/early-exit-handle-exits-in-scalar-loop.ll

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index f7968abbe5b6b..233fb64e94fff 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -397,6 +397,10 @@ static cl::opt<bool> ConsiderRegPressure(
     "vectorizer-consider-reg-pressure", cl::init(false), cl::Hidden,
     cl::desc("Discard VFs if their register pressure is too high."));
 
+static cl::opt<bool> HandleEarlyExitsInScalarTail(
+    "handle-early-exits-in-scalar-tail", cl::init(false), cl::Hidden,
+    cl::desc("Use the scalar tail to deal with early exit logic"));
+
 // Likelyhood of bypassing the vectorized loop because there are zero trips left
 // after prolog. See `emitIterationCountCheck`.
 static constexpr uint32_t MinItersBypassWeights[] = {1, 127};
@@ -502,8 +506,7 @@ class InnerLoopVectorizer {
       : OrigLoop(OrigLoop), PSE(PSE), LI(LI), DT(DT), TTI(TTI), AC(AC),
         VF(VecWidth), UF(UnrollFactor), Builder(PSE.getSE()->getContext()),
         Cost(CM), BFI(BFI), PSI(PSI), RTChecks(RTChecks), Plan(Plan),
-        VectorPHVPBB(cast<VPBasicBlock>(
-            Plan.getVectorLoopRegion()->getSinglePredecessor())) {}
+        VectorPHVPBB(cast<VPBasicBlock>(Plan.getVectorPreheader())) {}
 
   virtual ~InnerLoopVectorizer() = default;
 
@@ -4552,7 +4555,7 @@ LoopVectorizationPlanner::selectInterleaveCount(VPlan &Plan, ElementCount VF,
   // We don't attempt to perform interleaving for loops with uncountable early
   // exits because the VPInstruction::AnyOf code cannot currently handle
   // multiple parts.
-  if (Plan.hasEarlyExit())
+  if (Plan.hasEarlyExit() || Plan.shouldEarlyExitContinueInScalarLoop())
     return 1;
 
   const bool HasReductions =
@@ -8202,6 +8205,8 @@ void LoopVectorizationPlanner::buildVPlansWithVPRecipes(ElementCount MinVF,
   auto VPlan0 = VPlanTransforms::buildVPlan0(
       OrigLoop, *LI, Legal->getWidestInductionType(),
       getDebugLocFromInstOrOperands(Legal->getPrimaryInduction()), PSE);
+  VPlan0->setEarlyExitContinuesInScalarLoop(Legal->hasUncountableEarlyExit() &&
+                                            HandleEarlyExitsInScalarTail);
 
   auto MaxVFTimes2 = MaxVF * 2;
   for (ElementCount VF = MinVF; ElementCount::isKnownLT(VF, MaxVFTimes2);) {
@@ -8216,6 +8221,15 @@ void LoopVectorizationPlanner::buildVPlansWithVPRecipes(ElementCount MinVF,
       if (CM.foldTailWithEVL())
         VPlanTransforms::runPass(VPlanTransforms::addExplicitVectorLength,
                                  *Plan, CM.getMaxSafeElements());
+
+      // See if we can convert an early exit vplan to bail out to a scalar
+      // loop if state-changing operations (like stores) are present and
+      // an exit will be taken in the next vector iteration.
+      // If not, discard the plan.
+      if (!Plan->hasScalarVFOnly() && HandleEarlyExitsInScalarTail &&
+          !VPlanTransforms::runPass(
+              VPlanTransforms::handleUncountableExitsInScalarLoop, *Plan))
+        break;
       assert(verifyVPlanIsValid(*Plan) && "VPlan is invalid");
       VPlans.push_back(std::move(Plan));
     }
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.cpp b/llvm/lib/Transforms/Vectorize/VPlan.cpp
index 428a8f4c1348f..a14638ddef5a3 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlan.cpp
@@ -1230,6 +1230,8 @@ VPlan *VPlan::duplicate() {
       NewPlan->ExitBlocks.push_back(cast<VPIRBasicBlock>(VPB));
   }
 
+  NewPlan->setEarlyExitContinuesInScalarLoop(EarlyExitContinuesInScalarLoop);
+
   return NewPlan;
 }
 
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index 1f10058ab4a9a..d475a2daae47d 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -4182,6 +4182,14 @@ class VPlan {
   /// VPlan is destroyed.
   SmallVector<VPBlockBase *> CreatedBlocks;
 
+  /// The entry block in a vplan, which may be a check block that needs to
+  /// be wired up in the right place with existing check blocks.
+  std::optional<VPBasicBlock *> EarlyExitPreheader;
+
+  /// Indicates that an early exit loop will exit before the condition is
+  /// reached, and that the scalar loop must perform the last few iterations.
+  bool EarlyExitContinuesInScalarLoop = false;
+
   /// Construct a VPlan with \p Entry to the plan and with \p ScalarHeader
   /// wrapping the original header of the scalar loop.
   VPlan(VPBasicBlock *Entry, VPIRBasicBlock *ScalarHeader)
@@ -4224,12 +4232,17 @@ class VPlan {
   /// Returns the preheader of the vector loop region, if one exists, or null
   /// otherwise.
   VPBasicBlock *getVectorPreheader() {
+    if (EarlyExitPreheader)
+      return *EarlyExitPreheader;
     VPRegionBlock *VectorRegion = getVectorLoopRegion();
     return VectorRegion
                ? cast<VPBasicBlock>(VectorRegion->getSinglePredecessor())
                : nullptr;
   }
 
+  /// Overrides the current vplan preheader block.
+  void setEarlyExitPreheader(VPBasicBlock *BB) { EarlyExitPreheader = BB; }
+
   /// Returns the VPRegionBlock of the vector loop.
   LLVM_ABI_FOR_TEST VPRegionBlock *getVectorLoopRegion();
   LLVM_ABI_FOR_TEST const VPRegionBlock *getVectorLoopRegion() const;
@@ -4483,6 +4496,17 @@ class VPlan {
            (ExitBlocks.size() == 1 && ExitBlocks[0]->getNumPredecessors() > 1);
   }
 
+  /// Returns true if the vector iteration containing an exit should be handled
+  /// in the scalar loop instead of by masking.
+  bool shouldEarlyExitContinueInScalarLoop() const {
+    return EarlyExitContinuesInScalarLoop;
+  }
+
+  /// If set to true, early exits should be handled in the scalar loop.
+  void setEarlyExitContinuesInScalarLoop(bool Continues) {
+    EarlyExitContinuesInScalarLoop = Continues;
+  }
+
   /// Returns true if the scalar tail may execute after the vector loop. Note
   /// that this relies on unneeded branches to the scalar tail loop being
   /// removed.
diff --git a/llvm/lib/Transforms/Vectorize/VPlanConstruction.cpp b/llvm/lib/Transforms/Vectorize/VPlanConstruction.cpp
index 65688a3f0b6be..42cc56c613e13 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanConstruction.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanConstruction.cpp
@@ -561,6 +561,9 @@ void VPlanTransforms::handleEarlyExits(VPlan &Plan,
         handleUncountableEarlyExit(cast<VPBasicBlock>(Pred), EB, Plan,
                                    cast<VPBasicBlock>(HeaderVPB), LatchVPBB);
         HandledUncountableEarlyExit = true;
+        if (Plan.shouldEarlyExitContinueInScalarLoop())
+          for (VPRecipeBase &R : EB->phis())
+            cast<VPIRPhi>(&R)->removeIncomingValueFor(Pred);
       } else {
         for (VPRecipeBase &R : EB->phis())
           cast<VPIRPhi>(&R)->removeIncomingValueFor(Pred);
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index d9ac26bba7507..af420572324d6 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -20,6 +20,7 @@
 #include "VPlanHelpers.h"
 #include "VPlanPatternMatch.h"
 #include "VPlanUtils.h"
+#include "VPlanValue.h"
 #include "VPlanVerifier.h"
 #include "llvm/ADT/APInt.h"
 #include "llvm/ADT/PostOrderIterator.h"
@@ -31,6 +32,8 @@
 #include "llvm/Analysis/LoopInfo.h"
 #include "llvm/Analysis/ScalarEvolutionPatternMatch.h"
 #include "llvm/Analysis/VectorUtils.h"
+#include "llvm/IR/Instruction.h"
+#include "llvm/IR/Instructions.h"
 #include "llvm/IR/Intrinsics.h"
 #include "llvm/IR/MDBuilder.h"
 #include "llvm/Support/Casting.h"
@@ -1779,7 +1782,7 @@ static bool simplifyBranchConditionForVFAndUF(VPlan &Plan, ElementCount BestVF,
       HeaderR.eraseFromParent();
     }
 
-    VPBlockBase *Preheader = VectorRegion->getSinglePredecessor();
+    VPBlockBase *Preheader = Plan.getVectorPreheader();
     VPBlockBase *Exit = VectorRegion->getSingleSuccessor();
     VPBlockUtils::disconnectBlocks(Preheader, VectorRegion);
     VPBlockUtils::disconnectBlocks(VectorRegion, Exit);
@@ -2916,8 +2919,7 @@ void VPlanTransforms::replaceSymbolicStrides(
   // evolution.
   auto CanUseVersionedStride = [&Plan](VPUser &U, unsigned) {
     auto *R = cast<VPRecipeBase>(&U);
-    return R->getRegion() ||
-           R->getParent() == Plan.getVectorLoopRegion()->getSinglePredecessor();
+    return R->getParent() || R->getParent() == Plan.getVectorPreheader();
   };
   ValueToSCEVMapTy RewriteMap;
   for (const SCEV *Stride : StridesMap.values()) {
@@ -3218,7 +3220,11 @@ expandVPWidenIntOrFpInduction(VPWidenIntOrFpInductionRecipe *WidenIVR,
   }
 
   // If the phi is truncated, truncate the start and step values.
-  VPBuilder Builder(Plan->getVectorPreheader());
+  VPBasicBlock *VectorPH = Plan->getVectorPreheader();
+  VPBuilder Builder(VectorPH);
+  if (VPRecipeBase *Br = VectorPH->getTerminator())
+    Builder.setInsertPoint(Br);
+
   Type *StepTy = TypeInfo.inferScalarType(Step);
   if (Ty->getScalarSizeInBits() < StepTy->getScalarSizeInBits()) {
     assert(StepTy->isIntegerTy() && "Truncation requires an integer type");
@@ -3444,8 +3450,9 @@ void VPlanTransforms::handleUncountableEarlyExit(VPBasicBlock *EarlyExitingVPBB,
     // Early exit operand should always be last phi operand. If EarlyExitVPBB
     // has two predecessors and EarlyExitingVPBB is the first, swap the operands
     // of the phis.
-    for (VPRecipeBase &R : EarlyExitVPBB->phis())
-      cast<VPIRPhi>(&R)->swapOperands();
+    if (!Plan.shouldEarlyExitContinueInScalarLoop())
+      for (VPRecipeBase &R : EarlyExitVPBB->phis())
+        cast<VPIRPhi>(&R)->swapOperands();
   }
 
   VPBuilder Builder(LatchVPBB->getTerminator());
@@ -3462,42 +3469,45 @@ void VPlanTransforms::handleUncountableEarlyExit(VPBasicBlock *EarlyExitingVPBB,
   // block if CondToEarlyExit.
   VPValue *IsEarlyExitTaken =
       Builder.createNaryOp(VPInstruction::AnyOf, {CondToEarlyExit});
-  VPBasicBlock *NewMiddle = Plan.createVPBasicBlock("middle.split");
-  VPBasicBlock *VectorEarlyExitVPBB =
-      Plan.createVPBasicBlock("vector.early.exit");
-  VPBlockUtils::insertOnEdge(LatchVPBB, MiddleVPBB, NewMiddle);
-  VPBlockUtils::connectBlocks(NewMiddle, VectorEarlyExitVPBB);
-  NewMiddle->swapSuccessors();
-
-  VPBlockUtils::connectBlocks(VectorEarlyExitVPBB, EarlyExitVPBB);
-
-  // Update the exit phis in the early exit block.
-  VPBuilder MiddleBuilder(NewMiddle);
-  VPBuilder EarlyExitB(VectorEarlyExitVPBB);
-  for (VPRecipeBase &R : EarlyExitVPBB->phis()) {
-    auto *ExitIRI = cast<VPIRPhi>(&R);
-    // Early exit operand should always be last, i.e., 0 if EarlyExitVPBB has
-    // a single predecessor and 1 if it has two.
-    unsigned EarlyExitIdx = ExitIRI->getNumOperands() - 1;
-    if (ExitIRI->getNumOperands() != 1) {
-      // The first of two operands corresponds to the latch exit, via MiddleVPBB
-      // predecessor. Extract its last lane.
-      ExitIRI->extractLastLaneOfFirstOperand(MiddleBuilder);
-    }
+  if (!Plan.shouldEarlyExitContinueInScalarLoop()) {
+    VPBasicBlock *NewMiddle = Plan.createVPBasicBlock("middle.split");
+    VPBasicBlock *VectorEarlyExitVPBB =
+        Plan.createVPBasicBlock("vector.early.exit");
+    VPBlockUtils::insertOnEdge(LatchVPBB, MiddleVPBB, NewMiddle);
+    VPBlockUtils::connectBlocks(NewMiddle, VectorEarlyExitVPBB);
+    NewMiddle->swapSuccessors();
+
+    VPBlockUtils::connectBlocks(VectorEarlyExitVPBB, EarlyExitVPBB);
+
+    // Update the exit phis in the early exit block.
+    VPBuilder MiddleBuilder(NewMiddle);
+    VPBuilder EarlyExitB(VectorEarlyExitVPBB);
+    for (VPRecipeBase &R : EarlyExitVPBB->phis()) {
+      auto *ExitIRI = cast<VPIRPhi>(&R);
+      // Early exit operand should always be last, i.e., 0 if EarlyExitVPBB has
+      // a single predecessor and 1 if it has two.
+      unsigned EarlyExitIdx = ExitIRI->getNumOperands() - 1;
+      if (ExitIRI->getNumOperands() != 1) {
+        // The first of two operands corresponds to the latch exit, via
+        // MiddleVPBB predecessor. Extract its last lane.
+        ExitIRI->extractLastLaneOfFirstOperand(MiddleBuilder);
+      }
 
-    VPValue *IncomingFromEarlyExit = ExitIRI->getOperand(EarlyExitIdx);
-    if (!IncomingFromEarlyExit->isLiveIn()) {
-      // Update the incoming value from the early exit.
-      VPValue *FirstActiveLane = EarlyExitB.createNaryOp(
-          VPInstruction::FirstActiveLane, {CondToEarlyExit}, nullptr,
-          "first.active.lane");
-      IncomingFromEarlyExit = EarlyExitB.createNaryOp(
-          VPInstruction::ExtractLane, {FirstActiveLane, IncomingFromEarlyExit},
-          nullptr, "early.exit.value");
-      ExitIRI->setOperand(EarlyExitIdx, IncomingFromEarlyExit);
+      VPValue *IncomingFromEarlyExit = ExitIRI->getOperand(EarlyExitIdx);
+      if (!IncomingFromEarlyExit->isLiveIn()) {
+        // Update the incoming value from the early exit.
+        VPValue *FirstActiveLane = EarlyExitB.createNaryOp(
+            VPInstruction::FirstActiveLane, {CondToEarlyExit}, nullptr,
+            "first.active.lane");
+        IncomingFromEarlyExit =
+            EarlyExitB.createNaryOp(VPInstruction::ExtractLane,
+                                    {FirstActiveLane, IncomingFromEarlyExit},
+                                    nullptr, "early.exit.value");
+        ExitIRI->setOperand(EarlyExitIdx, IncomingFromEarlyExit);
+      }
     }
+    MiddleBuilder.createNaryOp(VPInstruction::BranchOnCond, {IsEarlyExitTaken});
   }
-  MiddleBuilder.createNaryOp(VPInstruction::BranchOnCond, {IsEarlyExitTaken});
 
   // Replace the condition controlling the non-early exit from the vector loop
   // with one exiting if either the original condition of the vector latch is
@@ -3514,6 +3524,151 @@ void VPlanTransforms::handleUncountableEarlyExit(VPBasicBlock *EarlyExitingVPBB,
   LatchExitingBranch->eraseFromParent();
 }
 
+bool VPlanTransforms::handleUncountableExitsInScalarLoop(VPlan &Plan) {
+  assert(!Plan.hasScalarVFOnly() &&
+         "Cannot transform uncountable exits in scalar loop");
+
+  // We can abandon a vplan entirely if we return false here, so we shouldn't
+  // crash if some earlier assumptions on scalar IR don't hold for the vplan
+  // version of the loop.
+  VPCanonicalIVPHIRecipe *IV = Plan.getVectorLoopRegion()->getCanonicalIV();
+  VPInstruction *IVUpdate = dyn_cast<VPInstruction>(IV->getBackedgeValue());
+  if (!IVUpdate)
+    return false;
+
+  SmallVector<VPRecipeBase *, 2> GEPs;
+  SmallVector<VPRecipeBase *, 8> ConditionRecipes;
+
+  std::optional<VPValue *> Cond =
+      vputils::getRecipesForUncountableExit(Plan, ConditionRecipes, GEPs);
+  if (!Cond)
+    return false;
+
+  // Check GEPs to see if we can link them to the canonical IV.
+  using namespace llvm::VPlanPatternMatch;
+  for (auto *GEP : GEPs)
+    if (!match(GEP,
+               m_GetElementPtr(m_LiveIn(),
+                               m_ScalarIVSteps(m_Specific(IV), m_SpecificInt(1),
+                                               m_Specific(&Plan.getVF())))))
+      return false;
+
+  // Clone the condition recipes into the preheader
+  SmallDenseMap<VPRecipeBase *, VPRecipeBase *, 8> CloneMap;
+  VPBasicBlock *VectorPH = Plan.getVectorPreheader();
+  for (VPRecipeBase *R : reverse(ConditionRecipes)) {
+    VPRecipeBase *Clone = R->clone();
+    VectorPH->appendRecipe(Clone);
+    CloneMap[R] = Clone;
+  }
+
+  // Remap the cloned recipes to use the corresponding operands.
+  for (VPRecipeBase *R : ConditionRecipes) {
+    auto *Clone = CloneMap.at(R);
+    for (unsigned I = 0; I < R->getNumOperands(); ++I)
+      if (VPRecipeBase *OpR =
+              CloneMap.lookup(R->getOperand(I)->getDefiningRecipe()))
+        Clone->setOperand(I, OpR->getVPSingleValue());
+  }
+
+  // Adjust preheader GEPs to match the value they would have for the first
+  // iteration of the vector body.
+  for (auto *GEP : GEPs)
+    CloneMap.at(GEP)->setOperand(1, IV->getStartValue());
+
+  // Split vector preheader to form a new bypass block.
+  VPBasicBlock *NewPH = VectorPH->splitAt(VectorPH->end());
+  VPBasicBlock *ScalarPH = Plan.getScalarPreheader();
+
+  // Create bypass block branch.
+  VPRecipeBase *Uncountable = (*Cond)->getDefiningRecipe();
+  VPRecipeBase *PHUncountable = CloneMap.at(Uncountable);
+  VPBuilder PHBuilder(VectorPH, VectorPH->end());
+  VPValue *PHAnyOf = PHBuilder.createNaryOp(
+      VPInstruction::AnyOf, {PHUncountable->getVPSingleValue()});
+  PHBuilder.createNaryOp(VPInstruction::BranchOnCond, {PHAnyOf},
+                         PHUncountable->getDebugLoc());
+  VectorPH->clearSuccessors();
+  NewPH->clearPredecessors();
+  VPBlockUtils::connectBlocks(VectorPH, ScalarPH);
+  VPBlockUtils::connectBlocks(VectorPH, NewPH);
+
+  // Modify plan so that other check blocks (e.g. SCEVs) can be attached to
+  // the correct block.
+  Plan.setEarlyExitPreheader(VectorPH);
+
+  // Fix up the resume phi in scalar preheader -- we might not have reached
+  // the calculated maximum vector tripcount, so just use the next value of IV.
+  VPBasicBlock *MiddleBlock = Plan.getMiddleBlock();
+  VPValue *VecTC = &Plan.getVectorTripCount();
+  for (VPRecipeBase &PHI : ScalarPH->phis()) {
+    VPPhi *ResumePHI = dyn_cast<VPPhi>(&PHI);
+    VPValue *EntryVal = nullptr;
+    for (unsigned I = 0; I < ResumePHI->getNumIncoming(); ++I) {
+      const VPBasicBlock *Block = ResumePHI->getIncomingBlock(I);
+      if (Block == Plan.getEntry()) {
+        EntryVal = ResumePHI->getIncomingValue(I);
+      } else if (Block == MiddleBlock) {
+        VPValue *V = ResumePHI->getIncomingValue(I);
+        if (V == VecTC) {
+          ResumePHI->setOperand(I, IVUpdate);
+        } else {
+          return false;
+        }
+      } else {
+        return false;
+      }
+    }
+
+    if (!EntryVal)
+      return false;
+    ResumePHI->addOperand(EntryVal);
+  }
+
+  // Move the IV update if necessary, then update the index operand of the GEP
+  // so that we load the next vector iteration's exit condition data.
+  VPDominatorTree VPDT(Plan);
+  for (auto *GEP : GEPs) {
+    if (!VPDT.properlyDominates(IVUpdate, GEP))
+      IVUpdate->moveBefore(*GEP->getParent(), GEP->getIterator());
+    GEP->setOperand(1, IVUpdate);
+  }
+
+  // Convert loads for the next vector iteration to use a mask so that we
+  // avoid any accesses that the scalar loop would not have performed.
+  for (VPRecipeBase *R : ConditionRecipes) {
+    if (auto *Load = dyn_cast<VPWidenLoadRecipe>(R)) {
+      // Bail out for now if it's already conditional.
+      if (Load->isMasked())
+        return false;
+      VPBuilder MaskBuilder(R);
+      VPValue *ALMMultiplier = Plan.getOrAddLiveIn(
+          ConstantInt::get(IntegerType::getInt64Ty(Plan.getContext()), 1));
+      VPValue *LaneMask = MaskBuilder.createNaryOp(
+          VPInstruction::ActiveLaneMask,
+          {IVUpdate, &Plan.getVectorTripCount(), ALMMultiplier}, nullptr,
+          "uncountable.exit.mask");
+      VPWidenLoadRecipe *NewLoad = new VPWidenLoadRecipe(
+          *(cast<LoadInst>(Load->getUnderlyingValue())), Load->getOperand(0),
+          LaneMask, Load->isConsecutive(), Load->isReverse(), Load->getAlign(),
+          *Load, Load->getDebugLoc());
+      MaskBuilder.insert(NewLoad);
+      Load->replaceAllUsesWith(NewLoad);
+      Load->eraseFromParent();
+    }
+  }
+
+  // Update middle block branch to use IVUpdate vs. the full trip count,
+  // since we may be exiting the vector loop early.
+  VPRecipeBase *OldTerminator = MiddleBlock->getTerminator();
+  VPBuilder MiddleBuilder(OldTerminator);
+  VPValue *FullTC =
+      MiddleBuilder.createICmp(CmpInst::ICMP_EQ, IVUpdate, Plan.getTripCount());
+  OldTerminator->setOperand(0, FullTC);
+
+  return true;
+}
+
 /// This function tries convert extended in-loop reductions to
 /// VPExpressionRecipe and clamp the \p Range if it is beneficial and
 /// valid. The created recipe must be decomposed to its constituent
@@ -4416,8 +4571,7 @@ void VPlanTransforms::addScalarResumePhis(
   auto *ScalarPH = Plan.getScalarPreheader();
   auto *MiddleVPBB = cast<VPBasicBlock>(ScalarPH->getPredecessors()[0]);
   VPRegionBlock *VectorRegion = Plan.getVectorLoopRegion();
-  VPBuilder VectorPHBuilder(
-      cast<VPBasicBlock>(VectorRegion->getSinglePredecessor()));
+  VPBuilder VectorPHBuilder(Plan.getVectorPreheader());
   VPBuilder MiddleBuilder(MiddleVPBB, MiddleVPBB->getFirstNonPhi());
   VPBuilder ScalarPHBuilder(ScalarPH);
   for (VPRecipeBase &ScalarPhiR : Plan.getScalarHeader()->phis()) {
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.h b/llvm/lib/Transforms/Vectorize/VPlanTransforms.h
index b28559b620e13..1ea31e38c9735 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.h
@@ -260,6 +260,16 @@ struct VPlanTransforms {
                                          VPlan &Plan, VPBasicBlock *HeaderVPBB,
                                          VPBasicBlock *LatchVPBB);
 
+  /// Update \p Plan to check whether the next iteration of the vector loop
+  /// would exit (using any exit type) and if so branch to the scalar loop
+  /// instead. This requires identifying the recipes that form the conditions
+  /// for exiting, cloning them to the preheader, then adjusting both the
+  /// preheader recipes (to check the first vector iteration) and those in
+  /// the vector loop (to check the next vector iteration instead of the
+  /// current one). This can be used to avoid complex masking for state-changing
+  /// recipes (like stores).
+  static bool handleUncountableExitsInScalarLoop(VPlan &Plan);
+
   /// Replace loop regions with explicit CFG.
   static void dissolveLoopRegions(VPlan &Plan);
 
diff --git a/llvm/lib/Transforms/Vectorize/VPlanUtils.cpp b/llvm/lib/Transforms/Vectorize/VPlanUtils.cpp
index 4db92e7def3ed..7dbe957ab153c 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanUtils.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanUtils.cpp
@@ -294,11 +294,18 @@ vputils::getRecipesForUncountableExit(VPlan &Plan,
       if (Load->isMasked())
         return std::nullopt;
 
+      Recipes.push_back(Load);
+
+      // Look through vector-pointer recipes.
       VPValue *GEP = Load->getAddr();
+      if (auto *VecPtrR = dyn_cast<VPVectorPointerRecipe>(GEP)) {
+        Recipes.push_back(VecPtrR);
+        GEP = VecPtrR->getOperand(0);
+      }
+
       if (!match(GEP, m_GetElementPtr(m_LiveIn(), m_VPValue())))
         return std::nullopt;
 
-      Recipes.push_back(Load);
       Recipes.push_back(GEP->getDefiningRecipe());
       GEPs.push_back(GEP->getDefiningRecipe());
     } else
diff --git a/llvm/test/Transforms/LoopVectorize/early-exit-handle-exits-in-scalar-loop.ll b/llvm/test/Transforms/LoopVectorize/early-exit-handle-exits-in-scalar-loop.ll
new file mode 100644
index 0000000000000..cce23d307a851
--- /dev/null
+++ b/llvm/test/Transforms/LoopVectorize/early-exit-handle-exits-in-scalar-loop.ll
@@ -0,0 +1,155 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --check-globals none --version 5
+; RUN: opt -S < %s -p loop-vectorize -handle-early-exits-in-scalar-tail -force-vector-width=4 | FileCheck %s
+
+define i32 @simple_contains(ptr align 4 dereferenceable(100) readonly %array, i32 %elt) {
+; CHECK-LABEL: define i32 @simple_contains(
+; CHECK-SAME: ptr readonly align 4 dereferenceable(100) [[ARRAY:%.*]], i32 [[ELT:%.*]]) {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    br label %[[VECTOR_PH:.*]]
+; CHECK:       [[VECTOR_PH]]:
+; CHECK-NEXT:    [[BROADCAST_SPLATINSERT:%.*]] = insertelement <4 x i32> poison, i32 [[ELT]], i64 0
+; CHECK-NEXT:    [[BROADCAST_SPLAT:%.*]] = shufflevector <4 x i32> [[BROADCAST_SPLATINSERT]], <4 x i32> poison, <4 x i32> zeroinitializer
+; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <4 x i32>, ptr [[ARRAY]], align 4
+; CHECK-NEXT:    [[TMP0:%.*]] = icmp eq <4 x i32> [[WIDE_LOAD]], [[BROADCAST_SPLAT]]
+; CHECK-NEXT:    [[TMP1:%.*]] = freeze <4 x i1> [[TMP0]]
+; CHECK-NEXT:    [[TMP2:%.*]] = call i1 @llvm.vector.reduce.or.v4i1(<4 x i1> [[TMP1]])
+; CHECK-NEXT:    br i1 [[TMP2]], label %[[SCALAR_PH:.*]], label %[[VECTOR_PH_SPLIT:.*]]
+; CHECK:       [[VECTOR_PH_SPLIT]]:
+; CHECK-NEXT:    br label %[[VECTOR_BODY:.*]]
+; CHECK:       [[VECTOR_BODY]]:
+; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, %[[VECTOR_PH_SPLIT]] ], [ [[IV:%.*]], %[[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[IV]] = add nuw i64 [[INDEX]], 4
+; CHECK-NEXT:    [[LD_ADDR:%.*]] = getelementptr inbounds i32, ptr [[ARRAY]], i64 [[IV]]
+; CHECK-NEXT:    [[UNCOUNTABLE_EXIT_MASK:%.*]] = call <4 x i1> @llvm.get.active.lane.mask.v4i1.i64(i64 [[IV]], i64 24)
+; CHECK-NEXT:    [[WIDE_LOAD1:%.*]] = call <4 x i32> @llvm.masked.load.v4i32.p0(ptr align 4 [[LD_ADDR]], <4 x i1> [[UNCOUNTABLE_EXIT_MASK]], <4 x i32> poison)
+; CHECK-NEXT:    [[TMP4:%.*]] = icmp eq <4 x i32> [[WIDE_LOAD1]], [[BROADCAST_SPLAT]]
+; CHECK-NEXT:    [[TMP5:%.*]] = freeze <4 x i1> [[TMP4]]
+; CHECK-NEXT:    [[TMP6:%.*]] = call i1 @llvm.vector.reduce.or.v4i1(<4 x i1> [[TMP5]])
+; CHECK-NEXT:    [[TMP7:%.*]] = icmp eq i64 [[IV]], 24
+; CHECK-NEXT:    [[TMP8:%.*]] = or i1 [[TMP6]], [[TMP7]]
+; CHECK-NEXT:    br i1 [[TMP8]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
+; CHECK:       [[MIDDLE_BLOCK]]:
+; CHECK-NEXT:    [[TMP9:%.*]] = icmp eq i64 [[IV]], 25
+; CHECK-NEXT:    br i1 [[TMP9]], label %[[NOT_FOUND:.*]], label %[[SCALAR_PH]]
+; CHECK:       [[SCALAR_PH]]:
+; CHECK-NEXT:    [[BC_RESUME_VAL:%.*]] = phi i64 [ [[IV]], %[[MIDDLE_BLOCK]] ], [ 0, %[[VECTOR_PH]] ]
+; CHECK-NEXT:    br label %[[FOR_BODY:.*]]
+; CHECK:       [[FOR_BODY]]:
+; CHECK-NEXT:    [[IV1:%.*]] = phi i64 [ [[IV_NEXT:%.*]], %[[FOR_INC:.*]] ], [ [[BC_RESUME_VAL]], %[[SCALAR_PH]] ]
+; CHECK-NEXT:    [[LD_ADDR1:%.*]] = getelementptr inbounds i32, ptr [[ARRAY]], i64 [[IV1]]
+; CHECK-NEXT:    [[LD:%.*]] = load i32, ptr [[LD_ADDR1]], align 4
+; CHECK-NEXT:    [[CMP_EARLY:%.*]] = icmp eq i32 [[LD]], [[ELT]]
+; CHECK-NEXT:    br i1 [[CMP_EARLY]], label %[[FOUND:.*]], label %[[FOR_INC]]
+; CHECK:       [[FOR_INC]]:
+; CHECK-NEXT:    [[IV_NEXT]] = add nuw nsw i64 [[IV1]], 1
+; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i64 [[IV_NEXT]], 25
+; CHECK-NEXT:    br i1 [[CMP]], label %[[NOT_FOUND]], label %[[FOR_BODY]], !llvm.loop [[LOOP3:![0-9]+]]
+; CHECK:       [[FOUND]]:
+; CHECK-NEXT:    ret i32 1
+; CHECK:       [[NOT_FOUND]]:
+; CHECK-NEXT:    ret i32 0
+;
+entry:
+  br label %for.body
+
+for.body:
+  %iv = phi i64 [ %iv.next, %for.inc ], [ 0, %entry ]
+  %ld.addr = getelementptr inbounds i32, ptr %array, i64 %iv
+  %ld = load i32, ptr %ld.addr, align 4
+  %cmp.early = icmp eq i32 %ld, %elt
+  br i1 %cmp.early, label %found, label %for.inc
+
+for.inc:
+  %iv.next = add nsw nuw i64 %iv, 1
+  %cmp = icmp eq i64 %iv.next, 25
+  br i1 %cmp, label %not.found, label %for.body
+
+found:
+  ret i32 1
+
+not.found:
+  ret i32 0
+}
+
+define i32 @contains_with_variable_tc(ptr readonly %array, i8 %elt, i64 %n) nofree nosync {
+; CHECK-LABEL: define i32 @contains_with_variable_tc(
+; CHECK-SAME: ptr readonly [[ARRAY:%.*]], i8 [[ELT:%.*]], i64 [[N:%.*]]) #[[ATTR0:[0-9]+]] {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    call void @llvm.assume(i1 true) [ "align"(ptr [[ARRAY]], i64 1), "dereferenceable"(ptr [[ARRAY]], i64 [[N]]) ]
+; CHECK-NEXT:    [[ZERO_TC:%.*]] = icmp eq i64 [[N]], 0
+; CHECK-NEXT:    br i1 [[ZERO_TC]], label %[[NOT_FOUND:.*]], label %[[FOR_BODY_PREHEADER:.*]]
+; CHECK:       [[FOR_BODY_PREHEADER]]:
+; CHECK-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[N]], 4
+; CHECK-NEXT:    br i1 [[MIN_ITERS_CHECK]], label %[[SCALAR_PH:.*]], label %[[VECTOR_PH:.*]]
+; CHECK:       [[VECTOR_PH]]:
+; CHECK-NEXT:    [[N_MOD_VF:%.*]] = urem i64 [[N]], 4
+; CHECK-NEXT:    [[ITERS:%.*]] = sub i64 [[N]], [[N_MOD_VF]]
+; CHECK-NEXT:    [[BROADCAST_SPLATINSERT:%.*]] = insertelement <4 x i8> poison, i8 [[ELT]], i64 0
+; CHECK-NEXT:    [[BROADCAST_SPLAT:%.*]] = shufflevector <4 x i8> [[BROADCAST_SPLATINSERT]], <4 x i8> poison, <4 x i32> zeroinitializer
+; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <4 x i8>, ptr [[ARRAY]], align 1
+; CHECK-NEXT:    [[TMP0:%.*]] = icmp eq <4 x i8> [[WIDE_LOAD]], [[BROADCAST_SPLAT]]
+; CHECK-NEXT:    [[TMP1:%.*]] = freeze <4 x i1> [[TMP0]]
+; CHECK-NEXT:    [[TMP2:%.*]] = call i1 @llvm.vector.reduce.or.v4i1(<4 x i1> [[TMP1]])
+; CHECK-NEXT:    br i1 [[TMP2]], label %[[SCALAR_PH]], label %[[VECTOR_PH_SPLIT:.*]]
+; CHECK:       [[VECTOR_PH_SPLIT]]:
+; CHECK-NEXT:    br label %[[VECTOR_BODY:.*]]
+; CHECK:       [[VECTOR_BODY]]:
+; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, %[[VECTOR_PH_SPLIT]] ], [ [[IV_NEXT:%.*]], %[[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[IV_NEXT]] = add nuw i64 [[INDEX]], 4
+; CHECK-NEXT:    [[TMP3:%.*]] = getelementptr inbounds i8, ptr [[ARRAY]], i64 [[IV_NEXT]]
+; CHECK-NEXT:    [[UNCOUNTABLE_EXIT_MASK:%.*]] = call <4 x i1> @llvm.get.active.lane.mask.v4i1.i64(i64 [[IV_NEXT]], i64 [[ITERS]])
+; CHECK-NEXT:    [[WIDE_LOAD1:%.*]] = call <4 x i8> @llvm.masked.load.v4i8.p0(ptr align 1 [[TMP3]], <4 x i1> [[UNCOUNTABLE_EXIT_MASK]], <4 x i8> poison)
+; CHECK-NEXT:    [[TMP4:%.*]] = icmp eq <4 x i8> [[WIDE_LOAD1]], [[BROADCAST_SPLAT]]
+; CHECK-NEXT:    [[TMP5:%.*]] = freeze <4 x i1> [[TMP4]]
+; CHECK-NEXT:    [[TMP6:%.*]] = call i1 @llvm.vector.reduce.or.v4i1(<4 x i1> [[TMP5]])
+; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i64 [[IV_NEXT]], [[ITERS]]
+; CHECK-NEXT:    [[TMP8:%.*]] = or i1 [[TMP6]], [[CMP]]
+; CHECK-NEXT:    br i1 [[TMP8]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP4:![0-9]+]]
+; CHECK:       [[MIDDLE_BLOCK]]:
+; CHECK-NEXT:    [[TMP9:%.*]] = icmp eq i64 [[IV_NEXT]], [[N]]
+; CHECK-NEXT:    br i1 [[TMP9]], label %[[NOT_FOUND_LOOPEXIT:.*]], label %[[SCALAR_PH]]
+; CHECK:       [[SCALAR_PH]]:
+; CHECK-NEXT:    [[BC_RESUME_VAL:%.*]] = phi i64 [ [[IV_NEXT]], %[[MIDDLE_BLOCK]] ], [ 0, %[[FOR_BODY_PREHEADER]] ], [ 0, %[[VECTOR_PH]] ]
+; CHECK-NEXT:    br label %[[FOR_BODY:.*]]
+; CHECK:       [[FOR_BODY]]:
+; CHECK-NEXT:    [[IV:%.*]] = phi i64 [ [[IV_NEXT1:%.*]], %[[FOR_INC:.*]] ], [ [[BC_RESUME_VAL]], %[[SCALAR_PH]] ]
+; CHECK-NEXT:    [[LD_ADDR:%.*]] = getelementptr inbounds i8, ptr [[ARRAY]], i64 [[IV]]
+; CHECK-NEXT:    [[LD:%.*]] = load i8, ptr [[LD_ADDR]], align 1
+; CHECK-NEXT:    [[CMP_EARLY:%.*]] = icmp eq i8 [[LD]], [[ELT]]
+; CHECK-NEXT:    br i1 [[CMP_EARLY]], label %[[FOUND:.*]], label %[[FOR_INC]]
+; CHECK:       [[FOR_INC]]:
+; CHECK-NEXT:    [[IV_NEXT1]] = add nuw nsw i64 [[IV]], 1
+; CHECK-NEXT:    [[CMP1:%.*]] = icmp eq i64 [[IV_NEXT1]], [[N]]
+; CHECK-NEXT:    br i1 [[CMP1]], label %[[NOT_FOUND_LOOPEXIT]], label %[[FOR_BODY]], !llvm.loop [[LOOP5:![0-9]+]]
+; CHECK:       [[FOUND]]:
+; CHECK-NEXT:    ret i32 1
+; CHECK:       [[NOT_FOUND_LOOPEXIT]]:
+; CHECK-NEXT:    br label %[[NOT_FOUND]]
+; CHECK:       [[NOT_FOUND]]:
+; CHECK-NEXT:    ret i32 0
+;
+
+entry:
+  call void @llvm.assume(i1 true) [ "align"(ptr %array, i64 1), "dereferenceable"(ptr %array, i64 %n) ]
+  %zero.tc = icmp eq i64 %n, 0
+  br i1 %zero.tc, label %not.found, label %for.body
+
+for.body:
+  %iv = phi i64 [ %iv.next, %for.inc ], [ 0, %entry ]
+  %ld.addr = getelementptr inbounds i8, ptr %array, i64 %iv
+  %ld = load i8, ptr %ld.addr
+  %cmp.early = icmp eq i8 %ld, %elt
+  br i1 %cmp.early, label %found, label %for.inc
+
+for.inc:
+  %iv.next = add nsw nuw i64 %iv, 1
+  %cmp = icmp eq i64 %iv.next, %n
+  br i1 %cmp, label %not.found, label %for.body
+
+found:
+  ret i32 1
+
+not.found:
+  ret i32 0
+}
+



More information about the llvm-commits mailing list