[Mlir-commits] [mlir] 012b148 - [mlir] Fix crash when adding nested dialect extensions
Matthias Springer
llvmlistbot at llvm.org
Sat Aug 26 01:15:36 PDT 2023
Author: Matthias Springer
Date: 2023-08-26T10:01:18+02:00
New Revision: 012b148dd9402c6ad58805701863d31fb506caff
URL: https://github.com/llvm/llvm-project/commit/012b148dd9402c6ad58805701863d31fb506caff
DIFF: https://github.com/llvm/llvm-project/commit/012b148dd9402c6ad58805701863d31fb506caff.diff
LOG: [mlir] Fix crash when adding nested dialect extensions
A dialect extension can add additional dialect extensions in its `apply` function. This used to crash when the vector of `extensions` was internally reallocated while it is being iterated over.
Differential Revision: https://reviews.llvm.org/D158838
Added:
Modified:
mlir/lib/IR/Dialect.cpp
mlir/unittests/IR/DialectTest.cpp
Removed:
################################################################################
diff --git a/mlir/lib/IR/Dialect.cpp b/mlir/lib/IR/Dialect.cpp
index c4e01ca5a8ae43..e860299fe4c496 100644
--- a/mlir/lib/IR/Dialect.cpp
+++ b/mlir/lib/IR/Dialect.cpp
@@ -125,7 +125,8 @@ DialectInterfaceCollectionBase::DialectInterfaceCollectionBase(
MLIRContext *ctx, TypeID interfaceKind, StringRef interfaceName) {
for (auto *dialect : ctx->getLoadedDialects()) {
#ifndef NDEBUG
- dialect->handleUseOfUndefinedPromisedInterface(interfaceKind, interfaceName);
+ dialect->handleUseOfUndefinedPromisedInterface(interfaceKind,
+ interfaceName);
#endif
if (auto *interface = dialect->getRegisteredInterface(interfaceKind)) {
interfaces.insert(interface);
@@ -248,8 +249,9 @@ void DialectRegistry::applyExtensions(Dialect *dialect) const {
extension.apply(ctx, requiredDialects);
};
- for (const auto &extension : extensions)
- applyExtension(*extension);
+ // Note: Additional extensions may be added while applying an extension.
+ for (int i = 0; i < static_cast<int>(extensions.size()); ++i)
+ applyExtension(*extensions[i]);
}
void DialectRegistry::applyExtensions(MLIRContext *ctx) const {
@@ -274,8 +276,9 @@ void DialectRegistry::applyExtensions(MLIRContext *ctx) const {
extension.apply(ctx, requiredDialects);
};
- for (const auto &extension : extensions)
- applyExtension(*extension);
+ // Note: Additional extensions may be added while applying an extension.
+ for (int i = 0; i < static_cast<int>(extensions.size()); ++i)
+ applyExtension(*extensions[i]);
}
bool DialectRegistry::isSubsetOf(const DialectRegistry &rhs) const {
diff --git a/mlir/unittests/IR/DialectTest.cpp b/mlir/unittests/IR/DialectTest.cpp
index a2b58bf7319762..e99d46e6d26436 100644
--- a/mlir/unittests/IR/DialectTest.cpp
+++ b/mlir/unittests/IR/DialectTest.cpp
@@ -136,4 +136,50 @@ TEST(Dialect, RepeatedDelayedRegistration) {
EXPECT_TRUE(testDialectInterface != nullptr);
}
+namespace {
+/// A dummy extension that increases a counter when being applied and
+/// recursively adds additional extensions.
+struct DummyExtension : DialectExtension<DummyExtension, TestDialect> {
+ DummyExtension(int *counter, int numRecursive)
+ : DialectExtension(), counter(counter), numRecursive(numRecursive) {}
+
+ void apply(MLIRContext *ctx, TestDialect *dialect) const final {
+ ++(*counter);
+ DialectRegistry nestedRegistry;
+ for (int i = 0; i < numRecursive; ++i)
+ nestedRegistry.addExtension(
+ std::make_unique<DummyExtension>(counter, /*numRecursive=*/0));
+ // Adding additional extensions may trigger a reallocation of the
+ // `extensions` vector in the dialect registry.
+ ctx->appendDialectRegistry(nestedRegistry);
+ }
+
+private:
+ int *counter;
+ int numRecursive;
+};
+} // namespace
+
+TEST(Dialect, NestedDialectExtension) {
+ DialectRegistry registry;
+ registry.insert<TestDialect>();
+
+ // Add an extension that adds 100 more extensions.
+ int counter1 = 0;
+ registry.addExtension(std::make_unique<DummyExtension>(&counter1, 100));
+ // Add one more extension. This should not crash.
+ int counter2 = 0;
+ registry.addExtension(std::make_unique<DummyExtension>(&counter2, 0));
+
+ // Load dialect and apply extensions.
+ MLIRContext context(registry);
+ Dialect *testDialect = context.getOrLoadDialect<TestDialect>();
+ ASSERT_TRUE(testDialect != nullptr);
+
+ // Extensions may be applied multiple times. Make sure that each expected
+ // extension was applied at least once.
+ EXPECT_GE(counter1, 101);
+ EXPECT_GE(counter2, 1);
+}
+
} // namespace
More information about the Mlir-commits
mailing list