[Mlir-commits] [mlir] cb9ae00 - [mlir] Add a new context flag for disabling/enabling multi-threading

River Riddle llvmlistbot at llvm.org
Sat May 2 12:35:45 PDT 2020


Author: River Riddle
Date: 2020-05-02T12:32:25-07:00
New Revision: cb9ae0025c4ed966a3a9b5539a9ff6b6e865516f

URL: https://github.com/llvm/llvm-project/commit/cb9ae0025c4ed966a3a9b5539a9ff6b6e865516f
DIFF: https://github.com/llvm/llvm-project/commit/cb9ae0025c4ed966a3a9b5539a9ff6b6e865516f.diff

LOG: [mlir] Add a new context flag for disabling/enabling multi-threading

This is useful for several reasons:
* In some situations the user can guarantee that thread-safety isn't necessary and don't want to pay the cost of synchronization, e.g., when parsing a very large module.

* For things like logging threading is not desirable as the output is not guaranteed to be in stable order.

This flag also subsumes the pass manager flag for multi-threading.

Differential Revision: https://reviews.llvm.org/D79266

Added: 
    

Modified: 
    mlir/docs/PassManagement.md
    mlir/include/mlir/IR/MLIRContext.h
    mlir/include/mlir/Pass/PassManager.h
    mlir/include/mlir/Support/StorageUniquer.h
    mlir/lib/IR/MLIRContext.cpp
    mlir/lib/Pass/IRPrinting.cpp
    mlir/lib/Pass/Pass.cpp
    mlir/lib/Pass/PassManagerOptions.cpp
    mlir/lib/Support/StorageUniquer.cpp
    mlir/lib/Transforms/Inliner.cpp
    mlir/test/Dialect/SPIRV/availability.mlir
    mlir/test/Dialect/SPIRV/target-env.mlir
    mlir/test/IR/test-matchers.mlir
    mlir/test/Pass/ir-printing.mlir
    mlir/test/Pass/pass-timing.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/docs/PassManagement.md b/mlir/docs/PassManagement.md
index 90d30deed914..04a4ca0a7b3c 100644
--- a/mlir/docs/PassManagement.md
+++ b/mlir/docs/PassManagement.md
@@ -801,7 +801,7 @@ pipeline. This display mode is available in mlir-opt via
 `-pass-timing-display=list`.
 
 ```shell
-$ mlir-opt foo.mlir -disable-pass-threading -pass-pipeline='func(cse,canonicalize)' -convert-std-to-llvm -pass-timing -pass-timing-display=list
+$ mlir-opt foo.mlir -mlir-disable-threading -pass-pipeline='func(cse,canonicalize)' -convert-std-to-llvm -pass-timing -pass-timing-display=list
 
 ===-------------------------------------------------------------------------===
                       ... Pass execution timing report ...
@@ -826,7 +826,7 @@ the most time, and can also be used to identify when analyses are being
 invalidated and recomputed. This is the default display mode.
 
 ```shell
-$ mlir-opt foo.mlir -disable-pass-threading -pass-pipeline='func(cse,canonicalize)' -convert-std-to-llvm -pass-timing
+$ mlir-opt foo.mlir -mlir-disable-threading -pass-pipeline='func(cse,canonicalize)' -convert-std-to-llvm -pass-timing
 
 ===-------------------------------------------------------------------------===
                       ... Pass execution timing report ...
@@ -943,10 +943,10 @@ func @simple_constant() -> (i32, i32) {
     *   Always print the top-level module operation, regardless of pass type or
         operation nesting level.
     *   Note: Printing at module scope should only be used when multi-threading
-        is disabled(`-disable-pass-threading`)
+        is disabled(`-mlir-disable-threading`)
 
 ```shell
-$ mlir-opt foo.mlir -disable-pass-threading -pass-pipeline='func(cse)' -print-ir-after=cse -print-ir-module-scope
+$ mlir-opt foo.mlir -mlir-disable-threading -pass-pipeline='func(cse)' -print-ir-after=cse -print-ir-module-scope
 
 *** IR Dump After CSE ***  ('func' operation: @bar)
 func @bar(%arg0: f32, %arg1: f32) -> f32 {

diff  --git a/mlir/include/mlir/IR/MLIRContext.h b/mlir/include/mlir/IR/MLIRContext.h
index e1213889fe73..40b332698c44 100644
--- a/mlir/include/mlir/IR/MLIRContext.h
+++ b/mlir/include/mlir/IR/MLIRContext.h
@@ -55,6 +55,12 @@ class MLIRContext {
   /// Enables creating operations in unregistered dialects.
   void allowUnregisteredDialects(bool allow = true);
 
+  /// Return true if multi-threading is enabled by the context.
+  bool isMultithreadingEnabled();
+
+  /// Set the flag specifying if multi-threading is disabled by the context.
+  void disableMultithreading(bool disable = true);
+
   /// Return true if we should attach the operation to diagnostics emitted via
   /// Operation::emit.
   bool shouldPrintOpOnDiagnostic();

diff  --git a/mlir/include/mlir/Pass/PassManager.h b/mlir/include/mlir/Pass/PassManager.h
index 15c88128102e..be5c5d5bcc22 100644
--- a/mlir/include/mlir/Pass/PassManager.h
+++ b/mlir/include/mlir/Pass/PassManager.h
@@ -99,7 +99,7 @@ class OpPassManager {
   void mergeStatisticsInto(OpPassManager &other);
 
 private:
-  OpPassManager(OperationName name, bool disableThreads, bool verifyPasses);
+  OpPassManager(OperationName name, bool verifyPasses);
 
   /// A pointer to an internal implementation instance.
   std::unique_ptr<detail::OpPassManagerImpl> impl;
@@ -139,13 +139,6 @@ class PassManager : public OpPassManager {
   LLVM_NODISCARD
   LogicalResult run(ModuleOp module);
 
-  /// Disable support for multi-threading within the pass manager.
-  void disableMultithreading(bool disable = true);
-
-  /// Return true if the pass manager is configured with multi-threading
-  /// enabled.
-  bool isMultithreadingEnabled();
-
   /// Enable support for the pass manager to generate a reproducer on the event
   /// of a crash or a pass failure. `outputFile` is a .mlir filename used to
   /// write the generated reproducer. If `genLocalReproducer` is true, the pass

diff  --git a/mlir/include/mlir/Support/StorageUniquer.h b/mlir/include/mlir/Support/StorageUniquer.h
index 62a43ff6d1fe..f13a2fef9d50 100644
--- a/mlir/include/mlir/Support/StorageUniquer.h
+++ b/mlir/include/mlir/Support/StorageUniquer.h
@@ -65,6 +65,9 @@ class StorageUniquer {
   StorageUniquer();
   ~StorageUniquer();
 
+  /// Set the flag specifying if multi-threading is disabled within the uniquer.
+  void disableMultithreading(bool disable = true);
+
   /// This class acts as the base storage that all storage classes must derived
   /// from.
   class BaseStorage {

diff  --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index b25c5111d8dc..c59c53567488 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -50,6 +50,10 @@ namespace {
 /// various bits of an MLIRContext. This uses a struct wrapper to avoid the need
 /// for global command line options.
 struct MLIRContextOptions {
+  llvm::cl::opt<bool> disableThreading{
+      "mlir-disable-threading",
+      llvm::cl::desc("Disabling multi-threading within MLIR")};
+
   llvm::cl::opt<bool> printOpOnDiagnostic{
       "mlir-print-op-on-diagnostic",
       llvm::cl::desc("When a diagnostic is emitted on an operation, also print "
@@ -101,6 +105,41 @@ struct BuiltinDialect : public Dialect {
 };
 } // end anonymous namespace.
 
+//===----------------------------------------------------------------------===//
+// Locking Utilities
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// Utility reader lock that takes a runtime flag that specifies if we really
+/// need to lock.
+struct ScopedReaderLock {
+  ScopedReaderLock(llvm::sys::SmartRWMutex<true> &mutexParam, bool shouldLock)
+      : mutex(shouldLock ? &mutexParam : nullptr) {
+    if (mutex)
+      mutex->lock_shared();
+  }
+  ~ScopedReaderLock() {
+    if (mutex)
+      mutex->unlock_shared();
+  }
+  llvm::sys::SmartRWMutex<true> *mutex;
+};
+/// Utility writer lock that takes a runtime flag that specifies if we really
+/// need to lock.
+struct ScopedWriterLock {
+  ScopedWriterLock(llvm::sys::SmartRWMutex<true> &mutexParam, bool shouldLock)
+      : mutex(shouldLock ? &mutexParam : nullptr) {
+    if (mutex)
+      mutex->lock();
+  }
+  ~ScopedWriterLock() {
+    if (mutex)
+      mutex->unlock();
+  }
+  llvm::sys::SmartRWMutex<true> *mutex;
+};
+} // end anonymous namespace.
+
 //===----------------------------------------------------------------------===//
 // AffineMap and IntegerSet hashing
 //===----------------------------------------------------------------------===//
@@ -111,8 +150,10 @@ template <typename ValueT, typename DenseInfoT, typename KeyT,
           typename ConstructorFn>
 static ValueT safeGetOrCreate(DenseSet<ValueT, DenseInfoT> &container,
                               KeyT &&key, llvm::sys::SmartRWMutex<true> &mutex,
+                              bool threadingIsEnabled,
                               ConstructorFn &&constructorFn) {
-  { // Check for an existing instance in read-only mode.
+  // Check for an existing instance in read-only mode.
+  if (threadingIsEnabled) {
     llvm::sys::SmartScopedReader<true> instanceLock(mutex);
     auto it = container.find_as(key);
     if (it != container.end())
@@ -120,16 +161,14 @@ static ValueT safeGetOrCreate(DenseSet<ValueT, DenseInfoT> &container,
   }
 
   // Acquire a writer-lock so that we can safely create the new instance.
-  llvm::sys::SmartScopedWriter<true> instanceLock(mutex);
+  ScopedWriterLock instanceLock(mutex, threadingIsEnabled);
 
   // Check for an existing instance again here, because another writer thread
-  // may have already created one.
+  // may have already created one. Otherwise, construct a new instance.
   auto existing = container.insert_as(ValueT(), key);
-  if (!existing.second)
-    return *existing.first;
-
-  // Otherwise, construct a new instance of the value.
-  return *existing.first = constructorFn();
+  if (existing.second)
+    return *existing.first = constructorFn();
+  return *existing.first;
 }
 
 namespace {
@@ -217,6 +256,9 @@ class MLIRContextImpl {
   /// detect such use cases
   bool allowUnregisteredDialects = false;
 
+  /// Enable support for multi-threading within MLIR.
+  bool threadingIsEnabled = true;
+
   /// If the operation should be attached to diagnostics printed via the
   /// Operation::emit methods.
   bool printOpOnDiagnostic = true;
@@ -288,17 +330,19 @@ class MLIRContextImpl {
   UnknownLoc unknownLocAttr;
 
 public:
-  MLIRContextImpl() : identifiers(identifierAllocator) {
-    // Initialize values based on the command line flags if they were provided.
-    if (clOptions.isConstructed()) {
-      printOpOnDiagnostic = clOptions->printOpOnDiagnostic;
-      printStackTraceOnDiagnostic = clOptions->printStackTraceOnDiagnostic;
-    }
-  }
+  MLIRContextImpl() : identifiers(identifierAllocator) {}
 };
 } // end namespace mlir
 
 MLIRContext::MLIRContext() : impl(new MLIRContextImpl()) {
+  // Initialize values based on the command line flags if they were provided.
+  if (clOptions.isConstructed()) {
+    disableMultithreading(clOptions->disableThreading);
+    printOpOnDiagnostic(clOptions->printOpOnDiagnostic);
+    printStackTraceOnDiagnostic(clOptions->printStackTraceOnDiagnostic);
+  }
+
+  // Register dialects with this context.
   new BuiltinDialect(this);
   registerAllDialects(this);
 
@@ -372,11 +416,10 @@ DiagnosticEngine &MLIRContext::getDiagEngine() { return getImpl().diagEngine; }
 /// Return information about all registered IR dialects.
 std::vector<Dialect *> MLIRContext::getRegisteredDialects() {
   // Lock access to the context registry.
-  llvm::sys::SmartScopedReader<true> registryLock(getImpl().contextMutex);
-
+  ScopedReaderLock registryLock(impl->contextMutex, impl->threadingIsEnabled);
   std::vector<Dialect *> result;
-  result.reserve(getImpl().dialects.size());
-  for (auto &dialect : getImpl().dialects)
+  result.reserve(impl->dialects.size());
+  for (auto &dialect : impl->dialects)
     result.push_back(dialect.get());
   return result;
 }
@@ -385,11 +428,15 @@ std::vector<Dialect *> MLIRContext::getRegisteredDialects() {
 /// then return nullptr.
 Dialect *MLIRContext::getRegisteredDialect(StringRef name) {
   // Lock access to the context registry.
-  llvm::sys::SmartScopedReader<true> registryLock(getImpl().contextMutex);
-  for (auto &dialect : getImpl().dialects)
-    if (name == dialect->getNamespace())
-      return dialect.get();
-  return nullptr;
+  ScopedReaderLock registryLock(impl->contextMutex, impl->threadingIsEnabled);
+
+  // Dialects are sorted by name, so we can use binary search for lookup.
+  auto it = llvm::lower_bound(
+      impl->dialects, name,
+      [](const auto &lhs, StringRef rhs) { return lhs->getNamespace() < rhs; });
+  return (it != impl->dialects.end() && (*it)->getNamespace() == name)
+             ? (*it).get()
+             : nullptr;
 }
 
 /// Register this dialect object with the specified context.  The context
@@ -399,15 +446,13 @@ void Dialect::registerDialect(MLIRContext *context) {
   std::unique_ptr<Dialect> dialect(this);
 
   // Lock access to the context registry.
-  llvm::sys::SmartScopedWriter<true> registryLock(impl.contextMutex);
+  ScopedWriterLock registryLock(impl.contextMutex, impl.threadingIsEnabled);
 
   // Get the correct insertion position sorted by namespace.
-  auto insertPt =
-      llvm::lower_bound(impl.dialects, dialect,
-                        [](const std::unique_ptr<Dialect> &lhs,
-                           const std::unique_ptr<Dialect> &rhs) {
-                          return lhs->getNamespace() < rhs->getNamespace();
-                        });
+  auto insertPt = llvm::lower_bound(
+      impl.dialects, dialect, [](const auto &lhs, const auto &rhs) {
+        return lhs->getNamespace() < rhs->getNamespace();
+      });
 
   // Abort if dialect with namespace has already been registered.
   if (insertPt != impl.dialects.end() &&
@@ -426,6 +471,21 @@ void MLIRContext::allowUnregisteredDialects(bool allowing) {
   impl->allowUnregisteredDialects = allowing;
 }
 
+/// Return true if multi-threading is disabled by the context.
+bool MLIRContext::isMultithreadingEnabled() {
+  return impl->threadingIsEnabled && llvm::llvm_is_multithreaded();
+}
+
+/// Set the flag specifying if multi-threading is disabled by the context.
+void MLIRContext::disableMultithreading(bool disable) {
+  impl->threadingIsEnabled = !disable;
+
+  // Update the threading mode for each of the uniquers.
+  impl->affineUniquer.disableMultithreading(disable);
+  impl->attributeUniquer.disableMultithreading(disable);
+  impl->typeUniquer.disableMultithreading(disable);
+}
+
 /// Return true if we should attach the operation to diagnostics emitted via
 /// Operation::emit.
 bool MLIRContext::shouldPrintOpOnDiagnostic() {
@@ -457,13 +517,13 @@ std::vector<AbstractOperation *> MLIRContext::getRegisteredOperations() {
   std::vector<std::pair<StringRef, AbstractOperation *>> opsToSort;
 
   { // Lock access to the context registry.
-    llvm::sys::SmartScopedReader<true> registryLock(getImpl().contextMutex);
+    ScopedReaderLock registryLock(impl->contextMutex, impl->threadingIsEnabled);
 
     // We just have the operations in a non-deterministic hash table order. Dump
     // into a temporary array, then sort it by operation name to get a stable
     // ordering.
     llvm::StringMap<AbstractOperation> &registeredOps =
-        getImpl().registeredOperations;
+        impl->registeredOperations;
 
     opsToSort.reserve(registeredOps.size());
     for (auto &elt : registeredOps)
@@ -487,7 +547,7 @@ void Dialect::addOperation(AbstractOperation opInfo) {
   auto &impl = context->getImpl();
 
   // Lock access to the context registry.
-  llvm::sys::SmartScopedWriter<true> registryLock(impl.contextMutex);
+  ScopedWriterLock registryLock(impl.contextMutex, impl.threadingIsEnabled);
   if (!impl.registeredOperations.insert({opInfo.name, opInfo}).second) {
     llvm::errs() << "error: operation named '" << opInfo.name
                  << "' is already registered.\n";
@@ -500,7 +560,7 @@ void Dialect::addSymbol(TypeID typeID) {
   auto &impl = context->getImpl();
 
   // Lock access to the context registry.
-  llvm::sys::SmartScopedWriter<true> registryLock(impl.contextMutex);
+  ScopedWriterLock registryLock(impl.contextMutex, impl.threadingIsEnabled);
   if (!impl.registeredDialectSymbols.insert({typeID, this}).second) {
     llvm::errs() << "error: dialect symbol already registered.\n";
     abort();
@@ -514,7 +574,7 @@ const AbstractOperation *AbstractOperation::lookup(StringRef opName,
   auto &impl = context->getImpl();
 
   // Lock access to the context registry.
-  llvm::sys::SmartScopedReader<true> registryLock(impl.contextMutex);
+  ScopedReaderLock registryLock(impl.contextMutex, impl.threadingIsEnabled);
   auto it = impl.registeredOperations.find(opName);
   if (it != impl.registeredOperations.end())
     return &it->second;
@@ -529,7 +589,8 @@ const AbstractOperation *AbstractOperation::lookup(StringRef opName,
 Identifier Identifier::get(StringRef str, MLIRContext *context) {
   auto &impl = context->getImpl();
 
-  { // Check for an existing identifier in read-only mode.
+  // Check for an existing identifier in read-only mode.
+  if (context->isMultithreadingEnabled()) {
     llvm::sys::SmartScopedReader<true> contextLock(impl.identifierMutex);
     auto it = impl.identifiers.find(str);
     if (it != impl.identifiers.end())
@@ -544,7 +605,7 @@ Identifier Identifier::get(StringRef str, MLIRContext *context) {
          "Cannot create an identifier with a nul character");
 
   // Acquire a writer-lock so that we can safely create the new instance.
-  llvm::sys::SmartScopedWriter<true> contextLock(impl.identifierMutex);
+  ScopedWriterLock contextLock(impl.identifierMutex, impl.threadingIsEnabled);
   auto it = impl.identifiers.insert(str).first;
   return Identifier(&*it);
 }
@@ -696,16 +757,18 @@ AffineMap AffineMap::getImpl(unsigned dimCount, unsigned symbolCount,
   auto key = std::make_tuple(dimCount, symbolCount, results);
 
   // Safely get or create an AffineMap instance.
-  return safeGetOrCreate(impl.affineMaps, key, impl.affineMutex, [&] {
-    auto *res = impl.affineAllocator.Allocate<detail::AffineMapStorage>();
+  return safeGetOrCreate(
+      impl.affineMaps, key, impl.affineMutex, impl.threadingIsEnabled, [&] {
+        auto *res = impl.affineAllocator.Allocate<detail::AffineMapStorage>();
 
-    // Copy the results into the bump pointer.
-    results = copyArrayRefInto(impl.affineAllocator, results);
+        // Copy the results into the bump pointer.
+        results = copyArrayRefInto(impl.affineAllocator, results);
 
-    // Initialize the memory using placement new.
-    new (res) detail::AffineMapStorage{dimCount, symbolCount, results, context};
-    return AffineMap(res);
-  });
+        // Initialize the memory using placement new.
+        new (res)
+            detail::AffineMapStorage{dimCount, symbolCount, results, context};
+        return AffineMap(res);
+      });
 }
 
 AffineMap AffineMap::get(MLIRContext *context) {
@@ -760,12 +823,12 @@ IntegerSet IntegerSet::get(unsigned dimCount, unsigned symbolCount,
   if (constraints.size() < IntegerSet::kUniquingThreshold) {
     auto key = std::make_tuple(dimCount, symbolCount, constraints, eqFlags);
     return safeGetOrCreate(impl.integerSets, key, impl.affineMutex,
-                           constructorFn);
+                           impl.threadingIsEnabled, constructorFn);
   }
 
   // Otherwise, acquire a writer-lock so that we can safely create the new
   // instance.
-  llvm::sys::SmartScopedWriter<true> affineLock(impl.affineMutex);
+  ScopedWriterLock affineLock(impl.affineMutex, impl.threadingIsEnabled);
   return constructorFn();
 }
 

diff  --git a/mlir/lib/Pass/IRPrinting.cpp b/mlir/lib/Pass/IRPrinting.cpp
index ba9ff989ea79..842b83c5d0f7 100644
--- a/mlir/lib/Pass/IRPrinting.cpp
+++ b/mlir/lib/Pass/IRPrinting.cpp
@@ -257,7 +257,8 @@ struct BasicIRPrinterConfig : public PassManager::IRPrinterConfig {
 /// Add an instrumentation to print the IR before and after pass execution,
 /// using the provided configuration.
 void PassManager::enableIRPrinting(std::unique_ptr<IRPrinterConfig> config) {
-  if (config->shouldPrintAtModuleScope() && isMultithreadingEnabled())
+  if (config->shouldPrintAtModuleScope() &&
+      getContext()->isMultithreadingEnabled())
     llvm::report_fatal_error("IR printing can't be setup on a pass-manager "
                              "without disabling multi-threading first.");
   addInstrumentation(

diff  --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp
index ef4ed760feb8..83855faca755 100644
--- a/mlir/lib/Pass/Pass.cpp
+++ b/mlir/lib/Pass/Pass.cpp
@@ -118,9 +118,8 @@ void VerifierPass::runOnOperation() {
 namespace mlir {
 namespace detail {
 struct OpPassManagerImpl {
-  OpPassManagerImpl(OperationName name, bool disableThreads, bool verifyPasses)
-      : name(name), disableThreads(disableThreads), verifyPasses(verifyPasses) {
-  }
+  OpPassManagerImpl(OperationName name, bool verifyPasses)
+      : name(name), verifyPasses(verifyPasses) {}
 
   /// Merge the passes of this pass manager into the one provided.
   void mergeInto(OpPassManagerImpl &rhs);
@@ -152,9 +151,6 @@ struct OpPassManagerImpl {
   /// The name of the operation that passes of this pass manager operate on.
   OperationName name;
 
-  /// Flag to disable multi-threading of passes.
-  bool disableThreads : 1;
-
   /// Flag that specifies if the IR should be verified after each pass has run.
   bool verifyPasses : 1;
 
@@ -172,7 +168,7 @@ void OpPassManagerImpl::mergeInto(OpPassManagerImpl &rhs) {
 }
 
 OpPassManager &OpPassManagerImpl::nest(const OperationName &nestedName) {
-  OpPassManager nested(nestedName, disableThreads, verifyPasses);
+  OpPassManager nested(nestedName, verifyPasses);
   auto *adaptor = new OpToOpPassAdaptor(std::move(nested));
   addPass(std::unique_ptr<Pass>(adaptor));
   return adaptor->getPassManagers().front();
@@ -269,9 +265,8 @@ void OpPassManagerImpl::splitAdaptorPasses() {
 // OpPassManager
 //===----------------------------------------------------------------------===//
 
-OpPassManager::OpPassManager(OperationName name, bool disableThreads,
-                             bool verifyPasses)
-    : impl(new OpPassManagerImpl(name, disableThreads, verifyPasses)) {
+OpPassManager::OpPassManager(OperationName name, bool verifyPasses)
+    : impl(new OpPassManagerImpl(name, verifyPasses)) {
   assert(name.getAbstractOperation() &&
          "OpPassManager can only operate on registered operations");
   assert(name.getAbstractOperation()->hasProperty(
@@ -282,8 +277,7 @@ OpPassManager::OpPassManager(OperationName name, bool disableThreads,
 OpPassManager::OpPassManager(OpPassManager &&rhs) : impl(std::move(rhs.impl)) {}
 OpPassManager::OpPassManager(const OpPassManager &rhs) { *this = rhs; }
 OpPassManager &OpPassManager::operator=(const OpPassManager &rhs) {
-  impl.reset(new OpPassManagerImpl(rhs.impl->name, rhs.impl->disableThreads,
-                                   rhs.impl->verifyPasses));
+  impl.reset(new OpPassManagerImpl(rhs.impl->name, rhs.impl->verifyPasses));
   for (auto &pass : rhs.impl->passes)
     impl->passes.emplace_back(pass->clone());
   return *this;
@@ -419,10 +413,10 @@ std::string OpToOpPassAdaptor::getAdaptorName() {
 
 /// Run the held pipeline over all nested operations.
 void OpToOpPassAdaptor::runOnOperation() {
-  if (mgrs.front().getImpl().disableThreads || !llvm::llvm_is_multithreaded())
-    runOnOperationImpl();
-  else
+  if (getContext().isMultithreadingEnabled())
     runOnOperationAsyncImpl();
+  else
+    runOnOperationImpl();
 }
 
 /// Run this pass adaptor synchronously.
@@ -576,7 +570,7 @@ struct RecoveryReproducerContext {
   /// The filename to use when generating the reproducer.
   StringRef filename;
 
-  /// Various pass manager flags.
+  /// Various pass manager and context flags.
   bool disableThreads;
   bool verifyPasses;
 
@@ -628,7 +622,7 @@ LogicalResult RecoveryReproducerContext::generate(std::string &error) {
   // Output the current pass manager configuration.
   outputOS << "// configuration: -pass-pipeline='" << pipeline << "'";
   if (disableThreads)
-    outputOS << " -disable-pass-threading";
+    outputOS << " -mlir-disable-threading";
 
   // TODO: Should this also be configured with a pass manager flag?
   outputOS << "\n// note: verifyPasses=" << (verifyPasses ? "true" : "false")
@@ -684,7 +678,8 @@ LogicalResult
 PassManager::runWithCrashRecovery(MutableArrayRef<std::unique_ptr<Pass>> passes,
                                   ModuleOp module, AnalysisManager am) {
   RecoveryReproducerContext context(passes, module, *crashReproducerFileName,
-                                    impl->disableThreads, impl->verifyPasses);
+                                    !getContext()->isMultithreadingEnabled(),
+                                    impl->verifyPasses);
 
   // Safely invoke the passes within a recovery context.
   llvm::CrashRecoveryContext::Enable();
@@ -715,7 +710,7 @@ PassManager::runWithCrashRecovery(MutableArrayRef<std::unique_ptr<Pass>> passes,
 
 PassManager::PassManager(MLIRContext *ctx, bool verifyPasses)
     : OpPassManager(OperationName(ModuleOp::getOperationName(), ctx),
-                    /*disableThreads=*/false, verifyPasses),
+                    verifyPasses),
       passTiming(false), localReproducer(false) {}
 
 PassManager::~PassManager() {}
@@ -741,15 +736,6 @@ LogicalResult PassManager::run(ModuleOp module) {
   return result;
 }
 
-/// Disable support for multi-threading within the pass manager.
-void PassManager::disableMultithreading(bool disable) {
-  getImpl().disableThreads = disable;
-}
-
-bool PassManager::isMultithreadingEnabled() {
-  return !getImpl().disableThreads;
-}
-
 /// Enable support for the pass manager to generate a reproducer on the event
 /// of a crash or a pass failure. `outputFile` is a .mlir filename used to write
 /// the generated reproducer. If `genLocalReproducer` is true, the pass manager

diff  --git a/mlir/lib/Pass/PassManagerOptions.cpp b/mlir/lib/Pass/PassManagerOptions.cpp
index c1b1eee07ea5..b00f992eceb9 100644
--- a/mlir/lib/Pass/PassManagerOptions.cpp
+++ b/mlir/lib/Pass/PassManagerOptions.cpp
@@ -29,14 +29,6 @@ struct PassManagerOptions {
                      "a reproducer with the smallest pipeline."),
       llvm::cl::init(false)};
 
-  //===--------------------------------------------------------------------===//
-  // Multi-threading
-  //===--------------------------------------------------------------------===//
-  llvm::cl::opt<bool> disableThreads{
-      "disable-pass-threading",
-      llvm::cl::desc("Disable multithreading in the pass manager"),
-      llvm::cl::init(false)};
-
   //===--------------------------------------------------------------------===//
   // IR Printing
   //===--------------------------------------------------------------------===//
@@ -164,10 +156,6 @@ void mlir::applyPassManagerCLOptions(PassManager &pm) {
     pm.enableCrashReproducerGeneration(options->reproducerFile,
                                        options->localReproducer);
 
-  // Disable multi-threading.
-  if (options->disableThreads)
-    pm.disableMultithreading();
-
   // Enable statistics dumping.
   if (options->passStatistics)
     pm.enableStatistics(options->passStatisticsDisplayMode);

diff  --git a/mlir/lib/Support/StorageUniquer.cpp b/mlir/lib/Support/StorageUniquer.cpp
index d50c599a7776..40304a544c4f 100644
--- a/mlir/lib/Support/StorageUniquer.cpp
+++ b/mlir/lib/Support/StorageUniquer.cpp
@@ -46,6 +46,8 @@ struct StorageUniquerImpl {
               function_ref<bool(const BaseStorage *)> isEqual,
               function_ref<BaseStorage *(StorageAllocator &)> ctorFn) {
     LookupKey lookupKey{kind, hashValue, isEqual};
+    if (!threadingIsEnabled)
+      return getOrCreateUnsafe(kind, hashValue, lookupKey, ctorFn);
 
     // Check for an existing instance in read-only mode.
     {
@@ -57,9 +59,12 @@ struct StorageUniquerImpl {
 
     // Acquire a writer-lock so that we can safely create the new type instance.
     llvm::sys::SmartScopedWriter<true> typeLock(mutex);
-
-    // Check for an existing instance again here, because another writer thread
-    // may have already created one.
+    return getOrCreateUnsafe(kind, hashValue, lookupKey, ctorFn);
+  }
+  /// Get or create an instance of a complex derived type in an unsafe fashion.
+  BaseStorage *
+  getOrCreateUnsafe(unsigned kind, unsigned hashValue, LookupKey &lookupKey,
+                    function_ref<BaseStorage *(StorageAllocator &)> ctorFn) {
     auto existing = storageTypes.insert_as({}, lookupKey);
     if (!existing.second)
       return existing.first->storage;
@@ -75,6 +80,9 @@ struct StorageUniquerImpl {
   BaseStorage *
   getOrCreate(unsigned kind,
               function_ref<BaseStorage *(StorageAllocator &)> ctorFn) {
+    if (!threadingIsEnabled)
+      return getOrCreateUnsafe(kind, ctorFn);
+
     // Check for an existing instance in read-only mode.
     {
       llvm::sys::SmartScopedReader<true> typeLock(mutex);
@@ -85,9 +93,12 @@ struct StorageUniquerImpl {
 
     // Acquire a writer-lock so that we can safely create the new type instance.
     llvm::sys::SmartScopedWriter<true> typeLock(mutex);
-
-    // Check for an existing instance again here, because another writer thread
-    // may have already created one.
+    return getOrCreateUnsafe(kind, ctorFn);
+  }
+  /// Get or create an instance of a simple derived type in an unsafe fashion.
+  BaseStorage *
+  getOrCreateUnsafe(unsigned kind,
+                    function_ref<BaseStorage *(StorageAllocator &)> ctorFn) {
     auto &result = simpleTypes[kind];
     if (result)
       return result;
@@ -152,18 +163,21 @@ struct StorageUniquerImpl {
     }
   };
 
-  // Unique types with specific hashing or storage constraints.
+  /// Unique types with specific hashing or storage constraints.
   using StorageTypeSet = DenseSet<HashedStorage, StorageKeyInfo>;
   StorageTypeSet storageTypes;
 
-  // Unique types with just the kind.
+  /// Unique types with just the kind.
   DenseMap<unsigned, BaseStorage *> simpleTypes;
 
-  // Allocator to use when constructing derived type instances.
+  /// Allocator to use when constructing derived type instances.
   StorageUniquer::StorageAllocator allocator;
 
-  // A mutex to keep type uniquing thread-safe.
+  /// A mutex to keep type uniquing thread-safe.
   llvm::sys::SmartRWMutex<true> mutex;
+
+  /// Flag specifying if multi-threading is enabled within the uniquer.
+  bool threadingIsEnabled = true;
 };
 } // end namespace detail
 } // namespace mlir
@@ -171,6 +185,11 @@ struct StorageUniquerImpl {
 StorageUniquer::StorageUniquer() : impl(new StorageUniquerImpl()) {}
 StorageUniquer::~StorageUniquer() {}
 
+/// Set the flag specifying if multi-threading is disabled within the uniquer.
+void StorageUniquer::disableMultithreading(bool disable) {
+  impl->threadingIsEnabled = !disable;
+}
+
 /// Implementation for getting/creating an instance of a derived type with
 /// complex storage.
 auto StorageUniquer::getImpl(

diff  --git a/mlir/lib/Transforms/Inliner.cpp b/mlir/lib/Transforms/Inliner.cpp
index c0f89da300f1..ee645cb55511 100644
--- a/mlir/lib/Transforms/Inliner.cpp
+++ b/mlir/lib/Transforms/Inliner.cpp
@@ -496,22 +496,28 @@ static void canonicalizeSCC(CallGraph &cg, CGUseList &useList,
   // NOTE: This is simple now, because we don't enable canonicalizing nodes
   // within children. When we remove this restriction, this logic will need to
   // be reworked.
-  ParallelDiagnosticHandler canonicalizationHandler(context);
-  llvm::parallel::for_each_n(
-      llvm::parallel::par, /*Begin=*/size_t(0),
-      /*End=*/nodesToCanonicalize.size(), [&](size_t index) {
-        // Set the order for this thread so that diagnostics will be properly
-        // ordered.
-        canonicalizationHandler.setOrderIDForThread(index);
-
-        // Apply the canonicalization patterns to this region.
-        auto *node = nodesToCanonicalize[index];
-        applyPatternsAndFoldGreedily(*node->getCallableRegion(), canonPatterns);
-
-        // Make sure to reset the order ID for the diagnostic handler, as this
-        // thread may be used in a 
diff erent context.
-        canonicalizationHandler.eraseOrderIDForThread();
-      });
+  if (context->isMultithreadingEnabled()) {
+    ParallelDiagnosticHandler canonicalizationHandler(context);
+    llvm::parallel::for_each_n(
+        llvm::parallel::par, /*Begin=*/size_t(0),
+        /*End=*/nodesToCanonicalize.size(), [&](size_t index) {
+          // Set the order for this thread so that diagnostics will be properly
+          // ordered.
+          canonicalizationHandler.setOrderIDForThread(index);
+
+          // Apply the canonicalization patterns to this region.
+          auto *node = nodesToCanonicalize[index];
+          applyPatternsAndFoldGreedily(*node->getCallableRegion(),
+                                       canonPatterns);
+
+          // Make sure to reset the order ID for the diagnostic handler, as this
+          // thread may be used in a 
diff erent context.
+          canonicalizationHandler.eraseOrderIDForThread();
+        });
+  } else {
+    for (CallGraphNode *node : nodesToCanonicalize)
+      applyPatternsAndFoldGreedily(*node->getCallableRegion(), canonPatterns);
+  }
 
   // Recompute the uses held by each of the nodes.
   for (CallGraphNode *node : nodesToCanonicalize)

diff  --git a/mlir/test/Dialect/SPIRV/availability.mlir b/mlir/test/Dialect/SPIRV/availability.mlir
index e31c1bdeacca..322cc533c826 100644
--- a/mlir/test/Dialect/SPIRV/availability.mlir
+++ b/mlir/test/Dialect/SPIRV/availability.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -disable-pass-threading -test-spirv-op-availability %s | FileCheck %s
+// RUN: mlir-opt -mlir-disable-threading -test-spirv-op-availability %s | FileCheck %s
 
 // CHECK-LABEL: iadd
 func @iadd(%arg: i32) -> i32 {

diff  --git a/mlir/test/Dialect/SPIRV/target-env.mlir b/mlir/test/Dialect/SPIRV/target-env.mlir
index 9b42314e3f1d..27c4e8d04092 100644
--- a/mlir/test/Dialect/SPIRV/target-env.mlir
+++ b/mlir/test/Dialect/SPIRV/target-env.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -disable-pass-threading -test-spirv-target-env %s | FileCheck %s
+// RUN: mlir-opt -mlir-disable-threading -test-spirv-target-env %s | FileCheck %s
 
 // Note: The following tests check that a spv.target_env can properly control
 // the conversion target and filter unavailable ops during the conversion.

diff  --git a/mlir/test/IR/test-matchers.mlir b/mlir/test/IR/test-matchers.mlir
index 60d5bcf7d81b..925b01bda110 100644
--- a/mlir/test/IR/test-matchers.mlir
+++ b/mlir/test/IR/test-matchers.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -disable-pass-threading=true -test-matchers -o /dev/null 2>&1 | FileCheck %s
+// RUN: mlir-opt %s -mlir-disable-threading=true -test-matchers -o /dev/null 2>&1 | FileCheck %s
 
 func @test1(%a: f32, %b: f32, %c: f32) {
   %0 = addf %a, %b: f32

diff  --git a/mlir/test/Pass/ir-printing.mlir b/mlir/test/Pass/ir-printing.mlir
index 892dc40a034a..8bb86b36c181 100644
--- a/mlir/test/Pass/ir-printing.mlir
+++ b/mlir/test/Pass/ir-printing.mlir
@@ -1,9 +1,9 @@
-// RUN: mlir-opt %s -disable-pass-threading=true -pass-pipeline='func(cse,canonicalize)' -print-ir-before=cse  -o /dev/null 2>&1 | FileCheck -check-prefix=BEFORE %s
-// RUN: mlir-opt %s -disable-pass-threading=true -pass-pipeline='func(cse,canonicalize)' -print-ir-before-all -o /dev/null 2>&1 | FileCheck -check-prefix=BEFORE_ALL %s
-// RUN: mlir-opt %s -disable-pass-threading=true -pass-pipeline='func(cse,canonicalize)' -print-ir-after=cse -o /dev/null 2>&1 | FileCheck -check-prefix=AFTER %s
-// RUN: mlir-opt %s -disable-pass-threading=true -pass-pipeline='func(cse,canonicalize)' -print-ir-after-all -o /dev/null 2>&1 | FileCheck -check-prefix=AFTER_ALL %s
-// RUN: mlir-opt %s -disable-pass-threading=true -pass-pipeline='func(cse,canonicalize)' -print-ir-before=cse -print-ir-module-scope -o /dev/null 2>&1 | FileCheck -check-prefix=BEFORE_MODULE %s
-// RUN: mlir-opt %s -disable-pass-threading=true -pass-pipeline='func(cse,cse)' -print-ir-after-all -print-ir-after-change -o /dev/null 2>&1 | FileCheck -check-prefix=AFTER_ALL_CHANGE %s
+// RUN: mlir-opt %s -mlir-disable-threading=true -pass-pipeline='func(cse,canonicalize)' -print-ir-before=cse  -o /dev/null 2>&1 | FileCheck -check-prefix=BEFORE %s
+// RUN: mlir-opt %s -mlir-disable-threading=true -pass-pipeline='func(cse,canonicalize)' -print-ir-before-all -o /dev/null 2>&1 | FileCheck -check-prefix=BEFORE_ALL %s
+// RUN: mlir-opt %s -mlir-disable-threading=true -pass-pipeline='func(cse,canonicalize)' -print-ir-after=cse -o /dev/null 2>&1 | FileCheck -check-prefix=AFTER %s
+// RUN: mlir-opt %s -mlir-disable-threading=true -pass-pipeline='func(cse,canonicalize)' -print-ir-after-all -o /dev/null 2>&1 | FileCheck -check-prefix=AFTER_ALL %s
+// RUN: mlir-opt %s -mlir-disable-threading=true -pass-pipeline='func(cse,canonicalize)' -print-ir-before=cse -print-ir-module-scope -o /dev/null 2>&1 | FileCheck -check-prefix=BEFORE_MODULE %s
+// RUN: mlir-opt %s -mlir-disable-threading=true -pass-pipeline='func(cse,cse)' -print-ir-after-all -print-ir-after-change -o /dev/null 2>&1 | FileCheck -check-prefix=AFTER_ALL_CHANGE %s
 
 func @foo() {
   %0 = constant 0 : i32

diff  --git a/mlir/test/Pass/pass-timing.mlir b/mlir/test/Pass/pass-timing.mlir
index db39ad6a1633..6cd8a29e6f48 100644
--- a/mlir/test/Pass/pass-timing.mlir
+++ b/mlir/test/Pass/pass-timing.mlir
@@ -1,8 +1,8 @@
-// RUN: mlir-opt %s -disable-pass-threading=true -verify-each=true -pass-pipeline='func(cse,canonicalize,cse)' -pass-timing -pass-timing-display=list 2>&1 | FileCheck -check-prefix=LIST %s
-// RUN: mlir-opt %s -disable-pass-threading=true -verify-each=true -pass-pipeline='func(cse,canonicalize,cse)' -pass-timing -pass-timing-display=pipeline 2>&1 | FileCheck -check-prefix=PIPELINE %s
-// RUN: mlir-opt %s -disable-pass-threading=false -verify-each=true -pass-pipeline='func(cse,canonicalize,cse)' -pass-timing -pass-timing-display=list 2>&1 | FileCheck -check-prefix=MT_LIST %s
-// RUN: mlir-opt %s -disable-pass-threading=false -verify-each=true -pass-pipeline='func(cse,canonicalize,cse)' -pass-timing -pass-timing-display=pipeline 2>&1 | FileCheck -check-prefix=MT_PIPELINE %s
-// RUN: mlir-opt %s -disable-pass-threading=false -verify-each=false -test-pm-nested-pipeline -pass-timing -pass-timing-display=pipeline 2>&1 | FileCheck -check-prefix=NESTED_MT_PIPELINE %s
+// RUN: mlir-opt %s -mlir-disable-threading=true -verify-each=true -pass-pipeline='func(cse,canonicalize,cse)' -pass-timing -pass-timing-display=list 2>&1 | FileCheck -check-prefix=LIST %s
+// RUN: mlir-opt %s -mlir-disable-threading=true -verify-each=true -pass-pipeline='func(cse,canonicalize,cse)' -pass-timing -pass-timing-display=pipeline 2>&1 | FileCheck -check-prefix=PIPELINE %s
+// RUN: mlir-opt %s -mlir-disable-threading=false -verify-each=true -pass-pipeline='func(cse,canonicalize,cse)' -pass-timing -pass-timing-display=list 2>&1 | FileCheck -check-prefix=MT_LIST %s
+// RUN: mlir-opt %s -mlir-disable-threading=false -verify-each=true -pass-pipeline='func(cse,canonicalize,cse)' -pass-timing -pass-timing-display=pipeline 2>&1 | FileCheck -check-prefix=MT_PIPELINE %s
+// RUN: mlir-opt %s -mlir-disable-threading=false -verify-each=false -test-pm-nested-pipeline -pass-timing -pass-timing-display=pipeline 2>&1 | FileCheck -check-prefix=NESTED_MT_PIPELINE %s
 
 // LIST: Pass execution timing report
 // LIST: Total Execution Time:


        


More information about the Mlir-commits mailing list