[llvm] 72277ec - Introduce a CallGraph updater helper class

Johannes Doerfert via llvm-commits llvm-commits at lists.llvm.org
Sat Feb 8 12:18:03 PST 2020


Author: Johannes Doerfert
Date: 2020-02-08T14:16:48-06:00
New Revision: 72277ecd62e28a01bb98866c1b15d5f172ed30dc

URL: https://github.com/llvm/llvm-project/commit/72277ecd62e28a01bb98866c1b15d5f172ed30dc
DIFF: https://github.com/llvm/llvm-project/commit/72277ecd62e28a01bb98866c1b15d5f172ed30dc.diff

LOG: Introduce a CallGraph updater helper class

The CallGraphUpdater is a helper that simplifies the process of updating
the call graph, both old and new style, while running an CGSCC pass.

The uses are contained in different commits, e.g. D70767.

More functionality is added as we need it.

Reviewed By: modocache, hfinkel

Differential Revision: https://reviews.llvm.org/D70927

Added: 
    llvm/include/llvm/Transforms/Utils/CallGraphUpdater.h
    llvm/lib/Transforms/Utils/CallGraphUpdater.cpp

Modified: 
    llvm/include/llvm/Analysis/CallGraph.h
    llvm/include/llvm/Analysis/LazyCallGraph.h
    llvm/lib/Analysis/CallGraph.cpp
    llvm/lib/Analysis/CallGraphSCCPass.cpp
    llvm/lib/Analysis/LazyCallGraph.cpp
    llvm/lib/Transforms/Utils/CMakeLists.txt
    llvm/unittests/Analysis/CGSCCPassManagerTest.cpp
    llvm/unittests/Analysis/CMakeLists.txt
    llvm/unittests/Analysis/LazyCallGraphTest.cpp
    llvm/unittests/IR/CMakeLists.txt
    llvm/unittests/IR/LegacyPassManagerTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/CallGraph.h b/llvm/include/llvm/Analysis/CallGraph.h
index c148950cd427..768dcea62ea9 100644
--- a/llvm/include/llvm/Analysis/CallGraph.h
+++ b/llvm/include/llvm/Analysis/CallGraph.h
@@ -94,10 +94,6 @@ class CallGraph {
   /// callers from the old function to the new.
   void spliceFunction(const Function *From, const Function *To);
 
-  /// Add a function to the call graph, and link the node to all of the
-  /// functions that it calls.
-  void addToCallGraph(Function *F);
-
 public:
   explicit CallGraph(Module &M);
   CallGraph(CallGraph &&Arg);
@@ -158,6 +154,13 @@ class CallGraph {
   /// Similar to operator[], but this will insert a new CallGraphNode for
   /// \c F if one does not already exist.
   CallGraphNode *getOrInsertFunction(const Function *F);
+
+  /// Populate \p CGN based on the calls inside the associated function.
+  void populateCallGraphNode(CallGraphNode *CGN);
+
+  /// Add a function to the call graph, and link the node to all of the
+  /// functions that it calls.
+  void addToCallGraph(Function *F);
 };
 
 /// A node in the call graph for a module.

diff  --git a/llvm/include/llvm/Analysis/LazyCallGraph.h b/llvm/include/llvm/Analysis/LazyCallGraph.h
index 2dce9e055694..c0fbadb73dcf 100644
--- a/llvm/include/llvm/Analysis/LazyCallGraph.h
+++ b/llvm/include/llvm/Analysis/LazyCallGraph.h
@@ -1058,6 +1058,9 @@ class LazyCallGraph {
   /// fully visited by the DFS prior to calling this routine.
   void removeDeadFunction(Function &F);
 
+  /// Introduce a node for the function \p NewF in the SCC \p C.
+  void addNewFunctionIntoSCC(Function &NewF, SCC &C);
+
   ///@}
 
   ///@{

diff  --git a/llvm/include/llvm/Transforms/Utils/CallGraphUpdater.h b/llvm/include/llvm/Transforms/Utils/CallGraphUpdater.h
new file mode 100644
index 000000000000..728028d0d114
--- /dev/null
+++ b/llvm/include/llvm/Transforms/Utils/CallGraphUpdater.h
@@ -0,0 +1,106 @@
+//===- CallGraphUpdater.h - A (lazy) call graph update helper ---*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+/// \file
+///
+/// This file provides interfaces used to manipulate a call graph, regardless
+/// if it is a "old style" CallGraph or an "new style" LazyCallGraph.
+///
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_TRANSFORMS_UTILS_CALLGRAPHUPDATER_H
+#define LLVM_TRANSFORMS_UTILS_CALLGRAPHUPDATER_H
+
+#include "llvm/Analysis/CGSCCPassManager.h"
+#include "llvm/Analysis/CallGraph.h"
+#include "llvm/Analysis/CallGraphSCCPass.h"
+#include "llvm/Analysis/LazyCallGraph.h"
+
+namespace llvm {
+
+/// Wrapper to unify "old style" CallGraph and "new style" LazyCallGraph. This
+/// simplifies the interface and the call sites, e.g., new and old pass manager
+/// passes can share the same code.
+class CallGraphUpdater {
+  /// Containers for functions which we did replace or want to delete when
+  /// `finalize` is called. This can happen explicitly or as part of the
+  /// destructor. Dead functions in comdat sections are tracked seperatly
+  /// because a function with discardable linakage in a COMDAT should only
+  /// be dropped if the entire COMDAT is dropped, see git ac07703842cf.
+  ///{
+  SmallPtrSet<Function *, 16> ReplacedFunctions;
+  SmallVector<Function *, 16> DeadFunctions;
+  SmallVector<Function *, 16> DeadFunctionsInComdats;
+  ///}
+
+  /// Old PM variables
+  ///{
+  CallGraph *CG = nullptr;
+  CallGraphSCC *CGSCC = nullptr;
+  ///}
+
+  /// New PM variables
+  ///{
+  LazyCallGraph *LCG = nullptr;
+  LazyCallGraph::SCC *SCC = nullptr;
+  CGSCCAnalysisManager *AM = nullptr;
+  CGSCCUpdateResult *UR = nullptr;
+  ///}
+
+public:
+  CallGraphUpdater() {}
+  ~CallGraphUpdater() { finalize(); }
+
+  /// Initializers for usage outside of a CGSCC pass, inside a CGSCC pass in
+  /// the old and new pass manager (PM).
+  ///{
+  void initialize(CallGraph &CG, CallGraphSCC &SCC) {
+    this->CG = &CG;
+    this->CGSCC = &SCC;
+  }
+  void initialize(LazyCallGraph &LCG, LazyCallGraph::SCC &SCC,
+                  CGSCCAnalysisManager &AM, CGSCCUpdateResult &UR) {
+    this->LCG = &LCG;
+    this->SCC = &SCC;
+    this->AM = &AM;
+    this->UR = &UR;
+  }
+  ///}
+
+  /// Finalizer that will trigger actions like function removal from the CG.
+  bool finalize();
+
+  /// Remove \p Fn from the call graph.
+  void removeFunction(Function &Fn);
+
+  /// After an CGSCC pass changes a function in ways that affect the call
+  /// graph, this method can be called to update it.
+  void reanalyzeFunction(Function &Fn);
+
+  /// If a new function was created by outlining, this method can be called
+  /// to update the call graph for the new function. Note that the old one
+  /// still needs to be re-analyzed or manually updated.
+  void registerOutlinedFunction(Function &NewFn);
+
+  /// Replace \p OldFn in the call graph (and SCC) with \p NewFn. The uses
+  /// outside the call graph and the function \p OldFn are not modified.
+  /// Note that \p OldFn is also removed from the call graph
+  /// (\see removeFunction).
+  void replaceFunctionWith(Function &OldFn, Function &NewFn);
+
+  /// Remove the call site \p CS from the call graph.
+  void removeCallSite(CallBase &CS);
+
+  /// Replace \p OldCS with the new call site \p NewCS.
+  /// \return True if the replacement was successful, otherwise False. In the
+  /// latter case the parent function of \p OldCB needs to be re-analyzed.
+  bool replaceCallSite(CallBase &OldCS, CallBase &NewCS);
+};
+
+} // end namespace llvm
+
+#endif // LLVM_TRANSFORMS_UTILS_CALLGRAPHUPDATER_H

diff  --git a/llvm/lib/Analysis/CallGraph.cpp b/llvm/lib/Analysis/CallGraph.cpp
index 99cdf3f0db6c..777821cc57cc 100644
--- a/llvm/lib/Analysis/CallGraph.cpp
+++ b/llvm/lib/Analysis/CallGraph.cpp
@@ -74,6 +74,12 @@ void CallGraph::addToCallGraph(Function *F) {
   if (!F->hasLocalLinkage() || F->hasAddressTaken())
     ExternalCallingNode->addCalledFunction(nullptr, Node);
 
+  populateCallGraphNode(Node);
+}
+
+void CallGraph::populateCallGraphNode(CallGraphNode *Node) {
+  Function *F = Node->getFunction();
+
   // If this function is not defined in this translation unit, it could call
   // anything.
   if (F->isDeclaration() && !F->isIntrinsic())

diff  --git a/llvm/lib/Analysis/CallGraphSCCPass.cpp b/llvm/lib/Analysis/CallGraphSCCPass.cpp
index 196ef400bc4e..0c6c398ee1bc 100644
--- a/llvm/lib/Analysis/CallGraphSCCPass.cpp
+++ b/llvm/lib/Analysis/CallGraphSCCPass.cpp
@@ -549,7 +549,10 @@ void CallGraphSCC::ReplaceNode(CallGraphNode *Old, CallGraphNode *New) {
   for (unsigned i = 0; ; ++i) {
     assert(i != Nodes.size() && "Node not in SCC");
     if (Nodes[i] != Old) continue;
-    Nodes[i] = New;
+    if (New)
+      Nodes[i] = New;
+    else
+      Nodes.erase(Nodes.begin() + i);
     break;
   }
 

diff  --git a/llvm/lib/Analysis/LazyCallGraph.cpp b/llvm/lib/Analysis/LazyCallGraph.cpp
index b54ef3154e14..cdf1d55f5ba2 100644
--- a/llvm/lib/Analysis/LazyCallGraph.cpp
+++ b/llvm/lib/Analysis/LazyCallGraph.cpp
@@ -1566,6 +1566,15 @@ void LazyCallGraph::removeDeadFunction(Function &F) {
   // allocators.
 }
 
+void LazyCallGraph::addNewFunctionIntoSCC(Function &NewF, SCC &C) {
+  Node &CGNode = get(NewF);
+  CGNode.DFSNumber = CGNode.LowLink = -1;
+  CGNode.populate();
+  C.Nodes.push_back(&CGNode);
+  SCCMap[&CGNode] = &C;
+  NodeMap[&NewF] = &CGNode;
+}
+
 LazyCallGraph::Node &LazyCallGraph::insertInto(Function &F, Node *&MappedN) {
   return *new (MappedN = BPA.Allocate()) Node(*this, F);
 }

diff  --git a/llvm/lib/Transforms/Utils/CMakeLists.txt b/llvm/lib/Transforms/Utils/CMakeLists.txt
index 9bd49436811f..24a5202c9f3d 100644
--- a/llvm/lib/Transforms/Utils/CMakeLists.txt
+++ b/llvm/lib/Transforms/Utils/CMakeLists.txt
@@ -7,6 +7,7 @@ add_llvm_component_library(LLVMTransformUtils
   BuildLibCalls.cpp
   BypassSlowDivision.cpp
   CallPromotionUtils.cpp
+  CallGraphUpdater.cpp
   CanonicalizeAliases.cpp
   CloneFunction.cpp
   CloneModule.cpp

diff  --git a/llvm/lib/Transforms/Utils/CallGraphUpdater.cpp b/llvm/lib/Transforms/Utils/CallGraphUpdater.cpp
new file mode 100644
index 000000000000..f26facf2216b
--- /dev/null
+++ b/llvm/lib/Transforms/Utils/CallGraphUpdater.cpp
@@ -0,0 +1,152 @@
+//===- CallGraphUpdater.cpp - A (lazy) call graph update helper -----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+/// \file
+///
+/// This file provides interfaces used to manipulate a call graph, regardless
+/// if it is a "old style" CallGraph or an "new style" LazyCallGraph.
+///
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Transforms/Utils/CallGraphUpdater.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Transforms/Utils/ModuleUtils.h"
+
+using namespace llvm;
+
+bool CallGraphUpdater::finalize() {
+  if (!DeadFunctionsInComdats.empty()) {
+    filterDeadComdatFunctions(*DeadFunctionsInComdats.front()->getParent(),
+                              DeadFunctionsInComdats);
+    DeadFunctions.append(DeadFunctionsInComdats.begin(),
+                         DeadFunctionsInComdats.end());
+  }
+
+  for (Function *DeadFn : DeadFunctions) {
+    DeadFn->removeDeadConstantUsers();
+
+    if (CG) {
+      CallGraphNode *OldCGN = CG->getOrInsertFunction(DeadFn);
+      CG->getExternalCallingNode()->removeAnyCallEdgeTo(OldCGN);
+      OldCGN->removeAllCalledFunctions();
+      DeadFn->replaceAllUsesWith(UndefValue::get(DeadFn->getType()));
+
+      assert(OldCGN->getNumReferences() == 0);
+
+      delete CG->removeFunctionFromModule(OldCGN);
+      continue;
+    }
+
+    // The old style call graph (CG) has a value handle we do not want to
+    // replace with undef so we do this here.
+    DeadFn->replaceAllUsesWith(UndefValue::get(DeadFn->getType()));
+
+    if (LCG && !ReplacedFunctions.count(DeadFn)) {
+      // Taken mostly from the inliner:
+      FunctionAnalysisManager &FAM =
+          AM->getResult<FunctionAnalysisManagerCGSCCProxy>(*SCC, *LCG)
+              .getManager();
+
+      LazyCallGraph::Node &N = LCG->get(*DeadFn);
+      auto *DeadSCC = LCG->lookupSCC(N);
+      assert(DeadSCC && DeadSCC->size() == 1 &&
+             &DeadSCC->begin()->getFunction() == DeadFn);
+      auto &DeadRC = DeadSCC->getOuterRefSCC();
+
+      FAM.clear(*DeadFn, DeadFn->getName());
+      AM->clear(*DeadSCC, DeadSCC->getName());
+      LCG->removeDeadFunction(*DeadFn);
+
+      // Mark the relevant parts of the call graph as invalid so we don't visit
+      // them.
+      UR->InvalidatedSCCs.insert(DeadSCC);
+      UR->InvalidatedRefSCCs.insert(&DeadRC);
+    }
+
+    // The function is now really dead and de-attached from everything.
+    DeadFn->eraseFromParent();
+  }
+
+  bool Changed = !DeadFunctions.empty();
+  DeadFunctionsInComdats.clear();
+  DeadFunctions.clear();
+  return Changed;
+}
+
+void CallGraphUpdater::reanalyzeFunction(Function &Fn) {
+  if (CG) {
+    CallGraphNode *OldCGN = CG->getOrInsertFunction(&Fn);
+    OldCGN->removeAllCalledFunctions();
+    CG->populateCallGraphNode(OldCGN);
+  } else if (LCG) {
+    LazyCallGraph::Node &N = LCG->get(Fn);
+    LazyCallGraph::SCC *C = LCG->lookupSCC(N);
+    updateCGAndAnalysisManagerForCGSCCPass(*LCG, *C, N, *AM, *UR);
+  }
+}
+
+void CallGraphUpdater::registerOutlinedFunction(Function &NewFn) {
+  if (CG)
+    CG->addToCallGraph(&NewFn);
+  else if (LCG)
+    LCG->addNewFunctionIntoSCC(NewFn, *SCC);
+}
+
+void CallGraphUpdater::removeFunction(Function &DeadFn) {
+  DeadFn.deleteBody();
+  DeadFn.setLinkage(GlobalValue::ExternalLinkage);
+  if (DeadFn.hasComdat())
+    DeadFunctionsInComdats.push_back(&DeadFn);
+  else
+    DeadFunctions.push_back(&DeadFn);
+}
+
+void CallGraphUpdater::replaceFunctionWith(Function &OldFn, Function &NewFn) {
+  ReplacedFunctions.insert(&OldFn);
+  if (CG) {
+    // Update the call graph for the newly promoted function.
+    // CG->spliceFunction(&OldFn, &NewFn);
+    CallGraphNode *OldCGN = (*CG)[&OldFn];
+    CallGraphNode *NewCGN = CG->getOrInsertFunction(&NewFn);
+    NewCGN->stealCalledFunctionsFrom(OldCGN);
+
+    // And update the SCC we're iterating as well.
+    CGSCC->ReplaceNode(OldCGN, NewCGN);
+  } else if (LCG) {
+    // Directly substitute the functions in the call graph.
+    LazyCallGraph::Node &OldLCGN = LCG->get(OldFn);
+    SCC->getOuterRefSCC().replaceNodeFunction(OldLCGN, NewFn);
+  }
+  removeFunction(OldFn);
+}
+
+bool CallGraphUpdater::replaceCallSite(CallBase &OldCS, CallBase &NewCS) {
+  // This is only necessary in the (old) CG.
+  if (!CG)
+    return true;
+
+  Function *Caller = OldCS.getCaller();
+  CallGraphNode *NewCalleeNode =
+      CG->getOrInsertFunction(NewCS.getCalledFunction());
+  CallGraphNode *CallerNode = (*CG)[Caller];
+  if (llvm::none_of(*CallerNode, [&OldCS](const CallGraphNode::CallRecord &CR) {
+        return CR.first == &OldCS;
+      }))
+    return false;
+  CallerNode->replaceCallEdge(OldCS, NewCS, NewCalleeNode);
+  return true;
+}
+
+void CallGraphUpdater::removeCallSite(CallBase &CS) {
+  // This is only necessary in the (old) CG.
+  if (!CG)
+    return;
+
+  Function *Caller = CS.getCaller();
+  CallGraphNode *CallerNode = (*CG)[Caller];
+  CallerNode->removeCallEdgeFor(CS);
+}

diff  --git a/llvm/unittests/Analysis/CGSCCPassManagerTest.cpp b/llvm/unittests/Analysis/CGSCCPassManagerTest.cpp
index 89ee4f1f0359..b8983901269b 100644
--- a/llvm/unittests/Analysis/CGSCCPassManagerTest.cpp
+++ b/llvm/unittests/Analysis/CGSCCPassManagerTest.cpp
@@ -16,6 +16,7 @@
 #include "llvm/IR/Module.h"
 #include "llvm/IR/PassManager.h"
 #include "llvm/Support/SourceMgr.h"
+#include "llvm/Transforms/Utils/CallGraphUpdater.h"
 #include "gtest/gtest.h"
 
 using namespace llvm;
@@ -1315,7 +1316,11 @@ struct LambdaSCCPassNoPreserve : public PassInfoMixin<LambdaSCCPassNoPreserve> {
   PreservedAnalyses run(LazyCallGraph::SCC &C, CGSCCAnalysisManager &AM,
                         LazyCallGraph &CG, CGSCCUpdateResult &UR) {
     Func(C, AM, CG, UR);
-    return PreservedAnalyses::none();
+    PreservedAnalyses PA;
+    // We update the core CGSCC data structures and so can preserve the proxy to
+    // the function analysis manager.
+    PA.preserve<FunctionAnalysisManagerCGSCCProxy>();
+    return PA;
   }
 
   std::function<void(LazyCallGraph::SCC &, CGSCCAnalysisManager &,
@@ -1449,5 +1454,236 @@ TEST_F(CGSCCPassManagerTest, TestUpdateCGAndAnalysisManagerForPasses3) {
   MPM.run(*M, MAM);
 }
 
+TEST_F(CGSCCPassManagerTest, TestUpdateCGAndAnalysisManagerForPasses4) {
+  CGSCCPassManager CGPM(/*DebugLogging*/ true);
+  CGPM.addPass(LambdaSCCPassNoPreserve(
+      [&](LazyCallGraph::SCC &C, CGSCCAnalysisManager &AM, LazyCallGraph &CG,
+          CGSCCUpdateResult &UR) {
+        if (C.getName() != "(f)")
+          return;
+
+        Function *FnF = M->getFunction("f");
+        Function *FnewF = Function::Create(FnF->getFunctionType(),
+                                           FnF->getLinkage(), "newF", *M);
+        BasicBlock *BB = BasicBlock::Create(FnewF->getContext(), "", FnewF);
+        ReturnInst::Create(FnewF->getContext(), BB);
+
+        // Use the CallGraphUpdater to update the call graph for the new
+        // function.
+        CallGraphUpdater CGU;
+        CGU.initialize(CG, C, AM, UR);
+        CGU.registerOutlinedFunction(*FnewF);
+
+        // And insert a call to `newF`
+        Instruction *IP = &FnF->getEntryBlock().front();
+        (void)CallInst::Create(FnewF, {}, "", IP);
+
+        auto &FN = *llvm::find_if(
+            C, [](LazyCallGraph::Node &N) { return N.getName() == "f"; });
+
+        ASSERT_NO_FATAL_FAILURE(
+            updateCGAndAnalysisManagerForCGSCCPass(CG, C, FN, AM, UR));
+      }));
+
+  ModulePassManager MPM(/*DebugLogging*/ true);
+  MPM.addPass(createModuleToPostOrderCGSCCPassAdaptor(std::move(CGPM)));
+  MPM.run(*M, MAM);
+}
+
+TEST_F(CGSCCPassManagerTest, TestUpdateCGAndAnalysisManagerForPasses5) {
+  CGSCCPassManager CGPM(/*DebugLogging*/ true);
+  CGPM.addPass(LambdaSCCPassNoPreserve([&](LazyCallGraph::SCC &C,
+                                           CGSCCAnalysisManager &AM,
+                                           LazyCallGraph &CG,
+                                           CGSCCUpdateResult &UR) {
+    if (C.getName() != "(f)")
+      return;
+
+    Function *FnF = M->getFunction("f");
+    Function *FnewF =
+        Function::Create(FnF->getFunctionType(), FnF->getLinkage(), "newF", *M);
+    BasicBlock *BB = BasicBlock::Create(FnewF->getContext(), "", FnewF);
+    ReturnInst::Create(FnewF->getContext(), BB);
+
+    // Use the CallGraphUpdater to update the call graph for the new
+    // function.
+    CallGraphUpdater CGU;
+    CGU.initialize(CG, C, AM, UR);
+    CGU.registerOutlinedFunction(*FnewF);
+
+    // And insert a call to `newF`
+    Instruction *IP = &FnF->getEntryBlock().front();
+    (void)CallInst::Create(FnewF, {}, "", IP);
+
+    auto &FN = *llvm::find_if(
+        C, [](LazyCallGraph::Node &N) { return N.getName() == "f"; });
+
+    ASSERT_DEATH(updateCGAndAnalysisManagerForFunctionPass(CG, C, FN, AM, UR),
+                 "Any new calls should be modeled as");
+  }));
+
+  ModulePassManager MPM(/*DebugLogging*/ true);
+  MPM.addPass(createModuleToPostOrderCGSCCPassAdaptor(std::move(CGPM)));
+  MPM.run(*M, MAM);
+}
+
+TEST_F(CGSCCPassManagerTest, TestUpdateCGAndAnalysisManagerForPasses6) {
+  CGSCCPassManager CGPM(/*DebugLogging*/ true);
+  CGPM.addPass(LambdaSCCPassNoPreserve(
+      [&](LazyCallGraph::SCC &C, CGSCCAnalysisManager &AM, LazyCallGraph &CG,
+          CGSCCUpdateResult &UR) {
+        if (C.getName() != "(h3, h1, h2)")
+          return;
+
+        Function *FnX = M->getFunction("x");
+        Function *FnH1 = M->getFunction("h1");
+        Function *FnH2 = M->getFunction("h2");
+        Function *FnH3 = M->getFunction("h3");
+        ASSERT_NE(FnX, nullptr);
+        ASSERT_NE(FnH1, nullptr);
+        ASSERT_NE(FnH2, nullptr);
+        ASSERT_NE(FnH3, nullptr);
+
+        // And insert a call to `h1`, `h2`, and `h3`.
+        Instruction *IP = &FnH2->getEntryBlock().front();
+        (void)CallInst::Create(FnH1, {}, "", IP);
+        (void)CallInst::Create(FnH2, {}, "", IP);
+        (void)CallInst::Create(FnH3, {}, "", IP);
+
+        // Use the CallGraphUpdater to update the call graph for the new
+        // function.
+        CallGraphUpdater CGU;
+        CGU.initialize(CG, C, AM, UR);
+        ASSERT_NO_FATAL_FAILURE(CGU.reanalyzeFunction(*FnH2));
+      }));
+
+  ModulePassManager MPM(/*DebugLogging*/ true);
+  MPM.addPass(createModuleToPostOrderCGSCCPassAdaptor(std::move(CGPM)));
+  MPM.run(*M, MAM);
+}
+
+TEST_F(CGSCCPassManagerTest, TestUpdateCGAndAnalysisManagerForPasses7) {
+  CGSCCPassManager CGPM(/*DebugLogging*/ true);
+  CGPM.addPass(LambdaSCCPassNoPreserve(
+      [&](LazyCallGraph::SCC &C, CGSCCAnalysisManager &AM, LazyCallGraph &CG,
+          CGSCCUpdateResult &UR) {
+        if (C.getName() != "(f)")
+          return;
+
+        Function *FnF = M->getFunction("f");
+        Function *FnH2 = M->getFunction("h2");
+        ASSERT_NE(FnF, nullptr);
+        ASSERT_NE(FnH2, nullptr);
+
+        // And insert a call to `h2`
+        Instruction *IP = &FnF->getEntryBlock().front();
+        (void)CallInst::Create(FnH2, {}, "", IP);
+
+        // Use the CallGraphUpdater to update the call graph for the new
+        // function.
+        CallGraphUpdater CGU;
+        CGU.initialize(CG, C, AM, UR);
+        ASSERT_NO_FATAL_FAILURE(CGU.reanalyzeFunction(*FnF));
+      }));
+
+  ModulePassManager MPM(/*DebugLogging*/ true);
+  MPM.addPass(createModuleToPostOrderCGSCCPassAdaptor(std::move(CGPM)));
+  MPM.run(*M, MAM);
+}
+
+TEST_F(CGSCCPassManagerTest, TestUpdateCGAndAnalysisManagerForPasses8) {
+  CGSCCPassManager CGPM(/*DebugLogging*/ true);
+  CGPM.addPass(LambdaSCCPassNoPreserve(
+      [&](LazyCallGraph::SCC &C, CGSCCAnalysisManager &AM, LazyCallGraph &CG,
+          CGSCCUpdateResult &UR) {
+        if (C.getName() != "(f)")
+          return;
+
+        Function *FnF = M->getFunction("f");
+        Function *FnewF = Function::Create(FnF->getFunctionType(),
+                                           FnF->getLinkage(), "newF", *M);
+        BasicBlock *BB = BasicBlock::Create(FnewF->getContext(), "", FnewF);
+        auto *RI = ReturnInst::Create(FnewF->getContext(), BB);
+        while (FnF->getEntryBlock().size() > 1)
+          FnF->getEntryBlock().front().moveBefore(RI);
+        ASSERT_NE(FnF, nullptr);
+
+        // Use the CallGraphUpdater to update the call graph.
+        CallGraphUpdater CGU;
+        CGU.initialize(CG, C, AM, UR);
+        ASSERT_NO_FATAL_FAILURE(CGU.replaceFunctionWith(*FnF, *FnewF));
+        ASSERT_TRUE(FnF->isDeclaration());
+        ASSERT_EQ(FnF->getNumUses(), 0U);
+      }));
+
+  ModulePassManager MPM(/*DebugLogging*/ true);
+  MPM.addPass(createModuleToPostOrderCGSCCPassAdaptor(std::move(CGPM)));
+  MPM.run(*M, MAM);
+}
+
+TEST_F(CGSCCPassManagerTest, TestUpdateCGAndAnalysisManagerForPasses9) {
+  CGSCCPassManager CGPM(/*DebugLogging*/ true);
+  CGPM.addPass(LambdaSCCPassNoPreserve(
+      [&](LazyCallGraph::SCC &C, CGSCCAnalysisManager &AM, LazyCallGraph &CG,
+          CGSCCUpdateResult &UR) {
+        if (C.getName() != "(f)")
+          return;
+
+        Function *FnF = M->getFunction("f");
+
+        // Use the CallGraphUpdater to update the call graph.
+        {
+          CallGraphUpdater CGU;
+          CGU.initialize(CG, C, AM, UR);
+          ASSERT_NO_FATAL_FAILURE(CGU.removeFunction(*FnF));
+          ASSERT_EQ(M->getFunctionList().size(), 6U);
+        }
+        ASSERT_EQ(M->getFunctionList().size(), 5U);
+      }));
+
+  ModulePassManager MPM(/*DebugLogging*/ true);
+  MPM.addPass(createModuleToPostOrderCGSCCPassAdaptor(std::move(CGPM)));
+  MPM.run(*M, MAM);
+}
+
+TEST_F(CGSCCPassManagerTest, TestUpdateCGAndAnalysisManagerForPasses10) {
+  CGSCCPassManager CGPM(/*DebugLogging*/ true);
+  CGPM.addPass(LambdaSCCPassNoPreserve(
+      [&](LazyCallGraph::SCC &C, CGSCCAnalysisManager &AM, LazyCallGraph &CG,
+          CGSCCUpdateResult &UR) {
+        if (C.getName() != "(h3, h1, h2)")
+          return;
+
+        Function *FnX = M->getFunction("x");
+        Function *FnH1 = M->getFunction("h1");
+        Function *FnH2 = M->getFunction("h2");
+        Function *FnH3 = M->getFunction("h3");
+        ASSERT_NE(FnX, nullptr);
+        ASSERT_NE(FnH1, nullptr);
+        ASSERT_NE(FnH2, nullptr);
+        ASSERT_NE(FnH3, nullptr);
+
+        // And insert a call to `h1`, and `h3`.
+        Instruction *IP = &FnH1->getEntryBlock().front();
+        (void)CallInst::Create(FnH1, {}, "", IP);
+        (void)CallInst::Create(FnH3, {}, "", IP);
+
+        // Remove the `h2` call.
+        ASSERT_TRUE(isa<CallBase>(IP));
+        ASSERT_EQ(cast<CallBase>(IP)->getCalledFunction(), FnH2);
+        IP->eraseFromParent();
+
+        // Use the CallGraphUpdater to update the call graph.
+        CallGraphUpdater CGU;
+        CGU.initialize(CG, C, AM, UR);
+        ASSERT_NO_FATAL_FAILURE(CGU.reanalyzeFunction(*FnH1));
+        ASSERT_NO_FATAL_FAILURE(CGU.removeFunction(*FnH2));
+      }));
+
+  ModulePassManager MPM(/*DebugLogging*/ true);
+  MPM.addPass(createModuleToPostOrderCGSCCPassAdaptor(std::move(CGPM)));
+  MPM.run(*M, MAM);
+}
+
 #endif
 } // namespace

diff  --git a/llvm/unittests/Analysis/CMakeLists.txt b/llvm/unittests/Analysis/CMakeLists.txt
index 669277fb1586..3a631d446691 100644
--- a/llvm/unittests/Analysis/CMakeLists.txt
+++ b/llvm/unittests/Analysis/CMakeLists.txt
@@ -3,6 +3,7 @@ set(LLVM_LINK_COMPONENTS
   AsmParser
   Core
   Support
+  TransformUtils
   )
 
 add_llvm_unittest(AnalysisTests

diff  --git a/llvm/unittests/Analysis/LazyCallGraphTest.cpp b/llvm/unittests/Analysis/LazyCallGraphTest.cpp
index 0da34ebd1d95..be7d458df729 100644
--- a/llvm/unittests/Analysis/LazyCallGraphTest.cpp
+++ b/llvm/unittests/Analysis/LazyCallGraphTest.cpp
@@ -450,6 +450,47 @@ TEST(LazyCallGraphTest, BasicGraphMutation) {
   EXPECT_EQ(0, std::distance(B->begin(), B->end()));
 }
 
+TEST(LazyCallGraphTest, BasicGraphMutationOutlining) {
+  LLVMContext Context;
+  std::unique_ptr<Module> M = parseAssembly(Context, "define void @a() {\n"
+                                                     "entry:\n"
+                                                     "  call void @b()\n"
+                                                     "  call void @c()\n"
+                                                     "  ret void\n"
+                                                     "}\n"
+                                                     "define void @b() {\n"
+                                                     "entry:\n"
+                                                     "  ret void\n"
+                                                     "}\n"
+                                                     "define void @c() {\n"
+                                                     "entry:\n"
+                                                     "  ret void\n"
+                                                     "}\n");
+  LazyCallGraph CG = buildCG(*M);
+
+  LazyCallGraph::Node &A = CG.get(lookupFunction(*M, "a"));
+  LazyCallGraph::Node &B = CG.get(lookupFunction(*M, "b"));
+  LazyCallGraph::Node &C = CG.get(lookupFunction(*M, "c"));
+  A.populate();
+  B.populate();
+  C.populate();
+  CG.buildRefSCCs();
+
+  // Add a new function that is called from @b and verify it is in the same SCC.
+  Function &BFn = B.getFunction();
+  Function *NewFn =
+      Function::Create(BFn.getFunctionType(), BFn.getLinkage(), "NewFn", *M);
+  auto IP = BFn.getEntryBlock().getFirstInsertionPt();
+  CallInst::Create(NewFn, "", &*IP);
+  CG.addNewFunctionIntoSCC(*NewFn, *CG.lookupSCC(B));
+
+  EXPECT_EQ(CG.lookupSCC(A)->size(), 1U);
+  EXPECT_EQ(CG.lookupSCC(B)->size(), 2U);
+  EXPECT_EQ(CG.lookupSCC(C)->size(), 1U);
+  EXPECT_EQ(CG.lookupSCC(*CG.lookup(*NewFn))->size(), 2U);
+  EXPECT_EQ(CG.lookupSCC(*CG.lookup(*NewFn))->size(), CG.lookupSCC(B)->size());
+}
+
 TEST(LazyCallGraphTest, InnerSCCFormation) {
   LLVMContext Context;
   std::unique_ptr<Module> M = parseAssembly(Context, DiamondOfTriangles);

diff  --git a/llvm/unittests/IR/CMakeLists.txt b/llvm/unittests/IR/CMakeLists.txt
index d27c6d969f17..3ef7cee98dcd 100644
--- a/llvm/unittests/IR/CMakeLists.txt
+++ b/llvm/unittests/IR/CMakeLists.txt
@@ -4,6 +4,7 @@ set(LLVM_LINK_COMPONENTS
   Core
   Support
   Passes
+  TransformUtils
   )
 
 add_llvm_unittest(IRTests

diff  --git a/llvm/unittests/IR/LegacyPassManagerTest.cpp b/llvm/unittests/IR/LegacyPassManagerTest.cpp
index aa02cbe96e0b..b716d656bf74 100644
--- a/llvm/unittests/IR/LegacyPassManagerTest.cpp
+++ b/llvm/unittests/IR/LegacyPassManagerTest.cpp
@@ -29,6 +29,7 @@
 #include "llvm/InitializePasses.h"
 #include "llvm/Support/MathExtras.h"
 #include "llvm/Support/raw_ostream.h"
+#include "llvm/Transforms/Utils/CallGraphUpdater.h"
 #include "gtest/gtest.h"
 
 using namespace llvm;
@@ -559,6 +560,72 @@ namespace llvm {
       return mod;
     }
 
+    struct CGModifierPass : public CGPass {
+      unsigned NumSCCs = 0;
+      unsigned NumFns = 0;
+      bool SetupWorked = true;
+
+      CallGraphUpdater CGU;
+
+      bool runOnSCC(CallGraphSCC &SCMM) override {
+        ++NumSCCs;
+        for (CallGraphNode *N : SCMM)
+          if (N->getFunction())
+            ++NumFns;
+
+        CGPass::run();
+
+        if (SCMM.size() <= 1)
+          return false;
+
+        CallGraphNode *N = *(SCMM.begin());
+        Function *F = N->getFunction();
+        Module *M = F->getParent();
+        Function *Test1F = M->getFunction("test1");
+        Function *Test2F = M->getFunction("test2");
+        Function *Test3F = M->getFunction("test3");
+        auto InSCC = [&](Function *Fn) {
+          return llvm::any_of(SCMM, [Fn](CallGraphNode *CGN) {
+            return CGN->getFunction() == Fn;
+          });
+        };
+
+        if (!Test1F || !Test2F || !Test3F || !InSCC(Test1F) || !InSCC(Test2F) ||
+            !InSCC(Test3F))
+          return SetupWorked = false;
+
+        CallInst *CI = dyn_cast<CallInst>(&Test1F->getEntryBlock().front());
+        if (!CI || CI->getCalledFunction() != Test2F)
+          return SetupWorked = false;
+
+        CI->setCalledFunction(Test3F);
+
+        CGU.initialize(const_cast<CallGraph &>(SCMM.getCallGraph()), SCMM);
+        CGU.removeFunction(*Test2F);
+        CGU.reanalyzeFunction(*Test1F);
+        return true;
+      }
+
+      bool doFinalization(CallGraph &CG) override { return CGU.finalize(); }
+    };
+
+    TEST(PassManager, CallGraphUpdater0) {
+      // SCC#1: test1->test2->test3->test1
+      // SCC#2: test4
+      // SCC#3: indirect call node
+
+      LLVMContext Context;
+      std::unique_ptr<Module> M(makeLLVMModule(Context));
+      ASSERT_EQ(M->getFunctionList().size(), 4U);
+      CGModifierPass *P = new CGModifierPass();
+      legacy::PassManager Passes;
+      Passes.add(P);
+      Passes.run(*M);
+      ASSERT_TRUE(P->SetupWorked);
+      ASSERT_EQ(P->NumSCCs, 3U);
+      ASSERT_EQ(P->NumFns, 4U);
+      ASSERT_EQ(M->getFunctionList().size(), 3U);
+    }
   }
 }
 


        


More information about the llvm-commits mailing list