[llvm] [orc-rt] Introduce Task and TaskDispatcher APIs and implementations. (PR #168514)
Lang Hames via llvm-commits
llvm-commits at lists.llvm.org
Tue Nov 18 03:19:41 PST 2025
https://github.com/lhames created https://github.com/llvm/llvm-project/pull/168514
Introduces the Task and TaskDispatcher interfaces (TaskDispatcher.h), ThreadPoolTaskDispatcher implementation (ThreadPoolTaskDispatch.h), and updates Session to include a TaskDispatcher instance that can be used to run tasks.
TaskDispatcher's introduction is motivated by the need to handle calls to JIT'd code initiated from the controller process: Incoming calls will be wrapped in Tasks and dispatched. Session shutdown will wait on TaskDispatcher shutdown, ensuring that all Tasks are run or destroyed prior to the Session being destroyed.
>From 951bec15eec1b941081aacae6c4a30a931e01506 Mon Sep 17 00:00:00 2001
From: Lang Hames <lhames at gmail.com>
Date: Mon, 17 Nov 2025 19:28:49 +1100
Subject: [PATCH] [orc-rt] Introduce Task and TaskDispatcher APIs and
implementations.
Introduces the Task and TaskDispatcher interfaces (TaskDispatcher.h),
ThreadPoolTaskDispatcher implementation (ThreadPoolTaskDispatch.h), and updates
Session to include a TaskDispatcher instance that can be used to run tasks.
TaskDispatcher's introduction is motivated by the need to handle calls to JIT'd
code initiated from the controller process: Incoming calls will be wrapped in
Tasks and dispatched. Session shutdown will wait on TaskDispatcher shutdown,
ensuring that all Tasks are run or destroyed prior to the Session being
destroyed.
---
orc-rt/include/CMakeLists.txt | 2 +
orc-rt/include/orc-rt/Session.h | 23 +++-
orc-rt/include/orc-rt/TaskDispatcher.h | 64 ++++++++++
.../include/orc-rt/ThreadPoolTaskDispatcher.h | 48 ++++++++
orc-rt/lib/executor/CMakeLists.txt | 2 +
orc-rt/lib/executor/Session.cpp | 58 ++++++---
orc-rt/lib/executor/TaskDispatcher.cpp | 20 ++++
.../lib/executor/ThreadPoolTaskDispatcher.cpp | 70 +++++++++++
orc-rt/unittests/CMakeLists.txt | 1 +
orc-rt/unittests/SessionTest.cpp | 94 ++++++++++++++-
.../ThreadPoolTaskDispatcherTest.cpp | 110 ++++++++++++++++++
11 files changed, 467 insertions(+), 25 deletions(-)
create mode 100644 orc-rt/include/orc-rt/TaskDispatcher.h
create mode 100644 orc-rt/include/orc-rt/ThreadPoolTaskDispatcher.h
create mode 100644 orc-rt/lib/executor/TaskDispatcher.cpp
create mode 100644 orc-rt/lib/executor/ThreadPoolTaskDispatcher.cpp
create mode 100644 orc-rt/unittests/ThreadPoolTaskDispatcherTest.cpp
diff --git a/orc-rt/include/CMakeLists.txt b/orc-rt/include/CMakeLists.txt
index 8ac8a126dd012..35c45e236c023 100644
--- a/orc-rt/include/CMakeLists.txt
+++ b/orc-rt/include/CMakeLists.txt
@@ -22,6 +22,8 @@ set(ORC_RT_HEADERS
orc-rt/SPSMemoryFlags.h
orc-rt/SPSWrapperFunction.h
orc-rt/SPSWrapperFunctionBuffer.h
+ orc-rt/TaskDispatcher.h
+ orc-rt/ThreadPoolTaskDispatcher.h
orc-rt/WrapperFunction.h
orc-rt/bind.h
orc-rt/bit.h
diff --git a/orc-rt/include/orc-rt/Session.h b/orc-rt/include/orc-rt/Session.h
index 78bd92bb0d0c8..367cdb9a97b62 100644
--- a/orc-rt/include/orc-rt/Session.h
+++ b/orc-rt/include/orc-rt/Session.h
@@ -15,10 +15,12 @@
#include "orc-rt/Error.h"
#include "orc-rt/ResourceManager.h"
+#include "orc-rt/TaskDispatcher.h"
#include "orc-rt/move_only_function.h"
#include "orc-rt-c/CoreTypes.h"
+#include <condition_variable>
#include <memory>
#include <mutex>
#include <vector>
@@ -39,7 +41,10 @@ class Session {
///
/// Note that entry into the reporter is not synchronized: it may be
/// called from multiple threads concurrently.
- Session(ErrorReporterFn ReportError) : ReportError(std::move(ReportError)) {}
+ Session(std::unique_ptr<TaskDispatcher> Dispatcher,
+ ErrorReporterFn ReportError)
+ : Dispatcher(std::move(Dispatcher)), ReportError(std::move(ReportError)) {
+ }
// Sessions are not copyable or moveable.
Session(const Session &) = delete;
@@ -49,6 +54,9 @@ class Session {
~Session();
+ /// Dispatch a task using the Session's TaskDispatcher.
+ void dispatch(std::unique_ptr<Task> T) { Dispatcher->dispatch(std::move(T)); }
+
/// Report an error via the ErrorReporter function.
void reportError(Error Err) { ReportError(std::move(Err)); }
@@ -67,12 +75,21 @@ class Session {
}
private:
- void shutdownNext(OnShutdownCompleteFn OnShutdownComplete, Error Err,
+ void shutdownNext(Error Err,
std::vector<std::unique_ptr<ResourceManager>> RemainingRMs);
- std::mutex M;
+ void shutdownComplete();
+
+ std::unique_ptr<TaskDispatcher> Dispatcher;
ErrorReporterFn ReportError;
+
+ enum class SessionState { Running, ShuttingDown, Shutdown };
+
+ std::mutex M;
+ SessionState State = SessionState::Running;
+ std::condition_variable StateCV;
std::vector<std::unique_ptr<ResourceManager>> ResourceMgrs;
+ std::vector<OnShutdownCompleteFn> ShutdownCallbacks;
};
inline orc_rt_SessionRef wrap(Session *S) noexcept {
diff --git a/orc-rt/include/orc-rt/TaskDispatcher.h b/orc-rt/include/orc-rt/TaskDispatcher.h
new file mode 100644
index 0000000000000..f49d537ef25f7
--- /dev/null
+++ b/orc-rt/include/orc-rt/TaskDispatcher.h
@@ -0,0 +1,64 @@
+//===----------- TaskDispatcher.h - Task dispatch utils ---------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Task and TaskDispatcher classes.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef ORC_RT_TASKDISPATCHER_H
+#define ORC_RT_TASKDISPATCHER_H
+
+#include "orc-rt/RTTI.h"
+
+#include <memory>
+#include <utility>
+
+namespace orc_rt {
+
+/// Represents an abstract task to be run.
+class Task : public RTTIExtends<Task, RTTIRoot> {
+public:
+ virtual ~Task();
+ virtual void run() = 0;
+};
+
+/// Base class for generic tasks.
+class GenericTask : public RTTIExtends<GenericTask, Task> {};
+
+/// Generic task implementation.
+template <typename FnT> class GenericTaskImpl : public GenericTask {
+public:
+ GenericTaskImpl(FnT &&Fn) : Fn(std::forward<FnT>(Fn)) {}
+ void run() override { Fn(); }
+
+private:
+ FnT Fn;
+};
+
+/// Create a generic task from a function object.
+template <typename FnT> std::unique_ptr<GenericTask> makeGenericTask(FnT &&Fn) {
+ return std::make_unique<GenericTaskImpl<std::decay_t<FnT>>>(
+ std::forward<FnT>(Fn));
+}
+
+/// Abstract base for classes that dispatch Tasks.
+class TaskDispatcher {
+public:
+ virtual ~TaskDispatcher();
+
+ /// Run the given task.
+ virtual void dispatch(std::unique_ptr<Task> T) = 0;
+
+ /// Called by Session. Should cause further dispatches to be rejected, and
+ /// wait until all previously dispatched tasks have completed.
+ virtual void shutdown() = 0;
+};
+
+} // End namespace orc_rt
+
+#endif // ORC_RT_TASKDISPATCHER_H
diff --git a/orc-rt/include/orc-rt/ThreadPoolTaskDispatcher.h b/orc-rt/include/orc-rt/ThreadPoolTaskDispatcher.h
new file mode 100644
index 0000000000000..227c3500a1321
--- /dev/null
+++ b/orc-rt/include/orc-rt/ThreadPoolTaskDispatcher.h
@@ -0,0 +1,48 @@
+//===--- ThreadPoolTaskDispatcher.h - Run tasks in thread pool --*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// ThreadPoolTaskDispatcher implementation.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef ORC_RT_THREADPOOLTASKDISPATCHER_H
+#define ORC_RT_THREADPOOLTASKDISPATCHER_H
+
+#include "orc-rt/TaskDispatcher.h"
+
+#include <condition_variable>
+#include <mutex>
+#include <thread>
+#include <vector>
+
+namespace orc_rt {
+
+/// Thread-pool based TaskDispatcher.
+///
+/// Will spawn NumThreads threads to run dispatched Tasks.
+class ThreadPoolTaskDispatcher : public TaskDispatcher {
+public:
+ ThreadPoolTaskDispatcher(size_t NumThreads);
+ ~ThreadPoolTaskDispatcher() override;
+ void dispatch(std::unique_ptr<Task> T) override;
+ void shutdown() override;
+
+private:
+ void taskLoop();
+
+ std::vector<std::thread> Threads;
+
+ std::mutex M;
+ bool AcceptingTasks = true;
+ std::condition_variable CV;
+ std::vector<std::unique_ptr<Task>> PendingTasks;
+};
+
+} // End namespace orc_rt
+
+#endif // ORC_RT_THREADPOOLTASKDISPATCHER_H
diff --git a/orc-rt/lib/executor/CMakeLists.txt b/orc-rt/lib/executor/CMakeLists.txt
index 9750d8e048f74..58b5ec2189d43 100644
--- a/orc-rt/lib/executor/CMakeLists.txt
+++ b/orc-rt/lib/executor/CMakeLists.txt
@@ -4,6 +4,8 @@ set(files
RTTI.cpp
Session.cpp
SimpleNativeMemoryMap.cpp
+ TaskDispatcher.cpp
+ ThreadPoolTaskDispatcher.cpp
)
add_library(orc-rt-executor STATIC ${files})
diff --git a/orc-rt/lib/executor/Session.cpp b/orc-rt/lib/executor/Session.cpp
index 599bc8705f397..fafa13b1cbb08 100644
--- a/orc-rt/lib/executor/Session.cpp
+++ b/orc-rt/lib/executor/Session.cpp
@@ -12,8 +12,6 @@
#include "orc-rt/Session.h"
-#include <future>
-
namespace orc_rt {
Session::~Session() { waitForShutdown(); }
@@ -23,38 +21,62 @@ void Session::shutdown(OnShutdownCompleteFn OnShutdownComplete) {
{
std::scoped_lock<std::mutex> Lock(M);
+ ShutdownCallbacks.push_back(std::move(OnShutdownComplete));
+
+ // If somebody else has already called shutdown then there's nothing further
+ // for us to do here.
+ if (State >= SessionState::ShuttingDown)
+ return;
+
+ State = SessionState::ShuttingDown;
std::swap(ResourceMgrs, ToShutdown);
}
- shutdownNext(std::move(OnShutdownComplete), Error::success(),
- std::move(ToShutdown));
+ shutdownNext(Error::success(), std::move(ToShutdown));
}
void Session::waitForShutdown() {
- std::promise<void> P;
- auto F = P.get_future();
-
- shutdown([P = std::move(P)]() mutable { P.set_value(); });
-
- F.wait();
+ shutdown([]() {});
+ std::unique_lock<std::mutex> Lock(M);
+ StateCV.wait(Lock, [&]() { return State == SessionState::Shutdown; });
}
void Session::shutdownNext(
- OnShutdownCompleteFn OnComplete, Error Err,
- std::vector<std::unique_ptr<ResourceManager>> RemainingRMs) {
+ Error Err, std::vector<std::unique_ptr<ResourceManager>> RemainingRMs) {
if (Err)
reportError(std::move(Err));
if (RemainingRMs.empty())
- return OnComplete();
+ return shutdownComplete();
auto NextRM = std::move(RemainingRMs.back());
RemainingRMs.pop_back();
- NextRM->shutdown([this, RemainingRMs = std::move(RemainingRMs),
- OnComplete = std::move(OnComplete)](Error Err) mutable {
- shutdownNext(std::move(OnComplete), std::move(Err),
- std::move(RemainingRMs));
- });
+ NextRM->shutdown(
+ [this, RemainingRMs = std::move(RemainingRMs)](Error Err) mutable {
+ shutdownNext(std::move(Err), std::move(RemainingRMs));
+ });
+}
+
+void Session::shutdownComplete() {
+
+ std::unique_ptr<TaskDispatcher> TmpDispatcher;
+ std::vector<OnShutdownCompleteFn> TmpShutdownCallbacks;
+ {
+ std::lock_guard<std::mutex> Lock(M);
+ TmpDispatcher = std::move(Dispatcher);
+ TmpShutdownCallbacks = std::move(ShutdownCallbacks);
+ }
+
+ TmpDispatcher->shutdown();
+
+ for (auto &OnShutdownComplete : TmpShutdownCallbacks)
+ OnShutdownComplete();
+
+ {
+ std::lock_guard<std::mutex> Lock(M);
+ State = SessionState::Shutdown;
+ }
+ StateCV.notify_all();
}
} // namespace orc_rt
diff --git a/orc-rt/lib/executor/TaskDispatcher.cpp b/orc-rt/lib/executor/TaskDispatcher.cpp
new file mode 100644
index 0000000000000..5f34627fb5150
--- /dev/null
+++ b/orc-rt/lib/executor/TaskDispatcher.cpp
@@ -0,0 +1,20 @@
+//===- TaskDispatch.cpp ---------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Contains the implementation of APIs in the orc-rt/TaskDispatch.h header.
+//
+//===----------------------------------------------------------------------===//
+
+#include "orc-rt/TaskDispatcher.h"
+
+namespace orc_rt {
+
+Task::~Task() = default;
+TaskDispatcher::~TaskDispatcher() = default;
+
+} // namespace orc_rt
diff --git a/orc-rt/lib/executor/ThreadPoolTaskDispatcher.cpp b/orc-rt/lib/executor/ThreadPoolTaskDispatcher.cpp
new file mode 100644
index 0000000000000..d6d301302220d
--- /dev/null
+++ b/orc-rt/lib/executor/ThreadPoolTaskDispatcher.cpp
@@ -0,0 +1,70 @@
+//===- ThreadPoolTaskDispatch.cpp -----------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Contains the implementation of APIs in the orc-rt/ThreadPoolTaskDispatch.h
+// header.
+//
+//===----------------------------------------------------------------------===//
+
+#include "orc-rt/ThreadPoolTaskDispatcher.h"
+
+#include <cassert>
+
+namespace orc_rt {
+
+ThreadPoolTaskDispatcher::~ThreadPoolTaskDispatcher() {
+ assert(!AcceptingTasks && "shutdown was not run");
+}
+
+ThreadPoolTaskDispatcher::ThreadPoolTaskDispatcher(size_t NumThreads) {
+ Threads.reserve(NumThreads);
+ for (size_t I = 0; I < NumThreads; ++I)
+ Threads.emplace_back([this]() { taskLoop(); });
+}
+
+void ThreadPoolTaskDispatcher::dispatch(std::unique_ptr<Task> T) {
+ {
+ std::scoped_lock<std::mutex> Lock(M);
+ if (!AcceptingTasks)
+ return;
+ PendingTasks.push_back(std::move(T));
+ }
+ CV.notify_one();
+}
+
+void ThreadPoolTaskDispatcher::shutdown() {
+ {
+ std::scoped_lock<std::mutex> Lock(M);
+ assert(AcceptingTasks && "ThreadPoolTaskDispatcher already shut down?");
+ AcceptingTasks = false;
+ }
+ CV.notify_all();
+ for (auto &Thread : Threads)
+ Thread.join();
+}
+
+void ThreadPoolTaskDispatcher::taskLoop() {
+ while (true) {
+ std::unique_ptr<Task> T;
+ {
+ std::unique_lock<std::mutex> Lock(M);
+ CV.wait(Lock,
+ [this]() { return !PendingTasks.empty() || !AcceptingTasks; });
+
+ if (!AcceptingTasks && PendingTasks.empty())
+ return;
+
+ T = std::move(PendingTasks.back());
+ PendingTasks.pop_back();
+ }
+
+ T->run();
+ }
+}
+
+} // namespace orc_rt
diff --git a/orc-rt/unittests/CMakeLists.txt b/orc-rt/unittests/CMakeLists.txt
index 7b943e8039449..c43ec17b54de3 100644
--- a/orc-rt/unittests/CMakeLists.txt
+++ b/orc-rt/unittests/CMakeLists.txt
@@ -31,6 +31,7 @@ add_orc_rt_unittest(CoreTests
SPSMemoryFlagsTest.cpp
SPSWrapperFunctionTest.cpp
SPSWrapperFunctionBufferTest.cpp
+ ThreadPoolTaskDispatcherTest.cpp
WrapperFunctionBufferTest.cpp
bind-test.cpp
bit-test.cpp
diff --git a/orc-rt/unittests/SessionTest.cpp b/orc-rt/unittests/SessionTest.cpp
index 7e6084484e227..85b82e65744b0 100644
--- a/orc-rt/unittests/SessionTest.cpp
+++ b/orc-rt/unittests/SessionTest.cpp
@@ -11,11 +11,17 @@
//===----------------------------------------------------------------------===//
#include "orc-rt/Session.h"
+#include "orc-rt/ThreadPoolTaskDispatcher.h"
+
#include "gmock/gmock.h"
#include "gtest/gtest.h"
+#include <deque>
+#include <future>
#include <optional>
+#include <iostream>
+
using namespace orc_rt;
using ::testing::Eq;
using ::testing::Optional;
@@ -49,17 +55,47 @@ class MockResourceManager : public ResourceManager {
move_only_function<Error(Op)> GenResult;
};
+class NoDispatcher : public TaskDispatcher {
+public:
+ void dispatch(std::unique_ptr<Task> T) override {
+ assert(false && "strictly no dispatching!");
+ }
+ void shutdown() override {}
+};
+
+class EnqueueingDispatcher : public TaskDispatcher {
+public:
+ using OnShutdownRunFn = move_only_function<void()>;
+ EnqueueingDispatcher(std::deque<std::unique_ptr<Task>> &Tasks,
+ OnShutdownRunFn OnShutdownRun = {})
+ : Tasks(Tasks), OnShutdownRun(std::move(OnShutdownRun)) {}
+ void dispatch(std::unique_ptr<Task> T) override {
+ Tasks.push_back(std::move(T));
+ }
+ void shutdown() override {
+ if (OnShutdownRun)
+ OnShutdownRun();
+ }
+
+private:
+ std::deque<std::unique_ptr<Task>> &Tasks;
+ OnShutdownRunFn OnShutdownRun;
+};
+
// Non-overloaded version of cantFail: allows easy construction of
// move_only_functions<void(Error)>s.
static void noErrors(Error Err) { cantFail(std::move(Err)); }
-TEST(SessionTest, TrivialConstructionAndDestruction) { Session S(noErrors); }
+TEST(SessionTest, TrivialConstructionAndDestruction) {
+ Session S(std::make_unique<NoDispatcher>(), noErrors);
+}
TEST(SessionTest, ReportError) {
Error E = Error::success();
cantFail(std::move(E)); // Force error into checked state.
- Session S([&](Error Err) { E = std::move(Err); });
+ Session S(std::make_unique<NoDispatcher>(),
+ [&](Error Err) { E = std::move(Err); });
S.reportError(make_error<StringError>("foo"));
if (E)
@@ -68,13 +104,27 @@ TEST(SessionTest, ReportError) {
ADD_FAILURE() << "Missing error value";
}
+TEST(SessionTest, DispatchTask) {
+ int X = 0;
+ std::deque<std::unique_ptr<Task>> Tasks;
+ Session S(std::make_unique<EnqueueingDispatcher>(Tasks), noErrors);
+
+ EXPECT_EQ(Tasks.size(), 0U);
+ S.dispatch(makeGenericTask([&]() { ++X; }));
+ EXPECT_EQ(Tasks.size(), 1U);
+ auto T = std::move(Tasks.front());
+ Tasks.pop_front();
+ T->run();
+ EXPECT_EQ(X, 1);
+}
+
TEST(SessionTest, SingleResourceManager) {
size_t OpIdx = 0;
std::optional<size_t> DetachOpIdx;
std::optional<size_t> ShutdownOpIdx;
{
- Session S(noErrors);
+ Session S(std::make_unique<NoDispatcher>(), noErrors);
S.addResourceManager(std::make_unique<MockResourceManager>(
DetachOpIdx, ShutdownOpIdx, OpIdx));
}
@@ -90,7 +140,7 @@ TEST(SessionTest, MultipleResourceManagers) {
std::optional<size_t> ShutdownOpIdx[3];
{
- Session S(noErrors);
+ Session S(std::make_unique<NoDispatcher>(), noErrors);
for (size_t I = 0; I != 3; ++I)
S.addResourceManager(std::make_unique<MockResourceManager>(
DetachOpIdx[I], ShutdownOpIdx[I], OpIdx));
@@ -103,3 +153,39 @@ TEST(SessionTest, MultipleResourceManagers) {
EXPECT_THAT(ShutdownOpIdx[I], Optional(Eq(2 - I)));
}
}
+
+TEST(SessionTest, ExpectedShutdownSequence) {
+ // Check that Session shutdown results in...
+ // 1. ResourceManagers being shut down.
+ // 2. The TaskDispatcher being shut down.
+ // 3. A call to OnShutdownComplete.
+
+ size_t OpIdx = 0;
+ std::optional<size_t> DetachOpIdx;
+ std::optional<size_t> ShutdownOpIdx;
+
+ bool DispatcherShutDown = false;
+ bool SessionShutdownComplete = false;
+ std::deque<std::unique_ptr<Task>> Tasks;
+ Session S(std::make_unique<EnqueueingDispatcher>(
+ Tasks,
+ [&]() {
+ std::cerr << "Running dispatcher shutdown.\n";
+ EXPECT_TRUE(ShutdownOpIdx);
+ EXPECT_EQ(*ShutdownOpIdx, 0);
+ EXPECT_FALSE(SessionShutdownComplete);
+ DispatcherShutDown = true;
+ }),
+ noErrors);
+ S.addResourceManager(
+ std::make_unique<MockResourceManager>(DetachOpIdx, ShutdownOpIdx, OpIdx));
+
+ S.shutdown([&]() {
+ EXPECT_TRUE(DispatcherShutDown);
+ std::cerr << "Running shutdown callback.\n";
+ SessionShutdownComplete = true;
+ });
+ S.waitForShutdown();
+
+ EXPECT_TRUE(SessionShutdownComplete);
+}
diff --git a/orc-rt/unittests/ThreadPoolTaskDispatcherTest.cpp b/orc-rt/unittests/ThreadPoolTaskDispatcherTest.cpp
new file mode 100644
index 0000000000000..02cca94a494ff
--- /dev/null
+++ b/orc-rt/unittests/ThreadPoolTaskDispatcherTest.cpp
@@ -0,0 +1,110 @@
+//===-- ThreadPoolTaskDispatcherTest.cpp ----------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "orc-rt/ThreadPoolTaskDispatcher.h"
+#include "gtest/gtest.h"
+
+#include <atomic>
+#include <future>
+#include <thread>
+#include <vector>
+
+using namespace orc_rt;
+
+namespace {
+
+TEST(ThreadPoolTaskDispatcherTest, NoTasks) {
+ // Check that immediate shutdown works as expected.
+ ThreadPoolTaskDispatcher Dispatcher(1);
+ Dispatcher.shutdown();
+}
+
+TEST(ThreadPoolTaskDispatcherTest, BasicTaskExecution) {
+ // Smoke test: Check that we can run a single task on a single-threaded pool.
+ ThreadPoolTaskDispatcher Dispatcher(1);
+ std::atomic<bool> TaskRan = false;
+
+ Dispatcher.dispatch(makeGenericTask([&]() { TaskRan = true; }));
+
+ Dispatcher.shutdown();
+
+ EXPECT_TRUE(TaskRan);
+}
+
+TEST(ThreadPoolTaskDispatcherTest, SingleThreadMultipleTasks) {
+ // Check that multiple tasks in a single threaded pool run as expected.
+ ThreadPoolTaskDispatcher Dispatcher(1);
+ size_t NumTasksToRun = 10;
+ std::atomic<size_t> TasksRun = 0;
+
+ for (size_t I = 0; I != NumTasksToRun; ++I)
+ Dispatcher.dispatch(makeGenericTask([&]() { ++TasksRun; }));
+
+ Dispatcher.shutdown();
+
+ EXPECT_EQ(TasksRun, NumTasksToRun);
+}
+
+TEST(ThreadPoolTaskDispatcherTest, ConcurrentTasks) {
+ // Check that tasks are run concurrently when multiple workers are available.
+ // Adds two tasks that communicate a value back and forth using futures.
+ // Neither task should be able to complete without the other having started.
+ ThreadPoolTaskDispatcher Dispatcher(2);
+
+ std::promise<int> PInit;
+ std::future<int> FInit = PInit.get_future();
+ std::promise<int> P1;
+ std::future<int> F1 = P1.get_future();
+ std::promise<int> P2;
+ std::future<int> F2 = P2.get_future();
+ std::promise<int> PResult;
+ std::future<int> FResult = PResult.get_future();
+
+ // Task A gets the initial value, sends it via P1, waits for response on F2.
+ Dispatcher.dispatch(makeGenericTask([&]() {
+ P1.set_value(FInit.get());
+ PResult.set_value(F2.get());
+ }));
+
+ // Task B gets value from F1, sends it back on P2.
+ Dispatcher.dispatch(makeGenericTask([&]() { P2.set_value(F1.get()); }));
+
+ int ExpectedValue = 42;
+ PInit.set_value(ExpectedValue);
+
+ Dispatcher.shutdown();
+
+ EXPECT_EQ(FResult.get(), ExpectedValue);
+}
+
+TEST(ThreadPoolTaskDispatcherTest, TasksRejectedAfterShutdown) {
+ class TaskToReject : public Task {
+ public:
+ TaskToReject(bool &BodyRun, bool &DestructorRun)
+ : BodyRun(BodyRun), DestructorRun(DestructorRun) {}
+ ~TaskToReject() { DestructorRun = true; }
+ void run() override { BodyRun = true; }
+
+ private:
+ bool &BodyRun;
+ bool &DestructorRun;
+ };
+
+ ThreadPoolTaskDispatcher Dispatcher(1);
+ Dispatcher.shutdown();
+
+ bool BodyRun = false;
+ bool DestructorRun = false;
+
+ Dispatcher.dispatch(std::make_unique<TaskToReject>(BodyRun, DestructorRun));
+
+ EXPECT_FALSE(BodyRun);
+ EXPECT_TRUE(DestructorRun);
+}
+
+} // end anonymous namespace
More information about the llvm-commits
mailing list