[llvm] f893dcc - Replace uses of ConstantExpr::getCompare. (#91558)

via llvm-commits llvm-commits at lists.llvm.org
Thu May 9 16:50:05 PDT 2024


Author: Eli Friedman
Date: 2024-05-09T16:50:01-07:00
New Revision: f893dccbba372792e7e7095d741f98a234654875

URL: https://github.com/llvm/llvm-project/commit/f893dccbba372792e7e7095d741f98a234654875
DIFF: https://github.com/llvm/llvm-project/commit/f893dccbba372792e7e7095d741f98a234654875.diff

LOG: Replace uses of ConstantExpr::getCompare. (#91558)

Use ICmpInst::compare() where possible, ConstantFoldCompareInstOperands
in other places. This only changes places where the either the fold is
guaranteed to succeed, or the code doesn't use the resulting compare if
we fail to fold.

Added: 
    

Modified: 
    llvm/include/llvm/Transforms/Scalar/JumpThreading.h
    llvm/lib/Analysis/BranchProbabilityInfo.cpp
    llvm/lib/Analysis/ConstantFolding.cpp
    llvm/lib/Analysis/InlineCost.cpp
    llvm/lib/Analysis/ScalarEvolution.cpp
    llvm/lib/IR/Constants.cpp
    llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp
    llvm/lib/Target/X86/X86InstCombineIntrinsic.cpp
    llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
    llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
    llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
    llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
    llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
    llvm/lib/Transforms/Scalar/JumpThreading.cpp
    llvm/lib/Transforms/Utils/SimplifyCFG.cpp
    llvm/test/Transforms/JumpThreading/thread-two-bbs.ll

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Transforms/Scalar/JumpThreading.h b/llvm/include/llvm/Transforms/Scalar/JumpThreading.h
index f7358ac9b1ee0..65d43775bdc1d 100644
--- a/llvm/include/llvm/Transforms/Scalar/JumpThreading.h
+++ b/llvm/include/llvm/Transforms/Scalar/JumpThreading.h
@@ -142,7 +142,7 @@ class JumpThreadingPass : public PassInfoMixin<JumpThreadingPass> {
   }
 
   Constant *evaluateOnPredecessorEdge(BasicBlock *BB, BasicBlock *PredPredBB,
-                                      Value *cond);
+                                      Value *cond, const DataLayout &DL);
   bool maybethreadThroughTwoBasicBlocks(BasicBlock *BB, Value *Cond);
   void threadThroughTwoBasicBlocks(BasicBlock *PredPredBB, BasicBlock *PredBB,
                                    BasicBlock *BB, BasicBlock *SuccBB);

diff  --git a/llvm/lib/Analysis/BranchProbabilityInfo.cpp b/llvm/lib/Analysis/BranchProbabilityInfo.cpp
index 36a2df6459132..cd3e3a4991327 100644
--- a/llvm/lib/Analysis/BranchProbabilityInfo.cpp
+++ b/llvm/lib/Analysis/BranchProbabilityInfo.cpp
@@ -630,8 +630,8 @@ computeUnlikelySuccessors(const BasicBlock *BB, Loop *L,
       if (!CmpLHSConst)
         continue;
       // Now constant-evaluate the compare
-      Constant *Result = ConstantExpr::getCompare(CI->getPredicate(),
-                                                  CmpLHSConst, CmpConst, true);
+      Constant *Result = ConstantFoldCompareInstOperands(
+          CI->getPredicate(), CmpLHSConst, CmpConst, DL);
       // If the result means we don't branch to the block then that block is
       // unlikely.
       if (Result &&

diff  --git a/llvm/lib/Analysis/ConstantFolding.cpp b/llvm/lib/Analysis/ConstantFolding.cpp
index 749374a3aa48a..046a769453808 100644
--- a/llvm/lib/Analysis/ConstantFolding.cpp
+++ b/llvm/lib/Analysis/ConstantFolding.cpp
@@ -1268,10 +1268,10 @@ Constant *llvm::ConstantFoldCompareInstOperands(
       Value *Stripped1 =
           Ops1->stripAndAccumulateInBoundsConstantOffsets(DL, Offset1);
       if (Stripped0 == Stripped1)
-        return ConstantExpr::getCompare(
-            ICmpInst::getSignedPredicate(Predicate),
-            ConstantInt::get(CE0->getContext(), Offset0),
-            ConstantInt::get(CE0->getContext(), Offset1));
+        return ConstantInt::getBool(
+            Ops0->getContext(),
+            ICmpInst::compare(Offset0, Offset1,
+                              ICmpInst::getSignedPredicate(Predicate)));
     }
   } else if (isa<ConstantExpr>(Ops1)) {
     // If RHS is a constant expression, but the left side isn't, swap the

diff  --git a/llvm/lib/Analysis/InlineCost.cpp b/llvm/lib/Analysis/InlineCost.cpp
index c75460f44c1d9..a531064e304d0 100644
--- a/llvm/lib/Analysis/InlineCost.cpp
+++ b/llvm/lib/Analysis/InlineCost.cpp
@@ -2046,13 +2046,11 @@ bool CallAnalyzer::visitCmpInst(CmpInst &I) {
     if (RHSBase && LHSBase == RHSBase) {
       // We have common bases, fold the icmp to a constant based on the
       // offsets.
-      Constant *CLHS = ConstantInt::get(LHS->getContext(), LHSOffset);
-      Constant *CRHS = ConstantInt::get(RHS->getContext(), RHSOffset);
-      if (Constant *C = ConstantExpr::getICmp(I.getPredicate(), CLHS, CRHS)) {
-        SimplifiedValues[&I] = C;
-        ++NumConstantPtrCmps;
-        return true;
-      }
+      SimplifiedValues[&I] = ConstantInt::getBool(
+          I.getType(),
+          ICmpInst::compare(LHSOffset, RHSOffset, I.getPredicate()));
+      ++NumConstantPtrCmps;
+      return true;
     }
   }
 

diff  --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 93f885c5d5ad8..7dc5aa084f3c3 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -10615,9 +10615,7 @@ bool ScalarEvolution::SimplifyICmpOperands(ICmpInst::Predicate &Pred,
   if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) {
     // Check for both operands constant.
     if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
-      if (ConstantExpr::getICmp(Pred,
-                                LHSC->getValue(),
-                                RHSC->getValue())->isNullValue())
+      if (!ICmpInst::compare(LHSC->getAPInt(), RHSC->getAPInt(), Pred))
         return TrivialCase(false);
       return TrivialCase(true);
     }

diff  --git a/llvm/lib/IR/Constants.cpp b/llvm/lib/IR/Constants.cpp
index 5268eccf70144..db442c54125a7 100644
--- a/llvm/lib/IR/Constants.cpp
+++ b/llvm/lib/IR/Constants.cpp
@@ -315,8 +315,8 @@ bool Constant::isElementWiseEqual(Value *Y) const {
   Type *IntTy = VectorType::getInteger(VTy);
   Constant *C0 = ConstantExpr::getBitCast(const_cast<Constant *>(this), IntTy);
   Constant *C1 = ConstantExpr::getBitCast(cast<Constant>(Y), IntTy);
-  Constant *CmpEq = ConstantExpr::getICmp(ICmpInst::ICMP_EQ, C0, C1);
-  return isa<PoisonValue>(CmpEq) || match(CmpEq, m_One());
+  Constant *CmpEq = ConstantFoldCompareInstruction(ICmpInst::ICMP_EQ, C0, C1);
+  return CmpEq && (isa<PoisonValue>(CmpEq) || match(CmpEq, m_One()));
 }
 
 static bool

diff  --git a/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp b/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp
index 5b7fa13f2e835..160a17584ca3a 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp
@@ -854,8 +854,9 @@ GCNTTIImpl::instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const {
 
     if (auto *CSrc0 = dyn_cast<Constant>(Src0)) {
       if (auto *CSrc1 = dyn_cast<Constant>(Src1)) {
-        Constant *CCmp = ConstantExpr::getCompare(CCVal, CSrc0, CSrc1);
-        if (CCmp->isNullValue()) {
+        Constant *CCmp = ConstantFoldCompareInstOperands(
+            (ICmpInst::Predicate)CCVal, CSrc0, CSrc1, DL);
+        if (CCmp && CCmp->isNullValue()) {
           return IC.replaceInstUsesWith(
               II, IC.Builder.CreateSExt(CCmp, II.getType()));
         }

diff  --git a/llvm/lib/Target/X86/X86InstCombineIntrinsic.cpp b/llvm/lib/Target/X86/X86InstCombineIntrinsic.cpp
index e46fc034cc269..8e75e185f0f66 100644
--- a/llvm/lib/Target/X86/X86InstCombineIntrinsic.cpp
+++ b/llvm/lib/Target/X86/X86InstCombineIntrinsic.cpp
@@ -26,20 +26,21 @@ using namespace llvm;
 
 /// Return a constant boolean vector that has true elements in all positions
 /// where the input constant data vector has an element with the sign bit set.
-static Constant *getNegativeIsTrueBoolVec(Constant *V) {
+static Constant *getNegativeIsTrueBoolVec(Constant *V, const DataLayout &DL) {
   VectorType *IntTy = VectorType::getInteger(cast<VectorType>(V->getType()));
   V = ConstantExpr::getBitCast(V, IntTy);
-  V = ConstantExpr::getICmp(CmpInst::ICMP_SGT, Constant::getNullValue(IntTy),
-                            V);
+  V = ConstantFoldCompareInstOperands(CmpInst::ICMP_SGT,
+                                      Constant::getNullValue(IntTy), V, DL);
+  assert(V && "Vector must be foldable");
   return V;
 }
 
 /// Convert the x86 XMM integer vector mask to a vector of bools based on
 /// each element's most significant bit (the sign bit).
-static Value *getBoolVecFromMask(Value *Mask) {
+static Value *getBoolVecFromMask(Value *Mask, const DataLayout &DL) {
   // Fold Constant Mask.
   if (auto *ConstantMask = dyn_cast<ConstantDataVector>(Mask))
-    return getNegativeIsTrueBoolVec(ConstantMask);
+    return getNegativeIsTrueBoolVec(ConstantMask, DL);
 
   // Mask was extended from a boolean vector.
   Value *ExtMask;
@@ -65,7 +66,7 @@ static Instruction *simplifyX86MaskedLoad(IntrinsicInst &II, InstCombiner &IC) {
 
   // The mask is constant or extended from a bool vector. Convert this x86
   // intrinsic to the LLVM intrinsic to allow target-independent optimizations.
-  if (Value *BoolMask = getBoolVecFromMask(Mask)) {
+  if (Value *BoolMask = getBoolVecFromMask(Mask, IC.getDataLayout())) {
     // First, cast the x86 intrinsic scalar pointer to a vector pointer to match
     // the LLVM intrinsic definition for the pointer argument.
     unsigned AddrSpace = cast<PointerType>(Ptr->getType())->getAddressSpace();
@@ -102,7 +103,7 @@ static bool simplifyX86MaskedStore(IntrinsicInst &II, InstCombiner &IC) {
 
   // The mask is constant or extended from a bool vector. Convert this x86
   // intrinsic to the LLVM intrinsic to allow target-independent optimizations.
-  if (Value *BoolMask = getBoolVecFromMask(Mask)) {
+  if (Value *BoolMask = getBoolVecFromMask(Mask, IC.getDataLayout())) {
     unsigned AddrSpace = cast<PointerType>(Ptr->getType())->getAddressSpace();
     PointerType *VecPtrTy = PointerType::get(Vec->getType(), AddrSpace);
     Value *PtrCast = IC.Builder.CreateBitCast(Ptr, VecPtrTy, "castvec");
@@ -2688,7 +2689,8 @@ X86TTIImpl::instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const {
 
     // Constant Mask - select 1st/2nd argument lane based on top bit of mask.
     if (auto *ConstantMask = dyn_cast<ConstantDataVector>(Mask)) {
-      Constant *NewSelector = getNegativeIsTrueBoolVec(ConstantMask);
+      Constant *NewSelector =
+          getNegativeIsTrueBoolVec(ConstantMask, IC.getDataLayout());
       return SelectInst::Create(NewSelector, Op1, Op0, "blendv");
     }
 

diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index a52c70dbdf3f4..8695e9e69df20 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -2504,8 +2504,8 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) {
           match(C1, m_Power2())) {
         Constant *Log2C1 = ConstantExpr::getExactLogBase2(C1);
         Constant *Cmp =
-            ConstantExpr::getCompare(ICmpInst::ICMP_ULT, Log2C3, C2);
-        if (Cmp->isZeroValue()) {
+            ConstantFoldCompareInstOperands(ICmpInst::ICMP_ULT, Log2C3, C2, DL);
+        if (Cmp && Cmp->isZeroValue()) {
           // iff C1,C3 is pow2 and Log2(C3) >= C2:
           // ((C1 >> X) << C2) & C3 -> X == (cttz(C1)+C2-cttz(C3)) ? C3 : 0
           Constant *ShlC = ConstantExpr::getAdd(C2, Log2C1);

diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index d7433ad3599f9..77534e0d36131 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -1982,7 +1982,8 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
       if (ModuloC != ShAmtC)
         return replaceOperand(*II, 2, ModuloC);
 
-      assert(match(ConstantExpr::getICmp(ICmpInst::ICMP_UGT, WidthC, ShAmtC),
+      assert(match(ConstantFoldCompareInstOperands(ICmpInst::ICMP_UGT, WidthC,
+                                                   ShAmtC, DL),
                    m_One()) &&
              "Shift amount expected to be modulo bitwidth");
 

diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 7092fb5e509bb..e1a3194a1beb7 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -3176,15 +3176,12 @@ Instruction *InstCombinerImpl::foldICmpSelectConstant(ICmpInst &Cmp,
                               C3GreaterThan)) {
     assert(C1LessThan && C2Equal && C3GreaterThan);
 
-    bool TrueWhenLessThan =
-        ConstantExpr::getCompare(Cmp.getPredicate(), C1LessThan, C)
-            ->isAllOnesValue();
-    bool TrueWhenEqual =
-        ConstantExpr::getCompare(Cmp.getPredicate(), C2Equal, C)
-            ->isAllOnesValue();
-    bool TrueWhenGreaterThan =
-        ConstantExpr::getCompare(Cmp.getPredicate(), C3GreaterThan, C)
-            ->isAllOnesValue();
+    bool TrueWhenLessThan = ICmpInst::compare(
+        C1LessThan->getValue(), C->getValue(), Cmp.getPredicate());
+    bool TrueWhenEqual = ICmpInst::compare(C2Equal->getValue(), C->getValue(),
+                                           Cmp.getPredicate());
+    bool TrueWhenGreaterThan = ICmpInst::compare(
+        C3GreaterThan->getValue(), C->getValue(), Cmp.getPredicate());
 
     // This generates the new instruction that will replace the original Cmp
     // Instruction. Instead of enumerating the various combinations when

diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 8818369e79452..ee090e0125082 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -1365,7 +1365,8 @@ Instruction *InstCombinerImpl::foldSelectValueEquivalence(SelectInst &Sel,
 // Also ULT predicate can also be UGT iff C0 != -1 (+invert result)
 //      SLT predicate can also be SGT iff C2 != INT_MAX (+invert res.)
 static Value *canonicalizeClampLike(SelectInst &Sel0, ICmpInst &Cmp0,
-                                    InstCombiner::BuilderTy &Builder) {
+                                    InstCombiner::BuilderTy &Builder,
+                                    InstCombiner &IC) {
   Value *X = Sel0.getTrueValue();
   Value *Sel1 = Sel0.getFalseValue();
 
@@ -1493,14 +1494,14 @@ static Value *canonicalizeClampLike(SelectInst &Sel0, ICmpInst &Cmp0,
     std::swap(ThresholdLowIncl, ThresholdHighExcl);
 
   // The fold has a precondition 1: C2 s>= ThresholdLow
-  auto *Precond1 = ConstantExpr::getICmp(ICmpInst::Predicate::ICMP_SGE, C2,
-                                         ThresholdLowIncl);
-  if (!match(Precond1, m_One()))
+  auto *Precond1 = ConstantFoldCompareInstOperands(
+      ICmpInst::Predicate::ICMP_SGE, C2, ThresholdLowIncl, IC.getDataLayout());
+  if (!Precond1 || !match(Precond1, m_One()))
     return nullptr;
   // The fold has a precondition 2: C2 s<= ThresholdHigh
-  auto *Precond2 = ConstantExpr::getICmp(ICmpInst::Predicate::ICMP_SLE, C2,
-                                         ThresholdHighExcl);
-  if (!match(Precond2, m_One()))
+  auto *Precond2 = ConstantFoldCompareInstOperands(
+      ICmpInst::Predicate::ICMP_SLE, C2, ThresholdHighExcl, IC.getDataLayout());
+  if (!Precond2 || !match(Precond2, m_One()))
     return nullptr;
 
   // If we are matching from a truncated input, we need to sext the
@@ -1803,7 +1804,7 @@ Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI,
   if (Value *V = foldSelectInstWithICmpConst(SI, ICI, Builder))
     return replaceInstUsesWith(SI, V);
 
-  if (Value *V = canonicalizeClampLike(SI, *ICI, Builder))
+  if (Value *V = canonicalizeClampLike(SI, *ICI, Builder, *this))
     return replaceInstUsesWith(SI, V);
 
   if (Instruction *NewSel =

diff  --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index b6f8b24f43b8c..6c25ff215c375 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -808,9 +808,12 @@ Instruction *InstCombinerImpl::tryFoldInstWithCtpopWithNot(Instruction *I) {
   Constant *BitWidthC = ConstantInt::get(Ty, Ty->getScalarSizeInBits());
   // Need extra check for icmp. Note if this check is true, it generally means
   // the icmp will simplify to true/false.
-  if (Opc == Instruction::ICmp && !cast<ICmpInst>(I)->isEquality() &&
-      !ConstantExpr::getICmp(ICmpInst::ICMP_UGT, C, BitWidthC)->isZeroValue())
-    return nullptr;
+  if (Opc == Instruction::ICmp && !cast<ICmpInst>(I)->isEquality()) {
+    Constant *Cmp =
+        ConstantFoldCompareInstOperands(ICmpInst::ICMP_UGT, C, BitWidthC, DL);
+    if (!Cmp || !Cmp->isZeroValue())
+      return nullptr;
+  }
 
   // Check we can invert `(not x)` for free.
   bool Consumes = false;

diff  --git a/llvm/lib/Transforms/Scalar/JumpThreading.cpp b/llvm/lib/Transforms/Scalar/JumpThreading.cpp
index 08d82fa66da30..802467b5b1835 100644
--- a/llvm/lib/Transforms/Scalar/JumpThreading.cpp
+++ b/llvm/lib/Transforms/Scalar/JumpThreading.cpp
@@ -868,7 +868,8 @@ bool JumpThreadingPass::computeValueKnownInPredecessorsImpl(
 
       for (const auto &LHSVal : LHSVals) {
         Constant *V = LHSVal.first;
-        Constant *Folded = ConstantExpr::getCompare(Pred, V, CmpConst);
+        Constant *Folded =
+            ConstantFoldCompareInstOperands(Pred, V, CmpConst, DL);
         if (Constant *KC = getKnownConstant(Folded, WantInteger))
           Result.emplace_back(KC, LHSVal.second);
       }
@@ -1509,7 +1510,8 @@ findMostPopularDest(BasicBlock *BB,
 // BB->getSinglePredecessor() and then on to BB.
 Constant *JumpThreadingPass::evaluateOnPredecessorEdge(BasicBlock *BB,
                                                        BasicBlock *PredPredBB,
-                                                       Value *V) {
+                                                       Value *V,
+                                                       const DataLayout &DL) {
   BasicBlock *PredBB = BB->getSinglePredecessor();
   assert(PredBB && "Expected a single predecessor");
 
@@ -1534,11 +1536,12 @@ Constant *JumpThreadingPass::evaluateOnPredecessorEdge(BasicBlock *BB,
   if (CmpInst *CondCmp = dyn_cast<CmpInst>(V)) {
     if (CondCmp->getParent() == BB) {
       Constant *Op0 =
-          evaluateOnPredecessorEdge(BB, PredPredBB, CondCmp->getOperand(0));
+          evaluateOnPredecessorEdge(BB, PredPredBB, CondCmp->getOperand(0), DL);
       Constant *Op1 =
-          evaluateOnPredecessorEdge(BB, PredPredBB, CondCmp->getOperand(1));
+          evaluateOnPredecessorEdge(BB, PredPredBB, CondCmp->getOperand(1), DL);
       if (Op0 && Op1) {
-        return ConstantExpr::getCompare(CondCmp->getPredicate(), Op0, Op1);
+        return ConstantFoldCompareInstOperands(CondCmp->getPredicate(), Op0,
+                                               Op1, DL);
       }
     }
     return nullptr;
@@ -2191,12 +2194,13 @@ bool JumpThreadingPass::maybethreadThroughTwoBasicBlocks(BasicBlock *BB,
   unsigned OneCount = 0;
   BasicBlock *ZeroPred = nullptr;
   BasicBlock *OnePred = nullptr;
+  const DataLayout &DL = BB->getModule()->getDataLayout();
   for (BasicBlock *P : predecessors(PredBB)) {
     // If PredPred ends with IndirectBrInst, we can't handle it.
     if (isa<IndirectBrInst>(P->getTerminator()))
       continue;
     if (ConstantInt *CI = dyn_cast_or_null<ConstantInt>(
-            evaluateOnPredecessorEdge(BB, P, Cond))) {
+            evaluateOnPredecessorEdge(BB, P, Cond, DL))) {
       if (CI->isZero()) {
         ZeroCount++;
         ZeroPred = P;

diff  --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
index 23a896c59bf66..93701b2a77916 100644
--- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
@@ -6580,16 +6580,17 @@ static void reuseTableCompare(
   Constant *FalseConst = ConstantInt::getFalse(RangeCmp->getType());
 
   // Check if the compare with the default value is constant true or false.
-  Constant *DefaultConst = ConstantExpr::getICmp(CmpInst->getPredicate(),
-                                                 DefaultValue, CmpOp1, true);
+  const DataLayout &DL = PhiBlock->getModule()->getDataLayout();
+  Constant *DefaultConst = ConstantFoldCompareInstOperands(
+      CmpInst->getPredicate(), DefaultValue, CmpOp1, DL);
   if (DefaultConst != TrueConst && DefaultConst != FalseConst)
     return;
 
   // Check if the compare with the case values is distinct from the default
   // compare result.
   for (auto ValuePair : Values) {
-    Constant *CaseConst = ConstantExpr::getICmp(CmpInst->getPredicate(),
-                                                ValuePair.second, CmpOp1, true);
+    Constant *CaseConst = ConstantFoldCompareInstOperands(
+        CmpInst->getPredicate(), ValuePair.second, CmpOp1, DL);
     if (!CaseConst || CaseConst == DefaultConst ||
         (CaseConst != TrueConst && CaseConst != FalseConst))
       return;

diff  --git a/llvm/test/Transforms/JumpThreading/thread-two-bbs.ll b/llvm/test/Transforms/JumpThreading/thread-two-bbs.ll
index f7e6b2189dc8b..09394a9462417 100644
--- a/llvm/test/Transforms/JumpThreading/thread-two-bbs.ll
+++ b/llvm/test/Transforms/JumpThreading/thread-two-bbs.ll
@@ -130,8 +130,8 @@ exit:
 }
 
 
-; Verify that we do *not* thread any edge.  We used to evaluate
-; constant expressions like:
+; Verify that we thread the edge correctly.  We used to evaluate constant
+; expressions like:
 ;
 ;   icmp ugt ptr null, inttoptr (i64 4 to ptr)
 ;
@@ -141,16 +141,17 @@ define void @icmp_ult_null_constexpr(ptr %arg1, ptr %arg2) {
 ; CHECK-LABEL: @icmp_ult_null_constexpr(
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[CMP1:%.*]] = icmp eq ptr [[ARG1:%.*]], null
-; CHECK-NEXT:    br i1 [[CMP1]], label [[BB_BAR1:%.*]], label [[BB_END:%.*]]
-; CHECK:       bb_bar1:
-; CHECK-NEXT:    call void @bar(i32 1)
-; CHECK-NEXT:    br label [[BB_END]]
+; CHECK-NEXT:    br i1 [[CMP1]], label [[BB_END_THREAD:%.*]], label [[BB_END:%.*]]
 ; CHECK:       bb_end:
 ; CHECK-NEXT:    [[CMP2:%.*]] = icmp ne ptr [[ARG2:%.*]], null
 ; CHECK-NEXT:    br i1 [[CMP2]], label [[BB_CONT:%.*]], label [[BB_BAR2:%.*]]
+; CHECK:       bb_end.thread:
+; CHECK-NEXT:    call void @bar(i32 1)
+; CHECK-NEXT:    [[CMP21:%.*]] = icmp ne ptr [[ARG2]], null
+; CHECK-NEXT:    br i1 [[CMP21]], label [[BB_EXIT:%.*]], label [[BB_BAR2]]
 ; CHECK:       bb_bar2:
 ; CHECK-NEXT:    call void @bar(i32 2)
-; CHECK-NEXT:    br label [[BB_EXIT:%.*]]
+; CHECK-NEXT:    br label [[BB_EXIT]]
 ; CHECK:       bb_cont:
 ; CHECK-NEXT:    [[CMP3:%.*]] = icmp ult ptr [[ARG1]], inttoptr (i64 4 to ptr)
 ; CHECK-NEXT:    br i1 [[CMP3]], label [[BB_EXIT]], label [[BB_BAR3:%.*]]


        


More information about the llvm-commits mailing list