[llvm] ada6d78 - [LoopFlatten] Address FIXME about getTripCountFromExitCount. NFC.

Sjoerd Meijer via llvm-commits llvm-commits at lists.llvm.org
Mon Jan 24 05:57:03 PST 2022


Author: Sjoerd Meijer
Date: 2022-01-24T13:46:19Z
New Revision: ada6d78a7802f8057f1ab7cee0bed25f91fcc4b4

URL: https://github.com/llvm/llvm-project/commit/ada6d78a7802f8057f1ab7cee0bed25f91fcc4b4
DIFF: https://github.com/llvm/llvm-project/commit/ada6d78a7802f8057f1ab7cee0bed25f91fcc4b4.diff

LOG: [LoopFlatten] Address FIXME about getTripCountFromExitCount. NFC.

Together with the previous commit which mainly documents better LoopFlatten's
overall strategy, this addresses a concern added as a FIXME comment in D110587;
the code refactoring (NFC) introduces functions (also for the SCEV usage) to
make this clearer.

Added: 
    

Modified: 
    llvm/lib/Transforms/Scalar/LoopFlatten.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Scalar/LoopFlatten.cpp b/llvm/lib/Transforms/Scalar/LoopFlatten.cpp
index 4d9578934d9e..c46db4e63bfe 100644
--- a/llvm/lib/Transforms/Scalar/LoopFlatten.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopFlatten.cpp
@@ -149,6 +149,118 @@ struct FlattenInfo {
       return false;
     return NarrowInnerInductionPHI == Phi || NarrowOuterInductionPHI == Phi;
   }
+  bool isInnerLoopIncrement(User *U) {
+    return InnerIncrement == U;
+  }
+  bool isOuterLoopIncrement(User *U) {
+    return OuterIncrement == U;
+  }
+  bool isInnerLoopTest(User *U) {
+    return InnerBranch->getCondition() == U;
+  }
+
+  bool checkOuterInductionPhiUsers(SmallPtrSet<Value *, 4> &ValidOuterPHIUses) {
+    for (User *U : OuterInductionPHI->users()) {
+      if (isOuterLoopIncrement(U))
+        continue;
+
+      auto IsValidOuterPHIUses = [&] (User *U) -> bool {
+        LLVM_DEBUG(dbgs() << "Found use of outer induction variable: "; U->dump());
+        if (!ValidOuterPHIUses.count(U)) {
+          LLVM_DEBUG(dbgs() << "Did not match expected pattern, bailing\n");
+          return false;
+        }
+        LLVM_DEBUG(dbgs() << "Use is optimisable\n");
+        return true;
+      };
+
+      if (auto *V = dyn_cast<TruncInst>(U)) {
+        for (auto *K : V->users()) {
+          if (!IsValidOuterPHIUses(K))
+            return false;
+        }
+        continue;
+      }
+
+      if (!IsValidOuterPHIUses(U))
+        return false;
+    }
+    return true;
+  }
+
+  bool matchLinearIVUser(User *U, Value *InnerTripCount,
+                         SmallPtrSet<Value *, 4> &ValidOuterPHIUses) {
+    LLVM_DEBUG(dbgs() << "Found use of inner induction variable: "; U->dump());
+    Value *MatchedMul = nullptr;
+    Value *MatchedItCount = nullptr;
+
+    bool IsAdd = match(U, m_c_Add(m_Specific(InnerInductionPHI),
+                                  m_Value(MatchedMul))) &&
+                 match(MatchedMul, m_c_Mul(m_Specific(OuterInductionPHI),
+                                           m_Value(MatchedItCount)));
+
+    // Matches the same pattern as above, except it also looks for truncs
+    // on the phi, which can be the result of widening the induction variables.
+    bool IsAddTrunc =
+        match(U, m_c_Add(m_Trunc(m_Specific(InnerInductionPHI)),
+                         m_Value(MatchedMul))) &&
+        match(MatchedMul, m_c_Mul(m_Trunc(m_Specific(OuterInductionPHI)),
+                                  m_Value(MatchedItCount)));
+
+    if (!MatchedItCount)
+      return false;
+
+    // Look through extends if the IV has been widened.
+    if (Widened &&
+        (isa<SExtInst>(MatchedItCount) || isa<ZExtInst>(MatchedItCount))) {
+      assert(MatchedItCount->getType() == InnerInductionPHI->getType() &&
+             "Unexpected type mismatch in types after widening");
+      MatchedItCount = isa<SExtInst>(MatchedItCount)
+                           ? dyn_cast<SExtInst>(MatchedItCount)->getOperand(0)
+                           : dyn_cast<ZExtInst>(MatchedItCount)->getOperand(0);
+    }
+
+    if ((IsAdd || IsAddTrunc) && MatchedItCount == InnerTripCount) {
+      LLVM_DEBUG(dbgs() << "Use is optimisable\n");
+      ValidOuterPHIUses.insert(MatchedMul);
+      LinearIVUses.insert(U);
+      return true;
+    }
+
+    LLVM_DEBUG(dbgs() << "Did not match expected pattern, bailing\n");
+    return false;
+  }
+
+  bool checkInnerInductionPhiUsers(SmallPtrSet<Value *, 4> &ValidOuterPHIUses) {
+    Value *SExtInnerTripCount = InnerTripCount;
+    if (Widened &&
+        (isa<SExtInst>(InnerTripCount) || isa<ZExtInst>(InnerTripCount)))
+      SExtInnerTripCount = cast<Instruction>(InnerTripCount)->getOperand(0);
+
+    for (User *U : InnerInductionPHI->users()) {
+      if (isInnerLoopIncrement(U))
+        continue;
+
+      // After widening the IVs, a trunc instruction might have been introduced,
+      // so look through truncs.
+      if (isa<TruncInst>(U)) {
+        if (!U->hasOneUse())
+          return false;
+        U = *U->user_begin();
+      }
+
+      // If the use is in the compare (which is also the condition of the inner
+      // branch) then the compare has been altered by another transformation e.g
+      // icmp ult %inc, tripcount -> icmp ult %j, tripcount-1, where tripcount is
+      // a constant. Ignore this use as the compare gets removed later anyway.
+      if (isInnerLoopTest(U))
+        continue;
+
+      if (!matchLinearIVUser(U, SExtInnerTripCount, ValidOuterPHIUses))
+        return false;
+    }
+    return true;
+  }
 };
 
 static bool
@@ -162,6 +274,77 @@ setLoopComponents(Value *&TC, Value *&TripCount, BinaryOperator *&Increment,
   return true;
 }
 
+// Given the RHS of the loop latch compare instruction, verify with SCEV
+// that this is indeed the loop tripcount.
+// TODO: This used to be a straightforward check but has grown to be quite
+// complicated now. It is therefore worth revisiting what the additional
+// benefits are of this (compared to relying on canonical loops and pattern
+// matching).
+static bool verifyTripCount(Value *RHS, Loop *L,
+     SmallPtrSetImpl<Instruction *> &IterationInstructions,
+    PHINode *&InductionPHI, Value *&TripCount, BinaryOperator *&Increment,
+    BranchInst *&BackBranch, ScalarEvolution *SE, bool IsWidened) {
+  const SCEV *BackedgeTakenCount = SE->getBackedgeTakenCount(L);
+  if (isa<SCEVCouldNotCompute>(BackedgeTakenCount)) {
+    LLVM_DEBUG(dbgs() << "Backedge-taken count is not predictable\n");
+    return false;
+  }
+
+  // The Extend=false flag is used for getTripCountFromExitCount as we want
+  // to verify and match it with the pattern matched tripcount. Please note
+  // that overflow checks are performed in checkOverflow, but are first tried
+  // to avoid by widening the IV.
+  const SCEV *SCEVTripCount =
+      SE->getTripCountFromExitCount(BackedgeTakenCount, /*Extend=*/false);
+
+  const SCEV *SCEVRHS = SE->getSCEV(RHS);
+  if (SCEVRHS == SCEVTripCount)
+    return setLoopComponents(RHS, TripCount, Increment, IterationInstructions);
+  ConstantInt *ConstantRHS = dyn_cast<ConstantInt>(RHS);
+  if (ConstantRHS) {
+    const SCEV *BackedgeTCExt = nullptr;
+    if (IsWidened) {
+      const SCEV *SCEVTripCountExt;
+      // Find the extended backedge taken count and extended trip count using
+      // SCEV. One of these should now match the RHS of the compare.
+      BackedgeTCExt = SE->getZeroExtendExpr(BackedgeTakenCount, RHS->getType());
+      SCEVTripCountExt = SE->getTripCountFromExitCount(BackedgeTCExt, false);
+      if (SCEVRHS != BackedgeTCExt && SCEVRHS != SCEVTripCountExt) {
+        LLVM_DEBUG(dbgs() << "Could not find valid trip count\n");
+        return false;
+      }
+    }
+    // If the RHS of the compare is equal to the backedge taken count we need
+    // to add one to get the trip count.
+    if (SCEVRHS == BackedgeTCExt || SCEVRHS == BackedgeTakenCount) {
+      ConstantInt *One = ConstantInt::get(ConstantRHS->getType(), 1);
+      Value *NewRHS = ConstantInt::get(
+          ConstantRHS->getContext(), ConstantRHS->getValue() + One->getValue());
+      return setLoopComponents(NewRHS, TripCount, Increment,
+                               IterationInstructions);
+    }
+    return setLoopComponents(RHS, TripCount, Increment, IterationInstructions);
+  }
+  // If the RHS isn't a constant then check that the reason it doesn't match
+  // the SCEV trip count is because the RHS is a ZExt or SExt instruction
+  // (and take the trip count to be the RHS).
+  if (!IsWidened) {
+    LLVM_DEBUG(dbgs() << "Could not find valid trip count\n");
+    return false;
+  }
+  auto *TripCountInst = dyn_cast<Instruction>(RHS);
+  if (!TripCountInst) {
+    LLVM_DEBUG(dbgs() << "Could not find valid trip count\n");
+    return false;
+  }
+  if ((!isa<ZExtInst>(TripCountInst) && !isa<SExtInst>(TripCountInst)) ||
+      SE->getSCEV(TripCountInst->getOperand(0)) != SCEVTripCount) {
+    LLVM_DEBUG(dbgs() << "Could not find valid extended trip count\n");
+    return false;
+  }
+  return setLoopComponents(RHS, TripCount, Increment, IterationInstructions);
+}
+
 // Finds the induction variable, increment and trip count for a simple loop that
 // we can flatten.
 static bool findLoopComponents(
@@ -238,63 +421,9 @@ static bool findLoopComponents(
   // another transformation has changed the compare (e.g. icmp ult %inc,
   // tripcount -> icmp ult %j, tripcount-1), or both.
   Value *RHS = Compare->getOperand(1);
-  const SCEV *BackedgeTakenCount = SE->getBackedgeTakenCount(L);
-  if (isa<SCEVCouldNotCompute>(BackedgeTakenCount)) {
-    LLVM_DEBUG(dbgs() << "Backedge-taken count is not predictable\n");
-    return false;
-  }
-  // The use of the Extend=false flag on getTripCountFromExitCount was added
-  // during a refactoring to preserve existing behavior.  However, there's
-  // nothing obvious in the surrounding code when handles the overflow case.
-  // FIXME: audit code to establish whether there's a latent bug here.
-  const SCEV *SCEVTripCount =
-      SE->getTripCountFromExitCount(BackedgeTakenCount, false);
-  const SCEV *SCEVRHS = SE->getSCEV(RHS);
-  if (SCEVRHS == SCEVTripCount)
-    return setLoopComponents(RHS, TripCount, Increment, IterationInstructions);
-  ConstantInt *ConstantRHS = dyn_cast<ConstantInt>(RHS);
-  if (ConstantRHS) {
-    const SCEV *BackedgeTCExt = nullptr;
-    if (IsWidened) {
-      const SCEV *SCEVTripCountExt;
-      // Find the extended backedge taken count and extended trip count using
-      // SCEV. One of these should now match the RHS of the compare.
-      BackedgeTCExt = SE->getZeroExtendExpr(BackedgeTakenCount, RHS->getType());
-      SCEVTripCountExt = SE->getTripCountFromExitCount(BackedgeTCExt, false);
-      if (SCEVRHS != BackedgeTCExt && SCEVRHS != SCEVTripCountExt) {
-        LLVM_DEBUG(dbgs() << "Could not find valid trip count\n");
-        return false;
-      }
-    }
-    // If the RHS of the compare is equal to the backedge taken count we need
-    // to add one to get the trip count.
-    if (SCEVRHS == BackedgeTCExt || SCEVRHS == BackedgeTakenCount) {
-      ConstantInt *One = ConstantInt::get(ConstantRHS->getType(), 1);
-      Value *NewRHS = ConstantInt::get(
-          ConstantRHS->getContext(), ConstantRHS->getValue() + One->getValue());
-      return setLoopComponents(NewRHS, TripCount, Increment,
-                               IterationInstructions);
-    }
-    return setLoopComponents(RHS, TripCount, Increment, IterationInstructions);
-  }
-  // If the RHS isn't a constant then check that the reason it doesn't match
-  // the SCEV trip count is because the RHS is a ZExt or SExt instruction
-  // (and take the trip count to be the RHS).
-  if (!IsWidened) {
-    LLVM_DEBUG(dbgs() << "Could not find valid trip count\n");
-    return false;
-  }
-  auto *TripCountInst = dyn_cast<Instruction>(RHS);
-  if (!TripCountInst) {
-    LLVM_DEBUG(dbgs() << "Could not find valid trip count\n");
-    return false;
-  }
-  if ((!isa<ZExtInst>(TripCountInst) && !isa<SExtInst>(TripCountInst)) ||
-      SE->getSCEV(TripCountInst->getOperand(0)) != SCEVTripCount) {
-    LLVM_DEBUG(dbgs() << "Could not find valid extended trip count\n");
-    return false;
-  }
-  return setLoopComponents(RHS, TripCount, Increment, IterationInstructions);
+
+  return verifyTripCount(RHS, L, IterationInstructions, InductionPHI, TripCount,
+                         Increment, BackBranch, SE, IsWidened);
 }
 
 static bool checkPHIs(FlattenInfo &FI, const TargetTransformInfo *TTI) {
@@ -440,108 +569,26 @@ checkOuterLoopInsts(FlattenInfo &FI,
   return true;
 }
 
-static bool checkIVUsers(FlattenInfo &FI) {
-  // We require all uses of both induction variables to match this pattern:
-  //
-  //   (OuterPHI * InnerTripCount) + InnerPHI
-  //
-  // Any uses of the induction variables not matching that pattern would
-  // require a div/mod to reconstruct in the flattened loop, so the
-  // transformation wouldn't be profitable.
-
-  Value *InnerTripCount = FI.InnerTripCount;
-  if (FI.Widened &&
-      (isa<SExtInst>(InnerTripCount) || isa<ZExtInst>(InnerTripCount)))
-    InnerTripCount = cast<Instruction>(InnerTripCount)->getOperand(0);
 
+
+// We require all uses of both induction variables to match this pattern:
+//
+//   (OuterPHI * InnerTripCount) + InnerPHI
+//
+// Any uses of the induction variables not matching that pattern would
+// require a div/mod to reconstruct in the flattened loop, so the
+// transformation wouldn't be profitable.
+static bool checkIVUsers(FlattenInfo &FI) {
   // Check that all uses of the inner loop's induction variable match the
   // expected pattern, recording the uses of the outer IV.
   SmallPtrSet<Value *, 4> ValidOuterPHIUses;
-  for (User *U : FI.InnerInductionPHI->users()) {
-    if (U == FI.InnerIncrement)
-      continue;
-
-    // After widening the IVs, a trunc instruction might have been introduced,
-    // so look through truncs.
-    if (isa<TruncInst>(U)) {
-      if (!U->hasOneUse())
-        return false;
-      U = *U->user_begin();
-    }
-
-    // If the use is in the compare (which is also the condition of the inner
-    // branch) then the compare has been altered by another transformation e.g
-    // icmp ult %inc, tripcount -> icmp ult %j, tripcount-1, where tripcount is
-    // a constant. Ignore this use as the compare gets removed later anyway.
-    if (U == FI.InnerBranch->getCondition())
-      continue;
-
-    LLVM_DEBUG(dbgs() << "Found use of inner induction variable: "; U->dump());
-
-    Value *MatchedMul = nullptr;
-    Value *MatchedItCount = nullptr;
-    bool IsAdd = match(U, m_c_Add(m_Specific(FI.InnerInductionPHI),
-                                  m_Value(MatchedMul))) &&
-                 match(MatchedMul, m_c_Mul(m_Specific(FI.OuterInductionPHI),
-                                           m_Value(MatchedItCount)));
-
-    // Matches the same pattern as above, except it also looks for truncs
-    // on the phi, which can be the result of widening the induction variables.
-    bool IsAddTrunc =
-        match(U, m_c_Add(m_Trunc(m_Specific(FI.InnerInductionPHI)),
-                         m_Value(MatchedMul))) &&
-        match(MatchedMul, m_c_Mul(m_Trunc(m_Specific(FI.OuterInductionPHI)),
-                                  m_Value(MatchedItCount)));
-
-    if (!MatchedItCount)
-      return false;
-    // Look through extends if the IV has been widened.
-    if (FI.Widened &&
-        (isa<SExtInst>(MatchedItCount) || isa<ZExtInst>(MatchedItCount))) {
-      assert(MatchedItCount->getType() == FI.InnerInductionPHI->getType() &&
-             "Unexpected type mismatch in types after widening");
-      MatchedItCount = isa<SExtInst>(MatchedItCount)
-                           ? dyn_cast<SExtInst>(MatchedItCount)->getOperand(0)
-                           : dyn_cast<ZExtInst>(MatchedItCount)->getOperand(0);
-    }
-
-    if ((IsAdd || IsAddTrunc) && MatchedItCount == InnerTripCount) {
-      LLVM_DEBUG(dbgs() << "Use is optimisable\n");
-      ValidOuterPHIUses.insert(MatchedMul);
-      FI.LinearIVUses.insert(U);
-    } else {
-      LLVM_DEBUG(dbgs() << "Did not match expected pattern, bailing\n");
-      return false;
-    }
-  }
+  if (!FI.checkInnerInductionPhiUsers(ValidOuterPHIUses))
+    return false;
 
   // Check that there are no uses of the outer IV other than the ones found
   // as part of the pattern above.
-  for (User *U : FI.OuterInductionPHI->users()) {
-    if (U == FI.OuterIncrement)
-      continue;
-
-    auto IsValidOuterPHIUses = [&] (User *U) -> bool {
-      LLVM_DEBUG(dbgs() << "Found use of outer induction variable: "; U->dump());
-      if (!ValidOuterPHIUses.count(U)) {
-        LLVM_DEBUG(dbgs() << "Did not match expected pattern, bailing\n");
-        return false;
-      }
-      LLVM_DEBUG(dbgs() << "Use is optimisable\n");
-      return true;
-    };
-
-    if (auto *V = dyn_cast<TruncInst>(U)) {
-      for (auto *K : V->users()) {
-        if (!IsValidOuterPHIUses(K))
-          return false;
-      }
-      continue;
-    }
-
-    if (!IsValidOuterPHIUses(U))
-      return false;
-  }
+  if (!FI.checkOuterInductionPhiUsers(ValidOuterPHIUses))
+    return false;
 
   LLVM_DEBUG(dbgs() << "checkIVUsers: OK\n";
              dbgs() << "Found " << FI.LinearIVUses.size()


        


More information about the llvm-commits mailing list