[Mlir-commits] [mlir] 6107069 - Add an assertion to protect against missing Dialect registration in a pass pipeline (NFC)

Mehdi Amini llvmlistbot at llvm.org
Sun Aug 23 23:49:48 PDT 2020


Author: Mehdi Amini
Date: 2020-08-24T06:49:29Z
New Revision: 610706906ae218eaff5b996f64554be7b279e4f0

URL: https://github.com/llvm/llvm-project/commit/610706906ae218eaff5b996f64554be7b279e4f0
DIFF: https://github.com/llvm/llvm-project/commit/610706906ae218eaff5b996f64554be7b279e4f0.diff

LOG: Add an assertion to protect against missing Dialect registration in a pass pipeline (NFC)

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D86327

Added: 
    

Modified: 
    mlir/include/mlir/IR/MLIRContext.h
    mlir/lib/IR/MLIRContext.cpp
    mlir/lib/Pass/Pass.cpp
    mlir/test/lib/Transforms/TestConvertCallOp.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/MLIRContext.h b/mlir/include/mlir/IR/MLIRContext.h
index e8a5d6e6d236..b7fe642166dd 100644
--- a/mlir/include/mlir/IR/MLIRContext.h
+++ b/mlir/include/mlir/IR/MLIRContext.h
@@ -156,6 +156,12 @@ class MLIRContext {
   /// instances. This should not be used directly.
   StorageUniquer &getAttributeUniquer();
 
+  /// These APIs are tracking whether the context will be used in a
+  /// multithreading environment: this has no effect other than enabling
+  /// assertions on misuses of some APIs.
+  void enterMultiThreadedExecution();
+  void exitMultiThreadedExecution();
+
 private:
   const std::unique_ptr<MLIRContextImpl> impl;
 

diff  --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index b47c143fbfa8..61dafc3e39f8 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -264,6 +264,13 @@ class MLIRContextImpl {
   /// Enable support for multi-threading within MLIR.
   bool threadingIsEnabled = true;
 
+  /// Track if we are currently executing in a threaded execution environment
+  /// (like the pass-manager): this is only a debugging feature to help reducing
+  /// the chances of data races one some context APIs.
+#ifndef NDEBUG
+  std::atomic<int> multiThreadedExecutionContext{0};
+#endif
+
   /// If the operation should be attached to diagnostics printed via the
   /// Operation::emit methods.
   bool printOpOnDiagnostic = true;
@@ -487,6 +494,15 @@ MLIRContext::getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID,
   if (!dialect) {
     LLVM_DEBUG(llvm::dbgs()
                << "Load new dialect in Context" << dialectNamespace);
+#ifndef NDEBUG
+    if (impl.multiThreadedExecutionContext != 0) {
+      llvm::errs() << "Loading a dialect (" << dialectNamespace
+                   << ") while in a multi-threaded execution context (maybe "
+                      "the PassManager): this can indicate a "
+                      "missing `dependentDialects` in a pass for example.";
+      abort();
+    }
+#endif
     dialect = ctor();
     assert(dialect && "dialect ctor failed");
     return dialect.get();
@@ -527,6 +543,17 @@ void MLIRContext::disableMultithreading(bool disable) {
   impl->typeUniquer.disableMultithreading(disable);
 }
 
+void MLIRContext::enterMultiThreadedExecution() {
+#ifndef NDEBUG
+  ++impl->multiThreadedExecutionContext;
+#endif
+}
+void MLIRContext::exitMultiThreadedExecution() {
+#ifndef NDEBUG
+  --impl->multiThreadedExecutionContext;
+#endif
+}
+
 /// Return true if we should attach the operation to diagnostics emitted via
 /// Operation::emit.
 bool MLIRContext::shouldPrintOpOnDiagnostic() {
@@ -583,6 +610,9 @@ void Dialect::addOperation(AbstractOperation opInfo) {
          "op name doesn't start with dialect namespace");
   assert(&opInfo.dialect == this && "Dialect object mismatch");
   auto &impl = context->getImpl();
+  assert(impl.multiThreadedExecutionContext == 0 &&
+         "Registering a new operation kind while in a multi-threaded execution "
+         "context");
   StringRef opName = opInfo.name;
   if (!impl.registeredOperations.insert({opName, std::move(opInfo)}).second) {
     llvm::errs() << "error: operation named '" << opInfo.name
@@ -593,6 +623,9 @@ void Dialect::addOperation(AbstractOperation opInfo) {
 
 void Dialect::addType(TypeID typeID, AbstractType &&typeInfo) {
   auto &impl = context->getImpl();
+  assert(impl.multiThreadedExecutionContext == 0 &&
+         "Registering a new type kind while in a multi-threaded execution "
+         "context");
   auto *newInfo =
       new (impl.abstractDialectSymbolAllocator.Allocate<AbstractType>())
           AbstractType(std::move(typeInfo));
@@ -602,6 +635,9 @@ void Dialect::addType(TypeID typeID, AbstractType &&typeInfo) {
 
 void Dialect::addAttribute(TypeID typeID, AbstractAttribute &&attrInfo) {
   auto &impl = context->getImpl();
+  assert(impl.multiThreadedExecutionContext == 0 &&
+         "Registering a new attribute kind while in a multi-threaded execution "
+         "context");
   auto *newInfo =
       new (impl.abstractDialectSymbolAllocator.Allocate<AbstractAttribute>())
           AbstractAttribute(std::move(attrInfo));

diff  --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp
index 9bc23c2e4a65..fe5223b7daef 100644
--- a/mlir/lib/Pass/Pass.cpp
+++ b/mlir/lib/Pass/Pass.cpp
@@ -751,12 +751,18 @@ LogicalResult PassManager::run(ModuleOp module) {
   // Construct an analysis manager for the pipeline.
   ModuleAnalysisManager am(module, instrumentor.get());
 
+  // Notify the context that we start running a pipeline for book keeping.
+  module.getContext()->enterMultiThreadedExecution();
+
   // If reproducer generation is enabled, run the pass manager with crash
   // handling enabled.
   LogicalResult result = crashReproducerFileName
                              ? runWithCrashRecovery(module, am)
                              : OpPassManager::run(module, am);
 
+  // Notify the context that the run is done.
+  module.getContext()->exitMultiThreadedExecution();
+
   // Dump all of the pass statistics if necessary.
   if (passStatisticsMode)
     dumpStatistics();

diff  --git a/mlir/test/lib/Transforms/TestConvertCallOp.cpp b/mlir/test/lib/Transforms/TestConvertCallOp.cpp
index 6cb596bfc71a..980c54258ab8 100644
--- a/mlir/test/lib/Transforms/TestConvertCallOp.cpp
+++ b/mlir/test/lib/Transforms/TestConvertCallOp.cpp
@@ -34,6 +34,10 @@ class TestTypeProducerOpConverter
 class TestConvertCallOp
     : public PassWrapper<TestConvertCallOp, OperationPass<ModuleOp>> {
 public:
+  void getDependentDialects(DialectRegistry &registry) const final {
+    registry.insert<LLVM::LLVMDialect>();
+  }
+
   void runOnOperation() override {
     ModuleOp m = getOperation();
 


        


More information about the Mlir-commits mailing list