[llvm] 7f9a89f - [ORC] Use the new dispatchTask API to run query callbacks.

Lang Hames via llvm-commits llvm-commits at lists.llvm.org
Sun May 9 19:38:25 PDT 2021


Author: Lang Hames
Date: 2021-05-09T19:19:40-07:00
New Revision: 7f9a89f9a2cc55dbfc315aa11416fe3609918199

URL: https://github.com/llvm/llvm-project/commit/7f9a89f9a2cc55dbfc315aa11416fe3609918199
DIFF: https://github.com/llvm/llvm-project/commit/7f9a89f9a2cc55dbfc315aa11416fe3609918199.diff

LOG: [ORC] Use the new dispatchTask API to run query callbacks.

Dispatching query callbacks, rather than running them on the current thread,
will allow them to be distributed across multiple threads.

Added: 
    

Modified: 
    llvm/include/llvm/ExecutionEngine/Orc/Core.h
    llvm/lib/ExecutionEngine/Orc/Core.cpp
    llvm/unittests/ExecutionEngine/Orc/CoreAPIsTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/ExecutionEngine/Orc/Core.h b/llvm/include/llvm/ExecutionEngine/Orc/Core.h
index c37361fbe57e..f8dc03923c5e 100644
--- a/llvm/include/llvm/ExecutionEngine/Orc/Core.h
+++ b/llvm/include/llvm/ExecutionEngine/Orc/Core.h
@@ -819,13 +819,10 @@ class AsynchronousSymbolQuery {
   ///        resolved.
   bool isComplete() const { return OutstandingSymbolsCount == 0; }
 
-  /// Call the NotifyComplete callback.
-  ///
-  /// This should only be called if all symbols covered by the query have
-  /// reached the specified state.
-  void handleComplete();
 
 private:
+  void handleComplete(ExecutionSession &ES);
+
   SymbolState getRequiredState() { return RequiredState; }
 
   void addQueryDependence(JITDylib &JD, SymbolStringPtr Name);

diff  --git a/llvm/lib/ExecutionEngine/Orc/Core.cpp b/llvm/lib/ExecutionEngine/Orc/Core.cpp
index 4300a0bbd1bc..270ef7cc37dc 100644
--- a/llvm/lib/ExecutionEngine/Orc/Core.cpp
+++ b/llvm/lib/ExecutionEngine/Orc/Core.cpp
@@ -170,13 +170,30 @@ void AsynchronousSymbolQuery::notifySymbolMetRequiredState(
   --OutstandingSymbolsCount;
 }
 
-void AsynchronousSymbolQuery::handleComplete() {
+void AsynchronousSymbolQuery::handleComplete(ExecutionSession &ES) {
   assert(OutstandingSymbolsCount == 0 &&
          "Symbols remain, handleComplete called prematurely");
 
-  auto TmpNotifyComplete = std::move(NotifyComplete);
+  class RunQueryCompleteTask : public Task {
+  public:
+    RunQueryCompleteTask(SymbolMap ResolvedSymbols,
+                         SymbolsResolvedCallback NotifyComplete)
+        : ResolvedSymbols(std::move(ResolvedSymbols)),
+          NotifyComplete(std::move(NotifyComplete)) {}
+    void printDescription(raw_ostream &OS) override {
+      OS << "Execute query complete callback for " << ResolvedSymbols;
+    }
+    void run() override { NotifyComplete(std::move(ResolvedSymbols)); }
+
+  private:
+    SymbolMap ResolvedSymbols;
+    SymbolsResolvedCallback NotifyComplete;
+  };
+
+  auto T = std::make_unique<RunQueryCompleteTask>(std::move(ResolvedSymbols),
+                                                  std::move(NotifyComplete));
   NotifyComplete = SymbolsResolvedCallback();
-  TmpNotifyComplete(std::move(ResolvedSymbols));
+  ES.dispatchTask(std::move(T));
 }
 
 void AsynchronousSymbolQuery::handleFailed(Error Err) {
@@ -969,7 +986,7 @@ Error JITDylib::resolve(MaterializationResponsibility &MR,
   // Otherwise notify all the completed queries.
   for (auto &Q : CompletedQueries) {
     assert(Q->isComplete() && "Q not completed");
-    Q->handleComplete();
+    Q->handleComplete(ES);
   }
 
   return Error::success();
@@ -1120,7 +1137,7 @@ Error JITDylib::emit(MaterializationResponsibility &MR,
   // Otherwise notify all the completed queries.
   for (auto &Q : CompletedQueries) {
     assert(Q->isComplete() && "Q is not complete");
-    Q->handleComplete();
+    Q->handleComplete(ES);
   }
 
   return Error::success();
@@ -2541,7 +2558,7 @@ void ExecutionSession::OL_completeLookup(
 
   if (QueryComplete) {
     LLVM_DEBUG(dbgs() << "Completing query\n");
-    Q->handleComplete();
+    Q->handleComplete(*this);
   }
 
   dispatchOutstandingMUs();

diff  --git a/llvm/unittests/ExecutionEngine/Orc/CoreAPIsTest.cpp b/llvm/unittests/ExecutionEngine/Orc/CoreAPIsTest.cpp
index 5128cc9f558a..8935ea4c3345 100644
--- a/llvm/unittests/ExecutionEngine/Orc/CoreAPIsTest.cpp
+++ b/llvm/unittests/ExecutionEngine/Orc/CoreAPIsTest.cpp
@@ -1019,12 +1019,11 @@ TEST_F(CoreAPIsStandardTest, TestBasicWeakSymbolMaterialization) {
 
 TEST_F(CoreAPIsStandardTest, DefineMaterializingSymbol) {
   bool ExpectNoMoreMaterialization = false;
-  ES.setDispatchTask(
-      [&](std::unique_ptr<Task> T) {
-        if (ExpectNoMoreMaterialization)
-          ADD_FAILURE() << "Unexpected materialization";
-        T->run();
-      });
+  ES.setDispatchTask([&](std::unique_ptr<Task> T) {
+    if (ExpectNoMoreMaterialization && isa<MaterializationTask>(*T))
+      ADD_FAILURE() << "Unexpected materialization";
+    T->run();
+  });
 
   auto MU = std::make_unique<SimpleMaterializationUnit>(
       SymbolFlagsMap({{Foo, FooSym.getFlags()}}),
@@ -1250,14 +1249,11 @@ TEST_F(CoreAPIsStandardTest, TestLookupWithUnthreadedMaterialization) {
 TEST_F(CoreAPIsStandardTest, TestLookupWithThreadedMaterialization) {
 #if LLVM_ENABLE_THREADS
 
-  std::thread MaterializationThread;
-  ES.setDispatchTask(
-      [&](std::unique_ptr<Task> T) {
-        MaterializationThread =
-            std::thread([T = std::move(T)]() mutable {
-              T->run();
-            });
-      });
+  std::vector<std::thread> WorkThreads;
+  ES.setDispatchTask([&](std::unique_ptr<Task> T) {
+    WorkThreads.push_back(
+        std::thread([T = std::move(T)]() mutable { T->run(); }));
+  });
 
   cantFail(JD.define(absoluteSymbols({{Foo, FooSym}})));
 
@@ -1267,7 +1263,9 @@ TEST_F(CoreAPIsStandardTest, TestLookupWithThreadedMaterialization) {
       << "lookup returned an incorrect address";
   EXPECT_EQ(FooLookupResult.getFlags(), FooSym.getFlags())
       << "lookup returned incorrect flags";
-  MaterializationThread.join();
+
+  for (auto &WT : WorkThreads)
+    WT.join();
 #endif
 }
 


        


More information about the llvm-commits mailing list