[llvm] r281217 - [InstCombine] add helper function for foldICmpUsingKnownBits; NFCI

Sanjay Patel via llvm-commits llvm-commits at lists.llvm.org
Mon Sep 12 08:24:32 PDT 2016


Author: spatel
Date: Mon Sep 12 10:24:31 2016
New Revision: 281217

URL: http://llvm.org/viewvc/llvm-project?rev=281217&view=rev
Log:
[InstCombine] add helper function for foldICmpUsingKnownBits; NFCI

Modified:
    llvm/trunk/lib/Transforms/InstCombine/InstCombineCompares.cpp
    llvm/trunk/lib/Transforms/InstCombine/InstCombineInternal.h

Modified: llvm/trunk/lib/Transforms/InstCombine/InstCombineCompares.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Transforms/InstCombine/InstCombineCompares.cpp?rev=281217&r1=281216&r2=281217&view=diff
==============================================================================
--- llvm/trunk/lib/Transforms/InstCombine/InstCombineCompares.cpp (original)
+++ llvm/trunk/lib/Transforms/InstCombine/InstCombineCompares.cpp Mon Sep 12 10:24:31 2016
@@ -3135,6 +3135,274 @@ bool InstCombiner::replacedSelectWithOpe
   return false;
 }
 
+/// Try to fold the comparison based on range information we can get by checking
+/// whether bits are known to be zero or one in the inputs.
+Instruction *InstCombiner::foldICmpUsingKnownBits(ICmpInst &I) {
+  Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
+  Type *Ty = Op0->getType();
+
+  // Get scalar or pointer size.
+  unsigned BitWidth = Ty->isIntOrIntVectorTy()
+                          ? Ty->getScalarSizeInBits()
+                          : DL.getTypeSizeInBits(Ty->getScalarType());
+
+  if (!BitWidth)
+    return nullptr;
+
+  // If this is a normal comparison, it demands all bits. If it is a sign bit
+  // comparison, it only demands the sign bit.
+  bool IsSignBit = false;
+  if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) {
+    bool UnusedBit;
+    IsSignBit = isSignBitCheck(I.getPredicate(), CI->getValue(), UnusedBit);
+  }
+
+  APInt Op0KnownZero(BitWidth, 0), Op0KnownOne(BitWidth, 0);
+  APInt Op1KnownZero(BitWidth, 0), Op1KnownOne(BitWidth, 0);
+
+  if (SimplifyDemandedBits(I.getOperandUse(0),
+                           DemandedBitsLHSMask(I, BitWidth, IsSignBit),
+                           Op0KnownZero, Op0KnownOne, 0))
+    return &I;
+
+  if (SimplifyDemandedBits(I.getOperandUse(1), APInt::getAllOnesValue(BitWidth),
+                           Op1KnownZero, Op1KnownOne, 0))
+    return &I;
+
+  // Given the known and unknown bits, compute a range that the LHS could be
+  // in.  Compute the Min, Max and RHS values based on the known bits. For the
+  // EQ and NE we use unsigned values.
+  APInt Op0Min(BitWidth, 0), Op0Max(BitWidth, 0);
+  APInt Op1Min(BitWidth, 0), Op1Max(BitWidth, 0);
+  if (I.isSigned()) {
+    ComputeSignedMinMaxValuesFromKnownBits(Op0KnownZero, Op0KnownOne, Op0Min,
+                                           Op0Max);
+    ComputeSignedMinMaxValuesFromKnownBits(Op1KnownZero, Op1KnownOne, Op1Min,
+                                           Op1Max);
+  } else {
+    ComputeUnsignedMinMaxValuesFromKnownBits(Op0KnownZero, Op0KnownOne, Op0Min,
+                                             Op0Max);
+    ComputeUnsignedMinMaxValuesFromKnownBits(Op1KnownZero, Op1KnownOne, Op1Min,
+                                             Op1Max);
+  }
+
+  // If Min and Max are known to be the same, then SimplifyDemandedBits
+  // figured out that the LHS is a constant.  Just constant fold this now so
+  // that code below can assume that Min != Max.
+  if (!isa<Constant>(Op0) && Op0Min == Op0Max)
+    return new ICmpInst(I.getPredicate(),
+                        ConstantInt::get(Op0->getType(), Op0Min), Op1);
+  if (!isa<Constant>(Op1) && Op1Min == Op1Max)
+    return new ICmpInst(I.getPredicate(), Op0,
+                        ConstantInt::get(Op1->getType(), Op1Min));
+
+  // Based on the range information we know about the LHS, see if we can
+  // simplify this comparison.  For example, (x&4) < 8 is always true.
+  switch (I.getPredicate()) {
+  default:
+    llvm_unreachable("Unknown icmp opcode!");
+  case ICmpInst::ICMP_EQ: {
+    if (Op0Max.ult(Op1Min) || Op0Min.ugt(Op1Max))
+      return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
+
+    // If all bits are known zero except for one, then we know at most one
+    // bit is set.   If the comparison is against zero, then this is a check
+    // to see if *that* bit is set.
+    APInt Op0KnownZeroInverted = ~Op0KnownZero;
+    if (~Op1KnownZero == 0) {
+      // If the LHS is an AND with the same constant, look through it.
+      Value *LHS = nullptr;
+      ConstantInt *LHSC = nullptr;
+      if (!match(Op0, m_And(m_Value(LHS), m_ConstantInt(LHSC))) ||
+          LHSC->getValue() != Op0KnownZeroInverted)
+        LHS = Op0;
+
+      // If the LHS is 1 << x, and we know the result is a power of 2 like 8,
+      // then turn "((1 << x)&8) == 0" into "x != 3".
+      // or turn "((1 << x)&7) == 0" into "x > 2".
+      Value *X = nullptr;
+      if (match(LHS, m_Shl(m_One(), m_Value(X)))) {
+        APInt ValToCheck = Op0KnownZeroInverted;
+        if (ValToCheck.isPowerOf2()) {
+          unsigned CmpVal = ValToCheck.countTrailingZeros();
+          return new ICmpInst(ICmpInst::ICMP_NE, X,
+                              ConstantInt::get(X->getType(), CmpVal));
+        } else if ((++ValToCheck).isPowerOf2()) {
+          unsigned CmpVal = ValToCheck.countTrailingZeros() - 1;
+          return new ICmpInst(ICmpInst::ICMP_UGT, X,
+                              ConstantInt::get(X->getType(), CmpVal));
+        }
+      }
+
+      // If the LHS is 8 >>u x, and we know the result is a power of 2 like 1,
+      // then turn "((8 >>u x)&1) == 0" into "x != 3".
+      const APInt *CI;
+      if (Op0KnownZeroInverted == 1 &&
+          match(LHS, m_LShr(m_Power2(CI), m_Value(X))))
+        return new ICmpInst(
+            ICmpInst::ICMP_NE, X,
+            ConstantInt::get(X->getType(), CI->countTrailingZeros()));
+    }
+    break;
+  }
+  case ICmpInst::ICMP_NE: {
+    if (Op0Max.ult(Op1Min) || Op0Min.ugt(Op1Max))
+      return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
+
+    // If all bits are known zero except for one, then we know at most one
+    // bit is set.   If the comparison is against zero, then this is a check
+    // to see if *that* bit is set.
+    APInt Op0KnownZeroInverted = ~Op0KnownZero;
+    if (~Op1KnownZero == 0) {
+      // If the LHS is an AND with the same constant, look through it.
+      Value *LHS = nullptr;
+      ConstantInt *LHSC = nullptr;
+      if (!match(Op0, m_And(m_Value(LHS), m_ConstantInt(LHSC))) ||
+          LHSC->getValue() != Op0KnownZeroInverted)
+        LHS = Op0;
+
+      // If the LHS is 1 << x, and we know the result is a power of 2 like 8,
+      // then turn "((1 << x)&8) != 0" into "x == 3".
+      // or turn "((1 << x)&7) != 0" into "x < 3".
+      Value *X = nullptr;
+      if (match(LHS, m_Shl(m_One(), m_Value(X)))) {
+        APInt ValToCheck = Op0KnownZeroInverted;
+        if (ValToCheck.isPowerOf2()) {
+          unsigned CmpVal = ValToCheck.countTrailingZeros();
+          return new ICmpInst(ICmpInst::ICMP_EQ, X,
+                              ConstantInt::get(X->getType(), CmpVal));
+        } else if ((++ValToCheck).isPowerOf2()) {
+          unsigned CmpVal = ValToCheck.countTrailingZeros();
+          return new ICmpInst(ICmpInst::ICMP_ULT, X,
+                              ConstantInt::get(X->getType(), CmpVal));
+        }
+      }
+
+      // If the LHS is 8 >>u x, and we know the result is a power of 2 like 1,
+      // then turn "((8 >>u x)&1) != 0" into "x == 3".
+      const APInt *CI;
+      if (Op0KnownZeroInverted == 1 &&
+          match(LHS, m_LShr(m_Power2(CI), m_Value(X))))
+        return new ICmpInst(
+            ICmpInst::ICMP_EQ, X,
+            ConstantInt::get(X->getType(), CI->countTrailingZeros()));
+    }
+    break;
+  }
+  case ICmpInst::ICMP_ULT: {
+    if (Op0Max.ult(Op1Min)) // A <u B -> true if max(A) < min(B)
+      return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
+    if (Op0Min.uge(Op1Max)) // A <u B -> false if min(A) >= max(B)
+      return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
+    if (Op1Min == Op0Max) // A <u B -> A != B if max(A) == min(B)
+      return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
+
+    const APInt *CmpC;
+    if (match(Op1, m_APInt(CmpC))) {
+      // A <u C -> A == C-1 if min(A)+1 == C
+      if (Op1Max == Op0Min + 1) {
+        Constant *CMinus1 = ConstantInt::get(Op0->getType(), *CmpC - 1);
+        return new ICmpInst(ICmpInst::ICMP_EQ, Op0, CMinus1);
+      }
+      // (x <u 2147483648) -> (x >s -1)  -> true if sign bit clear
+      if (CmpC->isMinSignedValue()) {
+        Constant *AllOnes = Constant::getAllOnesValue(Op0->getType());
+        return new ICmpInst(ICmpInst::ICMP_SGT, Op0, AllOnes);
+      }
+    }
+    break;
+  }
+  case ICmpInst::ICMP_UGT: {
+    if (Op0Min.ugt(Op1Max)) // A >u B -> true if min(A) > max(B)
+      return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
+
+    if (Op0Max.ule(Op1Min)) // A >u B -> false if max(A) <= max(B)
+      return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
+
+    if (Op1Max == Op0Min) // A >u B -> A != B if min(A) == max(B)
+      return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
+
+    const APInt *CmpC;
+    if (match(Op1, m_APInt(CmpC))) {
+      // A >u C -> A == C+1 if max(a)-1 == C
+      if (*CmpC == Op0Max - 1)
+        return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
+                            ConstantInt::get(Op1->getType(), *CmpC + 1));
+
+      // (x >u 2147483647) -> (x <s 0)  -> true if sign bit set
+      if (CmpC->isMaxSignedValue())
+        return new ICmpInst(ICmpInst::ICMP_SLT, Op0,
+                            Constant::getNullValue(Op0->getType()));
+    }
+    break;
+  }
+  case ICmpInst::ICMP_SLT:
+    if (Op0Max.slt(Op1Min)) // A <s B -> true if max(A) < min(C)
+      return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
+    if (Op0Min.sge(Op1Max)) // A <s B -> false if min(A) >= max(C)
+      return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
+    if (Op1Min == Op0Max) // A <s B -> A != B if max(A) == min(B)
+      return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
+    if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) {
+      if (Op1Max == Op0Min + 1) // A <s C -> A == C-1 if min(A)+1 == C
+        return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
+                            Builder->getInt(CI->getValue() - 1));
+    }
+    break;
+  case ICmpInst::ICMP_SGT:
+    if (Op0Min.sgt(Op1Max)) // A >s B -> true if min(A) > max(B)
+      return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
+    if (Op0Max.sle(Op1Min)) // A >s B -> false if max(A) <= min(B)
+      return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
+
+    if (Op1Max == Op0Min) // A >s B -> A != B if min(A) == max(B)
+      return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
+    if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) {
+      if (Op1Min == Op0Max - 1) // A >s C -> A == C+1 if max(A)-1 == C
+        return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
+                            Builder->getInt(CI->getValue() + 1));
+    }
+    break;
+  case ICmpInst::ICMP_SGE:
+    assert(!isa<ConstantInt>(Op1) && "ICMP_SGE with ConstantInt not folded!");
+    if (Op0Min.sge(Op1Max)) // A >=s B -> true if min(A) >= max(B)
+      return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
+    if (Op0Max.slt(Op1Min)) // A >=s B -> false if max(A) < min(B)
+      return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
+    break;
+  case ICmpInst::ICMP_SLE:
+    assert(!isa<ConstantInt>(Op1) && "ICMP_SLE with ConstantInt not folded!");
+    if (Op0Max.sle(Op1Min)) // A <=s B -> true if max(A) <= min(B)
+      return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
+    if (Op0Min.sgt(Op1Max)) // A <=s B -> false if min(A) > max(B)
+      return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
+    break;
+  case ICmpInst::ICMP_UGE:
+    assert(!isa<ConstantInt>(Op1) && "ICMP_UGE with ConstantInt not folded!");
+    if (Op0Min.uge(Op1Max)) // A >=u B -> true if min(A) >= max(B)
+      return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
+    if (Op0Max.ult(Op1Min)) // A >=u B -> false if max(A) < min(B)
+      return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
+    break;
+  case ICmpInst::ICMP_ULE:
+    assert(!isa<ConstantInt>(Op1) && "ICMP_ULE with ConstantInt not folded!");
+    if (Op0Max.ule(Op1Min)) // A <=u B -> true if max(A) <= min(B)
+      return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
+    if (Op0Min.ugt(Op1Max)) // A <=u B -> false if min(A) > max(B)
+      return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
+    break;
+  }
+
+  // Turn a signed comparison into an unsigned one if both operands are known to
+  // have the same sign.
+  if (I.isSigned() &&
+      ((Op0KnownZero.isNegative() && Op1KnownZero.isNegative()) ||
+       (Op0KnownOne.isNegative() && Op1KnownOne.isNegative())))
+    return new ICmpInst(I.getUnsignedPredicate(), Op0, Op1);
+
+  return nullptr;
+}
+
 /// If we have an icmp le or icmp ge instruction with a constant operand, turn
 /// it into the appropriate icmp lt or icmp gt instruction. This transform
 /// allows them to be folded in visitICmpInst.
@@ -3276,14 +3544,6 @@ Instruction *InstCombiner::visitICmpInst
   if (ICmpInst *NewICmp = canonicalizeCmpWithConstant(I))
     return NewICmp;
 
-  unsigned BitWidth = 0;
-  if (Ty->isIntOrIntVectorTy())
-    BitWidth = Ty->getScalarSizeInBits();
-  else // Get pointer size.
-    BitWidth = DL.getTypeSizeInBits(Ty->getScalarType());
-
-  bool isSignBit = false;
-
   // See if we are doing a comparison with a constant.
   if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) {
     Value *A = nullptr, *B = nullptr;
@@ -3365,11 +3625,6 @@ Instruction *InstCombiner::visitICmpInst
       }
     }
 
-    // If this comparison is a normal comparison, it demands all
-    // bits, if it is a sign bit comparison, it only demands the sign bit.
-    bool UnusedBit;
-    isSignBit = isSignBitCheck(I.getPredicate(), CI->getValue(), UnusedBit);
-
     // Canonicalize icmp instructions based on dominating conditions.
     BasicBlock *Parent = I.getParent();
     BasicBlock *Dom = Parent->getSinglePredecessor();
@@ -3393,12 +3648,19 @@ Instruction *InstCombiner::visitICmpInst
         return replaceInstUsesWith(I, Builder->getFalse());
       if (Difference.isEmptySet())
         return replaceInstUsesWith(I, Builder->getTrue());
+
+      // If this is a normal comparison, it demands all bits. If it is a sign
+      // bit comparison, it only demands the sign bit.
+      bool UnusedBit;
+      bool IsSignBit =
+          isSignBitCheck(I.getPredicate(), CI->getValue(), UnusedBit);
+
       // Canonicalizing a sign bit comparison that gets used in a branch,
       // pessimizes codegen by generating branch on zero instruction instead
       // of a test and branch. So we avoid canonicalizing in such situations
       // because test and branch instruction has better branch displacement
       // than compare and branch instruction.
-      if (!isBranchOnSignBitCheck(I, isSignBit) && !I.isEquality()) {
+      if (!isBranchOnSignBitCheck(I, IsSignBit) && !I.isEquality()) {
         if (auto *AI = Intersection.getSingleElement())
           return new ICmpInst(ICmpInst::ICMP_EQ, Op0, Builder->getInt(*AI));
         if (auto *AD = Difference.getSingleElement())
@@ -3407,251 +3669,8 @@ Instruction *InstCombiner::visitICmpInst
     }
   }
 
-  // See if we can fold the comparison based on range information we can get
-  // by checking whether bits are known to be zero or one in the input.
-  if (BitWidth != 0) {
-    APInt Op0KnownZero(BitWidth, 0), Op0KnownOne(BitWidth, 0);
-    APInt Op1KnownZero(BitWidth, 0), Op1KnownOne(BitWidth, 0);
-
-    if (SimplifyDemandedBits(I.getOperandUse(0),
-                             DemandedBitsLHSMask(I, BitWidth, isSignBit),
-                             Op0KnownZero, Op0KnownOne, 0))
-      return &I;
-    if (SimplifyDemandedBits(I.getOperandUse(1),
-                             APInt::getAllOnesValue(BitWidth), Op1KnownZero,
-                             Op1KnownOne, 0))
-      return &I;
-
-    // Given the known and unknown bits, compute a range that the LHS could be
-    // in.  Compute the Min, Max and RHS values based on the known bits. For the
-    // EQ and NE we use unsigned values.
-    APInt Op0Min(BitWidth, 0), Op0Max(BitWidth, 0);
-    APInt Op1Min(BitWidth, 0), Op1Max(BitWidth, 0);
-    if (I.isSigned()) {
-      ComputeSignedMinMaxValuesFromKnownBits(Op0KnownZero, Op0KnownOne,
-                                             Op0Min, Op0Max);
-      ComputeSignedMinMaxValuesFromKnownBits(Op1KnownZero, Op1KnownOne,
-                                             Op1Min, Op1Max);
-    } else {
-      ComputeUnsignedMinMaxValuesFromKnownBits(Op0KnownZero, Op0KnownOne,
-                                               Op0Min, Op0Max);
-      ComputeUnsignedMinMaxValuesFromKnownBits(Op1KnownZero, Op1KnownOne,
-                                               Op1Min, Op1Max);
-    }
-
-    // If Min and Max are known to be the same, then SimplifyDemandedBits
-    // figured out that the LHS is a constant.  Just constant fold this now so
-    // that code below can assume that Min != Max.
-    if (!isa<Constant>(Op0) && Op0Min == Op0Max)
-      return new ICmpInst(I.getPredicate(),
-                          ConstantInt::get(Op0->getType(), Op0Min), Op1);
-    if (!isa<Constant>(Op1) && Op1Min == Op1Max)
-      return new ICmpInst(I.getPredicate(), Op0,
-                          ConstantInt::get(Op1->getType(), Op1Min));
-
-    // Based on the range information we know about the LHS, see if we can
-    // simplify this comparison.  For example, (x&4) < 8 is always true.
-    switch (I.getPredicate()) {
-    default: llvm_unreachable("Unknown icmp opcode!");
-    case ICmpInst::ICMP_EQ: {
-      if (Op0Max.ult(Op1Min) || Op0Min.ugt(Op1Max))
-        return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
-
-      // If all bits are known zero except for one, then we know at most one
-      // bit is set.   If the comparison is against zero, then this is a check
-      // to see if *that* bit is set.
-      APInt Op0KnownZeroInverted = ~Op0KnownZero;
-      if (~Op1KnownZero == 0) {
-        // If the LHS is an AND with the same constant, look through it.
-        Value *LHS = nullptr;
-        ConstantInt *LHSC = nullptr;
-        if (!match(Op0, m_And(m_Value(LHS), m_ConstantInt(LHSC))) ||
-            LHSC->getValue() != Op0KnownZeroInverted)
-          LHS = Op0;
-
-        // If the LHS is 1 << x, and we know the result is a power of 2 like 8,
-        // then turn "((1 << x)&8) == 0" into "x != 3".
-        // or turn "((1 << x)&7) == 0" into "x > 2".
-        Value *X = nullptr;
-        if (match(LHS, m_Shl(m_One(), m_Value(X)))) {
-          APInt ValToCheck = Op0KnownZeroInverted;
-          if (ValToCheck.isPowerOf2()) {
-            unsigned CmpVal = ValToCheck.countTrailingZeros();
-            return new ICmpInst(ICmpInst::ICMP_NE, X,
-                                ConstantInt::get(X->getType(), CmpVal));
-          } else if ((++ValToCheck).isPowerOf2()) {
-            unsigned CmpVal = ValToCheck.countTrailingZeros() - 1;
-            return new ICmpInst(ICmpInst::ICMP_UGT, X,
-                                ConstantInt::get(X->getType(), CmpVal));
-          }
-        }
-
-        // If the LHS is 8 >>u x, and we know the result is a power of 2 like 1,
-        // then turn "((8 >>u x)&1) == 0" into "x != 3".
-        const APInt *CI;
-        if (Op0KnownZeroInverted == 1 &&
-            match(LHS, m_LShr(m_Power2(CI), m_Value(X))))
-          return new ICmpInst(ICmpInst::ICMP_NE, X,
-                              ConstantInt::get(X->getType(),
-                                               CI->countTrailingZeros()));
-      }
-      break;
-    }
-    case ICmpInst::ICMP_NE: {
-      if (Op0Max.ult(Op1Min) || Op0Min.ugt(Op1Max))
-        return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
-
-      // If all bits are known zero except for one, then we know at most one
-      // bit is set.   If the comparison is against zero, then this is a check
-      // to see if *that* bit is set.
-      APInt Op0KnownZeroInverted = ~Op0KnownZero;
-      if (~Op1KnownZero == 0) {
-        // If the LHS is an AND with the same constant, look through it.
-        Value *LHS = nullptr;
-        ConstantInt *LHSC = nullptr;
-        if (!match(Op0, m_And(m_Value(LHS), m_ConstantInt(LHSC))) ||
-            LHSC->getValue() != Op0KnownZeroInverted)
-          LHS = Op0;
-
-        // If the LHS is 1 << x, and we know the result is a power of 2 like 8,
-        // then turn "((1 << x)&8) != 0" into "x == 3".
-        // or turn "((1 << x)&7) != 0" into "x < 3".
-        Value *X = nullptr;
-        if (match(LHS, m_Shl(m_One(), m_Value(X)))) {
-          APInt ValToCheck = Op0KnownZeroInverted;
-          if (ValToCheck.isPowerOf2()) {
-            unsigned CmpVal = ValToCheck.countTrailingZeros();
-            return new ICmpInst(ICmpInst::ICMP_EQ, X,
-                                ConstantInt::get(X->getType(), CmpVal));
-          } else if ((++ValToCheck).isPowerOf2()) {
-            unsigned CmpVal = ValToCheck.countTrailingZeros();
-            return new ICmpInst(ICmpInst::ICMP_ULT, X,
-                                ConstantInt::get(X->getType(), CmpVal));
-          }
-        }
-
-        // If the LHS is 8 >>u x, and we know the result is a power of 2 like 1,
-        // then turn "((8 >>u x)&1) != 0" into "x == 3".
-        const APInt *CI;
-        if (Op0KnownZeroInverted == 1 &&
-            match(LHS, m_LShr(m_Power2(CI), m_Value(X))))
-          return new ICmpInst(ICmpInst::ICMP_EQ, X,
-                              ConstantInt::get(X->getType(),
-                                               CI->countTrailingZeros()));
-      }
-      break;
-    }
-    case ICmpInst::ICMP_ULT: {
-      if (Op0Max.ult(Op1Min))          // A <u B -> true if max(A) < min(B)
-        return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
-      if (Op0Min.uge(Op1Max))          // A <u B -> false if min(A) >= max(B)
-        return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
-      if (Op1Min == Op0Max)            // A <u B -> A != B if max(A) == min(B)
-        return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
-
-      const APInt *CmpC;
-      if (match(Op1, m_APInt(CmpC))) {
-        // A <u C -> A == C-1 if min(A)+1 == C
-        if (Op1Max == Op0Min + 1) {
-          Constant *CMinus1 = ConstantInt::get(Op0->getType(), *CmpC - 1);
-          return new ICmpInst(ICmpInst::ICMP_EQ, Op0, CMinus1);
-        }
-        // (x <u 2147483648) -> (x >s -1)  -> true if sign bit clear
-        if (CmpC->isMinSignedValue()) {
-          Constant *AllOnes = Constant::getAllOnesValue(Op0->getType());
-          return new ICmpInst(ICmpInst::ICMP_SGT, Op0, AllOnes);
-        }
-      }
-      break;
-    }
-    case ICmpInst::ICMP_UGT: {
-      if (Op0Min.ugt(Op1Max))          // A >u B -> true if min(A) > max(B)
-        return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
-
-      if (Op0Max.ule(Op1Min))          // A >u B -> false if max(A) <= max(B)
-        return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
-
-      if (Op1Max == Op0Min)            // A >u B -> A != B if min(A) == max(B)
-        return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
-
-      const APInt *CmpC;
-      if (match(Op1, m_APInt(CmpC))) {
-        // A >u C -> A == C+1 if max(a)-1 == C
-        if (*CmpC == Op0Max - 1)
-          return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
-                              ConstantInt::get(Op1->getType(), *CmpC + 1));
-
-        // (x >u 2147483647) -> (x <s 0)  -> true if sign bit set
-        if (CmpC->isMaxSignedValue())
-          return new ICmpInst(ICmpInst::ICMP_SLT, Op0,
-                              Constant::getNullValue(Op0->getType()));
-      }
-      break;
-    }
-    case ICmpInst::ICMP_SLT:
-      if (Op0Max.slt(Op1Min))          // A <s B -> true if max(A) < min(C)
-        return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
-      if (Op0Min.sge(Op1Max))          // A <s B -> false if min(A) >= max(C)
-        return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
-      if (Op1Min == Op0Max)            // A <s B -> A != B if max(A) == min(B)
-        return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
-      if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) {
-        if (Op1Max == Op0Min+1)        // A <s C -> A == C-1 if min(A)+1 == C
-          return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
-                              Builder->getInt(CI->getValue()-1));
-      }
-      break;
-    case ICmpInst::ICMP_SGT:
-      if (Op0Min.sgt(Op1Max))          // A >s B -> true if min(A) > max(B)
-        return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
-      if (Op0Max.sle(Op1Min))          // A >s B -> false if max(A) <= min(B)
-        return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
-
-      if (Op1Max == Op0Min)            // A >s B -> A != B if min(A) == max(B)
-        return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
-      if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) {
-        if (Op1Min == Op0Max-1)        // A >s C -> A == C+1 if max(A)-1 == C
-          return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
-                              Builder->getInt(CI->getValue()+1));
-      }
-      break;
-    case ICmpInst::ICMP_SGE:
-      assert(!isa<ConstantInt>(Op1) && "ICMP_SGE with ConstantInt not folded!");
-      if (Op0Min.sge(Op1Max))          // A >=s B -> true if min(A) >= max(B)
-        return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
-      if (Op0Max.slt(Op1Min))          // A >=s B -> false if max(A) < min(B)
-        return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
-      break;
-    case ICmpInst::ICMP_SLE:
-      assert(!isa<ConstantInt>(Op1) && "ICMP_SLE with ConstantInt not folded!");
-      if (Op0Max.sle(Op1Min))          // A <=s B -> true if max(A) <= min(B)
-        return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
-      if (Op0Min.sgt(Op1Max))          // A <=s B -> false if min(A) > max(B)
-        return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
-      break;
-    case ICmpInst::ICMP_UGE:
-      assert(!isa<ConstantInt>(Op1) && "ICMP_UGE with ConstantInt not folded!");
-      if (Op0Min.uge(Op1Max))          // A >=u B -> true if min(A) >= max(B)
-        return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
-      if (Op0Max.ult(Op1Min))          // A >=u B -> false if max(A) < min(B)
-        return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
-      break;
-    case ICmpInst::ICMP_ULE:
-      assert(!isa<ConstantInt>(Op1) && "ICMP_ULE with ConstantInt not folded!");
-      if (Op0Max.ule(Op1Min))          // A <=u B -> true if max(A) <= min(B)
-        return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
-      if (Op0Min.ugt(Op1Max))          // A <=u B -> false if min(A) > max(B)
-        return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
-      break;
-    }
-
-    // Turn a signed comparison into an unsigned one if both operands
-    // are known to have the same sign.
-    if (I.isSigned() &&
-        ((Op0KnownZero.isNegative() && Op1KnownZero.isNegative()) ||
-         (Op0KnownOne.isNegative() && Op1KnownOne.isNegative())))
-      return new ICmpInst(I.getUnsignedPredicate(), Op0, Op1);
-  }
+  if (Instruction *Res = foldICmpUsingKnownBits(I))
+    return Res;
 
   // Test if the ICmpInst instruction is used exclusively by a select as
   // part of a minimum or maximum operation. If so, refrain from doing

Modified: llvm/trunk/lib/Transforms/InstCombine/InstCombineInternal.h
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Transforms/InstCombine/InstCombineInternal.h?rev=281217&r1=281216&r2=281217&view=diff
==============================================================================
--- llvm/trunk/lib/Transforms/InstCombine/InstCombineInternal.h (original)
+++ llvm/trunk/lib/Transforms/InstCombine/InstCombineInternal.h Mon Sep 12 10:24:31 2016
@@ -556,6 +556,7 @@ private:
                                   ICmpInst::Predicate Pred);
   Instruction *foldICmpWithCastAndCast(ICmpInst &ICI);
 
+  Instruction *foldICmpUsingKnownBits(ICmpInst &Cmp);
   Instruction *foldICmpInstWithConstant(ICmpInst &Cmp);
 
   Instruction *foldICmpTruncConstant(ICmpInst &Cmp, Instruction *Trunc,




More information about the llvm-commits mailing list