[llvm] [SCCP] get rid of potentially dangling iterator (PR #105609)

Florian Mayer via llvm-commits llvm-commits at lists.llvm.org
Wed Aug 21 20:43:33 PDT 2024


https://github.com/fmayer created https://github.com/llvm/llvm-project/pull/105609

getUserBonus took an iterator into KnownConstants, but then inserted
into it later in the function, which could resize the data-structure
(DenseMap is based on a vector, so is not iterator stable).

The iterator is not needed, because we never change the value of an
entry after inserting.


>From 0c74cfdff70cc11edbfa3b1409beb504361e5a60 Mon Sep 17 00:00:00 2001
From: Florian Mayer <fmayer at google.com>
Date: Wed, 21 Aug 2024 20:43:15 -0700
Subject: [PATCH] =?UTF-8?q?[=F0=9D=98=80=F0=9D=97=BD=F0=9D=97=BF]=20initia?=
 =?UTF-8?q?l=20version?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Created using spr 1.3.4
---
 .../Transforms/IPO/FunctionSpecialization.h   |  3 +-
 .../Transforms/IPO/FunctionSpecialization.cpp | 60 ++++++++++---------
 2 files changed, 34 insertions(+), 29 deletions(-)

diff --git a/llvm/include/llvm/Transforms/IPO/FunctionSpecialization.h b/llvm/include/llvm/Transforms/IPO/FunctionSpecialization.h
index b001771951e0fe..05442c026fea88 100644
--- a/llvm/include/llvm/Transforms/IPO/FunctionSpecialization.h
+++ b/llvm/include/llvm/Transforms/IPO/FunctionSpecialization.h
@@ -189,7 +189,8 @@ class InstCostVisitor : public InstVisitor<InstCostVisitor, Constant *> {
   // (some of their incoming values may have become constant or dead).
   SmallVector<Instruction *> PendingPHIs;
 
-  ConstMap::iterator LastVisited;
+  Value* LastVisitedUse = nullptr;
+  Constant* LastVisitedConstant = nullptr;
 
 public:
   InstCostVisitor(const DataLayout &DL, BlockFrequencyInfo &BFI,
diff --git a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
index 548335d750e33d..63b74733f69c1f 100644
--- a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
+++ b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
@@ -185,10 +185,14 @@ Bonus InstCostVisitor::getUserBonus(Instruction *User, Value *Use, Constant *C)
   // We have already propagated a constant for this user.
   if (KnownConstants.contains(User))
     return {0, 0};
-
-  // Cache the iterator before visiting.
-  LastVisited = Use ? KnownConstants.insert({Use, C}).first
-                    : KnownConstants.end();
+  if (Use) {
+    KnownConstants.insert({Use, C});
+    LastVisitedUse = Use;
+    LastVisitedConstant = C;
+  } else {
+    LastVisitedUse = nullptr;
+    LastVisitedConstant = nullptr;
+  }
 
   Cost CodeSize = 0;
   if (auto *I = dyn_cast<SwitchInst>(User)) {
@@ -228,12 +232,12 @@ Bonus InstCostVisitor::getUserBonus(Instruction *User, Value *Use, Constant *C)
 }
 
 Cost InstCostVisitor::estimateSwitchInst(SwitchInst &I) {
-  assert(LastVisited != KnownConstants.end() && "Invalid iterator!");
+  assert(LastVisitedUse != nullptr && "missing last use!");
 
-  if (I.getCondition() != LastVisited->first)
+  if (I.getCondition() != LastVisitedUse)
     return 0;
 
-  auto *C = dyn_cast<ConstantInt>(LastVisited->second);
+  auto *C = dyn_cast<ConstantInt>(LastVisitedConstant);
   if (!C)
     return 0;
 
@@ -253,12 +257,12 @@ Cost InstCostVisitor::estimateSwitchInst(SwitchInst &I) {
 }
 
 Cost InstCostVisitor::estimateBranchInst(BranchInst &I) {
-  assert(LastVisited != KnownConstants.end() && "Invalid iterator!");
+  assert(LastVisitedUse != nullptr && "missing last use!");
 
-  if (I.getCondition() != LastVisited->first)
+  if (I.getCondition() != LastVisitedUse)
     return 0;
 
-  BasicBlock *Succ = I.getSuccessor(LastVisited->second->isOneValue());
+  BasicBlock *Succ = I.getSuccessor(LastVisitedConstant->isOneValue());
   // Initialize the worklist with the dead successor as long as
   // it is executable and has a unique predecessor.
   SmallVector<BasicBlock *> WorkList;
@@ -369,10 +373,10 @@ Constant *InstCostVisitor::visitPHINode(PHINode &I) {
 }
 
 Constant *InstCostVisitor::visitFreezeInst(FreezeInst &I) {
-  assert(LastVisited != KnownConstants.end() && "Invalid iterator!");
+  assert(LastVisitedUse != nullptr && "missing last use!");
 
-  if (isGuaranteedNotToBeUndefOrPoison(LastVisited->second))
-    return LastVisited->second;
+  if (isGuaranteedNotToBeUndefOrPoison(LastVisitedConstant))
+    return LastVisitedConstant;
   return nullptr;
 }
 
@@ -397,11 +401,11 @@ Constant *InstCostVisitor::visitCallBase(CallBase &I) {
 }
 
 Constant *InstCostVisitor::visitLoadInst(LoadInst &I) {
-  assert(LastVisited != KnownConstants.end() && "Invalid iterator!");
+  assert(LastVisitedUse != nullptr && "missing last use!");
 
-  if (isa<ConstantPointerNull>(LastVisited->second))
+  if (isa<ConstantPointerNull>(LastVisitedConstant))
     return nullptr;
-  return ConstantFoldLoadFromConstPtr(LastVisited->second, I.getType(), DL);
+  return ConstantFoldLoadFromConstPtr(LastVisitedConstant, I.getType(), DL);
 }
 
 Constant *InstCostVisitor::visitGetElementPtrInst(GetElementPtrInst &I) {
@@ -421,53 +425,53 @@ Constant *InstCostVisitor::visitGetElementPtrInst(GetElementPtrInst &I) {
 }
 
 Constant *InstCostVisitor::visitSelectInst(SelectInst &I) {
-  assert(LastVisited != KnownConstants.end() && "Invalid iterator!");
+  assert(LastVisitedUse != nullptr && "missing last use!");
 
-  if (I.getCondition() != LastVisited->first)
+  if (I.getCondition() != LastVisitedUse)
     return nullptr;
 
-  Value *V = LastVisited->second->isZeroValue() ? I.getFalseValue()
+  Value *V = LastVisitedConstant->isZeroValue() ? I.getFalseValue()
                                                 : I.getTrueValue();
   Constant *C = findConstantFor(V, KnownConstants);
   return C;
 }
 
 Constant *InstCostVisitor::visitCastInst(CastInst &I) {
-  return ConstantFoldCastOperand(I.getOpcode(), LastVisited->second,
+  return ConstantFoldCastOperand(I.getOpcode(), LastVisitedConstant,
                                  I.getType(), DL);
 }
 
 Constant *InstCostVisitor::visitCmpInst(CmpInst &I) {
-  assert(LastVisited != KnownConstants.end() && "Invalid iterator!");
+  assert(LastVisitedUse != nullptr && "missing last use!");
 
-  bool Swap = I.getOperand(1) == LastVisited->first;
+  bool Swap = I.getOperand(1) == LastVisitedUse;
   Value *V = Swap ? I.getOperand(0) : I.getOperand(1);
   Constant *Other = findConstantFor(V, KnownConstants);
   if (!Other)
     return nullptr;
 
-  Constant *Const = LastVisited->second;
+  Constant *Const = LastVisitedConstant;
   return Swap ?
         ConstantFoldCompareInstOperands(I.getPredicate(), Other, Const, DL)
       : ConstantFoldCompareInstOperands(I.getPredicate(), Const, Other, DL);
 }
 
 Constant *InstCostVisitor::visitUnaryOperator(UnaryOperator &I) {
-  assert(LastVisited != KnownConstants.end() && "Invalid iterator!");
+  assert(LastVisitedUse != nullptr && "missing last use!");
 
-  return ConstantFoldUnaryOpOperand(I.getOpcode(), LastVisited->second, DL);
+  return ConstantFoldUnaryOpOperand(I.getOpcode(), LastVisitedConstant, DL);
 }
 
 Constant *InstCostVisitor::visitBinaryOperator(BinaryOperator &I) {
-  assert(LastVisited != KnownConstants.end() && "Invalid iterator!");
+  assert(LastVisitedUse != nullptr && "missing last use!");
 
-  bool Swap = I.getOperand(1) == LastVisited->first;
+  bool Swap = I.getOperand(1) == LastVisitedUse;
   Value *V = Swap ? I.getOperand(0) : I.getOperand(1);
   Constant *Other = findConstantFor(V, KnownConstants);
   if (!Other)
     return nullptr;
 
-  Constant *Const = LastVisited->second;
+  Constant *Const = LastVisitedConstant;
   return dyn_cast_or_null<Constant>(Swap ?
         simplifyBinOp(I.getOpcode(), Other, Const, SimplifyQuery(DL))
       : simplifyBinOp(I.getOpcode(), Const, Other, SimplifyQuery(DL)));



More information about the llvm-commits mailing list