[Mlir-commits] [mlir] ff459c1 - [mlir] Fix invalidated reference when loading dependent dialects

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Dec 16 10:59:17 PST 2021


Author: Mogball
Date: 2021-12-16T18:59:12Z
New Revision: ff459c1f67f13925d02fc2570bdc99bf56f7993c

URL: https://github.com/llvm/llvm-project/commit/ff459c1f67f13925d02fc2570bdc99bf56f7993c
DIFF: https://github.com/llvm/llvm-project/commit/ff459c1f67f13925d02fc2570bdc99bf56f7993c.diff

LOG: [mlir] Fix invalidated reference when loading dependent dialects

When a dialect is loaded with `getOrLoadDialect`, its constructor may recurse and call `getOrLoadDialect` on a dependent dialect, which may result in an insertion in the dialect map, invalidating the reference to the (previously null) dialect pointer.

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D115846

Added: 
    

Modified: 
    mlir/lib/IR/MLIRContext.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index 1e10d60f59485..a670d9e42618b 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -409,9 +409,9 @@ MLIRContext::getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID,
                               function_ref<std::unique_ptr<Dialect>()> ctor) {
   auto &impl = getImpl();
   // Get the correct insertion position sorted by namespace.
-  std::unique_ptr<Dialect> &dialect = impl.loadedDialects[dialectNamespace];
+  auto dialectIt = impl.loadedDialects.find(dialectNamespace);
 
-  if (!dialect) {
+  if (dialectIt == impl.loadedDialects.end()) {
     LLVM_DEBUG(llvm::dbgs()
                << "Load new dialect in Context " << dialectNamespace << "\n");
 #ifndef NDEBUG
@@ -422,7 +422,8 @@ MLIRContext::getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID,
           "the PassManager): this can indicate a "
           "missing `dependentDialects` in a pass for example.");
 #endif
-    dialect = ctor();
+    std::unique_ptr<Dialect> &dialect =
+        impl.loadedDialects.insert({dialectNamespace, ctor()}).first->second;
     assert(dialect && "dialect ctor failed");
 
     // Refresh all the identifiers dialect field, this catches cases where a
@@ -441,6 +442,7 @@ MLIRContext::getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID,
   }
 
   // Abort if dialect with namespace has already been registered.
+  std::unique_ptr<Dialect> &dialect = dialectIt->second;
   if (dialect->getTypeID() != dialectID)
     llvm::report_fatal_error("a dialect with namespace '" + dialectNamespace +
                              "' has already been registered");


        


More information about the Mlir-commits mailing list