[llvm] [llvm/llvm-project][Coroutines] ABI Object (PR #106306)

Tyler Nowicki via llvm-commits llvm-commits at lists.llvm.org
Tue Sep 3 07:28:29 PDT 2024


https://github.com/TylerNowicki updated https://github.com/llvm/llvm-project/pull/106306

>From 06cde7ca59c29bfb0a96ce1acad6e3ff6b75b8f9 Mon Sep 17 00:00:00 2001
From: tnowicki <tnowicki.nowicki at amd.com>
Date: Fri, 23 Aug 2024 13:12:17 -0400
Subject: [PATCH] [llvm/llvm-project][Coroutines] Major refactoring of
 SuspendCrossingInfo

* Move SuspendCrossingInfo to its own files to clean up CoroFrame
---
 llvm/lib/Transforms/Coroutines/CMakeLists.txt |   1 +
 llvm/lib/Transforms/Coroutines/CoroFrame.cpp  | 322 +-----------------
 .../Coroutines/SuspendCrossingInfo.cpp        | 195 +++++++++++
 .../Coroutines/SuspendCrossingInfo.h          | 182 ++++++++++
 4 files changed, 390 insertions(+), 310 deletions(-)
 create mode 100644 llvm/lib/Transforms/Coroutines/SuspendCrossingInfo.cpp
 create mode 100644 llvm/lib/Transforms/Coroutines/SuspendCrossingInfo.h

diff --git a/llvm/lib/Transforms/Coroutines/CMakeLists.txt b/llvm/lib/Transforms/Coroutines/CMakeLists.txt
index 2139446e5ff957..57359acc81ee4c 100644
--- a/llvm/lib/Transforms/Coroutines/CMakeLists.txt
+++ b/llvm/lib/Transforms/Coroutines/CMakeLists.txt
@@ -6,6 +6,7 @@ add_llvm_component_library(LLVMCoroutines
   CoroElide.cpp
   CoroFrame.cpp
   CoroSplit.cpp
+  SuspendCrossingInfo.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${LLVM_MAIN_INCLUDE_DIR}/llvm/Transforms/Coroutines
diff --git a/llvm/lib/Transforms/Coroutines/CoroFrame.cpp b/llvm/lib/Transforms/Coroutines/CoroFrame.cpp
index f76cfe01b58cfd..ca5e9504264b25 100644
--- a/llvm/lib/Transforms/Coroutines/CoroFrame.cpp
+++ b/llvm/lib/Transforms/Coroutines/CoroFrame.cpp
@@ -17,6 +17,7 @@
 
 #include "CoroInternal.h"
 #include "llvm/ADT/BitVector.h"
+#include "SuspendCrossingInfo.h"
 #include "llvm/ADT/PostOrderIterator.h"
 #include "llvm/ADT/ScopeExit.h"
 #include "llvm/ADT/SmallString.h"
@@ -51,315 +52,6 @@ extern cl::opt<bool> UseNewDbgInfoFormat;
 // "coro-frame", which results in leaner debug spew.
 #define DEBUG_TYPE "coro-suspend-crossing"
 
-enum { SmallVectorThreshold = 32 };
-
-// Provides two way mapping between the blocks and numbers.
-namespace {
-class BlockToIndexMapping {
-  SmallVector<BasicBlock *, SmallVectorThreshold> V;
-
-public:
-  size_t size() const { return V.size(); }
-
-  BlockToIndexMapping(Function &F) {
-    for (BasicBlock &BB : F)
-      V.push_back(&BB);
-    llvm::sort(V);
-  }
-
-  size_t blockToIndex(BasicBlock const *BB) const {
-    auto *I = llvm::lower_bound(V, BB);
-    assert(I != V.end() && *I == BB && "BasicBlockNumberng: Unknown block");
-    return I - V.begin();
-  }
-
-  BasicBlock *indexToBlock(unsigned Index) const { return V[Index]; }
-};
-} // end anonymous namespace
-
-// The SuspendCrossingInfo maintains data that allows to answer a question
-// whether given two BasicBlocks A and B there is a path from A to B that
-// passes through a suspend point.
-//
-// For every basic block 'i' it maintains a BlockData that consists of:
-//   Consumes:  a bit vector which contains a set of indices of blocks that can
-//              reach block 'i'. A block can trivially reach itself.
-//   Kills: a bit vector which contains a set of indices of blocks that can
-//          reach block 'i' but there is a path crossing a suspend point
-//          not repeating 'i' (path to 'i' without cycles containing 'i').
-//   Suspend: a boolean indicating whether block 'i' contains a suspend point.
-//   End: a boolean indicating whether block 'i' contains a coro.end intrinsic.
-//   KillLoop: There is a path from 'i' to 'i' not otherwise repeating 'i' that
-//             crosses a suspend point.
-//
-namespace {
-class SuspendCrossingInfo {
-  BlockToIndexMapping Mapping;
-
-  struct BlockData {
-    BitVector Consumes;
-    BitVector Kills;
-    bool Suspend = false;
-    bool End = false;
-    bool KillLoop = false;
-    bool Changed = false;
-  };
-  SmallVector<BlockData, SmallVectorThreshold> Block;
-
-  iterator_range<pred_iterator> predecessors(BlockData const &BD) const {
-    BasicBlock *BB = Mapping.indexToBlock(&BD - &Block[0]);
-    return llvm::predecessors(BB);
-  }
-
-  BlockData &getBlockData(BasicBlock *BB) {
-    return Block[Mapping.blockToIndex(BB)];
-  }
-
-  /// Compute the BlockData for the current function in one iteration.
-  /// Initialize - Whether this is the first iteration, we can optimize
-  /// the initial case a little bit by manual loop switch.
-  /// Returns whether the BlockData changes in this iteration.
-  template <bool Initialize = false>
-  bool computeBlockData(const ReversePostOrderTraversal<Function *> &RPOT);
-
-public:
-#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
-  void dump() const;
-  void dump(StringRef Label, BitVector const &BV,
-            const ReversePostOrderTraversal<Function *> &RPOT) const;
-#endif
-
-  SuspendCrossingInfo(Function &F, coro::Shape &Shape);
-
-  /// Returns true if there is a path from \p From to \p To crossing a suspend
-  /// point without crossing \p From a 2nd time.
-  bool hasPathCrossingSuspendPoint(BasicBlock *From, BasicBlock *To) const {
-    size_t const FromIndex = Mapping.blockToIndex(From);
-    size_t const ToIndex = Mapping.blockToIndex(To);
-    bool const Result = Block[ToIndex].Kills[FromIndex];
-    LLVM_DEBUG(dbgs() << From->getName() << " => " << To->getName()
-                      << " answer is " << Result << "\n");
-    return Result;
-  }
-
-  /// Returns true if there is a path from \p From to \p To crossing a suspend
-  /// point without crossing \p From a 2nd time. If \p From is the same as \p To
-  /// this will also check if there is a looping path crossing a suspend point.
-  bool hasPathOrLoopCrossingSuspendPoint(BasicBlock *From,
-                                         BasicBlock *To) const {
-    size_t const FromIndex = Mapping.blockToIndex(From);
-    size_t const ToIndex = Mapping.blockToIndex(To);
-    bool Result = Block[ToIndex].Kills[FromIndex] ||
-                  (From == To && Block[ToIndex].KillLoop);
-    LLVM_DEBUG(dbgs() << From->getName() << " => " << To->getName()
-                      << " answer is " << Result << " (path or loop)\n");
-    return Result;
-  }
-
-  bool isDefinitionAcrossSuspend(BasicBlock *DefBB, User *U) const {
-    auto *I = cast<Instruction>(U);
-
-    // We rewrote PHINodes, so that only the ones with exactly one incoming
-    // value need to be analyzed.
-    if (auto *PN = dyn_cast<PHINode>(I))
-      if (PN->getNumIncomingValues() > 1)
-        return false;
-
-    BasicBlock *UseBB = I->getParent();
-
-    // As a special case, treat uses by an llvm.coro.suspend.retcon or an
-    // llvm.coro.suspend.async as if they were uses in the suspend's single
-    // predecessor: the uses conceptually occur before the suspend.
-    if (isa<CoroSuspendRetconInst>(I) || isa<CoroSuspendAsyncInst>(I)) {
-      UseBB = UseBB->getSinglePredecessor();
-      assert(UseBB && "should have split coro.suspend into its own block");
-    }
-
-    return hasPathCrossingSuspendPoint(DefBB, UseBB);
-  }
-
-  bool isDefinitionAcrossSuspend(Argument &A, User *U) const {
-    return isDefinitionAcrossSuspend(&A.getParent()->getEntryBlock(), U);
-  }
-
-  bool isDefinitionAcrossSuspend(Instruction &I, User *U) const {
-    auto *DefBB = I.getParent();
-
-    // As a special case, treat values produced by an llvm.coro.suspend.*
-    // as if they were defined in the single successor: the uses
-    // conceptually occur after the suspend.
-    if (isa<AnyCoroSuspendInst>(I)) {
-      DefBB = DefBB->getSingleSuccessor();
-      assert(DefBB && "should have split coro.suspend into its own block");
-    }
-
-    return isDefinitionAcrossSuspend(DefBB, U);
-  }
-
-  bool isDefinitionAcrossSuspend(Value &V, User *U) const {
-    if (auto *Arg = dyn_cast<Argument>(&V))
-      return isDefinitionAcrossSuspend(*Arg, U);
-    if (auto *Inst = dyn_cast<Instruction>(&V))
-      return isDefinitionAcrossSuspend(*Inst, U);
-
-    llvm_unreachable(
-        "Coroutine could only collect Argument and Instruction now.");
-  }
-};
-} // end anonymous namespace
-
-#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
-static std::string getBasicBlockLabel(const BasicBlock *BB) {
-  if (BB->hasName())
-    return BB->getName().str();
-
-  std::string S;
-  raw_string_ostream OS(S);
-  BB->printAsOperand(OS, false);
-  return OS.str().substr(1);
-}
-
-LLVM_DUMP_METHOD void SuspendCrossingInfo::dump(
-    StringRef Label, BitVector const &BV,
-    const ReversePostOrderTraversal<Function *> &RPOT) const {
-  dbgs() << Label << ":";
-  for (const BasicBlock *BB : RPOT) {
-    auto BBNo = Mapping.blockToIndex(BB);
-    if (BV[BBNo])
-      dbgs() << " " << getBasicBlockLabel(BB);
-  }
-  dbgs() << "\n";
-}
-
-LLVM_DUMP_METHOD void SuspendCrossingInfo::dump() const {
-  if (Block.empty())
-    return;
-
-  BasicBlock *const B = Mapping.indexToBlock(0);
-  Function *F = B->getParent();
-
-  ReversePostOrderTraversal<Function *> RPOT(F);
-  for (const BasicBlock *BB : RPOT) {
-    auto BBNo = Mapping.blockToIndex(BB);
-    dbgs() << getBasicBlockLabel(BB) << ":\n";
-    dump("   Consumes", Block[BBNo].Consumes, RPOT);
-    dump("      Kills", Block[BBNo].Kills, RPOT);
-  }
-  dbgs() << "\n";
-}
-#endif
-
-template <bool Initialize>
-bool SuspendCrossingInfo::computeBlockData(
-    const ReversePostOrderTraversal<Function *> &RPOT) {
-  bool Changed = false;
-
-  for (const BasicBlock *BB : RPOT) {
-    auto BBNo = Mapping.blockToIndex(BB);
-    auto &B = Block[BBNo];
-
-    // We don't need to count the predecessors when initialization.
-    if constexpr (!Initialize)
-      // If all the predecessors of the current Block don't change,
-      // the BlockData for the current block must not change too.
-      if (all_of(predecessors(B), [this](BasicBlock *BB) {
-            return !Block[Mapping.blockToIndex(BB)].Changed;
-          })) {
-        B.Changed = false;
-        continue;
-      }
-
-    // Saved Consumes and Kills bitsets so that it is easy to see
-    // if anything changed after propagation.
-    auto SavedConsumes = B.Consumes;
-    auto SavedKills = B.Kills;
-
-    for (BasicBlock *PI : predecessors(B)) {
-      auto PrevNo = Mapping.blockToIndex(PI);
-      auto &P = Block[PrevNo];
-
-      // Propagate Kills and Consumes from predecessors into B.
-      B.Consumes |= P.Consumes;
-      B.Kills |= P.Kills;
-
-      // If block P is a suspend block, it should propagate kills into block
-      // B for every block P consumes.
-      if (P.Suspend)
-        B.Kills |= P.Consumes;
-    }
-
-    if (B.Suspend) {
-      // If block B is a suspend block, it should kill all of the blocks it
-      // consumes.
-      B.Kills |= B.Consumes;
-    } else if (B.End) {
-      // If block B is an end block, it should not propagate kills as the
-      // blocks following coro.end() are reached during initial invocation
-      // of the coroutine while all the data are still available on the
-      // stack or in the registers.
-      B.Kills.reset();
-    } else {
-      // This is reached when B block it not Suspend nor coro.end and it
-      // need to make sure that it is not in the kill set.
-      B.KillLoop |= B.Kills[BBNo];
-      B.Kills.reset(BBNo);
-    }
-
-    if constexpr (!Initialize) {
-      B.Changed = (B.Kills != SavedKills) || (B.Consumes != SavedConsumes);
-      Changed |= B.Changed;
-    }
-  }
-
-  return Changed;
-}
-
-SuspendCrossingInfo::SuspendCrossingInfo(Function &F, coro::Shape &Shape)
-    : Mapping(F) {
-  const size_t N = Mapping.size();
-  Block.resize(N);
-
-  // Initialize every block so that it consumes itself
-  for (size_t I = 0; I < N; ++I) {
-    auto &B = Block[I];
-    B.Consumes.resize(N);
-    B.Kills.resize(N);
-    B.Consumes.set(I);
-    B.Changed = true;
-  }
-
-  // Mark all CoroEnd Blocks. We do not propagate Kills beyond coro.ends as
-  // the code beyond coro.end is reachable during initial invocation of the
-  // coroutine.
-  for (auto *CE : Shape.CoroEnds)
-    getBlockData(CE->getParent()).End = true;
-
-  // Mark all suspend blocks and indicate that they kill everything they
-  // consume. Note, that crossing coro.save also requires a spill, as any code
-  // between coro.save and coro.suspend may resume the coroutine and all of the
-  // state needs to be saved by that time.
-  auto markSuspendBlock = [&](IntrinsicInst *BarrierInst) {
-    BasicBlock *SuspendBlock = BarrierInst->getParent();
-    auto &B = getBlockData(SuspendBlock);
-    B.Suspend = true;
-    B.Kills |= B.Consumes;
-  };
-  for (auto *CSI : Shape.CoroSuspends) {
-    markSuspendBlock(CSI);
-    if (auto *Save = CSI->getCoroSave())
-      markSuspendBlock(Save);
-  }
-
-  // It is considered to be faster to use RPO traversal for forward-edges
-  // dataflow analysis.
-  ReversePostOrderTraversal<Function *> RPOT(&F);
-  computeBlockData</*Initialize=*/true>(RPOT);
-  while (computeBlockData</*Initialize*/ false>(RPOT))
-    ;
-
-  LLVM_DEBUG(dump());
-}
-
 namespace {
 
 // RematGraph is used to construct a DAG for rematerializable instructions
@@ -438,6 +130,16 @@ struct RematGraph {
   }
 
 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+  static std::string getBasicBlockLabel(const BasicBlock *BB) {
+    if (BB->hasName())
+      return BB->getName().str();
+
+    std::string S;
+    raw_string_ostream OS(S);
+    BB->printAsOperand(OS, false);
+    return OS.str().substr(1);
+  }
+
   void dump() const {
     dbgs() << "Entry (";
     dbgs() << getBasicBlockLabel(EntryNode->Node->getParent());
@@ -3159,7 +2861,7 @@ void coro::buildCoroutineFrame(
   rewritePHIs(F);
 
   // Build suspend crossing info.
-  SuspendCrossingInfo Checker(F, Shape);
+  SuspendCrossingInfo Checker(F, Shape.CoroSuspends, Shape.CoroEnds);
 
   doRematerializations(F, Checker, MaterializableCallback);
 
diff --git a/llvm/lib/Transforms/Coroutines/SuspendCrossingInfo.cpp b/llvm/lib/Transforms/Coroutines/SuspendCrossingInfo.cpp
new file mode 100644
index 00000000000000..ff3b32e958edac
--- /dev/null
+++ b/llvm/lib/Transforms/Coroutines/SuspendCrossingInfo.cpp
@@ -0,0 +1,195 @@
+//===- SuspendCrossingInfo.cpp - Utility for suspend crossing values ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+// The SuspendCrossingInfo maintains data that allows to answer a question
+// whether given two BasicBlocks A and B there is a path from A to B that
+// passes through a suspend point.
+//===----------------------------------------------------------------------===//
+
+#include "SuspendCrossingInfo.h"
+
+// The "coro-suspend-crossing" flag is very noisy. There is another debug type,
+// "coro-frame", which results in leaner debug spew.
+#define DEBUG_TYPE "coro-suspend-crossing"
+
+namespace llvm {
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+static std::string getBasicBlockLabel(const BasicBlock *BB) {
+  if (BB->hasName())
+    return BB->getName().str();
+
+  std::string S;
+  raw_string_ostream OS(S);
+  BB->printAsOperand(OS, false);
+  return OS.str().substr(1);
+}
+
+LLVM_DUMP_METHOD void SuspendCrossingInfo::dump(
+    StringRef Label, BitVector const &BV,
+    const ReversePostOrderTraversal<Function *> &RPOT) const {
+  dbgs() << Label << ":";
+  for (const BasicBlock *BB : RPOT) {
+    auto BBNo = Mapping.blockToIndex(BB);
+    if (BV[BBNo])
+      dbgs() << " " << getBasicBlockLabel(BB);
+  }
+  dbgs() << "\n";
+}
+
+LLVM_DUMP_METHOD void SuspendCrossingInfo::dump() const {
+  if (Block.empty())
+    return;
+
+  BasicBlock *const B = Mapping.indexToBlock(0);
+  Function *F = B->getParent();
+
+  ReversePostOrderTraversal<Function *> RPOT(F);
+  for (const BasicBlock *BB : RPOT) {
+    auto BBNo = Mapping.blockToIndex(BB);
+    dbgs() << getBasicBlockLabel(BB) << ":\n";
+    dump("   Consumes", Block[BBNo].Consumes, RPOT);
+    dump("      Kills", Block[BBNo].Kills, RPOT);
+  }
+  dbgs() << "\n";
+}
+#endif
+
+bool SuspendCrossingInfo::hasPathCrossingSuspendPoint(BasicBlock *From,
+                                                      BasicBlock *To) const {
+  size_t const FromIndex = Mapping.blockToIndex(From);
+  size_t const ToIndex = Mapping.blockToIndex(To);
+  bool const Result = Block[ToIndex].Kills[FromIndex];
+  LLVM_DEBUG(if (Result) dbgs() << From->getName() << " => " << To->getName()
+                                << " crosses suspend point\n");
+  return Result;
+}
+
+bool SuspendCrossingInfo::hasPathOrLoopCrossingSuspendPoint(
+    BasicBlock *From, BasicBlock *To) const {
+  size_t const FromIndex = Mapping.blockToIndex(From);
+  size_t const ToIndex = Mapping.blockToIndex(To);
+  bool Result = Block[ToIndex].Kills[FromIndex] ||
+                (From == To && Block[ToIndex].KillLoop);
+  LLVM_DEBUG(if (Result) dbgs() << From->getName() << " => " << To->getName()
+                                << " crosses suspend point (path or loop)\n");
+  return Result;
+}
+
+template <bool Initialize>
+bool SuspendCrossingInfo::computeBlockData(
+    const ReversePostOrderTraversal<Function *> &RPOT) {
+  bool Changed = false;
+
+  for (const BasicBlock *BB : RPOT) {
+    auto BBNo = Mapping.blockToIndex(BB);
+    auto &B = Block[BBNo];
+
+    // We don't need to count the predecessors when initialization.
+    if constexpr (!Initialize)
+      // If all the predecessors of the current Block don't change,
+      // the BlockData for the current block must not change too.
+      if (all_of(predecessors(B), [this](BasicBlock *BB) {
+            return !Block[Mapping.blockToIndex(BB)].Changed;
+          })) {
+        B.Changed = false;
+        continue;
+      }
+
+    // Saved Consumes and Kills bitsets so that it is easy to see
+    // if anything changed after propagation.
+    auto SavedConsumes = B.Consumes;
+    auto SavedKills = B.Kills;
+
+    for (BasicBlock *PI : predecessors(B)) {
+      auto PrevNo = Mapping.blockToIndex(PI);
+      auto &P = Block[PrevNo];
+
+      // Propagate Kills and Consumes from predecessors into B.
+      B.Consumes |= P.Consumes;
+      B.Kills |= P.Kills;
+
+      // If block P is a suspend block, it should propagate kills into block
+      // B for every block P consumes.
+      if (P.Suspend)
+        B.Kills |= P.Consumes;
+    }
+
+    if (B.Suspend) {
+      // If block B is a suspend block, it should kill all of the blocks it
+      // consumes.
+      B.Kills |= B.Consumes;
+    } else if (B.End) {
+      // If block B is an end block, it should not propagate kills as the
+      // blocks following coro.end() are reached during initial invocation
+      // of the coroutine while all the data are still available on the
+      // stack or in the registers.
+      B.Kills.reset();
+    } else {
+      // This is reached when B block it not Suspend nor coro.end and it
+      // need to make sure that it is not in the kill set.
+      B.KillLoop |= B.Kills[BBNo];
+      B.Kills.reset(BBNo);
+    }
+
+    if constexpr (!Initialize) {
+      B.Changed = (B.Kills != SavedKills) || (B.Consumes != SavedConsumes);
+      Changed |= B.Changed;
+    }
+  }
+
+  return Changed;
+}
+
+SuspendCrossingInfo::SuspendCrossingInfo(
+    Function &F, const SmallVectorImpl<AnyCoroSuspendInst *> &CoroSuspends,
+    const SmallVectorImpl<AnyCoroEndInst *> &CoroEnds)
+    : Mapping(F) {
+  const size_t N = Mapping.size();
+  Block.resize(N);
+
+  // Initialize every block so that it consumes itself
+  for (size_t I = 0; I < N; ++I) {
+    auto &B = Block[I];
+    B.Consumes.resize(N);
+    B.Kills.resize(N);
+    B.Consumes.set(I);
+    B.Changed = true;
+  }
+
+  // Mark all CoroEnd Blocks. We do not propagate Kills beyond coro.ends as
+  // the code beyond coro.end is reachable during initial invocation of the
+  // coroutine.
+  for (auto *CE : CoroEnds)
+    getBlockData(CE->getParent()).End = true;
+
+  // Mark all suspend blocks and indicate that they kill everything they
+  // consume. Note, that crossing coro.save also requires a spill, as any code
+  // between coro.save and coro.suspend may resume the coroutine and all of the
+  // state needs to be saved by that time.
+  auto markSuspendBlock = [&](IntrinsicInst *BarrierInst) {
+    BasicBlock *SuspendBlock = BarrierInst->getParent();
+    auto &B = getBlockData(SuspendBlock);
+    B.Suspend = true;
+    B.Kills |= B.Consumes;
+  };
+  for (auto *CSI : CoroSuspends) {
+    markSuspendBlock(CSI);
+    if (auto *Save = CSI->getCoroSave())
+      markSuspendBlock(Save);
+  }
+
+  // It is considered to be faster to use RPO traversal for forward-edges
+  // dataflow analysis.
+  ReversePostOrderTraversal<Function *> RPOT(&F);
+  computeBlockData</*Initialize=*/true>(RPOT);
+  while (computeBlockData</*Initialize*/ false>(RPOT))
+    ;
+
+  LLVM_DEBUG(dump());
+}
+
+} // namespace llvm
diff --git a/llvm/lib/Transforms/Coroutines/SuspendCrossingInfo.h b/llvm/lib/Transforms/Coroutines/SuspendCrossingInfo.h
new file mode 100644
index 00000000000000..d1551af0b54975
--- /dev/null
+++ b/llvm/lib/Transforms/Coroutines/SuspendCrossingInfo.h
@@ -0,0 +1,182 @@
+//===- SuspendCrossingInfo.cpp - Utility for suspend crossing values ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+// The SuspendCrossingInfo maintains data that allows to answer a question
+// whether given two BasicBlocks A and B there is a path from A to B that
+// passes through a suspend point.
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_LIB_TRANSFORMS_COROUTINES_SUSPENDCROSSINGINFO_H
+#define LLVM_LIB_TRANSFORMS_COROUTINES_SUSPENDCROSSINGINFO_H
+
+#include "CoroInstr.h"
+#include "llvm/ADT/BitVector.h"
+#include "llvm/ADT/PostOrderIterator.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/Instruction.h"
+
+namespace llvm {
+
+// Provides two way mapping between the blocks and numbers.
+class BlockToIndexMapping {
+  SmallVector<BasicBlock *, 32> V;
+
+public:
+  size_t size() const { return V.size(); }
+
+  BlockToIndexMapping(Function &F) {
+    for (BasicBlock &BB : F)
+      V.push_back(&BB);
+    llvm::sort(V);
+  }
+
+  size_t blockToIndex(BasicBlock const *BB) const {
+    auto *I = llvm::lower_bound(V, BB);
+    assert(I != V.end() && *I == BB && "BasicBlockNumberng: Unknown block");
+    return I - V.begin();
+  }
+
+  BasicBlock *indexToBlock(unsigned Index) const { return V[Index]; }
+};
+
+// The SuspendCrossingInfo maintains data that allows to answer a question
+// whether given two BasicBlocks A and B there is a path from A to B that
+// passes through a suspend point.
+//
+// For every basic block 'i' it maintains a BlockData that consists of:
+//   Consumes:  a bit vector which contains a set of indices of blocks that can
+//              reach block 'i'. A block can trivially reach itself.
+//   Kills: a bit vector which contains a set of indices of blocks that can
+//          reach block 'i' but there is a path crossing a suspend point
+//          not repeating 'i' (path to 'i' without cycles containing 'i').
+//   Suspend: a boolean indicating whether block 'i' contains a suspend point.
+//   End: a boolean indicating whether block 'i' contains a coro.end intrinsic.
+//   KillLoop: There is a path from 'i' to 'i' not otherwise repeating 'i' that
+//             crosses a suspend point.
+//
+class SuspendCrossingInfo {
+  BlockToIndexMapping Mapping;
+
+  struct BlockData {
+    BitVector Consumes;
+    BitVector Kills;
+    bool Suspend = false;
+    bool End = false;
+    bool KillLoop = false;
+    bool Changed = false;
+  };
+  SmallVector<BlockData, 32> Block;
+
+  iterator_range<pred_iterator> predecessors(BlockData const &BD) const {
+    BasicBlock *BB = Mapping.indexToBlock(&BD - &Block[0]);
+    return llvm::predecessors(BB);
+  }
+
+  BlockData &getBlockData(BasicBlock *BB) {
+    return Block[Mapping.blockToIndex(BB)];
+  }
+
+  /// Compute the BlockData for the current function in one iteration.
+  /// Initialize - Whether this is the first iteration, we can optimize
+  /// the initial case a little bit by manual loop switch.
+  /// Returns whether the BlockData changes in this iteration.
+  template <bool Initialize = false>
+  bool computeBlockData(const ReversePostOrderTraversal<Function *> &RPOT);
+
+public:
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+  // Print order is in RPO
+  void dump() const;
+  void dump(StringRef Label, BitVector const &BV,
+            const ReversePostOrderTraversal<Function *> &RPOT) const;
+#endif
+
+  SuspendCrossingInfo(Function &F,
+                      const SmallVectorImpl<AnyCoroSuspendInst *> &CoroSuspends,
+                      const SmallVectorImpl<AnyCoroEndInst *> &CoroEnds);
+
+  /// Returns true if there is a path from \p From to \p To crossing a suspend
+  /// point without crossing \p From a 2nd time.
+  bool hasPathCrossingSuspendPoint(BasicBlock *From, BasicBlock *To) const;
+
+  /// Returns true if there is a path from \p From to \p To crossing a suspend
+  /// point without crossing \p From a 2nd time. If \p From is the same as \p To
+  /// this will also check if there is a looping path crossing a suspend point.
+  bool hasPathOrLoopCrossingSuspendPoint(BasicBlock *From,
+                                         BasicBlock *To) const;
+
+  bool isDefinitionAcrossSuspend(BasicBlock *DefBB, User *U) const {
+    auto *I = cast<Instruction>(U);
+
+    // We rewrote PHINodes, so that only the ones with exactly one incoming
+    // value need to be analyzed.
+    if (auto *PN = dyn_cast<PHINode>(I))
+      if (PN->getNumIncomingValues() > 1)
+        return false;
+
+    BasicBlock *UseBB = I->getParent();
+
+    // As a special case, treat uses by an llvm.coro.suspend.retcon or an
+    // llvm.coro.suspend.async as if they were uses in the suspend's single
+    // predecessor: the uses conceptually occur before the suspend.
+    if (isa<CoroSuspendRetconInst>(I) || isa<CoroSuspendAsyncInst>(I)) {
+      UseBB = UseBB->getSinglePredecessor();
+      assert(UseBB && "should have split coro.suspend into its own block");
+    }
+
+    return hasPathCrossingSuspendPoint(DefBB, UseBB);
+  }
+
+  bool isDefinitionAcrossSuspend(Argument &A, User *U) const {
+    return isDefinitionAcrossSuspend(&A.getParent()->getEntryBlock(), U);
+  }
+
+  bool isDefinitionAcrossSuspend(Instruction &I, User *U) const {
+    auto *DefBB = I.getParent();
+
+    // As a special case, treat values produced by an llvm.coro.suspend.*
+    // as if they were defined in the single successor: the uses
+    // conceptually occur after the suspend.
+    if (isa<AnyCoroSuspendInst>(I)) {
+      DefBB = DefBB->getSingleSuccessor();
+      assert(DefBB && "should have split coro.suspend into its own block");
+    }
+
+    return isDefinitionAcrossSuspend(DefBB, U);
+  }
+
+  bool isDefinitionAcrossSuspend(Value &V, User *U) const {
+    if (auto *Arg = dyn_cast<Argument>(&V))
+      return isDefinitionAcrossSuspend(*Arg, U);
+    if (auto *Inst = dyn_cast<Instruction>(&V))
+      return isDefinitionAcrossSuspend(*Inst, U);
+
+    llvm_unreachable(
+        "Coroutine could only collect Argument and Instruction now.");
+  }
+
+  bool isDefinitionAcrossSuspend(Value &V) const {
+    if (auto *Arg = dyn_cast<Argument>(&V)) {
+      for (User *U : Arg->users())
+        if (isDefinitionAcrossSuspend(*Arg, U))
+          return true;
+    } else if (auto *Inst = dyn_cast<Instruction>(&V)) {
+      for (User *U : Inst->users())
+        if (isDefinitionAcrossSuspend(*Inst, U))
+          return true;
+    }
+
+    llvm_unreachable(
+        "Coroutine could only collect Argument and Instruction now.");
+  }
+};
+
+} // namespace llvm
+
+#endif // LLVM_TRANSFORMS_COROUTINES_SUSPENDCROSSINGINFO_H



More information about the llvm-commits mailing list