[Mlir-commits] [mlir] 74a8a1e - [mlir] Fix a use after free when loading dependent dialects
Benjamin Kramer
llvmlistbot at llvm.org
Wed Apr 5 06:47:06 PDT 2023
Author: Benjamin Kramer
Date: 2023-04-05T15:44:29+02:00
New Revision: 74a8a1e038022fb4ca9b8e444489e910f16a9741
URL: https://github.com/llvm/llvm-project/commit/74a8a1e038022fb4ca9b8e444489e910f16a9741
DIFF: https://github.com/llvm/llvm-project/commit/74a8a1e038022fb4ca9b8e444489e910f16a9741.diff
LOG: [mlir] Fix a use after free when loading dependent dialects
The way dependent dialects are implemented is by recursively calling
loadDialect in the constructor. This means we have to reload from the
dialect table because the constructor might have rehashed that table.
The steps for loading a dialect are
1. Insert a nullptr into loadedDialects. This indicates the dialect is
loading
2. Call ctor(). This recursively loads dependent dialects
3. Insert the new dialect into the table.
We had a conflict between steps 2 and 3 here. You have to be extremely
unlucky though as rehashing is rare and operator[] does no generation
checking on DenseMap. Changing that to an iterator would've uncovered
this issue immediately.
Added:
Modified:
mlir/lib/IR/MLIRContext.cpp
Removed:
################################################################################
diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index daa4a6af63020..e64babf35dac0 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -438,9 +438,9 @@ MLIRContext::getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID,
function_ref<std::unique_ptr<Dialect>()> ctor) {
auto &impl = getImpl();
// Get the correct insertion position sorted by namespace.
- auto dialectIt = impl.loadedDialects.find(dialectNamespace);
+ auto dialectIt = impl.loadedDialects.try_emplace(dialectNamespace, nullptr);
- if (dialectIt == impl.loadedDialects.end()) {
+ if (dialectIt.second) {
LLVM_DEBUG(llvm::dbgs()
<< "Load new dialect in Context " << dialectNamespace << "\n");
#ifndef NDEBUG
@@ -452,9 +452,11 @@ MLIRContext::getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID,
"missing `dependentDialects` in a pass for example.");
#endif // NDEBUG
// loadedDialects entry is initialized to nullptr, indicating that the
- // dialect is currently being loaded.
- std::unique_ptr<Dialect> &dialect = impl.loadedDialects[dialectNamespace];
- dialect = ctor();
+ // dialect is currently being loaded. Re-lookup the address in
+ // loadedDialects because the table might have been rehashed by recursive
+ // dialect loading in ctor().
+ std::unique_ptr<Dialect> &dialect = impl.loadedDialects[dialectNamespace] =
+ ctor();
assert(dialect && "dialect ctor failed");
// Refresh all the identifiers dialect field, this catches cases where a
@@ -473,7 +475,7 @@ MLIRContext::getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID,
}
#ifndef NDEBUG
- if (dialectIt->second == nullptr)
+ if (dialectIt.first->second == nullptr)
llvm::report_fatal_error(
"Loading (and getting) a dialect (" + dialectNamespace +
") while the same dialect is still loading: use loadDialect instead "
@@ -481,7 +483,7 @@ MLIRContext::getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID,
#endif // NDEBUG
// Abort if dialect with namespace has already been registered.
- std::unique_ptr<Dialect> &dialect = dialectIt->second;
+ std::unique_ptr<Dialect> &dialect = dialectIt.first->second;
if (dialect->getTypeID() != dialectID)
llvm::report_fatal_error("a dialect with namespace '" + dialectNamespace +
"' has already been registered");
More information about the Mlir-commits
mailing list