[Mlir-commits] [mlir] 0f304ef - [mlir] Add asserts when changing various MLIRContext configurations

River Riddle llvmlistbot at llvm.org
Fri Apr 15 21:49:37 PDT 2022


Author: River Riddle
Date: 2022-04-15T21:49:03-07:00
New Revision: 0f304ef0170231b860a249f34e07f50686392253

URL: https://github.com/llvm/llvm-project/commit/0f304ef0170231b860a249f34e07f50686392253
DIFF: https://github.com/llvm/llvm-project/commit/0f304ef0170231b860a249f34e07f50686392253.diff

LOG: [mlir] Add asserts when changing various MLIRContext configurations

This helps to prevent tsan failures when users inadvertantly mutate the
context in a non-safe way.

Differential Revision: https://reviews.llvm.org/D112021

Added: 
    

Modified: 
    mlir/include/mlir/IR/DialectRegistry.h
    mlir/lib/IR/Dialect.cpp
    mlir/lib/IR/MLIRContext.cpp
    mlir/lib/Reducer/OptReductionPass.cpp
    mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/DialectRegistry.h b/mlir/include/mlir/IR/DialectRegistry.h
index fbc81b0f04499..5e55fab3510e0 100644
--- a/mlir/include/mlir/IR/DialectRegistry.h
+++ b/mlir/include/mlir/IR/DialectRegistry.h
@@ -212,6 +212,10 @@ class DialectRegistry {
     addExtension(std::make_unique<Extension>(std::move(extensionFn)));
   }
 
+  /// Returns true if the current registry is a subset of 'rhs', i.e. if 'rhs'
+  /// contains all of the components of this registry.
+  bool isSubsetOf(const DialectRegistry &rhs) const;
+
 private:
   MapTy registry;
   std::vector<std::unique_ptr<DialectExtensionBase>> extensions;

diff  --git a/mlir/lib/IR/Dialect.cpp b/mlir/lib/IR/Dialect.cpp
index 2e983d641a9ec..b8f5aa29c31f5 100644
--- a/mlir/lib/IR/Dialect.cpp
+++ b/mlir/lib/IR/Dialect.cpp
@@ -228,3 +228,12 @@ void DialectRegistry::applyExtensions(MLIRContext *ctx) const {
   for (const auto &extension : extensions)
     applyExtension(*extension);
 }
+
+bool DialectRegistry::isSubsetOf(const DialectRegistry &rhs) const {
+  // Treat any extensions conservatively.
+  if (!extensions.empty())
+    return false;
+  // Check that the current dialects fully overlap with the dialects in 'rhs'.
+  return llvm::all_of(
+      registry, [&](const auto &it) { return rhs.registry.count(it.first); });
+}

diff  --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index 2c0b3ba049d7b..eae636347271a 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -355,6 +355,12 @@ DiagnosticEngine &MLIRContext::getDiagEngine() { return getImpl().diagEngine; }
 //===----------------------------------------------------------------------===//
 
 void MLIRContext::appendDialectRegistry(const DialectRegistry &registry) {
+  if (registry.isSubsetOf(impl->dialectsRegistry))
+    return;
+
+  assert(impl->multiThreadedExecutionContext == 0 &&
+         "appending to the MLIRContext dialect registry while in a "
+         "multi-threaded execution context");
   registry.appendTo(impl->dialectsRegistry);
 
   // For the already loaded dialects, apply any possible extensions immediately.
@@ -470,6 +476,9 @@ bool MLIRContext::allowsUnregisteredDialects() {
 }
 
 void MLIRContext::allowUnregisteredDialects(bool allowing) {
+  assert(impl->multiThreadedExecutionContext == 0 &&
+         "changing MLIRContext `allow-unregistered-dialects` configuration "
+         "while in a multi-threaded execution context");
   impl->allowUnregisteredDialects = allowing;
 }
 
@@ -484,6 +493,9 @@ void MLIRContext::disableMultithreading(bool disable) {
   // --mlir-disable-threading
   if (isThreadingGloballyDisabled())
     return;
+  assert(impl->multiThreadedExecutionContext == 0 &&
+         "changing MLIRContext `disable-threading` configuration while "
+         "in a multi-threaded execution context");
 
   impl->threadingIsEnabled = !disable;
 
@@ -557,6 +569,9 @@ bool MLIRContext::shouldPrintOpOnDiagnostic() {
 /// Set the flag specifying if we should attach the operation to diagnostics
 /// emitted via Operation::emit.
 void MLIRContext::printOpOnDiagnostic(bool enable) {
+  assert(impl->multiThreadedExecutionContext == 0 &&
+         "changing MLIRContext `print-op-on-diagnostic` configuration while in "
+         "a multi-threaded execution context");
   impl->printOpOnDiagnostic = enable;
 }
 
@@ -569,6 +584,9 @@ bool MLIRContext::shouldPrintStackTraceOnDiagnostic() {
 /// Set the flag specifying if we should attach the current stacktrace when
 /// emitting diagnostics.
 void MLIRContext::printStackTraceOnDiagnostic(bool enable) {
+  assert(impl->multiThreadedExecutionContext == 0 &&
+         "changing MLIRContext `print-stacktrace-on-diagnostic` configuration "
+         "while in a multi-threaded execution context");
   impl->printStackTraceOnDiagnostic = enable;
 }
 

diff  --git a/mlir/lib/Reducer/OptReductionPass.cpp b/mlir/lib/Reducer/OptReductionPass.cpp
index 806ce67e529a6..a7f09b4c923f6 100644
--- a/mlir/lib/Reducer/OptReductionPass.cpp
+++ b/mlir/lib/Reducer/OptReductionPass.cpp
@@ -42,7 +42,7 @@ void OptReductionPass::runOnOperation() {
   ModuleOp module = this->getOperation();
   ModuleOp moduleVariant = module.clone();
 
-  PassManager passManager(module.getContext());
+  OpPassManager passManager("builtin.module");
   if (failed(parsePassPipeline(optPass, passManager))) {
     module.emitError() << "\nfailed to parse pass pipeline";
     return signalPassFailure();
@@ -54,7 +54,13 @@ void OptReductionPass::runOnOperation() {
     return signalPassFailure();
   }
 
-  if (failed(passManager.run(moduleVariant))) {
+  // Temporarily push the variant under the main module and execute the pipeline
+  // on it.
+  module.getBody()->push_back(moduleVariant);
+  LogicalResult pipelineResult = runPipeline(passManager, moduleVariant);
+  moduleVariant->remove();
+
+  if (failed(pipelineResult)) {
     module.emitError() << "\nfailed to run pass pipeline";
     return signalPassFailure();
   }

diff  --git a/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp
index 88ccc77ab59a2..573197a584e0e 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp
@@ -255,14 +255,13 @@ struct TestLinalgGreedyFusion
     patterns.add<ExtractSliceOfPadTensorSwapPattern>(context);
     scf::populateSCFForLoopCanonicalizationPatterns(patterns);
     FrozenRewritePatternSet frozenPatterns(std::move(patterns));
+    OpPassManager pm(FuncOp::getOperationName());
+    pm.addPass(createLoopInvariantCodeMotionPass());
+    pm.addPass(createCanonicalizerPass());
+    pm.addPass(createCSEPass());
     do {
       (void)applyPatternsAndFoldGreedily(getOperation(), frozenPatterns);
-      PassManager pm(context);
-      pm.addPass(createLoopInvariantCodeMotionPass());
-      pm.addPass(createCanonicalizerPass());
-      pm.addPass(createCSEPass());
-      LogicalResult res = pm.run(getOperation()->getParentOfType<ModuleOp>());
-      if (failed(res))
+      if (failed(runPipeline(pm, getOperation())))
         this->signalPassFailure();
     } while (succeeded(fuseLinalgOpsGreedily(getOperation())));
   }


        


More information about the Mlir-commits mailing list