[llvm-branch-commits] [mlir] cc7e24c - [mlir] Fix crash when adding nested dialect extensions
Tobias Hieta via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Wed Aug 30 23:56:14 PDT 2023
Author: Matthias Springer
Date: 2023-08-31T08:54:24+02:00
New Revision: cc7e24c7a723fee9c4209663ea1517aeba34e42a
URL: https://github.com/llvm/llvm-project/commit/cc7e24c7a723fee9c4209663ea1517aeba34e42a
DIFF: https://github.com/llvm/llvm-project/commit/cc7e24c7a723fee9c4209663ea1517aeba34e42a.diff
LOG: [mlir] Fix crash when adding nested dialect extensions
A dialect extension can add additional dialect extensions in its `apply` function. This used to crash when the vector of `extensions` was internally reallocated while it is being iterated over.
Differential Revision: https://reviews.llvm.org/D158838
Added:
Modified:
mlir/lib/IR/Dialect.cpp
mlir/unittests/IR/DialectTest.cpp
Removed:
################################################################################
diff --git a/mlir/lib/IR/Dialect.cpp b/mlir/lib/IR/Dialect.cpp
index 501f52b83e026e..1de49769974ac6 100644
--- a/mlir/lib/IR/Dialect.cpp
+++ b/mlir/lib/IR/Dialect.cpp
@@ -125,7 +125,8 @@ DialectInterfaceCollectionBase::DialectInterfaceCollectionBase(
MLIRContext *ctx, TypeID interfaceKind, StringRef interfaceName) {
for (auto *dialect : ctx->getLoadedDialects()) {
#ifndef NDEBUG
- dialect->handleUseOfUndefinedPromisedInterface(interfaceKind, interfaceName);
+ dialect->handleUseOfUndefinedPromisedInterface(interfaceKind,
+ interfaceName);
#endif
if (auto *interface = dialect->getRegisteredInterface(interfaceKind)) {
interfaces.insert(interface);
@@ -243,8 +244,9 @@ void DialectRegistry::applyExtensions(Dialect *dialect) const {
extension.apply(ctx, requiredDialects);
};
- for (const auto &extension : extensions)
- applyExtension(*extension);
+ // Note: Additional extensions may be added while applying an extension.
+ for (int i = 0; i < static_cast<int>(extensions.size()); ++i)
+ applyExtension(*extensions[i]);
}
void DialectRegistry::applyExtensions(MLIRContext *ctx) const {
@@ -264,8 +266,9 @@ void DialectRegistry::applyExtensions(MLIRContext *ctx) const {
extension.apply(ctx, requiredDialects);
};
- for (const auto &extension : extensions)
- applyExtension(*extension);
+ // Note: Additional extensions may be added while applying an extension.
+ for (int i = 0; i < static_cast<int>(extensions.size()); ++i)
+ applyExtension(*extensions[i]);
}
bool DialectRegistry::isSubsetOf(const DialectRegistry &rhs) const {
diff --git a/mlir/unittests/IR/DialectTest.cpp b/mlir/unittests/IR/DialectTest.cpp
index a2b58bf7319762..e99d46e6d26436 100644
--- a/mlir/unittests/IR/DialectTest.cpp
+++ b/mlir/unittests/IR/DialectTest.cpp
@@ -136,4 +136,50 @@ TEST(Dialect, RepeatedDelayedRegistration) {
EXPECT_TRUE(testDialectInterface != nullptr);
}
+namespace {
+/// A dummy extension that increases a counter when being applied and
+/// recursively adds additional extensions.
+struct DummyExtension : DialectExtension<DummyExtension, TestDialect> {
+ DummyExtension(int *counter, int numRecursive)
+ : DialectExtension(), counter(counter), numRecursive(numRecursive) {}
+
+ void apply(MLIRContext *ctx, TestDialect *dialect) const final {
+ ++(*counter);
+ DialectRegistry nestedRegistry;
+ for (int i = 0; i < numRecursive; ++i)
+ nestedRegistry.addExtension(
+ std::make_unique<DummyExtension>(counter, /*numRecursive=*/0));
+ // Adding additional extensions may trigger a reallocation of the
+ // `extensions` vector in the dialect registry.
+ ctx->appendDialectRegistry(nestedRegistry);
+ }
+
+private:
+ int *counter;
+ int numRecursive;
+};
+} // namespace
+
+TEST(Dialect, NestedDialectExtension) {
+ DialectRegistry registry;
+ registry.insert<TestDialect>();
+
+ // Add an extension that adds 100 more extensions.
+ int counter1 = 0;
+ registry.addExtension(std::make_unique<DummyExtension>(&counter1, 100));
+ // Add one more extension. This should not crash.
+ int counter2 = 0;
+ registry.addExtension(std::make_unique<DummyExtension>(&counter2, 0));
+
+ // Load dialect and apply extensions.
+ MLIRContext context(registry);
+ Dialect *testDialect = context.getOrLoadDialect<TestDialect>();
+ ASSERT_TRUE(testDialect != nullptr);
+
+ // Extensions may be applied multiple times. Make sure that each expected
+ // extension was applied at least once.
+ EXPECT_GE(counter1, 101);
+ EXPECT_GE(counter2, 1);
+}
+
} // namespace
More information about the llvm-branch-commits
mailing list