[Mlir-commits] [mlir] 012b148 - [mlir] Fix crash when adding nested dialect extensions

Matthias Springer llvmlistbot at llvm.org
Sat Aug 26 01:15:36 PDT 2023


Author: Matthias Springer
Date: 2023-08-26T10:01:18+02:00
New Revision: 012b148dd9402c6ad58805701863d31fb506caff

URL: https://github.com/llvm/llvm-project/commit/012b148dd9402c6ad58805701863d31fb506caff
DIFF: https://github.com/llvm/llvm-project/commit/012b148dd9402c6ad58805701863d31fb506caff.diff

LOG: [mlir] Fix crash when adding nested dialect extensions

A dialect extension can add additional dialect extensions in its `apply` function. This used to crash when the vector of `extensions` was internally reallocated while it is being iterated over.

Differential Revision: https://reviews.llvm.org/D158838

Added: 
    

Modified: 
    mlir/lib/IR/Dialect.cpp
    mlir/unittests/IR/DialectTest.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/IR/Dialect.cpp b/mlir/lib/IR/Dialect.cpp
index c4e01ca5a8ae43..e860299fe4c496 100644
--- a/mlir/lib/IR/Dialect.cpp
+++ b/mlir/lib/IR/Dialect.cpp
@@ -125,7 +125,8 @@ DialectInterfaceCollectionBase::DialectInterfaceCollectionBase(
     MLIRContext *ctx, TypeID interfaceKind, StringRef interfaceName) {
   for (auto *dialect : ctx->getLoadedDialects()) {
 #ifndef NDEBUG
-  dialect->handleUseOfUndefinedPromisedInterface(interfaceKind, interfaceName);
+    dialect->handleUseOfUndefinedPromisedInterface(interfaceKind,
+                                                   interfaceName);
 #endif
     if (auto *interface = dialect->getRegisteredInterface(interfaceKind)) {
       interfaces.insert(interface);
@@ -248,8 +249,9 @@ void DialectRegistry::applyExtensions(Dialect *dialect) const {
     extension.apply(ctx, requiredDialects);
   };
 
-  for (const auto &extension : extensions)
-    applyExtension(*extension);
+  // Note: Additional extensions may be added while applying an extension.
+  for (int i = 0; i < static_cast<int>(extensions.size()); ++i)
+    applyExtension(*extensions[i]);
 }
 
 void DialectRegistry::applyExtensions(MLIRContext *ctx) const {
@@ -274,8 +276,9 @@ void DialectRegistry::applyExtensions(MLIRContext *ctx) const {
     extension.apply(ctx, requiredDialects);
   };
 
-  for (const auto &extension : extensions)
-    applyExtension(*extension);
+  // Note: Additional extensions may be added while applying an extension.
+  for (int i = 0; i < static_cast<int>(extensions.size()); ++i)
+    applyExtension(*extensions[i]);
 }
 
 bool DialectRegistry::isSubsetOf(const DialectRegistry &rhs) const {

diff  --git a/mlir/unittests/IR/DialectTest.cpp b/mlir/unittests/IR/DialectTest.cpp
index a2b58bf7319762..e99d46e6d26436 100644
--- a/mlir/unittests/IR/DialectTest.cpp
+++ b/mlir/unittests/IR/DialectTest.cpp
@@ -136,4 +136,50 @@ TEST(Dialect, RepeatedDelayedRegistration) {
   EXPECT_TRUE(testDialectInterface != nullptr);
 }
 
+namespace {
+/// A dummy extension that increases a counter when being applied and
+/// recursively adds additional extensions.
+struct DummyExtension : DialectExtension<DummyExtension, TestDialect> {
+  DummyExtension(int *counter, int numRecursive)
+      : DialectExtension(), counter(counter), numRecursive(numRecursive) {}
+
+  void apply(MLIRContext *ctx, TestDialect *dialect) const final {
+    ++(*counter);
+    DialectRegistry nestedRegistry;
+    for (int i = 0; i < numRecursive; ++i)
+      nestedRegistry.addExtension(
+          std::make_unique<DummyExtension>(counter, /*numRecursive=*/0));
+    // Adding additional extensions may trigger a reallocation of the
+    // `extensions` vector in the dialect registry.
+    ctx->appendDialectRegistry(nestedRegistry);
+  }
+
+private:
+  int *counter;
+  int numRecursive;
+};
+} // namespace
+
+TEST(Dialect, NestedDialectExtension) {
+  DialectRegistry registry;
+  registry.insert<TestDialect>();
+
+  // Add an extension that adds 100 more extensions.
+  int counter1 = 0;
+  registry.addExtension(std::make_unique<DummyExtension>(&counter1, 100));
+  // Add one more extension. This should not crash.
+  int counter2 = 0;
+  registry.addExtension(std::make_unique<DummyExtension>(&counter2, 0));
+
+  // Load dialect and apply extensions.
+  MLIRContext context(registry);
+  Dialect *testDialect = context.getOrLoadDialect<TestDialect>();
+  ASSERT_TRUE(testDialect != nullptr);
+
+  // Extensions may be applied multiple times. Make sure that each expected
+  // extension was applied at least once.
+  EXPECT_GE(counter1, 101);
+  EXPECT_GE(counter2, 1);
+}
+
 } // namespace


        


More information about the Mlir-commits mailing list