[llvm-branch-commits] [mlir] cc7e24c - [mlir] Fix crash when adding nested dialect extensions

Tobias Hieta via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Wed Aug 30 23:56:14 PDT 2023


Author: Matthias Springer
Date: 2023-08-31T08:54:24+02:00
New Revision: cc7e24c7a723fee9c4209663ea1517aeba34e42a

URL: https://github.com/llvm/llvm-project/commit/cc7e24c7a723fee9c4209663ea1517aeba34e42a
DIFF: https://github.com/llvm/llvm-project/commit/cc7e24c7a723fee9c4209663ea1517aeba34e42a.diff

LOG: [mlir] Fix crash when adding nested dialect extensions

A dialect extension can add additional dialect extensions in its `apply` function. This used to crash when the vector of `extensions` was internally reallocated while it is being iterated over.

Differential Revision: https://reviews.llvm.org/D158838

Added: 
    

Modified: 
    mlir/lib/IR/Dialect.cpp
    mlir/unittests/IR/DialectTest.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/IR/Dialect.cpp b/mlir/lib/IR/Dialect.cpp
index 501f52b83e026e..1de49769974ac6 100644
--- a/mlir/lib/IR/Dialect.cpp
+++ b/mlir/lib/IR/Dialect.cpp
@@ -125,7 +125,8 @@ DialectInterfaceCollectionBase::DialectInterfaceCollectionBase(
     MLIRContext *ctx, TypeID interfaceKind, StringRef interfaceName) {
   for (auto *dialect : ctx->getLoadedDialects()) {
 #ifndef NDEBUG
-  dialect->handleUseOfUndefinedPromisedInterface(interfaceKind, interfaceName);
+    dialect->handleUseOfUndefinedPromisedInterface(interfaceKind,
+                                                   interfaceName);
 #endif
     if (auto *interface = dialect->getRegisteredInterface(interfaceKind)) {
       interfaces.insert(interface);
@@ -243,8 +244,9 @@ void DialectRegistry::applyExtensions(Dialect *dialect) const {
     extension.apply(ctx, requiredDialects);
   };
 
-  for (const auto &extension : extensions)
-    applyExtension(*extension);
+  // Note: Additional extensions may be added while applying an extension.
+  for (int i = 0; i < static_cast<int>(extensions.size()); ++i)
+    applyExtension(*extensions[i]);
 }
 
 void DialectRegistry::applyExtensions(MLIRContext *ctx) const {
@@ -264,8 +266,9 @@ void DialectRegistry::applyExtensions(MLIRContext *ctx) const {
     extension.apply(ctx, requiredDialects);
   };
 
-  for (const auto &extension : extensions)
-    applyExtension(*extension);
+  // Note: Additional extensions may be added while applying an extension.
+  for (int i = 0; i < static_cast<int>(extensions.size()); ++i)
+    applyExtension(*extensions[i]);
 }
 
 bool DialectRegistry::isSubsetOf(const DialectRegistry &rhs) const {

diff  --git a/mlir/unittests/IR/DialectTest.cpp b/mlir/unittests/IR/DialectTest.cpp
index a2b58bf7319762..e99d46e6d26436 100644
--- a/mlir/unittests/IR/DialectTest.cpp
+++ b/mlir/unittests/IR/DialectTest.cpp
@@ -136,4 +136,50 @@ TEST(Dialect, RepeatedDelayedRegistration) {
   EXPECT_TRUE(testDialectInterface != nullptr);
 }
 
+namespace {
+/// A dummy extension that increases a counter when being applied and
+/// recursively adds additional extensions.
+struct DummyExtension : DialectExtension<DummyExtension, TestDialect> {
+  DummyExtension(int *counter, int numRecursive)
+      : DialectExtension(), counter(counter), numRecursive(numRecursive) {}
+
+  void apply(MLIRContext *ctx, TestDialect *dialect) const final {
+    ++(*counter);
+    DialectRegistry nestedRegistry;
+    for (int i = 0; i < numRecursive; ++i)
+      nestedRegistry.addExtension(
+          std::make_unique<DummyExtension>(counter, /*numRecursive=*/0));
+    // Adding additional extensions may trigger a reallocation of the
+    // `extensions` vector in the dialect registry.
+    ctx->appendDialectRegistry(nestedRegistry);
+  }
+
+private:
+  int *counter;
+  int numRecursive;
+};
+} // namespace
+
+TEST(Dialect, NestedDialectExtension) {
+  DialectRegistry registry;
+  registry.insert<TestDialect>();
+
+  // Add an extension that adds 100 more extensions.
+  int counter1 = 0;
+  registry.addExtension(std::make_unique<DummyExtension>(&counter1, 100));
+  // Add one more extension. This should not crash.
+  int counter2 = 0;
+  registry.addExtension(std::make_unique<DummyExtension>(&counter2, 0));
+
+  // Load dialect and apply extensions.
+  MLIRContext context(registry);
+  Dialect *testDialect = context.getOrLoadDialect<TestDialect>();
+  ASSERT_TRUE(testDialect != nullptr);
+
+  // Extensions may be applied multiple times. Make sure that each expected
+  // extension was applied at least once.
+  EXPECT_GE(counter1, 101);
+  EXPECT_GE(counter2, 1);
+}
+
 } // namespace


        


More information about the llvm-branch-commits mailing list