[llvm] bec77ec - [CallGraph] Preserve call records vector when replacing call edge

Sergey Dmitriev via llvm-commits llvm-commits at lists.llvm.org
Mon Jul 27 06:16:43 PDT 2020


Author: Sergey Dmitriev
Date: 2020-07-27T06:02:55-07:00
New Revision: bec77ece14890d2aa40c76eedc6a7a406d84f1fc

URL: https://github.com/llvm/llvm-project/commit/bec77ece14890d2aa40c76eedc6a7a406d84f1fc
DIFF: https://github.com/llvm/llvm-project/commit/bec77ece14890d2aa40c76eedc6a7a406d84f1fc.diff

LOG: [CallGraph] Preserve call records vector when replacing call edge

Summary:
Try not to resize vector of call records in a call graph node when
replacing call edge. That would prevent invalidation of iterators
stored in the CG SCC pass manager's scc_iterator.

Reviewers: jdoerfert

Reviewed By: jdoerfert

Subscribers: hiraditya, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D84295

Added: 
    

Modified: 
    llvm/lib/Analysis/CallGraph.cpp
    llvm/unittests/IR/LegacyPassManagerTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Analysis/CallGraph.cpp b/llvm/lib/Analysis/CallGraph.cpp
index 55adb454b733..19c128b6633c 100644
--- a/llvm/lib/Analysis/CallGraph.cpp
+++ b/llvm/lib/Analysis/CallGraph.cpp
@@ -281,13 +281,37 @@ void CallGraphNode::replaceCallEdge(CallBase &Call, CallBase &NewCall,
       I->second = NewNode;
       NewNode->AddRef();
 
-      // Refresh callback references.
-      forEachCallbackFunction(Call, [=](Function *CB) {
-        removeOneAbstractEdgeTo(CG->getOrInsertFunction(CB));
+      // Refresh callback references. Do not resize CalledFunctions if the
+      // number of callbacks is the same for new and old call sites.
+      SmallVector<CallGraphNode *, 4u> OldCBs;
+      SmallVector<CallGraphNode *, 4u> NewCBs;
+      forEachCallbackFunction(Call, [this, &OldCBs](Function *CB) {
+        OldCBs.push_back(CG->getOrInsertFunction(CB));
       });
-      forEachCallbackFunction(NewCall, [=](Function *CB) {
-        addCalledFunction(nullptr, CG->getOrInsertFunction(CB));
+      forEachCallbackFunction(NewCall, [this, &NewCBs](Function *CB) {
+        NewCBs.push_back(CG->getOrInsertFunction(CB));
       });
+      if (OldCBs.size() == NewCBs.size()) {
+        for (unsigned N = 0; N < OldCBs.size(); ++N) {
+          CallGraphNode *OldNode = OldCBs[N];
+          CallGraphNode *NewNode = NewCBs[N];
+          for (auto J = CalledFunctions.begin();; ++J) {
+            assert(J != CalledFunctions.end() &&
+                   "Cannot find callsite to update!");
+            if (!J->first && J->second == OldNode) {
+              J->second = NewNode;
+              OldNode->DropRef();
+              NewNode->AddRef();
+              break;
+            }
+          }
+        }
+      } else {
+        for (auto *CGN : OldCBs)
+          removeOneAbstractEdgeTo(CGN);
+        for (auto *CGN : NewCBs)
+          addCalledFunction(nullptr, CGN);
+      }
       return;
     }
   }

diff  --git a/llvm/unittests/IR/LegacyPassManagerTest.cpp b/llvm/unittests/IR/LegacyPassManagerTest.cpp
index 72ac4be22997..f461bcc8c776 100644
--- a/llvm/unittests/IR/LegacyPassManagerTest.cpp
+++ b/llvm/unittests/IR/LegacyPassManagerTest.cpp
@@ -16,6 +16,8 @@
 #include "llvm/Analysis/CallGraphSCCPass.h"
 #include "llvm/Analysis/LoopInfo.h"
 #include "llvm/Analysis/LoopPass.h"
+#include "llvm/AsmParser/Parser.h"
+#include "llvm/IR/AbstractCallSite.h"
 #include "llvm/IR/BasicBlock.h"
 #include "llvm/IR/CallingConv.h"
 #include "llvm/IR/DataLayout.h"
@@ -28,6 +30,7 @@
 #include "llvm/IR/OptBisect.h"
 #include "llvm/InitializePasses.h"
 #include "llvm/Support/MathExtras.h"
+#include "llvm/Support/SourceMgr.h"
 #include "llvm/Support/raw_ostream.h"
 #include "llvm/Transforms/Utils/CallGraphUpdater.h"
 #include "gtest/gtest.h"
@@ -694,6 +697,89 @@ namespace llvm {
       ASSERT_EQ(P->NumExtCalledBefore, /* test1, 2a, 2b, 3, 4 */ 5U);
       ASSERT_EQ(P->NumExtCalledAfter, /* test1, 3repl, 4 */ 3U);
     }
+
+    // Test for call graph SCC pass that replaces all callback call instructions
+    // with clones and updates CallGraph by calling CallGraph::replaceCallEdge()
+    // method. Test is expected to complete successfully after running pass on
+    // all SCCs in the test module.
+    struct CallbackCallsModifierPass : public CGPass {
+      bool runOnSCC(CallGraphSCC &SCC) override {
+        CGPass::run();
+
+        CallGraph &CG = const_cast<CallGraph &>(SCC.getCallGraph());
+
+        bool Changed = false;
+        for (CallGraphNode *CGN : SCC) {
+          Function *F = CGN->getFunction();
+          if (!F || F->isDeclaration())
+            continue;
+
+          SmallVector<CallBase *, 4u> Calls;
+          for (Use &U : F->uses()) {
+            AbstractCallSite ACS(&U);
+            if (!ACS || !ACS.isCallbackCall() || !ACS.isCallee(&U))
+              continue;
+            Calls.push_back(cast<CallBase>(ACS.getInstruction()));
+          }
+          if (Calls.empty())
+            continue;
+
+          for (CallBase *OldCB : Calls) {
+            CallGraphNode *CallerCGN = CG[OldCB->getParent()->getParent()];
+            assert(any_of(*CallerCGN,
+                          [CGN](const CallGraphNode::CallRecord &CallRecord) {
+                            return CallRecord.second == CGN;
+                          }) &&
+                   "function is not a callee");
+
+            CallBase *NewCB = cast<CallBase>(OldCB->clone());
+
+            NewCB->insertBefore(OldCB);
+            NewCB->takeName(OldCB);
+
+            CallerCGN->replaceCallEdge(*OldCB, *NewCB, CG[F]);
+
+            OldCB->replaceAllUsesWith(NewCB);
+            OldCB->eraseFromParent();
+          }
+          Changed = true;
+        }
+        return Changed;
+      }
+    };
+
+    TEST(PassManager, CallbackCallsModifier0) {
+      LLVMContext Context;
+
+      const char *IR = "define void @foo() {\n"
+                       "  call void @broker(void (i8*)* @callback0, i8* null)\n"
+                       "  call void @broker(void (i8*)* @callback1, i8* null)\n"
+                       "  ret void\n"
+                       "}\n"
+                       "\n"
+                       "declare !callback !0 void @broker(void (i8*)*, i8*)\n"
+                       "\n"
+                       "define internal void @callback0(i8* %arg) {\n"
+                       "  ret void\n"
+                       "}\n"
+                       "\n"
+                       "define internal void @callback1(i8* %arg) {\n"
+                       "  ret void\n"
+                       "}\n"
+                       "\n"
+                       "!0 = !{!1}\n"
+                       "!1 = !{i64 0, i64 1, i1 false}";
+
+      SMDiagnostic Err;
+      std::unique_ptr<Module> M = parseAssemblyString(IR, Err, Context);
+      if (!M)
+        Err.print("LegacyPassManagerTest", errs());
+
+      CallbackCallsModifierPass *P = new CallbackCallsModifierPass();
+      legacy::PassManager Passes;
+      Passes.add(P);
+      Passes.run(*M);
+    }
   }
 }
 


        


More information about the llvm-commits mailing list