[Mlir-commits] [clang] [llvm] [mlir] pr/densemap nfc mlir (PR #199365)

Fangrui Song llvmlistbot at llvm.org
Sat May 23 11:26:18 PDT 2026


https://github.com/MaskRay created https://github.com/llvm/llvm-project/pull/199365

- **[llvm,clang] Don't assume non-erased DenseMap entries remain valid after erase. NFC**
- **[mlir] Don't assume non-erased DenseMap entries remain valid after erase. NFC**


>From 8e79ae4a7bc91efd62a984faf716d3208584c496 Mon Sep 17 00:00:00 2001
From: Fangrui Song <i at maskray.me>
Date: Mon, 11 May 2026 23:37:03 -0700
Subject: [PATCH 1/2] [llvm,clang] Don't assume non-erased DenseMap entries
 remain valid after erase. NFC

In preparation for switching DenseMap from tombstone deletion to
backward-shift deletion, update call sites that reuse an iterator or a
bucket reference after erasing another entry from the same map.

These work under tombstone deletion because unrelated buckets stay put,
but backward-shift deletion relocates entries to close the gap.

Add DenseMap::remove_if, similar to SmallPtrSet::remove_if, as
replacement for erase-while-iterating, and use it where applicable.

Aided by Claude Opus 4.7
---
 clang/lib/AST/ASTImporter.cpp                 | 13 ++++--
 clang/lib/CodeGen/CoverageMappingGen.cpp      | 10 ++---
 clang/lib/Interpreter/IncrementalParser.cpp   |  7 ++-
 clang/lib/Sema/HLSLExternalSemaSource.cpp     |  6 ++-
 llvm/include/llvm/ADT/DenseMap.h              | 27 +++++++++++
 llvm/include/llvm/ADT/DenseSet.h              | 10 +++++
 llvm/lib/Analysis/IRSimilarityIdentifier.cpp  |  7 ++-
 llvm/lib/Analysis/LoopAccessAnalysis.cpp      | 11 +++--
 llvm/lib/Analysis/ScalarEvolution.cpp         | 22 +++------
 .../CodeGen/AssignmentTrackingAnalysis.cpp    | 19 ++++----
 llvm/lib/CodeGen/MachineCopyPropagation.cpp   |  2 +-
 llvm/lib/CodeGen/MachineLateInstrsCleanup.cpp | 12 ++---
 llvm/lib/CodeGen/PeepholeOptimizer.cpp        | 15 +++----
 .../CodeGen/RemoveRedundantDebugValues.cpp    |  8 ++--
 llvm/lib/ExecutionEngine/Orc/Core.cpp         | 21 ++++-----
 llvm/lib/IR/LegacyPassManager.cpp             | 40 ++++++-----------
 llvm/lib/MC/MCObjectStreamer.cpp              | 13 +++---
 llvm/lib/Support/CommandLine.cpp              | 13 ++++--
 llvm/lib/Transforms/IPO/FunctionImport.cpp    |  8 +---
 .../Scalar/LowerMatrixIntrinsics.cpp          |  3 +-
 .../Utils/PromoteMemoryToRegister.cpp         | 16 +++----
 llvm/unittests/ADT/DenseMapTest.cpp           | 45 +++++++++++++++++++
 llvm/unittests/ADT/DenseSetTest.cpp           | 16 +++++++
 llvm/unittests/Support/JSONTest.cpp           |  2 +-
 24 files changed, 212 insertions(+), 134 deletions(-)

diff --git a/clang/lib/AST/ASTImporter.cpp b/clang/lib/AST/ASTImporter.cpp
index 7bab2d7dcddfa..f8527af2bfe6f 100644
--- a/clang/lib/AST/ASTImporter.cpp
+++ b/clang/lib/AST/ASTImporter.cpp
@@ -9864,10 +9864,15 @@ Expected<Decl *> ASTImporter::Import(Decl *FromD) {
     // Failed to import.
 
     auto Pos = ImportedDecls.find(FromD);
-    if (Pos != ImportedDecls.end()) {
+    bool ToDWasCreated = Pos != ImportedDecls.end();
+    // Capture the mapped decl before erasing: the iterator is invalidated by
+    // the erase below under backward-shift deletion, but it is still needed
+    // further down to record the import error.
+    Decl *CreatedToD = ToDWasCreated ? Pos->second : nullptr;
+    if (ToDWasCreated) {
       // Import failed after the object was created.
       // Remove all references to it.
-      auto *ToD = Pos->second;
+      auto *ToD = CreatedToD;
       ImportedDecls.erase(Pos);
 
       // ImportedDecls and ImportedFromDecls are not symmetric.  It may happen
@@ -9903,8 +9908,8 @@ Expected<Decl *> ASTImporter::Import(Decl *FromD) {
                     [&ErrOut](const ASTImportError &E) { ErrOut = E; });
     setImportDeclError(FromD, ErrOut);
     // Set the error for the mapped to Decl, which is in the "to" context.
-    if (Pos != ImportedDecls.end())
-      SharedState->setImportDeclError(Pos->second, ErrOut);
+    if (ToDWasCreated)
+      SharedState->setImportDeclError(CreatedToD, ErrOut);
 
     // Set the error for all nodes which have been created before we
     // recognized the error.
diff --git a/clang/lib/CodeGen/CoverageMappingGen.cpp b/clang/lib/CodeGen/CoverageMappingGen.cpp
index eadb6e3bb25a8..c90afacbde293 100644
--- a/clang/lib/CodeGen/CoverageMappingGen.cpp
+++ b/clang/lib/CodeGen/CoverageMappingGen.cpp
@@ -2260,13 +2260,9 @@ struct CounterCoverageMappingBuilder
     (void)FoundCount;
 
     // Tell CodeGenPGO not to instrument.
-    for (auto I = MCDCState.BranchByStmt.begin(),
-              E = MCDCState.BranchByStmt.end();
-         I != E;) {
-      auto II = I++;
-      if (II->second.DecisionStmt == Decision)
-        MCDCState.BranchByStmt.erase(II);
-    }
+    MCDCState.BranchByStmt.remove_if([&](const auto &Entry) {
+      return Entry.second.DecisionStmt == Decision;
+    });
     MCDCState.DecisionByStmt.erase(Decision);
   }
 
diff --git a/clang/lib/Interpreter/IncrementalParser.cpp b/clang/lib/Interpreter/IncrementalParser.cpp
index 16a954f3c15e7..f6d2779d64b2b 100644
--- a/clang/lib/Interpreter/IncrementalParser.cpp
+++ b/clang/lib/Interpreter/IncrementalParser.cpp
@@ -174,6 +174,9 @@ IncrementalParser::Parse(llvm::StringRef input) {
 
 void IncrementalParser::CleanUpPTU(TranslationUnitDecl *MostRecentTU) {
   if (StoredDeclsMap *Map = MostRecentTU->getPrimaryContext()->getLookupPtr()) {
+    // Collect the keys to erase: erasing during iteration invalidates the map
+    // iterator under backward-shift deletion.
+    llvm::SmallVector<DeclarationName, 16> KeysToErase;
     for (auto &&[Key, List] : *Map) {
       DeclContextLookupResult R = List.getLookupResult();
       std::vector<NamedDecl *> NamedDeclsToRemove;
@@ -185,12 +188,14 @@ void IncrementalParser::CleanUpPTU(TranslationUnitDecl *MostRecentTU) {
           RemoveAll = false;
       }
       if (LLVM_LIKELY(RemoveAll)) {
-        Map->erase(Key);
+        KeysToErase.push_back(Key);
       } else {
         for (NamedDecl *D : NamedDeclsToRemove)
           List.remove(D);
       }
     }
+    for (DeclarationName Key : KeysToErase)
+      Map->erase(Key);
   }
 
   ExternCContextDecl *ECCD = S.getASTContext().getExternCContextDecl();
diff --git a/clang/lib/Sema/HLSLExternalSemaSource.cpp b/clang/lib/Sema/HLSLExternalSemaSource.cpp
index 449b32a215631..ae61b590a1f71 100644
--- a/clang/lib/Sema/HLSLExternalSemaSource.cpp
+++ b/clang/lib/Sema/HLSLExternalSemaSource.cpp
@@ -726,6 +726,10 @@ void HLSLExternalSemaSource::CompleteType(TagDecl *Tag) {
   auto It = Completions.find(Record);
   if (It == Completions.end())
     return;
-  It->second(Record);
+  // Move out the callback and erase before invoking it: the callback can
+  // re-enter CompleteType and mutate Completions, which invalidates It under
+  // backward-shift deletion.
+  CompletionFunction Fn = std::move(It->second);
   Completions.erase(It);
+  Fn(Record);
 }
diff --git a/llvm/include/llvm/ADT/DenseMap.h b/llvm/include/llvm/ADT/DenseMap.h
index b8b548a31acbc..e13b64b3e6bf4 100644
--- a/llvm/include/llvm/ADT/DenseMap.h
+++ b/llvm/include/llvm/ADT/DenseMap.h
@@ -344,6 +344,33 @@ class DenseMapBase : public DebugEpochBase {
     incrementNumTombstones();
   }
 
+  /// Remove entries that match the given predicate. \p Pred is invoked
+  /// with a reference to each live bucket and must not access the map being
+  /// modified. This is the safe replacement for erase-while-iterating.
+  ///
+  /// Returns whether anything was removed. If so, all iterators and references
+  /// into the map are invalidated.
+  template <typename Predicate> bool remove_if(Predicate Pred) {
+    const KeyT EmptyKey = KeyInfoT::getEmptyKey();
+    const KeyT TombstoneKey = KeyInfoT::getTombstoneKey();
+    bool Removed = false;
+    for (BucketT &B : buckets()) {
+      if (KeyInfoT::isEqual(B.getFirst(), EmptyKey) ||
+          KeyInfoT::isEqual(B.getFirst(), TombstoneKey))
+        continue;
+      if (Pred(B)) {
+        B.getSecond().~ValueT();
+        B.getFirst() = TombstoneKey;
+        decrementNumEntries();
+        incrementNumTombstones();
+        Removed = true;
+      }
+    }
+    if (Removed)
+      incrementEpoch();
+    return Removed;
+  }
+
   ValueT &operator[](const KeyT &Key) {
     return lookupOrInsertIntoBucket(Key).first->second;
   }
diff --git a/llvm/include/llvm/ADT/DenseSet.h b/llvm/include/llvm/ADT/DenseSet.h
index eec800d07b6df..645d6d1568f35 100644
--- a/llvm/include/llvm/ADT/DenseSet.h
+++ b/llvm/include/llvm/ADT/DenseSet.h
@@ -99,6 +99,16 @@ class DenseSetImpl {
 
   bool erase(const ValueT &V) { return TheMap.erase(V); }
 
+  /// Remove all elements for which \p Pred returns true.  This is the safe
+  /// replacement for erase-while-iterating; see DenseMap::remove_if.  The
+  /// predicate must not access the set being modified.  Returns whether
+  /// anything was removed; if so, all iterators are invalidated.
+  template <typename Predicate> bool remove_if(Predicate Pred) {
+    return TheMap.remove_if([&](const typename MapTy::value_type &KV) {
+      return Pred(KV.getFirst());
+    });
+  }
+
   void swap(DenseSetImpl &RHS) { TheMap.swap(RHS.TheMap); }
 
 private:
diff --git a/llvm/lib/Analysis/IRSimilarityIdentifier.cpp b/llvm/lib/Analysis/IRSimilarityIdentifier.cpp
index e5ebd1f908d55..11a824973de60 100644
--- a/llvm/lib/Analysis/IRSimilarityIdentifier.cpp
+++ b/llvm/lib/Analysis/IRSimilarityIdentifier.cpp
@@ -720,7 +720,12 @@ bool IRSimilarityCandidate::compareAssignmentMapping(
   if (!WasInserted && !ValueMappingIt->second.contains(InstValB))
     return false;
   else if (ValueMappingIt->second.size() != 1) {
-    for (unsigned OtherVal : ValueMappingIt->second) {
+    // Snapshot the set before iterating: when InstValA maps to itself the
+    // erase below removes InstValA from the very set being iterated, which
+    // invalidates the range iterator under backward-shift deletion.
+    SmallVector<unsigned> OtherVals(ValueMappingIt->second.begin(),
+                                    ValueMappingIt->second.end());
+    for (unsigned OtherVal : OtherVals) {
       if (OtherVal == InstValB)
         continue;
       auto OtherValIt = ValueNumberMappingA.find(OtherVal);
diff --git a/llvm/lib/Analysis/LoopAccessAnalysis.cpp b/llvm/lib/Analysis/LoopAccessAnalysis.cpp
index 2b9efd22131c6..60e15b2a5bd82 100644
--- a/llvm/lib/Analysis/LoopAccessAnalysis.cpp
+++ b/llvm/lib/Analysis/LoopAccessAnalysis.cpp
@@ -3229,12 +3229,11 @@ void LoopAccessInfoManager::clear() {
   // analyzed loop or SCEVs that may have been modified or invalidated. At the
   // moment, that is loops requiring memory or SCEV runtime checks, as those cache
   // SCEVs, e.g. for pointer expressions.
-  for (const auto &[L, LAI] : LoopAccessInfoMap) {
-    if (LAI->getRuntimePointerChecking()->getChecks().empty() &&
-        LAI->getPSE().getPredicate().isAlwaysTrue())
-      continue;
-    LoopAccessInfoMap.erase(L);
-  }
+  LoopAccessInfoMap.remove_if([](const auto &Entry) {
+    const auto &LAI = Entry.second;
+    return !(LAI->getRuntimePointerChecking()->getChecks().empty() &&
+             LAI->getPSE().getPredicate().isAlwaysTrue());
+  });
 }
 
 bool LoopAccessInfoManager::invalidate(
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 657984aa0c3a8..855ab3bb5d621 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -8763,8 +8763,8 @@ void ScalarEvolution::visitAndClearUsers(
     ValueExprMapType::iterator It =
         ValueExprMap.find_as(static_cast<Value *>(I));
     if (It != ValueExprMap.end()) {
-      eraseValueFromMap(It->first);
       ToForget.push_back(It->second);
+      eraseValueFromMap(It->first);
       if (PHINode *PN = dyn_cast<PHINode>(I))
         ConstantEvolutionLoopExitValue.erase(PN);
     }
@@ -8788,14 +8788,8 @@ void ScalarEvolution::forgetLoop(const Loop *L) {
     forgetBackedgeTakenCounts(CurrL, /* Predicated */ true);
 
     // Drop information about predicated SCEV rewrites for this loop.
-    for (auto I = PredicatedSCEVRewrites.begin();
-         I != PredicatedSCEVRewrites.end();) {
-      std::pair<const SCEV *, const Loop *> Entry = I->first;
-      if (Entry.second == CurrL)
-        PredicatedSCEVRewrites.erase(I++);
-      else
-        ++I;
-    }
+    PredicatedSCEVRewrites.remove_if(
+        [&](const auto &Entry) { return Entry.first.second == CurrL; });
 
     auto LoopUsersItr = LoopUsers.find(CurrL);
     if (LoopUsersItr != LoopUsers.end())
@@ -14581,14 +14575,8 @@ void ScalarEvolution::forgetMemoizedResults(ArrayRef<SCEVUse> SCEVs) {
   for (const auto *S : ToForget)
     forgetMemoizedResultsImpl(S);
 
-  for (auto I = PredicatedSCEVRewrites.begin();
-       I != PredicatedSCEVRewrites.end();) {
-    std::pair<const SCEV *, const Loop *> Entry = I->first;
-    if (ToForget.count(Entry.first))
-      PredicatedSCEVRewrites.erase(I++);
-    else
-      ++I;
-  }
+  PredicatedSCEVRewrites.remove_if(
+      [&](const auto &Entry) { return ToForget.count(Entry.first.first); });
 }
 
 void ScalarEvolution::forgetMemoizedResultsImpl(const SCEV *S) {
diff --git a/llvm/lib/CodeGen/AssignmentTrackingAnalysis.cpp b/llvm/lib/CodeGen/AssignmentTrackingAnalysis.cpp
index 2bd278614f8ac..d4da80a49bc71 100644
--- a/llvm/lib/CodeGen/AssignmentTrackingAnalysis.cpp
+++ b/llvm/lib/CodeGen/AssignmentTrackingAnalysis.cpp
@@ -543,18 +543,15 @@ class MemLocFragmentFill {
     // Meet A and B.
     //
     // Result = meet(a, b) for a in A, b in B where Var(a) == Var(b)
-    for (auto It = A.begin(), End = A.end(); It != End; ++It) {
-      unsigned AVar = It->first;
-      FragsInMemMap &AFrags = It->second;
-      auto BIt = B.find(AVar);
-      if (BIt == B.end()) {
-        A.erase(It);
-        continue; // Var has no bits defined in B.
-      }
+    A.remove_if([&](auto &Entry) {
+      auto BIt = B.find(Entry.first);
+      if (BIt == B.end())
+        return true; // Var has no bits defined in B.
       LLVM_DEBUG(dbgs() << "meet fragment maps for "
-                        << Aggregates[AVar].first->getName() << "\n");
-      AFrags = meetFragments(AFrags, BIt->second);
-    }
+                        << Aggregates[Entry.first].first->getName() << "\n");
+      Entry.second = meetFragments(Entry.second, BIt->second);
+      return false;
+    });
   }
 
   bool meet(const BasicBlock &BB,
diff --git a/llvm/lib/CodeGen/MachineCopyPropagation.cpp b/llvm/lib/CodeGen/MachineCopyPropagation.cpp
index 1d8fdcab64909..6bfd21b215706 100644
--- a/llvm/lib/CodeGen/MachineCopyPropagation.cpp
+++ b/llvm/lib/CodeGen/MachineCopyPropagation.cpp
@@ -255,7 +255,7 @@ class CopyTracker {
         }
       }
       // Now we can erase the copy.
-      Copies.erase(I);
+      Copies.erase(Unit);
     }
   }
 
diff --git a/llvm/lib/CodeGen/MachineLateInstrsCleanup.cpp b/llvm/lib/CodeGen/MachineLateInstrsCleanup.cpp
index 4f281fa1361ca..811cc4fe65f3f 100644
--- a/llvm/lib/CodeGen/MachineLateInstrsCleanup.cpp
+++ b/llvm/lib/CodeGen/MachineLateInstrsCleanup.cpp
@@ -243,15 +243,17 @@ bool MachineLateInstrsCleanup::processBlock(MachineBasicBlock *MBB) {
     }
 
     // Clear any entries in map that MI clobbers.
-    for (auto DefI : llvm::make_early_inc_range(MBBDefs)) {
-      Register Reg = DefI.first;
+    MBBDefs.remove_if([&](const auto &Entry) {
+      Register Reg = Entry.first;
       if (MI.modifiesRegister(Reg, TRI)) {
-        MBBDefs.erase(Reg);
         MBBKills.erase(Reg);
-      } else if (MI.findRegisterUseOperandIdx(Reg, TRI, true /*isKill*/) != -1)
+        return true;
+      }
+      if (MI.findRegisterUseOperandIdx(Reg, TRI, true /*isKill*/) != -1)
         // Keep track of all instructions that fully or partially kills Reg.
         MBBKills[Reg].push_back(&MI);
-    }
+      return false;
+    });
 
     // Record this MI for potential later reuse.
     if (IsCandidate) {
diff --git a/llvm/lib/CodeGen/PeepholeOptimizer.cpp b/llvm/lib/CodeGen/PeepholeOptimizer.cpp
index cfbffb920ef36..19a81e3363086 100644
--- a/llvm/lib/CodeGen/PeepholeOptimizer.cpp
+++ b/llvm/lib/CodeGen/PeepholeOptimizer.cpp
@@ -1813,14 +1813,13 @@ bool PeepholeOptimizer::run(MachineFunction &MF) {
             }
           } else if (MO.isRegMask()) {
             const uint32_t *RegMask = MO.getRegMask();
-            for (auto &RegMI : NAPhysToVirtMIs) {
-              Register Def = RegMI.first;
-              if (MachineOperand::clobbersPhysReg(RegMask, Def)) {
-                LLVM_DEBUG(dbgs()
-                           << "NAPhysCopy: invalidating because of " << *MI);
-                NAPhysToVirtMIs.erase(Def);
-              }
-            }
+            NAPhysToVirtMIs.remove_if([&](const auto &RegMI) {
+              if (!MachineOperand::clobbersPhysReg(RegMask, RegMI.first))
+                return false;
+              LLVM_DEBUG(dbgs()
+                         << "NAPhysCopy: invalidating because of " << *MI);
+              return true;
+            });
           }
         }
       }
diff --git a/llvm/lib/CodeGen/RemoveRedundantDebugValues.cpp b/llvm/lib/CodeGen/RemoveRedundantDebugValues.cpp
index 11468245f8400..30057317a6383 100644
--- a/llvm/lib/CodeGen/RemoveRedundantDebugValues.cpp
+++ b/llvm/lib/CodeGen/RemoveRedundantDebugValues.cpp
@@ -128,11 +128,9 @@ static bool reduceDbgValsForwardScan(MachineBasicBlock &MBB) {
       continue;
 
     // Stop tracking any location that is clobbered by this instruction.
-    for (auto &Var : VariableMap) {
-      auto &LocOp = Var.second.first;
-      if (MI.modifiesRegister(LocOp->getReg(), TRI))
-        VariableMap.erase(Var.first);
-    }
+    VariableMap.remove_if([&](const auto &Var) {
+      return MI.modifiesRegister(Var.second.first->getReg(), TRI);
+    });
   }
 
   for (auto &Instr : DbgValsToBeRemoved) {
diff --git a/llvm/lib/ExecutionEngine/Orc/Core.cpp b/llvm/lib/ExecutionEngine/Orc/Core.cpp
index ac780dc82ae5a..1838397f7b270 100644
--- a/llvm/lib/ExecutionEngine/Orc/Core.cpp
+++ b/llvm/lib/ExecutionEngine/Orc/Core.cpp
@@ -1065,9 +1065,7 @@ void JITDylib::removeFromLinkOrder(JITDylib &JD) {
 Error JITDylib::remove(const SymbolNameSet &Names) {
   return ES.runSessionLocked([&]() -> Error {
     assert(State == Open && "JD is defunct");
-    using SymbolMaterializerItrPair =
-        std::pair<SymbolTable::iterator, UnmaterializedInfosMap::iterator>;
-    std::vector<SymbolMaterializerItrPair> SymbolsToRemove;
+    SmallVector<SymbolStringPtr, 0> SymbolsToRemove;
     SymbolNameSet Missing;
     SymbolNameSet Materializing;
 
@@ -1087,10 +1085,7 @@ Error JITDylib::remove(const SymbolNameSet &Names) {
         continue;
       }
 
-      auto UMII = I->second.hasMaterializerAttached()
-                      ? UnmaterializedInfos.find(Name)
-                      : UnmaterializedInfos.end();
-      SymbolsToRemove.push_back(std::make_pair(I, UMII));
+      SymbolsToRemove.push_back(Name);
     }
 
     // If any of the symbols are not defined, return an error.
@@ -1103,18 +1098,18 @@ Error JITDylib::remove(const SymbolNameSet &Names) {
       return make_error<SymbolsCouldNotBeRemoved>(ES.getSymbolStringPool(),
                                                   std::move(Materializing));
 
-    // Remove the symbols.
-    for (auto &SymbolMaterializerItrPair : SymbolsToRemove) {
-      auto UMII = SymbolMaterializerItrPair.second;
-
+    // Remove the symbols. Erase by key rather than holding iterators across the
+    // loop: a prior erase invalidates other stored iterators under
+    // backward-shift deletion.
+    for (const SymbolStringPtr &Name : SymbolsToRemove) {
       // If there is a materializer attached, call discard.
+      auto UMII = UnmaterializedInfos.find(Name);
       if (UMII != UnmaterializedInfos.end()) {
         UMII->second->MU->doDiscard(*this, UMII->first);
         UnmaterializedInfos.erase(UMII);
       }
 
-      auto SymI = SymbolMaterializerItrPair.first;
-      Symbols.erase(SymI);
+      Symbols.erase(Name);
     }
 
     shrinkMaterializationInfoMemory();
diff --git a/llvm/lib/IR/LegacyPassManager.cpp b/llvm/lib/IR/LegacyPassManager.cpp
index 7b9ad89038dc6..b8efa7a399734 100644
--- a/llvm/lib/IR/LegacyPassManager.cpp
+++ b/llvm/lib/IR/LegacyPassManager.cpp
@@ -903,40 +903,26 @@ void PMDataManager::removeNotPreservedAnalysis(Pass *P) {
     return;
 
   const AnalysisUsage::VectorType &PreservedSet = AnUsage->getPreservedSet();
-  for (auto I = AvailableAnalysis.begin(), E = AvailableAnalysis.end();
-       I != E;) {
-    auto Info = I++;
-    if (Info->second->getAsImmutablePass() == nullptr &&
-        !is_contained(PreservedSet, Info->first)) {
-      // Remove this analysis
-      if (PassDebugging >= Details) {
-        Pass *S = Info->second;
-        dbgs() << " -- '" <<  P->getPassName() << "' is not preserving '";
-        dbgs() << S->getPassName() << "'\n";
-      }
-      AvailableAnalysis.erase(Info);
+  auto IsNotPreserved = [&](const auto &Entry) {
+    if (Entry.second->getAsImmutablePass() != nullptr ||
+        is_contained(PreservedSet, Entry.first))
+      return false;
+    // Remove this analysis
+    if (PassDebugging >= Details) {
+      Pass *S = Entry.second;
+      dbgs() << " -- '" << P->getPassName() << "' is not preserving '";
+      dbgs() << S->getPassName() << "'\n";
     }
-  }
+    return true;
+  };
+  AvailableAnalysis.remove_if(IsNotPreserved);
 
   // Check inherited analysis also. If P is not preserving analysis
   // provided by parent manager then remove it here.
   for (DenseMap<AnalysisID, Pass *> *IA : InheritedAnalysis) {
     if (!IA)
       continue;
-
-    for (auto I = IA->begin(), E = IA->end(); I != E;) {
-      auto Info = I++;
-      if (Info->second->getAsImmutablePass() == nullptr &&
-          !is_contained(PreservedSet, Info->first)) {
-        // Remove this analysis
-        if (PassDebugging >= Details) {
-          Pass *S = Info->second;
-          dbgs() << " -- '" <<  P->getPassName() << "' is not preserving '";
-          dbgs() << S->getPassName() << "'\n";
-        }
-        IA->erase(Info);
-      }
-    }
+    IA->remove_if(IsNotPreserved);
   }
 }
 
diff --git a/llvm/lib/MC/MCObjectStreamer.cpp b/llvm/lib/MC/MCObjectStreamer.cpp
index 88dafb94a4aaa..2bf5f05c1c315 100644
--- a/llvm/lib/MC/MCObjectStreamer.cpp
+++ b/llvm/lib/MC/MCObjectStreamer.cpp
@@ -272,12 +272,15 @@ void MCObjectStreamer::emitLabel(MCSymbol *Symbol, SMLoc Loc) {
 
 void MCObjectStreamer::emitPendingAssignments(MCSymbol *Symbol) {
   auto Assignments = pendingAssignments.find(Symbol);
-  if (Assignments != pendingAssignments.end()) {
-    for (const PendingAssignment &A : Assignments->second)
-      emitAssignment(A.Symbol, A.Value);
+  if (Assignments == pendingAssignments.end())
+    return;
 
-    pendingAssignments.erase(Assignments);
-  }
+  // emitAssignment can recursively re-enter emitPendingAssignments for
+  // other symbols, so move the list out and erase before iterating.
+  SmallVector<PendingAssignment, 1> Pending = std::move(Assignments->second);
+  pendingAssignments.erase(Assignments);
+  for (const PendingAssignment &A : Pending)
+    emitAssignment(A.Symbol, A.Value);
 }
 
 // Emit a label at a previously emitted fragment/offset position. This must be
diff --git a/llvm/lib/Support/CommandLine.cpp b/llvm/lib/Support/CommandLine.cpp
index 0c244b31724f1..5f679d50f8073 100644
--- a/llvm/lib/Support/CommandLine.cpp
+++ b/llvm/lib/Support/CommandLine.cpp
@@ -277,10 +277,11 @@ class CommandLineParser {
       OptionNames.push_back(O->ArgStr);
 
     SubCommand &Sub = *SC;
-    auto End = Sub.OptionsMap.end();
     for (auto Name : OptionNames) {
       auto I = Sub.OptionsMap.find(Name);
-      if (I != End && I->second == O)
+      // Re-query end() each iteration: a prior erase invalidates iterators
+      // (including a cached end()) under backward-shift deletion.
+      if (I != Sub.OptionsMap.end() && I->second == O)
         Sub.OptionsMap.erase(I);
     }
 
@@ -1494,8 +1495,14 @@ void CommandLineParser::ResetAllOptionOccurrences() {
   // Options might be reset twice (they can be reference in both OptionsMap
   // and one of the other members), but that does not harm.
   for (auto *SC : RegisteredSubCommands) {
+    // reset() removes default options from OptionsMap (via removeArgument), so
+    // collect the options first to avoid invalidating the map iterator.
+    SmallVector<Option *, 0> Opts;
+    Opts.reserve(SC->OptionsMap.size());
     for (auto &O : SC->OptionsMap)
-      O.second->reset();
+      Opts.push_back(O.second);
+    for (Option *O : Opts)
+      O->reset();
     for (Option *O : SC->PositionalOpts)
       O->reset();
     for (Option *O : SC->SinkOpts)
diff --git a/llvm/lib/Transforms/IPO/FunctionImport.cpp b/llvm/lib/Transforms/IPO/FunctionImport.cpp
index 456a9b116cc30..d305eadc12f35 100644
--- a/llvm/lib/Transforms/IPO/FunctionImport.cpp
+++ b/llvm/lib/Transforms/IPO/FunctionImport.cpp
@@ -1284,12 +1284,8 @@ void llvm::ComputeCrossModuleImport(
     // exporting module. We do this after the above insertion since we may hit
     // the same ref/call target multiple times in above loop, and it is more
     // efficient to avoid a set lookup each time.
-    for (auto EI = NewExports.begin(); EI != NewExports.end();) {
-      if (!DefinedGVSummaries.count(EI->getGUID()))
-        NewExports.erase(EI++);
-      else
-        ++EI;
-    }
+    NewExports.remove_if(
+        [&](ValueInfo VI) { return !DefinedGVSummaries.count(VI.getGUID()); });
     ELI.second.insert_range(NewExports);
   }
 
diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 04fdd24ad76da..c031574260c3c 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -902,9 +902,10 @@ class LowerMatrixIntrinsics {
     // it conditionally instead.
     auto S = ShapeMap.find(&Old);
     if (S != ShapeMap.end()) {
+      ShapeInfo Shape = S->second;
       ShapeMap.erase(S);
       if (supportsShapeInfo(New))
-        ShapeMap.insert({New, S->second});
+        ShapeMap.insert({New, Shape});
     }
     Old.replaceAllUsesWith(New);
   }
diff --git a/llvm/lib/Transforms/Utils/PromoteMemoryToRegister.cpp b/llvm/lib/Transforms/Utils/PromoteMemoryToRegister.cpp
index b635d805bf13d..ed0e864fd6905 100644
--- a/llvm/lib/Transforms/Utils/PromoteMemoryToRegister.cpp
+++ b/llvm/lib/Transforms/Utils/PromoteMemoryToRegister.cpp
@@ -929,22 +929,16 @@ void PromoteMem2Reg::run() {
     // simplify and RAUW them as we go.  If it was not, we could add uses to
     // the values we replace with in a non-deterministic order, thus creating
     // non-deterministic def->use chains.
-    for (DenseMap<std::pair<unsigned, unsigned>, PHINode *>::iterator
-             I = NewPhiNodes.begin(),
-             E = NewPhiNodes.end();
-         I != E;) {
-      PHINode *PN = I->second;
-
+    EliminatedAPHI = NewPhiNodes.remove_if([&](const auto &Entry) {
+      PHINode *PN = Entry.second;
       // If this PHI node merges one value and/or undefs, get the value.
       if (Value *V = simplifyInstruction(PN, SQ)) {
         PN->replaceAllUsesWith(V);
         PN->eraseFromParent();
-        NewPhiNodes.erase(I++);
-        EliminatedAPHI = true;
-        continue;
+        return true;
       }
-      ++I;
-    }
+      return false;
+    });
   }
 
   // At this point, the renamer has added entries to PHI nodes for all reachable
diff --git a/llvm/unittests/ADT/DenseMapTest.cpp b/llvm/unittests/ADT/DenseMapTest.cpp
index 553d159d33b1a..c38c0709f615a 100644
--- a/llvm/unittests/ADT/DenseMapTest.cpp
+++ b/llvm/unittests/ADT/DenseMapTest.cpp
@@ -1108,4 +1108,49 @@ TEST(DenseMapCustomTest, ValueDtor) {
   EXPECT_EQ(0u, CtorTester::getNumConstructed());
 }
 
+TEST(DenseMapCustomTest, RemoveIf) {
+  // Use enough entries to exercise the large representation and force the
+  // same-size rehash inside remove_if to restore the probe invariant.
+  DenseMap<int, int> Map;
+  for (int I = 0; I < 100; ++I)
+    Map[I] = I * 10;
+
+  // Remove all even keys.
+  EXPECT_TRUE(Map.remove_if([](const auto &E) { return E.first % 2 == 0; }));
+  EXPECT_EQ(Map.size(), 50u);
+  for (int I = 0; I < 100; ++I) {
+    auto It = Map.find(I);
+    if (I % 2 == 0) {
+      EXPECT_EQ(It, Map.end());
+    } else {
+      ASSERT_NE(It, Map.end());
+      EXPECT_EQ(It->second, I * 10);
+    }
+  }
+
+  // A predicate that matches nothing returns false and leaves the map alone.
+  EXPECT_FALSE(Map.remove_if([](const auto &) { return false; }));
+  EXPECT_EQ(Map.size(), 50u);
+
+  // Remove everything.
+  EXPECT_TRUE(Map.remove_if([](const auto &) { return true; }));
+  EXPECT_TRUE(Map.empty());
+}
+
+TEST(DenseMapCustomTest, RemoveIfValueDtor) {
+  // remove_if must destroy the values of removed entries exactly once, and the
+  // rehash must not leak or double-destroy surviving values.
+  EXPECT_EQ(0u, CtorTester::getNumConstructed());
+  {
+    DenseMap<int, CtorTester> Map;
+    for (int I = 0; I < 16; ++I)
+      Map.try_emplace(I, CtorTester(I));
+    EXPECT_EQ(16u, CtorTester::getNumConstructed());
+    EXPECT_TRUE(Map.remove_if([](const auto &E) { return E.first < 10; }));
+    EXPECT_EQ(6u, CtorTester::getNumConstructed());
+    EXPECT_EQ(Map.size(), 6u);
+  }
+  EXPECT_EQ(0u, CtorTester::getNumConstructed());
+}
+
 } // namespace
diff --git a/llvm/unittests/ADT/DenseSetTest.cpp b/llvm/unittests/ADT/DenseSetTest.cpp
index a2a062b151b67..9d214b8649f66 100644
--- a/llvm/unittests/ADT/DenseSetTest.cpp
+++ b/llvm/unittests/ADT/DenseSetTest.cpp
@@ -65,6 +65,22 @@ TEST(SmallDenseSetTest, InsertRange) {
   EXPECT_THAT(set, ::testing::UnorderedElementsAre(7, 8, 9));
 }
 
+TEST(DenseSetTest, RemoveIf) {
+  llvm::DenseSet<unsigned> set;
+  for (unsigned I = 0; I < 100; ++I)
+    set.insert(I);
+
+  EXPECT_TRUE(set.remove_if([](unsigned V) { return V % 2 == 0; }));
+  EXPECT_EQ(set.size(), 50u);
+  for (unsigned I = 0; I < 100; ++I)
+    EXPECT_EQ(set.contains(I), I % 2 == 1);
+
+  EXPECT_FALSE(set.remove_if([](unsigned) { return false; }));
+  EXPECT_EQ(set.size(), 50u);
+  EXPECT_TRUE(set.remove_if([](unsigned) { return true; }));
+  EXPECT_TRUE(set.empty());
+}
+
 struct TestDenseSetInfo {
   static inline unsigned getEmptyKey() { return ~0; }
   static inline unsigned getTombstoneKey() { return ~0U - 1; }
diff --git a/llvm/unittests/Support/JSONTest.cpp b/llvm/unittests/Support/JSONTest.cpp
index a58799b4f0455..348f5cb04f9db 100644
--- a/llvm/unittests/Support/JSONTest.cpp
+++ b/llvm/unittests/Support/JSONTest.cpp
@@ -155,8 +155,8 @@ TEST(JSONTest, Object) {
   auto E = O.find("e");
   EXPECT_EQ(E, O.end());
 
-  O.erase("b");
   O.erase(D);
+  O.erase("b");
   EXPECT_EQ(O.size(), 2u);
   EXPECT_EQ(R"({"a":1,"c":3})", s(std::move(O)));
 }

>From c0f27695746f1aa269c73e6f87fa8fba33d6dec2 Mon Sep 17 00:00:00 2001
From: Fangrui Song <i at maskray.me>
Date: Fri, 22 May 2026 00:10:26 -0700
Subject: [PATCH 2/2] [mlir] Don't assume non-erased DenseMap entries remain
 valid after erase. NFC

Like the preceding llvm/ change, fix MLIR sites that reuse an iterator
or bucket reference after erasing from the same map, in preparation for
backward-shift DenseMap deletion which relocates surviving entries.

Use DenseMap::remove_if in ThreadLocalCache::clearExpiredEntries and the
RootOrdering cycle contraction (deferring the in-cycle graph erases until
after iteration). ThreadLocalCache::get reads the value into a local and
bufferizeOp snapshots the worklist before folding, since those erases
re-enter via a rewriter listener.

Aided by Claude Opus 4.7
---
 mlir/include/mlir/Support/ThreadLocalCache.h  | 16 ++++----
 .../PDLToPDLInterp/RootOrdering.cpp           | 40 +++++++++----------
 .../Bufferization/Transforms/Bufferize.cpp    | 11 ++++-
 3 files changed, 36 insertions(+), 31 deletions(-)

diff --git a/mlir/include/mlir/Support/ThreadLocalCache.h b/mlir/include/mlir/Support/ThreadLocalCache.h
index 53b6d31a09555..92b6a7f5eec44 100644
--- a/mlir/include/mlir/Support/ThreadLocalCache.h
+++ b/mlir/include/mlir/Support/ThreadLocalCache.h
@@ -128,11 +128,8 @@ class ThreadLocalCache {
     /// Clear out any unused entries within the map. This method is not
     /// thread-safe, and should only be called by the same thread as the cache.
     void clearExpiredEntries() {
-      for (auto it = this->begin(), e = this->end(); it != e;) {
-        auto curIt = it++;
-        if (!curIt->second.ptr->second)
-          this->erase(curIt);
-      }
+      this->remove_if(
+          [](const auto &entry) { return !entry.second.ptr->second; });
     }
   };
 
@@ -159,11 +156,12 @@ class ThreadLocalCache {
     }
     threadInstance.keepalive = perInstanceState;
 
-    // Before returning the new instance, take the chance to clear out any used
-    // entries in the static map. The cache is only cleared within the same
-    // thread to remove the need to lock the cache itself.
+    // Capture the value before clearing expired entries: clearExpiredEntries
+    // erases from `staticCache`, and backward-shift deletion can relocate the
+    // bucket referenced by `threadInstance`.
+    ValueT &value = *threadInstance.ptr->first;
     staticCache.clearExpiredEntries();
-    return *threadInstance.ptr->first;
+    return value;
   }
   ValueT &operator*() { return get(); }
   ValueT *operator->() { return &get(); }
diff --git a/mlir/lib/Conversion/PDLToPDLInterp/RootOrdering.cpp b/mlir/lib/Conversion/PDLToPDLInterp/RootOrdering.cpp
index 2d9c661f7df2c..b36d5a774275d 100644
--- a/mlir/lib/Conversion/PDLToPDLInterp/RootOrdering.cpp
+++ b/mlir/lib/Conversion/PDLToPDLInterp/RootOrdering.cpp
@@ -52,12 +52,11 @@ static void contract(RootOrderingGraph &graph, ArrayRef<Value> cycle,
 
   // Now, contract the cycle, marking the actual sources and targets.
   DenseMap<Value, RootOrderingEntry> repEntries;
-  for (auto outer = graph.begin(), e = graph.end(); outer != e; ++outer) {
-    Value target = outer->first;
+  for (auto &[target, edges] : graph) {
     if (cycleSet.contains(target)) {
       // Target in the cycle => edges incoming to the cycle or within the cycle.
       unsigned parentDepth = parentDepths.lookup(target);
-      for (const auto &inner : outer->second) {
+      for (const auto &inner : edges) {
         Value source = inner.first;
         // Ignore edges within the cycle.
         if (cycleSet.contains(source))
@@ -81,36 +80,37 @@ static void contract(RootOrderingGraph &graph, ArrayRef<Value> cycle,
           repEntries[source].cost = cost;
         }
       }
-      // Erase the node in the cycle.
-      graph.erase(outer);
+      // Defer erasing graph[target] until after the loop; backward-shift
+      // erase would otherwise invalidate the surrounding iterator.
     } else {
       // Target not in cycle => edges going away from or unrelated to the cycle.
-      DenseMap<Value, RootOrderingEntry> &entries = outer->second;
       Value bestSource;
       std::pair<unsigned, unsigned> bestCost;
-      auto inner = entries.begin(), innerE = entries.end();
-      while (inner != innerE) {
-        Value source = inner->first;
-        if (cycleSet.contains(source)) {
-          // Going-away edge => get its cost and erase it.
-          if (!bestSource || bestCost > inner->second.cost) {
-            bestSource = source;
-            bestCost = inner->second.cost;
-          }
-          entries.erase(inner++);
-        } else {
-          ++inner;
+      edges.remove_if([&](const auto &inner) {
+        Value source = inner.first;
+        if (!cycleSet.contains(source))
+          return false;
+        // Going-away edge => get its cost and erase it.
+        if (!bestSource || bestCost > inner.second.cost) {
+          bestSource = source;
+          bestCost = inner.second.cost;
         }
-      }
+        return true;
+      });
 
       // There were going-away edges, contract them.
       if (bestSource) {
-        entries[rep].cost = bestCost;
+        edges[rep].cost = bestCost;
         actualSource[target] = bestSource;
       }
     }
   }
 
+  // Erase all in-cycle nodes from the graph. Done after the iteration above
+  // because backward-shift erase relocates surviving entries.
+  for (Value node : cycle)
+    graph.erase(node);
+
   // Store the edges to the representative.
   graph[rep] = std::move(repEntries);
 }
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index 701ab52a491a8..0c155150d65c9 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -345,8 +345,15 @@ LogicalResult bufferization::bufferizeOp(Operation *op,
   if (erasedOps.contains(op))
     return success();
 
-  // Fold all to_buffer(to_tensor(x)) pairs.
-  for (Operation *op : toBufferOps) {
+  // Fold all to_buffer(to_tensor(x)) pairs.  Snapshot the set first:
+  // `foldToBufferToTensorPair` can erase ops, and the rewriter listener
+  // mutates `toBufferOps` from inside that call, which would invalidate
+  // any DenseSet iterator held across it.
+  SmallVector<Operation *> toBufferOpsSnapshot(toBufferOps.begin(),
+                                               toBufferOps.end());
+  for (Operation *op : toBufferOpsSnapshot) {
+    if (erasedOps.contains(op))
+      continue;
     rewriter.setInsertionPoint(op);
     (void)bufferization::foldToBufferToTensorPair(
         rewriter, cast<ToBufferOp>(op), options);



More information about the Mlir-commits mailing list