[llvm] c5592f7 - [InstCombine] Fix use after free when removing unreachable code (PR64235)

Nikita Popov via llvm-commits llvm-commits at lists.llvm.org
Tue Aug 1 01:19:35 PDT 2023


Author: Nikita Popov
Date: 2023-08-01T10:19:27+02:00
New Revision: c5592f7acdf2e505b0431f44775b4339444a4711

URL: https://github.com/llvm/llvm-project/commit/c5592f7acdf2e505b0431f44775b4339444a4711
DIFF: https://github.com/llvm/llvm-project/commit/c5592f7acdf2e505b0431f44775b4339444a4711.diff

LOG: [InstCombine] Fix use after free when removing unreachable code (PR64235)

In degenerate cases, it is possible for unreachable code removal
to remove the current instruction. However, we still return the
instruction to report a change, resulting in a use after free.

Instead, perform the change reporting in the same way as
eraseInstFromFunction() does, by directly setting MadeIRChange
and returning nullptr.

Fixes https://github.com/llvm/llvm-project/issues/64235.

Added: 
    

Modified: 
    llvm/lib/Transforms/InstCombine/InstCombineInternal.h
    llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp
    llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
    llvm/test/Transforms/InstCombine/unreachable-code.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index 16812467f4c427..bccbc953e1f16e 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -672,10 +672,10 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
   bool tryToSinkInstruction(Instruction *I, BasicBlock *DestBlock);
 
   bool removeInstructionsBeforeUnreachable(Instruction &I);
-  bool handleUnreachableFrom(Instruction *I,
+  void handleUnreachableFrom(Instruction *I,
                              SmallVectorImpl<BasicBlock *> &Worklist);
-  bool handlePotentiallyDeadBlocks(SmallVectorImpl<BasicBlock *> &Worklist);
-  bool handlePotentiallyDeadSuccessors(BasicBlock *BB, BasicBlock *LiveSucc);
+  void handlePotentiallyDeadBlocks(SmallVectorImpl<BasicBlock *> &Worklist);
+  void handlePotentiallyDeadSuccessors(BasicBlock *BB, BasicBlock *LiveSucc);
   void freelyInvertAllUsersOf(Value *V, Value *IgnoredUser = nullptr);
 };
 

diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp
index c69098299c91a5..939ef8e5e67731 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp
@@ -1560,10 +1560,8 @@ Instruction *InstCombinerImpl::visitStoreInst(StoreInst &SI) {
     // Remove all instructions after the marker and handle dead blocks this
     // implies.
     SmallVector<BasicBlock *> Worklist;
-    bool Changed = handleUnreachableFrom(SI.getNextNode(), Worklist);
-    Changed |= handlePotentiallyDeadBlocks(Worklist);
-    if (Changed)
-      return &SI;
+    handleUnreachableFrom(SI.getNextNode(), Worklist);
+    handlePotentiallyDeadBlocks(Worklist);
     return nullptr;
   }
 

diff  --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index beea3dd259a808..82fd76db8f2b4b 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -2749,22 +2749,21 @@ Instruction *InstCombinerImpl::visitUnconditionalBranchInst(BranchInst &BI) {
 }
 
 // Under the assumption that I is unreachable, remove it and following
-// instructions.
-bool InstCombinerImpl::handleUnreachableFrom(
+// instructions. Changes are reported directly to MadeIRChange.
+void InstCombinerImpl::handleUnreachableFrom(
     Instruction *I, SmallVectorImpl<BasicBlock *> &Worklist) {
-  bool Changed = false;
   BasicBlock *BB = I->getParent();
   for (Instruction &Inst : make_early_inc_range(
            make_range(std::next(BB->getTerminator()->getReverseIterator()),
                       std::next(I->getReverseIterator())))) {
     if (!Inst.use_empty() && !Inst.getType()->isTokenTy()) {
       replaceInstUsesWith(Inst, PoisonValue::get(Inst.getType()));
-      Changed = true;
+      MadeIRChange = true;
     }
     if (Inst.isEHPad() || Inst.getType()->isTokenTy())
       continue;
     eraseInstFromFunction(Inst);
-    Changed = true;
+    MadeIRChange = true;
   }
 
   // Replace phi node operands in successor blocks with poison.
@@ -2774,19 +2773,17 @@ bool InstCombinerImpl::handleUnreachableFrom(
         if (PN.getIncomingBlock(U) == BB && !isa<PoisonValue>(U)) {
           replaceUse(U, PoisonValue::get(PN.getType()));
           addToWorklist(&PN);
-          Changed = true;
+          MadeIRChange = true;
         }
 
   // Handle potentially dead successors.
   for (BasicBlock *Succ : successors(BB))
     if (DeadEdges.insert({BB, Succ}).second)
       Worklist.push_back(Succ);
-  return Changed;
 }
 
-bool InstCombinerImpl::handlePotentiallyDeadBlocks(
+void InstCombinerImpl::handlePotentiallyDeadBlocks(
     SmallVectorImpl<BasicBlock *> &Worklist) {
-  bool Changed = false;
   while (!Worklist.empty()) {
     BasicBlock *BB = Worklist.pop_back_val();
     if (!all_of(predecessors(BB), [&](BasicBlock *Pred) {
@@ -2794,12 +2791,11 @@ bool InstCombinerImpl::handlePotentiallyDeadBlocks(
         }))
       continue;
 
-    Changed |= handleUnreachableFrom(&BB->front(), Worklist);
+    handleUnreachableFrom(&BB->front(), Worklist);
   }
-  return Changed;
 }
 
-bool InstCombinerImpl::handlePotentiallyDeadSuccessors(BasicBlock *BB,
+void InstCombinerImpl::handlePotentiallyDeadSuccessors(BasicBlock *BB,
                                                        BasicBlock *LiveSucc) {
   SmallVector<BasicBlock *> Worklist;
   for (BasicBlock *Succ : successors(BB)) {
@@ -2811,7 +2807,7 @@ bool InstCombinerImpl::handlePotentiallyDeadSuccessors(BasicBlock *BB,
       Worklist.push_back(Succ);
   }
 
-  return handlePotentiallyDeadBlocks(Worklist);
+  handlePotentiallyDeadBlocks(Worklist);
 }
 
 Instruction *InstCombinerImpl::visitBranchInst(BranchInst &BI) {
@@ -2857,13 +2853,15 @@ Instruction *InstCombinerImpl::visitBranchInst(BranchInst &BI) {
     return &BI;
   }
 
-  if (isa<UndefValue>(Cond) &&
-      handlePotentiallyDeadSuccessors(BI.getParent(), /*LiveSucc*/ nullptr))
-    return &BI;
-  if (auto *CI = dyn_cast<ConstantInt>(Cond))
-    if (handlePotentiallyDeadSuccessors(BI.getParent(),
-                                        BI.getSuccessor(!CI->getZExtValue())))
-      return &BI;
+  if (isa<UndefValue>(Cond)) {
+    handlePotentiallyDeadSuccessors(BI.getParent(), /*LiveSucc*/ nullptr);
+    return nullptr;
+  }
+  if (auto *CI = dyn_cast<ConstantInt>(Cond)) {
+    handlePotentiallyDeadSuccessors(BI.getParent(),
+                                    BI.getSuccessor(!CI->getZExtValue()));
+    return nullptr;
+  }
 
   return nullptr;
 }
@@ -2883,14 +2881,6 @@ Instruction *InstCombinerImpl::visitSwitchInst(SwitchInst &SI) {
     return replaceOperand(SI, 0, Op0);
   }
 
-  if (isa<UndefValue>(Cond) &&
-      handlePotentiallyDeadSuccessors(SI.getParent(), /*LiveSucc*/ nullptr))
-    return &SI;
-  if (auto *CI = dyn_cast<ConstantInt>(Cond))
-    if (handlePotentiallyDeadSuccessors(
-            SI.getParent(), SI.findCaseValue(CI)->getCaseSuccessor()))
-      return &SI;
-
   KnownBits Known = computeKnownBits(Cond, 0, &SI);
   unsigned LeadingKnownZeros = Known.countMinLeadingZeros();
   unsigned LeadingKnownOnes = Known.countMinLeadingOnes();
@@ -2923,6 +2913,16 @@ Instruction *InstCombinerImpl::visitSwitchInst(SwitchInst &SI) {
     return replaceOperand(SI, 0, NewCond);
   }
 
+  if (isa<UndefValue>(Cond)) {
+    handlePotentiallyDeadSuccessors(SI.getParent(), /*LiveSucc*/ nullptr);
+    return nullptr;
+  }
+  if (auto *CI = dyn_cast<ConstantInt>(Cond)) {
+    handlePotentiallyDeadSuccessors(SI.getParent(),
+                                    SI.findCaseValue(CI)->getCaseSuccessor());
+    return nullptr;
+  }
+
   return nullptr;
 }
 

diff  --git a/llvm/test/Transforms/InstCombine/unreachable-code.ll b/llvm/test/Transforms/InstCombine/unreachable-code.ll
index 6a05a2804bf6df..c886257e0d7fc4 100644
--- a/llvm/test/Transforms/InstCombine/unreachable-code.ll
+++ b/llvm/test/Transforms/InstCombine/unreachable-code.ll
@@ -3,6 +3,7 @@
 ; RUN: opt -S -passes='instcombine<max-iterations=1>' < %s | FileCheck %s --check-prefixes=CHECK,MAX1
 
 declare void @dummy()
+declare void @llvm.assume(i1)
 
 define i32 @br_true(i1 %x) {
 ; CHECK-LABEL: define i32 @br_true
@@ -465,6 +466,33 @@ exit:
   ret void
 }
 
+define i32 @pr64235() {
+; CHECK-LABEL: define i32 @pr64235() {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    br i1 false, label [[BB:%.*]], label [[BB3:%.*]]
+; CHECK:       bb3:
+; CHECK-NEXT:    store i1 true, ptr poison, align 1
+; CHECK-NEXT:    br label [[BB2:%.*]]
+; CHECK:       bb:
+; CHECK-NEXT:    br label [[BB2]]
+; CHECK:       bb2:
+; CHECK-NEXT:    br label [[BB]]
+;
+entry:
+  br i1 false, label %bb, label %bb3
+
+bb3:
+  call void @llvm.assume(i1 false)
+  br label %bb2
+
+bb:
+  br label %bb2
+
+bb2:
+  call void @llvm.assume(i1 false)
+  br label %bb
+}
+
 ;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
 ; DEFAULT_ITER: {{.*}}
 ; MAX1: {{.*}}


        


More information about the llvm-commits mailing list