[Mlir-commits] [mlir] 0f304ef - [mlir] Add asserts when changing various MLIRContext configurations
River Riddle
llvmlistbot at llvm.org
Fri Apr 15 21:49:37 PDT 2022
Author: River Riddle
Date: 2022-04-15T21:49:03-07:00
New Revision: 0f304ef0170231b860a249f34e07f50686392253
URL: https://github.com/llvm/llvm-project/commit/0f304ef0170231b860a249f34e07f50686392253
DIFF: https://github.com/llvm/llvm-project/commit/0f304ef0170231b860a249f34e07f50686392253.diff
LOG: [mlir] Add asserts when changing various MLIRContext configurations
This helps to prevent tsan failures when users inadvertantly mutate the
context in a non-safe way.
Differential Revision: https://reviews.llvm.org/D112021
Added:
Modified:
mlir/include/mlir/IR/DialectRegistry.h
mlir/lib/IR/Dialect.cpp
mlir/lib/IR/MLIRContext.cpp
mlir/lib/Reducer/OptReductionPass.cpp
mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/DialectRegistry.h b/mlir/include/mlir/IR/DialectRegistry.h
index fbc81b0f04499..5e55fab3510e0 100644
--- a/mlir/include/mlir/IR/DialectRegistry.h
+++ b/mlir/include/mlir/IR/DialectRegistry.h
@@ -212,6 +212,10 @@ class DialectRegistry {
addExtension(std::make_unique<Extension>(std::move(extensionFn)));
}
+ /// Returns true if the current registry is a subset of 'rhs', i.e. if 'rhs'
+ /// contains all of the components of this registry.
+ bool isSubsetOf(const DialectRegistry &rhs) const;
+
private:
MapTy registry;
std::vector<std::unique_ptr<DialectExtensionBase>> extensions;
diff --git a/mlir/lib/IR/Dialect.cpp b/mlir/lib/IR/Dialect.cpp
index 2e983d641a9ec..b8f5aa29c31f5 100644
--- a/mlir/lib/IR/Dialect.cpp
+++ b/mlir/lib/IR/Dialect.cpp
@@ -228,3 +228,12 @@ void DialectRegistry::applyExtensions(MLIRContext *ctx) const {
for (const auto &extension : extensions)
applyExtension(*extension);
}
+
+bool DialectRegistry::isSubsetOf(const DialectRegistry &rhs) const {
+ // Treat any extensions conservatively.
+ if (!extensions.empty())
+ return false;
+ // Check that the current dialects fully overlap with the dialects in 'rhs'.
+ return llvm::all_of(
+ registry, [&](const auto &it) { return rhs.registry.count(it.first); });
+}
diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index 2c0b3ba049d7b..eae636347271a 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -355,6 +355,12 @@ DiagnosticEngine &MLIRContext::getDiagEngine() { return getImpl().diagEngine; }
//===----------------------------------------------------------------------===//
void MLIRContext::appendDialectRegistry(const DialectRegistry ®istry) {
+ if (registry.isSubsetOf(impl->dialectsRegistry))
+ return;
+
+ assert(impl->multiThreadedExecutionContext == 0 &&
+ "appending to the MLIRContext dialect registry while in a "
+ "multi-threaded execution context");
registry.appendTo(impl->dialectsRegistry);
// For the already loaded dialects, apply any possible extensions immediately.
@@ -470,6 +476,9 @@ bool MLIRContext::allowsUnregisteredDialects() {
}
void MLIRContext::allowUnregisteredDialects(bool allowing) {
+ assert(impl->multiThreadedExecutionContext == 0 &&
+ "changing MLIRContext `allow-unregistered-dialects` configuration "
+ "while in a multi-threaded execution context");
impl->allowUnregisteredDialects = allowing;
}
@@ -484,6 +493,9 @@ void MLIRContext::disableMultithreading(bool disable) {
// --mlir-disable-threading
if (isThreadingGloballyDisabled())
return;
+ assert(impl->multiThreadedExecutionContext == 0 &&
+ "changing MLIRContext `disable-threading` configuration while "
+ "in a multi-threaded execution context");
impl->threadingIsEnabled = !disable;
@@ -557,6 +569,9 @@ bool MLIRContext::shouldPrintOpOnDiagnostic() {
/// Set the flag specifying if we should attach the operation to diagnostics
/// emitted via Operation::emit.
void MLIRContext::printOpOnDiagnostic(bool enable) {
+ assert(impl->multiThreadedExecutionContext == 0 &&
+ "changing MLIRContext `print-op-on-diagnostic` configuration while in "
+ "a multi-threaded execution context");
impl->printOpOnDiagnostic = enable;
}
@@ -569,6 +584,9 @@ bool MLIRContext::shouldPrintStackTraceOnDiagnostic() {
/// Set the flag specifying if we should attach the current stacktrace when
/// emitting diagnostics.
void MLIRContext::printStackTraceOnDiagnostic(bool enable) {
+ assert(impl->multiThreadedExecutionContext == 0 &&
+ "changing MLIRContext `print-stacktrace-on-diagnostic` configuration "
+ "while in a multi-threaded execution context");
impl->printStackTraceOnDiagnostic = enable;
}
diff --git a/mlir/lib/Reducer/OptReductionPass.cpp b/mlir/lib/Reducer/OptReductionPass.cpp
index 806ce67e529a6..a7f09b4c923f6 100644
--- a/mlir/lib/Reducer/OptReductionPass.cpp
+++ b/mlir/lib/Reducer/OptReductionPass.cpp
@@ -42,7 +42,7 @@ void OptReductionPass::runOnOperation() {
ModuleOp module = this->getOperation();
ModuleOp moduleVariant = module.clone();
- PassManager passManager(module.getContext());
+ OpPassManager passManager("builtin.module");
if (failed(parsePassPipeline(optPass, passManager))) {
module.emitError() << "\nfailed to parse pass pipeline";
return signalPassFailure();
@@ -54,7 +54,13 @@ void OptReductionPass::runOnOperation() {
return signalPassFailure();
}
- if (failed(passManager.run(moduleVariant))) {
+ // Temporarily push the variant under the main module and execute the pipeline
+ // on it.
+ module.getBody()->push_back(moduleVariant);
+ LogicalResult pipelineResult = runPipeline(passManager, moduleVariant);
+ moduleVariant->remove();
+
+ if (failed(pipelineResult)) {
module.emitError() << "\nfailed to run pass pipeline";
return signalPassFailure();
}
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp
index 88ccc77ab59a2..573197a584e0e 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp
@@ -255,14 +255,13 @@ struct TestLinalgGreedyFusion
patterns.add<ExtractSliceOfPadTensorSwapPattern>(context);
scf::populateSCFForLoopCanonicalizationPatterns(patterns);
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
+ OpPassManager pm(FuncOp::getOperationName());
+ pm.addPass(createLoopInvariantCodeMotionPass());
+ pm.addPass(createCanonicalizerPass());
+ pm.addPass(createCSEPass());
do {
(void)applyPatternsAndFoldGreedily(getOperation(), frozenPatterns);
- PassManager pm(context);
- pm.addPass(createLoopInvariantCodeMotionPass());
- pm.addPass(createCanonicalizerPass());
- pm.addPass(createCSEPass());
- LogicalResult res = pm.run(getOperation()->getParentOfType<ModuleOp>());
- if (failed(res))
+ if (failed(runPipeline(pm, getOperation())))
this->signalPassFailure();
} while (succeeded(fuseLinalgOpsGreedily(getOperation())));
}
More information about the Mlir-commits
mailing list