[llvm] r300252 - Remove all allocation and divisions from GreatestCommonDivisor

Richard Smith via llvm-commits llvm-commits at lists.llvm.org
Thu Apr 13 13:29:59 PDT 2017


Author: rsmith
Date: Thu Apr 13 15:29:59 2017
New Revision: 300252

URL: http://llvm.org/viewvc/llvm-project?rev=300252&view=rev
Log:
Remove all allocation and divisions from GreatestCommonDivisor

Switch from Euclid's algorithm to Stein's algorithm for computing GCD. This
avoids the (expensive) APInt division operation in favour of bit operations.
Remove all memory allocation from within the GCD loop by tweaking our `lshr`
implementation so it can operate in-place.

Differential Revision: https://reviews.llvm.org/D31968

Added:
    llvm/trunk/test/Transforms/InstCombine/divisibility.ll
Modified:
    llvm/trunk/include/llvm/ADT/APInt.h
    llvm/trunk/lib/Support/APInt.cpp
    llvm/trunk/lib/Transforms/InstCombine/InstCombineCompares.cpp
    llvm/trunk/unittests/ADT/APIntTest.cpp

Modified: llvm/trunk/include/llvm/ADT/APInt.h
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/include/llvm/ADT/APInt.h?rev=300252&r1=300251&r2=300252&view=diff
==============================================================================
--- llvm/trunk/include/llvm/ADT/APInt.h (original)
+++ llvm/trunk/include/llvm/ADT/APInt.h Thu Apr 13 15:29:59 2017
@@ -869,7 +869,14 @@ public:
   /// \brief Logical right-shift function.
   ///
   /// Logical right-shift this APInt by shiftAmt.
-  APInt lshr(unsigned shiftAmt) const;
+  APInt lshr(unsigned shiftAmt) const {
+    APInt R(*this);
+    R.lshrInPlace(shiftAmt);
+    return R;
+  }
+
+  /// Logical right-shift this APInt by shiftAmt in place.
+  void lshrInPlace(unsigned shiftAmt);
 
   /// \brief Left-shift function.
   ///
@@ -1949,7 +1956,7 @@ inline const APInt &umax(const APInt &A,
   return A.ugt(B) ? A : B;
 }
 
-/// \brief Compute GCD of two APInt values.
+/// \brief Compute GCD of two unsigned APInt values.
 ///
 /// This function returns the greatest common divisor of the two APInt values
 /// using Euclid's algorithm.

Modified: llvm/trunk/lib/Support/APInt.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Support/APInt.cpp?rev=300252&r1=300251&r2=300252&view=diff
==============================================================================
--- llvm/trunk/lib/Support/APInt.cpp (original)
+++ llvm/trunk/lib/Support/APInt.cpp Thu Apr 13 15:29:59 2017
@@ -733,18 +733,6 @@ unsigned APInt::countPopulationSlowCase(
   return Count;
 }
 
-/// Perform a logical right-shift from Src to Dst, which must be equal or
-/// non-overlapping, of Words words, by Shift, which must be less than 64.
-static void lshrNear(uint64_t *Dst, uint64_t *Src, unsigned Words,
-                     unsigned Shift) {
-  uint64_t Carry = 0;
-  for (int I = Words - 1; I >= 0; --I) {
-    uint64_t Tmp = Src[I];
-    Dst[I] = (Tmp >> Shift) | Carry;
-    Carry = Tmp << (64 - Shift);
-  }
-}
-
 APInt APInt::byteSwap() const {
   assert(BitWidth >= 16 && BitWidth % 16 == 0 && "Cannot byteswap!");
   if (BitWidth == 16)
@@ -765,8 +753,7 @@ APInt APInt::byteSwap() const {
   for (unsigned I = 0, N = getNumWords(); I != N; ++I)
     Result.pVal[I] = ByteSwap_64(pVal[N - I - 1]);
   if (Result.BitWidth != BitWidth) {
-    lshrNear(Result.pVal, Result.pVal, getNumWords(),
-             Result.BitWidth - BitWidth);
+    Result.lshrInPlace(Result.BitWidth - BitWidth);
     Result.BitWidth = BitWidth;
   }
   return Result;
@@ -803,11 +790,45 @@ APInt APInt::reverseBits() const {
 }
 
 APInt llvm::APIntOps::GreatestCommonDivisor(APInt A, APInt B) {
-  while (!!B) {
-    APInt R = A.urem(B);
-    A = std::move(B);
-    B = std::move(R);
+  // Fast-path a common case.
+  if (A == B) return A;
+
+  // Corner cases: if either operand is zero, the other is the gcd.
+  if (!A) return B;
+  if (!B) return A;
+
+  // Count common powers of 2 and remove all other powers of 2.
+  unsigned Pow2;
+  {
+    unsigned Pow2_A = A.countTrailingZeros();
+    unsigned Pow2_B = B.countTrailingZeros();
+    if (Pow2_A > Pow2_B) {
+      A.lshrInPlace(Pow2_A - Pow2_B);
+      Pow2 = Pow2_B;
+    } else if (Pow2_B > Pow2_A) {
+      B.lshrInPlace(Pow2_B - Pow2_A);
+      Pow2 = Pow2_A;
+    } else {
+      Pow2 = Pow2_A;
+    }
+  }
+
+  // Both operands are odd multiples of 2^Pow_2:
+  //
+  //   gcd(a, b) = gcd(|a - b| / 2^i, min(a, b))
+  //
+  // This is a modified version of Stein's algorithm, taking advantage of
+  // efficient countTrailingZeros().
+  while (A != B) {
+    if (A.ugt(B)) {
+      A -= B;
+      A.lshrInPlace(A.countTrailingZeros() - Pow2);
+    } else {
+      B -= A;
+      B.lshrInPlace(B.countTrailingZeros() - Pow2);
+    }
   }
+
   return A;
 }
 
@@ -1119,68 +1140,59 @@ APInt APInt::lshr(const APInt &shiftAmt)
   return lshr((unsigned)shiftAmt.getLimitedValue(BitWidth));
 }
 
+/// Perform a logical right-shift from Src to Dst of Words words, by Shift,
+/// which must be less than 64. If the source and destination ranges overlap,
+/// we require that Src >= Dst (put another way, we require that the overall
+/// operation is a right shift on the combined range).
+static void lshrWords(APInt::WordType *Dst, APInt::WordType *Src,
+                      unsigned Words, unsigned Shift) {
+  assert(Shift < APInt::APINT_BITS_PER_WORD);
+
+  if (!Words)
+    return;
+
+  if (Shift == 0) {
+    std::memmove(Dst, Src, Words * APInt::APINT_WORD_SIZE);
+    return;
+  }
+
+  uint64_t Low = Src[0];
+  for (unsigned I = 1; I != Words; ++I) {
+    uint64_t High = Src[I];
+    Dst[I - 1] =
+        (Low >> Shift) | (High << (APInt::APINT_BITS_PER_WORD - Shift));
+    Low = High;
+  }
+  Dst[Words - 1] = Low >> Shift;
+}
+
 /// Logical right-shift this APInt by shiftAmt.
 /// @brief Logical right-shift function.
-APInt APInt::lshr(unsigned shiftAmt) const {
+void APInt::lshrInPlace(unsigned shiftAmt) {
   if (isSingleWord()) {
     if (shiftAmt >= BitWidth)
-      return APInt(BitWidth, 0);
+      VAL = 0;
     else
-      return APInt(BitWidth, this->VAL >> shiftAmt);
+      VAL >>= shiftAmt;
+    return;
   }
 
-  // If all the bits were shifted out, the result is 0. This avoids issues
-  // with shifting by the size of the integer type, which produces undefined
-  // results. We define these "undefined results" to always be 0.
-  if (shiftAmt >= BitWidth)
-    return APInt(BitWidth, 0);
-
-  // If none of the bits are shifted out, the result is *this. This avoids
-  // issues with shifting by the size of the integer type, which produces
-  // undefined results in the code below. This is also an optimization.
-  if (shiftAmt == 0)
-    return *this;
-
-  // Create some space for the result.
-  uint64_t * val = new uint64_t[getNumWords()];
-
-  // If we are shifting less than a word, compute the shift with a simple carry
-  if (shiftAmt < APINT_BITS_PER_WORD) {
-    lshrNear(val, pVal, getNumWords(), shiftAmt);
-    APInt Result(val, BitWidth);
-    Result.clearUnusedBits();
-    return Result;
-  }
-
-  // Compute some values needed by the remaining shift algorithms
-  unsigned wordShift = shiftAmt % APINT_BITS_PER_WORD;
-  unsigned offset = shiftAmt / APINT_BITS_PER_WORD;
-
-  // If we are shifting whole words, just move whole words
-  if (wordShift == 0) {
-    for (unsigned i = 0; i < getNumWords() - offset; ++i)
-      val[i] = pVal[i+offset];
-    for (unsigned i = getNumWords()-offset; i < getNumWords(); i++)
-      val[i] = 0;
-    APInt Result(val, BitWidth);
-    Result.clearUnusedBits();
-    return Result;
-  }
-
-  // Shift the low order words
-  unsigned breakWord = getNumWords() - offset -1;
-  for (unsigned i = 0; i < breakWord; ++i)
-    val[i] = (pVal[i+offset] >> wordShift) |
-             (pVal[i+offset+1] << (APINT_BITS_PER_WORD - wordShift));
-  // Shift the break word.
-  val[breakWord] = pVal[breakWord+offset] >> wordShift;
-
-  // Remaining words are 0
-  for (unsigned i = breakWord+1; i < getNumWords(); ++i)
-    val[i] = 0;
-  APInt Result(val, BitWidth);
-  Result.clearUnusedBits();
-  return Result;
+  // Don't bother performing a no-op shift.
+  if (!shiftAmt)
+    return;
+
+  // Find number of complete words being shifted out and zeroed.
+  const unsigned Words = getNumWords();
+  const unsigned ShiftFullWords =
+      std::min(shiftAmt / APINT_BITS_PER_WORD, Words);
+
+  // Fill in first Words - ShiftFullWords by shifting.
+  lshrWords(pVal, pVal + ShiftFullWords, Words - ShiftFullWords,
+            shiftAmt - (ShiftFullWords * APINT_BITS_PER_WORD));
+
+  // The remaining high words are all zero.
+  for (unsigned I = Words - ShiftFullWords; I != Words; ++I)
+    pVal[I] = 0;
 }
 
 /// Left-shift this APInt by shiftAmt.

Modified: llvm/trunk/lib/Transforms/InstCombine/InstCombineCompares.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Transforms/InstCombine/InstCombineCompares.cpp?rev=300252&r1=300251&r2=300252&view=diff
==============================================================================
--- llvm/trunk/lib/Transforms/InstCombine/InstCombineCompares.cpp (original)
+++ llvm/trunk/lib/Transforms/InstCombine/InstCombineCompares.cpp Thu Apr 13 15:29:59 2017
@@ -1178,6 +1178,373 @@ Instruction *InstCombiner::foldICmpAddOp
   Constant *C = Builder->getInt(CI->getValue()-1);
   return new ICmpInst(ICmpInst::ICMP_SLT, X, ConstantExpr::getSub(SMax, C));
 }
+#if 0
+/// FoldICmpDivCst - Fold "icmp pred, ([su]div X, DivRHS), CmpRHS" where DivRHS
+/// and CmpRHS are both known to be integer constants.
+Instruction *InstCombiner::FoldICmpDivCst(ICmpInst &ICI, BinaryOperator *DivI,
+                                          ConstantInt *DivRHS) {
+  ConstantInt *CmpRHS = cast<ConstantInt>(ICI.getOperand(1));
+  const APInt &CmpRHSV = CmpRHS->getValue();
+
+  // FIXME: If the operand types don't match the type of the divide
+  // then don't attempt this transform. The code below doesn't have the
+  // logic to deal with a signed divide and an unsigned compare (and
+  // vice versa). This is because (x /s C1) <s C2  produces different
+  // results than (x /s C1) <u C2 or (x /u C1) <s C2 or even
+  // (x /u C1) <u C2.  Simply casting the operands and result won't
+  // work. :(  The if statement below tests that condition and bails
+  // if it finds it.
+  bool DivIsSigned = DivI->getOpcode() == Instruction::SDiv;
+  if (!ICI.isEquality() && DivIsSigned != ICI.isSigned())
+    return nullptr;
+  if (DivRHS->isZero())
+    return nullptr; // The ProdOV computation fails on divide by zero.
+  if (DivIsSigned && DivRHS->isAllOnesValue())
+    return nullptr; // The overflow computation also screws up here
+  if (DivRHS->isOne()) {
+    // This eliminates some funny cases with INT_MIN.
+    ICI.setOperand(0, DivI->getOperand(0));   // X/1 == X.
+    return &ICI;
+  }
+
+  // Compute Prod = CI * DivRHS. We are essentially solving an equation
+  // of form X/C1=C2. We solve for X by multiplying C1 (DivRHS) and
+  // C2 (CI). By solving for X we can turn this into a range check
+  // instead of computing a divide.
+  Constant *Prod = ConstantExpr::getMul(CmpRHS, DivRHS);
+
+  // Determine if the product overflows by seeing if the product is
+  // not equal to the divide. Make sure we do the same kind of divide
+  // as in the LHS instruction that we're folding.
+  bool ProdOV = (DivIsSigned ? ConstantExpr::getSDiv(Prod, DivRHS) :
+                 ConstantExpr::getUDiv(Prod, DivRHS)) != CmpRHS;
+
+  // Get the ICmp opcode
+  ICmpInst::Predicate Pred = ICI.getPredicate();
+
+  /// If the division is known to be exact, then there is no remainder from the
+  /// divide, so the covered range size is unit, otherwise it is the divisor.
+  ConstantInt *RangeSize = DivI->isExact() ? getOne(Prod) : DivRHS;
+
+  // Figure out the interval that is being checked.  For example, a comparison
+  // like "X /u 5 == 0" is really checking that X is in the interval [0, 5).
+  // Compute this interval based on the constants involved and the signedness of
+  // the compare/divide.  This computes a half-open interval, keeping track of
+  // whether either value in the interval overflows.  After analysis each
+  // overflow variable is set to 0 if it's corresponding bound variable is valid
+  // -1 if overflowed off the bottom end, or +1 if overflowed off the top end.
+  int LoOverflow = 0, HiOverflow = 0;
+  Constant *LoBound = nullptr, *HiBound = nullptr;
+
+  if (!DivIsSigned) {  // udiv
+    // e.g. X/5 op 3  --> [15, 20)
+    LoBound = Prod;
+    HiOverflow = LoOverflow = ProdOV;
+    if (!HiOverflow) {
+      // If this is not an exact divide, then many values in the range collapse
+      // to the same result value.
+      HiOverflow = AddWithOverflow(HiBound, LoBound, RangeSize, false);
+    }
+  } else if (DivRHS->getValue().isStrictlyPositive()) { // Divisor is > 0.
+    if (CmpRHSV == 0) {       // (X / pos) op 0
+      // Can't overflow.  e.g.  X/2 op 0 --> [-1, 2)
+      LoBound = ConstantExpr::getNeg(SubOne(RangeSize));
+      HiBound = RangeSize;
+    } else if (CmpRHSV.isStrictlyPositive()) {   // (X / pos) op pos
+      LoBound = Prod;     // e.g.   X/5 op 3 --> [15, 20)
+      HiOverflow = LoOverflow = ProdOV;
+      if (!HiOverflow)
+        HiOverflow = AddWithOverflow(HiBound, Prod, RangeSize, true);
+    } else {                       // (X / pos) op neg
+      // e.g. X/5 op -3  --> [-15-4, -15+1) --> [-19, -14)
+      HiBound = AddOne(Prod);
+      LoOverflow = HiOverflow = ProdOV ? -1 : 0;
+      if (!LoOverflow) {
+        ConstantInt *DivNeg =cast<ConstantInt>(ConstantExpr::getNeg(RangeSize));
+        LoOverflow = AddWithOverflow(LoBound, HiBound, DivNeg, true) ? -1 : 0;
+      }
+    }
+  } else if (DivRHS->isNegative()) { // Divisor is < 0.
+    if (DivI->isExact())
+      RangeSize = cast<ConstantInt>(ConstantExpr::getNeg(RangeSize));
+    if (CmpRHSV == 0) {       // (X / neg) op 0
+      // e.g. X/-5 op 0  --> [-4, 5)
+      LoBound = AddOne(RangeSize);
+      HiBound = cast<ConstantInt>(ConstantExpr::getNeg(RangeSize));
+      if (HiBound == DivRHS) {     // -INTMIN = INTMIN
+        HiOverflow = 1;            // [INTMIN+1, overflow)
+        HiBound = nullptr;         // e.g. X/INTMIN = 0 --> X > INTMIN
+      }
+    } else if (CmpRHSV.isStrictlyPositive()) {   // (X / neg) op pos
+      // e.g. X/-5 op 3  --> [-19, -14)
+      HiBound = AddOne(Prod);
+      HiOverflow = LoOverflow = ProdOV ? -1 : 0;
+      if (!LoOverflow)
+        LoOverflow = AddWithOverflow(LoBound, HiBound, RangeSize, true) ? -1:0;
+    } else {                       // (X / neg) op neg
+      LoBound = Prod;       // e.g. X/-5 op -3  --> [15, 20)
+      LoOverflow = HiOverflow = ProdOV;
+      if (!HiOverflow)
+        HiOverflow = SubWithOverflow(HiBound, Prod, RangeSize, true);
+    }
+
+    // Dividing by a negative swaps the condition.  LT <-> GT
+    Pred = ICmpInst::getSwappedPredicate(Pred);
+  }
+
+  Value *X = DivI->getOperand(0);
+  switch (Pred) {
+  default: llvm_unreachable("Unhandled icmp opcode!");
+  case ICmpInst::ICMP_EQ:
+    if (LoOverflow && HiOverflow)
+      return ReplaceInstUsesWith(ICI, Builder->getFalse());
+    if (HiOverflow)
+      return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SGE :
+                          ICmpInst::ICMP_UGE, X, LoBound);
+    if (LoOverflow)
+      return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SLT :
+                          ICmpInst::ICMP_ULT, X, HiBound);
+    return ReplaceInstUsesWith(ICI, InsertRangeTest(X, LoBound, HiBound,
+                                                    DivIsSigned, true));
+  case ICmpInst::ICMP_NE:
+    if (LoOverflow && HiOverflow)
+      return ReplaceInstUsesWith(ICI, Builder->getTrue());
+    if (HiOverflow)
+      return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SLT :
+                          ICmpInst::ICMP_ULT, X, LoBound);
+    if (LoOverflow)
+      return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SGE :
+                          ICmpInst::ICMP_UGE, X, HiBound);
+    return ReplaceInstUsesWith(ICI, InsertRangeTest(X, LoBound, HiBound,
+                                                    DivIsSigned, false));
+  case ICmpInst::ICMP_ULT:
+  case ICmpInst::ICMP_SLT:
+    if (LoOverflow == +1)   // Low bound is greater than input range.
+      return ReplaceInstUsesWith(ICI, Builder->getTrue());
+    if (LoOverflow == -1)   // Low bound is less than input range.
+      return ReplaceInstUsesWith(ICI, Builder->getFalse());
+    return new ICmpInst(Pred, X, LoBound);
+  case ICmpInst::ICMP_UGT:
+  case ICmpInst::ICMP_SGT:
+    if (HiOverflow == +1)       // High bound greater than input range.
+      return ReplaceInstUsesWith(ICI, Builder->getFalse());
+    if (HiOverflow == -1)       // High bound less than input range.
+      return ReplaceInstUsesWith(ICI, Builder->getTrue());
+    if (Pred == ICmpInst::ICMP_UGT)
+      return new ICmpInst(ICmpInst::ICMP_UGE, X, HiBound);
+    return new ICmpInst(ICmpInst::ICMP_SGE, X, HiBound);
+  }
+}
+
+/// FoldICmpShrCst - Handle "icmp(([al]shr X, cst1), cst2)".
+Instruction *InstCombiner::FoldICmpShrCst(ICmpInst &ICI, BinaryOperator *Shr,
+                                          ConstantInt *ShAmt) {
+  const APInt &CmpRHSV = cast<ConstantInt>(ICI.getOperand(1))->getValue();
+
+  // Check that the shift amount is in range.  If not, don't perform
+  // undefined shifts.  When the shift is visited it will be
+  // simplified.
+  uint32_t TypeBits = CmpRHSV.getBitWidth();
+  uint32_t ShAmtVal = (uint32_t)ShAmt->getLimitedValue(TypeBits);
+  if (ShAmtVal >= TypeBits || ShAmtVal == 0)
+    return nullptr;
+
+  if (!ICI.isEquality()) {
+    // If we have an unsigned comparison and an ashr, we can't simplify this.
+    // Similarly for signed comparisons with lshr.
+    if (ICI.isSigned() != (Shr->getOpcode() == Instruction::AShr))
+      return nullptr;
+
+    // Otherwise, all lshr and most exact ashr's are equivalent to a udiv/sdiv
+    // by a power of 2.  Since we already have logic to simplify these,
+    // transform to div and then simplify the resultant comparison.
+    if (Shr->getOpcode() == Instruction::AShr &&
+        (!Shr->isExact() || ShAmtVal == TypeBits - 1))
+      return nullptr;
+
+    // Revisit the shift (to delete it).
+    Worklist.Add(Shr);
+
+    Constant *DivCst =
+      ConstantInt::get(Shr->getType(), APInt::getOneBitSet(TypeBits, ShAmtVal));
+
+    Value *Tmp =
+      Shr->getOpcode() == Instruction::AShr ?
+      Builder->CreateSDiv(Shr->getOperand(0), DivCst, "", Shr->isExact()) :
+      Builder->CreateUDiv(Shr->getOperand(0), DivCst, "", Shr->isExact());
+
+    ICI.setOperand(0, Tmp);
+
+    // If the builder folded the binop, just return it.
+    BinaryOperator *TheDiv = dyn_cast<BinaryOperator>(Tmp);
+    if (!TheDiv)
+      return &ICI;
+
+    // Otherwise, fold this div/compare.
+    assert(TheDiv->getOpcode() == Instruction::SDiv ||
+           TheDiv->getOpcode() == Instruction::UDiv);
+
+    Instruction *Res = FoldICmpDivCst(ICI, TheDiv, cast<ConstantInt>(DivCst));
+    assert(Res && "This div/cst should have folded!");
+    return Res;
+  }
+
+  // If we are comparing against bits always shifted out, the
+  // comparison cannot succeed.
+  APInt Comp = CmpRHSV << ShAmtVal;
+  ConstantInt *ShiftedCmpRHS = Builder->getInt(Comp);
+  if (Shr->getOpcode() == Instruction::LShr)
+    Comp = Comp.lshr(ShAmtVal);
+  else
+    Comp = Comp.ashr(ShAmtVal);
+
+  if (Comp != CmpRHSV) { // Comparing against a bit that we know is zero.
+    bool IsICMP_NE = ICI.getPredicate() == ICmpInst::ICMP_NE;
+    Constant *Cst = Builder->getInt1(IsICMP_NE);
+    return ReplaceInstUsesWith(ICI, Cst);
+  }
+
+  // Otherwise, check to see if the bits shifted out are known to be zero.
+  // If so, we can compare against the unshifted value:
+  //  (X & 4) >> 1 == 2  --> (X & 4) == 4.
+  if (Shr->hasOneUse() && Shr->isExact())
+    return new ICmpInst(ICI.getPredicate(), Shr->getOperand(0), ShiftedCmpRHS);
+
+  if (Shr->hasOneUse()) {
+    // Otherwise strength reduce the shift into an and.
+    APInt Val(APInt::getHighBitsSet(TypeBits, TypeBits - ShAmtVal));
+    Constant *Mask = Builder->getInt(Val);
+
+    Value *And = Builder->CreateAnd(Shr->getOperand(0),
+                                    Mask, Shr->getName()+".mask");
+    return new ICmpInst(ICI.getPredicate(), And, ShiftedCmpRHS);
+  }
+  return nullptr;
+}
+#endif
+namespace {
+/// Models a check that LHS is divisible by Factor.
+class DivisibilityCheck {
+  // Signedness of the check. A bitwise and is a divisibility check,
+  // if its mask is (equivalent to) a power of 2 mask.
+  enum { DC_Null, DC_SRem, DC_URem, DC_And } Kind;
+  Value *Check;
+  Value *LHS;
+  ConstantInt *Factor;
+
+public:
+  DivisibilityCheck() : Kind(DC_Null) {}
+
+  /// Try to extract a divisibility check from V, on the assumption
+  /// that it is being compared to 0.
+  bool match(Value *V) {
+    Kind = DC_Null;
+    Check = V;
+    if (::match(V, m_SRem(m_Value(LHS), m_ConstantInt(Factor))))
+      Kind = DC_SRem;
+    else if (::match(V, m_URem(m_Value(LHS), m_ConstantInt(Factor))))
+      Kind = DC_URem;
+    else if (::match(V, m_And(m_Value(LHS), m_ConstantInt(Factor))))
+      Kind = DC_And;
+    return Kind != DC_Null;
+  }
+
+  /// Merge another divisibility check into this one.
+  bool merge(const DivisibilityCheck &O) {
+    assert(Kind != DC_Null && O.Kind != DC_Null);
+    if (LHS != O.LHS)
+      // We don't have two divisibility checks on the same operand.
+      return false;
+
+    if (!(Check->hasOneUse() && Kind != DC_And) &&
+        !(O.Check->hasOneUse() && O.Kind != DC_And))
+      // We would not remove a division: bail out.
+      return false;
+
+    // Determine the factors we're checking for.
+    bool Failed = false;
+    APInt LHS = getFactor(O, Failed);
+    APInt RHS = O.getFactor(*this, Failed);
+    if (Failed)
+      return false;
+
+    // If we don't have a single signedness, we can fold the checks
+    // together if one of them is for a power of 2, because
+    // divisibility by a power of 2 is the same for srem and urem.
+    if (Kind != O.Kind && O.Kind != DC_And && LHS.isPowerOf2())
+      Kind = O.Kind;
+    if (Kind != O.Kind && !RHS.isPowerOf2())
+      return false;
+    assert(Kind == DC_SRem || Kind == DC_URem && "bad kind after merging");
+    bool Signed = Kind == DC_SRem;
+
+    // Fold them together.
+    APInt GCD = APIntOps::GreatestCommonDivisor(LHS, RHS);
+    APInt LCM = LHS.udiv(GCD);
+    bool Overflow = false;
+    // Use a negative signed multiplication: producing INT_MIN should not
+    // be considered an overflow here.
+    LCM = Signed ? LCM.smul_ov(-RHS, Overflow) : LCM.umul_ov(RHS, Overflow);
+    // On overflow, there cannot exist a non-zero value that is divisible by
+    // both factors at once.
+    if (Overflow) LCM = 0;
+    Factor = cast<ConstantInt>(ConstantInt::get(Factor->getType(), LCM));
+    return true;
+  }
+
+  Value *create(InstCombiner::BuilderTy *Builder) {
+    // LHS is divisible by zero iff LHS is zero.
+    if (!Factor->getValue())
+      return LHS;
+    // Checking for divisibility by power of 2 doesn't need a division.
+    if (Factor->getValue().isPowerOf2())
+      return Builder->CreateAnd(
+          LHS, ConstantInt::get(Factor->getType(), Factor->getValue() - 1));
+    return Kind == DC_SRem ? Builder->CreateSRem(LHS, Factor)
+                           : Builder->CreateURem(LHS, Factor);
+  }
+
+private:
+  /// Get the unsigned multiplicative factor we're checking for.
+  APInt getFactor(const DivisibilityCheck &O, bool &Failed) const {
+    switch (Kind) {
+    case DC_Null:
+      llvm_unreachable("unexpected Kind");
+
+    case DC_SRem:
+      if (Factor->getValue().isNegative())
+        return -Factor->getValue();
+      // Fall through.
+    case DC_URem:
+      return Factor->getValue();
+
+    case DC_And:
+      assert(O.Kind != DC_And && "bad kind pair");
+      // If we're also checking for divisibility by K * 2^N,
+      // the low N bits of the mask are irrelevant.
+      APInt Result =
+          Factor->getValue() |
+          APInt::getLowBitsSet(Factor->getValue().getBitWidth(),
+                               O.getFactor(*this, Failed).countTrailingZeros());
+      ++Result;
+      if (!!Result && !Result.isPowerOf2())
+        Failed = true;
+      return Result;
+    }
+  }
+};
+
+struct DivisibilityCheck_match {
+  DivisibilityCheck &Check;
+  DivisibilityCheck_match(DivisibilityCheck &Check) : Check(Check) {}
+  bool match(Value *V) { return Check.match(V); }
+};
+
+/// Matcher for divisibility checks.
+DivisibilityCheck_match m_DivisibilityCheck(DivisibilityCheck &Check) {
+  return DivisibilityCheck_match(Check);
+}
+}
 
 /// Handle "(icmp eq/ne (ashr/lshr AP2, A), AP1)" ->
 /// (icmp eq/ne A, Log2(AP2/AP1)) ->
@@ -1806,6 +2173,42 @@ Instruction *InstCombiner::foldICmpOrCon
   if (!Cmp.isEquality() || *C != 0 || !Or->hasOneUse())
     return nullptr;
 
+  DivisibilityCheck DivL, DivR;
+  if (match(Or, m_Or(m_DivisibilityCheck(DivL), m_DivisibilityCheck(DivR))) &&
+      DivL.merge(DivR)) {
+    // Simplify icmp eq (or (srem P, M), (srem P, N)), 0
+    //  -> icmp eq (srem P, lcm(M, N)), 0
+    return new ICmpInst(Pred, DivL.create(Builder),
+                        ConstantInt::getNullValue(Or->getType()));
+  }
+
+#if 0
+    // icmp eq (or X, Y), 0
+    //  -> and (icmp eq X, 0), (icmp eq Y, 0)
+    // but only if this allows either subexpression to simplify further.
+    Instruction *ICmpX = nullptr, *ICmpY = nullptr;
+    if (auto *X = dyn_cast<Instruction>(LHSI->getOperand(0)))
+      ICmpX = visitICmpInstWithInstAndIntCst(ICI, X, RHS);
+    if (auto *Y = dyn_cast<Instruction>(LHSI->getOperand(1)))
+      ICmpY = visitICmpInstWithInstAndIntCst(ICI, Y, RHS);
+    if (ICmpX || ICmpX) {
+      Value *NewX, *NewY;
+      if (ICmpX) {
+        Worklist.Add(ICmpX);
+        NewX = Builder->Insert(ICmpX);
+      } else
+        NewX = Builder->CreateICmp(ICI.getPredicate(), LHSI->getOperand(0),
+                                    RHS);
+      if (ICmpY) {
+        Worklist.Add(ICmpY);
+        NewY = Builder->Insert(ICmpY);
+      } else
+        NewY = Builder->CreateICmp(ICI.getPredicate(), LHSI->getOperand(1),
+                                    RHS);
+      return BinaryOperator::CreateAnd(NewX, NewY);
+    }
+#endif
+
   Value *P, *Q;
   if (match(Or, m_Or(m_PtrToInt(m_Value(P)), m_PtrToInt(m_Value(Q))))) {
     // Simplify icmp eq (or (ptrtoint P), (ptrtoint Q)), 0

Added: llvm/trunk/test/Transforms/InstCombine/divisibility.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/Transforms/InstCombine/divisibility.ll?rev=300252&view=auto
==============================================================================
--- llvm/trunk/test/Transforms/InstCombine/divisibility.ll (added)
+++ llvm/trunk/test/Transforms/InstCombine/divisibility.ll Thu Apr 13 15:29:59 2017
@@ -0,0 +1,297 @@
+; Test that multiple divisibility checks are merged.
+
+; RUN: opt < %s -instcombine -S | FileCheck %s
+
+define i1 @test1(i32 %A) {
+  %B = srem i32 %A, 2
+  %C = srem i32 %A, 3
+  %D = or i32 %B, %C
+  %E = icmp eq i32 %D, 0
+  ret i1 %E
+; CHECK-LABEL: @test1(
+; CHECK-NEXT: srem i32 %A, 6
+; CHECK-NEXT: icmp eq i32 %{{.*}}, 0
+; CHECK-NEXT: ret i1
+}
+
+define i1 @test2(i32 %A) {
+  %B = urem i32 %A, 2
+  %C = urem i32 %A, 3
+  %D = or i32 %B, %C
+  %E = icmp eq i32 %D, 0
+  ret i1 %E
+; CHECK-LABEL: @test2(
+; CHECK-NEXT: urem i32 %A, 6
+; CHECK-NEXT: icmp eq i32 %{{.*}}, 0
+; CHECK-NEXT: ret i1
+}
+
+define i1 @test3(i32 %A) {
+  %B = srem i32 %A, 2
+  %C = urem i32 %A, 3
+  %D = or i32 %B, %C
+  %E = icmp eq i32 %D, 0
+  ret i1 %E
+; CHECK-LABEL: @test3(
+; CHECK-NEXT: urem i32 %A, 6
+; CHECK-NEXT: icmp eq i32 %{{.*}}, 0
+; CHECK-NEXT: ret i1
+}
+
+define i1 @test4(i32 %A) {
+  %B = urem i32 %A, 2
+  %C = srem i32 %A, 3
+  %D = or i32 %B, %C
+  %E = icmp eq i32 %D, 0
+  ret i1 %E
+; CHECK-LABEL: @test4(
+; CHECK-NEXT: srem i32 %A, 6
+; CHECK-NEXT: icmp eq i32 %{{.*}}, 0
+; CHECK-NEXT: ret i1
+}
+
+define i1 @test5(i32 %A) {
+  %B = srem i32 %A, 8
+  %C = srem i32 %A, 12
+  %D = or i32 %B, %C
+  %E = icmp eq i32 %D, 0
+  ret i1 %E
+; CHECK-LABEL: @test5(
+; CHECK-NEXT: srem i32 %A, 24
+; CHECK-NEXT: icmp eq i32 %{{.*}}, 0
+; CHECK-NEXT: ret i1
+}
+
+define i1 @test6(i32 %A) {
+  %B = and i32 %A, 6
+  %C = srem i32 %A, 12
+  %D = or i32 %B, %C
+  %E = icmp eq i32 %D, 0
+  ret i1 %E
+; CHECK-LABEL: @test6(
+; CHECK-NEXT: srem i32 %A, 24
+; CHECK-NEXT: icmp eq i32 %{{.*}}, 0
+; CHECK-NEXT: ret i1
+}
+
+define i1 @test7(i32 %A) {
+  %B = and i32 %A, 8
+  %C = srem i32 %A, 12
+  %D = or i32 %B, %C
+  %E = icmp eq i32 %D, 0
+  ret i1 %E
+; CHECK-LABEL: @test7(
+; CHECK-NEXT: and i32 %A, 8
+; CHECK-NEXT: srem i32 %A, 12
+; CHECK-NEXT: or
+; CHECK-NEXT: icmp
+; CHECK-NEXT: ret i1
+}
+
+define i1 @test8(i32 %A, i32 %B) {
+  %C = srem i32 %A, 2
+  %D = srem i32 %B, 3
+  %E = or i32 %C, %D
+  %F = icmp eq i32 %E, 0
+  ret i1 %F
+; CHECK-LABEL: @test8(
+; CHECK-NEXT: srem i32 %B, 3
+; CHECK-NEXT: and i32 %A, 1
+; CHECK-NEXT: or
+; CHECK-NEXT: icmp
+; CHECK-NEXT: ret i1
+}
+
+define i1 @test9(i32 %A) {
+  %B = srem i32 %A, 7589
+  %C = srem i32 %A, 395309
+  %D = or i32 %B, %C
+  %E = icmp eq i32 %D, 0
+  ret i1 %E
+; CHECK-LABEL: @test9(
+; CHECK-NEXT: icmp eq i32 %A, 0
+; CHECK-NEXT: ret i1 %E
+}
+
+define i1 @test10(i32 %A) {
+  ; 7589 and 395309 are prime, and
+  ; 7589 * 395309 == 3000000001 == -1294967295 (2^32)
+  %B = urem i32 %A, 7589
+  %C = urem i32 %A, 395309
+  %D = or i32 %B, %C
+  %E = icmp eq i32 %D, 0
+  ret i1 %E
+; CHECK-LABEL: @test10(
+; CHECK-NEXT: urem i32 %A, -1294967295
+; CHECK-NEXT: icmp eq i32 %{{.*}}, 0
+; CHECK-NEXT: ret i1
+}
+
+define i1 @test11(i32 %A) {
+  %B = urem i32 %A, 65535
+  %C = urem i32 %A, 65537
+  %D = or i32 %B, %C
+  %E = icmp eq i32 %D, 0
+  ret i1 %E
+; CHECK-LABEL: @test11(
+; CHECK-NEXT: urem i32 %A, -1
+; CHECK-NEXT: icmp eq i32 %{{.*}}, 0
+; CHECK-NEXT: ret i1
+}
+
+define i1 @test12(i32 %A) {
+  %B = urem i32 %A, 65536
+  %C = urem i32 %A, 65537
+  %D = or i32 %B, %C
+  %E = icmp eq i32 %D, 0
+  ret i1 %E
+; CHECK-LABEL: @test12(
+; CHECK-NEXT: icmp eq i32 %A, 0
+; CHECK-NEXT: ret i1
+}
+
+define i1 @test13(i32 %A) {
+  %B = srem i32 %A, 65536
+  %C = urem i32 %A, 65535
+  %D = or i32 %B, %C
+  %E = icmp eq i32 %D, 0
+  ret i1 %E
+; CHECK-LABEL: @test13(
+; CHECK-NEXT: urem i32 %A, -65536
+; CHECK-NEXT: icmp eq i32 %{{.*}}, 0
+; CHECK-NEXT: ret i1
+}
+
+define i1 @test14(i32 %A) {
+  %B = srem i32 %A, 95
+  %C = srem i32 %A, 22605091
+  %D = or i32 %B, %C
+  %E = icmp eq i32 %D, 0
+  ret i1 %E
+; CHECK-LABEL: @test14(
+; CHECK-NEXT: srem i32 %A, 2147483645
+; CHECK-NEXT: icmp eq i32 %{{.*}}, 0
+; CHECK-NEXT: ret i1
+}
+
+define i1 @test15(i32 %A) {
+  %B = srem i32 %A, 97
+  %C = srem i32 %A, 22605091
+  %D = or i32 %B, %C
+  %E = icmp eq i32 %D, 0
+  ret i1 %E
+; CHECK-LABEL: @test15(
+; CHECK-NEXT: icmp eq i32 %A, 0
+; CHECK-NEXT: ret i1
+}
+
+define i32 @test16(i32 %A) {
+  %B = srem i32 %A, 3
+  %C = srem i32 %A, 5
+  %D = or i32 %B, %C
+  %E = icmp eq i32 %D, 0
+  %F = zext i1 %E to i32
+  %G = add i32 %B, %F
+  ret i32 %G
+; CHECK-LABEL: @test16(
+; CHECK-NEXT:  %B = srem i32 %A, 3
+; CHECK-NEXT:  %[[REM:.*]] = srem i32 %A, 15
+; CHECK-NEXT:  %E = icmp eq i32 %[[REM]], 0
+; CHECK-NEXT:  %F = zext i1 %E to i32
+; CHECK-NEXT:  %G = add i32 %B, %F
+; CHECK-NEXT:  ret i32 %G
+}
+
+define i32 @test17(i32 %A) {
+  %B = srem i32 %A, 3
+  %C = srem i32 %A, 5
+  %D = or i32 %B, %C
+  %E = icmp eq i32 %D, 0
+  %F = zext i1 %E to i32
+  %G = add i32 %B, %F
+  %H = add i32 %C, %G
+  ret i32 %H
+; CHECK-LABEL: @test17(
+; CHECK-NEXT:  %B = srem i32 %A, 3
+; CHECK-NEXT:  %C = srem i32 %A, 5
+; CHECK-NOT: srem
+; CHECK: ret i32
+}
+
+define i32 @test18(i32 %A) {
+  %B = srem i32 %A, 3
+  %C = and i32 %A, 7
+  %D = or i32 %B, %C
+  %E = icmp eq i32 %D, 0
+  %F = zext i1 %E to i32
+  %G = add i32 %C, %F
+  ret i32 %G
+; CHECK-LABEL: @test18(
+; CHECK-NEXT:  %C = and i32 %A, 7
+; CHECK-NEXT:  %[[REM:.*]] = srem i32 %A, 24
+; CHECK-NEXT:  %E = icmp eq i32 %[[REM]], 0
+; CHECK-NEXT:  %F = zext i1 %E to i32
+; CHECK-NEXT:  %G = add
+; CHECK-NEXT:  ret i32 %G
+}
+
+define i1 @test19(i32 %A) {
+  %B = srem i32 %A, 6
+  %C = srem i32 %A, 10
+  %D = icmp eq i32 %B, 0
+  %E = icmp eq i32 %C, 0
+  %F = and i1 %D, %E
+  ret i1 %F
+; CHECK-LABEL: @test19(
+; CHECK-NEXT:  %[[REM:.*]] = srem i32 %A, 30
+; CHECK-NEXT:  icmp eq i32 %[[REM]], 0
+; CHECK-NEXT:  ret i1
+}
+
+define i1 @test20(i32 %A) {
+  %B = and i32 %A, 1
+  %C = srem i32 %A, 3
+  %D = and i32 %A, 3
+  %E = srem i32 %A, 5
+  %F = srem i32 %A, 6
+  %G = icmp eq i32 %B, 0
+  %H = icmp eq i32 %C, 0
+  %I = icmp eq i32 %D, 0
+  %J = icmp eq i32 %E, 0
+  %K = icmp eq i32 %F, 0
+  %L = and i1 %G, %H
+  %M = and i1 %L, %I
+  %N = and i1 %M, %J
+  %O = and i1 %N, %K
+  ret i1 %O
+; CHECK-LABEL: @test20(
+; CHECK-NEXT:  srem i32 %A, 60
+; CHECK-NEXT:  icmp eq i32
+; CHECK-NEXT:  ret i1
+}
+
+define i1 @test21(i32 %A) {
+  %B = srem i32 %A, -2147483648
+  %C = srem i32 %A, 1024
+  %D = icmp eq i32 %B, 0
+  %E = icmp eq i32 %C, 0
+  %F = and i1 %D, %E
+  ret i1 %F
+; CHECK-LABEL: @test21(
+; CHECK-NEXT:  and i32 %A, 2147483647
+; CHECK-NEXT:  icmp eq i32
+; CHECK-NEXT:  ret i1
+}
+
+define i1 @test22(i32 %A) {
+  %B = srem i32 %A, 1024
+  %C = srem i32 %A, -2147483648
+  %D = icmp eq i32 %B, 0
+  %E = icmp eq i32 %C, 0
+  %F = and i1 %D, %E
+  ret i1 %F
+; CHECK-LABEL: @test22(
+; CHECK-NEXT:  and i32 %A, 2147483647
+; CHECK-NEXT:  icmp eq i32
+; CHECK-NEXT:  ret i1
+}

Modified: llvm/trunk/unittests/ADT/APIntTest.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/unittests/ADT/APIntTest.cpp?rev=300252&r1=300251&r2=300252&view=diff
==============================================================================
--- llvm/trunk/unittests/ADT/APIntTest.cpp (original)
+++ llvm/trunk/unittests/ADT/APIntTest.cpp Thu Apr 13 15:29:59 2017
@@ -1977,3 +1977,46 @@ TEST(APIntTest, getHiBits) {
   i128.setHighBits(2);
   EXPECT_EQ(0xc, i128.getHiBits(4));
 }
+
+TEST(APIntTest, GCD) {
+  using APIntOps::GreatestCommonDivisor;
+
+  for (unsigned Bits : {1, 2, 32, 63, 64, 65}) {
+    // Test some corner cases near zero.
+    APInt Zero(Bits, 0), One(Bits, 1);
+    EXPECT_EQ(GreatestCommonDivisor(Zero, Zero), Zero);
+    EXPECT_EQ(GreatestCommonDivisor(Zero, One), One);
+    EXPECT_EQ(GreatestCommonDivisor(One, Zero), One);
+    EXPECT_EQ(GreatestCommonDivisor(One, One), One);
+
+    if (Bits > 1) {
+      APInt Two(Bits, 2);
+      EXPECT_EQ(GreatestCommonDivisor(Zero, Two), Two);
+      EXPECT_EQ(GreatestCommonDivisor(One, Two), One);
+      EXPECT_EQ(GreatestCommonDivisor(Two, Two), Two);
+
+      // Test some corner cases near the highest representable value.
+      APInt Max(Bits, 0);
+      Max.setAllBits();
+      EXPECT_EQ(GreatestCommonDivisor(Zero, Max), Max);
+      EXPECT_EQ(GreatestCommonDivisor(One, Max), One);
+      EXPECT_EQ(GreatestCommonDivisor(Two, Max), One);
+      EXPECT_EQ(GreatestCommonDivisor(Max, Max), Max);
+
+      APInt MaxOver2 = Max.udiv(Two);
+      EXPECT_EQ(GreatestCommonDivisor(MaxOver2, Max), One);
+      // Max - 1 == Max / 2 * 2, because Max is odd.
+      EXPECT_EQ(GreatestCommonDivisor(MaxOver2, Max - 1), MaxOver2);
+    }
+  }
+
+  // Compute the 20th Mersenne prime.
+  const unsigned BitWidth = 4450;
+  APInt HugePrime = APInt::getLowBitsSet(BitWidth, 4423);
+
+  // 9931 and 123456 are coprime.
+  APInt A = HugePrime * APInt(BitWidth, 9931);
+  APInt B = HugePrime * APInt(BitWidth, 123456);
+  APInt C = GreatestCommonDivisor(A, B);
+  EXPECT_EQ(C, HugePrime);
+}




More information about the llvm-commits mailing list