[llvm] ada6d78 - [LoopFlatten] Address FIXME about getTripCountFromExitCount. NFC.
Sjoerd Meijer via llvm-commits
llvm-commits at lists.llvm.org
Mon Jan 24 05:57:03 PST 2022
Author: Sjoerd Meijer
Date: 2022-01-24T13:46:19Z
New Revision: ada6d78a7802f8057f1ab7cee0bed25f91fcc4b4
URL: https://github.com/llvm/llvm-project/commit/ada6d78a7802f8057f1ab7cee0bed25f91fcc4b4
DIFF: https://github.com/llvm/llvm-project/commit/ada6d78a7802f8057f1ab7cee0bed25f91fcc4b4.diff
LOG: [LoopFlatten] Address FIXME about getTripCountFromExitCount. NFC.
Together with the previous commit which mainly documents better LoopFlatten's
overall strategy, this addresses a concern added as a FIXME comment in D110587;
the code refactoring (NFC) introduces functions (also for the SCEV usage) to
make this clearer.
Added:
Modified:
llvm/lib/Transforms/Scalar/LoopFlatten.cpp
Removed:
################################################################################
diff --git a/llvm/lib/Transforms/Scalar/LoopFlatten.cpp b/llvm/lib/Transforms/Scalar/LoopFlatten.cpp
index 4d9578934d9e..c46db4e63bfe 100644
--- a/llvm/lib/Transforms/Scalar/LoopFlatten.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopFlatten.cpp
@@ -149,6 +149,118 @@ struct FlattenInfo {
return false;
return NarrowInnerInductionPHI == Phi || NarrowOuterInductionPHI == Phi;
}
+ bool isInnerLoopIncrement(User *U) {
+ return InnerIncrement == U;
+ }
+ bool isOuterLoopIncrement(User *U) {
+ return OuterIncrement == U;
+ }
+ bool isInnerLoopTest(User *U) {
+ return InnerBranch->getCondition() == U;
+ }
+
+ bool checkOuterInductionPhiUsers(SmallPtrSet<Value *, 4> &ValidOuterPHIUses) {
+ for (User *U : OuterInductionPHI->users()) {
+ if (isOuterLoopIncrement(U))
+ continue;
+
+ auto IsValidOuterPHIUses = [&] (User *U) -> bool {
+ LLVM_DEBUG(dbgs() << "Found use of outer induction variable: "; U->dump());
+ if (!ValidOuterPHIUses.count(U)) {
+ LLVM_DEBUG(dbgs() << "Did not match expected pattern, bailing\n");
+ return false;
+ }
+ LLVM_DEBUG(dbgs() << "Use is optimisable\n");
+ return true;
+ };
+
+ if (auto *V = dyn_cast<TruncInst>(U)) {
+ for (auto *K : V->users()) {
+ if (!IsValidOuterPHIUses(K))
+ return false;
+ }
+ continue;
+ }
+
+ if (!IsValidOuterPHIUses(U))
+ return false;
+ }
+ return true;
+ }
+
+ bool matchLinearIVUser(User *U, Value *InnerTripCount,
+ SmallPtrSet<Value *, 4> &ValidOuterPHIUses) {
+ LLVM_DEBUG(dbgs() << "Found use of inner induction variable: "; U->dump());
+ Value *MatchedMul = nullptr;
+ Value *MatchedItCount = nullptr;
+
+ bool IsAdd = match(U, m_c_Add(m_Specific(InnerInductionPHI),
+ m_Value(MatchedMul))) &&
+ match(MatchedMul, m_c_Mul(m_Specific(OuterInductionPHI),
+ m_Value(MatchedItCount)));
+
+ // Matches the same pattern as above, except it also looks for truncs
+ // on the phi, which can be the result of widening the induction variables.
+ bool IsAddTrunc =
+ match(U, m_c_Add(m_Trunc(m_Specific(InnerInductionPHI)),
+ m_Value(MatchedMul))) &&
+ match(MatchedMul, m_c_Mul(m_Trunc(m_Specific(OuterInductionPHI)),
+ m_Value(MatchedItCount)));
+
+ if (!MatchedItCount)
+ return false;
+
+ // Look through extends if the IV has been widened.
+ if (Widened &&
+ (isa<SExtInst>(MatchedItCount) || isa<ZExtInst>(MatchedItCount))) {
+ assert(MatchedItCount->getType() == InnerInductionPHI->getType() &&
+ "Unexpected type mismatch in types after widening");
+ MatchedItCount = isa<SExtInst>(MatchedItCount)
+ ? dyn_cast<SExtInst>(MatchedItCount)->getOperand(0)
+ : dyn_cast<ZExtInst>(MatchedItCount)->getOperand(0);
+ }
+
+ if ((IsAdd || IsAddTrunc) && MatchedItCount == InnerTripCount) {
+ LLVM_DEBUG(dbgs() << "Use is optimisable\n");
+ ValidOuterPHIUses.insert(MatchedMul);
+ LinearIVUses.insert(U);
+ return true;
+ }
+
+ LLVM_DEBUG(dbgs() << "Did not match expected pattern, bailing\n");
+ return false;
+ }
+
+ bool checkInnerInductionPhiUsers(SmallPtrSet<Value *, 4> &ValidOuterPHIUses) {
+ Value *SExtInnerTripCount = InnerTripCount;
+ if (Widened &&
+ (isa<SExtInst>(InnerTripCount) || isa<ZExtInst>(InnerTripCount)))
+ SExtInnerTripCount = cast<Instruction>(InnerTripCount)->getOperand(0);
+
+ for (User *U : InnerInductionPHI->users()) {
+ if (isInnerLoopIncrement(U))
+ continue;
+
+ // After widening the IVs, a trunc instruction might have been introduced,
+ // so look through truncs.
+ if (isa<TruncInst>(U)) {
+ if (!U->hasOneUse())
+ return false;
+ U = *U->user_begin();
+ }
+
+ // If the use is in the compare (which is also the condition of the inner
+ // branch) then the compare has been altered by another transformation e.g
+ // icmp ult %inc, tripcount -> icmp ult %j, tripcount-1, where tripcount is
+ // a constant. Ignore this use as the compare gets removed later anyway.
+ if (isInnerLoopTest(U))
+ continue;
+
+ if (!matchLinearIVUser(U, SExtInnerTripCount, ValidOuterPHIUses))
+ return false;
+ }
+ return true;
+ }
};
static bool
@@ -162,6 +274,77 @@ setLoopComponents(Value *&TC, Value *&TripCount, BinaryOperator *&Increment,
return true;
}
+// Given the RHS of the loop latch compare instruction, verify with SCEV
+// that this is indeed the loop tripcount.
+// TODO: This used to be a straightforward check but has grown to be quite
+// complicated now. It is therefore worth revisiting what the additional
+// benefits are of this (compared to relying on canonical loops and pattern
+// matching).
+static bool verifyTripCount(Value *RHS, Loop *L,
+ SmallPtrSetImpl<Instruction *> &IterationInstructions,
+ PHINode *&InductionPHI, Value *&TripCount, BinaryOperator *&Increment,
+ BranchInst *&BackBranch, ScalarEvolution *SE, bool IsWidened) {
+ const SCEV *BackedgeTakenCount = SE->getBackedgeTakenCount(L);
+ if (isa<SCEVCouldNotCompute>(BackedgeTakenCount)) {
+ LLVM_DEBUG(dbgs() << "Backedge-taken count is not predictable\n");
+ return false;
+ }
+
+ // The Extend=false flag is used for getTripCountFromExitCount as we want
+ // to verify and match it with the pattern matched tripcount. Please note
+ // that overflow checks are performed in checkOverflow, but are first tried
+ // to avoid by widening the IV.
+ const SCEV *SCEVTripCount =
+ SE->getTripCountFromExitCount(BackedgeTakenCount, /*Extend=*/false);
+
+ const SCEV *SCEVRHS = SE->getSCEV(RHS);
+ if (SCEVRHS == SCEVTripCount)
+ return setLoopComponents(RHS, TripCount, Increment, IterationInstructions);
+ ConstantInt *ConstantRHS = dyn_cast<ConstantInt>(RHS);
+ if (ConstantRHS) {
+ const SCEV *BackedgeTCExt = nullptr;
+ if (IsWidened) {
+ const SCEV *SCEVTripCountExt;
+ // Find the extended backedge taken count and extended trip count using
+ // SCEV. One of these should now match the RHS of the compare.
+ BackedgeTCExt = SE->getZeroExtendExpr(BackedgeTakenCount, RHS->getType());
+ SCEVTripCountExt = SE->getTripCountFromExitCount(BackedgeTCExt, false);
+ if (SCEVRHS != BackedgeTCExt && SCEVRHS != SCEVTripCountExt) {
+ LLVM_DEBUG(dbgs() << "Could not find valid trip count\n");
+ return false;
+ }
+ }
+ // If the RHS of the compare is equal to the backedge taken count we need
+ // to add one to get the trip count.
+ if (SCEVRHS == BackedgeTCExt || SCEVRHS == BackedgeTakenCount) {
+ ConstantInt *One = ConstantInt::get(ConstantRHS->getType(), 1);
+ Value *NewRHS = ConstantInt::get(
+ ConstantRHS->getContext(), ConstantRHS->getValue() + One->getValue());
+ return setLoopComponents(NewRHS, TripCount, Increment,
+ IterationInstructions);
+ }
+ return setLoopComponents(RHS, TripCount, Increment, IterationInstructions);
+ }
+ // If the RHS isn't a constant then check that the reason it doesn't match
+ // the SCEV trip count is because the RHS is a ZExt or SExt instruction
+ // (and take the trip count to be the RHS).
+ if (!IsWidened) {
+ LLVM_DEBUG(dbgs() << "Could not find valid trip count\n");
+ return false;
+ }
+ auto *TripCountInst = dyn_cast<Instruction>(RHS);
+ if (!TripCountInst) {
+ LLVM_DEBUG(dbgs() << "Could not find valid trip count\n");
+ return false;
+ }
+ if ((!isa<ZExtInst>(TripCountInst) && !isa<SExtInst>(TripCountInst)) ||
+ SE->getSCEV(TripCountInst->getOperand(0)) != SCEVTripCount) {
+ LLVM_DEBUG(dbgs() << "Could not find valid extended trip count\n");
+ return false;
+ }
+ return setLoopComponents(RHS, TripCount, Increment, IterationInstructions);
+}
+
// Finds the induction variable, increment and trip count for a simple loop that
// we can flatten.
static bool findLoopComponents(
@@ -238,63 +421,9 @@ static bool findLoopComponents(
// another transformation has changed the compare (e.g. icmp ult %inc,
// tripcount -> icmp ult %j, tripcount-1), or both.
Value *RHS = Compare->getOperand(1);
- const SCEV *BackedgeTakenCount = SE->getBackedgeTakenCount(L);
- if (isa<SCEVCouldNotCompute>(BackedgeTakenCount)) {
- LLVM_DEBUG(dbgs() << "Backedge-taken count is not predictable\n");
- return false;
- }
- // The use of the Extend=false flag on getTripCountFromExitCount was added
- // during a refactoring to preserve existing behavior. However, there's
- // nothing obvious in the surrounding code when handles the overflow case.
- // FIXME: audit code to establish whether there's a latent bug here.
- const SCEV *SCEVTripCount =
- SE->getTripCountFromExitCount(BackedgeTakenCount, false);
- const SCEV *SCEVRHS = SE->getSCEV(RHS);
- if (SCEVRHS == SCEVTripCount)
- return setLoopComponents(RHS, TripCount, Increment, IterationInstructions);
- ConstantInt *ConstantRHS = dyn_cast<ConstantInt>(RHS);
- if (ConstantRHS) {
- const SCEV *BackedgeTCExt = nullptr;
- if (IsWidened) {
- const SCEV *SCEVTripCountExt;
- // Find the extended backedge taken count and extended trip count using
- // SCEV. One of these should now match the RHS of the compare.
- BackedgeTCExt = SE->getZeroExtendExpr(BackedgeTakenCount, RHS->getType());
- SCEVTripCountExt = SE->getTripCountFromExitCount(BackedgeTCExt, false);
- if (SCEVRHS != BackedgeTCExt && SCEVRHS != SCEVTripCountExt) {
- LLVM_DEBUG(dbgs() << "Could not find valid trip count\n");
- return false;
- }
- }
- // If the RHS of the compare is equal to the backedge taken count we need
- // to add one to get the trip count.
- if (SCEVRHS == BackedgeTCExt || SCEVRHS == BackedgeTakenCount) {
- ConstantInt *One = ConstantInt::get(ConstantRHS->getType(), 1);
- Value *NewRHS = ConstantInt::get(
- ConstantRHS->getContext(), ConstantRHS->getValue() + One->getValue());
- return setLoopComponents(NewRHS, TripCount, Increment,
- IterationInstructions);
- }
- return setLoopComponents(RHS, TripCount, Increment, IterationInstructions);
- }
- // If the RHS isn't a constant then check that the reason it doesn't match
- // the SCEV trip count is because the RHS is a ZExt or SExt instruction
- // (and take the trip count to be the RHS).
- if (!IsWidened) {
- LLVM_DEBUG(dbgs() << "Could not find valid trip count\n");
- return false;
- }
- auto *TripCountInst = dyn_cast<Instruction>(RHS);
- if (!TripCountInst) {
- LLVM_DEBUG(dbgs() << "Could not find valid trip count\n");
- return false;
- }
- if ((!isa<ZExtInst>(TripCountInst) && !isa<SExtInst>(TripCountInst)) ||
- SE->getSCEV(TripCountInst->getOperand(0)) != SCEVTripCount) {
- LLVM_DEBUG(dbgs() << "Could not find valid extended trip count\n");
- return false;
- }
- return setLoopComponents(RHS, TripCount, Increment, IterationInstructions);
+
+ return verifyTripCount(RHS, L, IterationInstructions, InductionPHI, TripCount,
+ Increment, BackBranch, SE, IsWidened);
}
static bool checkPHIs(FlattenInfo &FI, const TargetTransformInfo *TTI) {
@@ -440,108 +569,26 @@ checkOuterLoopInsts(FlattenInfo &FI,
return true;
}
-static bool checkIVUsers(FlattenInfo &FI) {
- // We require all uses of both induction variables to match this pattern:
- //
- // (OuterPHI * InnerTripCount) + InnerPHI
- //
- // Any uses of the induction variables not matching that pattern would
- // require a div/mod to reconstruct in the flattened loop, so the
- // transformation wouldn't be profitable.
-
- Value *InnerTripCount = FI.InnerTripCount;
- if (FI.Widened &&
- (isa<SExtInst>(InnerTripCount) || isa<ZExtInst>(InnerTripCount)))
- InnerTripCount = cast<Instruction>(InnerTripCount)->getOperand(0);
+
+// We require all uses of both induction variables to match this pattern:
+//
+// (OuterPHI * InnerTripCount) + InnerPHI
+//
+// Any uses of the induction variables not matching that pattern would
+// require a div/mod to reconstruct in the flattened loop, so the
+// transformation wouldn't be profitable.
+static bool checkIVUsers(FlattenInfo &FI) {
// Check that all uses of the inner loop's induction variable match the
// expected pattern, recording the uses of the outer IV.
SmallPtrSet<Value *, 4> ValidOuterPHIUses;
- for (User *U : FI.InnerInductionPHI->users()) {
- if (U == FI.InnerIncrement)
- continue;
-
- // After widening the IVs, a trunc instruction might have been introduced,
- // so look through truncs.
- if (isa<TruncInst>(U)) {
- if (!U->hasOneUse())
- return false;
- U = *U->user_begin();
- }
-
- // If the use is in the compare (which is also the condition of the inner
- // branch) then the compare has been altered by another transformation e.g
- // icmp ult %inc, tripcount -> icmp ult %j, tripcount-1, where tripcount is
- // a constant. Ignore this use as the compare gets removed later anyway.
- if (U == FI.InnerBranch->getCondition())
- continue;
-
- LLVM_DEBUG(dbgs() << "Found use of inner induction variable: "; U->dump());
-
- Value *MatchedMul = nullptr;
- Value *MatchedItCount = nullptr;
- bool IsAdd = match(U, m_c_Add(m_Specific(FI.InnerInductionPHI),
- m_Value(MatchedMul))) &&
- match(MatchedMul, m_c_Mul(m_Specific(FI.OuterInductionPHI),
- m_Value(MatchedItCount)));
-
- // Matches the same pattern as above, except it also looks for truncs
- // on the phi, which can be the result of widening the induction variables.
- bool IsAddTrunc =
- match(U, m_c_Add(m_Trunc(m_Specific(FI.InnerInductionPHI)),
- m_Value(MatchedMul))) &&
- match(MatchedMul, m_c_Mul(m_Trunc(m_Specific(FI.OuterInductionPHI)),
- m_Value(MatchedItCount)));
-
- if (!MatchedItCount)
- return false;
- // Look through extends if the IV has been widened.
- if (FI.Widened &&
- (isa<SExtInst>(MatchedItCount) || isa<ZExtInst>(MatchedItCount))) {
- assert(MatchedItCount->getType() == FI.InnerInductionPHI->getType() &&
- "Unexpected type mismatch in types after widening");
- MatchedItCount = isa<SExtInst>(MatchedItCount)
- ? dyn_cast<SExtInst>(MatchedItCount)->getOperand(0)
- : dyn_cast<ZExtInst>(MatchedItCount)->getOperand(0);
- }
-
- if ((IsAdd || IsAddTrunc) && MatchedItCount == InnerTripCount) {
- LLVM_DEBUG(dbgs() << "Use is optimisable\n");
- ValidOuterPHIUses.insert(MatchedMul);
- FI.LinearIVUses.insert(U);
- } else {
- LLVM_DEBUG(dbgs() << "Did not match expected pattern, bailing\n");
- return false;
- }
- }
+ if (!FI.checkInnerInductionPhiUsers(ValidOuterPHIUses))
+ return false;
// Check that there are no uses of the outer IV other than the ones found
// as part of the pattern above.
- for (User *U : FI.OuterInductionPHI->users()) {
- if (U == FI.OuterIncrement)
- continue;
-
- auto IsValidOuterPHIUses = [&] (User *U) -> bool {
- LLVM_DEBUG(dbgs() << "Found use of outer induction variable: "; U->dump());
- if (!ValidOuterPHIUses.count(U)) {
- LLVM_DEBUG(dbgs() << "Did not match expected pattern, bailing\n");
- return false;
- }
- LLVM_DEBUG(dbgs() << "Use is optimisable\n");
- return true;
- };
-
- if (auto *V = dyn_cast<TruncInst>(U)) {
- for (auto *K : V->users()) {
- if (!IsValidOuterPHIUses(K))
- return false;
- }
- continue;
- }
-
- if (!IsValidOuterPHIUses(U))
- return false;
- }
+ if (!FI.checkOuterInductionPhiUsers(ValidOuterPHIUses))
+ return false;
LLVM_DEBUG(dbgs() << "checkIVUsers: OK\n";
dbgs() << "Found " << FI.LinearIVUses.size()
More information about the llvm-commits
mailing list