[llvm] [LV] Transform to handle exits in the scalar loop (PR #148626)
Graham Hunter via llvm-commits
llvm-commits at lists.llvm.org
Mon Jul 14 06:02:31 PDT 2025
https://github.com/huntergr-arm created https://github.com/llvm/llvm-project/pull/148626
In preparation for supporting stores in early exit loops, this transform replicates the uncounted exit condition in a new check block and adjust the in-loop recipes so that if the next vector iteration would encounter an exit, we instead transfer to the scalar loop.
This can be used with architectures that do not have advanced masking features for state-changing operations (like stores), or in cases where creating multiple mask partitions in the vector loop could be expensive.
>From d2579774c1402317d5f0679476c1d5800a3a59bf Mon Sep 17 00:00:00 2001
From: Graham Hunter <graham.hunter at arm.com>
Date: Tue, 1 Jul 2025 13:08:48 +0000
Subject: [PATCH] Transform code
---
.../Transforms/Vectorize/LoopVectorize.cpp | 24 +-
llvm/lib/Transforms/Vectorize/VPlan.cpp | 2 +
llvm/lib/Transforms/Vectorize/VPlan.h | 24 ++
.../Vectorize/VPlanConstruction.cpp | 3 +
.../Transforms/Vectorize/VPlanPatternMatch.h | 31 +++
.../lib/Transforms/Vectorize/VPlanRecipes.cpp | 2 +-
.../Transforms/Vectorize/VPlanTransforms.cpp | 255 +++++++++++++++---
.../Transforms/Vectorize/VPlanTransforms.h | 11 +
llvm/lib/Transforms/Vectorize/VPlanValue.h | 2 +
.../AArch64/simple_early_exit_scalar_exits.ll | 79 ++++++
10 files changed, 383 insertions(+), 50 deletions(-)
create mode 100644 llvm/test/Transforms/LoopVectorize/AArch64/simple_early_exit_scalar_exits.ll
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 5380a0fc6498a..368ed47a8e590 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -400,6 +400,10 @@ static cl::opt<bool> EnableEarlyExitVectorization(
cl::desc(
"Enable vectorization of early exit loops with uncountable exits."));
+static cl::opt<bool> HandleEarlyExitsInScalarTail(
+ "handle-early-exits-in-scalar-tail", cl::init(false), cl::Hidden,
+ cl::desc("Use the scalar tail to deal with early exit logic"));
+
// Likelyhood of bypassing the vectorized loop because there are zero trips left
// after prolog. See `emitIterationCountCheck`.
static constexpr uint32_t MinItersBypassWeights[] = {1, 127};
@@ -491,8 +495,8 @@ class InnerLoopVectorizer {
AC(AC), ORE(ORE), VF(VecWidth),
MinProfitableTripCount(MinProfitableTripCount), UF(UnrollFactor),
Builder(PSE.getSE()->getContext()), Cost(CM), BFI(BFI), PSI(PSI),
- RTChecks(RTChecks), Plan(Plan),
- VectorPHVPB(Plan.getVectorLoopRegion()->getSinglePredecessor()) {}
+ RTChecks(RTChecks), Plan(Plan), VectorPHVPB(Plan.getVectorPreheader()) {
+ }
virtual ~InnerLoopVectorizer() = default;
@@ -8322,6 +8326,7 @@ void LoopVectorizationPlanner::buildVPlansWithVPRecipes(ElementCount MinVF,
auto MaxVFTimes2 = MaxVF * 2;
auto VPlan0 = VPlanTransforms::buildPlainCFG(OrigLoop, *LI);
+ VPlan0->setEarlyExitContinuesInScalarLoop(HandleEarlyExitsInScalarTail);
for (ElementCount VF = MinVF; ElementCount::isKnownLT(VF, MaxVFTimes2);) {
VFRange SubRange = {VF, MaxVFTimes2};
if (auto Plan = tryToBuildVPlanWithVPRecipes(
@@ -8338,6 +8343,14 @@ void LoopVectorizationPlanner::buildVPlansWithVPRecipes(ElementCount MinVF,
!VPlanTransforms::runPass(VPlanTransforms::tryAddExplicitVectorLength,
*Plan, CM.getMaxSafeElements()))
break;
+ // See if we can convert an early exit vplan to bail out to a scalar
+ // loop if state-changing operations (like stores) are present and
+ // an exit will be taken in the next vector iteration.
+ // If not, discard the plan.
+ if (HandleEarlyExitsInScalarTail && !HasScalarVF &&
+ !VPlanTransforms::runPass(VPlanTransforms::handleExitsInScalarLoop,
+ *Plan))
+ break;
assert(verifyVPlanIsValid(*Plan) && "VPlan is invalid");
VPlans.push_back(std::move(Plan));
}
@@ -8391,8 +8404,7 @@ static void addScalarResumePhis(VPRecipeBuilder &Builder, VPlan &Plan,
auto *ScalarPH = Plan.getScalarPreheader();
auto *MiddleVPBB = cast<VPBasicBlock>(ScalarPH->getPredecessors()[0]);
VPRegionBlock *VectorRegion = Plan.getVectorLoopRegion();
- VPBuilder VectorPHBuilder(
- cast<VPBasicBlock>(VectorRegion->getSinglePredecessor()));
+ VPBuilder VectorPHBuilder(cast<VPBasicBlock>(Plan.getVectorPreheader()));
VPBuilder MiddleBuilder(MiddleVPBB, MiddleVPBB->getFirstNonPhi());
VPBuilder ScalarPHBuilder(ScalarPH);
for (VPRecipeBase &ScalarPhiR : Plan.getScalarHeader()->phis()) {
@@ -8839,8 +8851,7 @@ VPlanPtr LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(
auto CanUseVersionedStride = [&Plan](VPUser &U, unsigned) {
auto *R = cast<VPRecipeBase>(&U);
return R->getParent()->getParent() ||
- R->getParent() ==
- Plan->getVectorLoopRegion()->getSinglePredecessor();
+ R->getParent() == Plan->getVectorPreheader();
};
for (auto [_, Stride] : Legal->getLAI()->getSymbolicStrides()) {
auto *StrideV = cast<SCEVUnknown>(Stride)->getValue();
@@ -8906,6 +8917,7 @@ VPlanPtr LoopVectorizationPlanner::tryToBuildVPlan(VFRange &Range) {
assert(EnableVPlanNativePath && "VPlan-native path is not enabled.");
auto Plan = VPlanTransforms::buildPlainCFG(OrigLoop, *LI);
+ Plan->setEarlyExitContinuesInScalarLoop(HandleEarlyExitsInScalarTail);
VPlanTransforms::prepareForVectorization(
*Plan, Legal->getWidestInductionType(), PSE, true, false, OrigLoop,
getDebugLocFromInstOrOperands(Legal->getPrimaryInduction()), false,
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.cpp b/llvm/lib/Transforms/Vectorize/VPlan.cpp
index 40a55656bfa7e..2d966702148e0 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlan.cpp
@@ -1291,6 +1291,8 @@ VPlan *VPlan::duplicate() {
NewPlan->ExitBlocks.push_back(cast<VPIRBasicBlock>(VPB));
}
+ NewPlan->setEarlyExitContinuesInScalarLoop(EarlyExitContinuesInScalarLoop);
+
return NewPlan;
}
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index 85741b977bb77..834ede74a5b50 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -3883,6 +3883,14 @@ class VPlan {
/// VPlan is destroyed.
SmallVector<VPBlockBase *> CreatedBlocks;
+ /// The entry block in a vplan, which may be a check block that needs to
+ /// be wired up in the right place with existing check blocks.
+ std::optional<VPBasicBlock *> EarlyExitPreheader;
+
+ /// Indicates that an early exit loop will exit before the condition is
+ /// reached, and that the scalar loop must perform the last few iterations.
+ bool EarlyExitContinuesInScalarLoop = false;
+
/// Construct a VPlan with \p Entry to the plan and with \p ScalarHeader
/// wrapping the original header of the scalar loop.
VPlan(VPBasicBlock *Entry, VPIRBasicBlock *ScalarHeader)
@@ -3929,12 +3937,17 @@ class VPlan {
/// Returns the preheader of the vector loop region, if one exists, or null
/// otherwise.
VPBasicBlock *getVectorPreheader() {
+ if (EarlyExitPreheader)
+ return *EarlyExitPreheader;
VPRegionBlock *VectorRegion = getVectorLoopRegion();
return VectorRegion
? cast<VPBasicBlock>(VectorRegion->getSinglePredecessor())
: nullptr;
}
+ /// Overrides the current vplan preheader block.
+ void setEarlyExitPreheader(VPBasicBlock *BB) { EarlyExitPreheader = BB; }
+
/// Returns the VPRegionBlock of the vector loop.
LLVM_ABI_FOR_TEST VPRegionBlock *getVectorLoopRegion();
LLVM_ABI_FOR_TEST const VPRegionBlock *getVectorLoopRegion() const;
@@ -4187,6 +4200,17 @@ class VPlan {
(ExitBlocks.size() == 1 && ExitBlocks[0]->getNumPredecessors() > 1);
}
+ /// Returns true if the vector iteration containing an exit should be handled
+ /// in the scalar loop instead of by masking.
+ bool shouldEarlyExitContinueInScalarLoop() const {
+ return EarlyExitContinuesInScalarLoop;
+ }
+
+ /// If set to true, early exits should be handled in the scalar loop.
+ void setEarlyExitContinuesInScalarLoop(bool Continues) {
+ EarlyExitContinuesInScalarLoop = Continues;
+ }
+
/// Returns true if the scalar tail may execute after the vector loop. Note
/// that this relies on unneeded branches to the scalar tail loop being
/// removed.
diff --git a/llvm/lib/Transforms/Vectorize/VPlanConstruction.cpp b/llvm/lib/Transforms/Vectorize/VPlanConstruction.cpp
index 52eecb000d0c2..decf7a01c3ae9 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanConstruction.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanConstruction.cpp
@@ -502,6 +502,9 @@ void VPlanTransforms::prepareForVectorization(
cast<VPBasicBlock>(HeaderVPB),
cast<VPBasicBlock>(LatchVPB), Range);
HandledUncountableEarlyExit = true;
+ if (Plan.shouldEarlyExitContinueInScalarLoop())
+ for (VPRecipeBase &R : EB->phis())
+ cast<VPIRPhi>(&R)->removeIncomingValueFor(Pred);
} else {
for (VPRecipeBase &R : EB->phis())
cast<VPIRPhi>(&R)->removeIncomingValueFor(Pred);
diff --git a/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h b/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h
index d133610ef4f75..26fca9ae99cca 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h
@@ -632,6 +632,37 @@ m_Intrinsic(const T0 &Op0, const T1 &Op1, const T2 &Op2, const T3 &Op3) {
return m_CombineAnd(m_Intrinsic<IntrID>(Op0, Op1, Op2), m_Argument<3>(Op3));
}
+struct loop_invariant_vpvalue {
+ template <typename ITy> bool match(ITy *V) const {
+ VPValue *Val = dyn_cast<VPValue>(V);
+ return Val && Val->isDefinedOutsideLoopRegions();
+ }
+};
+
+inline loop_invariant_vpvalue m_LoopInvVPValue() {
+ return loop_invariant_vpvalue();
+}
+
+template <typename Op0_t>
+inline UnaryVPInstruction_match<Op0_t, VPInstruction::AnyOf>
+m_AnyOf(const Op0_t &Op0) {
+ return m_VPInstruction<VPInstruction::AnyOf>(Op0);
+}
+
+template <typename SubPattern_t> struct OneUse_match {
+ SubPattern_t SubPattern;
+
+ OneUse_match(const SubPattern_t &SP) : SubPattern(SP) {}
+
+ template <typename OpTy> bool match(OpTy *V) {
+ return V->hasOneUse() && SubPattern.match(V);
+ }
+};
+
+template <typename T> inline OneUse_match<T> m_OneUse(const T &SubPattern) {
+ return SubPattern;
+}
+
} // namespace VPlanPatternMatch
} // namespace llvm
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 3c367664a0988..cf30852394f21 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -285,7 +285,7 @@ bool VPRecipeBase::isPhi() const {
return (getVPDefID() >= VPFirstPHISC && getVPDefID() <= VPLastPHISC) ||
(isa<VPInstruction>(this) &&
cast<VPInstruction>(this)->getOpcode() == Instruction::PHI) ||
- isa<VPIRPhi>(this);
+ isa<VPPhi>(this) || isa<VPIRPhi>(this);
}
bool VPRecipeBase::isScalarCast() const {
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index 6a3b3e6e41955..af420911ae6b7 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -1490,7 +1490,7 @@ static bool simplifyBranchConditionForVFAndUF(VPlan &Plan, ElementCount BestVF,
HeaderPhiR->eraseFromParent();
}
- VPBlockBase *Preheader = VectorRegion->getSinglePredecessor();
+ VPBlockBase *Preheader = Plan.getVectorPreheader();
VPBlockBase *Exit = VectorRegion->getSingleSuccessor();
VPBlockUtils::disconnectBlocks(Preheader, VectorRegion);
VPBlockUtils::disconnectBlocks(VectorRegion, Exit);
@@ -2612,7 +2612,11 @@ expandVPWidenIntOrFpInduction(VPWidenIntOrFpInductionRecipe *WidenIVR,
}
// If the phi is truncated, truncate the start and step values.
- VPBuilder Builder(Plan->getVectorPreheader());
+ VPBasicBlock *VectorPH = Plan->getVectorPreheader();
+ VPBuilder Builder(VectorPH);
+ if (VPRecipeBase *Br = VectorPH->getTerminator())
+ Builder.setInsertPoint(Br);
+
Type *StepTy = TypeInfo.inferScalarType(Step);
if (Ty->getScalarSizeInBits() < StepTy->getScalarSizeInBits()) {
assert(StepTy->isIntegerTy() && "Truncation requires an integer type");
@@ -2776,8 +2780,9 @@ void VPlanTransforms::handleUncountableEarlyExit(
// Early exit operand should always be last phi operand. If EarlyExitVPBB
// has two predecessors and EarlyExitingVPBB is the first, swap the operands
// of the phis.
- for (VPRecipeBase &R : EarlyExitVPBB->phis())
- cast<VPIRPhi>(&R)->swapOperands();
+ if (!Plan.shouldEarlyExitContinueInScalarLoop())
+ for (VPRecipeBase &R : EarlyExitVPBB->phis())
+ cast<VPIRPhi>(&R)->swapOperands();
}
VPBuilder Builder(LatchVPBB->getTerminator());
@@ -2795,48 +2800,51 @@ void VPlanTransforms::handleUncountableEarlyExit(
// block if CondToEarlyExit.
VPValue *IsEarlyExitTaken =
Builder.createNaryOp(VPInstruction::AnyOf, {CondToEarlyExit});
- VPBasicBlock *NewMiddle = Plan.createVPBasicBlock("middle.split");
- VPBasicBlock *VectorEarlyExitVPBB =
- Plan.createVPBasicBlock("vector.early.exit");
- VPBlockUtils::insertOnEdge(LatchVPBB, MiddleVPBB, NewMiddle);
- VPBlockUtils::connectBlocks(NewMiddle, VectorEarlyExitVPBB);
- NewMiddle->swapSuccessors();
-
- VPBlockUtils::connectBlocks(VectorEarlyExitVPBB, EarlyExitVPBB);
-
- // Update the exit phis in the early exit block.
- VPBuilder MiddleBuilder(NewMiddle);
- VPBuilder EarlyExitB(VectorEarlyExitVPBB);
- for (VPRecipeBase &R : EarlyExitVPBB->phis()) {
- auto *ExitIRI = cast<VPIRPhi>(&R);
- // Early exit operand should always be last, i.e., 0 if EarlyExitVPBB has
- // a single predecessor and 1 if it has two.
- unsigned EarlyExitIdx = ExitIRI->getNumOperands() - 1;
- if (ExitIRI->getNumOperands() != 1) {
- // The first of two operands corresponds to the latch exit, via MiddleVPBB
- // predecessor. Extract its last lane.
- ExitIRI->extractLastLaneOfFirstOperand(MiddleBuilder);
- }
+ if (!Plan.shouldEarlyExitContinueInScalarLoop()) {
+ VPBasicBlock *NewMiddle = Plan.createVPBasicBlock("middle.split");
+ VPBasicBlock *VectorEarlyExitVPBB =
+ Plan.createVPBasicBlock("vector.early.exit");
+ VPBlockUtils::insertOnEdge(LatchVPBB, MiddleVPBB, NewMiddle);
+ VPBlockUtils::connectBlocks(NewMiddle, VectorEarlyExitVPBB);
+ NewMiddle->swapSuccessors();
+
+ VPBlockUtils::connectBlocks(VectorEarlyExitVPBB, EarlyExitVPBB);
+
+ // Update the exit phis in the early exit block.
+ VPBuilder MiddleBuilder(NewMiddle);
+ VPBuilder EarlyExitB(VectorEarlyExitVPBB);
+ for (VPRecipeBase &R : EarlyExitVPBB->phis()) {
+ auto *ExitIRI = cast<VPIRPhi>(&R);
+ // Early exit operand should always be last, i.e., 0 if EarlyExitVPBB has
+ // a single predecessor and 1 if it has two.
+ unsigned EarlyExitIdx = ExitIRI->getNumOperands() - 1;
+ if (ExitIRI->getNumOperands() != 1) {
+ // The first of two operands corresponds to the latch exit, via
+ // MiddleVPBB predecessor. Extract its last lane.
+ ExitIRI->extractLastLaneOfFirstOperand(MiddleBuilder);
+ }
- VPValue *IncomingFromEarlyExit = ExitIRI->getOperand(EarlyExitIdx);
- auto IsVector = [](ElementCount VF) { return VF.isVector(); };
- // When the VFs are vectors, need to add `extract` to get the incoming value
- // from early exit. When the range contains scalar VF, limit the range to
- // scalar VF to prevent mis-compilation for the range containing both scalar
- // and vector VFs.
- if (!IncomingFromEarlyExit->isLiveIn() &&
- LoopVectorizationPlanner::getDecisionAndClampRange(IsVector, Range)) {
- // Update the incoming value from the early exit.
- VPValue *FirstActiveLane = EarlyExitB.createNaryOp(
- VPInstruction::FirstActiveLane, {CondToEarlyExit}, nullptr,
- "first.active.lane");
- IncomingFromEarlyExit = EarlyExitB.createNaryOp(
- Instruction::ExtractElement, {IncomingFromEarlyExit, FirstActiveLane},
- nullptr, "early.exit.value");
- ExitIRI->setOperand(EarlyExitIdx, IncomingFromEarlyExit);
+ VPValue *IncomingFromEarlyExit = ExitIRI->getOperand(EarlyExitIdx);
+ auto IsVector = [](ElementCount VF) { return VF.isVector(); };
+ // When the VFs are vectors, need to add `extract` to get the incoming
+ // value from early exit. When the range contains scalar VF, limit the
+ // range to scalar VF to prevent mis-compilation for the range containing
+ // both scalar and vector VFs.
+ if (!IncomingFromEarlyExit->isLiveIn() &&
+ LoopVectorizationPlanner::getDecisionAndClampRange(IsVector, Range)) {
+ // Update the incoming value from the early exit.
+ VPValue *FirstActiveLane = EarlyExitB.createNaryOp(
+ VPInstruction::FirstActiveLane, {CondToEarlyExit}, nullptr,
+ "first.active.lane");
+ IncomingFromEarlyExit =
+ EarlyExitB.createNaryOp(Instruction::ExtractElement,
+ {IncomingFromEarlyExit, FirstActiveLane},
+ nullptr, "early.exit.value");
+ ExitIRI->setOperand(EarlyExitIdx, IncomingFromEarlyExit);
+ }
}
+ MiddleBuilder.createNaryOp(VPInstruction::BranchOnCond, {IsEarlyExitTaken});
}
- MiddleBuilder.createNaryOp(VPInstruction::BranchOnCond, {IsEarlyExitTaken});
// Replace the condition controlling the non-early exit from the vector loop
// with one exiting if either the original condition of the vector latch is
@@ -2853,6 +2861,167 @@ void VPlanTransforms::handleUncountableEarlyExit(
LatchExitingBranch->eraseFromParent();
}
+bool VPlanTransforms::handleExitsInScalarLoop(VPlan &Plan) {
+ // We can abandon a vplan entirely if we return false here, so we shouldn't
+ // crash if some earlier assumptions on scalar IR don't hold for the vplan
+ // version of the loop.
+ if (Plan.hasScalarVFOnly())
+ return false;
+ auto *Region = Plan.getVectorLoopRegion();
+ using namespace llvm::VPlanPatternMatch;
+ VPCanonicalIVPHIRecipe *IV = Plan.getCanonicalIV();
+ VPInstruction *IVUpdate = dyn_cast<VPInstruction>(IV->getBackedgeValue());
+ if (!IVUpdate)
+ return false;
+
+ // Find the uncounted loop exit condition.
+ VPValue *Uncounted = nullptr;
+ if (!match(Region->getExitingBasicBlock()->getTerminator(),
+ m_BranchOnCond(m_OneUse(m_c_BinaryOr(
+ m_OneUse(m_AnyOf(m_VPValue(Uncounted))), m_VPValue())))))
+ return false;
+
+ // Extract the recipes needed to create the uncountable exit condition.
+ // Looking for icmp(load(gep(base, iv)), loop_inv) or similar
+ SmallVector<VPValue *, 4> Worklist;
+ SmallDenseMap<VPValue *, VPRecipeBase *, 8> CloneMap;
+ SmallVector<VPReplicateRecipe *, 2> GEPs;
+ SmallVector<VPRecipeBase *, 8> ConditionRecipes;
+ bool LoadFound = false;
+ bool IVLinkedGEP = false;
+
+ Worklist.push_back(Uncounted);
+ while (!Worklist.empty()) {
+ VPValue *V = Worklist.pop_back_val();
+ if (V->isDefinedOutsideLoopRegions())
+ continue;
+ if (V->getNumUsers() > 1)
+ return false;
+ if (auto *Cmp = dyn_cast<VPWidenRecipe>(V)) {
+ if (Cmp->getOpcode() != Instruction::ICmp)
+ return false;
+ Worklist.push_back(Cmp->getOperand(0));
+ Worklist.push_back(Cmp->getOperand(1));
+ ConditionRecipes.push_back(Cmp);
+ } else if (auto *Load = dyn_cast<VPWidenLoadRecipe>(V)) {
+ if (!Load->isConsecutive() || Load->isMasked())
+ return false;
+ Worklist.push_back(Load->getAddr());
+ ConditionRecipes.push_back(Load);
+ LoadFound = true;
+ } else if (auto *VecPtr = dyn_cast<VPVectorPointerRecipe>(V)) {
+ Worklist.push_back(VecPtr->getOperand(0));
+ ConditionRecipes.push_back(VecPtr);
+ } else if (auto *GEP = dyn_cast<VPReplicateRecipe>(V)) {
+ if (GEP->getNumOperands() != 2)
+ return false;
+ if (!match(GEP, m_GetElementPtr(
+ m_LoopInvVPValue(),
+ m_ScalarIVSteps(m_Specific(IV), m_SpecificInt(1),
+ m_Specific(&Plan.getVF())))))
+ return false;
+ GEPs.push_back(GEP);
+ ConditionRecipes.push_back(GEP);
+ IVLinkedGEP = true;
+ } else
+ return false;
+ }
+
+ // If we didn't find any recipes, didn't find a load, or didn't find a
+ // GEP linked to the IV, bail out.
+ if (ConditionRecipes.empty() || !LoadFound || !IVLinkedGEP)
+ return false;
+
+ // Clone the condition recipes into the preheader
+ VPBasicBlock *VectorPH = Plan.getVectorPreheader();
+ for (VPRecipeBase *R : reverse(ConditionRecipes)) {
+ VPRecipeBase *Clone = nullptr;
+ Clone = R->clone();
+ VectorPH->appendRecipe(Clone);
+ CloneMap[R->getVPSingleValue()] = Clone;
+ }
+
+ // Remap the cloned recipes to use the corresponding operands.
+ for (VPRecipeBase *R : ConditionRecipes) {
+ auto *Clone = CloneMap.at(R->getVPSingleValue());
+ for (unsigned I = 0; I < R->getNumOperands(); ++I)
+ if (VPRecipeBase *OpR = CloneMap.lookup(R->getOperand(I)))
+ Clone->setOperand(I, OpR->getVPSingleValue());
+ }
+
+ // Adjust preheader GEPs
+ for (auto *GEP : GEPs)
+ CloneMap[GEP]->setOperand(
+ 1, Plan.getOrAddLiveIn(ConstantInt::get(IV->getScalarType(), 0)));
+
+ // Split vector preheader to form a new bypass block.
+ VPBasicBlock *NewPH = VectorPH->splitAt(VectorPH->end());
+ VPBasicBlock *ScalarPH = Plan.getScalarPreheader();
+
+ // Create bypass block branch.
+ VPBuilder PHBuilder(VectorPH, VectorPH->end());
+ VPValue *PHAnyOf = PHBuilder.createNaryOp(
+ VPInstruction::AnyOf, {CloneMap[Uncounted]->getVPSingleValue()});
+ PHBuilder.createNaryOp(VPInstruction::BranchOnCond, {PHAnyOf},
+ CloneMap[Uncounted]->getDebugLoc());
+ VectorPH->clearSuccessors();
+ NewPH->clearPredecessors();
+ VPBlockUtils::connectBlocks(VectorPH, ScalarPH);
+ VPBlockUtils::connectBlocks(VectorPH, NewPH);
+
+ // Modify plan so that other check blocks (e.g. SCEVs) can be attached to
+ // the correct block.
+ Plan.setEarlyExitPreheader(VectorPH);
+
+ // Fix up the resume phi in scalar preheader -- we might not have reached
+ // the calculated maximum vector tripcount, so just use the next value of IV.
+ VPBasicBlock *MiddleBlock = Plan.getMiddleBlock();
+ VPValue *VecTC = &Plan.getVectorTripCount();
+ for (VPRecipeBase &PHI : ScalarPH->phis()) {
+ VPPhi *ResumePHI = dyn_cast<VPPhi>(&PHI);
+ VPValue *EntryVal = nullptr;
+ for (unsigned I = 0; I < ResumePHI->getNumIncoming(); ++I) {
+ const VPBasicBlock *Block = ResumePHI->getIncomingBlock(I);
+ if (Block == Plan.getEntry()) {
+ EntryVal = ResumePHI->getIncomingValue(I);
+ } else if (Block == MiddleBlock) {
+ VPValue *V = ResumePHI->getIncomingValue(I);
+ if (V == VecTC) {
+ ResumePHI->setOperand(I, IVUpdate);
+ } else {
+ return false;
+ }
+ } else {
+ return false;
+ }
+ }
+
+ if (!EntryVal)
+ return false;
+ ResumePHI->addOperand(EntryVal);
+ }
+
+ // Move the IV update if necessary, then update the index operand of the GEP
+ // so that we load the next vector iteration's exit condition data.
+ VPDominatorTree VPDT;
+ VPDT.recalculate(Plan);
+ for (auto *GEP : GEPs) {
+ if (!VPDT.properlyDominates(IVUpdate, GEP))
+ IVUpdate->moveBefore(*GEP->getParent(), GEP->getIterator());
+ GEP->setOperand(1, IVUpdate);
+ }
+
+ // Update middle block branch to use IVUpdate vs. the full trip count,
+ // since we may be exiting the vector loop early.
+ VPRecipeBase *OldTerminator = MiddleBlock->getTerminator();
+ VPBuilder MiddleBuilder(OldTerminator);
+ VPValue *FullTC =
+ MiddleBuilder.createICmp(CmpInst::ICMP_EQ, IVUpdate, Plan.getTripCount());
+ OldTerminator->setOperand(0, FullTC);
+
+ return true;
+}
+
/// This function tries convert extended in-loop reductions to
/// VPExpressionRecipe and clamp the \p Range if it is beneficial and
/// valid. The created recipe must be decomposed to its constituent
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.h b/llvm/lib/Transforms/Vectorize/VPlanTransforms.h
index 870b1bb68b79a..8f7a10fb0ae20 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.h
@@ -196,6 +196,17 @@ struct VPlanTransforms {
VPBasicBlock *LatchVPBB,
VFRange &Range);
+ /// Update \p Plan to check whether the next iteration of the vector loop
+ /// would exit (using any exit type) and if so branch to the scalar loop
+ /// instead. This requires identifying the recipes that form the conditions
+ /// for exiting, cloning them to the preheader, then adjusting both the
+ /// preheader recipes (to check the first vector iteration) and those in
+ /// the vector loop (to check the next vector iteration instead of the
+ /// current one). This can be used to avoid complex masking for state-changing
+ /// recipes (like stores).
+ static bool handleExitsInScalarLoop(VPlan &Plan);
+
+
/// Replace loop regions with explicit CFG.
static void dissolveLoopRegions(VPlan &Plan);
diff --git a/llvm/lib/Transforms/Vectorize/VPlanValue.h b/llvm/lib/Transforms/Vectorize/VPlanValue.h
index 24f6d61512ef6..5fecbbdef4b5b 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanValue.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanValue.h
@@ -148,6 +148,8 @@ class LLVM_ABI_FOR_TEST VPValue {
return Current != user_end();
}
+ bool hasOneUse() const { return getNumUsers() == 1; }
+
void replaceAllUsesWith(VPValue *New);
/// Go through the uses list for this VPValue and make each use point to \p
diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/simple_early_exit_scalar_exits.ll b/llvm/test/Transforms/LoopVectorize/AArch64/simple_early_exit_scalar_exits.ll
new file mode 100644
index 0000000000000..8301544129ed2
--- /dev/null
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/simple_early_exit_scalar_exits.ll
@@ -0,0 +1,79 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
+; RUN: opt -S < %s -p loop-vectorize -handle-early-exits-in-scalar-tail | FileCheck %s --check-prefixes=CHECK
+
+target triple = "aarch64-unknown-linux-gnu"
+
+define i32 @simple_contains(ptr align 4 dereferenceable(100) readonly %array, i32 %elt) {
+; CHECK-LABEL: define i32 @simple_contains(
+; CHECK-SAME: ptr readonly align 4 dereferenceable(100) [[ARRAY:%.*]], i32 [[ELT:%.*]]) {
+; CHECK-NEXT: entry:
+; CHECK-NEXT: br i1 false, label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
+; CHECK: vector.ph:
+; CHECK-NEXT: [[BROADCAST_SPLATINSERT:%.*]] = insertelement <4 x i32> poison, i32 [[ELT]], i64 0
+; CHECK-NEXT: [[BROADCAST_SPLAT:%.*]] = shufflevector <4 x i32> [[BROADCAST_SPLATINSERT]], <4 x i32> poison, <4 x i32> zeroinitializer
+; CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds i32, ptr [[ARRAY]], i32 0
+; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <4 x i32>, ptr [[TMP0]], align 4
+; CHECK-NEXT: [[TMP1:%.*]] = icmp eq <4 x i32> [[WIDE_LOAD]], [[BROADCAST_SPLAT]]
+; CHECK-NEXT: [[TMP2:%.*]] = call i1 @llvm.vector.reduce.or.v4i1(<4 x i1> [[TMP1]])
+; CHECK-NEXT: br i1 [[TMP2]], label [[SCALAR_PH]], label [[VECTOR_PH_SPLIT:%.*]]
+; CHECK: vector.ph.split:
+; CHECK-NEXT: br label [[FOR_BODY:%.*]]
+; CHECK: vector.body:
+; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH_SPLIT]] ], [ [[INDEX_NEXT:%.*]], [[FOR_BODY]] ]
+; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 4
+; CHECK-NEXT: [[TMP3:%.*]] = getelementptr inbounds i32, ptr [[ARRAY]], i64 [[INDEX_NEXT]]
+; CHECK-NEXT: [[TMP4:%.*]] = getelementptr inbounds i32, ptr [[TMP3]], i32 0
+; CHECK-NEXT: [[WIDE_LOAD1:%.*]] = load <4 x i32>, ptr [[TMP4]], align 4
+; CHECK-NEXT: [[TMP5:%.*]] = icmp eq <4 x i32> [[WIDE_LOAD1]], [[BROADCAST_SPLAT]]
+; CHECK-NEXT: [[TMP6:%.*]] = call i1 @llvm.vector.reduce.or.v4i1(<4 x i1> [[TMP5]])
+; CHECK-NEXT: [[TMP7:%.*]] = icmp eq i64 [[INDEX_NEXT]], 24
+; CHECK-NEXT: [[TMP8:%.*]] = or i1 [[TMP6]], [[TMP7]]
+; CHECK-NEXT: br i1 [[TMP8]], label [[MIDDLE_BLOCK:%.*]], label [[FOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
+; CHECK: middle.block:
+; CHECK-NEXT: [[TMP9:%.*]] = icmp eq i64 [[INDEX_NEXT]], 25
+; CHECK-NEXT: br i1 [[TMP9]], label [[NOT_FOUND:%.*]], label [[SCALAR_PH]]
+; CHECK: scalar.ph:
+; CHECK-NEXT: [[BC_RESUME_VAL:%.*]] = phi i64 [ [[INDEX_NEXT]], [[MIDDLE_BLOCK]] ], [ 0, [[ENTRY:%.*]] ], [ 0, [[VECTOR_PH]] ]
+; CHECK-NEXT: br label [[FOR_BODY1:%.*]]
+; CHECK: for.body:
+; CHECK-NEXT: [[IV:%.*]] = phi i64 [ [[IV_NEXT:%.*]], [[FOR_INC:%.*]] ], [ [[BC_RESUME_VAL]], [[SCALAR_PH]] ]
+; CHECK-NEXT: [[LD_ADDR:%.*]] = getelementptr inbounds i32, ptr [[ARRAY]], i64 [[IV]]
+; CHECK-NEXT: [[LD:%.*]] = load i32, ptr [[LD_ADDR]], align 4
+; CHECK-NEXT: [[CMP_EARLY:%.*]] = icmp eq i32 [[LD]], [[ELT]]
+; CHECK-NEXT: br i1 [[CMP_EARLY]], label [[FOUND:%.*]], label [[FOR_INC]]
+; CHECK: for.inc:
+; CHECK-NEXT: [[IV_NEXT]] = add nuw nsw i64 [[IV]], 1
+; CHECK-NEXT: [[CMP:%.*]] = icmp eq i64 [[IV_NEXT]], 25
+; CHECK-NEXT: br i1 [[CMP]], label [[NOT_FOUND]], label [[FOR_BODY1]], !llvm.loop [[LOOP3:![0-9]+]]
+; CHECK: found:
+; CHECK-NEXT: ret i32 1
+; CHECK: not.found:
+; CHECK-NEXT: ret i32 0
+;
+entry:
+ br label %for.body
+
+for.body:
+ %iv = phi i64 [ %iv.next, %for.inc ], [ 0, %entry ]
+ %ld.addr = getelementptr inbounds i32, ptr %array, i64 %iv
+ %ld = load i32, ptr %ld.addr, align 4
+ %cmp.early = icmp eq i32 %ld, %elt
+ br i1 %cmp.early, label %found, label %for.inc
+
+for.inc:
+ %iv.next = add nsw nuw i64 %iv, 1
+ %cmp = icmp eq i64 %iv.next, 25
+ br i1 %cmp, label %not.found, label %for.body
+
+found:
+ ret i32 1
+
+not.found:
+ ret i32 0
+}
+;.
+; CHECK: [[LOOP0]] = distinct !{[[LOOP0]], [[META1:![0-9]+]], [[META2:![0-9]+]]}
+; CHECK: [[META1]] = !{!"llvm.loop.isvectorized", i32 1}
+; CHECK: [[META2]] = !{!"llvm.loop.unroll.runtime.disable"}
+; CHECK: [[LOOP3]] = distinct !{[[LOOP3]], [[META2]], [[META1]]}
+;.
More information about the llvm-commits
mailing list