[llvm] r288589 - [InstSimplify] add more helper functions for SimplifyICmpInst; NFCI

Sanjay Patel via llvm-commits llvm-commits at lists.llvm.org
Sat Dec 3 10:03:53 PST 2016


Author: spatel
Date: Sat Dec  3 12:03:53 2016
New Revision: 288589

URL: http://llvm.org/viewvc/llvm-project?rev=288589&view=rev
Log:
[InstSimplify] add more helper functions for SimplifyICmpInst; NFCI

Modified:
    llvm/trunk/lib/Analysis/InstructionSimplify.cpp

Modified: llvm/trunk/lib/Analysis/InstructionSimplify.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Analysis/InstructionSimplify.cpp?rev=288589&r1=288588&r2=288589&view=diff
==============================================================================
--- llvm/trunk/lib/Analysis/InstructionSimplify.cpp (original)
+++ llvm/trunk/lib/Analysis/InstructionSimplify.cpp Sat Dec  3 12:03:53 2016
@@ -67,6 +67,8 @@ static Value *SimplifyFPBinOp(unsigned,
                               const Query &, unsigned);
 static Value *SimplifyCmpInst(unsigned, Value *, Value *, const Query &,
                               unsigned);
+static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS,
+                               const Query &Q, unsigned MaxRecurse);
 static Value *SimplifyOrInst(Value *, Value *, const Query &, unsigned);
 static Value *SimplifyXorInst(Value *, Value *, const Query &, unsigned);
 static Value *SimplifyCastInst(unsigned, Value *, Type *,
@@ -2431,231 +2433,11 @@ static Value *simplifyICmpWithConstant(C
   return nullptr;
 }
 
-/// Given operands for an ICmpInst, see if we can fold the result.
-/// If not, this returns null.
-static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS,
-                               const Query &Q, unsigned MaxRecurse) {
-  CmpInst::Predicate Pred = (CmpInst::Predicate)Predicate;
-  assert(CmpInst::isIntPredicate(Pred) && "Not an integer compare!");
-
-  if (Constant *CLHS = dyn_cast<Constant>(LHS)) {
-    if (Constant *CRHS = dyn_cast<Constant>(RHS))
-      return ConstantFoldCompareInstOperands(Pred, CLHS, CRHS, Q.DL, Q.TLI);
-
-    // If we have a constant, make sure it is on the RHS.
-    std::swap(LHS, RHS);
-    Pred = CmpInst::getSwappedPredicate(Pred);
-  }
-
+static Value *simplifyICmpWithBinOp(CmpInst::Predicate Pred, Value *LHS,
+                                    Value *RHS, const Query &Q,
+                                    unsigned MaxRecurse) {
   Type *ITy = GetCompareTy(LHS); // The return type.
 
-  // icmp X, X -> true/false
-  // X icmp undef -> true/false.  For example, icmp ugt %X, undef -> false
-  // because X could be 0.
-  if (LHS == RHS || isa<UndefValue>(RHS))
-    return ConstantInt::get(ITy, CmpInst::isTrueWhenEqual(Pred));
-
-  if (Value *V = simplifyICmpOfBools(Pred, LHS, RHS, Q))
-    return V;
-
-  if (Value *V = simplifyICmpWithZero(Pred, LHS, RHS, Q))
-    return V;
-
-  if (Value *V = simplifyICmpWithConstant(Pred, LHS, RHS))
-    return V;
-
-  // If both operands have range metadata, use the metadata
-  // to simplify the comparison.
-  if (isa<Instruction>(RHS) && isa<Instruction>(LHS)) {
-    auto RHS_Instr = dyn_cast<Instruction>(RHS);
-    auto LHS_Instr = dyn_cast<Instruction>(LHS);
-
-    if (RHS_Instr->getMetadata(LLVMContext::MD_range) &&
-        LHS_Instr->getMetadata(LLVMContext::MD_range)) {
-      auto RHS_CR = getConstantRangeFromMetadata(
-          *RHS_Instr->getMetadata(LLVMContext::MD_range));
-      auto LHS_CR = getConstantRangeFromMetadata(
-          *LHS_Instr->getMetadata(LLVMContext::MD_range));
-
-      auto Satisfied_CR = ConstantRange::makeSatisfyingICmpRegion(Pred, RHS_CR);
-      if (Satisfied_CR.contains(LHS_CR))
-        return ConstantInt::getTrue(RHS->getContext());
-
-      auto InversedSatisfied_CR = ConstantRange::makeSatisfyingICmpRegion(
-                CmpInst::getInversePredicate(Pred), RHS_CR);
-      if (InversedSatisfied_CR.contains(LHS_CR))
-        return ConstantInt::getFalse(RHS->getContext());
-    }
-  }
-
-  // Compare of cast, for example (zext X) != 0 -> X != 0
-  if (isa<CastInst>(LHS) && (isa<Constant>(RHS) || isa<CastInst>(RHS))) {
-    Instruction *LI = cast<CastInst>(LHS);
-    Value *SrcOp = LI->getOperand(0);
-    Type *SrcTy = SrcOp->getType();
-    Type *DstTy = LI->getType();
-
-    // Turn icmp (ptrtoint x), (ptrtoint/constant) into a compare of the input
-    // if the integer type is the same size as the pointer type.
-    if (MaxRecurse && isa<PtrToIntInst>(LI) &&
-        Q.DL.getTypeSizeInBits(SrcTy) == DstTy->getPrimitiveSizeInBits()) {
-      if (Constant *RHSC = dyn_cast<Constant>(RHS)) {
-        // Transfer the cast to the constant.
-        if (Value *V = SimplifyICmpInst(Pred, SrcOp,
-                                        ConstantExpr::getIntToPtr(RHSC, SrcTy),
-                                        Q, MaxRecurse-1))
-          return V;
-      } else if (PtrToIntInst *RI = dyn_cast<PtrToIntInst>(RHS)) {
-        if (RI->getOperand(0)->getType() == SrcTy)
-          // Compare without the cast.
-          if (Value *V = SimplifyICmpInst(Pred, SrcOp, RI->getOperand(0),
-                                          Q, MaxRecurse-1))
-            return V;
-      }
-    }
-
-    if (isa<ZExtInst>(LHS)) {
-      // Turn icmp (zext X), (zext Y) into a compare of X and Y if they have the
-      // same type.
-      if (ZExtInst *RI = dyn_cast<ZExtInst>(RHS)) {
-        if (MaxRecurse && SrcTy == RI->getOperand(0)->getType())
-          // Compare X and Y.  Note that signed predicates become unsigned.
-          if (Value *V = SimplifyICmpInst(ICmpInst::getUnsignedPredicate(Pred),
-                                          SrcOp, RI->getOperand(0), Q,
-                                          MaxRecurse-1))
-            return V;
-      }
-      // Turn icmp (zext X), Cst into a compare of X and Cst if Cst is extended
-      // too.  If not, then try to deduce the result of the comparison.
-      else if (ConstantInt *CI = dyn_cast<ConstantInt>(RHS)) {
-        // Compute the constant that would happen if we truncated to SrcTy then
-        // reextended to DstTy.
-        Constant *Trunc = ConstantExpr::getTrunc(CI, SrcTy);
-        Constant *RExt = ConstantExpr::getCast(CastInst::ZExt, Trunc, DstTy);
-
-        // If the re-extended constant didn't change then this is effectively
-        // also a case of comparing two zero-extended values.
-        if (RExt == CI && MaxRecurse)
-          if (Value *V = SimplifyICmpInst(ICmpInst::getUnsignedPredicate(Pred),
-                                        SrcOp, Trunc, Q, MaxRecurse-1))
-            return V;
-
-        // Otherwise the upper bits of LHS are zero while RHS has a non-zero bit
-        // there.  Use this to work out the result of the comparison.
-        if (RExt != CI) {
-          switch (Pred) {
-          default: llvm_unreachable("Unknown ICmp predicate!");
-          // LHS <u RHS.
-          case ICmpInst::ICMP_EQ:
-          case ICmpInst::ICMP_UGT:
-          case ICmpInst::ICMP_UGE:
-            return ConstantInt::getFalse(CI->getContext());
-
-          case ICmpInst::ICMP_NE:
-          case ICmpInst::ICMP_ULT:
-          case ICmpInst::ICMP_ULE:
-            return ConstantInt::getTrue(CI->getContext());
-
-          // LHS is non-negative.  If RHS is negative then LHS >s LHS.  If RHS
-          // is non-negative then LHS <s RHS.
-          case ICmpInst::ICMP_SGT:
-          case ICmpInst::ICMP_SGE:
-            return CI->getValue().isNegative() ?
-              ConstantInt::getTrue(CI->getContext()) :
-              ConstantInt::getFalse(CI->getContext());
-
-          case ICmpInst::ICMP_SLT:
-          case ICmpInst::ICMP_SLE:
-            return CI->getValue().isNegative() ?
-              ConstantInt::getFalse(CI->getContext()) :
-              ConstantInt::getTrue(CI->getContext());
-          }
-        }
-      }
-    }
-
-    if (isa<SExtInst>(LHS)) {
-      // Turn icmp (sext X), (sext Y) into a compare of X and Y if they have the
-      // same type.
-      if (SExtInst *RI = dyn_cast<SExtInst>(RHS)) {
-        if (MaxRecurse && SrcTy == RI->getOperand(0)->getType())
-          // Compare X and Y.  Note that the predicate does not change.
-          if (Value *V = SimplifyICmpInst(Pred, SrcOp, RI->getOperand(0),
-                                          Q, MaxRecurse-1))
-            return V;
-      }
-      // Turn icmp (sext X), Cst into a compare of X and Cst if Cst is extended
-      // too.  If not, then try to deduce the result of the comparison.
-      else if (ConstantInt *CI = dyn_cast<ConstantInt>(RHS)) {
-        // Compute the constant that would happen if we truncated to SrcTy then
-        // reextended to DstTy.
-        Constant *Trunc = ConstantExpr::getTrunc(CI, SrcTy);
-        Constant *RExt = ConstantExpr::getCast(CastInst::SExt, Trunc, DstTy);
-
-        // If the re-extended constant didn't change then this is effectively
-        // also a case of comparing two sign-extended values.
-        if (RExt == CI && MaxRecurse)
-          if (Value *V = SimplifyICmpInst(Pred, SrcOp, Trunc, Q, MaxRecurse-1))
-            return V;
-
-        // Otherwise the upper bits of LHS are all equal, while RHS has varying
-        // bits there.  Use this to work out the result of the comparison.
-        if (RExt != CI) {
-          switch (Pred) {
-          default: llvm_unreachable("Unknown ICmp predicate!");
-          case ICmpInst::ICMP_EQ:
-            return ConstantInt::getFalse(CI->getContext());
-          case ICmpInst::ICMP_NE:
-            return ConstantInt::getTrue(CI->getContext());
-
-          // If RHS is non-negative then LHS <s RHS.  If RHS is negative then
-          // LHS >s RHS.
-          case ICmpInst::ICMP_SGT:
-          case ICmpInst::ICMP_SGE:
-            return CI->getValue().isNegative() ?
-              ConstantInt::getTrue(CI->getContext()) :
-              ConstantInt::getFalse(CI->getContext());
-          case ICmpInst::ICMP_SLT:
-          case ICmpInst::ICMP_SLE:
-            return CI->getValue().isNegative() ?
-              ConstantInt::getFalse(CI->getContext()) :
-              ConstantInt::getTrue(CI->getContext());
-
-          // If LHS is non-negative then LHS <u RHS.  If LHS is negative then
-          // LHS >u RHS.
-          case ICmpInst::ICMP_UGT:
-          case ICmpInst::ICMP_UGE:
-            // Comparison is true iff the LHS <s 0.
-            if (MaxRecurse)
-              if (Value *V = SimplifyICmpInst(ICmpInst::ICMP_SLT, SrcOp,
-                                              Constant::getNullValue(SrcTy),
-                                              Q, MaxRecurse-1))
-                return V;
-            break;
-          case ICmpInst::ICMP_ULT:
-          case ICmpInst::ICMP_ULE:
-            // Comparison is true iff the LHS >=s 0.
-            if (MaxRecurse)
-              if (Value *V = SimplifyICmpInst(ICmpInst::ICMP_SGE, SrcOp,
-                                              Constant::getNullValue(SrcTy),
-                                              Q, MaxRecurse-1))
-                return V;
-            break;
-          }
-        }
-      }
-    }
-  }
-
-  // icmp eq|ne X, Y -> false|true if X != Y
-  if ((Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE) &&
-      isKnownNonEqual(LHS, RHS, Q.DL, Q.AC, Q.CxtI, Q.DT)) {
-    LLVMContext &Ctx = LHS->getType()->getContext();
-    return Pred == ICmpInst::ICMP_NE ?
-      ConstantInt::getTrue(Ctx) : ConstantInt::getFalse(Ctx);
-  }
-
-  // Special logic for binary operators.
   BinaryOperator *LBO = dyn_cast<BinaryOperator>(LHS);
   BinaryOperator *RBO = dyn_cast<BinaryOperator>(RHS);
   if (MaxRecurse && (LBO || RBO)) {
@@ -2664,35 +2446,39 @@ static Value *SimplifyICmpInst(unsigned
     // LHS = A + B (or A and B are null); RHS = C + D (or C and D are null).
     bool NoLHSWrapProblem = false, NoRHSWrapProblem = false;
     if (LBO && LBO->getOpcode() == Instruction::Add) {
-      A = LBO->getOperand(0); B = LBO->getOperand(1);
-      NoLHSWrapProblem = ICmpInst::isEquality(Pred) ||
-        (CmpInst::isUnsigned(Pred) && LBO->hasNoUnsignedWrap()) ||
-        (CmpInst::isSigned(Pred) && LBO->hasNoSignedWrap());
+      A = LBO->getOperand(0);
+      B = LBO->getOperand(1);
+      NoLHSWrapProblem =
+          ICmpInst::isEquality(Pred) ||
+          (CmpInst::isUnsigned(Pred) && LBO->hasNoUnsignedWrap()) ||
+          (CmpInst::isSigned(Pred) && LBO->hasNoSignedWrap());
     }
     if (RBO && RBO->getOpcode() == Instruction::Add) {
-      C = RBO->getOperand(0); D = RBO->getOperand(1);
-      NoRHSWrapProblem = ICmpInst::isEquality(Pred) ||
-        (CmpInst::isUnsigned(Pred) && RBO->hasNoUnsignedWrap()) ||
-        (CmpInst::isSigned(Pred) && RBO->hasNoSignedWrap());
+      C = RBO->getOperand(0);
+      D = RBO->getOperand(1);
+      NoRHSWrapProblem =
+          ICmpInst::isEquality(Pred) ||
+          (CmpInst::isUnsigned(Pred) && RBO->hasNoUnsignedWrap()) ||
+          (CmpInst::isSigned(Pred) && RBO->hasNoSignedWrap());
     }
 
     // icmp (X+Y), X -> icmp Y, 0 for equalities or if there is no overflow.
     if ((A == RHS || B == RHS) && NoLHSWrapProblem)
       if (Value *V = SimplifyICmpInst(Pred, A == RHS ? B : A,
-                                      Constant::getNullValue(RHS->getType()),
-                                      Q, MaxRecurse-1))
+                                      Constant::getNullValue(RHS->getType()), Q,
+                                      MaxRecurse - 1))
         return V;
 
     // icmp X, (X+Y) -> icmp 0, Y for equalities or if there is no overflow.
     if ((C == LHS || D == LHS) && NoRHSWrapProblem)
-      if (Value *V = SimplifyICmpInst(Pred,
-                                      Constant::getNullValue(LHS->getType()),
-                                      C == LHS ? D : C, Q, MaxRecurse-1))
+      if (Value *V =
+              SimplifyICmpInst(Pred, Constant::getNullValue(LHS->getType()),
+                               C == LHS ? D : C, Q, MaxRecurse - 1))
         return V;
 
     // icmp (X+Y), (X+Z) -> icmp Y,Z for equalities or if there is no overflow.
-    if (A && C && (A == C || A == D || B == C || B == D) &&
-        NoLHSWrapProblem && NoRHSWrapProblem) {
+    if (A && C && (A == C || A == D || B == C || B == D) && NoLHSWrapProblem &&
+        NoRHSWrapProblem) {
       // Determine Y and Z in the form icmp (X+Y), (X+Z).
       Value *Y, *Z;
       if (A == C) {
@@ -2713,7 +2499,7 @@ static Value *SimplifyICmpInst(unsigned
         Y = A;
         Z = C;
       }
-      if (Value *V = SimplifyICmpInst(Pred, Y, Z, Q, MaxRecurse-1))
+      if (Value *V = SimplifyICmpInst(Pred, Y, Z, Q, MaxRecurse - 1))
         return V;
     }
   }
@@ -2923,7 +2709,8 @@ static Value *SimplifyICmpInst(unsigned
   if (MaxRecurse && LBO && RBO && LBO->getOpcode() == RBO->getOpcode() &&
       LBO->getOperand(1) == RBO->getOperand(1)) {
     switch (LBO->getOpcode()) {
-    default: break;
+    default:
+      break;
     case Instruction::UDiv:
     case Instruction::LShr:
       if (ICmpInst::isSigned(Pred))
@@ -2934,7 +2721,7 @@ static Value *SimplifyICmpInst(unsigned
       if (!LBO->isExact() || !RBO->isExact())
         break;
       if (Value *V = SimplifyICmpInst(Pred, LBO->getOperand(0),
-                                      RBO->getOperand(0), Q, MaxRecurse-1))
+                                      RBO->getOperand(0), Q, MaxRecurse - 1))
         return V;
       break;
     case Instruction::Shl: {
@@ -2945,40 +2732,49 @@ static Value *SimplifyICmpInst(unsigned
       if (!NSW && ICmpInst::isSigned(Pred))
         break;
       if (Value *V = SimplifyICmpInst(Pred, LBO->getOperand(0),
-                                      RBO->getOperand(0), Q, MaxRecurse-1))
+                                      RBO->getOperand(0), Q, MaxRecurse - 1))
         return V;
       break;
     }
     }
   }
+  return nullptr;
+}
 
-  // Simplify comparisons involving max/min.
+/// Simplify comparisons corresponding to integer min/max idioms.
+static Value *simplifyMinMax(CmpInst::Predicate Pred, Value *LHS, Value *RHS,
+                             const Query &Q, unsigned MaxRecurse) {
+  Type *ITy = GetCompareTy(LHS); // The return type.
   Value *A, *B;
   CmpInst::Predicate P = CmpInst::BAD_ICMP_PREDICATE;
   CmpInst::Predicate EqP; // Chosen so that "A == max/min(A,B)" iff "A EqP B".
 
   // Signed variants on "max(a,b)>=a -> true".
   if (match(LHS, m_SMax(m_Value(A), m_Value(B))) && (A == RHS || B == RHS)) {
-    if (A != RHS) std::swap(A, B); // smax(A, B) pred A.
+    if (A != RHS)
+      std::swap(A, B);       // smax(A, B) pred A.
     EqP = CmpInst::ICMP_SGE; // "A == smax(A, B)" iff "A sge B".
     // We analyze this as smax(A, B) pred A.
     P = Pred;
   } else if (match(RHS, m_SMax(m_Value(A), m_Value(B))) &&
              (A == LHS || B == LHS)) {
-    if (A != LHS) std::swap(A, B); // A pred smax(A, B).
+    if (A != LHS)
+      std::swap(A, B);       // A pred smax(A, B).
     EqP = CmpInst::ICMP_SGE; // "A == smax(A, B)" iff "A sge B".
     // We analyze this as smax(A, B) swapped-pred A.
     P = CmpInst::getSwappedPredicate(Pred);
   } else if (match(LHS, m_SMin(m_Value(A), m_Value(B))) &&
              (A == RHS || B == RHS)) {
-    if (A != RHS) std::swap(A, B); // smin(A, B) pred A.
+    if (A != RHS)
+      std::swap(A, B);       // smin(A, B) pred A.
     EqP = CmpInst::ICMP_SLE; // "A == smin(A, B)" iff "A sle B".
     // We analyze this as smax(-A, -B) swapped-pred -A.
     // Note that we do not need to actually form -A or -B thanks to EqP.
     P = CmpInst::getSwappedPredicate(Pred);
   } else if (match(RHS, m_SMin(m_Value(A), m_Value(B))) &&
              (A == LHS || B == LHS)) {
-    if (A != LHS) std::swap(A, B); // A pred smin(A, B).
+    if (A != LHS)
+      std::swap(A, B);       // A pred smin(A, B).
     EqP = CmpInst::ICMP_SLE; // "A == smin(A, B)" iff "A sle B".
     // We analyze this as smax(-A, -B) pred -A.
     // Note that we do not need to actually form -A or -B thanks to EqP.
@@ -2999,7 +2795,7 @@ static Value *SimplifyICmpInst(unsigned
         return V;
       // Otherwise, see if "A EqP B" simplifies.
       if (MaxRecurse)
-        if (Value *V = SimplifyICmpInst(EqP, A, B, Q, MaxRecurse-1))
+        if (Value *V = SimplifyICmpInst(EqP, A, B, Q, MaxRecurse - 1))
           return V;
       break;
     case CmpInst::ICMP_NE:
@@ -3013,7 +2809,7 @@ static Value *SimplifyICmpInst(unsigned
         return V;
       // Otherwise, see if "A InvEqP B" simplifies.
       if (MaxRecurse)
-        if (Value *V = SimplifyICmpInst(InvEqP, A, B, Q, MaxRecurse-1))
+        if (Value *V = SimplifyICmpInst(InvEqP, A, B, Q, MaxRecurse - 1))
           return V;
       break;
     }
@@ -3029,26 +2825,30 @@ static Value *SimplifyICmpInst(unsigned
   // Unsigned variants on "max(a,b)>=a -> true".
   P = CmpInst::BAD_ICMP_PREDICATE;
   if (match(LHS, m_UMax(m_Value(A), m_Value(B))) && (A == RHS || B == RHS)) {
-    if (A != RHS) std::swap(A, B); // umax(A, B) pred A.
+    if (A != RHS)
+      std::swap(A, B);       // umax(A, B) pred A.
     EqP = CmpInst::ICMP_UGE; // "A == umax(A, B)" iff "A uge B".
     // We analyze this as umax(A, B) pred A.
     P = Pred;
   } else if (match(RHS, m_UMax(m_Value(A), m_Value(B))) &&
              (A == LHS || B == LHS)) {
-    if (A != LHS) std::swap(A, B); // A pred umax(A, B).
+    if (A != LHS)
+      std::swap(A, B);       // A pred umax(A, B).
     EqP = CmpInst::ICMP_UGE; // "A == umax(A, B)" iff "A uge B".
     // We analyze this as umax(A, B) swapped-pred A.
     P = CmpInst::getSwappedPredicate(Pred);
   } else if (match(LHS, m_UMin(m_Value(A), m_Value(B))) &&
              (A == RHS || B == RHS)) {
-    if (A != RHS) std::swap(A, B); // umin(A, B) pred A.
+    if (A != RHS)
+      std::swap(A, B);       // umin(A, B) pred A.
     EqP = CmpInst::ICMP_ULE; // "A == umin(A, B)" iff "A ule B".
     // We analyze this as umax(-A, -B) swapped-pred -A.
     // Note that we do not need to actually form -A or -B thanks to EqP.
     P = CmpInst::getSwappedPredicate(Pred);
   } else if (match(RHS, m_UMin(m_Value(A), m_Value(B))) &&
              (A == LHS || B == LHS)) {
-    if (A != LHS) std::swap(A, B); // A pred umin(A, B).
+    if (A != LHS)
+      std::swap(A, B);       // A pred umin(A, B).
     EqP = CmpInst::ICMP_ULE; // "A == umin(A, B)" iff "A ule B".
     // We analyze this as umax(-A, -B) pred -A.
     // Note that we do not need to actually form -A or -B thanks to EqP.
@@ -3069,7 +2869,7 @@ static Value *SimplifyICmpInst(unsigned
         return V;
       // Otherwise, see if "A EqP B" simplifies.
       if (MaxRecurse)
-        if (Value *V = SimplifyICmpInst(EqP, A, B, Q, MaxRecurse-1))
+        if (Value *V = SimplifyICmpInst(EqP, A, B, Q, MaxRecurse - 1))
           return V;
       break;
     case CmpInst::ICMP_NE:
@@ -3083,7 +2883,7 @@ static Value *SimplifyICmpInst(unsigned
         return V;
       // Otherwise, see if "A InvEqP B" simplifies.
       if (MaxRecurse)
-        if (Value *V = SimplifyICmpInst(InvEqP, A, B, Q, MaxRecurse-1))
+        if (Value *V = SimplifyICmpInst(InvEqP, A, B, Q, MaxRecurse - 1))
           return V;
       break;
     }
@@ -3140,6 +2940,239 @@ static Value *SimplifyICmpInst(unsigned
       return getFalse(ITy);
   }
 
+  return nullptr;
+}
+
+/// Given operands for an ICmpInst, see if we can fold the result.
+/// If not, this returns null.
+static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS,
+                               const Query &Q, unsigned MaxRecurse) {
+  CmpInst::Predicate Pred = (CmpInst::Predicate)Predicate;
+  assert(CmpInst::isIntPredicate(Pred) && "Not an integer compare!");
+
+  if (Constant *CLHS = dyn_cast<Constant>(LHS)) {
+    if (Constant *CRHS = dyn_cast<Constant>(RHS))
+      return ConstantFoldCompareInstOperands(Pred, CLHS, CRHS, Q.DL, Q.TLI);
+
+    // If we have a constant, make sure it is on the RHS.
+    std::swap(LHS, RHS);
+    Pred = CmpInst::getSwappedPredicate(Pred);
+  }
+
+  Type *ITy = GetCompareTy(LHS); // The return type.
+
+  // icmp X, X -> true/false
+  // X icmp undef -> true/false.  For example, icmp ugt %X, undef -> false
+  // because X could be 0.
+  if (LHS == RHS || isa<UndefValue>(RHS))
+    return ConstantInt::get(ITy, CmpInst::isTrueWhenEqual(Pred));
+
+  if (Value *V = simplifyICmpOfBools(Pred, LHS, RHS, Q))
+    return V;
+
+  if (Value *V = simplifyICmpWithZero(Pred, LHS, RHS, Q))
+    return V;
+
+  if (Value *V = simplifyICmpWithConstant(Pred, LHS, RHS))
+    return V;
+
+  // If both operands have range metadata, use the metadata
+  // to simplify the comparison.
+  if (isa<Instruction>(RHS) && isa<Instruction>(LHS)) {
+    auto RHS_Instr = dyn_cast<Instruction>(RHS);
+    auto LHS_Instr = dyn_cast<Instruction>(LHS);
+
+    if (RHS_Instr->getMetadata(LLVMContext::MD_range) &&
+        LHS_Instr->getMetadata(LLVMContext::MD_range)) {
+      auto RHS_CR = getConstantRangeFromMetadata(
+          *RHS_Instr->getMetadata(LLVMContext::MD_range));
+      auto LHS_CR = getConstantRangeFromMetadata(
+          *LHS_Instr->getMetadata(LLVMContext::MD_range));
+
+      auto Satisfied_CR = ConstantRange::makeSatisfyingICmpRegion(Pred, RHS_CR);
+      if (Satisfied_CR.contains(LHS_CR))
+        return ConstantInt::getTrue(RHS->getContext());
+
+      auto InversedSatisfied_CR = ConstantRange::makeSatisfyingICmpRegion(
+                CmpInst::getInversePredicate(Pred), RHS_CR);
+      if (InversedSatisfied_CR.contains(LHS_CR))
+        return ConstantInt::getFalse(RHS->getContext());
+    }
+  }
+
+  // Compare of cast, for example (zext X) != 0 -> X != 0
+  if (isa<CastInst>(LHS) && (isa<Constant>(RHS) || isa<CastInst>(RHS))) {
+    Instruction *LI = cast<CastInst>(LHS);
+    Value *SrcOp = LI->getOperand(0);
+    Type *SrcTy = SrcOp->getType();
+    Type *DstTy = LI->getType();
+
+    // Turn icmp (ptrtoint x), (ptrtoint/constant) into a compare of the input
+    // if the integer type is the same size as the pointer type.
+    if (MaxRecurse && isa<PtrToIntInst>(LI) &&
+        Q.DL.getTypeSizeInBits(SrcTy) == DstTy->getPrimitiveSizeInBits()) {
+      if (Constant *RHSC = dyn_cast<Constant>(RHS)) {
+        // Transfer the cast to the constant.
+        if (Value *V = SimplifyICmpInst(Pred, SrcOp,
+                                        ConstantExpr::getIntToPtr(RHSC, SrcTy),
+                                        Q, MaxRecurse-1))
+          return V;
+      } else if (PtrToIntInst *RI = dyn_cast<PtrToIntInst>(RHS)) {
+        if (RI->getOperand(0)->getType() == SrcTy)
+          // Compare without the cast.
+          if (Value *V = SimplifyICmpInst(Pred, SrcOp, RI->getOperand(0),
+                                          Q, MaxRecurse-1))
+            return V;
+      }
+    }
+
+    if (isa<ZExtInst>(LHS)) {
+      // Turn icmp (zext X), (zext Y) into a compare of X and Y if they have the
+      // same type.
+      if (ZExtInst *RI = dyn_cast<ZExtInst>(RHS)) {
+        if (MaxRecurse && SrcTy == RI->getOperand(0)->getType())
+          // Compare X and Y.  Note that signed predicates become unsigned.
+          if (Value *V = SimplifyICmpInst(ICmpInst::getUnsignedPredicate(Pred),
+                                          SrcOp, RI->getOperand(0), Q,
+                                          MaxRecurse-1))
+            return V;
+      }
+      // Turn icmp (zext X), Cst into a compare of X and Cst if Cst is extended
+      // too.  If not, then try to deduce the result of the comparison.
+      else if (ConstantInt *CI = dyn_cast<ConstantInt>(RHS)) {
+        // Compute the constant that would happen if we truncated to SrcTy then
+        // reextended to DstTy.
+        Constant *Trunc = ConstantExpr::getTrunc(CI, SrcTy);
+        Constant *RExt = ConstantExpr::getCast(CastInst::ZExt, Trunc, DstTy);
+
+        // If the re-extended constant didn't change then this is effectively
+        // also a case of comparing two zero-extended values.
+        if (RExt == CI && MaxRecurse)
+          if (Value *V = SimplifyICmpInst(ICmpInst::getUnsignedPredicate(Pred),
+                                        SrcOp, Trunc, Q, MaxRecurse-1))
+            return V;
+
+        // Otherwise the upper bits of LHS are zero while RHS has a non-zero bit
+        // there.  Use this to work out the result of the comparison.
+        if (RExt != CI) {
+          switch (Pred) {
+          default: llvm_unreachable("Unknown ICmp predicate!");
+          // LHS <u RHS.
+          case ICmpInst::ICMP_EQ:
+          case ICmpInst::ICMP_UGT:
+          case ICmpInst::ICMP_UGE:
+            return ConstantInt::getFalse(CI->getContext());
+
+          case ICmpInst::ICMP_NE:
+          case ICmpInst::ICMP_ULT:
+          case ICmpInst::ICMP_ULE:
+            return ConstantInt::getTrue(CI->getContext());
+
+          // LHS is non-negative.  If RHS is negative then LHS >s LHS.  If RHS
+          // is non-negative then LHS <s RHS.
+          case ICmpInst::ICMP_SGT:
+          case ICmpInst::ICMP_SGE:
+            return CI->getValue().isNegative() ?
+              ConstantInt::getTrue(CI->getContext()) :
+              ConstantInt::getFalse(CI->getContext());
+
+          case ICmpInst::ICMP_SLT:
+          case ICmpInst::ICMP_SLE:
+            return CI->getValue().isNegative() ?
+              ConstantInt::getFalse(CI->getContext()) :
+              ConstantInt::getTrue(CI->getContext());
+          }
+        }
+      }
+    }
+
+    if (isa<SExtInst>(LHS)) {
+      // Turn icmp (sext X), (sext Y) into a compare of X and Y if they have the
+      // same type.
+      if (SExtInst *RI = dyn_cast<SExtInst>(RHS)) {
+        if (MaxRecurse && SrcTy == RI->getOperand(0)->getType())
+          // Compare X and Y.  Note that the predicate does not change.
+          if (Value *V = SimplifyICmpInst(Pred, SrcOp, RI->getOperand(0),
+                                          Q, MaxRecurse-1))
+            return V;
+      }
+      // Turn icmp (sext X), Cst into a compare of X and Cst if Cst is extended
+      // too.  If not, then try to deduce the result of the comparison.
+      else if (ConstantInt *CI = dyn_cast<ConstantInt>(RHS)) {
+        // Compute the constant that would happen if we truncated to SrcTy then
+        // reextended to DstTy.
+        Constant *Trunc = ConstantExpr::getTrunc(CI, SrcTy);
+        Constant *RExt = ConstantExpr::getCast(CastInst::SExt, Trunc, DstTy);
+
+        // If the re-extended constant didn't change then this is effectively
+        // also a case of comparing two sign-extended values.
+        if (RExt == CI && MaxRecurse)
+          if (Value *V = SimplifyICmpInst(Pred, SrcOp, Trunc, Q, MaxRecurse-1))
+            return V;
+
+        // Otherwise the upper bits of LHS are all equal, while RHS has varying
+        // bits there.  Use this to work out the result of the comparison.
+        if (RExt != CI) {
+          switch (Pred) {
+          default: llvm_unreachable("Unknown ICmp predicate!");
+          case ICmpInst::ICMP_EQ:
+            return ConstantInt::getFalse(CI->getContext());
+          case ICmpInst::ICMP_NE:
+            return ConstantInt::getTrue(CI->getContext());
+
+          // If RHS is non-negative then LHS <s RHS.  If RHS is negative then
+          // LHS >s RHS.
+          case ICmpInst::ICMP_SGT:
+          case ICmpInst::ICMP_SGE:
+            return CI->getValue().isNegative() ?
+              ConstantInt::getTrue(CI->getContext()) :
+              ConstantInt::getFalse(CI->getContext());
+          case ICmpInst::ICMP_SLT:
+          case ICmpInst::ICMP_SLE:
+            return CI->getValue().isNegative() ?
+              ConstantInt::getFalse(CI->getContext()) :
+              ConstantInt::getTrue(CI->getContext());
+
+          // If LHS is non-negative then LHS <u RHS.  If LHS is negative then
+          // LHS >u RHS.
+          case ICmpInst::ICMP_UGT:
+          case ICmpInst::ICMP_UGE:
+            // Comparison is true iff the LHS <s 0.
+            if (MaxRecurse)
+              if (Value *V = SimplifyICmpInst(ICmpInst::ICMP_SLT, SrcOp,
+                                              Constant::getNullValue(SrcTy),
+                                              Q, MaxRecurse-1))
+                return V;
+            break;
+          case ICmpInst::ICMP_ULT:
+          case ICmpInst::ICMP_ULE:
+            // Comparison is true iff the LHS >=s 0.
+            if (MaxRecurse)
+              if (Value *V = SimplifyICmpInst(ICmpInst::ICMP_SGE, SrcOp,
+                                              Constant::getNullValue(SrcTy),
+                                              Q, MaxRecurse-1))
+                return V;
+            break;
+          }
+        }
+      }
+    }
+  }
+
+  // icmp eq|ne X, Y -> false|true if X != Y
+  if ((Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE) &&
+      isKnownNonEqual(LHS, RHS, Q.DL, Q.AC, Q.CxtI, Q.DT)) {
+    LLVMContext &Ctx = LHS->getType()->getContext();
+    return Pred == ICmpInst::ICMP_NE ?
+      ConstantInt::getTrue(Ctx) : ConstantInt::getFalse(Ctx);
+  }
+
+  if (Value *V = simplifyICmpWithBinOp(Pred, LHS, RHS, Q, MaxRecurse))
+    return V;
+
+  if (Value *V = simplifyMinMax(Pred, LHS, RHS, Q, MaxRecurse))
+    return V;
+
   // Simplify comparisons of related pointers using a powerful, recursive
   // GEP-walk when we have target data available..
   if (LHS->getType()->isPointerTy())




More information about the llvm-commits mailing list