[llvm] b115ba0 - [NFC] Introduce range based singleton searches for loop queries.

Jamie Schmeiser via llvm-commits llvm-commits at lists.llvm.org
Wed Oct 26 10:50:59 PDT 2022


Author: Jamie Schmeiser
Date: 2022-10-26T13:50:11-04:00
New Revision: b115ba005030281d9dc85655fbfddf99c0dd82c1

URL: https://github.com/llvm/llvm-project/commit/b115ba005030281d9dc85655fbfddf99c0dd82c1
DIFF: https://github.com/llvm/llvm-project/commit/b115ba005030281d9dc85655fbfddf99c0dd82c1.diff

LOG: [NFC] Introduce range based singleton searches for loop queries.

Summary:
Several loop queries look for a singleton by finding all instances and then
returning whether there is 1 instance or not. This can be improved by
stopping the search after 2 have been found. Introduce generic range
based singleton searches that stop after finding a second value
and use them for these loop queries.

There is no intended functional change other than improved compile-time
efficiency.

Author: Jamie Schmeiser <schmeise at ca.ibm.com>
Reviewed By: Meinersbur (Michael Kruse)
Differential Revision: https://reviews.llvm.org/D136261

Added: 
    

Modified: 
    llvm/include/llvm/ADT/STLExtras.h
    llvm/include/llvm/Analysis/LoopInfoImpl.h
    llvm/include/llvm/Analysis/RegionInfoImpl.h

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/ADT/STLExtras.h b/llvm/include/llvm/ADT/STLExtras.h
index 2c242b317470b..1fe6609986f6a 100644
--- a/llvm/include/llvm/ADT/STLExtras.h
+++ b/llvm/include/llvm/ADT/STLExtras.h
@@ -1636,6 +1636,57 @@ OutputIt copy_if(R &&Range, OutputIt Out, UnaryPredicate P) {
   return std::copy_if(adl_begin(Range), adl_end(Range), Out, P);
 }
 
+/// Return the single value in \p Range that satisfies
+/// \p P(<member of \p Range> *, AllowRepeats)->T * returning nullptr
+/// when no values or multiple values were found.
+/// When \p AllowRepeats is true, multiple values that compare equal
+/// are allowed.
+template <typename T, typename R, typename Predicate>
+T *find_singleton(R &&Range, Predicate P, bool AllowRepeats = false) {
+  T *RC = nullptr;
+  for (auto *A : Range) {
+    if (T *PRC = P(A, AllowRepeats)) {
+      if (RC) {
+        if (!AllowRepeats || PRC != RC)
+          return nullptr;
+      } else
+        RC = PRC;
+    }
+  }
+  return RC;
+}
+
+/// Return a pair consisting of the single value in \p Range that satisfies
+/// \p P(<member of \p Range> *, AllowRepeats)->std::pair<T*, bool> returning
+/// nullptr when no values or multiple values were found, and a bool indicating
+/// whether multiple values were found to cause the nullptr.
+/// When \p AllowRepeats is true, multiple values that compare equal are
+/// allowed.  The predicate \p P returns a pair<T *, bool> where T is the
+/// singleton while the bool indicates whether multiples have already been
+/// found.  It is expected that first will be nullptr when second is true.
+/// This allows using find_singleton_nested within the predicate \P.
+template <typename T, typename R, typename Predicate>
+std::pair<T *, bool> find_singleton_nested(R &&Range, Predicate P,
+                                           bool AllowRepeats = false) {
+  T *RC = nullptr;
+  for (auto *A : Range) {
+    std::pair<T *, bool> PRC = P(A, AllowRepeats);
+    if (PRC.second) {
+      assert(PRC.first == nullptr &&
+             "Inconsistent return values in find_singleton_nested.");
+      return PRC;
+    }
+    if (PRC.first) {
+      if (RC) {
+        if (!AllowRepeats || PRC.first != RC)
+          return {nullptr, true};
+      } else
+        RC = PRC.first;
+    }
+  }
+  return {RC, false};
+}
+
 template <typename R, typename OutputIt>
 OutputIt copy(R &&Range, OutputIt Out) {
   return std::copy(adl_begin(Range), adl_end(Range), Out);

diff  --git a/llvm/include/llvm/Analysis/LoopInfoImpl.h b/llvm/include/llvm/Analysis/LoopInfoImpl.h
index 0a21c45f12694..c509ee67cbacc 100644
--- a/llvm/include/llvm/Analysis/LoopInfoImpl.h
+++ b/llvm/include/llvm/Analysis/LoopInfoImpl.h
@@ -47,11 +47,14 @@ void LoopBase<BlockT, LoopT>::getExitingBlocks(
 template <class BlockT, class LoopT>
 BlockT *LoopBase<BlockT, LoopT>::getExitingBlock() const {
   assert(!isInvalid() && "Loop not in a valid state!");
-  SmallVector<BlockT *, 8> ExitingBlocks;
-  getExitingBlocks(ExitingBlocks);
-  if (ExitingBlocks.size() == 1)
-    return ExitingBlocks[0];
-  return nullptr;
+  auto notInLoop = [&](BlockT *BB) { return !contains(BB); };
+  auto isExitBlock = [&](BlockT *BB, bool AllowRepeats) -> BlockT * {
+    assert(!AllowRepeats && "Unexpected parameter value.");
+    // Child not in current loop?  It must be an exit block.
+    return any_of(children<BlockT *>(BB), notInLoop) ? BB : nullptr;
+  };
+
+  return find_singleton<BlockT>(blocks(), isExitBlock);
 }
 
 /// getExitBlocks - Return all of the successor blocks of this loop.  These
@@ -68,23 +71,41 @@ void LoopBase<BlockT, LoopT>::getExitBlocks(
         ExitBlocks.push_back(Succ);
 }
 
+/// getExitBlock - If getExitBlocks would return exactly one block,
+/// return that block. Otherwise return null.
+template <class BlockT, class LoopT>
+std::pair<BlockT *, bool> getExitBlockHelper(const LoopBase<BlockT, LoopT> *L,
+                                             bool Unique) {
+  assert(!L->isInvalid() && "Loop not in a valid state!");
+  auto notInLoop = [&](BlockT *BB,
+                       bool AllowRepeats) -> std::pair<BlockT *, bool> {
+    assert(AllowRepeats == Unique && "Unexpected parameter value.");
+    return {!L->contains(BB) ? BB : nullptr, false};
+  };
+  auto singleExitBlock = [&](BlockT *BB,
+                             bool AllowRepeats) -> std::pair<BlockT *, bool> {
+    assert(AllowRepeats == Unique && "Unexpected parameter value.");
+    return find_singleton_nested<BlockT>(children<BlockT *>(BB), notInLoop,
+                                         AllowRepeats);
+  };
+  return find_singleton_nested<BlockT>(L->blocks(), singleExitBlock, Unique);
+}
+
 template <class BlockT, class LoopT>
 bool LoopBase<BlockT, LoopT>::hasNoExitBlocks() const {
-  SmallVector<BlockT *, 8> ExitBlocks;
-  getExitBlocks(ExitBlocks);
-  return ExitBlocks.empty();
+  auto RC = getExitBlockHelper(this, false);
+  if (RC.second)
+    // found multiple exit blocks
+    return false;
+  // return true if there is no exit block
+  return !RC.first;
 }
 
 /// getExitBlock - If getExitBlocks would return exactly one block,
 /// return that block. Otherwise return null.
 template <class BlockT, class LoopT>
 BlockT *LoopBase<BlockT, LoopT>::getExitBlock() const {
-  assert(!isInvalid() && "Loop not in a valid state!");
-  SmallVector<BlockT *, 8> ExitBlocks;
-  getExitBlocks(ExitBlocks);
-  if (ExitBlocks.size() == 1)
-    return ExitBlocks[0];
-  return nullptr;
+  return getExitBlockHelper(this, false).first;
 }
 
 template <class BlockT, class LoopT>
@@ -135,11 +156,7 @@ void LoopBase<BlockT, LoopT>::getUniqueNonLatchExitBlocks(
 
 template <class BlockT, class LoopT>
 BlockT *LoopBase<BlockT, LoopT>::getUniqueExitBlock() const {
-  SmallVector<BlockT *, 8> UniqueExitBlocks;
-  getUniqueExitBlocks(UniqueExitBlocks);
-  if (UniqueExitBlocks.size() == 1)
-    return UniqueExitBlocks[0];
-  return nullptr;
+  return getExitBlockHelper(this, true).first;
 }
 
 /// getExitEdges - Return all pairs of (_inside_block_,_outside_block_).

diff  --git a/llvm/include/llvm/Analysis/RegionInfoImpl.h b/llvm/include/llvm/Analysis/RegionInfoImpl.h
index 9a18e72eebec7..7fdfdd0efba84 100644
--- a/llvm/include/llvm/Analysis/RegionInfoImpl.h
+++ b/llvm/include/llvm/Analysis/RegionInfoImpl.h
@@ -159,20 +159,14 @@ typename Tr::LoopT *RegionBase<Tr>::outermostLoopInRegion(LoopInfoT *LI,
 
 template <class Tr>
 typename RegionBase<Tr>::BlockT *RegionBase<Tr>::getEnteringBlock() const {
+  auto isEnteringBlock = [&](BlockT *Pred, bool AllowRepeats) -> BlockT * {
+    assert(!AllowRepeats && "Unexpected parameter value.");
+    return DT->getNode(Pred) && !contains(Pred) ? Pred : nullptr;
+  };
   BlockT *entry = getEntry();
-  BlockT *enteringBlock = nullptr;
-
-  for (BlockT *Pred : make_range(InvBlockTraits::child_begin(entry),
-                                 InvBlockTraits::child_end(entry))) {
-    if (DT->getNode(Pred) && !contains(Pred)) {
-      if (enteringBlock)
-        return nullptr;
-
-      enteringBlock = Pred;
-    }
-  }
-
-  return enteringBlock;
+  return find_singleton<BlockT>(make_range(InvBlockTraits::child_begin(entry),
+                                           InvBlockTraits::child_end(entry)),
+                                isEnteringBlock);
 }
 
 template <class Tr>
@@ -201,22 +195,16 @@ bool RegionBase<Tr>::getExitingBlocks(
 template <class Tr>
 typename RegionBase<Tr>::BlockT *RegionBase<Tr>::getExitingBlock() const {
   BlockT *exit = getExit();
-  BlockT *exitingBlock = nullptr;
-
   if (!exit)
     return nullptr;
 
-  for (BlockT *Pred : make_range(InvBlockTraits::child_begin(exit),
-                                 InvBlockTraits::child_end(exit))) {
-    if (contains(Pred)) {
-      if (exitingBlock)
-        return nullptr;
-
-      exitingBlock = Pred;
-    }
-  }
-
-  return exitingBlock;
+  auto isContained = [&](BlockT *Pred, bool AllowRepeats) -> BlockT * {
+    assert(!AllowRepeats && "Unexpected parameter value.");
+    return contains(Pred) ? Pred : nullptr;
+  };
+  return find_singleton<BlockT>(make_range(InvBlockTraits::child_begin(exit),
+                                           InvBlockTraits::child_end(exit)),
+                                isContained);
 }
 
 template <class Tr>


        


More information about the llvm-commits mailing list