[Mlir-commits] [mlir] 4433e52 - [mlir] Fix circular dialect initialization
Matthias Springer
llvmlistbot at llvm.org
Thu Oct 27 02:50:46 PDT 2022
Author: Matthias Springer
Date: 2022-10-27T11:50:37+02:00
New Revision: 4433e52e69b1ce19b1d3c756e6d3262170ad4a30
URL: https://github.com/llvm/llvm-project/commit/4433e52e69b1ce19b1d3c756e6d3262170ad4a30
DIFF: https://github.com/llvm/llvm-project/commit/4433e52e69b1ce19b1d3c756e6d3262170ad4a30.diff
LOG: [mlir] Fix circular dialect initialization
This change fixes a bug where a dialect is initialized multiple times. This triggers an assertion when the ops of the dialect are registered (`error: operation named ... is already registered`).
This bug can be triggered as follows:
1. Dialect A depends on dialect B (as per ADialect.td).
2. Somewhere there is an extension of dialect B that depends on dialect A (e.g., it defines external models create ops from dialect A). E.g.:
```
registry.addExtension(+[](MLIRContext *ctx, BDialect *dialect) {
BDialectOp::attachInterface ...
ctx->loadDialect<ADialect>();
});
```
3. When dialect A is loaded, its `initialize` function is called twice:
```
ADialect::ADialect()
| |
| v
| ADialect::initialize()
v
getOrLoadDialect<BDialect>()
|
v
(load extension of BDialect)
|
v
ctx->loadDialect<ADialect>() // user wrote this in the extension
|
v
getOrLoadDialect<ADialect>() // the dialect is not "fully" loaded yet
|
v
ADialect::ADialect()
|
v
ADialect::initialize()
```
An example of a dialect extension that depends on other dialects is `Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp`. That particular dialect extension does not trigger this bug. (It would trigger this bug if the SCF dialect would depend on the Tensor dialect.)
This change introduces a new dialect state: dialects that are currently being loaded. Same as dialects that were already fully loaded (and initialized), dialects that are in the process of being loaded are not loaded a second time.
Differential Revision: https://reviews.llvm.org/D136685
Added:
Modified:
mlir/include/mlir/IR/MLIRContext.h
mlir/lib/IR/MLIRContext.cpp
mlir/tools/mlir-tblgen/DialectGen.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/MLIRContext.h b/mlir/include/mlir/IR/MLIRContext.h
index c162b00f8402..b87dd27b2aaa 100644
--- a/mlir/include/mlir/IR/MLIRContext.h
+++ b/mlir/include/mlir/IR/MLIRContext.h
@@ -98,16 +98,22 @@ class MLIRContext {
}));
}
+ /// Return true if the given dialect is currently loading.
+ bool isDialectLoading(StringRef dialectNamespace);
+
/// Load a dialect in the context.
template <typename Dialect>
void loadDialect() {
- getOrLoadDialect<Dialect>();
+ // Do not load the dialect if it is currently loading. This can happen if a
+ // dialect initializer triggers loading the same dialect recursively.
+ if (!isDialectLoading(Dialect::getDialectNamespace()))
+ getOrLoadDialect<Dialect>();
}
/// Load a list dialects in the context.
template <typename Dialect, typename OtherDialect, typename... MoreDialects>
void loadDialect() {
- getOrLoadDialect<Dialect>();
+ loadDialect<Dialect>();
loadDialect<OtherDialect, MoreDialects...>();
}
diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index 7ddcc2ff11d9..896938d4406f 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -429,9 +429,11 @@ MLIRContext::getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID,
") while in a multi-threaded execution context (maybe "
"the PassManager): this can indicate a "
"missing `dependentDialects` in a pass for example.");
-#endif
- std::unique_ptr<Dialect> &dialect =
- impl.loadedDialects.insert({dialectNamespace, ctor()}).first->second;
+#endif // NDEBUG
+ // nullptr indicates that the dialect is currently being loaded.
+ impl.loadedDialects[dialectNamespace] = nullptr;
+ std::unique_ptr<Dialect> &dialect = impl.loadedDialects[dialectNamespace] =
+ ctor();
assert(dialect && "dialect ctor failed");
// Refresh all the identifiers dialect field, this catches cases where a
@@ -449,6 +451,14 @@ MLIRContext::getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID,
return dialect.get();
}
+#ifndef NDEBUG
+ if (dialectIt->second == nullptr)
+ llvm::report_fatal_error(
+ "Loading (and getting) a dialect (" + dialectNamespace +
+ ") while the same dialect is still loading: use loadDialect instead "
+ "of getOrLoadDialect.");
+#endif // NDEBUG
+
// Abort if dialect with namespace has already been registered.
std::unique_ptr<Dialect> &dialect = dialectIt->second;
if (dialect->getTypeID() != dialectID)
@@ -458,6 +468,12 @@ MLIRContext::getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID,
return dialect.get();
}
+bool MLIRContext::isDialectLoading(StringRef dialectNamespace) {
+ auto it = getImpl().loadedDialects.find(dialectNamespace);
+ // nullptr indicates that the dialect is currently being loaded.
+ return it != getImpl().loadedDialects.end() && it->second == nullptr;
+}
+
DynamicDialect *MLIRContext::getOrLoadDynamicDialect(
StringRef dialectNamespace, function_ref<void(DynamicDialect *)> ctor) {
auto &impl = getImpl();
diff --git a/mlir/tools/mlir-tblgen/DialectGen.cpp b/mlir/tools/mlir-tblgen/DialectGen.cpp
index c7e42ac0ac8f..1085407f561d 100644
--- a/mlir/tools/mlir-tblgen/DialectGen.cpp
+++ b/mlir/tools/mlir-tblgen/DialectGen.cpp
@@ -107,7 +107,7 @@ class {0} : public ::mlir::{2} {
/// Registration for a single dependent dialect: to be inserted in the ctor
/// above for each dependent dialect.
const char *const dialectRegistrationTemplate = R"(
- getContext()->getOrLoadDialect<{0}>();
+ getContext()->loadDialect<{0}>();
)";
/// The code block for the attribute parser/printer hooks.
More information about the Mlir-commits
mailing list