[Mlir-commits] [mlir] 6bbbd7b - Update MLIRContext to allow injecting an external ThreadPool (NFC)

Mehdi Amini llvmlistbot at llvm.org
Thu Jul 1 15:17:59 PDT 2021


Author: Mehdi Amini
Date: 2021-07-01T22:17:47Z
New Revision: 6bbbd7b499f2d5e1d716f33fdf5c072083f007da

URL: https://github.com/llvm/llvm-project/commit/6bbbd7b499f2d5e1d716f33fdf5c072083f007da
DIFF: https://github.com/llvm/llvm-project/commit/6bbbd7b499f2d5e1d716f33fdf5c072083f007da.diff

LOG: Update MLIRContext to allow injecting an external ThreadPool (NFC)

The context can be created with threading disabled, to avoid creating a thread pool
that may be destroyed when injecting another one later.

Differential Revision: https://reviews.llvm.org/D105302

Added: 
    

Modified: 
    mlir/include/mlir/IR/MLIRContext.h
    mlir/lib/IR/MLIRContext.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/MLIRContext.h b/mlir/include/mlir/IR/MLIRContext.h
index 7b0fcd66dc2b0..196c6ef441a6a 100644
--- a/mlir/include/mlir/IR/MLIRContext.h
+++ b/mlir/include/mlir/IR/MLIRContext.h
@@ -38,11 +38,27 @@ class StorageUniquer;
 /// a very generic name ("Context") and because it is uncommon for clients to
 /// interact with it.
 ///
+/// The context wrap some multi-threading facilities, and in particular by
+/// default it will implicitly create a thread pool.
+/// This can be undesirable if multiple context exists at the same time or if a
+/// process will be long-lived and create and destroy contexts.
+/// To control better thread spawning, an externally owned ThreadPool can be
+/// injected in the context. For example:
+///
+///  llvm::ThreadPool myThreadPool;
+///  while (auto *request = nextCompilationRequests()) {
+///    MLIRContext ctx(registry, MLIRContext::Threading::DISABLED);
+///    ctx.setThreadPool(myThreadPool);
+///    processRequest(request, cxt);
+///  }
+///
 class MLIRContext {
 public:
+  enum class Threading { DISABLED, ENABLED };
   /// Create a new Context.
-  explicit MLIRContext();
-  explicit MLIRContext(const DialectRegistry &registry);
+  explicit MLIRContext(Threading multithreading = Threading::ENABLED);
+  explicit MLIRContext(const DialectRegistry &registry,
+                       Threading multithreading = Threading::ENABLED);
   ~MLIRContext();
 
   /// Return information about all IR dialects loaded in the context.
@@ -118,7 +134,15 @@ class MLIRContext {
     disableMultithreading(!enable);
   }
 
-  /// Return the thread pool owned by this context. This method requires that
+  /// Set a new thread pool to be used in this context. This method requires
+  /// that multithreading is disabled for this context prior to the call. This
+  /// allows to share a thread pool across multiple contexts, as well as
+  /// decoupling the lifetime of the threads from the contexts. The thread pool
+  /// must outlive the context. Multi-threading will be enabled as part of this
+  /// method.
+  void setThreadPool(llvm::ThreadPool &pool);
+
+  /// Return the thread pool used by this context. This method requires that
   /// multithreading be enabled within the context, and should generally not be
   /// used directly. Users should instead prefer the threading utilities within
   /// Threading.h.

diff  --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index ddb909949cdfc..7e4ec8261a2fc 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -261,8 +261,15 @@ class MLIRContextImpl {
   // Other
   //===--------------------------------------------------------------------===//
 
-  /// The thread pool to use when processing MLIR tasks in parallel.
-  llvm::Optional<llvm::ThreadPool> threadPool;
+  /// This points to the ThreadPool used when processing MLIR tasks in parallel.
+  /// It can't be nullptr when multi-threading is enabled. Otherwise if
+  /// multi-threading is disabled, and the threadpool wasn't externally provided
+  /// using `setThreadPool`, this will be nullptr.
+  llvm::ThreadPool *threadPool = nullptr;
+
+  /// In case where the thread pool is owned by the context, this ensures
+  /// destruction with the context.
+  std::unique_ptr<llvm::ThreadPool> ownedThreadPool;
 
   /// This is a list of dialects that are created referring to this context.
   /// The MLIRContext owns the objects.
@@ -334,9 +341,13 @@ class MLIRContextImpl {
   StringAttr emptyStringAttr;
 
 public:
-  MLIRContextImpl() : identifiers(identifierAllocator) {
-    if (threadingIsEnabled)
-      threadPool.emplace();
+  MLIRContextImpl(bool threadingIsEnabled)
+      : threadingIsEnabled(threadingIsEnabled),
+        identifiers(identifierAllocator) {
+    if (threadingIsEnabled) {
+      ownedThreadPool = std::make_unique<llvm::ThreadPool>();
+      threadPool = ownedThreadPool.get();
+    }
   }
   ~MLIRContextImpl() {
     for (auto typeMapping : registeredTypes)
@@ -347,10 +358,11 @@ class MLIRContextImpl {
 };
 } // end namespace mlir
 
-MLIRContext::MLIRContext() : MLIRContext(DialectRegistry()) {}
+MLIRContext::MLIRContext(Threading setting)
+    : MLIRContext(DialectRegistry(), setting) {}
 
-MLIRContext::MLIRContext(const DialectRegistry &registry)
-    : impl(new MLIRContextImpl) {
+MLIRContext::MLIRContext(const DialectRegistry &registry, Threading setting)
+    : impl(new MLIRContextImpl(setting == Threading::ENABLED)) {
   // Initialize values based on the command line flags if they were provided.
   if (clOptions.isConstructed()) {
     disableMultithreading(clOptions->disableThreading);
@@ -579,15 +591,36 @@ void MLIRContext::disableMultithreading(bool disable) {
 
   // Destroy thread pool (stop all threads) if it is no longer needed, or create
   // a new one if multithreading was re-enabled.
-  if (!impl->threadingIsEnabled)
-    impl->threadPool.reset();
-  else if (!impl->threadPool.hasValue())
-    impl->threadPool.emplace();
+  if (disable) {
+    // If the thread pool is owned, explicitly set it to nullptr to avoid
+    // keeping a dangling pointer around. If the thread pool is externally
+    // owned, we don't do anything.
+    if (impl->ownedThreadPool) {
+      assert(impl->threadPool);
+      impl->threadPool = nullptr;
+      impl->ownedThreadPool.reset();
+    }
+  } else if (!impl->threadPool) {
+    // The thread pool isn't externally provided.
+    assert(!impl->ownedThreadPool);
+    impl->ownedThreadPool = std::make_unique<llvm::ThreadPool>();
+    impl->threadPool = impl->ownedThreadPool.get();
+  }
+}
+
+void MLIRContext::setThreadPool(llvm::ThreadPool &pool) {
+  assert(!isMultithreadingEnabled() &&
+         "expected multi-threading to be disabled when setting a ThreadPool");
+  impl->threadPool = &pool;
+  impl->ownedThreadPool.reset();
+  enableMultithreading();
 }
 
 llvm::ThreadPool &MLIRContext::getThreadPool() {
   assert(isMultithreadingEnabled() &&
          "expected multi-threading to be enabled within the context");
+  assert(impl->threadPool &&
+         "multi-threading is enabled but threadpool not set");
   return *impl->threadPool;
 }
 


        


More information about the Mlir-commits mailing list