[llvm] 4247806 - [NFC] [Coroutines] Add a fastpath when computing the cross suspend point information

Chuanqi Xu via llvm-commits llvm-commits at lists.llvm.org
Fri May 5 02:01:34 PDT 2023


Author: Chuanqi Xu
Date: 2023-05-05T17:00:40+08:00
New Revision: 4247806690801cba0699552b711dbfdb943d05d2

URL: https://github.com/llvm/llvm-project/commit/4247806690801cba0699552b711dbfdb943d05d2
DIFF: https://github.com/llvm/llvm-project/commit/4247806690801cba0699552b711dbfdb943d05d2.diff

LOG: [NFC] [Coroutines] Add a fastpath when computing the cross suspend point information

Mitigate https://github.com/llvm/llvm-project/issues/62348

The root cause for the above issue is that we used a textbook dataflow
analysis for the cross suspend point information. The analysis is
powerful but not scaling.

It is not easy to improve the current algorithm and the patch tries to
prune some branches to mitigate the problems.

Before the patch:

```
n: 20000

real	0m11.081s
user	0m10.597s
sys	0m0.320s

n: 40000

real	0m32.927s
user	0m31.403s
sys	0m1.043s

n: 60000

real	1m2.145s
user	0m58.903s
sys	0m2.268s

n: 80000

real	1m47.143s
user	1m41.630s
sys	0m3.857s

n: 100000

real	2m34.758s
user	2m26.587s
sys	0m5.922s
```

After the patch:

```
n: 20000

real	0m10.418s
user	0m9.945s
sys	0m0.311s

n: 40000

real	0m27.884s
user	0m26.430s
sys	0m1.036s

n: 60000

real	0m52.420s
user	0m49.321s
sys	0m2.267s

n: 80000

real	1m25.389s
user	1m20.247s
sys	0m3.856s

n: 100000

real	2m4.275s
user	1m56.405s
sys	0m5.975s
```

This patch intended to be a NFC patch.

Added: 
    

Modified: 
    llvm/lib/Transforms/Coroutines/CoroFrame.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Coroutines/CoroFrame.cpp b/llvm/lib/Transforms/Coroutines/CoroFrame.cpp
index c2c6294ef60e8..c81f1e7f21fc8 100644
--- a/llvm/lib/Transforms/Coroutines/CoroFrame.cpp
+++ b/llvm/lib/Transforms/Coroutines/CoroFrame.cpp
@@ -89,7 +89,7 @@ class BlockToIndexMapping {
 //             crosses a suspend point.
 //
 namespace {
-struct SuspendCrossingInfo {
+class SuspendCrossingInfo {
   BlockToIndexMapping Mapping;
 
   struct BlockData {
@@ -98,18 +98,26 @@ struct SuspendCrossingInfo {
     bool Suspend = false;
     bool End = false;
     bool KillLoop = false;
+    bool Changed = false;
   };
   SmallVector<BlockData, SmallVectorThreshold> Block;
 
-  iterator_range<succ_iterator> successors(BlockData const &BD) const {
+  iterator_range<pred_iterator> predecessors(BlockData const &BD) const {
     BasicBlock *BB = Mapping.indexToBlock(&BD - &Block[0]);
-    return llvm::successors(BB);
+    return llvm::predecessors(BB);
   }
 
   BlockData &getBlockData(BasicBlock *BB) {
     return Block[Mapping.blockToIndex(BB)];
   }
 
+  /// Compute the BlockData for the current function in one iteration.
+  /// Returns whether the BlockData changes in this iteration.
+  /// Initialize - Whether this is the first iteration, we can optimize
+  /// the initial case a little bit by manual loop switch.
+  template <bool Initialize = false> bool computeBlockData();
+
+public:
 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
   void dump() const;
   void dump(StringRef Label, BitVector const &BV) const;
@@ -215,6 +223,72 @@ LLVM_DUMP_METHOD void SuspendCrossingInfo::dump() const {
 }
 #endif
 
+template <bool Initialize> bool SuspendCrossingInfo::computeBlockData() {
+  const size_t N = Mapping.size();
+  bool Changed = false;
+
+  for (size_t I = 0; I < N; ++I) {
+    auto &B = Block[I];
+
+    // We don't need to count the predecessors when initialization.
+    if constexpr (!Initialize)
+      // If all the predecessors of the current Block don't change,
+      // the BlockData for the current block must not change too.
+      if (all_of(predecessors(B), [this](BasicBlock *BB) {
+            return !Block[Mapping.blockToIndex(BB)].Changed;
+          })) {
+        B.Changed = false;
+        continue;
+      }
+
+    // Saved Consumes and Kills bitsets so that it is easy to see
+    // if anything changed after propagation.
+    auto SavedConsumes = B.Consumes;
+    auto SavedKills = B.Kills;
+
+    for (BasicBlock *PI : predecessors(B)) {
+      auto PrevNo = Mapping.blockToIndex(PI);
+      auto &P = Block[PrevNo];
+
+      // Propagate Kills and Consumes from predecessors into B.
+      B.Consumes |= P.Consumes;
+      B.Kills |= P.Kills;
+
+      // If block P is a suspend block, it should propagate kills into block
+      // B for every block P consumes.
+      if (P.Suspend)
+        B.Kills |= P.Consumes;
+    }
+
+    if (B.Suspend) {
+      // If block S is a suspend block, it should kill all of the blocks it
+      // consumes.
+      B.Kills |= B.Consumes;
+    } else if (B.End) {
+      // If block B is an end block, it should not propagate kills as the
+      // blocks following coro.end() are reached during initial invocation
+      // of the coroutine while all the data are still available on the
+      // stack or in the registers.
+      B.Kills.reset();
+    } else {
+      // This is reached when B block it not Suspend nor coro.end and it
+      // need to make sure that it is not in the kill set.
+      B.KillLoop |= B.Kills[I];
+      B.Kills.reset(I);
+    }
+
+    if constexpr (!Initialize) {
+      B.Changed = (B.Kills != SavedKills) || (B.Consumes != SavedConsumes);
+      Changed |= B.Changed;
+    }
+  }
+
+  if constexpr (Initialize)
+    return true;
+
+  return Changed;
+}
+
 SuspendCrossingInfo::SuspendCrossingInfo(Function &F, coro::Shape &Shape)
     : Mapping(F) {
   const size_t N = Mapping.size();
@@ -226,6 +300,7 @@ SuspendCrossingInfo::SuspendCrossingInfo(Function &F, coro::Shape &Shape)
     B.Consumes.resize(N);
     B.Kills.resize(N);
     B.Consumes.set(I);
+    B.Changed = true;
   }
 
   // Mark all CoroEnd Blocks. We do not propagate Kills beyond coro.ends as
@@ -250,71 +325,11 @@ SuspendCrossingInfo::SuspendCrossingInfo(Function &F, coro::Shape &Shape)
       markSuspendBlock(Save);
   }
 
-  // Iterate propagating consumes and kills until they stop changing.
-  int Iteration = 0;
-  (void)Iteration;
+  computeBlockData</*Initialize=*/true>();
 
-  bool Changed;
-  do {
-    LLVM_DEBUG(dbgs() << "iteration " << ++Iteration);
-    LLVM_DEBUG(dbgs() << "==============\n");
-
-    Changed = false;
-    for (size_t I = 0; I < N; ++I) {
-      auto &B = Block[I];
-      for (BasicBlock *SI : successors(B)) {
-
-        auto SuccNo = Mapping.blockToIndex(SI);
-
-        // Saved Consumes and Kills bitsets so that it is easy to see
-        // if anything changed after propagation.
-        auto &S = Block[SuccNo];
-        auto SavedConsumes = S.Consumes;
-        auto SavedKills = S.Kills;
-
-        // Propagate Kills and Consumes from block B into its successor S.
-        S.Consumes |= B.Consumes;
-        S.Kills |= B.Kills;
-
-        // If block B is a suspend block, it should propagate kills into the
-        // its successor for every block B consumes.
-        if (B.Suspend) {
-          S.Kills |= B.Consumes;
-        }
-        if (S.Suspend) {
-          // If block S is a suspend block, it should kill all of the blocks it
-          // consumes.
-          S.Kills |= S.Consumes;
-        } else if (S.End) {
-          // If block S is an end block, it should not propagate kills as the
-          // blocks following coro.end() are reached during initial invocation
-          // of the coroutine while all the data are still available on the
-          // stack or in the registers.
-          S.Kills.reset();
-        } else {
-          // This is reached when S block it not Suspend nor coro.end and it
-          // need to make sure that it is not in the kill set.
-          S.KillLoop |= S.Kills[SuccNo];
-          S.Kills.reset(SuccNo);
-        }
-
-        // See if anything changed.
-        Changed |= (S.Kills != SavedKills) || (S.Consumes != SavedConsumes);
+  while (computeBlockData())
+    ;
 
-        if (S.Kills != SavedKills) {
-          LLVM_DEBUG(dbgs() << "\nblock " << I << " follower " << SI->getName()
-                            << "\n");
-          LLVM_DEBUG(dump("S.Kills", S.Kills));
-          LLVM_DEBUG(dump("SavedKills", SavedKills));
-        }
-        if (S.Consumes != SavedConsumes) {
-          LLVM_DEBUG(dbgs() << "\nblock " << I << " follower " << SI << "\n");
-          LLVM_DEBUG(dump("S.Consume", S.Consumes));
-          LLVM_DEBUG(dump("SavedCons", SavedConsumes));
-        }
-      }
-    }
-  } while (Changed);
   LLVM_DEBUG(dump());
 }
 


        


More information about the llvm-commits mailing list