[Mlir-commits] [mlir] 6107069 - Add an assertion to protect against missing Dialect registration in a pass pipeline (NFC)
Mehdi Amini
llvmlistbot at llvm.org
Sun Aug 23 23:49:48 PDT 2020
Author: Mehdi Amini
Date: 2020-08-24T06:49:29Z
New Revision: 610706906ae218eaff5b996f64554be7b279e4f0
URL: https://github.com/llvm/llvm-project/commit/610706906ae218eaff5b996f64554be7b279e4f0
DIFF: https://github.com/llvm/llvm-project/commit/610706906ae218eaff5b996f64554be7b279e4f0.diff
LOG: Add an assertion to protect against missing Dialect registration in a pass pipeline (NFC)
Reviewed By: rriddle
Differential Revision: https://reviews.llvm.org/D86327
Added:
Modified:
mlir/include/mlir/IR/MLIRContext.h
mlir/lib/IR/MLIRContext.cpp
mlir/lib/Pass/Pass.cpp
mlir/test/lib/Transforms/TestConvertCallOp.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/MLIRContext.h b/mlir/include/mlir/IR/MLIRContext.h
index e8a5d6e6d236..b7fe642166dd 100644
--- a/mlir/include/mlir/IR/MLIRContext.h
+++ b/mlir/include/mlir/IR/MLIRContext.h
@@ -156,6 +156,12 @@ class MLIRContext {
/// instances. This should not be used directly.
StorageUniquer &getAttributeUniquer();
+ /// These APIs are tracking whether the context will be used in a
+ /// multithreading environment: this has no effect other than enabling
+ /// assertions on misuses of some APIs.
+ void enterMultiThreadedExecution();
+ void exitMultiThreadedExecution();
+
private:
const std::unique_ptr<MLIRContextImpl> impl;
diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index b47c143fbfa8..61dafc3e39f8 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -264,6 +264,13 @@ class MLIRContextImpl {
/// Enable support for multi-threading within MLIR.
bool threadingIsEnabled = true;
+ /// Track if we are currently executing in a threaded execution environment
+ /// (like the pass-manager): this is only a debugging feature to help reducing
+ /// the chances of data races one some context APIs.
+#ifndef NDEBUG
+ std::atomic<int> multiThreadedExecutionContext{0};
+#endif
+
/// If the operation should be attached to diagnostics printed via the
/// Operation::emit methods.
bool printOpOnDiagnostic = true;
@@ -487,6 +494,15 @@ MLIRContext::getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID,
if (!dialect) {
LLVM_DEBUG(llvm::dbgs()
<< "Load new dialect in Context" << dialectNamespace);
+#ifndef NDEBUG
+ if (impl.multiThreadedExecutionContext != 0) {
+ llvm::errs() << "Loading a dialect (" << dialectNamespace
+ << ") while in a multi-threaded execution context (maybe "
+ "the PassManager): this can indicate a "
+ "missing `dependentDialects` in a pass for example.";
+ abort();
+ }
+#endif
dialect = ctor();
assert(dialect && "dialect ctor failed");
return dialect.get();
@@ -527,6 +543,17 @@ void MLIRContext::disableMultithreading(bool disable) {
impl->typeUniquer.disableMultithreading(disable);
}
+void MLIRContext::enterMultiThreadedExecution() {
+#ifndef NDEBUG
+ ++impl->multiThreadedExecutionContext;
+#endif
+}
+void MLIRContext::exitMultiThreadedExecution() {
+#ifndef NDEBUG
+ --impl->multiThreadedExecutionContext;
+#endif
+}
+
/// Return true if we should attach the operation to diagnostics emitted via
/// Operation::emit.
bool MLIRContext::shouldPrintOpOnDiagnostic() {
@@ -583,6 +610,9 @@ void Dialect::addOperation(AbstractOperation opInfo) {
"op name doesn't start with dialect namespace");
assert(&opInfo.dialect == this && "Dialect object mismatch");
auto &impl = context->getImpl();
+ assert(impl.multiThreadedExecutionContext == 0 &&
+ "Registering a new operation kind while in a multi-threaded execution "
+ "context");
StringRef opName = opInfo.name;
if (!impl.registeredOperations.insert({opName, std::move(opInfo)}).second) {
llvm::errs() << "error: operation named '" << opInfo.name
@@ -593,6 +623,9 @@ void Dialect::addOperation(AbstractOperation opInfo) {
void Dialect::addType(TypeID typeID, AbstractType &&typeInfo) {
auto &impl = context->getImpl();
+ assert(impl.multiThreadedExecutionContext == 0 &&
+ "Registering a new type kind while in a multi-threaded execution "
+ "context");
auto *newInfo =
new (impl.abstractDialectSymbolAllocator.Allocate<AbstractType>())
AbstractType(std::move(typeInfo));
@@ -602,6 +635,9 @@ void Dialect::addType(TypeID typeID, AbstractType &&typeInfo) {
void Dialect::addAttribute(TypeID typeID, AbstractAttribute &&attrInfo) {
auto &impl = context->getImpl();
+ assert(impl.multiThreadedExecutionContext == 0 &&
+ "Registering a new attribute kind while in a multi-threaded execution "
+ "context");
auto *newInfo =
new (impl.abstractDialectSymbolAllocator.Allocate<AbstractAttribute>())
AbstractAttribute(std::move(attrInfo));
diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp
index 9bc23c2e4a65..fe5223b7daef 100644
--- a/mlir/lib/Pass/Pass.cpp
+++ b/mlir/lib/Pass/Pass.cpp
@@ -751,12 +751,18 @@ LogicalResult PassManager::run(ModuleOp module) {
// Construct an analysis manager for the pipeline.
ModuleAnalysisManager am(module, instrumentor.get());
+ // Notify the context that we start running a pipeline for book keeping.
+ module.getContext()->enterMultiThreadedExecution();
+
// If reproducer generation is enabled, run the pass manager with crash
// handling enabled.
LogicalResult result = crashReproducerFileName
? runWithCrashRecovery(module, am)
: OpPassManager::run(module, am);
+ // Notify the context that the run is done.
+ module.getContext()->exitMultiThreadedExecution();
+
// Dump all of the pass statistics if necessary.
if (passStatisticsMode)
dumpStatistics();
diff --git a/mlir/test/lib/Transforms/TestConvertCallOp.cpp b/mlir/test/lib/Transforms/TestConvertCallOp.cpp
index 6cb596bfc71a..980c54258ab8 100644
--- a/mlir/test/lib/Transforms/TestConvertCallOp.cpp
+++ b/mlir/test/lib/Transforms/TestConvertCallOp.cpp
@@ -34,6 +34,10 @@ class TestTypeProducerOpConverter
class TestConvertCallOp
: public PassWrapper<TestConvertCallOp, OperationPass<ModuleOp>> {
public:
+ void getDependentDialects(DialectRegistry ®istry) const final {
+ registry.insert<LLVM::LLVMDialect>();
+ }
+
void runOnOperation() override {
ModuleOp m = getOperation();
More information about the Mlir-commits
mailing list