[llvm] 7da6342 - Re-apply "[ORC] Unify task dispatch across ExecutionSession..." with more fixes.

Lang Hames via llvm-commits llvm-commits at lists.llvm.org
Wed Apr 24 00:11:46 PDT 2024


Author: Lang Hames
Date: 2024-04-23T23:11:37-08:00
New Revision: 7da63426ac5d9719038842c30ca2a644620be071

URL: https://github.com/llvm/llvm-project/commit/7da63426ac5d9719038842c30ca2a644620be071
DIFF: https://github.com/llvm/llvm-project/commit/7da63426ac5d9719038842c30ca2a644620be071.diff

LOG: Re-apply "[ORC] Unify task dispatch across ExecutionSession..." with more fixes.

This re-applies 6094b3b7db7, which was reverted in e7efd37c229 (and before that
in 1effa19de24) due to bot failures.

The test failures were fixed by having SelfExecutorProcessControl use an
InPlaceTaskDispatcher by default, rather than a DynamicThreadPoolTaskDispatcher.
This shouldn't be necessary (and indicates a concurrency issue elsewhere), but
InPlaceTaskDispatcher is a less surprising default, and better matches the
existing behavior (compilation on current thread by default), so the change
seems reasonable. I've filed https://github.com/llvm/llvm-project/issues/89870
to investigate the concurrency issue as a follow-up.

Coding my way home: 6.25133S 127.94177W

Added: 
    

Modified: 
    llvm/include/llvm/ExecutionEngine/Orc/Core.h
    llvm/include/llvm/ExecutionEngine/Orc/LLJIT.h
    llvm/include/llvm/ExecutionEngine/Orc/TaskDispatch.h
    llvm/lib/ExecutionEngine/Orc/ExecutorProcessControl.cpp
    llvm/lib/ExecutionEngine/Orc/LLJIT.cpp
    llvm/lib/ExecutionEngine/Orc/TaskDispatch.cpp
    llvm/tools/llvm-jitlink/llvm-jitlink.cpp
    llvm/unittests/ExecutionEngine/Orc/CoreAPIsTest.cpp
    llvm/unittests/ExecutionEngine/Orc/OrcTestCommon.cpp
    llvm/unittests/ExecutionEngine/Orc/OrcTestCommon.h
    llvm/unittests/ExecutionEngine/Orc/TaskDispatchTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/ExecutionEngine/Orc/Core.h b/llvm/include/llvm/ExecutionEngine/Orc/Core.h
index 7121b3fe762748..bac923aba02afd 100644
--- a/llvm/include/llvm/ExecutionEngine/Orc/Core.h
+++ b/llvm/include/llvm/ExecutionEngine/Orc/Core.h
@@ -1443,9 +1443,6 @@ class ExecutionSession {
   /// Send a result to the remote.
   using SendResultFunction = unique_function<void(shared::WrapperFunctionResult)>;
 
-  /// For dispatching ORC tasks (typically materialization tasks).
-  using DispatchTaskFunction = unique_function<void(std::unique_ptr<Task> T)>;
-
   /// An asynchronous wrapper-function callable from the executor via
   /// jit-dispatch.
   using JITDispatchHandlerFunction = unique_function<void(
@@ -1568,12 +1565,6 @@ class ExecutionSession {
   /// Unhandled errors can be sent here to log them.
   void reportError(Error Err) { ReportError(std::move(Err)); }
 
-  /// Set the task dispatch function.
-  ExecutionSession &setDispatchTask(DispatchTaskFunction DispatchTask) {
-    this->DispatchTask = std::move(DispatchTask);
-    return *this;
-  }
-
   /// Search the given JITDylibs to find the flags associated with each of the
   /// given symbols.
   void lookupFlags(LookupKind K, JITDylibSearchOrder SearchOrder,
@@ -1648,7 +1639,7 @@ class ExecutionSession {
   void dispatchTask(std::unique_ptr<Task> T) {
     assert(T && "T must be non-null");
     DEBUG_WITH_TYPE("orc", dumpDispatchInfo(*T));
-    DispatchTask(std::move(T));
+    EPC->getDispatcher().dispatch(std::move(T));
   }
 
   /// Run a wrapper function in the executor.
@@ -1762,8 +1753,6 @@ class ExecutionSession {
     logAllUnhandledErrors(std::move(Err), errs(), "JIT session error: ");
   }
 
-  static void runOnCurrentThread(std::unique_ptr<Task> T) { T->run(); }
-
   void dispatchOutstandingMUs();
 
   static std::unique_ptr<MaterializationResponsibility>
@@ -1869,7 +1858,6 @@ class ExecutionSession {
   std::unique_ptr<ExecutorProcessControl> EPC;
   std::unique_ptr<Platform> P;
   ErrorReporter ReportError = logErrorsToStdErr;
-  DispatchTaskFunction DispatchTask = runOnCurrentThread;
 
   std::vector<ResourceManager *> ResourceManagers;
 

diff  --git a/llvm/include/llvm/ExecutionEngine/Orc/LLJIT.h b/llvm/include/llvm/ExecutionEngine/Orc/LLJIT.h
index 810a38f4a6acb8..3a71ddc88ce956 100644
--- a/llvm/include/llvm/ExecutionEngine/Orc/LLJIT.h
+++ b/llvm/include/llvm/ExecutionEngine/Orc/LLJIT.h
@@ -254,7 +254,6 @@ class LLJIT {
 
   DataLayout DL;
   Triple TT;
-  std::unique_ptr<DefaultThreadPool> CompileThreads;
 
   std::unique_ptr<ObjectLayer> ObjLinkingLayer;
   std::unique_ptr<ObjectTransformLayer> ObjTransformLayer;
@@ -325,6 +324,7 @@ class LLJITBuilderState {
   PlatformSetupFunction SetUpPlatform;
   NotifyCreatedFunction NotifyCreated;
   unsigned NumCompileThreads = 0;
+  std::optional<bool> SupportConcurrentCompilation;
 
   /// Called prior to JIT class construcion to fix up defaults.
   Error prepareForConstruction();
@@ -333,7 +333,7 @@ class LLJITBuilderState {
 template <typename JITType, typename SetterImpl, typename State>
 class LLJITBuilderSetters {
 public:
-  /// Set a ExecutorProcessControl for this instance.
+  /// Set an ExecutorProcessControl for this instance.
   /// This should not be called if ExecutionSession has already been set.
   SetterImpl &
   setExecutorProcessControl(std::unique_ptr<ExecutorProcessControl> EPC) {
@@ -462,19 +462,26 @@ class LLJITBuilderSetters {
   ///
   /// If this method is not called, behavior will be as if it were called with
   /// a zero argument.
+  ///
+  /// This setting should not be used if a custom ExecutionSession or
+  /// ExecutorProcessControl object is set: in those cases a custom
+  /// TaskDispatcher should be used instead.
   SetterImpl &setNumCompileThreads(unsigned NumCompileThreads) {
     impl().NumCompileThreads = NumCompileThreads;
     return impl();
   }
 
-  /// Set an ExecutorProcessControl object.
+  /// If set, this forces LLJIT concurrent compilation support to be either on
+  /// or off. This controls the selection of compile function (concurrent vs
+  /// single threaded) and whether or not sub-modules are cloned to new
+  /// contexts for lazy emission.
   ///
-  /// If the platform uses ObjectLinkingLayer by default and no
-  /// ObjectLinkingLayerCreator has been set then the ExecutorProcessControl
-  /// object will be used to supply the memory manager for the
-  /// ObjectLinkingLayer.
-  SetterImpl &setExecutorProcessControl(ExecutorProcessControl &EPC) {
-    impl().EPC = &EPC;
+  /// If not explicitly set then concurrency support will be turned on if
+  /// NumCompileThreads is set to a non-zero value, or if a custom
+  /// ExecutionSession or ExecutorProcessControl instance is provided.
+  SetterImpl &setSupportConcurrentCompilation(
+      std::optional<bool> SupportConcurrentCompilation) {
+    impl().SupportConcurrentCompilation = SupportConcurrentCompilation;
     return impl();
   }
 

diff  --git a/llvm/include/llvm/ExecutionEngine/Orc/TaskDispatch.h b/llvm/include/llvm/ExecutionEngine/Orc/TaskDispatch.h
index 8c287f9fec0e89..8c65677aae25a4 100644
--- a/llvm/include/llvm/ExecutionEngine/Orc/TaskDispatch.h
+++ b/llvm/include/llvm/ExecutionEngine/Orc/TaskDispatch.h
@@ -23,6 +23,7 @@
 
 #if LLVM_ENABLE_THREADS
 #include <condition_variable>
+#include <deque>
 #include <mutex>
 #include <thread>
 #endif
@@ -114,6 +115,9 @@ class InPlaceTaskDispatcher : public TaskDispatcher {
 
 class DynamicThreadPoolTaskDispatcher : public TaskDispatcher {
 public:
+  DynamicThreadPoolTaskDispatcher(
+      std::optional<size_t> MaxMaterializationThreads)
+      : MaxMaterializationThreads(MaxMaterializationThreads) {}
   void dispatch(std::unique_ptr<Task> T) override;
   void shutdown() override;
 private:
@@ -121,6 +125,10 @@ class DynamicThreadPoolTaskDispatcher : public TaskDispatcher {
   bool Running = true;
   size_t Outstanding = 0;
   std::condition_variable OutstandingCV;
+
+  std::optional<size_t> MaxMaterializationThreads;
+  size_t NumMaterializationThreads = 0;
+  std::deque<std::unique_ptr<Task>> MaterializationTaskQueue;
 };
 
 #endif // LLVM_ENABLE_THREADS

diff  --git a/llvm/lib/ExecutionEngine/Orc/ExecutorProcessControl.cpp b/llvm/lib/ExecutionEngine/Orc/ExecutorProcessControl.cpp
index efafca949e61ef..0df7c4f25eb82c 100644
--- a/llvm/lib/ExecutionEngine/Orc/ExecutorProcessControl.cpp
+++ b/llvm/lib/ExecutionEngine/Orc/ExecutorProcessControl.cpp
@@ -61,13 +61,8 @@ SelfExecutorProcessControl::Create(
   if (!SSP)
     SSP = std::make_shared<SymbolStringPool>();
 
-  if (!D) {
-#if LLVM_ENABLE_THREADS
-    D = std::make_unique<DynamicThreadPoolTaskDispatcher>();
-#else
+  if (!D)
     D = std::make_unique<InPlaceTaskDispatcher>();
-#endif
-  }
 
   auto PageSize = sys::Process::getPageSize();
   if (!PageSize)

diff  --git a/llvm/lib/ExecutionEngine/Orc/LLJIT.cpp b/llvm/lib/ExecutionEngine/Orc/LLJIT.cpp
index 79adda5b7bc034..53f13a68c7b8b3 100644
--- a/llvm/lib/ExecutionEngine/Orc/LLJIT.cpp
+++ b/llvm/lib/ExecutionEngine/Orc/LLJIT.cpp
@@ -667,6 +667,37 @@ Error LLJITBuilderState::prepareForConstruction() {
       return JTMBOrErr.takeError();
   }
 
+  if ((ES || EPC) && NumCompileThreads)
+    return make_error<StringError>(
+        "NumCompileThreads cannot be used with a custom ExecutionSession or "
+        "ExecutorProcessControl",
+        inconvertibleErrorCode());
+
+#if !LLVM_ENABLE_THREADS
+  if (NumCompileThreads)
+    return make_error<StringError>(
+        "LLJIT num-compile-threads is " + Twine(NumCompileThreads) +
+            " but LLVM was compiled with LLVM_ENABLE_THREADS=Off",
+        inconvertibleErrorCode());
+#endif // !LLVM_ENABLE_THREADS
+
+  bool ConcurrentCompilationSettingDefaulted = !SupportConcurrentCompilation;
+  if (!SupportConcurrentCompilation) {
+#if LLVM_ENABLE_THREADS
+    SupportConcurrentCompilation = NumCompileThreads || ES || EPC;
+#else
+    SupportConcurrentCompilation = false;
+#endif // LLVM_ENABLE_THREADS
+  } else {
+#if !LLVM_ENABLE_THREADS
+    if (*SupportConcurrentCompilation)
+      return make_error<StringError>(
+          "LLJIT concurrent compilation support requested, but LLVM was built "
+          "with LLVM_ENABLE_THREADS=Off",
+          inconvertibleErrorCode());
+#endif // !LLVM_ENABLE_THREADS
+  }
+
   LLVM_DEBUG({
     dbgs() << "  JITTargetMachineBuilder is "
            << JITTargetMachineBuilderPrinter(*JTMB, "  ")
@@ -684,11 +715,13 @@ Error LLJITBuilderState::prepareForConstruction() {
            << (CreateCompileFunction ? "Yes" : "No") << "\n"
            << "  Custom platform-setup function: "
            << (SetUpPlatform ? "Yes" : "No") << "\n"
-           << "  Number of compile threads: " << NumCompileThreads;
-    if (!NumCompileThreads)
-      dbgs() << " (code will be compiled on the execution thread)\n";
+           << "  Support concurrent compilation: "
+           << (*SupportConcurrentCompilation ? "Yes" : "No");
+    if (ConcurrentCompilationSettingDefaulted)
+      dbgs() << " (defaulted based on ES / EPC)\n";
     else
       dbgs() << "\n";
+    dbgs() << "  Number of compile threads: " << NumCompileThreads << "\n";
   });
 
   // Create DL if not specified.
@@ -705,7 +738,19 @@ Error LLJITBuilderState::prepareForConstruction() {
       dbgs() << "ExecutorProcessControl not specified, "
                 "Creating SelfExecutorProcessControl instance\n";
     });
-    if (auto EPCOrErr = SelfExecutorProcessControl::Create())
+
+    std::unique_ptr<TaskDispatcher> D = nullptr;
+#if LLVM_ENABLE_THREADS
+    if (*SupportConcurrentCompilation) {
+      std::optional<size_t> NumThreads = std ::nullopt;
+      if (NumCompileThreads)
+        NumThreads = NumCompileThreads;
+      D = std::make_unique<DynamicThreadPoolTaskDispatcher>(NumThreads);
+    } else
+      D = std::make_unique<InPlaceTaskDispatcher>();
+#endif // LLVM_ENABLE_THREADS
+    if (auto EPCOrErr =
+            SelfExecutorProcessControl::Create(nullptr, std::move(D), nullptr))
       EPC = std::move(*EPCOrErr);
     else
       return EPCOrErr.takeError();
@@ -790,8 +835,6 @@ Error LLJITBuilderState::prepareForConstruction() {
 }
 
 LLJIT::~LLJIT() {
-  if (CompileThreads)
-    CompileThreads->wait();
   if (auto Err = ES->endSession())
     ES->reportError(std::move(Err));
 }
@@ -916,9 +959,8 @@ LLJIT::createCompileFunction(LLJITBuilderState &S,
   if (S.CreateCompileFunction)
     return S.CreateCompileFunction(std::move(JTMB));
 
-  // Otherwise default to creating a SimpleCompiler, or ConcurrentIRCompiler,
-  // depending on the number of threads requested.
-  if (S.NumCompileThreads > 0)
+  // If using a custom EPC then use a ConcurrentIRCompiler by default.
+  if (*S.SupportConcurrentCompilation)
     return std::make_unique<ConcurrentIRCompiler>(std::move(JTMB));
 
   auto TM = JTMB.createTargetMachine();
@@ -970,21 +1012,8 @@ LLJIT::LLJIT(LLJITBuilderState &S, Error &Err)
         std::make_unique<IRTransformLayer>(*ES, *TransformLayer);
   }
 
-  if (S.NumCompileThreads > 0) {
+  if (*S.SupportConcurrentCompilation)
     InitHelperTransformLayer->setCloneToNewContextOnEmit(true);
-    CompileThreads = std::make_unique<DefaultThreadPool>(
-        hardware_concurrency(S.NumCompileThreads));
-    ES->setDispatchTask([this](std::unique_ptr<Task> T) {
-      // FIXME: We should be able to use move-capture here, but ThreadPool's
-      // AsyncTaskTys are std::functions rather than unique_functions
-      // (because MSVC's std::packaged_tasks don't support move-only types).
-      // Fix this when all the above gets sorted out.
-      CompileThreads->async([UnownedT = T.release()]() mutable {
-        std::unique_ptr<Task> T(UnownedT);
-        T->run();
-      });
-    });
-  }
 
   if (S.SetupProcessSymbolsJITDylib) {
     if (auto ProcSymsJD = S.SetupProcessSymbolsJITDylib(*this)) {
@@ -1240,7 +1269,7 @@ LLLazyJIT::LLLazyJIT(LLLazyJITBuilderState &S, Error &Err) : LLJIT(S, Err) {
   CODLayer = std::make_unique<CompileOnDemandLayer>(
       *ES, *InitHelperTransformLayer, *LCTMgr, std::move(ISMBuilder));
 
-  if (S.NumCompileThreads > 0)
+  if (*S.SupportConcurrentCompilation)
     CODLayer->setCloneToNewContextOnEmit(true);
 }
 

diff  --git a/llvm/lib/ExecutionEngine/Orc/TaskDispatch.cpp b/llvm/lib/ExecutionEngine/Orc/TaskDispatch.cpp
index 11a99986f2ee92..4ac2a42091858e 100644
--- a/llvm/lib/ExecutionEngine/Orc/TaskDispatch.cpp
+++ b/llvm/lib/ExecutionEngine/Orc/TaskDispatch.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "llvm/ExecutionEngine/Orc/TaskDispatch.h"
+#include "llvm/ExecutionEngine/Orc/Core.h"
 
 namespace llvm {
 namespace orc {
@@ -24,16 +25,52 @@ void InPlaceTaskDispatcher::shutdown() {}
 
 #if LLVM_ENABLE_THREADS
 void DynamicThreadPoolTaskDispatcher::dispatch(std::unique_ptr<Task> T) {
+  bool IsMaterializationTask = isa<MaterializationTask>(*T);
+
   {
     std::lock_guard<std::mutex> Lock(DispatchMutex);
+
+    if (IsMaterializationTask) {
+
+      // If this is a materialization task and there are too many running
+      // already then queue this one up and return early.
+      if (MaxMaterializationThreads &&
+          NumMaterializationThreads == *MaxMaterializationThreads) {
+        MaterializationTaskQueue.push_back(std::move(T));
+        return;
+      }
+
+      // Otherwise record that we have a materialization task running.
+      ++NumMaterializationThreads;
+    }
+
     ++Outstanding;
   }
 
-  std::thread([this, T = std::move(T)]() mutable {
-    T->run();
-    std::lock_guard<std::mutex> Lock(DispatchMutex);
-    --Outstanding;
-    OutstandingCV.notify_all();
+  std::thread([this, T = std::move(T), IsMaterializationTask]() mutable {
+    while (true) {
+
+      // Run the task.
+      T->run();
+
+      std::lock_guard<std::mutex> Lock(DispatchMutex);
+      if (!MaterializationTaskQueue.empty()) {
+        // If there are any materialization tasks running then steal that work.
+        T = std::move(MaterializationTaskQueue.front());
+        MaterializationTaskQueue.pop_front();
+        if (!IsMaterializationTask) {
+          ++NumMaterializationThreads;
+          IsMaterializationTask = true;
+        }
+      } else {
+        // Otherwise decrement work counters.
+        if (IsMaterializationTask)
+          --NumMaterializationThreads;
+        --Outstanding;
+        OutstandingCV.notify_all();
+        return;
+      }
+    }
   }).detach();
 }
 

diff  --git a/llvm/tools/llvm-jitlink/llvm-jitlink.cpp b/llvm/tools/llvm-jitlink/llvm-jitlink.cpp
index 09b2a5900eb0b7..bff05b9ca4bebc 100644
--- a/llvm/tools/llvm-jitlink/llvm-jitlink.cpp
+++ b/llvm/tools/llvm-jitlink/llvm-jitlink.cpp
@@ -807,8 +807,8 @@ static Expected<std::unique_ptr<ExecutorProcessControl>> launchExecutor() {
     S.CreateMemoryManager = createSharedMemoryManager;
 
   return SimpleRemoteEPC::Create<FDSimpleRemoteEPCTransport>(
-      std::make_unique<DynamicThreadPoolTaskDispatcher>(), std::move(S),
-      FromExecutor[ReadEnd], ToExecutor[WriteEnd]);
+      std::make_unique<DynamicThreadPoolTaskDispatcher>(std::nullopt),
+      std::move(S), FromExecutor[ReadEnd], ToExecutor[WriteEnd]);
 #endif
 }
 
@@ -897,7 +897,7 @@ static Expected<std::unique_ptr<ExecutorProcessControl>> connectToExecutor() {
     S.CreateMemoryManager = createSharedMemoryManager;
 
   return SimpleRemoteEPC::Create<FDSimpleRemoteEPCTransport>(
-      std::make_unique<DynamicThreadPoolTaskDispatcher>(),
+      std::make_unique<DynamicThreadPoolTaskDispatcher>(std::nullopt),
       std::move(S), *SockFD, *SockFD);
 #endif
 }

diff  --git a/llvm/unittests/ExecutionEngine/Orc/CoreAPIsTest.cpp b/llvm/unittests/ExecutionEngine/Orc/CoreAPIsTest.cpp
index 5e2b5f35bcf471..3b24e29e1ed386 100644
--- a/llvm/unittests/ExecutionEngine/Orc/CoreAPIsTest.cpp
+++ b/llvm/unittests/ExecutionEngine/Orc/CoreAPIsTest.cpp
@@ -1005,11 +1005,11 @@ TEST_F(CoreAPIsStandardTest, RedefineBoundWeakSymbol) {
 
 TEST_F(CoreAPIsStandardTest, DefineMaterializingSymbol) {
   bool ExpectNoMoreMaterialization = false;
-  ES.setDispatchTask([&](std::unique_ptr<Task> T) {
+  DispatchOverride = [&](std::unique_ptr<Task> T) {
     if (ExpectNoMoreMaterialization && isa<MaterializationTask>(*T))
       ADD_FAILURE() << "Unexpected materialization";
     T->run();
-  });
+  };
 
   auto MU = std::make_unique<SimpleMaterializationUnit>(
       SymbolFlagsMap({{Foo, FooSym.getFlags()}}),
@@ -1403,7 +1403,7 @@ TEST_F(CoreAPIsStandardTest, TestLookupWithThreadedMaterialization) {
 
   std::mutex WorkThreadsMutex;
   std::vector<std::thread> WorkThreads;
-  ES.setDispatchTask([&](std::unique_ptr<Task> T) {
+  DispatchOverride = [&](std::unique_ptr<Task> T) {
     std::promise<void> WaitP;
     std::lock_guard<std::mutex> Lock(WorkThreadsMutex);
     WorkThreads.push_back(
@@ -1412,7 +1412,7 @@ TEST_F(CoreAPIsStandardTest, TestLookupWithThreadedMaterialization) {
           T->run();
         }));
     WaitP.set_value();
-  });
+  };
 
   cantFail(JD.define(absoluteSymbols({{Foo, FooSym}})));
 

diff  --git a/llvm/unittests/ExecutionEngine/Orc/OrcTestCommon.cpp b/llvm/unittests/ExecutionEngine/Orc/OrcTestCommon.cpp
index bc87df1fe8c6a9..307f14dfe24d03 100644
--- a/llvm/unittests/ExecutionEngine/Orc/OrcTestCommon.cpp
+++ b/llvm/unittests/ExecutionEngine/Orc/OrcTestCommon.cpp
@@ -22,3 +22,18 @@ ModuleBuilder::ModuleBuilder(LLVMContext &Context, StringRef Triple,
   if (Triple != "")
     M->setTargetTriple(Triple);
 }
+
+void llvm::orc::CoreAPIsBasedStandardTest::OverridableDispatcher::dispatch(
+    std::unique_ptr<Task> T) {
+  if (Parent.DispatchOverride)
+    Parent.DispatchOverride(std::move(T));
+  else
+    InPlaceTaskDispatcher::dispatch(std::move(T));
+}
+
+std::unique_ptr<llvm::orc::ExecutorProcessControl>
+llvm::orc::CoreAPIsBasedStandardTest::makeEPC(
+    std::shared_ptr<SymbolStringPool> SSP) {
+  return std::make_unique<UnsupportedExecutorProcessControl>(
+      std::move(SSP), std::make_unique<OverridableDispatcher>(*this));
+}

diff  --git a/llvm/unittests/ExecutionEngine/Orc/OrcTestCommon.h b/llvm/unittests/ExecutionEngine/Orc/OrcTestCommon.h
index ce7da76c9653a3..0981f4b8132bd0 100644
--- a/llvm/unittests/ExecutionEngine/Orc/OrcTestCommon.h
+++ b/llvm/unittests/ExecutionEngine/Orc/OrcTestCommon.h
@@ -52,8 +52,20 @@ class CoreAPIsBasedStandardTest : public testing::Test {
   }
 
 protected:
+  class OverridableDispatcher : public InPlaceTaskDispatcher {
+  public:
+    OverridableDispatcher(CoreAPIsBasedStandardTest &Parent) : Parent(Parent) {}
+    void dispatch(std::unique_ptr<Task> T) override;
+
+  private:
+    CoreAPIsBasedStandardTest &Parent;
+  };
+
+  std::unique_ptr<llvm::orc::ExecutorProcessControl>
+  makeEPC(std::shared_ptr<SymbolStringPool> SSP);
+
   std::shared_ptr<SymbolStringPool> SSP = std::make_shared<SymbolStringPool>();
-  ExecutionSession ES{std::make_unique<UnsupportedExecutorProcessControl>(SSP)};
+  ExecutionSession ES{makeEPC(SSP)};
   JITDylib &JD = ES.createBareJITDylib("JD");
   SymbolStringPtr Foo = ES.intern("foo");
   SymbolStringPtr Bar = ES.intern("bar");
@@ -67,6 +79,7 @@ class CoreAPIsBasedStandardTest : public testing::Test {
   ExecutorSymbolDef BarSym{BarAddr, JITSymbolFlags::Exported};
   ExecutorSymbolDef BazSym{BazAddr, JITSymbolFlags::Exported};
   ExecutorSymbolDef QuxSym{QuxAddr, JITSymbolFlags::Exported};
+  unique_function<void(std::unique_ptr<Task>)> DispatchOverride;
 };
 
 } // end namespace orc

diff  --git a/llvm/unittests/ExecutionEngine/Orc/TaskDispatchTest.cpp b/llvm/unittests/ExecutionEngine/Orc/TaskDispatchTest.cpp
index 83d386c631dddb..6af0d60cf8ae6a 100644
--- a/llvm/unittests/ExecutionEngine/Orc/TaskDispatchTest.cpp
+++ b/llvm/unittests/ExecutionEngine/Orc/TaskDispatchTest.cpp
@@ -24,7 +24,7 @@ TEST(InPlaceTaskDispatchTest, GenericNamedTask) {
 
 #if LLVM_ENABLE_THREADS
 TEST(DynamicThreadPoolDispatchTest, GenericNamedTask) {
-  auto D = std::make_unique<DynamicThreadPoolTaskDispatcher>();
+  auto D = std::make_unique<DynamicThreadPoolTaskDispatcher>(std::nullopt);
   std::promise<bool> P;
   auto F = P.get_future();
   D->dispatch(makeGenericNamedTask(


        


More information about the llvm-commits mailing list