[Mlir-commits] [mlir] a8a8119 - [mlir] Fix a rare use-after free in dialect loading

Benjamin Kramer llvmlistbot at llvm.org
Mon Jun 19 09:45:00 PDT 2023


Author: Benjamin Kramer
Date: 2023-06-19T18:20:36+02:00
New Revision: a8a811997062d38c6e16e190ecd6377213b77be6

URL: https://github.com/llvm/llvm-project/commit/a8a811997062d38c6e16e190ecd6377213b77be6
DIFF: https://github.com/llvm/llvm-project/commit/a8a811997062d38c6e16e190ecd6377213b77be6.diff

LOG: [mlir] Fix a rare use-after free in dialect loading

applyExtensions can load further dialects, invalidating the reference to
the dialect pointer in the dialects DenseMap. Capture the pointer to
prevent that from happening.

Added: 
    

Modified: 
    mlir/lib/IR/MLIRContext.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index cf355703aaf47..cc4b33f9ca669 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -457,8 +457,9 @@ MLIRContext::getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID,
     // dialect is currently being loaded. Re-lookup the address in
     // loadedDialects because the table might have been rehashed by recursive
     // dialect loading in ctor().
-    std::unique_ptr<Dialect> &dialect = impl.loadedDialects[dialectNamespace] =
-        ctor();
+    std::unique_ptr<Dialect> &dialectOwned =
+        impl.loadedDialects[dialectNamespace] = ctor();
+    Dialect *dialect = dialectOwned.get();
     assert(dialect && "dialect ctor failed");
 
     // Refresh all the identifiers dialect field, this catches cases where a
@@ -467,13 +468,13 @@ MLIRContext::getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID,
     auto stringAttrsIt = impl.dialectReferencingStrAttrs.find(dialectNamespace);
     if (stringAttrsIt != impl.dialectReferencingStrAttrs.end()) {
       for (StringAttrStorage *storage : stringAttrsIt->second)
-        storage->referencedDialect = dialect.get();
+        storage->referencedDialect = dialect;
       impl.dialectReferencingStrAttrs.erase(stringAttrsIt);
     }
 
     // Apply any extensions to this newly loaded dialect.
-    impl.dialectsRegistry.applyExtensions(dialect.get());
-    return dialect.get();
+    impl.dialectsRegistry.applyExtensions(dialect);
+    return dialect;
   }
 
 #ifndef NDEBUG


        


More information about the Mlir-commits mailing list