[llvm] [orc-rt] Introduce Task and TaskDispatcher APIs and implementations. (PR #168514)

Lang Hames via llvm-commits llvm-commits at lists.llvm.org
Tue Nov 18 03:19:41 PST 2025


https://github.com/lhames created https://github.com/llvm/llvm-project/pull/168514

Introduces the Task and TaskDispatcher interfaces (TaskDispatcher.h), ThreadPoolTaskDispatcher implementation (ThreadPoolTaskDispatch.h), and updates Session to include a TaskDispatcher instance that can be used to run tasks.

TaskDispatcher's introduction is motivated by the need to handle calls to JIT'd code initiated from the controller process: Incoming calls will be wrapped in Tasks and dispatched. Session shutdown will wait on TaskDispatcher shutdown, ensuring that all Tasks are run or destroyed prior to the Session being destroyed.

>From 951bec15eec1b941081aacae6c4a30a931e01506 Mon Sep 17 00:00:00 2001
From: Lang Hames <lhames at gmail.com>
Date: Mon, 17 Nov 2025 19:28:49 +1100
Subject: [PATCH] [orc-rt] Introduce Task and TaskDispatcher APIs and
 implementations.

Introduces the Task and TaskDispatcher interfaces (TaskDispatcher.h),
ThreadPoolTaskDispatcher implementation (ThreadPoolTaskDispatch.h), and updates
Session to include a TaskDispatcher instance that can be used to run tasks.

TaskDispatcher's introduction is motivated by the need to handle calls to JIT'd
code initiated from the controller process: Incoming calls will be wrapped in
Tasks and dispatched. Session shutdown will wait on TaskDispatcher shutdown,
ensuring that all Tasks are run or destroyed prior to the Session being
destroyed.
---
 orc-rt/include/CMakeLists.txt                 |   2 +
 orc-rt/include/orc-rt/Session.h               |  23 +++-
 orc-rt/include/orc-rt/TaskDispatcher.h        |  64 ++++++++++
 .../include/orc-rt/ThreadPoolTaskDispatcher.h |  48 ++++++++
 orc-rt/lib/executor/CMakeLists.txt            |   2 +
 orc-rt/lib/executor/Session.cpp               |  58 ++++++---
 orc-rt/lib/executor/TaskDispatcher.cpp        |  20 ++++
 .../lib/executor/ThreadPoolTaskDispatcher.cpp |  70 +++++++++++
 orc-rt/unittests/CMakeLists.txt               |   1 +
 orc-rt/unittests/SessionTest.cpp              |  94 ++++++++++++++-
 .../ThreadPoolTaskDispatcherTest.cpp          | 110 ++++++++++++++++++
 11 files changed, 467 insertions(+), 25 deletions(-)
 create mode 100644 orc-rt/include/orc-rt/TaskDispatcher.h
 create mode 100644 orc-rt/include/orc-rt/ThreadPoolTaskDispatcher.h
 create mode 100644 orc-rt/lib/executor/TaskDispatcher.cpp
 create mode 100644 orc-rt/lib/executor/ThreadPoolTaskDispatcher.cpp
 create mode 100644 orc-rt/unittests/ThreadPoolTaskDispatcherTest.cpp

diff --git a/orc-rt/include/CMakeLists.txt b/orc-rt/include/CMakeLists.txt
index 8ac8a126dd012..35c45e236c023 100644
--- a/orc-rt/include/CMakeLists.txt
+++ b/orc-rt/include/CMakeLists.txt
@@ -22,6 +22,8 @@ set(ORC_RT_HEADERS
     orc-rt/SPSMemoryFlags.h
     orc-rt/SPSWrapperFunction.h
     orc-rt/SPSWrapperFunctionBuffer.h
+    orc-rt/TaskDispatcher.h
+    orc-rt/ThreadPoolTaskDispatcher.h
     orc-rt/WrapperFunction.h
     orc-rt/bind.h
     orc-rt/bit.h
diff --git a/orc-rt/include/orc-rt/Session.h b/orc-rt/include/orc-rt/Session.h
index 78bd92bb0d0c8..367cdb9a97b62 100644
--- a/orc-rt/include/orc-rt/Session.h
+++ b/orc-rt/include/orc-rt/Session.h
@@ -15,10 +15,12 @@
 
 #include "orc-rt/Error.h"
 #include "orc-rt/ResourceManager.h"
+#include "orc-rt/TaskDispatcher.h"
 #include "orc-rt/move_only_function.h"
 
 #include "orc-rt-c/CoreTypes.h"
 
+#include <condition_variable>
 #include <memory>
 #include <mutex>
 #include <vector>
@@ -39,7 +41,10 @@ class Session {
   ///
   /// Note that entry into the reporter is not synchronized: it may be
   /// called from multiple threads concurrently.
-  Session(ErrorReporterFn ReportError) : ReportError(std::move(ReportError)) {}
+  Session(std::unique_ptr<TaskDispatcher> Dispatcher,
+          ErrorReporterFn ReportError)
+      : Dispatcher(std::move(Dispatcher)), ReportError(std::move(ReportError)) {
+  }
 
   // Sessions are not copyable or moveable.
   Session(const Session &) = delete;
@@ -49,6 +54,9 @@ class Session {
 
   ~Session();
 
+  /// Dispatch a task using the Session's TaskDispatcher.
+  void dispatch(std::unique_ptr<Task> T) { Dispatcher->dispatch(std::move(T)); }
+
   /// Report an error via the ErrorReporter function.
   void reportError(Error Err) { ReportError(std::move(Err)); }
 
@@ -67,12 +75,21 @@ class Session {
   }
 
 private:
-  void shutdownNext(OnShutdownCompleteFn OnShutdownComplete, Error Err,
+  void shutdownNext(Error Err,
                     std::vector<std::unique_ptr<ResourceManager>> RemainingRMs);
 
-  std::mutex M;
+  void shutdownComplete();
+
+  std::unique_ptr<TaskDispatcher> Dispatcher;
   ErrorReporterFn ReportError;
+
+  enum class SessionState { Running, ShuttingDown, Shutdown };
+
+  std::mutex M;
+  SessionState State = SessionState::Running;
+  std::condition_variable StateCV;
   std::vector<std::unique_ptr<ResourceManager>> ResourceMgrs;
+  std::vector<OnShutdownCompleteFn> ShutdownCallbacks;
 };
 
 inline orc_rt_SessionRef wrap(Session *S) noexcept {
diff --git a/orc-rt/include/orc-rt/TaskDispatcher.h b/orc-rt/include/orc-rt/TaskDispatcher.h
new file mode 100644
index 0000000000000..f49d537ef25f7
--- /dev/null
+++ b/orc-rt/include/orc-rt/TaskDispatcher.h
@@ -0,0 +1,64 @@
+//===----------- TaskDispatcher.h - Task dispatch utils ---------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Task and TaskDispatcher classes.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef ORC_RT_TASKDISPATCHER_H
+#define ORC_RT_TASKDISPATCHER_H
+
+#include "orc-rt/RTTI.h"
+
+#include <memory>
+#include <utility>
+
+namespace orc_rt {
+
+/// Represents an abstract task to be run.
+class Task : public RTTIExtends<Task, RTTIRoot> {
+public:
+  virtual ~Task();
+  virtual void run() = 0;
+};
+
+/// Base class for generic tasks.
+class GenericTask : public RTTIExtends<GenericTask, Task> {};
+
+/// Generic task implementation.
+template <typename FnT> class GenericTaskImpl : public GenericTask {
+public:
+  GenericTaskImpl(FnT &&Fn) : Fn(std::forward<FnT>(Fn)) {}
+  void run() override { Fn(); }
+
+private:
+  FnT Fn;
+};
+
+/// Create a generic task from a function object.
+template <typename FnT> std::unique_ptr<GenericTask> makeGenericTask(FnT &&Fn) {
+  return std::make_unique<GenericTaskImpl<std::decay_t<FnT>>>(
+      std::forward<FnT>(Fn));
+}
+
+/// Abstract base for classes that dispatch Tasks.
+class TaskDispatcher {
+public:
+  virtual ~TaskDispatcher();
+
+  /// Run the given task.
+  virtual void dispatch(std::unique_ptr<Task> T) = 0;
+
+  /// Called by Session. Should cause further dispatches to be rejected, and
+  /// wait until all previously dispatched tasks have completed.
+  virtual void shutdown() = 0;
+};
+
+} // End namespace orc_rt
+
+#endif // ORC_RT_TASKDISPATCHER_H
diff --git a/orc-rt/include/orc-rt/ThreadPoolTaskDispatcher.h b/orc-rt/include/orc-rt/ThreadPoolTaskDispatcher.h
new file mode 100644
index 0000000000000..227c3500a1321
--- /dev/null
+++ b/orc-rt/include/orc-rt/ThreadPoolTaskDispatcher.h
@@ -0,0 +1,48 @@
+//===--- ThreadPoolTaskDispatcher.h - Run tasks in thread pool --*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// ThreadPoolTaskDispatcher implementation.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef ORC_RT_THREADPOOLTASKDISPATCHER_H
+#define ORC_RT_THREADPOOLTASKDISPATCHER_H
+
+#include "orc-rt/TaskDispatcher.h"
+
+#include <condition_variable>
+#include <mutex>
+#include <thread>
+#include <vector>
+
+namespace orc_rt {
+
+/// Thread-pool based TaskDispatcher.
+///
+/// Will spawn NumThreads threads to run dispatched Tasks.
+class ThreadPoolTaskDispatcher : public TaskDispatcher {
+public:
+  ThreadPoolTaskDispatcher(size_t NumThreads);
+  ~ThreadPoolTaskDispatcher() override;
+  void dispatch(std::unique_ptr<Task> T) override;
+  void shutdown() override;
+
+private:
+  void taskLoop();
+
+  std::vector<std::thread> Threads;
+
+  std::mutex M;
+  bool AcceptingTasks = true;
+  std::condition_variable CV;
+  std::vector<std::unique_ptr<Task>> PendingTasks;
+};
+
+} // End namespace orc_rt
+
+#endif // ORC_RT_THREADPOOLTASKDISPATCHER_H
diff --git a/orc-rt/lib/executor/CMakeLists.txt b/orc-rt/lib/executor/CMakeLists.txt
index 9750d8e048f74..58b5ec2189d43 100644
--- a/orc-rt/lib/executor/CMakeLists.txt
+++ b/orc-rt/lib/executor/CMakeLists.txt
@@ -4,6 +4,8 @@ set(files
   RTTI.cpp
   Session.cpp
   SimpleNativeMemoryMap.cpp
+  TaskDispatcher.cpp
+  ThreadPoolTaskDispatcher.cpp
   )
 
 add_library(orc-rt-executor STATIC ${files})
diff --git a/orc-rt/lib/executor/Session.cpp b/orc-rt/lib/executor/Session.cpp
index 599bc8705f397..fafa13b1cbb08 100644
--- a/orc-rt/lib/executor/Session.cpp
+++ b/orc-rt/lib/executor/Session.cpp
@@ -12,8 +12,6 @@
 
 #include "orc-rt/Session.h"
 
-#include <future>
-
 namespace orc_rt {
 
 Session::~Session() { waitForShutdown(); }
@@ -23,38 +21,62 @@ void Session::shutdown(OnShutdownCompleteFn OnShutdownComplete) {
 
   {
     std::scoped_lock<std::mutex> Lock(M);
+    ShutdownCallbacks.push_back(std::move(OnShutdownComplete));
+
+    // If somebody else has already called shutdown then there's nothing further
+    // for us to do here.
+    if (State >= SessionState::ShuttingDown)
+      return;
+
+    State = SessionState::ShuttingDown;
     std::swap(ResourceMgrs, ToShutdown);
   }
 
-  shutdownNext(std::move(OnShutdownComplete), Error::success(),
-               std::move(ToShutdown));
+  shutdownNext(Error::success(), std::move(ToShutdown));
 }
 
 void Session::waitForShutdown() {
-  std::promise<void> P;
-  auto F = P.get_future();
-
-  shutdown([P = std::move(P)]() mutable { P.set_value(); });
-
-  F.wait();
+  shutdown([]() {});
+  std::unique_lock<std::mutex> Lock(M);
+  StateCV.wait(Lock, [&]() { return State == SessionState::Shutdown; });
 }
 
 void Session::shutdownNext(
-    OnShutdownCompleteFn OnComplete, Error Err,
-    std::vector<std::unique_ptr<ResourceManager>> RemainingRMs) {
+    Error Err, std::vector<std::unique_ptr<ResourceManager>> RemainingRMs) {
   if (Err)
     reportError(std::move(Err));
 
   if (RemainingRMs.empty())
-    return OnComplete();
+    return shutdownComplete();
 
   auto NextRM = std::move(RemainingRMs.back());
   RemainingRMs.pop_back();
-  NextRM->shutdown([this, RemainingRMs = std::move(RemainingRMs),
-                    OnComplete = std::move(OnComplete)](Error Err) mutable {
-    shutdownNext(std::move(OnComplete), std::move(Err),
-                 std::move(RemainingRMs));
-  });
+  NextRM->shutdown(
+      [this, RemainingRMs = std::move(RemainingRMs)](Error Err) mutable {
+        shutdownNext(std::move(Err), std::move(RemainingRMs));
+      });
+}
+
+void Session::shutdownComplete() {
+
+  std::unique_ptr<TaskDispatcher> TmpDispatcher;
+  std::vector<OnShutdownCompleteFn> TmpShutdownCallbacks;
+  {
+    std::lock_guard<std::mutex> Lock(M);
+    TmpDispatcher = std::move(Dispatcher);
+    TmpShutdownCallbacks = std::move(ShutdownCallbacks);
+  }
+
+  TmpDispatcher->shutdown();
+
+  for (auto &OnShutdownComplete : TmpShutdownCallbacks)
+    OnShutdownComplete();
+
+  {
+    std::lock_guard<std::mutex> Lock(M);
+    State = SessionState::Shutdown;
+  }
+  StateCV.notify_all();
 }
 
 } // namespace orc_rt
diff --git a/orc-rt/lib/executor/TaskDispatcher.cpp b/orc-rt/lib/executor/TaskDispatcher.cpp
new file mode 100644
index 0000000000000..5f34627fb5150
--- /dev/null
+++ b/orc-rt/lib/executor/TaskDispatcher.cpp
@@ -0,0 +1,20 @@
+//===- TaskDispatch.cpp ---------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Contains the implementation of APIs in the orc-rt/TaskDispatch.h header.
+//
+//===----------------------------------------------------------------------===//
+
+#include "orc-rt/TaskDispatcher.h"
+
+namespace orc_rt {
+
+Task::~Task() = default;
+TaskDispatcher::~TaskDispatcher() = default;
+
+} // namespace orc_rt
diff --git a/orc-rt/lib/executor/ThreadPoolTaskDispatcher.cpp b/orc-rt/lib/executor/ThreadPoolTaskDispatcher.cpp
new file mode 100644
index 0000000000000..d6d301302220d
--- /dev/null
+++ b/orc-rt/lib/executor/ThreadPoolTaskDispatcher.cpp
@@ -0,0 +1,70 @@
+//===- ThreadPoolTaskDispatch.cpp -----------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Contains the implementation of APIs in the orc-rt/ThreadPoolTaskDispatch.h
+// header.
+//
+//===----------------------------------------------------------------------===//
+
+#include "orc-rt/ThreadPoolTaskDispatcher.h"
+
+#include <cassert>
+
+namespace orc_rt {
+
+ThreadPoolTaskDispatcher::~ThreadPoolTaskDispatcher() {
+  assert(!AcceptingTasks && "shutdown was not run");
+}
+
+ThreadPoolTaskDispatcher::ThreadPoolTaskDispatcher(size_t NumThreads) {
+  Threads.reserve(NumThreads);
+  for (size_t I = 0; I < NumThreads; ++I)
+    Threads.emplace_back([this]() { taskLoop(); });
+}
+
+void ThreadPoolTaskDispatcher::dispatch(std::unique_ptr<Task> T) {
+  {
+    std::scoped_lock<std::mutex> Lock(M);
+    if (!AcceptingTasks)
+      return;
+    PendingTasks.push_back(std::move(T));
+  }
+  CV.notify_one();
+}
+
+void ThreadPoolTaskDispatcher::shutdown() {
+  {
+    std::scoped_lock<std::mutex> Lock(M);
+    assert(AcceptingTasks && "ThreadPoolTaskDispatcher already shut down?");
+    AcceptingTasks = false;
+  }
+  CV.notify_all();
+  for (auto &Thread : Threads)
+    Thread.join();
+}
+
+void ThreadPoolTaskDispatcher::taskLoop() {
+  while (true) {
+    std::unique_ptr<Task> T;
+    {
+      std::unique_lock<std::mutex> Lock(M);
+      CV.wait(Lock,
+              [this]() { return !PendingTasks.empty() || !AcceptingTasks; });
+
+      if (!AcceptingTasks && PendingTasks.empty())
+        return;
+
+      T = std::move(PendingTasks.back());
+      PendingTasks.pop_back();
+    }
+
+    T->run();
+  }
+}
+
+} // namespace orc_rt
diff --git a/orc-rt/unittests/CMakeLists.txt b/orc-rt/unittests/CMakeLists.txt
index 7b943e8039449..c43ec17b54de3 100644
--- a/orc-rt/unittests/CMakeLists.txt
+++ b/orc-rt/unittests/CMakeLists.txt
@@ -31,6 +31,7 @@ add_orc_rt_unittest(CoreTests
   SPSMemoryFlagsTest.cpp
   SPSWrapperFunctionTest.cpp
   SPSWrapperFunctionBufferTest.cpp
+  ThreadPoolTaskDispatcherTest.cpp
   WrapperFunctionBufferTest.cpp
   bind-test.cpp
   bit-test.cpp
diff --git a/orc-rt/unittests/SessionTest.cpp b/orc-rt/unittests/SessionTest.cpp
index 7e6084484e227..85b82e65744b0 100644
--- a/orc-rt/unittests/SessionTest.cpp
+++ b/orc-rt/unittests/SessionTest.cpp
@@ -11,11 +11,17 @@
 //===----------------------------------------------------------------------===//
 
 #include "orc-rt/Session.h"
+#include "orc-rt/ThreadPoolTaskDispatcher.h"
+
 #include "gmock/gmock.h"
 #include "gtest/gtest.h"
 
+#include <deque>
+#include <future>
 #include <optional>
 
+#include <iostream>
+
 using namespace orc_rt;
 using ::testing::Eq;
 using ::testing::Optional;
@@ -49,17 +55,47 @@ class MockResourceManager : public ResourceManager {
   move_only_function<Error(Op)> GenResult;
 };
 
+class NoDispatcher : public TaskDispatcher {
+public:
+  void dispatch(std::unique_ptr<Task> T) override {
+    assert(false && "strictly no dispatching!");
+  }
+  void shutdown() override {}
+};
+
+class EnqueueingDispatcher : public TaskDispatcher {
+public:
+  using OnShutdownRunFn = move_only_function<void()>;
+  EnqueueingDispatcher(std::deque<std::unique_ptr<Task>> &Tasks,
+                       OnShutdownRunFn OnShutdownRun = {})
+      : Tasks(Tasks), OnShutdownRun(std::move(OnShutdownRun)) {}
+  void dispatch(std::unique_ptr<Task> T) override {
+    Tasks.push_back(std::move(T));
+  }
+  void shutdown() override {
+    if (OnShutdownRun)
+      OnShutdownRun();
+  }
+
+private:
+  std::deque<std::unique_ptr<Task>> &Tasks;
+  OnShutdownRunFn OnShutdownRun;
+};
+
 // Non-overloaded version of cantFail: allows easy construction of
 // move_only_functions<void(Error)>s.
 static void noErrors(Error Err) { cantFail(std::move(Err)); }
 
-TEST(SessionTest, TrivialConstructionAndDestruction) { Session S(noErrors); }
+TEST(SessionTest, TrivialConstructionAndDestruction) {
+  Session S(std::make_unique<NoDispatcher>(), noErrors);
+}
 
 TEST(SessionTest, ReportError) {
   Error E = Error::success();
   cantFail(std::move(E)); // Force error into checked state.
 
-  Session S([&](Error Err) { E = std::move(Err); });
+  Session S(std::make_unique<NoDispatcher>(),
+            [&](Error Err) { E = std::move(Err); });
   S.reportError(make_error<StringError>("foo"));
 
   if (E)
@@ -68,13 +104,27 @@ TEST(SessionTest, ReportError) {
     ADD_FAILURE() << "Missing error value";
 }
 
+TEST(SessionTest, DispatchTask) {
+  int X = 0;
+  std::deque<std::unique_ptr<Task>> Tasks;
+  Session S(std::make_unique<EnqueueingDispatcher>(Tasks), noErrors);
+
+  EXPECT_EQ(Tasks.size(), 0U);
+  S.dispatch(makeGenericTask([&]() { ++X; }));
+  EXPECT_EQ(Tasks.size(), 1U);
+  auto T = std::move(Tasks.front());
+  Tasks.pop_front();
+  T->run();
+  EXPECT_EQ(X, 1);
+}
+
 TEST(SessionTest, SingleResourceManager) {
   size_t OpIdx = 0;
   std::optional<size_t> DetachOpIdx;
   std::optional<size_t> ShutdownOpIdx;
 
   {
-    Session S(noErrors);
+    Session S(std::make_unique<NoDispatcher>(), noErrors);
     S.addResourceManager(std::make_unique<MockResourceManager>(
         DetachOpIdx, ShutdownOpIdx, OpIdx));
   }
@@ -90,7 +140,7 @@ TEST(SessionTest, MultipleResourceManagers) {
   std::optional<size_t> ShutdownOpIdx[3];
 
   {
-    Session S(noErrors);
+    Session S(std::make_unique<NoDispatcher>(), noErrors);
     for (size_t I = 0; I != 3; ++I)
       S.addResourceManager(std::make_unique<MockResourceManager>(
           DetachOpIdx[I], ShutdownOpIdx[I], OpIdx));
@@ -103,3 +153,39 @@ TEST(SessionTest, MultipleResourceManagers) {
     EXPECT_THAT(ShutdownOpIdx[I], Optional(Eq(2 - I)));
   }
 }
+
+TEST(SessionTest, ExpectedShutdownSequence) {
+  // Check that Session shutdown results in...
+  // 1. ResourceManagers being shut down.
+  // 2. The TaskDispatcher being shut down.
+  // 3. A call to OnShutdownComplete.
+
+  size_t OpIdx = 0;
+  std::optional<size_t> DetachOpIdx;
+  std::optional<size_t> ShutdownOpIdx;
+
+  bool DispatcherShutDown = false;
+  bool SessionShutdownComplete = false;
+  std::deque<std::unique_ptr<Task>> Tasks;
+  Session S(std::make_unique<EnqueueingDispatcher>(
+                Tasks,
+                [&]() {
+                  std::cerr << "Running dispatcher shutdown.\n";
+                  EXPECT_TRUE(ShutdownOpIdx);
+                  EXPECT_EQ(*ShutdownOpIdx, 0);
+                  EXPECT_FALSE(SessionShutdownComplete);
+                  DispatcherShutDown = true;
+                }),
+            noErrors);
+  S.addResourceManager(
+      std::make_unique<MockResourceManager>(DetachOpIdx, ShutdownOpIdx, OpIdx));
+
+  S.shutdown([&]() {
+    EXPECT_TRUE(DispatcherShutDown);
+    std::cerr << "Running shutdown callback.\n";
+    SessionShutdownComplete = true;
+  });
+  S.waitForShutdown();
+
+  EXPECT_TRUE(SessionShutdownComplete);
+}
diff --git a/orc-rt/unittests/ThreadPoolTaskDispatcherTest.cpp b/orc-rt/unittests/ThreadPoolTaskDispatcherTest.cpp
new file mode 100644
index 0000000000000..02cca94a494ff
--- /dev/null
+++ b/orc-rt/unittests/ThreadPoolTaskDispatcherTest.cpp
@@ -0,0 +1,110 @@
+//===-- ThreadPoolTaskDispatcherTest.cpp ----------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "orc-rt/ThreadPoolTaskDispatcher.h"
+#include "gtest/gtest.h"
+
+#include <atomic>
+#include <future>
+#include <thread>
+#include <vector>
+
+using namespace orc_rt;
+
+namespace {
+
+TEST(ThreadPoolTaskDispatcherTest, NoTasks) {
+  // Check that immediate shutdown works as expected.
+  ThreadPoolTaskDispatcher Dispatcher(1);
+  Dispatcher.shutdown();
+}
+
+TEST(ThreadPoolTaskDispatcherTest, BasicTaskExecution) {
+  // Smoke test: Check that we can run a single task on a single-threaded pool.
+  ThreadPoolTaskDispatcher Dispatcher(1);
+  std::atomic<bool> TaskRan = false;
+
+  Dispatcher.dispatch(makeGenericTask([&]() { TaskRan = true; }));
+
+  Dispatcher.shutdown();
+
+  EXPECT_TRUE(TaskRan);
+}
+
+TEST(ThreadPoolTaskDispatcherTest, SingleThreadMultipleTasks) {
+  // Check that multiple tasks in a single threaded pool run as expected.
+  ThreadPoolTaskDispatcher Dispatcher(1);
+  size_t NumTasksToRun = 10;
+  std::atomic<size_t> TasksRun = 0;
+
+  for (size_t I = 0; I != NumTasksToRun; ++I)
+    Dispatcher.dispatch(makeGenericTask([&]() { ++TasksRun; }));
+
+  Dispatcher.shutdown();
+
+  EXPECT_EQ(TasksRun, NumTasksToRun);
+}
+
+TEST(ThreadPoolTaskDispatcherTest, ConcurrentTasks) {
+  // Check that tasks are run concurrently when multiple workers are available.
+  // Adds two tasks that communicate a value back and forth using futures.
+  // Neither task should be able to complete without the other having started.
+  ThreadPoolTaskDispatcher Dispatcher(2);
+
+  std::promise<int> PInit;
+  std::future<int> FInit = PInit.get_future();
+  std::promise<int> P1;
+  std::future<int> F1 = P1.get_future();
+  std::promise<int> P2;
+  std::future<int> F2 = P2.get_future();
+  std::promise<int> PResult;
+  std::future<int> FResult = PResult.get_future();
+
+  // Task A gets the initial value, sends it via P1, waits for response on F2.
+  Dispatcher.dispatch(makeGenericTask([&]() {
+    P1.set_value(FInit.get());
+    PResult.set_value(F2.get());
+  }));
+
+  // Task B gets value from F1, sends it back on P2.
+  Dispatcher.dispatch(makeGenericTask([&]() { P2.set_value(F1.get()); }));
+
+  int ExpectedValue = 42;
+  PInit.set_value(ExpectedValue);
+
+  Dispatcher.shutdown();
+
+  EXPECT_EQ(FResult.get(), ExpectedValue);
+}
+
+TEST(ThreadPoolTaskDispatcherTest, TasksRejectedAfterShutdown) {
+  class TaskToReject : public Task {
+  public:
+    TaskToReject(bool &BodyRun, bool &DestructorRun)
+        : BodyRun(BodyRun), DestructorRun(DestructorRun) {}
+    ~TaskToReject() { DestructorRun = true; }
+    void run() override { BodyRun = true; }
+
+  private:
+    bool &BodyRun;
+    bool &DestructorRun;
+  };
+
+  ThreadPoolTaskDispatcher Dispatcher(1);
+  Dispatcher.shutdown();
+
+  bool BodyRun = false;
+  bool DestructorRun = false;
+
+  Dispatcher.dispatch(std::make_unique<TaskToReject>(BodyRun, DestructorRun));
+
+  EXPECT_FALSE(BodyRun);
+  EXPECT_TRUE(DestructorRun);
+}
+
+} // end anonymous namespace



More information about the llvm-commits mailing list