[llvm] [NFC] Replace uses of ConstantExpr::getCompare. (PR #91558)

Eli Friedman via llvm-commits llvm-commits at lists.llvm.org
Wed May 8 21:58:29 PDT 2024


https://github.com/efriedma-quic created https://github.com/llvm/llvm-project/pull/91558

Use ICmpInst::compare() where possible, ConstantFoldCompareInstruction in other places.  This only changes places where the either the fold is guaranteed to succeed, or the code doesn't use the resulting compare if we fail to fold.

>From 4cb3721d6030179a5f19cd4a46e790968d7f0a45 Mon Sep 17 00:00:00 2001
From: Eli Friedman <efriedma at quicinc.com>
Date: Wed, 8 May 2024 21:01:26 -0700
Subject: [PATCH] [NFC] Replace uses of ConstantExpr::getCompare.

Use ICmpInst::compare() where possible, ConstantFoldCompareInstruction
in other places.  This only changes places where the either the fold is
guaranteed to succeed, or the code doesn't use the resulting compare if
we fail to fold.
---
 llvm/lib/Analysis/BranchProbabilityInfo.cpp       |  5 +++--
 llvm/lib/Analysis/ConstantFolding.cpp             |  8 ++++----
 llvm/lib/Analysis/InlineCost.cpp                  | 12 +++++-------
 llvm/lib/Analysis/ScalarEvolution.cpp             |  4 +---
 llvm/lib/IR/Constants.cpp                         |  4 ++--
 .../Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp  |  5 +++--
 llvm/lib/Target/X86/X86InstCombineIntrinsic.cpp   |  5 +++--
 .../InstCombine/InstCombineAndOrXor.cpp           |  4 ++--
 .../Transforms/InstCombine/InstCombineCalls.cpp   |  3 ++-
 .../InstCombine/InstCombineCompares.cpp           | 15 ++++++---------
 .../Transforms/InstCombine/InstCombineSelect.cpp  | 12 ++++++------
 .../InstCombine/InstructionCombining.cpp          |  9 ++++++---
 llvm/lib/Transforms/Scalar/JumpThreading.cpp      |  6 ++++--
 llvm/lib/Transforms/Utils/SimplifyCFG.cpp         |  8 ++++----
 14 files changed, 51 insertions(+), 49 deletions(-)

diff --git a/llvm/lib/Analysis/BranchProbabilityInfo.cpp b/llvm/lib/Analysis/BranchProbabilityInfo.cpp
index 36a2df6459132..0ef9ff836be29 100644
--- a/llvm/lib/Analysis/BranchProbabilityInfo.cpp
+++ b/llvm/lib/Analysis/BranchProbabilityInfo.cpp
@@ -22,6 +22,7 @@
 #include "llvm/IR/Attributes.h"
 #include "llvm/IR/BasicBlock.h"
 #include "llvm/IR/CFG.h"
+#include "llvm/IR/ConstantFold.h"
 #include "llvm/IR/Constants.h"
 #include "llvm/IR/Dominators.h"
 #include "llvm/IR/Function.h"
@@ -630,8 +631,8 @@ computeUnlikelySuccessors(const BasicBlock *BB, Loop *L,
       if (!CmpLHSConst)
         continue;
       // Now constant-evaluate the compare
-      Constant *Result = ConstantExpr::getCompare(CI->getPredicate(),
-                                                  CmpLHSConst, CmpConst, true);
+      Constant *Result = ConstantFoldCompareInstruction(CI->getPredicate(),
+                                                        CmpLHSConst, CmpConst);
       // If the result means we don't branch to the block then that block is
       // unlikely.
       if (Result &&
diff --git a/llvm/lib/Analysis/ConstantFolding.cpp b/llvm/lib/Analysis/ConstantFolding.cpp
index 749374a3aa48a..046a769453808 100644
--- a/llvm/lib/Analysis/ConstantFolding.cpp
+++ b/llvm/lib/Analysis/ConstantFolding.cpp
@@ -1268,10 +1268,10 @@ Constant *llvm::ConstantFoldCompareInstOperands(
       Value *Stripped1 =
           Ops1->stripAndAccumulateInBoundsConstantOffsets(DL, Offset1);
       if (Stripped0 == Stripped1)
-        return ConstantExpr::getCompare(
-            ICmpInst::getSignedPredicate(Predicate),
-            ConstantInt::get(CE0->getContext(), Offset0),
-            ConstantInt::get(CE0->getContext(), Offset1));
+        return ConstantInt::getBool(
+            Ops0->getContext(),
+            ICmpInst::compare(Offset0, Offset1,
+                              ICmpInst::getSignedPredicate(Predicate)));
     }
   } else if (isa<ConstantExpr>(Ops1)) {
     // If RHS is a constant expression, but the left side isn't, swap the
diff --git a/llvm/lib/Analysis/InlineCost.cpp b/llvm/lib/Analysis/InlineCost.cpp
index c75460f44c1d9..a531064e304d0 100644
--- a/llvm/lib/Analysis/InlineCost.cpp
+++ b/llvm/lib/Analysis/InlineCost.cpp
@@ -2046,13 +2046,11 @@ bool CallAnalyzer::visitCmpInst(CmpInst &I) {
     if (RHSBase && LHSBase == RHSBase) {
       // We have common bases, fold the icmp to a constant based on the
       // offsets.
-      Constant *CLHS = ConstantInt::get(LHS->getContext(), LHSOffset);
-      Constant *CRHS = ConstantInt::get(RHS->getContext(), RHSOffset);
-      if (Constant *C = ConstantExpr::getICmp(I.getPredicate(), CLHS, CRHS)) {
-        SimplifiedValues[&I] = C;
-        ++NumConstantPtrCmps;
-        return true;
-      }
+      SimplifiedValues[&I] = ConstantInt::getBool(
+          I.getType(),
+          ICmpInst::compare(LHSOffset, RHSOffset, I.getPredicate()));
+      ++NumConstantPtrCmps;
+      return true;
     }
   }
 
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 93f885c5d5ad8..7dc5aa084f3c3 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -10615,9 +10615,7 @@ bool ScalarEvolution::SimplifyICmpOperands(ICmpInst::Predicate &Pred,
   if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) {
     // Check for both operands constant.
     if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
-      if (ConstantExpr::getICmp(Pred,
-                                LHSC->getValue(),
-                                RHSC->getValue())->isNullValue())
+      if (!ICmpInst::compare(LHSC->getAPInt(), RHSC->getAPInt(), Pred))
         return TrivialCase(false);
       return TrivialCase(true);
     }
diff --git a/llvm/lib/IR/Constants.cpp b/llvm/lib/IR/Constants.cpp
index 5268eccf70144..db442c54125a7 100644
--- a/llvm/lib/IR/Constants.cpp
+++ b/llvm/lib/IR/Constants.cpp
@@ -315,8 +315,8 @@ bool Constant::isElementWiseEqual(Value *Y) const {
   Type *IntTy = VectorType::getInteger(VTy);
   Constant *C0 = ConstantExpr::getBitCast(const_cast<Constant *>(this), IntTy);
   Constant *C1 = ConstantExpr::getBitCast(cast<Constant>(Y), IntTy);
-  Constant *CmpEq = ConstantExpr::getICmp(ICmpInst::ICMP_EQ, C0, C1);
-  return isa<PoisonValue>(CmpEq) || match(CmpEq, m_One());
+  Constant *CmpEq = ConstantFoldCompareInstruction(ICmpInst::ICMP_EQ, C0, C1);
+  return CmpEq && (isa<PoisonValue>(CmpEq) || match(CmpEq, m_One()));
 }
 
 static bool
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp b/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp
index 5b7fa13f2e835..5cb2de20f0957 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp
@@ -854,8 +854,9 @@ GCNTTIImpl::instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const {
 
     if (auto *CSrc0 = dyn_cast<Constant>(Src0)) {
       if (auto *CSrc1 = dyn_cast<Constant>(Src1)) {
-        Constant *CCmp = ConstantExpr::getCompare(CCVal, CSrc0, CSrc1);
-        if (CCmp->isNullValue()) {
+        Constant *CCmp = ConstantFoldCompareInstruction(
+            (ICmpInst::Predicate)CCVal, CSrc0, CSrc1);
+        if (CCmp && CCmp->isNullValue()) {
           return IC.replaceInstUsesWith(
               II, IC.Builder.CreateSExt(CCmp, II.getType()));
         }
diff --git a/llvm/lib/Target/X86/X86InstCombineIntrinsic.cpp b/llvm/lib/Target/X86/X86InstCombineIntrinsic.cpp
index e46fc034cc269..4ecc29317d6b2 100644
--- a/llvm/lib/Target/X86/X86InstCombineIntrinsic.cpp
+++ b/llvm/lib/Target/X86/X86InstCombineIntrinsic.cpp
@@ -29,8 +29,9 @@ using namespace llvm;
 static Constant *getNegativeIsTrueBoolVec(Constant *V) {
   VectorType *IntTy = VectorType::getInteger(cast<VectorType>(V->getType()));
   V = ConstantExpr::getBitCast(V, IntTy);
-  V = ConstantExpr::getICmp(CmpInst::ICMP_SGT, Constant::getNullValue(IntTy),
-                            V);
+  V = ConstantFoldCompareInstruction(CmpInst::ICMP_SGT,
+                                     Constant::getNullValue(IntTy), V);
+  assert(V && "Vector must be foldable");
   return V;
 }
 
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index ed9a89b14efcc..fe7716f26a5d7 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -2504,8 +2504,8 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) {
           match(C1, m_Power2())) {
         Constant *Log2C1 = ConstantExpr::getExactLogBase2(C1);
         Constant *Cmp =
-            ConstantExpr::getCompare(ICmpInst::ICMP_ULT, Log2C3, C2);
-        if (Cmp->isZeroValue()) {
+            ConstantFoldCompareInstruction(ICmpInst::ICMP_ULT, Log2C3, C2);
+        if (Cmp && Cmp->isZeroValue()) {
           // iff C1,C3 is pow2 and Log2(C3) >= C2:
           // ((C1 >> X) << C2) & C3 -> X == (cttz(C1)+C2-cttz(C3)) ? C3 : 0
           Constant *ShlC = ConstantExpr::getAdd(C2, Log2C1);
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index d7433ad3599f9..4255864331ecd 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -1982,7 +1982,8 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
       if (ModuloC != ShAmtC)
         return replaceOperand(*II, 2, ModuloC);
 
-      assert(match(ConstantExpr::getICmp(ICmpInst::ICMP_UGT, WidthC, ShAmtC),
+      assert(match(ConstantFoldCompareInstruction(ICmpInst::ICMP_UGT, WidthC,
+                                                  ShAmtC),
                    m_One()) &&
              "Shift amount expected to be modulo bitwidth");
 
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 7092fb5e509bb..e1a3194a1beb7 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -3176,15 +3176,12 @@ Instruction *InstCombinerImpl::foldICmpSelectConstant(ICmpInst &Cmp,
                               C3GreaterThan)) {
     assert(C1LessThan && C2Equal && C3GreaterThan);
 
-    bool TrueWhenLessThan =
-        ConstantExpr::getCompare(Cmp.getPredicate(), C1LessThan, C)
-            ->isAllOnesValue();
-    bool TrueWhenEqual =
-        ConstantExpr::getCompare(Cmp.getPredicate(), C2Equal, C)
-            ->isAllOnesValue();
-    bool TrueWhenGreaterThan =
-        ConstantExpr::getCompare(Cmp.getPredicate(), C3GreaterThan, C)
-            ->isAllOnesValue();
+    bool TrueWhenLessThan = ICmpInst::compare(
+        C1LessThan->getValue(), C->getValue(), Cmp.getPredicate());
+    bool TrueWhenEqual = ICmpInst::compare(C2Equal->getValue(), C->getValue(),
+                                           Cmp.getPredicate());
+    bool TrueWhenGreaterThan = ICmpInst::compare(
+        C3GreaterThan->getValue(), C->getValue(), Cmp.getPredicate());
 
     // This generates the new instruction that will replace the original Cmp
     // Instruction. Instead of enumerating the various combinations when
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 8818369e79452..29462e2214242 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -1493,14 +1493,14 @@ static Value *canonicalizeClampLike(SelectInst &Sel0, ICmpInst &Cmp0,
     std::swap(ThresholdLowIncl, ThresholdHighExcl);
 
   // The fold has a precondition 1: C2 s>= ThresholdLow
-  auto *Precond1 = ConstantExpr::getICmp(ICmpInst::Predicate::ICMP_SGE, C2,
-                                         ThresholdLowIncl);
-  if (!match(Precond1, m_One()))
+  auto *Precond1 = ConstantFoldCompareInstruction(ICmpInst::Predicate::ICMP_SGE,
+                                                  C2, ThresholdLowIncl);
+  if (!Precond1 || !match(Precond1, m_One()))
     return nullptr;
   // The fold has a precondition 2: C2 s<= ThresholdHigh
-  auto *Precond2 = ConstantExpr::getICmp(ICmpInst::Predicate::ICMP_SLE, C2,
-                                         ThresholdHighExcl);
-  if (!match(Precond2, m_One()))
+  auto *Precond2 = ConstantFoldCompareInstruction(ICmpInst::Predicate::ICMP_SLE,
+                                                  C2, ThresholdHighExcl);
+  if (!Precond2 || !match(Precond2, m_One()))
     return nullptr;
 
   // If we are matching from a truncated input, we need to sext the
diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index b6f8b24f43b8c..affa2fdcbb897 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -808,9 +808,12 @@ Instruction *InstCombinerImpl::tryFoldInstWithCtpopWithNot(Instruction *I) {
   Constant *BitWidthC = ConstantInt::get(Ty, Ty->getScalarSizeInBits());
   // Need extra check for icmp. Note if this check is true, it generally means
   // the icmp will simplify to true/false.
-  if (Opc == Instruction::ICmp && !cast<ICmpInst>(I)->isEquality() &&
-      !ConstantExpr::getICmp(ICmpInst::ICMP_UGT, C, BitWidthC)->isZeroValue())
-    return nullptr;
+  if (Opc == Instruction::ICmp && !cast<ICmpInst>(I)->isEquality()) {
+    Constant *Cmp =
+        ConstantFoldCompareInstruction(ICmpInst::ICMP_UGT, C, BitWidthC);
+    if (Cmp && !Cmp->isZeroValue())
+      return nullptr;
+  }
 
   // Check we can invert `(not x)` for free.
   bool Consumes = false;
diff --git a/llvm/lib/Transforms/Scalar/JumpThreading.cpp b/llvm/lib/Transforms/Scalar/JumpThreading.cpp
index 08d82fa66da30..ac90ad8c08e61 100644
--- a/llvm/lib/Transforms/Scalar/JumpThreading.cpp
+++ b/llvm/lib/Transforms/Scalar/JumpThreading.cpp
@@ -37,6 +37,7 @@
 #include "llvm/IR/BasicBlock.h"
 #include "llvm/IR/CFG.h"
 #include "llvm/IR/Constant.h"
+#include "llvm/IR/ConstantFold.h"
 #include "llvm/IR/ConstantRange.h"
 #include "llvm/IR/Constants.h"
 #include "llvm/IR/DataLayout.h"
@@ -868,7 +869,7 @@ bool JumpThreadingPass::computeValueKnownInPredecessorsImpl(
 
       for (const auto &LHSVal : LHSVals) {
         Constant *V = LHSVal.first;
-        Constant *Folded = ConstantExpr::getCompare(Pred, V, CmpConst);
+        Constant *Folded = ConstantFoldCompareInstruction(Pred, V, CmpConst);
         if (Constant *KC = getKnownConstant(Folded, WantInteger))
           Result.emplace_back(KC, LHSVal.second);
       }
@@ -1538,7 +1539,8 @@ Constant *JumpThreadingPass::evaluateOnPredecessorEdge(BasicBlock *BB,
       Constant *Op1 =
           evaluateOnPredecessorEdge(BB, PredPredBB, CondCmp->getOperand(1));
       if (Op0 && Op1) {
-        return ConstantExpr::getCompare(CondCmp->getPredicate(), Op0, Op1);
+        return ConstantFoldCompareInstruction(CondCmp->getPredicate(), Op0,
+                                              Op1);
       }
     }
     return nullptr;
diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
index 5a44a11ecfd2c..4efd39c930c99 100644
--- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
@@ -6581,16 +6581,16 @@ static void reuseTableCompare(
   Constant *FalseConst = ConstantInt::getFalse(RangeCmp->getType());
 
   // Check if the compare with the default value is constant true or false.
-  Constant *DefaultConst = ConstantExpr::getICmp(CmpInst->getPredicate(),
-                                                 DefaultValue, CmpOp1, true);
+  Constant *DefaultConst = ConstantFoldCompareInstruction(
+      CmpInst->getPredicate(), DefaultValue, CmpOp1);
   if (DefaultConst != TrueConst && DefaultConst != FalseConst)
     return;
 
   // Check if the compare with the case values is distinct from the default
   // compare result.
   for (auto ValuePair : Values) {
-    Constant *CaseConst = ConstantExpr::getICmp(CmpInst->getPredicate(),
-                                                ValuePair.second, CmpOp1, true);
+    Constant *CaseConst = ConstantFoldCompareInstruction(
+        CmpInst->getPredicate(), ValuePair.second, CmpOp1);
     if (!CaseConst || CaseConst == DefaultConst ||
         (CaseConst != TrueConst && CaseConst != FalseConst))
       return;



More information about the llvm-commits mailing list