[llvm] r278816 - [InstCombine] use m_APInt in foldICmpWithConstant; NFCI

Sanjay Patel via llvm-commits llvm-commits at lists.llvm.org
Tue Aug 16 09:08:11 PDT 2016


Author: spatel
Date: Tue Aug 16 11:08:11 2016
New Revision: 278816

URL: http://llvm.org/viewvc/llvm-project?rev=278816&view=rev
Log:
[InstCombine] use m_APInt in foldICmpWithConstant; NFCI

There's some formatting and pointer deref ugliness here that I intend to fix in
subsequent patches. The overall goal is to refactor the obnoxiously long switch
and incrementally remove the restriction to scalar types (allow folds for vector
splats). This patch introduces the use of m_APInt which means the RHSV reference
is now a pointer (and may have matched a vector splat), but the check of 'RHS' 
remains, so vector folds are disallowed and no functional change is intended.


Modified:
    llvm/trunk/lib/Transforms/InstCombine/InstCombineCompares.cpp
    llvm/trunk/lib/Transforms/InstCombine/InstCombineInternal.h

Modified: llvm/trunk/lib/Transforms/InstCombine/InstCombineCompares.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Transforms/InstCombine/InstCombineCompares.cpp?rev=278816&r1=278815&r2=278816&view=diff
==============================================================================
--- llvm/trunk/lib/Transforms/InstCombine/InstCombineCompares.cpp (original)
+++ llvm/trunk/lib/Transforms/InstCombine/InstCombineCompares.cpp Tue Aug 16 11:08:11 2016
@@ -1530,15 +1530,22 @@ Instruction *InstCombiner::foldICmpCstSh
   return getConstant(false);
 }
 
-/// Handle "icmp (instr, intcst)".
-Instruction *InstCombiner::foldICmpWithConstant(ICmpInst &ICI,
-                                                Instruction *LHSI,
-                                                ConstantInt *RHS) {
-  const APInt &RHSV = RHS->getValue();
+/// Try to fold integer comparisons with a constant operand: icmp Pred X, C.
+Instruction *InstCombiner::foldICmpWithConstant(ICmpInst &ICI) {
+  Instruction *LHSI;
+  const APInt *RHSV;
+  if (!match(ICI.getOperand(0), m_Instruction(LHSI)) ||
+      !match(ICI.getOperand(1), m_APInt(RHSV)))
+    return nullptr;
+
+  // FIXME: This check restricts all folds under here to scalar types.
+  ConstantInt *RHS = dyn_cast<ConstantInt>(ICI.getOperand(1));
+  if (!RHS)
+    return nullptr;
 
   switch (LHSI->getOpcode()) {
   case Instruction::Trunc:
-    if (RHS->isOne() && RHSV.getBitWidth() > 1) {
+    if (RHS->isOne() && RHSV->getBitWidth() > 1) {
       // icmp slt trunc(signum(V)) 1 --> icmp slt V, 1
       Value *V = nullptr;
       if (ICI.getPredicate() == ICmpInst::ICMP_SLT &&
@@ -1569,8 +1576,8 @@ Instruction *InstCombiner::foldICmpWithC
     if (ConstantInt *XorCst = dyn_cast<ConstantInt>(LHSI->getOperand(1))) {
       // If this is a comparison that tests the signbit (X < 0) or (x > -1),
       // fold the xor.
-      if ((ICI.getPredicate() == ICmpInst::ICMP_SLT && RHSV == 0) ||
-          (ICI.getPredicate() == ICmpInst::ICMP_SGT && RHSV.isAllOnesValue())) {
+      if ((ICI.getPredicate() == ICmpInst::ICMP_SLT && *RHSV == 0) ||
+          (ICI.getPredicate() == ICmpInst::ICMP_SGT && RHSV->isAllOnesValue())) {
         Value *CompareVal = LHSI->getOperand(0);
 
         // If the sign bit of the XorCst is not set, there is no change to
@@ -1603,7 +1610,7 @@ Instruction *InstCombiner::foldICmpWithC
                                          ? ICI.getUnsignedPredicate()
                                          : ICI.getSignedPredicate();
           return new ICmpInst(Pred, LHSI->getOperand(0),
-                              Builder->getInt(RHSV ^ SignBit));
+                              Builder->getInt(*RHSV ^ SignBit));
         }
 
         // (icmp u/s (xor A ~SignBit), C) -> (icmp s/u (xor C ~SignBit), A)
@@ -1614,20 +1621,20 @@ Instruction *InstCombiner::foldICmpWithC
                                          : ICI.getSignedPredicate();
           Pred = ICI.getSwappedPredicate(Pred);
           return new ICmpInst(Pred, LHSI->getOperand(0),
-                              Builder->getInt(RHSV ^ NotSignBit));
+                              Builder->getInt(*RHSV ^ NotSignBit));
         }
       }
 
       // (icmp ugt (xor X, C), ~C) -> (icmp ult X, C)
       //   iff -C is a power of 2
       if (ICI.getPredicate() == ICmpInst::ICMP_UGT &&
-          XorCst->getValue() == ~RHSV && (RHSV + 1).isPowerOf2())
+          XorCst->getValue() == ~(*RHSV) && (*RHSV + 1).isPowerOf2())
         return new ICmpInst(ICmpInst::ICMP_ULT, LHSI->getOperand(0), XorCst);
 
       // (icmp ult (xor X, C), -C) -> (icmp uge X, C)
       //   iff -C is a power of 2
       if (ICI.getPredicate() == ICmpInst::ICMP_ULT &&
-          XorCst->getValue() == -RHSV && RHSV.isPowerOf2())
+          XorCst->getValue() == -(*RHSV) && RHSV->isPowerOf2())
         return new ICmpInst(ICmpInst::ICMP_UGE, LHSI->getOperand(0), XorCst);
     }
     break;
@@ -1645,7 +1652,7 @@ Instruction *InstCombiner::foldICmpWithC
         // Extending a relational comparison when we're checking the sign
         // bit would not work.
         if (ICI.isEquality() ||
-            (!AndCst->isNegative() && RHSV.isNonNegative())) {
+            (!AndCst->isNegative() && RHSV->isNonNegative())) {
           Value *NewAnd =
             Builder->CreateAnd(Cast->getOperand(0),
                                ConstantExpr::getZExt(AndCst, Cast->getSrcTy()));
@@ -1661,7 +1668,7 @@ Instruction *InstCombiner::foldICmpWithC
         IntegerType *Ty = cast<IntegerType>(Cast->getSrcTy());
         // Make sure we don't compare the upper bits, SimplifyDemandedBits
         // should fold the icmp to true/false in that case.
-        if (ICI.isEquality() && RHSV.getActiveBits() <= Ty->getBitWidth()) {
+        if (ICI.isEquality() && RHSV->getActiveBits() <= Ty->getBitWidth()) {
           Value *NewAnd =
             Builder->CreateAnd(Cast->getOperand(0),
                                ConstantExpr::getTrunc(AndCst, Ty));
@@ -1754,7 +1761,7 @@ Instruction *InstCombiner::foldICmpWithC
       // Turn ((X >> Y) & C) == 0  into  (X & (C << Y)) == 0.  The later is
       // preferable because it allows the C<<Y expression to be hoisted out
       // of a loop if Y is invariant and X is not.
-      if (Shift && Shift->hasOneUse() && RHSV == 0 &&
+      if (Shift && Shift->hasOneUse() && *RHSV == 0 &&
           ICI.isEquality() && !Shift->isArithmeticShift() &&
           !isa<Constant>(Shift->getOperand(0))) {
         // Compute C << Y.
@@ -1780,7 +1787,7 @@ Instruction *InstCombiner::foldICmpWithC
       // iff pred isn't signed
       {
         Value *X, *Y, *LShr;
-        if (!ICI.isSigned() && RHSV == 0) {
+        if (!ICI.isSigned() && *RHSV == 0) {
           if (match(LHSI->getOperand(1), m_One())) {
             Constant *One = cast<Constant>(LHSI->getOperand(1));
             Value *Or = LHSI->getOperand(0);
@@ -1821,7 +1828,7 @@ Instruction *InstCombiner::foldICmpWithC
       if (ICI.getPredicate() == ICmpInst::ICMP_UGT) {
         unsigned NTZ = AndCst->getValue().countTrailingZeros();
         if ((NTZ < AndCst->getBitWidth()) &&
-            APInt::getOneBitSet(AndCst->getBitWidth(), NTZ).ugt(RHSV))
+            APInt::getOneBitSet(AndCst->getBitWidth(), NTZ).ugt(*RHSV))
           return new ICmpInst(ICmpInst::ICMP_NE, LHSI,
                               Constant::getNullValue(RHS->getType()));
       }
@@ -1843,7 +1850,7 @@ Instruction *InstCombiner::foldICmpWithC
     // X & -C == -C -> X >  u ~C
     // X & -C != -C -> X <= u ~C
     //   iff C is a power of 2
-    if (ICI.isEquality() && RHS == LHSI->getOperand(1) && (-RHSV).isPowerOf2())
+    if (ICI.isEquality() && RHS == LHSI->getOperand(1) && (-(*RHSV)).isPowerOf2())
       return new ICmpInst(
           ICI.getPredicate() == ICmpInst::ICMP_EQ ? ICmpInst::ICMP_UGT
                                                   : ICmpInst::ICMP_ULE,
@@ -1915,13 +1922,13 @@ Instruction *InstCombiner::foldICmpWithC
   }
 
   case Instruction::Shl: {       // (icmp pred (shl X, ShAmt), CI)
-    uint32_t TypeBits = RHSV.getBitWidth();
+    uint32_t TypeBits = RHSV->getBitWidth();
     ConstantInt *ShAmt = dyn_cast<ConstantInt>(LHSI->getOperand(1));
     if (!ShAmt) {
       Value *X;
       // (1 << X) pred P2 -> X pred Log2(P2)
       if (match(LHSI, m_Shl(m_One(), m_Value(X)))) {
-        bool RHSVIsPowerOf2 = RHSV.isPowerOf2();
+        bool RHSVIsPowerOf2 = RHSV->isPowerOf2();
         ICmpInst::Predicate Pred = ICI.getPredicate();
         if (ICI.isUnsigned()) {
           if (!RHSVIsPowerOf2) {
@@ -1934,7 +1941,7 @@ Instruction *InstCombiner::foldICmpWithC
             else if (Pred == ICmpInst::ICMP_UGE)
               Pred = ICmpInst::ICMP_UGT;
           }
-          unsigned RHSLog2 = RHSV.logBase2();
+          unsigned RHSLog2 = RHSV->logBase2();
 
           // (1 << X) >= 2147483648 -> X >= 31 -> X == 31
           // (1 << X) <  2147483648 -> X <  31 -> X != 31
@@ -1948,7 +1955,7 @@ Instruction *InstCombiner::foldICmpWithC
           return new ICmpInst(Pred, X,
                               ConstantInt::get(RHS->getType(), RHSLog2));
         } else if (ICI.isSigned()) {
-          if (RHSV.isAllOnesValue()) {
+          if (RHSV->isAllOnesValue()) {
             // (1 << X) <= -1 -> X == 31
             if (Pred == ICmpInst::ICMP_SLE)
               return new ICmpInst(ICmpInst::ICMP_EQ, X,
@@ -1958,7 +1965,7 @@ Instruction *InstCombiner::foldICmpWithC
             if (Pred == ICmpInst::ICMP_SGT)
               return new ICmpInst(ICmpInst::ICMP_NE, X,
                                   ConstantInt::get(RHS->getType(), TypeBits-1));
-          } else if (!RHSV) {
+          } else if (!(*RHSV)) {
             // (1 << X) <  0 -> X == 31
             // (1 << X) <= 0 -> X == 31
             if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE)
@@ -1974,7 +1981,7 @@ Instruction *InstCombiner::foldICmpWithC
         } else if (ICI.isEquality()) {
           if (RHSVIsPowerOf2)
             return new ICmpInst(
-                Pred, X, ConstantInt::get(RHS->getType(), RHSV.logBase2()));
+                Pred, X, ConstantInt::get(RHS->getType(), RHSV->logBase2()));
         }
       }
       break;
@@ -2006,7 +2013,7 @@ Instruction *InstCombiner::foldICmpWithC
 
       // If the shift is NSW and we compare to 0, then it is just shifting out
       // sign bits, no need for an AND either.
-      if (cast<BinaryOperator>(LHSI)->hasNoSignedWrap() && RHSV == 0)
+      if (cast<BinaryOperator>(LHSI)->hasNoSignedWrap() && *RHSV == 0)
         return new ICmpInst(ICI.getPredicate(), LHSI->getOperand(0),
                             ConstantExpr::getLShr(RHS, ShAmt));
 
@@ -2054,7 +2061,7 @@ Instruction *InstCombiner::foldICmpWithC
     // smaller constant, which will be target friendly.
     unsigned Amt = ShAmt->getLimitedValue(TypeBits-1);
     if (LHSI->hasOneUse() &&
-        Amt != 0 && RHSV.countTrailingZeros() >= Amt) {
+        Amt != 0 && RHSV->countTrailingZeros() >= Amt) {
       Type *NTy = IntegerType::get(ICI.getContext(), TypeBits - Amt);
       Constant *NCI = ConstantExpr::getTrunc(
                         ConstantExpr::getAShr(RHS,
@@ -2079,7 +2086,7 @@ Instruction *InstCombiner::foldICmpWithC
 
     // Handle exact shr's.
     if (ICI.isEquality() && BO->isExact() && BO->hasOneUse()) {
-      if (RHSV.isMinValue())
+      if (RHSV->isMinValue())
         return new ICmpInst(ICI.getPredicate(), BO->getOperand(0), RHS);
     }
     break;
@@ -2128,18 +2135,18 @@ Instruction *InstCombiner::foldICmpWithC
     //   iff C1 & (C2-1) == C2-1
     //       C2 is a power of 2
     if (ICI.getPredicate() == ICmpInst::ICMP_ULT && LHSI->hasOneUse() &&
-        RHSV.isPowerOf2() && (LHSV & (RHSV - 1)) == (RHSV - 1))
+        RHSV->isPowerOf2() && (LHSV & (*RHSV - 1)) == (*RHSV - 1))
       return new ICmpInst(ICmpInst::ICMP_EQ,
-                          Builder->CreateOr(LHSI->getOperand(1), RHSV - 1),
+                          Builder->CreateOr(LHSI->getOperand(1), *RHSV - 1),
                           LHSC);
 
     // C1-X >u C2 -> (X|C2) != C1
     //   iff C1 & C2 == C2
     //       C2+1 is a power of 2
     if (ICI.getPredicate() == ICmpInst::ICMP_UGT && LHSI->hasOneUse() &&
-        (RHSV + 1).isPowerOf2() && (LHSV & RHSV) == RHSV)
+        (*RHSV + 1).isPowerOf2() && (LHSV & *RHSV) == *RHSV)
       return new ICmpInst(ICmpInst::ICMP_NE,
-                          Builder->CreateOr(LHSI->getOperand(1), RHSV), LHSC);
+                          Builder->CreateOr(LHSI->getOperand(1), *RHSV), LHSC);
     break;
   }
 
@@ -2150,7 +2157,7 @@ Instruction *InstCombiner::foldICmpWithC
       if (!LHSC) break;
       const APInt &LHSV = LHSC->getValue();
 
-      ConstantRange CR = ICI.makeConstantRange(ICI.getPredicate(), RHSV)
+      ConstantRange CR = ICI.makeConstantRange(ICI.getPredicate(), *RHSV)
                             .subtract(LHSV);
 
       if (ICI.isSigned()) {
@@ -2175,18 +2182,18 @@ Instruction *InstCombiner::foldICmpWithC
       //   iff C1 & (C2-1) == 0
       //       C2 is a power of 2
       if (ICI.getPredicate() == ICmpInst::ICMP_ULT && LHSI->hasOneUse() &&
-          RHSV.isPowerOf2() && (LHSV & (RHSV - 1)) == 0)
+          RHSV->isPowerOf2() && (LHSV & (*RHSV - 1)) == 0)
         return new ICmpInst(ICmpInst::ICMP_EQ,
-                            Builder->CreateAnd(LHSI->getOperand(0), -RHSV),
+                            Builder->CreateAnd(LHSI->getOperand(0), -(*RHSV)),
                             ConstantExpr::getNeg(LHSC));
 
       // X-C1 >u C2 -> (X & ~C2) != C1
       //   iff C1 & C2 == 0
       //       C2+1 is a power of 2
       if (ICI.getPredicate() == ICmpInst::ICMP_UGT && LHSI->hasOneUse() &&
-          (RHSV + 1).isPowerOf2() && (LHSV & RHSV) == 0)
+          (*RHSV + 1).isPowerOf2() && (LHSV & *RHSV) == 0)
         return new ICmpInst(ICmpInst::ICMP_NE,
-                            Builder->CreateAnd(LHSI->getOperand(0), ~RHSV),
+                            Builder->CreateAnd(LHSI->getOperand(0), ~(*RHSV)),
                             ConstantExpr::getNeg(LHSC));
     }
     break;
@@ -3627,17 +3634,8 @@ Instruction *InstCombiner::visitICmpInst
   // See if we are doing a comparison between a constant and an instruction that
   // can be folded into the comparison.
 
-  // FIXME: Use m_APInt instead of dyn_cast<ConstantInt> to allow these
-  // transforms for vectors.
-
-  if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) {
-    // Since the RHS is a ConstantInt (CI), if the left hand side is an
-    // instruction, see if that instruction also has constants so that the
-    // instruction can be folded into the icmp
-    if (Instruction *LHSI = dyn_cast<Instruction>(Op0))
-      if (Instruction *Res = foldICmpWithConstant(I, LHSI, CI))
-        return Res;
-  }
+  if (Instruction *Res = foldICmpWithConstant(I))
+    return Res;
 
   if (Instruction *Res = foldICmpEqualityWithConstant(I))
     return Res;

Modified: llvm/trunk/lib/Transforms/InstCombine/InstCombineInternal.h
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Transforms/InstCombine/InstCombineInternal.h?rev=278816&r1=278815&r2=278816&view=diff
==============================================================================
--- llvm/trunk/lib/Transforms/InstCombine/InstCombineInternal.h (original)
+++ llvm/trunk/lib/Transforms/InstCombine/InstCombineInternal.h Tue Aug 16 11:08:11 2016
@@ -559,8 +559,7 @@ private:
   Instruction *foldICmpAddOpConst(Instruction &ICI, Value *X, ConstantInt *CI,
                                   ICmpInst::Predicate Pred);
   Instruction *foldICmpWithCastAndCast(ICmpInst &ICI);
-  Instruction *foldICmpWithConstant(ICmpInst &ICI, Instruction *LHS,
-                                    ConstantInt *RHS);
+  Instruction *foldICmpWithConstant(ICmpInst &ICI);
   Instruction *foldICmpEqualityWithConstant(ICmpInst &ICI);
   Instruction *foldICmpIntrinsicWithConstant(ICmpInst &ICI);
 




More information about the llvm-commits mailing list