[llvm] 3e51af9 - [Coroutines] Improve rematerialization stage

David Stuttard via llvm-commits llvm-commits at lists.llvm.org
Mon Feb 13 03:06:54 PST 2023


Author: David Stuttard
Date: 2023-02-13T11:02:20Z
New Revision: 3e51af9b5b3a2a1e4793fdaba9aa57c09d0b77cc

URL: https://github.com/llvm/llvm-project/commit/3e51af9b5b3a2a1e4793fdaba9aa57c09d0b77cc
DIFF: https://github.com/llvm/llvm-project/commit/3e51af9b5b3a2a1e4793fdaba9aa57c09d0b77cc.diff

LOG: [Coroutines] Improve rematerialization stage

As originally implemented, the rematerialization of valid instructions across
the suspend point would iterate 4 times, meaning that up to 4 instructions could
be rematerialized.

This implementation changes that approach to instead build a graph of
rematerializable instructions, then move all of them. This is faster than the
original approach and is not limited to an arbitrary limit.

Differential Revision: https://reviews.llvm.org/D142620

Added: 
    

Modified: 
    llvm/lib/Transforms/Coroutines/CoroFrame.cpp
    llvm/test/Transforms/Coroutines/coro-materialize.ll
    llvm/test/Transforms/Coroutines/coro-retcon-remat.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Coroutines/CoroFrame.cpp b/llvm/lib/Transforms/Coroutines/CoroFrame.cpp
index 814c839cea739..dc14f68dd6b6f 100644
--- a/llvm/lib/Transforms/Coroutines/CoroFrame.cpp
+++ b/llvm/lib/Transforms/Coroutines/CoroFrame.cpp
@@ -16,6 +16,7 @@
 
 #include "CoroInternal.h"
 #include "llvm/ADT/BitVector.h"
+#include "llvm/ADT/PostOrderIterator.h"
 #include "llvm/ADT/ScopeExit.h"
 #include "llvm/ADT/SmallString.h"
 #include "llvm/Analysis/PtrUseVisitor.h"
@@ -37,6 +38,7 @@
 #include "llvm/Transforms/Utils/Local.h"
 #include "llvm/Transforms/Utils/PromoteMemToReg.h"
 #include <algorithm>
+#include <deque>
 #include <optional>
 
 using namespace llvm;
@@ -108,8 +110,10 @@ struct SuspendCrossingInfo {
     return Block[Mapping.blockToIndex(BB)];
   }
 
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
   void dump() const;
   void dump(StringRef Label, BitVector const &BV) const;
+#endif
 
   SuspendCrossingInfo(Function &F, coro::Shape &Shape);
 
@@ -314,6 +318,115 @@ SuspendCrossingInfo::SuspendCrossingInfo(Function &F, coro::Shape &Shape)
   LLVM_DEBUG(dump());
 }
 
+static bool materializable(Instruction &V);
+
+namespace {
+
+// RematGraph is used to construct a DAG for rematerializable instructions
+// When the constructor is invoked with a candidate instruction (which is
+// materializable) it builds a DAG of materializable instructions from that
+// point.
+// Typically, for each instruction identified as re-materializable across a
+// suspend point, a RematGraph will be created.
+struct RematGraph {
+  // Each RematNode in the graph contains the edges to instructions providing
+  // operands in the current node.
+  struct RematNode {
+    Instruction *Node;
+    SmallVector<RematNode *> Operands;
+    RematNode() = default;
+    RematNode(Instruction *V) : Node(V) {}
+  };
+
+  RematNode *EntryNode;
+  using RematNodeMap =
+      SmallMapVector<Instruction *, std::unique_ptr<RematNode>, 8>;
+  RematNodeMap Remats;
+  SuspendCrossingInfo &Checker;
+
+  RematGraph(Instruction *I, SuspendCrossingInfo &Checker) : Checker(Checker) {
+    std::unique_ptr<RematNode> FirstNode = std::make_unique<RematNode>(I);
+    EntryNode = FirstNode.get();
+    std::deque<std::unique_ptr<RematNode>> WorkList;
+    addNode(std::move(FirstNode), WorkList, cast<User>(I));
+    while (WorkList.size()) {
+      std::unique_ptr<RematNode> N = std::move(WorkList.front());
+      WorkList.pop_front();
+      addNode(std::move(N), WorkList, cast<User>(I));
+    }
+  }
+
+  void addNode(std::unique_ptr<RematNode> NUPtr,
+               std::deque<std::unique_ptr<RematNode>> &WorkList,
+               User *FirstUse) {
+    RematNode *N = NUPtr.get();
+    if (Remats.count(N->Node))
+      return;
+
+    // We haven't see this node yet - add to the list
+    Remats[N->Node] = std::move(NUPtr);
+    for (auto &Def : N->Node->operands()) {
+      Instruction *D = dyn_cast<Instruction>(Def.get());
+      if (!D || !materializable(*D) ||
+          !Checker.isDefinitionAcrossSuspend(*D, FirstUse))
+        continue;
+
+      if (Remats.count(D)) {
+        // Already have this in the graph
+        N->Operands.push_back(Remats[D].get());
+        continue;
+      }
+
+      bool NoMatch = true;
+      for (auto &I : WorkList) {
+        if (I->Node == D) {
+          NoMatch = false;
+          N->Operands.push_back(I.get());
+          break;
+        }
+      }
+      if (NoMatch) {
+        // Create a new node
+        std::unique_ptr<RematNode> ChildNode = std::make_unique<RematNode>(D);
+        N->Operands.push_back(ChildNode.get());
+        WorkList.push_back(std::move(ChildNode));
+      }
+    }
+  }
+
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+  void dump() const {
+    dbgs() << "Entry (";
+    if (EntryNode->Node->getParent()->hasName())
+      dbgs() << EntryNode->Node->getParent()->getName();
+    else
+      EntryNode->Node->getParent()->printAsOperand(dbgs(), false);
+    dbgs() << ") : " << *EntryNode->Node << "\n";
+    for (auto &E : Remats) {
+      dbgs() << *(E.first) << "\n";
+      for (RematNode *U : E.second->Operands)
+        dbgs() << "  " << *U->Node << "\n";
+    }
+  }
+#endif
+};
+} // end anonymous namespace
+
+namespace llvm {
+
+template <> struct GraphTraits<RematGraph *> {
+  using NodeRef = RematGraph::RematNode *;
+  using ChildIteratorType = RematGraph::RematNode **;
+
+  static NodeRef getEntryNode(RematGraph *G) { return G->EntryNode; }
+  static ChildIteratorType child_begin(NodeRef N) {
+    return N->Operands.begin();
+  }
+  static ChildIteratorType child_end(NodeRef N) { return N->Operands.end(); }
+};
+
+} // end namespace llvm
+
 #undef DEBUG_TYPE // "coro-suspend-crossing"
 #define DEBUG_TYPE "coro-frame"
 
@@ -425,6 +538,15 @@ static void dumpSpills(StringRef Title, const SpillInfo &Spills) {
       I->dump();
   }
 }
+static void dumpRemats(
+    StringRef Title,
+    const SmallMapVector<Instruction *, std::unique_ptr<RematGraph>, 8> &RM) {
+  dbgs() << "------------- " << Title << "--------------\n";
+  for (const auto &E : RM) {
+    E.second->dump();
+    dbgs() << "--\n";
+  }
+}
 
 static void dumpAllocas(const SmallVectorImpl<AllocaInfo> &Allocas) {
   dbgs() << "------------- Allocas --------------\n";
@@ -2103,41 +2225,82 @@ static bool isCoroutineStructureIntrinsic(Instruction &I) {
          isa<CoroSuspendInst>(&I);
 }
 
-// For every use of the value that is across suspend point, recreate that value
-// after a suspend point.
-static void rewriteMaterializableInstructions(IRBuilder<> &IRB,
-                                              const SpillInfo &Spills) {
-  for (const auto &E : Spills) {
-    Value *Def = E.first;
-    BasicBlock *CurrentBlock = nullptr;
+// For each instruction identified as materializable across the suspend point,
+// and its associated DAG of other rematerializable instructions,
+// recreate the DAG of instructions after the suspend point.
+static void rewriteMaterializableInstructions(
+    const SmallMapVector<Instruction *, std::unique_ptr<RematGraph>, 8>
+        &AllRemats) {
+  // This has to be done in 2 phases
+  // Do the remats and record the required defs to be replaced in the
+  // original use instructions
+  // Once all the remats are complete, replace the uses in the final
+  // instructions with the new defs
+  typedef struct {
+    Instruction *Use;
+    Instruction *Def;
+    Instruction *Remat;
+  } ProcessNode;
+
+  SmallVector<ProcessNode> FinalInstructionsToProcess;
+
+  for (const auto &E : AllRemats) {
+    Instruction *Use = E.first;
     Instruction *CurrentMaterialization = nullptr;
-    for (Instruction *U : E.second) {
-      // If we have not seen this block, materialize the value.
-      if (CurrentBlock != U->getParent()) {
+    RematGraph *RG = E.second.get();
+    ReversePostOrderTraversal<RematGraph *> RPOT(RG);
+    SmallVector<Instruction *> InstructionsToProcess;
+
+    // If the target use is actually a suspend instruction then we have to
+    // insert the remats into the end of the predecessor (there should only be
+    // one). This is so that suspend blocks always have the suspend instruction
+    // as the first instruction.
+    auto InsertPoint = &*Use->getParent()->getFirstInsertionPt();
+    if (isa<AnyCoroSuspendInst>(Use)) {
+      BasicBlock *SuspendPredecessorBlock =
+          Use->getParent()->getSinglePredecessor();
+      assert(SuspendPredecessorBlock && "malformed coro suspend instruction");
+      InsertPoint = SuspendPredecessorBlock->getTerminator();
+    }
 
-        bool IsInCoroSuspendBlock = isa<AnyCoroSuspendInst>(U);
-        CurrentBlock = U->getParent();
-        auto *InsertBlock = IsInCoroSuspendBlock
-                                ? CurrentBlock->getSinglePredecessor()
-                                : CurrentBlock;
-        CurrentMaterialization = cast<Instruction>(Def)->clone();
-        CurrentMaterialization->setName(Def->getName());
-        CurrentMaterialization->insertBefore(
-            IsInCoroSuspendBlock ? InsertBlock->getTerminator()
-                                 : &*InsertBlock->getFirstInsertionPt());
-      }
-      if (auto *PN = dyn_cast<PHINode>(U)) {
-        assert(PN->getNumIncomingValues() == 1 &&
-               "unexpected number of incoming "
-               "values in the PHINode");
-        PN->replaceAllUsesWith(CurrentMaterialization);
-        PN->eraseFromParent();
-        continue;
-      }
-      // Replace all uses of Def in the current instruction with the
-      // CurrentMaterialization for the block.
-      U->replaceUsesOfWith(Def, CurrentMaterialization);
+    // Note: skip the first instruction as this is the actual use that we're
+    // rematerializing everything for.
+    auto I = RPOT.begin();
+    ++I;
+    for (; I != RPOT.end(); ++I) {
+      Instruction *D = (*I)->Node;
+      CurrentMaterialization = D->clone();
+      CurrentMaterialization->setName(D->getName());
+      CurrentMaterialization->insertBefore(InsertPoint);
+      InsertPoint = CurrentMaterialization;
+
+      // Replace all uses of Def in the instructions being added as part of this
+      // rematerialization group
+      for (auto &I : InstructionsToProcess)
+        I->replaceUsesOfWith(D, CurrentMaterialization);
+
+      // Don't replace the final use at this point as this can cause problems
+      // for other materializations. Instead, for any final use that uses a
+      // define that's being rematerialized, record the replace values
+      for (unsigned i = 0, E = Use->getNumOperands(); i != E; ++i)
+        if (Use->getOperand(i) == D) // Is this operand pointing to oldval?
+          FinalInstructionsToProcess.push_back(
+              {Use, D, CurrentMaterialization});
+
+      InstructionsToProcess.push_back(CurrentMaterialization);
+    }
+  }
+
+  // Finally, replace the uses with the defines that we've just rematerialized
+  for (auto &R : FinalInstructionsToProcess) {
+    if (auto *PN = dyn_cast<PHINode>(R.Use)) {
+      assert(PN->getNumIncomingValues() == 1 && "unexpected number of incoming "
+                                                "values in the PHINode");
+      PN->replaceAllUsesWith(R.Remat);
+      PN->eraseFromParent();
+      continue;
     }
+    R.Use->replaceUsesOfWith(R.Def, R.Remat);
   }
 }
 
@@ -2724,6 +2887,62 @@ void coro::salvageDebugInfo(
   }
 }
 
+static void doRematerializations(Function &F, SuspendCrossingInfo &Checker) {
+  SpillInfo Spills;
+
+  // See if there are materializable instructions across suspend points
+  // We record these as the starting point to also identify materializable
+  // defs of uses in these operations
+  for (Instruction &I : instructions(F)) {
+    if (!materializable(I))
+      continue;
+    for (User *U : I.users())
+      if (Checker.isDefinitionAcrossSuspend(I, U))
+        Spills[&I].push_back(cast<Instruction>(U));
+  }
+
+  // Process each of the identified rematerializable instructions
+  // and add predecessor instructions that can also be rematerialized.
+  // This is actually a graph of instructions since we could potentially
+  // have multiple uses of a def in the set of predecessor instructions.
+  // The approach here is to maintain a graph of instructions for each bottom
+  // level instruction - where we have a unique set of instructions (nodes)
+  // and edges between them. We then walk the graph in reverse post-dominator
+  // order to insert them past the suspend point, but ensure that ordering is
+  // correct. We also rely on CSE removing duplicate defs for remats of
+  // 
diff erent instructions with a def in common (rather than maintaining more
+  // complex graphs for each suspend point)
+
+  // We can do this by adding new nodes to the list for each suspend
+  // point. Then using standard GraphTraits to give a reverse post-order
+  // traversal when we insert the nodes after the suspend
+  SmallMapVector<Instruction *, std::unique_ptr<RematGraph>, 8> AllRemats;
+  for (auto &E : Spills) {
+    for (Instruction *U : E.second) {
+      // Don't process a user twice (this can happen if the instruction uses
+      // more than one rematerializable def)
+      if (AllRemats.count(U))
+        continue;
+
+      // Constructor creates the whole RematGraph for the given Use
+      auto RematUPtr = std::make_unique<RematGraph>(U, Checker);
+
+      LLVM_DEBUG(dbgs() << "***** Next remat group *****\n";
+                 ReversePostOrderTraversal<RematGraph *> RPOT(RematUPtr.get());
+                 for (auto I = RPOT.begin(); I != RPOT.end();
+                      ++I) { (*I)->Node->dump(); } dbgs()
+                 << "\n";);
+
+      AllRemats[U] = std::move(RematUPtr);
+    }
+  }
+
+  // Rewrite materializable instructions to be materialized at the use
+  // point.
+  LLVM_DEBUG(dumpRemats("Materializations", AllRemats));
+  rewriteMaterializableInstructions(AllRemats);
+}
+
 void coro::buildCoroutineFrame(Function &F, Shape &Shape) {
   // Don't eliminate swifterror in async functions that won't be split.
   if (Shape.ABI != coro::ABI::Async || !Shape.CoroSuspends.empty())
@@ -2775,35 +2994,11 @@ void coro::buildCoroutineFrame(Function &F, Shape &Shape) {
   // Build suspend crossing info.
   SuspendCrossingInfo Checker(F, Shape);
 
-  IRBuilder<> Builder(F.getContext());
+  doRematerializations(F, Checker);
+
   FrameDataInfo FrameData;
   SmallVector<CoroAllocaAllocInst*, 4> LocalAllocas;
   SmallVector<Instruction*, 4> DeadInstructions;
-
-  {
-    SpillInfo Spills;
-    for (int Repeat = 0; Repeat < 4; ++Repeat) {
-      // See if there are materializable instructions across suspend points.
-      // FIXME: We can use a worklist to track the possible materialize
-      // instructions instead of iterating the whole function again and again.
-      for (Instruction &I : instructions(F))
-        if (materializable(I)) {
-          for (User *U : I.users())
-            if (Checker.isDefinitionAcrossSuspend(I, U))
-              Spills[&I].push_back(cast<Instruction>(U));
-        }
-
-      if (Spills.empty())
-        break;
-
-      // Rewrite materializable instructions to be materialized at the use
-      // point.
-      LLVM_DEBUG(dumpSpills("Materializations", Spills));
-      rewriteMaterializableInstructions(Builder, Spills);
-      Spills.clear();
-    }
-  }
-
   if (Shape.ABI != coro::ABI::Async && Shape.ABI != coro::ABI::Retcon &&
       Shape.ABI != coro::ABI::RetconOnce)
     sinkLifetimeStartMarkers(F, Shape, Checker);

diff  --git a/llvm/test/Transforms/Coroutines/coro-materialize.ll b/llvm/test/Transforms/Coroutines/coro-materialize.ll
index c1002b0bf1c38..b45df16eb8c4b 100644
--- a/llvm/test/Transforms/Coroutines/coro-materialize.ll
+++ b/llvm/test/Transforms/Coroutines/coro-materialize.ll
@@ -4,9 +4,9 @@
 ; See that we only spilled one value for f
 ; CHECK: %f.Frame = type { ptr, ptr, i32, i1 }
 ; Check other variants where 
diff erent levels of materialization are achieved
-; CHECK: %f_multiple_remat.Frame = type { ptr, ptr, i32, i32, i32, i1 }
-; CHECK: %f_common_def.Frame = type { ptr, ptr, i32, i32, i32, i1 }
-; CHECK: %f_common_def_multi_result.Frame = type { ptr, ptr, i32, i32, i32, i32, i32, i32, i32, i1 }
+; CHECK: %f_multiple_remat.Frame = type { ptr, ptr, i32, i1 }
+; CHECK: %f_common_def.Frame = type { ptr, ptr, i32, i1 }
+; CHECK: %f_common_def_multi_result.Frame = type { ptr, ptr, i32, i1 }
 ; CHECK-LABEL: @f(
 ; CHECK-LABEL: @f_multiple_remat(
 ; CHECK-LABEL: @f_common_def(

diff  --git a/llvm/test/Transforms/Coroutines/coro-retcon-remat.ll b/llvm/test/Transforms/Coroutines/coro-retcon-remat.ll
index 3f73fdbaa52b0..584c4e0da87b7 100644
--- a/llvm/test/Transforms/Coroutines/coro-retcon-remat.ll
+++ b/llvm/test/Transforms/Coroutines/coro-retcon-remat.ll
@@ -2,7 +2,7 @@
 ; as expected
 ; RUN: opt < %s -O0 -S | FileCheck %s
 
-; CHECK: %f.Frame = type { i32, i32 }
+; CHECK: %f.Frame = type { i32 }
 
 define { i8*, i32 } @f(i8* %buffer, i32 %n) {
 entry:


        


More information about the llvm-commits mailing list