[llvm] r314640 - [InstCombine] Use APInt for all the math in foldICmpDivConstant

Craig Topper via llvm-commits llvm-commits at lists.llvm.org
Sun Oct 1 16:53:54 PDT 2017


Author: ctopper
Date: Sun Oct  1 16:53:54 2017
New Revision: 314640

URL: http://llvm.org/viewvc/llvm-project?rev=314640&view=rev
Log:
[InstCombine] Use APInt for all the math in foldICmpDivConstant

Summary: This currently uses ConstantExpr to do its math, but as noted in a TODO it can all be done directly on APInt.

Reviewers: spatel, majnemer

Reviewed By: majnemer

Subscribers: llvm-commits

Differential Revision: https://reviews.llvm.org/D38440

Modified:
    llvm/trunk/lib/Transforms/InstCombine/InstCombineCompares.cpp

Modified: llvm/trunk/lib/Transforms/InstCombine/InstCombineCompares.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Transforms/InstCombine/InstCombineCompares.cpp?rev=314640&r1=314639&r2=314640&view=diff
==============================================================================
--- llvm/trunk/lib/Transforms/InstCombine/InstCombineCompares.cpp (original)
+++ llvm/trunk/lib/Transforms/InstCombine/InstCombineCompares.cpp Sun Oct  1 16:53:54 2017
@@ -37,77 +37,30 @@ using namespace PatternMatch;
 STATISTIC(NumSel, "Number of select opts");
 
 
-static ConstantInt *extractElement(Constant *V, Constant *Idx) {
-  return cast<ConstantInt>(ConstantExpr::getExtractElement(V, Idx));
-}
-
-static bool hasAddOverflow(ConstantInt *Result,
-                           ConstantInt *In1, ConstantInt *In2,
-                           bool IsSigned) {
-  if (!IsSigned)
-    return Result->getValue().ult(In1->getValue());
-
-  if (In2->isNegative())
-    return Result->getValue().sgt(In1->getValue());
-  return Result->getValue().slt(In1->getValue());
-}
-
 /// Compute Result = In1+In2, returning true if the result overflowed for this
 /// type.
-static bool addWithOverflow(Constant *&Result, Constant *In1,
-                            Constant *In2, bool IsSigned = false) {
-  Result = ConstantExpr::getAdd(In1, In2);
-
-  if (VectorType *VTy = dyn_cast<VectorType>(In1->getType())) {
-    for (unsigned i = 0, e = VTy->getNumElements(); i != e; ++i) {
-      Constant *Idx = ConstantInt::get(Type::getInt32Ty(In1->getContext()), i);
-      if (hasAddOverflow(extractElement(Result, Idx),
-                         extractElement(In1, Idx),
-                         extractElement(In2, Idx),
-                         IsSigned))
-        return true;
-    }
-    return false;
-  }
-
-  return hasAddOverflow(cast<ConstantInt>(Result),
-                        cast<ConstantInt>(In1), cast<ConstantInt>(In2),
-                        IsSigned);
-}
-
-static bool hasSubOverflow(ConstantInt *Result,
-                           ConstantInt *In1, ConstantInt *In2,
-                           bool IsSigned) {
-  if (!IsSigned)
-    return Result->getValue().ugt(In1->getValue());
-
-  if (In2->isNegative())
-    return Result->getValue().slt(In1->getValue());
+static bool addWithOverflow(APInt &Result, const APInt &In1,
+                            const APInt &In2, bool IsSigned = false) {
+  bool Overflow;
+  if (IsSigned)
+    Result = In1.sadd_ov(In2, Overflow);
+  else
+    Result = In1.uadd_ov(In2, Overflow);
 
-  return Result->getValue().sgt(In1->getValue());
+  return Overflow;
 }
 
 /// Compute Result = In1-In2, returning true if the result overflowed for this
 /// type.
-static bool subWithOverflow(Constant *&Result, Constant *In1,
-                            Constant *In2, bool IsSigned = false) {
-  Result = ConstantExpr::getSub(In1, In2);
-
-  if (VectorType *VTy = dyn_cast<VectorType>(In1->getType())) {
-    for (unsigned i = 0, e = VTy->getNumElements(); i != e; ++i) {
-      Constant *Idx = ConstantInt::get(Type::getInt32Ty(In1->getContext()), i);
-      if (hasSubOverflow(extractElement(Result, Idx),
-                         extractElement(In1, Idx),
-                         extractElement(In2, Idx),
-                         IsSigned))
-        return true;
-    }
-    return false;
-  }
+static bool subWithOverflow(APInt &Result, const APInt &In1,
+                            const APInt &In2, bool IsSigned = false) {
+  bool Overflow;
+  if (IsSigned)
+    Result = In1.ssub_ov(In2, Overflow);
+  else
+    Result = In1.usub_ov(In2, Overflow);
 
-  return hasSubOverflow(cast<ConstantInt>(Result),
-                        cast<ConstantInt>(In1), cast<ConstantInt>(In2),
-                        IsSigned);
+  return Overflow;
 }
 
 /// Given an icmp instruction, return true if any use of this comparison is a
@@ -2186,28 +2139,22 @@ Instruction *InstCombiner::foldICmpDivCo
       (DivIsSigned && C2->isAllOnesValue()))
     return nullptr;
 
-  // TODO: We could do all of the computations below using APInt.
-  Constant *CmpRHS = cast<Constant>(Cmp.getOperand(1));
-  Constant *DivRHS = cast<Constant>(Div->getOperand(1));
-
-  // Compute Prod = CmpRHS * DivRHS. We are essentially solving an equation of
-  // form X / C2 = C. We solve for X by multiplying C2 (DivRHS) and C (CmpRHS).
+  // Compute Prod = C * C2. We are essentially solving an equation of
+  // form X / C2 = C. We solve for X by multiplying C2 and C.
   // By solving for X, we can turn this into a range check instead of computing
   // a divide.
-  Constant *Prod = ConstantExpr::getMul(CmpRHS, DivRHS);
+  APInt Prod = *C * *C2;
 
   // Determine if the product overflows by seeing if the product is not equal to
   // the divide. Make sure we do the same kind of divide as in the LHS
   // instruction that we're folding.
-  bool ProdOV = (DivIsSigned ? ConstantExpr::getSDiv(Prod, DivRHS)
-                             : ConstantExpr::getUDiv(Prod, DivRHS)) != CmpRHS;
+  bool ProdOV = (DivIsSigned ? Prod.sdiv(*C2) : Prod.udiv(*C2)) != *C;
 
   ICmpInst::Predicate Pred = Cmp.getPredicate();
 
   // If the division is known to be exact, then there is no remainder from the
   // divide, so the covered range size is unit, otherwise it is the divisor.
-  Constant *RangeSize =
-      Div->isExact() ? ConstantInt::get(Div->getType(), 1) : DivRHS;
+  APInt RangeSize = Div->isExact() ? APInt(C2->getBitWidth(), 1) : *C2;
 
   // Figure out the interval that is being checked.  For example, a comparison
   // like "X /u 5 == 0" is really checking that X is in the interval [0, 5).
@@ -2217,7 +2164,7 @@ Instruction *InstCombiner::foldICmpDivCo
   // overflow variable is set to 0 if it's corresponding bound variable is valid
   // -1 if overflowed off the bottom end, or +1 if overflowed off the top end.
   int LoOverflow = 0, HiOverflow = 0;
-  Constant *LoBound = nullptr, *HiBound = nullptr;
+  APInt LoBound, HiBound;
 
   if (!DivIsSigned) {  // udiv
     // e.g. X/5 op 3  --> [15, 20)
@@ -2231,7 +2178,7 @@ Instruction *InstCombiner::foldICmpDivCo
   } else if (C2->isStrictlyPositive()) { // Divisor is > 0.
     if (C->isNullValue()) {       // (X / pos) op 0
       // Can't overflow.  e.g.  X/2 op 0 --> [-1, 2)
-      LoBound = ConstantExpr::getNeg(SubOne(RangeSize));
+      LoBound = -(RangeSize - 1);
       HiBound = RangeSize;
     } else if (C->isStrictlyPositive()) {   // (X / pos) op pos
       LoBound = Prod;     // e.g.   X/5 op 3 --> [15, 20)
@@ -2240,27 +2187,27 @@ Instruction *InstCombiner::foldICmpDivCo
         HiOverflow = addWithOverflow(HiBound, Prod, RangeSize, true);
     } else {                       // (X / pos) op neg
       // e.g. X/5 op -3  --> [-15-4, -15+1) --> [-19, -14)
-      HiBound = AddOne(Prod);
+      HiBound = Prod + 1;
       LoOverflow = HiOverflow = ProdOV ? -1 : 0;
       if (!LoOverflow) {
-        Constant *DivNeg = ConstantExpr::getNeg(RangeSize);
+        APInt DivNeg = -RangeSize;
         LoOverflow = addWithOverflow(LoBound, HiBound, DivNeg, true) ? -1 : 0;
       }
     }
   } else if (C2->isNegative()) { // Divisor is < 0.
     if (Div->isExact())
-      RangeSize = ConstantExpr::getNeg(RangeSize);
+      RangeSize.negate();
     if (C->isNullValue()) { // (X / neg) op 0
       // e.g. X/-5 op 0  --> [-4, 5)
-      LoBound = AddOne(RangeSize);
-      HiBound = ConstantExpr::getNeg(RangeSize);
-      if (HiBound == DivRHS) {     // -INTMIN = INTMIN
+      LoBound = RangeSize + 1;
+      HiBound = -RangeSize;
+      if (HiBound == *C2) {        // -INTMIN = INTMIN
         HiOverflow = 1;            // [INTMIN+1, overflow)
-        HiBound = nullptr;         // e.g. X/INTMIN = 0 --> X > INTMIN
+        HiBound = APInt();         // e.g. X/INTMIN = 0 --> X > INTMIN
       }
     } else if (C->isStrictlyPositive()) {   // (X / neg) op pos
       // e.g. X/-5 op 3  --> [-19, -14)
-      HiBound = AddOne(Prod);
+      HiBound = Prod + 1;
       HiOverflow = LoOverflow = ProdOV ? -1 : 0;
       if (!LoOverflow)
         LoOverflow = addWithOverflow(LoBound, HiBound, RangeSize, true) ? -1:0;
@@ -2283,25 +2230,27 @@ Instruction *InstCombiner::foldICmpDivCo
         return replaceInstUsesWith(Cmp, Builder.getFalse());
       if (HiOverflow)
         return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SGE :
-                            ICmpInst::ICMP_UGE, X, LoBound);
+                            ICmpInst::ICMP_UGE, X,
+                            ConstantInt::get(Div->getType(), LoBound));
       if (LoOverflow)
         return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SLT :
-                            ICmpInst::ICMP_ULT, X, HiBound);
+                            ICmpInst::ICMP_ULT, X,
+                            ConstantInt::get(Div->getType(), HiBound));
       return replaceInstUsesWith(
-          Cmp, insertRangeTest(X, LoBound->getUniqueInteger(),
-                               HiBound->getUniqueInteger(), DivIsSigned, true));
+          Cmp, insertRangeTest(X, LoBound, HiBound, DivIsSigned, true));
     case ICmpInst::ICMP_NE:
       if (LoOverflow && HiOverflow)
         return replaceInstUsesWith(Cmp, Builder.getTrue());
       if (HiOverflow)
         return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SLT :
-                            ICmpInst::ICMP_ULT, X, LoBound);
+                            ICmpInst::ICMP_ULT, X,
+                            ConstantInt::get(Div->getType(), LoBound));
       if (LoOverflow)
         return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SGE :
-                            ICmpInst::ICMP_UGE, X, HiBound);
+                            ICmpInst::ICMP_UGE, X,
+                            ConstantInt::get(Div->getType(), HiBound));
       return replaceInstUsesWith(Cmp,
-                                 insertRangeTest(X, LoBound->getUniqueInteger(),
-                                                 HiBound->getUniqueInteger(),
+                                 insertRangeTest(X, LoBound, HiBound,
                                                  DivIsSigned, false));
     case ICmpInst::ICMP_ULT:
     case ICmpInst::ICMP_SLT:
@@ -2309,7 +2258,7 @@ Instruction *InstCombiner::foldICmpDivCo
         return replaceInstUsesWith(Cmp, Builder.getTrue());
       if (LoOverflow == -1)   // Low bound is less than input range.
         return replaceInstUsesWith(Cmp, Builder.getFalse());
-      return new ICmpInst(Pred, X, LoBound);
+      return new ICmpInst(Pred, X, ConstantInt::get(Div->getType(), LoBound));
     case ICmpInst::ICMP_UGT:
     case ICmpInst::ICMP_SGT:
       if (HiOverflow == +1)       // High bound greater than input range.
@@ -2317,8 +2266,10 @@ Instruction *InstCombiner::foldICmpDivCo
       if (HiOverflow == -1)       // High bound less than input range.
         return replaceInstUsesWith(Cmp, Builder.getTrue());
       if (Pred == ICmpInst::ICMP_UGT)
-        return new ICmpInst(ICmpInst::ICMP_UGE, X, HiBound);
-      return new ICmpInst(ICmpInst::ICMP_SGE, X, HiBound);
+        return new ICmpInst(ICmpInst::ICMP_UGE, X,
+                            ConstantInt::get(Div->getType(), HiBound));
+      return new ICmpInst(ICmpInst::ICMP_SGE, X,
+                          ConstantInt::get(Div->getType(), HiBound));
   }
 
   return nullptr;




More information about the llvm-commits mailing list