[llvm] 7f9a89f - [ORC] Use the new dispatchTask API to run query callbacks.
Lang Hames via llvm-commits
llvm-commits at lists.llvm.org
Sun May 9 19:38:25 PDT 2021
Author: Lang Hames
Date: 2021-05-09T19:19:40-07:00
New Revision: 7f9a89f9a2cc55dbfc315aa11416fe3609918199
URL: https://github.com/llvm/llvm-project/commit/7f9a89f9a2cc55dbfc315aa11416fe3609918199
DIFF: https://github.com/llvm/llvm-project/commit/7f9a89f9a2cc55dbfc315aa11416fe3609918199.diff
LOG: [ORC] Use the new dispatchTask API to run query callbacks.
Dispatching query callbacks, rather than running them on the current thread,
will allow them to be distributed across multiple threads.
Added:
Modified:
llvm/include/llvm/ExecutionEngine/Orc/Core.h
llvm/lib/ExecutionEngine/Orc/Core.cpp
llvm/unittests/ExecutionEngine/Orc/CoreAPIsTest.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/ExecutionEngine/Orc/Core.h b/llvm/include/llvm/ExecutionEngine/Orc/Core.h
index c37361fbe57e..f8dc03923c5e 100644
--- a/llvm/include/llvm/ExecutionEngine/Orc/Core.h
+++ b/llvm/include/llvm/ExecutionEngine/Orc/Core.h
@@ -819,13 +819,10 @@ class AsynchronousSymbolQuery {
/// resolved.
bool isComplete() const { return OutstandingSymbolsCount == 0; }
- /// Call the NotifyComplete callback.
- ///
- /// This should only be called if all symbols covered by the query have
- /// reached the specified state.
- void handleComplete();
private:
+ void handleComplete(ExecutionSession &ES);
+
SymbolState getRequiredState() { return RequiredState; }
void addQueryDependence(JITDylib &JD, SymbolStringPtr Name);
diff --git a/llvm/lib/ExecutionEngine/Orc/Core.cpp b/llvm/lib/ExecutionEngine/Orc/Core.cpp
index 4300a0bbd1bc..270ef7cc37dc 100644
--- a/llvm/lib/ExecutionEngine/Orc/Core.cpp
+++ b/llvm/lib/ExecutionEngine/Orc/Core.cpp
@@ -170,13 +170,30 @@ void AsynchronousSymbolQuery::notifySymbolMetRequiredState(
--OutstandingSymbolsCount;
}
-void AsynchronousSymbolQuery::handleComplete() {
+void AsynchronousSymbolQuery::handleComplete(ExecutionSession &ES) {
assert(OutstandingSymbolsCount == 0 &&
"Symbols remain, handleComplete called prematurely");
- auto TmpNotifyComplete = std::move(NotifyComplete);
+ class RunQueryCompleteTask : public Task {
+ public:
+ RunQueryCompleteTask(SymbolMap ResolvedSymbols,
+ SymbolsResolvedCallback NotifyComplete)
+ : ResolvedSymbols(std::move(ResolvedSymbols)),
+ NotifyComplete(std::move(NotifyComplete)) {}
+ void printDescription(raw_ostream &OS) override {
+ OS << "Execute query complete callback for " << ResolvedSymbols;
+ }
+ void run() override { NotifyComplete(std::move(ResolvedSymbols)); }
+
+ private:
+ SymbolMap ResolvedSymbols;
+ SymbolsResolvedCallback NotifyComplete;
+ };
+
+ auto T = std::make_unique<RunQueryCompleteTask>(std::move(ResolvedSymbols),
+ std::move(NotifyComplete));
NotifyComplete = SymbolsResolvedCallback();
- TmpNotifyComplete(std::move(ResolvedSymbols));
+ ES.dispatchTask(std::move(T));
}
void AsynchronousSymbolQuery::handleFailed(Error Err) {
@@ -969,7 +986,7 @@ Error JITDylib::resolve(MaterializationResponsibility &MR,
// Otherwise notify all the completed queries.
for (auto &Q : CompletedQueries) {
assert(Q->isComplete() && "Q not completed");
- Q->handleComplete();
+ Q->handleComplete(ES);
}
return Error::success();
@@ -1120,7 +1137,7 @@ Error JITDylib::emit(MaterializationResponsibility &MR,
// Otherwise notify all the completed queries.
for (auto &Q : CompletedQueries) {
assert(Q->isComplete() && "Q is not complete");
- Q->handleComplete();
+ Q->handleComplete(ES);
}
return Error::success();
@@ -2541,7 +2558,7 @@ void ExecutionSession::OL_completeLookup(
if (QueryComplete) {
LLVM_DEBUG(dbgs() << "Completing query\n");
- Q->handleComplete();
+ Q->handleComplete(*this);
}
dispatchOutstandingMUs();
diff --git a/llvm/unittests/ExecutionEngine/Orc/CoreAPIsTest.cpp b/llvm/unittests/ExecutionEngine/Orc/CoreAPIsTest.cpp
index 5128cc9f558a..8935ea4c3345 100644
--- a/llvm/unittests/ExecutionEngine/Orc/CoreAPIsTest.cpp
+++ b/llvm/unittests/ExecutionEngine/Orc/CoreAPIsTest.cpp
@@ -1019,12 +1019,11 @@ TEST_F(CoreAPIsStandardTest, TestBasicWeakSymbolMaterialization) {
TEST_F(CoreAPIsStandardTest, DefineMaterializingSymbol) {
bool ExpectNoMoreMaterialization = false;
- ES.setDispatchTask(
- [&](std::unique_ptr<Task> T) {
- if (ExpectNoMoreMaterialization)
- ADD_FAILURE() << "Unexpected materialization";
- T->run();
- });
+ ES.setDispatchTask([&](std::unique_ptr<Task> T) {
+ if (ExpectNoMoreMaterialization && isa<MaterializationTask>(*T))
+ ADD_FAILURE() << "Unexpected materialization";
+ T->run();
+ });
auto MU = std::make_unique<SimpleMaterializationUnit>(
SymbolFlagsMap({{Foo, FooSym.getFlags()}}),
@@ -1250,14 +1249,11 @@ TEST_F(CoreAPIsStandardTest, TestLookupWithUnthreadedMaterialization) {
TEST_F(CoreAPIsStandardTest, TestLookupWithThreadedMaterialization) {
#if LLVM_ENABLE_THREADS
- std::thread MaterializationThread;
- ES.setDispatchTask(
- [&](std::unique_ptr<Task> T) {
- MaterializationThread =
- std::thread([T = std::move(T)]() mutable {
- T->run();
- });
- });
+ std::vector<std::thread> WorkThreads;
+ ES.setDispatchTask([&](std::unique_ptr<Task> T) {
+ WorkThreads.push_back(
+ std::thread([T = std::move(T)]() mutable { T->run(); }));
+ });
cantFail(JD.define(absoluteSymbols({{Foo, FooSym}})));
@@ -1267,7 +1263,9 @@ TEST_F(CoreAPIsStandardTest, TestLookupWithThreadedMaterialization) {
<< "lookup returned an incorrect address";
EXPECT_EQ(FooLookupResult.getFlags(), FooSym.getFlags())
<< "lookup returned incorrect flags";
- MaterializationThread.join();
+
+ for (auto &WT : WorkThreads)
+ WT.join();
#endif
}
More information about the llvm-commits
mailing list