[Mlir-commits] [mlir] cb9ae00 - [mlir] Add a new context flag for disabling/enabling multi-threading
River Riddle
llvmlistbot at llvm.org
Sat May 2 12:35:45 PDT 2020
Author: River Riddle
Date: 2020-05-02T12:32:25-07:00
New Revision: cb9ae0025c4ed966a3a9b5539a9ff6b6e865516f
URL: https://github.com/llvm/llvm-project/commit/cb9ae0025c4ed966a3a9b5539a9ff6b6e865516f
DIFF: https://github.com/llvm/llvm-project/commit/cb9ae0025c4ed966a3a9b5539a9ff6b6e865516f.diff
LOG: [mlir] Add a new context flag for disabling/enabling multi-threading
This is useful for several reasons:
* In some situations the user can guarantee that thread-safety isn't necessary and don't want to pay the cost of synchronization, e.g., when parsing a very large module.
* For things like logging threading is not desirable as the output is not guaranteed to be in stable order.
This flag also subsumes the pass manager flag for multi-threading.
Differential Revision: https://reviews.llvm.org/D79266
Added:
Modified:
mlir/docs/PassManagement.md
mlir/include/mlir/IR/MLIRContext.h
mlir/include/mlir/Pass/PassManager.h
mlir/include/mlir/Support/StorageUniquer.h
mlir/lib/IR/MLIRContext.cpp
mlir/lib/Pass/IRPrinting.cpp
mlir/lib/Pass/Pass.cpp
mlir/lib/Pass/PassManagerOptions.cpp
mlir/lib/Support/StorageUniquer.cpp
mlir/lib/Transforms/Inliner.cpp
mlir/test/Dialect/SPIRV/availability.mlir
mlir/test/Dialect/SPIRV/target-env.mlir
mlir/test/IR/test-matchers.mlir
mlir/test/Pass/ir-printing.mlir
mlir/test/Pass/pass-timing.mlir
Removed:
################################################################################
diff --git a/mlir/docs/PassManagement.md b/mlir/docs/PassManagement.md
index 90d30deed914..04a4ca0a7b3c 100644
--- a/mlir/docs/PassManagement.md
+++ b/mlir/docs/PassManagement.md
@@ -801,7 +801,7 @@ pipeline. This display mode is available in mlir-opt via
`-pass-timing-display=list`.
```shell
-$ mlir-opt foo.mlir -disable-pass-threading -pass-pipeline='func(cse,canonicalize)' -convert-std-to-llvm -pass-timing -pass-timing-display=list
+$ mlir-opt foo.mlir -mlir-disable-threading -pass-pipeline='func(cse,canonicalize)' -convert-std-to-llvm -pass-timing -pass-timing-display=list
===-------------------------------------------------------------------------===
... Pass execution timing report ...
@@ -826,7 +826,7 @@ the most time, and can also be used to identify when analyses are being
invalidated and recomputed. This is the default display mode.
```shell
-$ mlir-opt foo.mlir -disable-pass-threading -pass-pipeline='func(cse,canonicalize)' -convert-std-to-llvm -pass-timing
+$ mlir-opt foo.mlir -mlir-disable-threading -pass-pipeline='func(cse,canonicalize)' -convert-std-to-llvm -pass-timing
===-------------------------------------------------------------------------===
... Pass execution timing report ...
@@ -943,10 +943,10 @@ func @simple_constant() -> (i32, i32) {
* Always print the top-level module operation, regardless of pass type or
operation nesting level.
* Note: Printing at module scope should only be used when multi-threading
- is disabled(`-disable-pass-threading`)
+ is disabled(`-mlir-disable-threading`)
```shell
-$ mlir-opt foo.mlir -disable-pass-threading -pass-pipeline='func(cse)' -print-ir-after=cse -print-ir-module-scope
+$ mlir-opt foo.mlir -mlir-disable-threading -pass-pipeline='func(cse)' -print-ir-after=cse -print-ir-module-scope
*** IR Dump After CSE *** ('func' operation: @bar)
func @bar(%arg0: f32, %arg1: f32) -> f32 {
diff --git a/mlir/include/mlir/IR/MLIRContext.h b/mlir/include/mlir/IR/MLIRContext.h
index e1213889fe73..40b332698c44 100644
--- a/mlir/include/mlir/IR/MLIRContext.h
+++ b/mlir/include/mlir/IR/MLIRContext.h
@@ -55,6 +55,12 @@ class MLIRContext {
/// Enables creating operations in unregistered dialects.
void allowUnregisteredDialects(bool allow = true);
+ /// Return true if multi-threading is enabled by the context.
+ bool isMultithreadingEnabled();
+
+ /// Set the flag specifying if multi-threading is disabled by the context.
+ void disableMultithreading(bool disable = true);
+
/// Return true if we should attach the operation to diagnostics emitted via
/// Operation::emit.
bool shouldPrintOpOnDiagnostic();
diff --git a/mlir/include/mlir/Pass/PassManager.h b/mlir/include/mlir/Pass/PassManager.h
index 15c88128102e..be5c5d5bcc22 100644
--- a/mlir/include/mlir/Pass/PassManager.h
+++ b/mlir/include/mlir/Pass/PassManager.h
@@ -99,7 +99,7 @@ class OpPassManager {
void mergeStatisticsInto(OpPassManager &other);
private:
- OpPassManager(OperationName name, bool disableThreads, bool verifyPasses);
+ OpPassManager(OperationName name, bool verifyPasses);
/// A pointer to an internal implementation instance.
std::unique_ptr<detail::OpPassManagerImpl> impl;
@@ -139,13 +139,6 @@ class PassManager : public OpPassManager {
LLVM_NODISCARD
LogicalResult run(ModuleOp module);
- /// Disable support for multi-threading within the pass manager.
- void disableMultithreading(bool disable = true);
-
- /// Return true if the pass manager is configured with multi-threading
- /// enabled.
- bool isMultithreadingEnabled();
-
/// Enable support for the pass manager to generate a reproducer on the event
/// of a crash or a pass failure. `outputFile` is a .mlir filename used to
/// write the generated reproducer. If `genLocalReproducer` is true, the pass
diff --git a/mlir/include/mlir/Support/StorageUniquer.h b/mlir/include/mlir/Support/StorageUniquer.h
index 62a43ff6d1fe..f13a2fef9d50 100644
--- a/mlir/include/mlir/Support/StorageUniquer.h
+++ b/mlir/include/mlir/Support/StorageUniquer.h
@@ -65,6 +65,9 @@ class StorageUniquer {
StorageUniquer();
~StorageUniquer();
+ /// Set the flag specifying if multi-threading is disabled within the uniquer.
+ void disableMultithreading(bool disable = true);
+
/// This class acts as the base storage that all storage classes must derived
/// from.
class BaseStorage {
diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index b25c5111d8dc..c59c53567488 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -50,6 +50,10 @@ namespace {
/// various bits of an MLIRContext. This uses a struct wrapper to avoid the need
/// for global command line options.
struct MLIRContextOptions {
+ llvm::cl::opt<bool> disableThreading{
+ "mlir-disable-threading",
+ llvm::cl::desc("Disabling multi-threading within MLIR")};
+
llvm::cl::opt<bool> printOpOnDiagnostic{
"mlir-print-op-on-diagnostic",
llvm::cl::desc("When a diagnostic is emitted on an operation, also print "
@@ -101,6 +105,41 @@ struct BuiltinDialect : public Dialect {
};
} // end anonymous namespace.
+//===----------------------------------------------------------------------===//
+// Locking Utilities
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// Utility reader lock that takes a runtime flag that specifies if we really
+/// need to lock.
+struct ScopedReaderLock {
+ ScopedReaderLock(llvm::sys::SmartRWMutex<true> &mutexParam, bool shouldLock)
+ : mutex(shouldLock ? &mutexParam : nullptr) {
+ if (mutex)
+ mutex->lock_shared();
+ }
+ ~ScopedReaderLock() {
+ if (mutex)
+ mutex->unlock_shared();
+ }
+ llvm::sys::SmartRWMutex<true> *mutex;
+};
+/// Utility writer lock that takes a runtime flag that specifies if we really
+/// need to lock.
+struct ScopedWriterLock {
+ ScopedWriterLock(llvm::sys::SmartRWMutex<true> &mutexParam, bool shouldLock)
+ : mutex(shouldLock ? &mutexParam : nullptr) {
+ if (mutex)
+ mutex->lock();
+ }
+ ~ScopedWriterLock() {
+ if (mutex)
+ mutex->unlock();
+ }
+ llvm::sys::SmartRWMutex<true> *mutex;
+};
+} // end anonymous namespace.
+
//===----------------------------------------------------------------------===//
// AffineMap and IntegerSet hashing
//===----------------------------------------------------------------------===//
@@ -111,8 +150,10 @@ template <typename ValueT, typename DenseInfoT, typename KeyT,
typename ConstructorFn>
static ValueT safeGetOrCreate(DenseSet<ValueT, DenseInfoT> &container,
KeyT &&key, llvm::sys::SmartRWMutex<true> &mutex,
+ bool threadingIsEnabled,
ConstructorFn &&constructorFn) {
- { // Check for an existing instance in read-only mode.
+ // Check for an existing instance in read-only mode.
+ if (threadingIsEnabled) {
llvm::sys::SmartScopedReader<true> instanceLock(mutex);
auto it = container.find_as(key);
if (it != container.end())
@@ -120,16 +161,14 @@ static ValueT safeGetOrCreate(DenseSet<ValueT, DenseInfoT> &container,
}
// Acquire a writer-lock so that we can safely create the new instance.
- llvm::sys::SmartScopedWriter<true> instanceLock(mutex);
+ ScopedWriterLock instanceLock(mutex, threadingIsEnabled);
// Check for an existing instance again here, because another writer thread
- // may have already created one.
+ // may have already created one. Otherwise, construct a new instance.
auto existing = container.insert_as(ValueT(), key);
- if (!existing.second)
- return *existing.first;
-
- // Otherwise, construct a new instance of the value.
- return *existing.first = constructorFn();
+ if (existing.second)
+ return *existing.first = constructorFn();
+ return *existing.first;
}
namespace {
@@ -217,6 +256,9 @@ class MLIRContextImpl {
/// detect such use cases
bool allowUnregisteredDialects = false;
+ /// Enable support for multi-threading within MLIR.
+ bool threadingIsEnabled = true;
+
/// If the operation should be attached to diagnostics printed via the
/// Operation::emit methods.
bool printOpOnDiagnostic = true;
@@ -288,17 +330,19 @@ class MLIRContextImpl {
UnknownLoc unknownLocAttr;
public:
- MLIRContextImpl() : identifiers(identifierAllocator) {
- // Initialize values based on the command line flags if they were provided.
- if (clOptions.isConstructed()) {
- printOpOnDiagnostic = clOptions->printOpOnDiagnostic;
- printStackTraceOnDiagnostic = clOptions->printStackTraceOnDiagnostic;
- }
- }
+ MLIRContextImpl() : identifiers(identifierAllocator) {}
};
} // end namespace mlir
MLIRContext::MLIRContext() : impl(new MLIRContextImpl()) {
+ // Initialize values based on the command line flags if they were provided.
+ if (clOptions.isConstructed()) {
+ disableMultithreading(clOptions->disableThreading);
+ printOpOnDiagnostic(clOptions->printOpOnDiagnostic);
+ printStackTraceOnDiagnostic(clOptions->printStackTraceOnDiagnostic);
+ }
+
+ // Register dialects with this context.
new BuiltinDialect(this);
registerAllDialects(this);
@@ -372,11 +416,10 @@ DiagnosticEngine &MLIRContext::getDiagEngine() { return getImpl().diagEngine; }
/// Return information about all registered IR dialects.
std::vector<Dialect *> MLIRContext::getRegisteredDialects() {
// Lock access to the context registry.
- llvm::sys::SmartScopedReader<true> registryLock(getImpl().contextMutex);
-
+ ScopedReaderLock registryLock(impl->contextMutex, impl->threadingIsEnabled);
std::vector<Dialect *> result;
- result.reserve(getImpl().dialects.size());
- for (auto &dialect : getImpl().dialects)
+ result.reserve(impl->dialects.size());
+ for (auto &dialect : impl->dialects)
result.push_back(dialect.get());
return result;
}
@@ -385,11 +428,15 @@ std::vector<Dialect *> MLIRContext::getRegisteredDialects() {
/// then return nullptr.
Dialect *MLIRContext::getRegisteredDialect(StringRef name) {
// Lock access to the context registry.
- llvm::sys::SmartScopedReader<true> registryLock(getImpl().contextMutex);
- for (auto &dialect : getImpl().dialects)
- if (name == dialect->getNamespace())
- return dialect.get();
- return nullptr;
+ ScopedReaderLock registryLock(impl->contextMutex, impl->threadingIsEnabled);
+
+ // Dialects are sorted by name, so we can use binary search for lookup.
+ auto it = llvm::lower_bound(
+ impl->dialects, name,
+ [](const auto &lhs, StringRef rhs) { return lhs->getNamespace() < rhs; });
+ return (it != impl->dialects.end() && (*it)->getNamespace() == name)
+ ? (*it).get()
+ : nullptr;
}
/// Register this dialect object with the specified context. The context
@@ -399,15 +446,13 @@ void Dialect::registerDialect(MLIRContext *context) {
std::unique_ptr<Dialect> dialect(this);
// Lock access to the context registry.
- llvm::sys::SmartScopedWriter<true> registryLock(impl.contextMutex);
+ ScopedWriterLock registryLock(impl.contextMutex, impl.threadingIsEnabled);
// Get the correct insertion position sorted by namespace.
- auto insertPt =
- llvm::lower_bound(impl.dialects, dialect,
- [](const std::unique_ptr<Dialect> &lhs,
- const std::unique_ptr<Dialect> &rhs) {
- return lhs->getNamespace() < rhs->getNamespace();
- });
+ auto insertPt = llvm::lower_bound(
+ impl.dialects, dialect, [](const auto &lhs, const auto &rhs) {
+ return lhs->getNamespace() < rhs->getNamespace();
+ });
// Abort if dialect with namespace has already been registered.
if (insertPt != impl.dialects.end() &&
@@ -426,6 +471,21 @@ void MLIRContext::allowUnregisteredDialects(bool allowing) {
impl->allowUnregisteredDialects = allowing;
}
+/// Return true if multi-threading is disabled by the context.
+bool MLIRContext::isMultithreadingEnabled() {
+ return impl->threadingIsEnabled && llvm::llvm_is_multithreaded();
+}
+
+/// Set the flag specifying if multi-threading is disabled by the context.
+void MLIRContext::disableMultithreading(bool disable) {
+ impl->threadingIsEnabled = !disable;
+
+ // Update the threading mode for each of the uniquers.
+ impl->affineUniquer.disableMultithreading(disable);
+ impl->attributeUniquer.disableMultithreading(disable);
+ impl->typeUniquer.disableMultithreading(disable);
+}
+
/// Return true if we should attach the operation to diagnostics emitted via
/// Operation::emit.
bool MLIRContext::shouldPrintOpOnDiagnostic() {
@@ -457,13 +517,13 @@ std::vector<AbstractOperation *> MLIRContext::getRegisteredOperations() {
std::vector<std::pair<StringRef, AbstractOperation *>> opsToSort;
{ // Lock access to the context registry.
- llvm::sys::SmartScopedReader<true> registryLock(getImpl().contextMutex);
+ ScopedReaderLock registryLock(impl->contextMutex, impl->threadingIsEnabled);
// We just have the operations in a non-deterministic hash table order. Dump
// into a temporary array, then sort it by operation name to get a stable
// ordering.
llvm::StringMap<AbstractOperation> ®isteredOps =
- getImpl().registeredOperations;
+ impl->registeredOperations;
opsToSort.reserve(registeredOps.size());
for (auto &elt : registeredOps)
@@ -487,7 +547,7 @@ void Dialect::addOperation(AbstractOperation opInfo) {
auto &impl = context->getImpl();
// Lock access to the context registry.
- llvm::sys::SmartScopedWriter<true> registryLock(impl.contextMutex);
+ ScopedWriterLock registryLock(impl.contextMutex, impl.threadingIsEnabled);
if (!impl.registeredOperations.insert({opInfo.name, opInfo}).second) {
llvm::errs() << "error: operation named '" << opInfo.name
<< "' is already registered.\n";
@@ -500,7 +560,7 @@ void Dialect::addSymbol(TypeID typeID) {
auto &impl = context->getImpl();
// Lock access to the context registry.
- llvm::sys::SmartScopedWriter<true> registryLock(impl.contextMutex);
+ ScopedWriterLock registryLock(impl.contextMutex, impl.threadingIsEnabled);
if (!impl.registeredDialectSymbols.insert({typeID, this}).second) {
llvm::errs() << "error: dialect symbol already registered.\n";
abort();
@@ -514,7 +574,7 @@ const AbstractOperation *AbstractOperation::lookup(StringRef opName,
auto &impl = context->getImpl();
// Lock access to the context registry.
- llvm::sys::SmartScopedReader<true> registryLock(impl.contextMutex);
+ ScopedReaderLock registryLock(impl.contextMutex, impl.threadingIsEnabled);
auto it = impl.registeredOperations.find(opName);
if (it != impl.registeredOperations.end())
return &it->second;
@@ -529,7 +589,8 @@ const AbstractOperation *AbstractOperation::lookup(StringRef opName,
Identifier Identifier::get(StringRef str, MLIRContext *context) {
auto &impl = context->getImpl();
- { // Check for an existing identifier in read-only mode.
+ // Check for an existing identifier in read-only mode.
+ if (context->isMultithreadingEnabled()) {
llvm::sys::SmartScopedReader<true> contextLock(impl.identifierMutex);
auto it = impl.identifiers.find(str);
if (it != impl.identifiers.end())
@@ -544,7 +605,7 @@ Identifier Identifier::get(StringRef str, MLIRContext *context) {
"Cannot create an identifier with a nul character");
// Acquire a writer-lock so that we can safely create the new instance.
- llvm::sys::SmartScopedWriter<true> contextLock(impl.identifierMutex);
+ ScopedWriterLock contextLock(impl.identifierMutex, impl.threadingIsEnabled);
auto it = impl.identifiers.insert(str).first;
return Identifier(&*it);
}
@@ -696,16 +757,18 @@ AffineMap AffineMap::getImpl(unsigned dimCount, unsigned symbolCount,
auto key = std::make_tuple(dimCount, symbolCount, results);
// Safely get or create an AffineMap instance.
- return safeGetOrCreate(impl.affineMaps, key, impl.affineMutex, [&] {
- auto *res = impl.affineAllocator.Allocate<detail::AffineMapStorage>();
+ return safeGetOrCreate(
+ impl.affineMaps, key, impl.affineMutex, impl.threadingIsEnabled, [&] {
+ auto *res = impl.affineAllocator.Allocate<detail::AffineMapStorage>();
- // Copy the results into the bump pointer.
- results = copyArrayRefInto(impl.affineAllocator, results);
+ // Copy the results into the bump pointer.
+ results = copyArrayRefInto(impl.affineAllocator, results);
- // Initialize the memory using placement new.
- new (res) detail::AffineMapStorage{dimCount, symbolCount, results, context};
- return AffineMap(res);
- });
+ // Initialize the memory using placement new.
+ new (res)
+ detail::AffineMapStorage{dimCount, symbolCount, results, context};
+ return AffineMap(res);
+ });
}
AffineMap AffineMap::get(MLIRContext *context) {
@@ -760,12 +823,12 @@ IntegerSet IntegerSet::get(unsigned dimCount, unsigned symbolCount,
if (constraints.size() < IntegerSet::kUniquingThreshold) {
auto key = std::make_tuple(dimCount, symbolCount, constraints, eqFlags);
return safeGetOrCreate(impl.integerSets, key, impl.affineMutex,
- constructorFn);
+ impl.threadingIsEnabled, constructorFn);
}
// Otherwise, acquire a writer-lock so that we can safely create the new
// instance.
- llvm::sys::SmartScopedWriter<true> affineLock(impl.affineMutex);
+ ScopedWriterLock affineLock(impl.affineMutex, impl.threadingIsEnabled);
return constructorFn();
}
diff --git a/mlir/lib/Pass/IRPrinting.cpp b/mlir/lib/Pass/IRPrinting.cpp
index ba9ff989ea79..842b83c5d0f7 100644
--- a/mlir/lib/Pass/IRPrinting.cpp
+++ b/mlir/lib/Pass/IRPrinting.cpp
@@ -257,7 +257,8 @@ struct BasicIRPrinterConfig : public PassManager::IRPrinterConfig {
/// Add an instrumentation to print the IR before and after pass execution,
/// using the provided configuration.
void PassManager::enableIRPrinting(std::unique_ptr<IRPrinterConfig> config) {
- if (config->shouldPrintAtModuleScope() && isMultithreadingEnabled())
+ if (config->shouldPrintAtModuleScope() &&
+ getContext()->isMultithreadingEnabled())
llvm::report_fatal_error("IR printing can't be setup on a pass-manager "
"without disabling multi-threading first.");
addInstrumentation(
diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp
index ef4ed760feb8..83855faca755 100644
--- a/mlir/lib/Pass/Pass.cpp
+++ b/mlir/lib/Pass/Pass.cpp
@@ -118,9 +118,8 @@ void VerifierPass::runOnOperation() {
namespace mlir {
namespace detail {
struct OpPassManagerImpl {
- OpPassManagerImpl(OperationName name, bool disableThreads, bool verifyPasses)
- : name(name), disableThreads(disableThreads), verifyPasses(verifyPasses) {
- }
+ OpPassManagerImpl(OperationName name, bool verifyPasses)
+ : name(name), verifyPasses(verifyPasses) {}
/// Merge the passes of this pass manager into the one provided.
void mergeInto(OpPassManagerImpl &rhs);
@@ -152,9 +151,6 @@ struct OpPassManagerImpl {
/// The name of the operation that passes of this pass manager operate on.
OperationName name;
- /// Flag to disable multi-threading of passes.
- bool disableThreads : 1;
-
/// Flag that specifies if the IR should be verified after each pass has run.
bool verifyPasses : 1;
@@ -172,7 +168,7 @@ void OpPassManagerImpl::mergeInto(OpPassManagerImpl &rhs) {
}
OpPassManager &OpPassManagerImpl::nest(const OperationName &nestedName) {
- OpPassManager nested(nestedName, disableThreads, verifyPasses);
+ OpPassManager nested(nestedName, verifyPasses);
auto *adaptor = new OpToOpPassAdaptor(std::move(nested));
addPass(std::unique_ptr<Pass>(adaptor));
return adaptor->getPassManagers().front();
@@ -269,9 +265,8 @@ void OpPassManagerImpl::splitAdaptorPasses() {
// OpPassManager
//===----------------------------------------------------------------------===//
-OpPassManager::OpPassManager(OperationName name, bool disableThreads,
- bool verifyPasses)
- : impl(new OpPassManagerImpl(name, disableThreads, verifyPasses)) {
+OpPassManager::OpPassManager(OperationName name, bool verifyPasses)
+ : impl(new OpPassManagerImpl(name, verifyPasses)) {
assert(name.getAbstractOperation() &&
"OpPassManager can only operate on registered operations");
assert(name.getAbstractOperation()->hasProperty(
@@ -282,8 +277,7 @@ OpPassManager::OpPassManager(OperationName name, bool disableThreads,
OpPassManager::OpPassManager(OpPassManager &&rhs) : impl(std::move(rhs.impl)) {}
OpPassManager::OpPassManager(const OpPassManager &rhs) { *this = rhs; }
OpPassManager &OpPassManager::operator=(const OpPassManager &rhs) {
- impl.reset(new OpPassManagerImpl(rhs.impl->name, rhs.impl->disableThreads,
- rhs.impl->verifyPasses));
+ impl.reset(new OpPassManagerImpl(rhs.impl->name, rhs.impl->verifyPasses));
for (auto &pass : rhs.impl->passes)
impl->passes.emplace_back(pass->clone());
return *this;
@@ -419,10 +413,10 @@ std::string OpToOpPassAdaptor::getAdaptorName() {
/// Run the held pipeline over all nested operations.
void OpToOpPassAdaptor::runOnOperation() {
- if (mgrs.front().getImpl().disableThreads || !llvm::llvm_is_multithreaded())
- runOnOperationImpl();
- else
+ if (getContext().isMultithreadingEnabled())
runOnOperationAsyncImpl();
+ else
+ runOnOperationImpl();
}
/// Run this pass adaptor synchronously.
@@ -576,7 +570,7 @@ struct RecoveryReproducerContext {
/// The filename to use when generating the reproducer.
StringRef filename;
- /// Various pass manager flags.
+ /// Various pass manager and context flags.
bool disableThreads;
bool verifyPasses;
@@ -628,7 +622,7 @@ LogicalResult RecoveryReproducerContext::generate(std::string &error) {
// Output the current pass manager configuration.
outputOS << "// configuration: -pass-pipeline='" << pipeline << "'";
if (disableThreads)
- outputOS << " -disable-pass-threading";
+ outputOS << " -mlir-disable-threading";
// TODO: Should this also be configured with a pass manager flag?
outputOS << "\n// note: verifyPasses=" << (verifyPasses ? "true" : "false")
@@ -684,7 +678,8 @@ LogicalResult
PassManager::runWithCrashRecovery(MutableArrayRef<std::unique_ptr<Pass>> passes,
ModuleOp module, AnalysisManager am) {
RecoveryReproducerContext context(passes, module, *crashReproducerFileName,
- impl->disableThreads, impl->verifyPasses);
+ !getContext()->isMultithreadingEnabled(),
+ impl->verifyPasses);
// Safely invoke the passes within a recovery context.
llvm::CrashRecoveryContext::Enable();
@@ -715,7 +710,7 @@ PassManager::runWithCrashRecovery(MutableArrayRef<std::unique_ptr<Pass>> passes,
PassManager::PassManager(MLIRContext *ctx, bool verifyPasses)
: OpPassManager(OperationName(ModuleOp::getOperationName(), ctx),
- /*disableThreads=*/false, verifyPasses),
+ verifyPasses),
passTiming(false), localReproducer(false) {}
PassManager::~PassManager() {}
@@ -741,15 +736,6 @@ LogicalResult PassManager::run(ModuleOp module) {
return result;
}
-/// Disable support for multi-threading within the pass manager.
-void PassManager::disableMultithreading(bool disable) {
- getImpl().disableThreads = disable;
-}
-
-bool PassManager::isMultithreadingEnabled() {
- return !getImpl().disableThreads;
-}
-
/// Enable support for the pass manager to generate a reproducer on the event
/// of a crash or a pass failure. `outputFile` is a .mlir filename used to write
/// the generated reproducer. If `genLocalReproducer` is true, the pass manager
diff --git a/mlir/lib/Pass/PassManagerOptions.cpp b/mlir/lib/Pass/PassManagerOptions.cpp
index c1b1eee07ea5..b00f992eceb9 100644
--- a/mlir/lib/Pass/PassManagerOptions.cpp
+++ b/mlir/lib/Pass/PassManagerOptions.cpp
@@ -29,14 +29,6 @@ struct PassManagerOptions {
"a reproducer with the smallest pipeline."),
llvm::cl::init(false)};
- //===--------------------------------------------------------------------===//
- // Multi-threading
- //===--------------------------------------------------------------------===//
- llvm::cl::opt<bool> disableThreads{
- "disable-pass-threading",
- llvm::cl::desc("Disable multithreading in the pass manager"),
- llvm::cl::init(false)};
-
//===--------------------------------------------------------------------===//
// IR Printing
//===--------------------------------------------------------------------===//
@@ -164,10 +156,6 @@ void mlir::applyPassManagerCLOptions(PassManager &pm) {
pm.enableCrashReproducerGeneration(options->reproducerFile,
options->localReproducer);
- // Disable multi-threading.
- if (options->disableThreads)
- pm.disableMultithreading();
-
// Enable statistics dumping.
if (options->passStatistics)
pm.enableStatistics(options->passStatisticsDisplayMode);
diff --git a/mlir/lib/Support/StorageUniquer.cpp b/mlir/lib/Support/StorageUniquer.cpp
index d50c599a7776..40304a544c4f 100644
--- a/mlir/lib/Support/StorageUniquer.cpp
+++ b/mlir/lib/Support/StorageUniquer.cpp
@@ -46,6 +46,8 @@ struct StorageUniquerImpl {
function_ref<bool(const BaseStorage *)> isEqual,
function_ref<BaseStorage *(StorageAllocator &)> ctorFn) {
LookupKey lookupKey{kind, hashValue, isEqual};
+ if (!threadingIsEnabled)
+ return getOrCreateUnsafe(kind, hashValue, lookupKey, ctorFn);
// Check for an existing instance in read-only mode.
{
@@ -57,9 +59,12 @@ struct StorageUniquerImpl {
// Acquire a writer-lock so that we can safely create the new type instance.
llvm::sys::SmartScopedWriter<true> typeLock(mutex);
-
- // Check for an existing instance again here, because another writer thread
- // may have already created one.
+ return getOrCreateUnsafe(kind, hashValue, lookupKey, ctorFn);
+ }
+ /// Get or create an instance of a complex derived type in an unsafe fashion.
+ BaseStorage *
+ getOrCreateUnsafe(unsigned kind, unsigned hashValue, LookupKey &lookupKey,
+ function_ref<BaseStorage *(StorageAllocator &)> ctorFn) {
auto existing = storageTypes.insert_as({}, lookupKey);
if (!existing.second)
return existing.first->storage;
@@ -75,6 +80,9 @@ struct StorageUniquerImpl {
BaseStorage *
getOrCreate(unsigned kind,
function_ref<BaseStorage *(StorageAllocator &)> ctorFn) {
+ if (!threadingIsEnabled)
+ return getOrCreateUnsafe(kind, ctorFn);
+
// Check for an existing instance in read-only mode.
{
llvm::sys::SmartScopedReader<true> typeLock(mutex);
@@ -85,9 +93,12 @@ struct StorageUniquerImpl {
// Acquire a writer-lock so that we can safely create the new type instance.
llvm::sys::SmartScopedWriter<true> typeLock(mutex);
-
- // Check for an existing instance again here, because another writer thread
- // may have already created one.
+ return getOrCreateUnsafe(kind, ctorFn);
+ }
+ /// Get or create an instance of a simple derived type in an unsafe fashion.
+ BaseStorage *
+ getOrCreateUnsafe(unsigned kind,
+ function_ref<BaseStorage *(StorageAllocator &)> ctorFn) {
auto &result = simpleTypes[kind];
if (result)
return result;
@@ -152,18 +163,21 @@ struct StorageUniquerImpl {
}
};
- // Unique types with specific hashing or storage constraints.
+ /// Unique types with specific hashing or storage constraints.
using StorageTypeSet = DenseSet<HashedStorage, StorageKeyInfo>;
StorageTypeSet storageTypes;
- // Unique types with just the kind.
+ /// Unique types with just the kind.
DenseMap<unsigned, BaseStorage *> simpleTypes;
- // Allocator to use when constructing derived type instances.
+ /// Allocator to use when constructing derived type instances.
StorageUniquer::StorageAllocator allocator;
- // A mutex to keep type uniquing thread-safe.
+ /// A mutex to keep type uniquing thread-safe.
llvm::sys::SmartRWMutex<true> mutex;
+
+ /// Flag specifying if multi-threading is enabled within the uniquer.
+ bool threadingIsEnabled = true;
};
} // end namespace detail
} // namespace mlir
@@ -171,6 +185,11 @@ struct StorageUniquerImpl {
StorageUniquer::StorageUniquer() : impl(new StorageUniquerImpl()) {}
StorageUniquer::~StorageUniquer() {}
+/// Set the flag specifying if multi-threading is disabled within the uniquer.
+void StorageUniquer::disableMultithreading(bool disable) {
+ impl->threadingIsEnabled = !disable;
+}
+
/// Implementation for getting/creating an instance of a derived type with
/// complex storage.
auto StorageUniquer::getImpl(
diff --git a/mlir/lib/Transforms/Inliner.cpp b/mlir/lib/Transforms/Inliner.cpp
index c0f89da300f1..ee645cb55511 100644
--- a/mlir/lib/Transforms/Inliner.cpp
+++ b/mlir/lib/Transforms/Inliner.cpp
@@ -496,22 +496,28 @@ static void canonicalizeSCC(CallGraph &cg, CGUseList &useList,
// NOTE: This is simple now, because we don't enable canonicalizing nodes
// within children. When we remove this restriction, this logic will need to
// be reworked.
- ParallelDiagnosticHandler canonicalizationHandler(context);
- llvm::parallel::for_each_n(
- llvm::parallel::par, /*Begin=*/size_t(0),
- /*End=*/nodesToCanonicalize.size(), [&](size_t index) {
- // Set the order for this thread so that diagnostics will be properly
- // ordered.
- canonicalizationHandler.setOrderIDForThread(index);
-
- // Apply the canonicalization patterns to this region.
- auto *node = nodesToCanonicalize[index];
- applyPatternsAndFoldGreedily(*node->getCallableRegion(), canonPatterns);
-
- // Make sure to reset the order ID for the diagnostic handler, as this
- // thread may be used in a
diff erent context.
- canonicalizationHandler.eraseOrderIDForThread();
- });
+ if (context->isMultithreadingEnabled()) {
+ ParallelDiagnosticHandler canonicalizationHandler(context);
+ llvm::parallel::for_each_n(
+ llvm::parallel::par, /*Begin=*/size_t(0),
+ /*End=*/nodesToCanonicalize.size(), [&](size_t index) {
+ // Set the order for this thread so that diagnostics will be properly
+ // ordered.
+ canonicalizationHandler.setOrderIDForThread(index);
+
+ // Apply the canonicalization patterns to this region.
+ auto *node = nodesToCanonicalize[index];
+ applyPatternsAndFoldGreedily(*node->getCallableRegion(),
+ canonPatterns);
+
+ // Make sure to reset the order ID for the diagnostic handler, as this
+ // thread may be used in a
diff erent context.
+ canonicalizationHandler.eraseOrderIDForThread();
+ });
+ } else {
+ for (CallGraphNode *node : nodesToCanonicalize)
+ applyPatternsAndFoldGreedily(*node->getCallableRegion(), canonPatterns);
+ }
// Recompute the uses held by each of the nodes.
for (CallGraphNode *node : nodesToCanonicalize)
diff --git a/mlir/test/Dialect/SPIRV/availability.mlir b/mlir/test/Dialect/SPIRV/availability.mlir
index e31c1bdeacca..322cc533c826 100644
--- a/mlir/test/Dialect/SPIRV/availability.mlir
+++ b/mlir/test/Dialect/SPIRV/availability.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -disable-pass-threading -test-spirv-op-availability %s | FileCheck %s
+// RUN: mlir-opt -mlir-disable-threading -test-spirv-op-availability %s | FileCheck %s
// CHECK-LABEL: iadd
func @iadd(%arg: i32) -> i32 {
diff --git a/mlir/test/Dialect/SPIRV/target-env.mlir b/mlir/test/Dialect/SPIRV/target-env.mlir
index 9b42314e3f1d..27c4e8d04092 100644
--- a/mlir/test/Dialect/SPIRV/target-env.mlir
+++ b/mlir/test/Dialect/SPIRV/target-env.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -disable-pass-threading -test-spirv-target-env %s | FileCheck %s
+// RUN: mlir-opt -mlir-disable-threading -test-spirv-target-env %s | FileCheck %s
// Note: The following tests check that a spv.target_env can properly control
// the conversion target and filter unavailable ops during the conversion.
diff --git a/mlir/test/IR/test-matchers.mlir b/mlir/test/IR/test-matchers.mlir
index 60d5bcf7d81b..925b01bda110 100644
--- a/mlir/test/IR/test-matchers.mlir
+++ b/mlir/test/IR/test-matchers.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -disable-pass-threading=true -test-matchers -o /dev/null 2>&1 | FileCheck %s
+// RUN: mlir-opt %s -mlir-disable-threading=true -test-matchers -o /dev/null 2>&1 | FileCheck %s
func @test1(%a: f32, %b: f32, %c: f32) {
%0 = addf %a, %b: f32
diff --git a/mlir/test/Pass/ir-printing.mlir b/mlir/test/Pass/ir-printing.mlir
index 892dc40a034a..8bb86b36c181 100644
--- a/mlir/test/Pass/ir-printing.mlir
+++ b/mlir/test/Pass/ir-printing.mlir
@@ -1,9 +1,9 @@
-// RUN: mlir-opt %s -disable-pass-threading=true -pass-pipeline='func(cse,canonicalize)' -print-ir-before=cse -o /dev/null 2>&1 | FileCheck -check-prefix=BEFORE %s
-// RUN: mlir-opt %s -disable-pass-threading=true -pass-pipeline='func(cse,canonicalize)' -print-ir-before-all -o /dev/null 2>&1 | FileCheck -check-prefix=BEFORE_ALL %s
-// RUN: mlir-opt %s -disable-pass-threading=true -pass-pipeline='func(cse,canonicalize)' -print-ir-after=cse -o /dev/null 2>&1 | FileCheck -check-prefix=AFTER %s
-// RUN: mlir-opt %s -disable-pass-threading=true -pass-pipeline='func(cse,canonicalize)' -print-ir-after-all -o /dev/null 2>&1 | FileCheck -check-prefix=AFTER_ALL %s
-// RUN: mlir-opt %s -disable-pass-threading=true -pass-pipeline='func(cse,canonicalize)' -print-ir-before=cse -print-ir-module-scope -o /dev/null 2>&1 | FileCheck -check-prefix=BEFORE_MODULE %s
-// RUN: mlir-opt %s -disable-pass-threading=true -pass-pipeline='func(cse,cse)' -print-ir-after-all -print-ir-after-change -o /dev/null 2>&1 | FileCheck -check-prefix=AFTER_ALL_CHANGE %s
+// RUN: mlir-opt %s -mlir-disable-threading=true -pass-pipeline='func(cse,canonicalize)' -print-ir-before=cse -o /dev/null 2>&1 | FileCheck -check-prefix=BEFORE %s
+// RUN: mlir-opt %s -mlir-disable-threading=true -pass-pipeline='func(cse,canonicalize)' -print-ir-before-all -o /dev/null 2>&1 | FileCheck -check-prefix=BEFORE_ALL %s
+// RUN: mlir-opt %s -mlir-disable-threading=true -pass-pipeline='func(cse,canonicalize)' -print-ir-after=cse -o /dev/null 2>&1 | FileCheck -check-prefix=AFTER %s
+// RUN: mlir-opt %s -mlir-disable-threading=true -pass-pipeline='func(cse,canonicalize)' -print-ir-after-all -o /dev/null 2>&1 | FileCheck -check-prefix=AFTER_ALL %s
+// RUN: mlir-opt %s -mlir-disable-threading=true -pass-pipeline='func(cse,canonicalize)' -print-ir-before=cse -print-ir-module-scope -o /dev/null 2>&1 | FileCheck -check-prefix=BEFORE_MODULE %s
+// RUN: mlir-opt %s -mlir-disable-threading=true -pass-pipeline='func(cse,cse)' -print-ir-after-all -print-ir-after-change -o /dev/null 2>&1 | FileCheck -check-prefix=AFTER_ALL_CHANGE %s
func @foo() {
%0 = constant 0 : i32
diff --git a/mlir/test/Pass/pass-timing.mlir b/mlir/test/Pass/pass-timing.mlir
index db39ad6a1633..6cd8a29e6f48 100644
--- a/mlir/test/Pass/pass-timing.mlir
+++ b/mlir/test/Pass/pass-timing.mlir
@@ -1,8 +1,8 @@
-// RUN: mlir-opt %s -disable-pass-threading=true -verify-each=true -pass-pipeline='func(cse,canonicalize,cse)' -pass-timing -pass-timing-display=list 2>&1 | FileCheck -check-prefix=LIST %s
-// RUN: mlir-opt %s -disable-pass-threading=true -verify-each=true -pass-pipeline='func(cse,canonicalize,cse)' -pass-timing -pass-timing-display=pipeline 2>&1 | FileCheck -check-prefix=PIPELINE %s
-// RUN: mlir-opt %s -disable-pass-threading=false -verify-each=true -pass-pipeline='func(cse,canonicalize,cse)' -pass-timing -pass-timing-display=list 2>&1 | FileCheck -check-prefix=MT_LIST %s
-// RUN: mlir-opt %s -disable-pass-threading=false -verify-each=true -pass-pipeline='func(cse,canonicalize,cse)' -pass-timing -pass-timing-display=pipeline 2>&1 | FileCheck -check-prefix=MT_PIPELINE %s
-// RUN: mlir-opt %s -disable-pass-threading=false -verify-each=false -test-pm-nested-pipeline -pass-timing -pass-timing-display=pipeline 2>&1 | FileCheck -check-prefix=NESTED_MT_PIPELINE %s
+// RUN: mlir-opt %s -mlir-disable-threading=true -verify-each=true -pass-pipeline='func(cse,canonicalize,cse)' -pass-timing -pass-timing-display=list 2>&1 | FileCheck -check-prefix=LIST %s
+// RUN: mlir-opt %s -mlir-disable-threading=true -verify-each=true -pass-pipeline='func(cse,canonicalize,cse)' -pass-timing -pass-timing-display=pipeline 2>&1 | FileCheck -check-prefix=PIPELINE %s
+// RUN: mlir-opt %s -mlir-disable-threading=false -verify-each=true -pass-pipeline='func(cse,canonicalize,cse)' -pass-timing -pass-timing-display=list 2>&1 | FileCheck -check-prefix=MT_LIST %s
+// RUN: mlir-opt %s -mlir-disable-threading=false -verify-each=true -pass-pipeline='func(cse,canonicalize,cse)' -pass-timing -pass-timing-display=pipeline 2>&1 | FileCheck -check-prefix=MT_PIPELINE %s
+// RUN: mlir-opt %s -mlir-disable-threading=false -verify-each=false -test-pm-nested-pipeline -pass-timing -pass-timing-display=pipeline 2>&1 | FileCheck -check-prefix=NESTED_MT_PIPELINE %s
// LIST: Pass execution timing report
// LIST: Total Execution Time:
More information about the Mlir-commits
mailing list