[llvm] [llvm/llvm-project][Coroutines] ABI Object (PR #106306)
Tyler Nowicki via llvm-commits
llvm-commits at lists.llvm.org
Tue Sep 3 07:28:29 PDT 2024
https://github.com/TylerNowicki updated https://github.com/llvm/llvm-project/pull/106306
>From 06cde7ca59c29bfb0a96ce1acad6e3ff6b75b8f9 Mon Sep 17 00:00:00 2001
From: tnowicki <tnowicki.nowicki at amd.com>
Date: Fri, 23 Aug 2024 13:12:17 -0400
Subject: [PATCH] [llvm/llvm-project][Coroutines] Major refactoring of
SuspendCrossingInfo
* Move SuspendCrossingInfo to its own files to clean up CoroFrame
---
llvm/lib/Transforms/Coroutines/CMakeLists.txt | 1 +
llvm/lib/Transforms/Coroutines/CoroFrame.cpp | 322 +-----------------
.../Coroutines/SuspendCrossingInfo.cpp | 195 +++++++++++
.../Coroutines/SuspendCrossingInfo.h | 182 ++++++++++
4 files changed, 390 insertions(+), 310 deletions(-)
create mode 100644 llvm/lib/Transforms/Coroutines/SuspendCrossingInfo.cpp
create mode 100644 llvm/lib/Transforms/Coroutines/SuspendCrossingInfo.h
diff --git a/llvm/lib/Transforms/Coroutines/CMakeLists.txt b/llvm/lib/Transforms/Coroutines/CMakeLists.txt
index 2139446e5ff957..57359acc81ee4c 100644
--- a/llvm/lib/Transforms/Coroutines/CMakeLists.txt
+++ b/llvm/lib/Transforms/Coroutines/CMakeLists.txt
@@ -6,6 +6,7 @@ add_llvm_component_library(LLVMCoroutines
CoroElide.cpp
CoroFrame.cpp
CoroSplit.cpp
+ SuspendCrossingInfo.cpp
ADDITIONAL_HEADER_DIRS
${LLVM_MAIN_INCLUDE_DIR}/llvm/Transforms/Coroutines
diff --git a/llvm/lib/Transforms/Coroutines/CoroFrame.cpp b/llvm/lib/Transforms/Coroutines/CoroFrame.cpp
index f76cfe01b58cfd..ca5e9504264b25 100644
--- a/llvm/lib/Transforms/Coroutines/CoroFrame.cpp
+++ b/llvm/lib/Transforms/Coroutines/CoroFrame.cpp
@@ -17,6 +17,7 @@
#include "CoroInternal.h"
#include "llvm/ADT/BitVector.h"
+#include "SuspendCrossingInfo.h"
#include "llvm/ADT/PostOrderIterator.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/SmallString.h"
@@ -51,315 +52,6 @@ extern cl::opt<bool> UseNewDbgInfoFormat;
// "coro-frame", which results in leaner debug spew.
#define DEBUG_TYPE "coro-suspend-crossing"
-enum { SmallVectorThreshold = 32 };
-
-// Provides two way mapping between the blocks and numbers.
-namespace {
-class BlockToIndexMapping {
- SmallVector<BasicBlock *, SmallVectorThreshold> V;
-
-public:
- size_t size() const { return V.size(); }
-
- BlockToIndexMapping(Function &F) {
- for (BasicBlock &BB : F)
- V.push_back(&BB);
- llvm::sort(V);
- }
-
- size_t blockToIndex(BasicBlock const *BB) const {
- auto *I = llvm::lower_bound(V, BB);
- assert(I != V.end() && *I == BB && "BasicBlockNumberng: Unknown block");
- return I - V.begin();
- }
-
- BasicBlock *indexToBlock(unsigned Index) const { return V[Index]; }
-};
-} // end anonymous namespace
-
-// The SuspendCrossingInfo maintains data that allows to answer a question
-// whether given two BasicBlocks A and B there is a path from A to B that
-// passes through a suspend point.
-//
-// For every basic block 'i' it maintains a BlockData that consists of:
-// Consumes: a bit vector which contains a set of indices of blocks that can
-// reach block 'i'. A block can trivially reach itself.
-// Kills: a bit vector which contains a set of indices of blocks that can
-// reach block 'i' but there is a path crossing a suspend point
-// not repeating 'i' (path to 'i' without cycles containing 'i').
-// Suspend: a boolean indicating whether block 'i' contains a suspend point.
-// End: a boolean indicating whether block 'i' contains a coro.end intrinsic.
-// KillLoop: There is a path from 'i' to 'i' not otherwise repeating 'i' that
-// crosses a suspend point.
-//
-namespace {
-class SuspendCrossingInfo {
- BlockToIndexMapping Mapping;
-
- struct BlockData {
- BitVector Consumes;
- BitVector Kills;
- bool Suspend = false;
- bool End = false;
- bool KillLoop = false;
- bool Changed = false;
- };
- SmallVector<BlockData, SmallVectorThreshold> Block;
-
- iterator_range<pred_iterator> predecessors(BlockData const &BD) const {
- BasicBlock *BB = Mapping.indexToBlock(&BD - &Block[0]);
- return llvm::predecessors(BB);
- }
-
- BlockData &getBlockData(BasicBlock *BB) {
- return Block[Mapping.blockToIndex(BB)];
- }
-
- /// Compute the BlockData for the current function in one iteration.
- /// Initialize - Whether this is the first iteration, we can optimize
- /// the initial case a little bit by manual loop switch.
- /// Returns whether the BlockData changes in this iteration.
- template <bool Initialize = false>
- bool computeBlockData(const ReversePostOrderTraversal<Function *> &RPOT);
-
-public:
-#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
- void dump() const;
- void dump(StringRef Label, BitVector const &BV,
- const ReversePostOrderTraversal<Function *> &RPOT) const;
-#endif
-
- SuspendCrossingInfo(Function &F, coro::Shape &Shape);
-
- /// Returns true if there is a path from \p From to \p To crossing a suspend
- /// point without crossing \p From a 2nd time.
- bool hasPathCrossingSuspendPoint(BasicBlock *From, BasicBlock *To) const {
- size_t const FromIndex = Mapping.blockToIndex(From);
- size_t const ToIndex = Mapping.blockToIndex(To);
- bool const Result = Block[ToIndex].Kills[FromIndex];
- LLVM_DEBUG(dbgs() << From->getName() << " => " << To->getName()
- << " answer is " << Result << "\n");
- return Result;
- }
-
- /// Returns true if there is a path from \p From to \p To crossing a suspend
- /// point without crossing \p From a 2nd time. If \p From is the same as \p To
- /// this will also check if there is a looping path crossing a suspend point.
- bool hasPathOrLoopCrossingSuspendPoint(BasicBlock *From,
- BasicBlock *To) const {
- size_t const FromIndex = Mapping.blockToIndex(From);
- size_t const ToIndex = Mapping.blockToIndex(To);
- bool Result = Block[ToIndex].Kills[FromIndex] ||
- (From == To && Block[ToIndex].KillLoop);
- LLVM_DEBUG(dbgs() << From->getName() << " => " << To->getName()
- << " answer is " << Result << " (path or loop)\n");
- return Result;
- }
-
- bool isDefinitionAcrossSuspend(BasicBlock *DefBB, User *U) const {
- auto *I = cast<Instruction>(U);
-
- // We rewrote PHINodes, so that only the ones with exactly one incoming
- // value need to be analyzed.
- if (auto *PN = dyn_cast<PHINode>(I))
- if (PN->getNumIncomingValues() > 1)
- return false;
-
- BasicBlock *UseBB = I->getParent();
-
- // As a special case, treat uses by an llvm.coro.suspend.retcon or an
- // llvm.coro.suspend.async as if they were uses in the suspend's single
- // predecessor: the uses conceptually occur before the suspend.
- if (isa<CoroSuspendRetconInst>(I) || isa<CoroSuspendAsyncInst>(I)) {
- UseBB = UseBB->getSinglePredecessor();
- assert(UseBB && "should have split coro.suspend into its own block");
- }
-
- return hasPathCrossingSuspendPoint(DefBB, UseBB);
- }
-
- bool isDefinitionAcrossSuspend(Argument &A, User *U) const {
- return isDefinitionAcrossSuspend(&A.getParent()->getEntryBlock(), U);
- }
-
- bool isDefinitionAcrossSuspend(Instruction &I, User *U) const {
- auto *DefBB = I.getParent();
-
- // As a special case, treat values produced by an llvm.coro.suspend.*
- // as if they were defined in the single successor: the uses
- // conceptually occur after the suspend.
- if (isa<AnyCoroSuspendInst>(I)) {
- DefBB = DefBB->getSingleSuccessor();
- assert(DefBB && "should have split coro.suspend into its own block");
- }
-
- return isDefinitionAcrossSuspend(DefBB, U);
- }
-
- bool isDefinitionAcrossSuspend(Value &V, User *U) const {
- if (auto *Arg = dyn_cast<Argument>(&V))
- return isDefinitionAcrossSuspend(*Arg, U);
- if (auto *Inst = dyn_cast<Instruction>(&V))
- return isDefinitionAcrossSuspend(*Inst, U);
-
- llvm_unreachable(
- "Coroutine could only collect Argument and Instruction now.");
- }
-};
-} // end anonymous namespace
-
-#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
-static std::string getBasicBlockLabel(const BasicBlock *BB) {
- if (BB->hasName())
- return BB->getName().str();
-
- std::string S;
- raw_string_ostream OS(S);
- BB->printAsOperand(OS, false);
- return OS.str().substr(1);
-}
-
-LLVM_DUMP_METHOD void SuspendCrossingInfo::dump(
- StringRef Label, BitVector const &BV,
- const ReversePostOrderTraversal<Function *> &RPOT) const {
- dbgs() << Label << ":";
- for (const BasicBlock *BB : RPOT) {
- auto BBNo = Mapping.blockToIndex(BB);
- if (BV[BBNo])
- dbgs() << " " << getBasicBlockLabel(BB);
- }
- dbgs() << "\n";
-}
-
-LLVM_DUMP_METHOD void SuspendCrossingInfo::dump() const {
- if (Block.empty())
- return;
-
- BasicBlock *const B = Mapping.indexToBlock(0);
- Function *F = B->getParent();
-
- ReversePostOrderTraversal<Function *> RPOT(F);
- for (const BasicBlock *BB : RPOT) {
- auto BBNo = Mapping.blockToIndex(BB);
- dbgs() << getBasicBlockLabel(BB) << ":\n";
- dump(" Consumes", Block[BBNo].Consumes, RPOT);
- dump(" Kills", Block[BBNo].Kills, RPOT);
- }
- dbgs() << "\n";
-}
-#endif
-
-template <bool Initialize>
-bool SuspendCrossingInfo::computeBlockData(
- const ReversePostOrderTraversal<Function *> &RPOT) {
- bool Changed = false;
-
- for (const BasicBlock *BB : RPOT) {
- auto BBNo = Mapping.blockToIndex(BB);
- auto &B = Block[BBNo];
-
- // We don't need to count the predecessors when initialization.
- if constexpr (!Initialize)
- // If all the predecessors of the current Block don't change,
- // the BlockData for the current block must not change too.
- if (all_of(predecessors(B), [this](BasicBlock *BB) {
- return !Block[Mapping.blockToIndex(BB)].Changed;
- })) {
- B.Changed = false;
- continue;
- }
-
- // Saved Consumes and Kills bitsets so that it is easy to see
- // if anything changed after propagation.
- auto SavedConsumes = B.Consumes;
- auto SavedKills = B.Kills;
-
- for (BasicBlock *PI : predecessors(B)) {
- auto PrevNo = Mapping.blockToIndex(PI);
- auto &P = Block[PrevNo];
-
- // Propagate Kills and Consumes from predecessors into B.
- B.Consumes |= P.Consumes;
- B.Kills |= P.Kills;
-
- // If block P is a suspend block, it should propagate kills into block
- // B for every block P consumes.
- if (P.Suspend)
- B.Kills |= P.Consumes;
- }
-
- if (B.Suspend) {
- // If block B is a suspend block, it should kill all of the blocks it
- // consumes.
- B.Kills |= B.Consumes;
- } else if (B.End) {
- // If block B is an end block, it should not propagate kills as the
- // blocks following coro.end() are reached during initial invocation
- // of the coroutine while all the data are still available on the
- // stack or in the registers.
- B.Kills.reset();
- } else {
- // This is reached when B block it not Suspend nor coro.end and it
- // need to make sure that it is not in the kill set.
- B.KillLoop |= B.Kills[BBNo];
- B.Kills.reset(BBNo);
- }
-
- if constexpr (!Initialize) {
- B.Changed = (B.Kills != SavedKills) || (B.Consumes != SavedConsumes);
- Changed |= B.Changed;
- }
- }
-
- return Changed;
-}
-
-SuspendCrossingInfo::SuspendCrossingInfo(Function &F, coro::Shape &Shape)
- : Mapping(F) {
- const size_t N = Mapping.size();
- Block.resize(N);
-
- // Initialize every block so that it consumes itself
- for (size_t I = 0; I < N; ++I) {
- auto &B = Block[I];
- B.Consumes.resize(N);
- B.Kills.resize(N);
- B.Consumes.set(I);
- B.Changed = true;
- }
-
- // Mark all CoroEnd Blocks. We do not propagate Kills beyond coro.ends as
- // the code beyond coro.end is reachable during initial invocation of the
- // coroutine.
- for (auto *CE : Shape.CoroEnds)
- getBlockData(CE->getParent()).End = true;
-
- // Mark all suspend blocks and indicate that they kill everything they
- // consume. Note, that crossing coro.save also requires a spill, as any code
- // between coro.save and coro.suspend may resume the coroutine and all of the
- // state needs to be saved by that time.
- auto markSuspendBlock = [&](IntrinsicInst *BarrierInst) {
- BasicBlock *SuspendBlock = BarrierInst->getParent();
- auto &B = getBlockData(SuspendBlock);
- B.Suspend = true;
- B.Kills |= B.Consumes;
- };
- for (auto *CSI : Shape.CoroSuspends) {
- markSuspendBlock(CSI);
- if (auto *Save = CSI->getCoroSave())
- markSuspendBlock(Save);
- }
-
- // It is considered to be faster to use RPO traversal for forward-edges
- // dataflow analysis.
- ReversePostOrderTraversal<Function *> RPOT(&F);
- computeBlockData</*Initialize=*/true>(RPOT);
- while (computeBlockData</*Initialize*/ false>(RPOT))
- ;
-
- LLVM_DEBUG(dump());
-}
-
namespace {
// RematGraph is used to construct a DAG for rematerializable instructions
@@ -438,6 +130,16 @@ struct RematGraph {
}
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+ static std::string getBasicBlockLabel(const BasicBlock *BB) {
+ if (BB->hasName())
+ return BB->getName().str();
+
+ std::string S;
+ raw_string_ostream OS(S);
+ BB->printAsOperand(OS, false);
+ return OS.str().substr(1);
+ }
+
void dump() const {
dbgs() << "Entry (";
dbgs() << getBasicBlockLabel(EntryNode->Node->getParent());
@@ -3159,7 +2861,7 @@ void coro::buildCoroutineFrame(
rewritePHIs(F);
// Build suspend crossing info.
- SuspendCrossingInfo Checker(F, Shape);
+ SuspendCrossingInfo Checker(F, Shape.CoroSuspends, Shape.CoroEnds);
doRematerializations(F, Checker, MaterializableCallback);
diff --git a/llvm/lib/Transforms/Coroutines/SuspendCrossingInfo.cpp b/llvm/lib/Transforms/Coroutines/SuspendCrossingInfo.cpp
new file mode 100644
index 00000000000000..ff3b32e958edac
--- /dev/null
+++ b/llvm/lib/Transforms/Coroutines/SuspendCrossingInfo.cpp
@@ -0,0 +1,195 @@
+//===- SuspendCrossingInfo.cpp - Utility for suspend crossing values ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+// The SuspendCrossingInfo maintains data that allows to answer a question
+// whether given two BasicBlocks A and B there is a path from A to B that
+// passes through a suspend point.
+//===----------------------------------------------------------------------===//
+
+#include "SuspendCrossingInfo.h"
+
+// The "coro-suspend-crossing" flag is very noisy. There is another debug type,
+// "coro-frame", which results in leaner debug spew.
+#define DEBUG_TYPE "coro-suspend-crossing"
+
+namespace llvm {
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+static std::string getBasicBlockLabel(const BasicBlock *BB) {
+ if (BB->hasName())
+ return BB->getName().str();
+
+ std::string S;
+ raw_string_ostream OS(S);
+ BB->printAsOperand(OS, false);
+ return OS.str().substr(1);
+}
+
+LLVM_DUMP_METHOD void SuspendCrossingInfo::dump(
+ StringRef Label, BitVector const &BV,
+ const ReversePostOrderTraversal<Function *> &RPOT) const {
+ dbgs() << Label << ":";
+ for (const BasicBlock *BB : RPOT) {
+ auto BBNo = Mapping.blockToIndex(BB);
+ if (BV[BBNo])
+ dbgs() << " " << getBasicBlockLabel(BB);
+ }
+ dbgs() << "\n";
+}
+
+LLVM_DUMP_METHOD void SuspendCrossingInfo::dump() const {
+ if (Block.empty())
+ return;
+
+ BasicBlock *const B = Mapping.indexToBlock(0);
+ Function *F = B->getParent();
+
+ ReversePostOrderTraversal<Function *> RPOT(F);
+ for (const BasicBlock *BB : RPOT) {
+ auto BBNo = Mapping.blockToIndex(BB);
+ dbgs() << getBasicBlockLabel(BB) << ":\n";
+ dump(" Consumes", Block[BBNo].Consumes, RPOT);
+ dump(" Kills", Block[BBNo].Kills, RPOT);
+ }
+ dbgs() << "\n";
+}
+#endif
+
+bool SuspendCrossingInfo::hasPathCrossingSuspendPoint(BasicBlock *From,
+ BasicBlock *To) const {
+ size_t const FromIndex = Mapping.blockToIndex(From);
+ size_t const ToIndex = Mapping.blockToIndex(To);
+ bool const Result = Block[ToIndex].Kills[FromIndex];
+ LLVM_DEBUG(if (Result) dbgs() << From->getName() << " => " << To->getName()
+ << " crosses suspend point\n");
+ return Result;
+}
+
+bool SuspendCrossingInfo::hasPathOrLoopCrossingSuspendPoint(
+ BasicBlock *From, BasicBlock *To) const {
+ size_t const FromIndex = Mapping.blockToIndex(From);
+ size_t const ToIndex = Mapping.blockToIndex(To);
+ bool Result = Block[ToIndex].Kills[FromIndex] ||
+ (From == To && Block[ToIndex].KillLoop);
+ LLVM_DEBUG(if (Result) dbgs() << From->getName() << " => " << To->getName()
+ << " crosses suspend point (path or loop)\n");
+ return Result;
+}
+
+template <bool Initialize>
+bool SuspendCrossingInfo::computeBlockData(
+ const ReversePostOrderTraversal<Function *> &RPOT) {
+ bool Changed = false;
+
+ for (const BasicBlock *BB : RPOT) {
+ auto BBNo = Mapping.blockToIndex(BB);
+ auto &B = Block[BBNo];
+
+ // We don't need to count the predecessors when initialization.
+ if constexpr (!Initialize)
+ // If all the predecessors of the current Block don't change,
+ // the BlockData for the current block must not change too.
+ if (all_of(predecessors(B), [this](BasicBlock *BB) {
+ return !Block[Mapping.blockToIndex(BB)].Changed;
+ })) {
+ B.Changed = false;
+ continue;
+ }
+
+ // Saved Consumes and Kills bitsets so that it is easy to see
+ // if anything changed after propagation.
+ auto SavedConsumes = B.Consumes;
+ auto SavedKills = B.Kills;
+
+ for (BasicBlock *PI : predecessors(B)) {
+ auto PrevNo = Mapping.blockToIndex(PI);
+ auto &P = Block[PrevNo];
+
+ // Propagate Kills and Consumes from predecessors into B.
+ B.Consumes |= P.Consumes;
+ B.Kills |= P.Kills;
+
+ // If block P is a suspend block, it should propagate kills into block
+ // B for every block P consumes.
+ if (P.Suspend)
+ B.Kills |= P.Consumes;
+ }
+
+ if (B.Suspend) {
+ // If block B is a suspend block, it should kill all of the blocks it
+ // consumes.
+ B.Kills |= B.Consumes;
+ } else if (B.End) {
+ // If block B is an end block, it should not propagate kills as the
+ // blocks following coro.end() are reached during initial invocation
+ // of the coroutine while all the data are still available on the
+ // stack or in the registers.
+ B.Kills.reset();
+ } else {
+ // This is reached when B block it not Suspend nor coro.end and it
+ // need to make sure that it is not in the kill set.
+ B.KillLoop |= B.Kills[BBNo];
+ B.Kills.reset(BBNo);
+ }
+
+ if constexpr (!Initialize) {
+ B.Changed = (B.Kills != SavedKills) || (B.Consumes != SavedConsumes);
+ Changed |= B.Changed;
+ }
+ }
+
+ return Changed;
+}
+
+SuspendCrossingInfo::SuspendCrossingInfo(
+ Function &F, const SmallVectorImpl<AnyCoroSuspendInst *> &CoroSuspends,
+ const SmallVectorImpl<AnyCoroEndInst *> &CoroEnds)
+ : Mapping(F) {
+ const size_t N = Mapping.size();
+ Block.resize(N);
+
+ // Initialize every block so that it consumes itself
+ for (size_t I = 0; I < N; ++I) {
+ auto &B = Block[I];
+ B.Consumes.resize(N);
+ B.Kills.resize(N);
+ B.Consumes.set(I);
+ B.Changed = true;
+ }
+
+ // Mark all CoroEnd Blocks. We do not propagate Kills beyond coro.ends as
+ // the code beyond coro.end is reachable during initial invocation of the
+ // coroutine.
+ for (auto *CE : CoroEnds)
+ getBlockData(CE->getParent()).End = true;
+
+ // Mark all suspend blocks and indicate that they kill everything they
+ // consume. Note, that crossing coro.save also requires a spill, as any code
+ // between coro.save and coro.suspend may resume the coroutine and all of the
+ // state needs to be saved by that time.
+ auto markSuspendBlock = [&](IntrinsicInst *BarrierInst) {
+ BasicBlock *SuspendBlock = BarrierInst->getParent();
+ auto &B = getBlockData(SuspendBlock);
+ B.Suspend = true;
+ B.Kills |= B.Consumes;
+ };
+ for (auto *CSI : CoroSuspends) {
+ markSuspendBlock(CSI);
+ if (auto *Save = CSI->getCoroSave())
+ markSuspendBlock(Save);
+ }
+
+ // It is considered to be faster to use RPO traversal for forward-edges
+ // dataflow analysis.
+ ReversePostOrderTraversal<Function *> RPOT(&F);
+ computeBlockData</*Initialize=*/true>(RPOT);
+ while (computeBlockData</*Initialize*/ false>(RPOT))
+ ;
+
+ LLVM_DEBUG(dump());
+}
+
+} // namespace llvm
diff --git a/llvm/lib/Transforms/Coroutines/SuspendCrossingInfo.h b/llvm/lib/Transforms/Coroutines/SuspendCrossingInfo.h
new file mode 100644
index 00000000000000..d1551af0b54975
--- /dev/null
+++ b/llvm/lib/Transforms/Coroutines/SuspendCrossingInfo.h
@@ -0,0 +1,182 @@
+//===- SuspendCrossingInfo.cpp - Utility for suspend crossing values ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+// The SuspendCrossingInfo maintains data that allows to answer a question
+// whether given two BasicBlocks A and B there is a path from A to B that
+// passes through a suspend point.
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_LIB_TRANSFORMS_COROUTINES_SUSPENDCROSSINGINFO_H
+#define LLVM_LIB_TRANSFORMS_COROUTINES_SUSPENDCROSSINGINFO_H
+
+#include "CoroInstr.h"
+#include "llvm/ADT/BitVector.h"
+#include "llvm/ADT/PostOrderIterator.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/Instruction.h"
+
+namespace llvm {
+
+// Provides two way mapping between the blocks and numbers.
+class BlockToIndexMapping {
+ SmallVector<BasicBlock *, 32> V;
+
+public:
+ size_t size() const { return V.size(); }
+
+ BlockToIndexMapping(Function &F) {
+ for (BasicBlock &BB : F)
+ V.push_back(&BB);
+ llvm::sort(V);
+ }
+
+ size_t blockToIndex(BasicBlock const *BB) const {
+ auto *I = llvm::lower_bound(V, BB);
+ assert(I != V.end() && *I == BB && "BasicBlockNumberng: Unknown block");
+ return I - V.begin();
+ }
+
+ BasicBlock *indexToBlock(unsigned Index) const { return V[Index]; }
+};
+
+// The SuspendCrossingInfo maintains data that allows to answer a question
+// whether given two BasicBlocks A and B there is a path from A to B that
+// passes through a suspend point.
+//
+// For every basic block 'i' it maintains a BlockData that consists of:
+// Consumes: a bit vector which contains a set of indices of blocks that can
+// reach block 'i'. A block can trivially reach itself.
+// Kills: a bit vector which contains a set of indices of blocks that can
+// reach block 'i' but there is a path crossing a suspend point
+// not repeating 'i' (path to 'i' without cycles containing 'i').
+// Suspend: a boolean indicating whether block 'i' contains a suspend point.
+// End: a boolean indicating whether block 'i' contains a coro.end intrinsic.
+// KillLoop: There is a path from 'i' to 'i' not otherwise repeating 'i' that
+// crosses a suspend point.
+//
+class SuspendCrossingInfo {
+ BlockToIndexMapping Mapping;
+
+ struct BlockData {
+ BitVector Consumes;
+ BitVector Kills;
+ bool Suspend = false;
+ bool End = false;
+ bool KillLoop = false;
+ bool Changed = false;
+ };
+ SmallVector<BlockData, 32> Block;
+
+ iterator_range<pred_iterator> predecessors(BlockData const &BD) const {
+ BasicBlock *BB = Mapping.indexToBlock(&BD - &Block[0]);
+ return llvm::predecessors(BB);
+ }
+
+ BlockData &getBlockData(BasicBlock *BB) {
+ return Block[Mapping.blockToIndex(BB)];
+ }
+
+ /// Compute the BlockData for the current function in one iteration.
+ /// Initialize - Whether this is the first iteration, we can optimize
+ /// the initial case a little bit by manual loop switch.
+ /// Returns whether the BlockData changes in this iteration.
+ template <bool Initialize = false>
+ bool computeBlockData(const ReversePostOrderTraversal<Function *> &RPOT);
+
+public:
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+ // Print order is in RPO
+ void dump() const;
+ void dump(StringRef Label, BitVector const &BV,
+ const ReversePostOrderTraversal<Function *> &RPOT) const;
+#endif
+
+ SuspendCrossingInfo(Function &F,
+ const SmallVectorImpl<AnyCoroSuspendInst *> &CoroSuspends,
+ const SmallVectorImpl<AnyCoroEndInst *> &CoroEnds);
+
+ /// Returns true if there is a path from \p From to \p To crossing a suspend
+ /// point without crossing \p From a 2nd time.
+ bool hasPathCrossingSuspendPoint(BasicBlock *From, BasicBlock *To) const;
+
+ /// Returns true if there is a path from \p From to \p To crossing a suspend
+ /// point without crossing \p From a 2nd time. If \p From is the same as \p To
+ /// this will also check if there is a looping path crossing a suspend point.
+ bool hasPathOrLoopCrossingSuspendPoint(BasicBlock *From,
+ BasicBlock *To) const;
+
+ bool isDefinitionAcrossSuspend(BasicBlock *DefBB, User *U) const {
+ auto *I = cast<Instruction>(U);
+
+ // We rewrote PHINodes, so that only the ones with exactly one incoming
+ // value need to be analyzed.
+ if (auto *PN = dyn_cast<PHINode>(I))
+ if (PN->getNumIncomingValues() > 1)
+ return false;
+
+ BasicBlock *UseBB = I->getParent();
+
+ // As a special case, treat uses by an llvm.coro.suspend.retcon or an
+ // llvm.coro.suspend.async as if they were uses in the suspend's single
+ // predecessor: the uses conceptually occur before the suspend.
+ if (isa<CoroSuspendRetconInst>(I) || isa<CoroSuspendAsyncInst>(I)) {
+ UseBB = UseBB->getSinglePredecessor();
+ assert(UseBB && "should have split coro.suspend into its own block");
+ }
+
+ return hasPathCrossingSuspendPoint(DefBB, UseBB);
+ }
+
+ bool isDefinitionAcrossSuspend(Argument &A, User *U) const {
+ return isDefinitionAcrossSuspend(&A.getParent()->getEntryBlock(), U);
+ }
+
+ bool isDefinitionAcrossSuspend(Instruction &I, User *U) const {
+ auto *DefBB = I.getParent();
+
+ // As a special case, treat values produced by an llvm.coro.suspend.*
+ // as if they were defined in the single successor: the uses
+ // conceptually occur after the suspend.
+ if (isa<AnyCoroSuspendInst>(I)) {
+ DefBB = DefBB->getSingleSuccessor();
+ assert(DefBB && "should have split coro.suspend into its own block");
+ }
+
+ return isDefinitionAcrossSuspend(DefBB, U);
+ }
+
+ bool isDefinitionAcrossSuspend(Value &V, User *U) const {
+ if (auto *Arg = dyn_cast<Argument>(&V))
+ return isDefinitionAcrossSuspend(*Arg, U);
+ if (auto *Inst = dyn_cast<Instruction>(&V))
+ return isDefinitionAcrossSuspend(*Inst, U);
+
+ llvm_unreachable(
+ "Coroutine could only collect Argument and Instruction now.");
+ }
+
+ bool isDefinitionAcrossSuspend(Value &V) const {
+ if (auto *Arg = dyn_cast<Argument>(&V)) {
+ for (User *U : Arg->users())
+ if (isDefinitionAcrossSuspend(*Arg, U))
+ return true;
+ } else if (auto *Inst = dyn_cast<Instruction>(&V)) {
+ for (User *U : Inst->users())
+ if (isDefinitionAcrossSuspend(*Inst, U))
+ return true;
+ }
+
+ llvm_unreachable(
+ "Coroutine could only collect Argument and Instruction now.");
+ }
+};
+
+} // namespace llvm
+
+#endif // LLVM_TRANSFORMS_COROUTINES_SUSPENDCROSSINGINFO_H
More information about the llvm-commits
mailing list