[llvm] [Coroutines] Move materialization code into its own utils (PR #108240)

Tyler Nowicki via llvm-commits llvm-commits at lists.llvm.org
Wed Sep 11 08:33:18 PDT 2024


https://github.com/TylerNowicki created https://github.com/llvm/llvm-project/pull/108240

* Move materialization out of CoroFrame to MaterializationUtils.h
* Move spill related utilities that were used by materialization to SpillUtils
* Move isSuspendBlock (needed by materialization) to CoroInternal

See RFC for more info: https://discourse.llvm.org/t/rfc-abi-objects-for-coroutines/81057

>From 1a33655db85050d1ac1eaa31bcbc367b3d5be657 Mon Sep 17 00:00:00 2001
From: tnowicki <tnowicki.nowicki at amd.com>
Date: Fri, 23 Aug 2024 18:24:54 -0400
Subject: [PATCH] [Coroutines] Move materialization code into its own utils

* Move materialization out of CoroFrame to MaterializationUtils.h
* Move spill related utilities that were used by materialization to SpillUtils
* Move isSuspendBlock (needed by materialization) to CoroInternal
---
 llvm/lib/Transforms/Coroutines/CMakeLists.txt |   1 +
 llvm/lib/Transforms/Coroutines/CoroFrame.cpp  | 296 +----------------
 llvm/lib/Transforms/Coroutines/CoroInternal.h |   1 +
 llvm/lib/Transforms/Coroutines/Coroutines.cpp |   4 +
 .../Coroutines/MaterializationUtils.cpp       | 308 ++++++++++++++++++
 .../Coroutines/MaterializationUtils.h         |  30 ++
 llvm/lib/Transforms/Coroutines/SpillUtils.cpp |   6 +-
 llvm/lib/Transforms/Coroutines/SpillUtils.h   |   2 -
 8 files changed, 351 insertions(+), 297 deletions(-)
 create mode 100644 llvm/lib/Transforms/Coroutines/MaterializationUtils.cpp
 create mode 100644 llvm/lib/Transforms/Coroutines/MaterializationUtils.h

diff --git a/llvm/lib/Transforms/Coroutines/CMakeLists.txt b/llvm/lib/Transforms/Coroutines/CMakeLists.txt
index c6508174a7f109..46ef5cd4e8cfe8 100644
--- a/llvm/lib/Transforms/Coroutines/CMakeLists.txt
+++ b/llvm/lib/Transforms/Coroutines/CMakeLists.txt
@@ -9,6 +9,7 @@ add_llvm_component_library(LLVMCoroutines
   CoroSplit.cpp
   SuspendCrossingInfo.cpp
   SpillUtils.cpp
+  MaterializationUtils.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${LLVM_MAIN_INCLUDE_DIR}/llvm/Transforms/Coroutines
diff --git a/llvm/lib/Transforms/Coroutines/CoroFrame.cpp b/llvm/lib/Transforms/Coroutines/CoroFrame.cpp
index 8ee4bfa3b888df..b74c9f01cd2395 100644
--- a/llvm/lib/Transforms/Coroutines/CoroFrame.cpp
+++ b/llvm/lib/Transforms/Coroutines/CoroFrame.cpp
@@ -16,10 +16,10 @@
 //===----------------------------------------------------------------------===//
 
 #include "CoroInternal.h"
+#include "MaterializationUtils.h"
 #include "SpillUtils.h"
 #include "SuspendCrossingInfo.h"
 #include "llvm/ADT/BitVector.h"
-#include "llvm/ADT/PostOrderIterator.h"
 #include "llvm/ADT/ScopeExit.h"
 #include "llvm/ADT/SmallString.h"
 #include "llvm/Analysis/StackLifetime.h"
@@ -36,135 +36,12 @@
 #include "llvm/Transforms/Utils/Local.h"
 #include "llvm/Transforms/Utils/PromoteMemToReg.h"
 #include <algorithm>
-#include <deque>
 #include <optional>
 
 using namespace llvm;
 
 extern cl::opt<bool> UseNewDbgInfoFormat;
 
-// The "coro-suspend-crossing" flag is very noisy. There is another debug type,
-// "coro-frame", which results in leaner debug spew.
-#define DEBUG_TYPE "coro-suspend-crossing"
-
-namespace {
-
-// RematGraph is used to construct a DAG for rematerializable instructions
-// When the constructor is invoked with a candidate instruction (which is
-// materializable) it builds a DAG of materializable instructions from that
-// point.
-// Typically, for each instruction identified as re-materializable across a
-// suspend point, a RematGraph will be created.
-struct RematGraph {
-  // Each RematNode in the graph contains the edges to instructions providing
-  // operands in the current node.
-  struct RematNode {
-    Instruction *Node;
-    SmallVector<RematNode *> Operands;
-    RematNode() = default;
-    RematNode(Instruction *V) : Node(V) {}
-  };
-
-  RematNode *EntryNode;
-  using RematNodeMap =
-      SmallMapVector<Instruction *, std::unique_ptr<RematNode>, 8>;
-  RematNodeMap Remats;
-  const std::function<bool(Instruction &)> &MaterializableCallback;
-  SuspendCrossingInfo &Checker;
-
-  RematGraph(const std::function<bool(Instruction &)> &MaterializableCallback,
-             Instruction *I, SuspendCrossingInfo &Checker)
-      : MaterializableCallback(MaterializableCallback), Checker(Checker) {
-    std::unique_ptr<RematNode> FirstNode = std::make_unique<RematNode>(I);
-    EntryNode = FirstNode.get();
-    std::deque<std::unique_ptr<RematNode>> WorkList;
-    addNode(std::move(FirstNode), WorkList, cast<User>(I));
-    while (WorkList.size()) {
-      std::unique_ptr<RematNode> N = std::move(WorkList.front());
-      WorkList.pop_front();
-      addNode(std::move(N), WorkList, cast<User>(I));
-    }
-  }
-
-  void addNode(std::unique_ptr<RematNode> NUPtr,
-               std::deque<std::unique_ptr<RematNode>> &WorkList,
-               User *FirstUse) {
-    RematNode *N = NUPtr.get();
-    if (Remats.count(N->Node))
-      return;
-
-    // We haven't see this node yet - add to the list
-    Remats[N->Node] = std::move(NUPtr);
-    for (auto &Def : N->Node->operands()) {
-      Instruction *D = dyn_cast<Instruction>(Def.get());
-      if (!D || !MaterializableCallback(*D) ||
-          !Checker.isDefinitionAcrossSuspend(*D, FirstUse))
-        continue;
-
-      if (Remats.count(D)) {
-        // Already have this in the graph
-        N->Operands.push_back(Remats[D].get());
-        continue;
-      }
-
-      bool NoMatch = true;
-      for (auto &I : WorkList) {
-        if (I->Node == D) {
-          NoMatch = false;
-          N->Operands.push_back(I.get());
-          break;
-        }
-      }
-      if (NoMatch) {
-        // Create a new node
-        std::unique_ptr<RematNode> ChildNode = std::make_unique<RematNode>(D);
-        N->Operands.push_back(ChildNode.get());
-        WorkList.push_back(std::move(ChildNode));
-      }
-    }
-  }
-
-#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
-  static std::string getBasicBlockLabel(const BasicBlock *BB) {
-    if (BB->hasName())
-      return BB->getName().str();
-
-    std::string S;
-    raw_string_ostream OS(S);
-    BB->printAsOperand(OS, false);
-    return OS.str().substr(1);
-  }
-
-  void dump() const {
-    dbgs() << "Entry (";
-    dbgs() << getBasicBlockLabel(EntryNode->Node->getParent());
-    dbgs() << ") : " << *EntryNode->Node << "\n";
-    for (auto &E : Remats) {
-      dbgs() << *(E.first) << "\n";
-      for (RematNode *U : E.second->Operands)
-        dbgs() << "  " << *U->Node << "\n";
-    }
-  }
-#endif
-};
-} // end anonymous namespace
-
-namespace llvm {
-
-template <> struct GraphTraits<RematGraph *> {
-  using NodeRef = RematGraph::RematNode *;
-  using ChildIteratorType = RematGraph::RematNode **;
-
-  static NodeRef getEntryNode(RematGraph *G) { return G->EntryNode; }
-  static ChildIteratorType child_begin(NodeRef N) {
-    return N->Operands.begin();
-  }
-  static ChildIteratorType child_end(NodeRef N) { return N->Operands.end(); }
-};
-
-} // end namespace llvm
-
-#undef DEBUG_TYPE // "coro-suspend-crossing"
 #define DEBUG_TYPE "coro-frame"
 
 namespace {
@@ -268,15 +145,6 @@ static void dumpSpills(StringRef Title, const coro::SpillInfo &Spills) {
       I->dump();
   }
 }
-static void dumpRemats(
-    StringRef Title,
-    const SmallMapVector<Instruction *, std::unique_ptr<RematGraph>, 8> &RM) {
-  dbgs() << "------------- " << Title << "--------------\n";
-  for (const auto &E : RM) {
-    E.second->dump();
-    dbgs() << "--\n";
-  }
-}
 
 static void dumpAllocas(const SmallVectorImpl<coro::AllocaInfo> &Allocas) {
   dbgs() << "------------- Allocas --------------\n";
@@ -1634,93 +1502,6 @@ static void rewritePHIs(Function &F) {
     rewritePHIs(*BB);
 }
 
-/// Default materializable callback
-// Check for instructions that we can recreate on resume as opposed to spill
-// the result into a coroutine frame.
-bool coro::defaultMaterializable(Instruction &V) {
-  return (isa<CastInst>(&V) || isa<GetElementPtrInst>(&V) ||
-          isa<BinaryOperator>(&V) || isa<CmpInst>(&V) || isa<SelectInst>(&V));
-}
-
-// For each instruction identified as materializable across the suspend point,
-// and its associated DAG of other rematerializable instructions,
-// recreate the DAG of instructions after the suspend point.
-static void rewriteMaterializableInstructions(
-    const SmallMapVector<Instruction *, std::unique_ptr<RematGraph>, 8>
-        &AllRemats) {
-  // This has to be done in 2 phases
-  // Do the remats and record the required defs to be replaced in the
-  // original use instructions
-  // Once all the remats are complete, replace the uses in the final
-  // instructions with the new defs
-  typedef struct {
-    Instruction *Use;
-    Instruction *Def;
-    Instruction *Remat;
-  } ProcessNode;
-
-  SmallVector<ProcessNode> FinalInstructionsToProcess;
-
-  for (const auto &E : AllRemats) {
-    Instruction *Use = E.first;
-    Instruction *CurrentMaterialization = nullptr;
-    RematGraph *RG = E.second.get();
-    ReversePostOrderTraversal<RematGraph *> RPOT(RG);
-    SmallVector<Instruction *> InstructionsToProcess;
-
-    // If the target use is actually a suspend instruction then we have to
-    // insert the remats into the end of the predecessor (there should only be
-    // one). This is so that suspend blocks always have the suspend instruction
-    // as the first instruction.
-    auto InsertPoint = &*Use->getParent()->getFirstInsertionPt();
-    if (isa<AnyCoroSuspendInst>(Use)) {
-      BasicBlock *SuspendPredecessorBlock =
-          Use->getParent()->getSinglePredecessor();
-      assert(SuspendPredecessorBlock && "malformed coro suspend instruction");
-      InsertPoint = SuspendPredecessorBlock->getTerminator();
-    }
-
-    // Note: skip the first instruction as this is the actual use that we're
-    // rematerializing everything for.
-    auto I = RPOT.begin();
-    ++I;
-    for (; I != RPOT.end(); ++I) {
-      Instruction *D = (*I)->Node;
-      CurrentMaterialization = D->clone();
-      CurrentMaterialization->setName(D->getName());
-      CurrentMaterialization->insertBefore(InsertPoint);
-      InsertPoint = CurrentMaterialization;
-
-      // Replace all uses of Def in the instructions being added as part of this
-      // rematerialization group
-      for (auto &I : InstructionsToProcess)
-        I->replaceUsesOfWith(D, CurrentMaterialization);
-
-      // Don't replace the final use at this point as this can cause problems
-      // for other materializations. Instead, for any final use that uses a
-      // define that's being rematerialized, record the replace values
-      for (unsigned i = 0, E = Use->getNumOperands(); i != E; ++i)
-        if (Use->getOperand(i) == D) // Is this operand pointing to oldval?
-          FinalInstructionsToProcess.push_back(
-              {Use, D, CurrentMaterialization});
-
-      InstructionsToProcess.push_back(CurrentMaterialization);
-    }
-  }
-
-  // Finally, replace the uses with the defines that we've just rematerialized
-  for (auto &R : FinalInstructionsToProcess) {
-    if (auto *PN = dyn_cast<PHINode>(R.Use)) {
-      assert(PN->getNumIncomingValues() == 1 && "unexpected number of incoming "
-                                                "values in the PHINode");
-      PN->replaceAllUsesWith(R.Remat);
-      PN->eraseFromParent();
-      continue;
-    }
-    R.Use->replaceUsesOfWith(R.Def, R.Remat);
-  }
-}
-
 // Splits the block at a particular instruction unless it is the first
 // instruction in the block with a single predecessor.
 static BasicBlock *splitBlockIfNotFirst(Instruction *I, const Twine &Name) {
@@ -1741,10 +1522,6 @@ static void splitAround(Instruction *I, const Twine &Name) {
   splitBlockIfNotFirst(I->getNextNode(), "After" + Name);
 }
 
-static bool isSuspendBlock(BasicBlock *BB) {
-  return isa<AnyCoroSuspendInst>(BB->front());
-}
-
 /// After we split the coroutine, will the given basic block be along
 /// an obvious exit path for the resumption function?
 static bool willLeaveFunctionImmediatelyAfter(BasicBlock *BB,
@@ -1754,7 +1531,7 @@ static bool willLeaveFunctionImmediatelyAfter(BasicBlock *BB,
   if (depth == 0) return false;
 
   // If this is a suspend block, we're about to exit the resumption function.
-  if (isSuspendBlock(BB))
+  if (coro::isSuspendBlock(BB))
     return true;
 
   // Recurse into the successors.
@@ -1995,7 +1772,8 @@ static void sinkLifetimeStartMarkers(Function &F, coro::Shape &Shape,
   DomSet.insert(&F.getEntryBlock());
   for (auto *CSI : Shape.CoroSuspends) {
     BasicBlock *SuspendBlock = CSI->getParent();
-    assert(isSuspendBlock(SuspendBlock) && SuspendBlock->getSingleSuccessor() &&
+    assert(coro::isSuspendBlock(SuspendBlock) &&
+           SuspendBlock->getSingleSuccessor() &&
            "should have split coro.suspend into its own block");
     DomSet.insert(SuspendBlock->getSingleSuccessor());
   }
@@ -2227,68 +2005,6 @@ void coro::salvageDebugInfo(
   }
 }
 
-static void doRematerializations(
-    Function &F, SuspendCrossingInfo &Checker,
-    const std::function<bool(Instruction &)> &MaterializableCallback) {
-  if (F.hasOptNone())
-    return;
-
-  coro::SpillInfo Spills;
-
-  // See if there are materializable instructions across suspend points
-  // We record these as the starting point to also identify materializable
-  // defs of uses in these operations
-  for (Instruction &I : instructions(F)) {
-    if (!MaterializableCallback(I))
-      continue;
-    for (User *U : I.users())
-      if (Checker.isDefinitionAcrossSuspend(I, U))
-        Spills[&I].push_back(cast<Instruction>(U));
-  }
-
-  // Process each of the identified rematerializable instructions
-  // and add predecessor instructions that can also be rematerialized.
-  // This is actually a graph of instructions since we could potentially
-  // have multiple uses of a def in the set of predecessor instructions.
-  // The approach here is to maintain a graph of instructions for each bottom
-  // level instruction - where we have a unique set of instructions (nodes)
-  // and edges between them. We then walk the graph in reverse post-dominator
-  // order to insert them past the suspend point, but ensure that ordering is
-  // correct. We also rely on CSE removing duplicate defs for remats of
-  // different instructions with a def in common (rather than maintaining more
-  // complex graphs for each suspend point)
-
-  // We can do this by adding new nodes to the list for each suspend
-  // point. Then using standard GraphTraits to give a reverse post-order
-  // traversal when we insert the nodes after the suspend
-  SmallMapVector<Instruction *, std::unique_ptr<RematGraph>, 8> AllRemats;
-  for (auto &E : Spills) {
-    for (Instruction *U : E.second) {
-      // Don't process a user twice (this can happen if the instruction uses
-      // more than one rematerializable def)
-      if (AllRemats.count(U))
-        continue;
-
-      // Constructor creates the whole RematGraph for the given Use
-      auto RematUPtr =
-          std::make_unique<RematGraph>(MaterializableCallback, U, Checker);
-
-      LLVM_DEBUG(dbgs() << "***** Next remat group *****\n";
-                 ReversePostOrderTraversal<RematGraph *> RPOT(RematUPtr.get());
-                 for (auto I = RPOT.begin(); I != RPOT.end();
-                      ++I) { (*I)->Node->dump(); } dbgs()
-                 << "\n";);
-
-      AllRemats[U] = std::move(RematUPtr);
-    }
-  }
-
-  // Rewrite materializable instructions to be materialized at the use
-  // point.
-  LLVM_DEBUG(dumpRemats("Materializations", AllRemats));
-  rewriteMaterializableInstructions(AllRemats);
-}
-
 void coro::normalizeCoroutine(Function &F, coro::Shape &Shape,
                               TargetTransformInfo &TTI) {
   // Don't eliminate swifterror in async functions that won't be split.
@@ -2324,8 +2040,8 @@ void coro::normalizeCoroutine(Function &F, coro::Shape &Shape,
       IRBuilder<> Builder(AsyncEnd);
       SmallVector<Value *, 8> Args(AsyncEnd->args());
       auto Arguments = ArrayRef<Value *>(Args).drop_front(3);
-      auto *Call = createMustTailCall(AsyncEnd->getDebugLoc(), MustTailCallFn,
-                                      TTI, Arguments, Builder);
+      auto *Call = coro::createMustTailCall(
+          AsyncEnd->getDebugLoc(), MustTailCallFn, TTI, Arguments, Builder);
       splitAround(Call, "MustTailCall.Before.CoroEnd");
     }
   }
diff --git a/llvm/lib/Transforms/Coroutines/CoroInternal.h b/llvm/lib/Transforms/Coroutines/CoroInternal.h
index 698c21a797420a..891798f53b2d00 100644
--- a/llvm/lib/Transforms/Coroutines/CoroInternal.h
+++ b/llvm/lib/Transforms/Coroutines/CoroInternal.h
@@ -21,6 +21,7 @@ class CallGraph;
 
 namespace coro {
 
+bool isSuspendBlock(BasicBlock *BB);
 bool declaresAnyIntrinsic(const Module &M);
 bool declaresIntrinsics(const Module &M,
                         const std::initializer_list<StringRef>);
diff --git a/llvm/lib/Transforms/Coroutines/Coroutines.cpp b/llvm/lib/Transforms/Coroutines/Coroutines.cpp
index be257339e0ac49..cdc442bc819c37 100644
--- a/llvm/lib/Transforms/Coroutines/Coroutines.cpp
+++ b/llvm/lib/Transforms/Coroutines/Coroutines.cpp
@@ -100,6 +100,10 @@ static bool isCoroutineIntrinsicName(StringRef Name) {
 }
 #endif
 
+bool coro::isSuspendBlock(BasicBlock *BB) {
+  return isa<AnyCoroSuspendInst>(BB->front());
+}
+
 bool coro::declaresAnyIntrinsic(const Module &M) {
   for (StringRef Name : CoroIntrinsics) {
     assert(isCoroutineIntrinsicName(Name) && "not a coroutine intrinsic");
diff --git a/llvm/lib/Transforms/Coroutines/MaterializationUtils.cpp b/llvm/lib/Transforms/Coroutines/MaterializationUtils.cpp
new file mode 100644
index 00000000000000..708e8734175f93
--- /dev/null
+++ b/llvm/lib/Transforms/Coroutines/MaterializationUtils.cpp
@@ -0,0 +1,308 @@
+//===- MaterializationUtils.cpp - Builds and manipulates coroutine frame
+//-------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+// This file contains classes used to materialize insts after suspends points.
+//===----------------------------------------------------------------------===//
+
+#include "MaterializationUtils.h"
+#include "SpillUtils.h"
+#include "llvm/ADT/PostOrderIterator.h"
+#include "llvm/IR/Dominators.h"
+#include "llvm/IR/InstIterator.h"
+#include "llvm/IR/Instruction.h"
+#include <deque>
+
+using namespace llvm;
+
+using namespace coro;
+
+// The "coro-suspend-crossing" flag is very noisy. There is another debug type,
+// "coro-frame", which results in leaner debug spew.
+#define DEBUG_TYPE "coro-suspend-crossing"
+
+namespace {
+
+// RematGraph is used to construct a DAG for rematerializable instructions
+// When the constructor is invoked with a candidate instruction (which is
+// materializable) it builds a DAG of materializable instructions from that
+// point.
+// Typically, for each instruction identified as re-materializable across a
+// suspend point, a RematGraph will be created.
+struct RematGraph {
+  // Each RematNode in the graph contains the edges to instructions providing
+  // operands in the current node.
+  struct RematNode {
+    Instruction *Node;
+    SmallVector<RematNode *> Operands;
+    RematNode() = default;
+    RematNode(Instruction *V) : Node(V) {}
+  };
+
+  RematNode *EntryNode;
+  using RematNodeMap =
+      SmallMapVector<Instruction *, std::unique_ptr<RematNode>, 8>;
+  RematNodeMap Remats;
+  const std::function<bool(Instruction &)> &MaterializableCallback;
+  SuspendCrossingInfo &Checker;
+
+  RematGraph(const std::function<bool(Instruction &)> &MaterializableCallback,
+             Instruction *I, SuspendCrossingInfo &Checker)
+      : MaterializableCallback(MaterializableCallback), Checker(Checker) {
+    std::unique_ptr<RematNode> FirstNode = std::make_unique<RematNode>(I);
+    EntryNode = FirstNode.get();
+    std::deque<std::unique_ptr<RematNode>> WorkList;
+    addNode(std::move(FirstNode), WorkList, cast<User>(I));
+    while (WorkList.size()) {
+      std::unique_ptr<RematNode> N = std::move(WorkList.front());
+      WorkList.pop_front();
+      addNode(std::move(N), WorkList, cast<User>(I));
+    }
+  }
+
+  void addNode(std::unique_ptr<RematNode> NUPtr,
+               std::deque<std::unique_ptr<RematNode>> &WorkList,
+               User *FirstUse) {
+    RematNode *N = NUPtr.get();
+    if (Remats.count(N->Node))
+      return;
+
+    // We haven't see this node yet - add to the list
+    Remats[N->Node] = std::move(NUPtr);
+    for (auto &Def : N->Node->operands()) {
+      Instruction *D = dyn_cast<Instruction>(Def.get());
+      if (!D || !MaterializableCallback(*D) ||
+          !Checker.isDefinitionAcrossSuspend(*D, FirstUse))
+        continue;
+
+      if (Remats.count(D)) {
+        // Already have this in the graph
+        N->Operands.push_back(Remats[D].get());
+        continue;
+      }
+
+      bool NoMatch = true;
+      for (auto &I : WorkList) {
+        if (I->Node == D) {
+          NoMatch = false;
+          N->Operands.push_back(I.get());
+          break;
+        }
+      }
+      if (NoMatch) {
+        // Create a new node
+        std::unique_ptr<RematNode> ChildNode = std::make_unique<RematNode>(D);
+        N->Operands.push_back(ChildNode.get());
+        WorkList.push_back(std::move(ChildNode));
+      }
+    }
+  }
+
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+  static std::string getBasicBlockLabel(const BasicBlock *BB) {
+    if (BB->hasName())
+      return BB->getName().str();
+
+    std::string S;
+    raw_string_ostream OS(S);
+    BB->printAsOperand(OS, false);
+    return OS.str().substr(1);
+  }
+
+  void dump() const {
+    dbgs() << "Entry (";
+    dbgs() << getBasicBlockLabel(EntryNode->Node->getParent());
+    dbgs() << ") : " << *EntryNode->Node << "\n";
+    for (auto &E : Remats) {
+      dbgs() << *(E.first) << "\n";
+      for (RematNode *U : E.second->Operands)
+        dbgs() << "  " << *U->Node << "\n";
+    }
+  }
+#endif
+};
+
+} // namespace
+
+namespace llvm {
+template <> struct GraphTraits<RematGraph *> {
+  using NodeRef = RematGraph::RematNode *;
+  using ChildIteratorType = RematGraph::RematNode **;
+
+  static NodeRef getEntryNode(RematGraph *G) { return G->EntryNode; }
+  static ChildIteratorType child_begin(NodeRef N) {
+    return N->Operands.begin();
+  }
+  static ChildIteratorType child_end(NodeRef N) { return N->Operands.end(); }
+};
+
+} // end namespace llvm
+
+// For each instruction identified as materializable across the suspend point,
+// and its associated DAG of other rematerializable instructions,
+// recreate the DAG of instructions after the suspend point.
+static void rewriteMaterializableInstructions(
+    const SmallMapVector<Instruction *, std::unique_ptr<RematGraph>, 8>
+        &AllRemats) {
+  // This has to be done in 2 phases
+  // Do the remats and record the required defs to be replaced in the
+  // original use instructions
+  // Once all the remats are complete, replace the uses in the final
+  // instructions with the new defs
+  typedef struct {
+    Instruction *Use;
+    Instruction *Def;
+    Instruction *Remat;
+  } ProcessNode;
+
+  SmallVector<ProcessNode> FinalInstructionsToProcess;
+
+  for (const auto &E : AllRemats) {
+    Instruction *Use = E.first;
+    Instruction *CurrentMaterialization = nullptr;
+    RematGraph *RG = E.second.get();
+    ReversePostOrderTraversal<RematGraph *> RPOT(RG);
+    SmallVector<Instruction *> InstructionsToProcess;
+
+    // If the target use is actually a suspend instruction then we have to
+    // insert the remats into the end of the predecessor (there should only be
+    // one). This is so that suspend blocks always have the suspend instruction
+    // as the first instruction.
+    auto InsertPoint = &*Use->getParent()->getFirstInsertionPt();
+    if (isa<AnyCoroSuspendInst>(Use)) {
+      BasicBlock *SuspendPredecessorBlock =
+          Use->getParent()->getSinglePredecessor();
+      assert(SuspendPredecessorBlock && "malformed coro suspend instruction");
+      InsertPoint = SuspendPredecessorBlock->getTerminator();
+    }
+
+    // Note: skip the first instruction as this is the actual use that we're
+    // rematerializing everything for.
+    auto I = RPOT.begin();
+    ++I;
+    for (; I != RPOT.end(); ++I) {
+      Instruction *D = (*I)->Node;
+      CurrentMaterialization = D->clone();
+      CurrentMaterialization->setName(D->getName());
+      CurrentMaterialization->insertBefore(InsertPoint);
+      InsertPoint = CurrentMaterialization;
+
+      // Replace all uses of Def in the instructions being added as part of this
+      // rematerialization group
+      for (auto &I : InstructionsToProcess)
+        I->replaceUsesOfWith(D, CurrentMaterialization);
+
+      // Don't replace the final use at this point as this can cause problems
+      // for other materializations. Instead, for any final use that uses a
+      // define that's being rematerialized, record the replace values
+      for (unsigned i = 0, E = Use->getNumOperands(); i != E; ++i)
+        if (Use->getOperand(i) == D) // Is this operand pointing to oldval?
+          FinalInstructionsToProcess.push_back(
+              {Use, D, CurrentMaterialization});
+
+      InstructionsToProcess.push_back(CurrentMaterialization);
+    }
+  }
+
+  // Finally, replace the uses with the defines that we've just rematerialized
+  for (auto &R : FinalInstructionsToProcess) {
+    if (auto *PN = dyn_cast<PHINode>(R.Use)) {
+      assert(PN->getNumIncomingValues() == 1 && "unexpected number of incoming "
+                                                "values in the PHINode");
+      PN->replaceAllUsesWith(R.Remat);
+      PN->eraseFromParent();
+      continue;
+    }
+    R.Use->replaceUsesOfWith(R.Def, R.Remat);
+  }
+}
+
+/// Default materializable callback
+// Check for instructions that we can recreate on resume as opposed to spill
+// the result into a coroutine frame.
+bool llvm::coro::defaultMaterializable(Instruction &V) {
+  return (isa<CastInst>(&V) || isa<GetElementPtrInst>(&V) ||
+          isa<BinaryOperator>(&V) || isa<CmpInst>(&V) || isa<SelectInst>(&V));
+}
+
+bool llvm::coro::isTriviallyMaterializable(Instruction &V) {
+  return defaultMaterializable(V);
+}
+
+#ifndef NDEBUG
+static void dumpRemats(
+    StringRef Title,
+    const SmallMapVector<Instruction *, std::unique_ptr<RematGraph>, 8> &RM) {
+  dbgs() << "------------- " << Title << "--------------\n";
+  for (const auto &E : RM) {
+    E.second->dump();
+    dbgs() << "--\n";
+  }
+}
+#endif
+
+void coro::doRematerializations(
+    Function &F, SuspendCrossingInfo &Checker,
+    std::function<bool(Instruction &)> IsMaterializable) {
+  if (F.hasOptNone())
+    return;
+
+  coro::SpillInfo Spills;
+
+  // See if there are materializable instructions across suspend points
+  // We record these as the starting point to also identify materializable
+  // defs of uses in these operations
+  for (Instruction &I : instructions(F)) {
+    if (!IsMaterializable(I))
+      continue;
+    for (User *U : I.users())
+      if (Checker.isDefinitionAcrossSuspend(I, U))
+        Spills[&I].push_back(cast<Instruction>(U));
+  }
+
+  // Process each of the identified rematerializable instructions
+  // and add predecessor instructions that can also be rematerialized.
+  // This is actually a graph of instructions since we could potentially
+  // have multiple uses of a def in the set of predecessor instructions.
+  // The approach here is to maintain a graph of instructions for each bottom
+  // level instruction - where we have a unique set of instructions (nodes)
+  // and edges between them. We then walk the graph in reverse post-dominator
+  // order to insert them past the suspend point, but ensure that ordering is
+  // correct. We also rely on CSE removing duplicate defs for remats of
+  // different instructions with a def in common (rather than maintaining more
+  // complex graphs for each suspend point)
+
+  // We can do this by adding new nodes to the list for each suspend
+  // point. Then using standard GraphTraits to give a reverse post-order
+  // traversal when we insert the nodes after the suspend
+  SmallMapVector<Instruction *, std::unique_ptr<RematGraph>, 8> AllRemats;
+  for (auto &E : Spills) {
+    for (Instruction *U : E.second) {
+      // Don't process a user twice (this can happen if the instruction uses
+      // more than one rematerializable def)
+      if (AllRemats.count(U))
+        continue;
+
+      // Constructor creates the whole RematGraph for the given Use
+      auto RematUPtr =
+          std::make_unique<RematGraph>(IsMaterializable, U, Checker);
+
+      LLVM_DEBUG(dbgs() << "***** Next remat group *****\n";
+                 ReversePostOrderTraversal<RematGraph *> RPOT(RematUPtr.get());
+                 for (auto I = RPOT.begin(); I != RPOT.end();
+                      ++I) { (*I)->Node->dump(); } dbgs()
+                 << "\n";);
+
+      AllRemats[U] = std::move(RematUPtr);
+    }
+  }
+
+  // Rewrite materializable instructions to be materialized at the use
+  // point.
+  LLVM_DEBUG(dumpRemats("Materializations", AllRemats));
+  rewriteMaterializableInstructions(AllRemats);
+}
diff --git a/llvm/lib/Transforms/Coroutines/MaterializationUtils.h b/llvm/lib/Transforms/Coroutines/MaterializationUtils.h
new file mode 100644
index 00000000000000..f391851c97b3b6
--- /dev/null
+++ b/llvm/lib/Transforms/Coroutines/MaterializationUtils.h
@@ -0,0 +1,30 @@
+//===- MaterializationUtils.h - Utilities for doing materialization -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "SuspendCrossingInfo.h"
+#include "llvm/IR/Instruction.h"
+
+#ifndef LIB_TRANSFORMS_COROUTINES_MATERIALIZATIONUTILS_H
+#define LIB_TRANSFORMS_COROUTINES_MATERIALIZATIONUTILS_H
+
+namespace llvm {
+
+namespace coro {
+
+// True if I is trivially rematerialzable, e.g. InsertElementInst
+bool isTriviallyMaterializable(Instruction &I);
+
+// Performs rematerialization, invoked from buildCoroutineFrame.
+void doRematerializations(Function &F, SuspendCrossingInfo &Checker,
+                          std::function<bool(Instruction &)> IsMaterializable);
+
+} // namespace coro
+
+} // namespace llvm
+
+#endif // LIB_TRANSFORMS_COROUTINES_MATERIALIZATIONUTILS_H
diff --git a/llvm/lib/Transforms/Coroutines/SpillUtils.cpp b/llvm/lib/Transforms/Coroutines/SpillUtils.cpp
index d71b0a336d4715..f213ac1c8d7d57 100644
--- a/llvm/lib/Transforms/Coroutines/SpillUtils.cpp
+++ b/llvm/lib/Transforms/Coroutines/SpillUtils.cpp
@@ -23,10 +23,6 @@ namespace {
 
 typedef SmallPtrSet<BasicBlock *, 8> VisitedBlocksSet;
 
-static bool isSuspendBlock(BasicBlock *BB) {
-  return isa<AnyCoroSuspendInst>(BB->front());
-}
-
 // Check for structural coroutine intrinsics that should not be spilled into
 // the coroutine frame.
 static bool isCoroutineStructureIntrinsic(Instruction &I) {
@@ -45,7 +41,7 @@ static bool isSuspendReachableFrom(BasicBlock *From,
     return false;
 
   // We assume that we'll already have split suspends into their own blocks.
-  if (isSuspendBlock(From))
+  if (coro::isSuspendBlock(From))
     return true;
 
   // Recurse on the successors.
diff --git a/llvm/lib/Transforms/Coroutines/SpillUtils.h b/llvm/lib/Transforms/Coroutines/SpillUtils.h
index de0ff0bcd3a4fd..8843b611e08424 100644
--- a/llvm/lib/Transforms/Coroutines/SpillUtils.h
+++ b/llvm/lib/Transforms/Coroutines/SpillUtils.h
@@ -29,8 +29,6 @@ struct AllocaInfo {
         MayWriteBeforeCoroBegin(MayWriteBeforeCoroBegin) {}
 };
 
-bool isSuspendBlock(BasicBlock *BB);
-
 void collectSpillsFromArgs(SpillInfo &Spills, Function &F,
                            const SuspendCrossingInfo &Checker);
 void collectSpillsAndAllocasFromInsts(



More information about the llvm-commits mailing list