[PATCH] Fold together repeated tests for divisibility by constants

Richard Smith richard at metafoo.co.uk
Mon Jul 21 18:55:05 PDT 2014


Hi!

The attached patch teaches LLVM to fold together repeated tests for
divisibility by a constant into a single test for divisibility by the LCM
of the constants.

This was inspired by PR20205 (but doesn't directly help there, because we
can't compute a trip count for its loop, nor peel the loop like GCC does).
-------------- next part --------------
An HTML attachment was scrubbed...
URL: <http://lists.llvm.org/pipermail/llvm-commits/attachments/20140721/23aac22c/attachment.html>
-------------- next part --------------
Index: lib/Support/APInt.cpp
===================================================================
--- lib/Support/APInt.cpp	(revision 213404)
+++ lib/Support/APInt.cpp	(working copy)
@@ -802,15 +802,44 @@
   return Result;
 }
 
-APInt llvm::APIntOps::GreatestCommonDivisor(const APInt& API1,
-                                            const APInt& API2) {
-  APInt A = API1, B = API2;
-  while (!!B) {
-    APInt T = B;
-    B = APIntOps::urem(A, B);
-    A = T;
+APInt llvm::APIntOps::GreatestCommonDivisor(APInt API1, APInt API2) {
+  // Fast-path a common case.
+  if (API1 == API2) return API1;
+
+  // Corner cases: if either operand is zero, the other is the gcd.
+  if (!API1) return API2;
+  if (!API2) return API1;
+
+  // Count common powers of 2 and remove the rest.
+  unsigned Pow2;
+  {
+    unsigned Pow2_1 = API1.countTrailingZeros();
+    unsigned Pow2_2 = API2.countTrailingZeros();
+    if (Pow2_1 > Pow2_2) {
+      API1 = API1.lshr(Pow2_1 - Pow2_2);
+      Pow2 = Pow2_2;
+    } else if (Pow2_2 > Pow2_1) {
+      API2 = API2.lshr(Pow2_2 - Pow2_1);
+      Pow2 = Pow2_1;
+    } else {
+      Pow2 = Pow2_1;
+    }
   }
-  return A;
+
+  // Both operands are odd multiples of 2^Pow_2:
+  //
+  //   gcd(a, b) = gcd(|a - b| / 2^i, min(a, b))
+  while (API1 != API2) {
+    if (API1.ugt(API2)) {
+      API1 -= API2;
+      API1 = API1.lshr(API1.countTrailingZeros() - Pow2);
+    } else {
+      API2 -= API1;
+      API2 = API2.lshr(API2.countTrailingZeros() - Pow2);
+    }
+  }
+
+  return API1;
 }
 
 APInt llvm::APIntOps::RoundDoubleToAPInt(double Double, unsigned width) {
Index: lib/Transforms/InstCombine/InstCombineCompares.cpp
===================================================================
--- lib/Transforms/InstCombine/InstCombineCompares.cpp	(revision 213404)
+++ lib/Transforms/InstCombine/InstCombineCompares.cpp	(working copy)
@@ -1044,7 +1044,124 @@
   return nullptr;
 }
 
+namespace {
+/// Models a check that LHS is divisible by Factor.
+class DivisibilityCheck {
+  // Signedness of the check. A bitwise and is a divisibility check,
+  // if its mask is (equivalent to) a power of 2 mask.
+  enum { DC_Null, DC_SRem, DC_URem, DC_And } Kind;
+  Value *Check;
+  Value *LHS;
+  ConstantInt *Factor;
 
+public:
+  DivisibilityCheck() : Kind(DC_Null) {}
+
+  /// Try to extract a divisibility check from V, on the assumption
+  /// that it is being compared to 0.
+  bool match(Value *V) {
+    Kind = DC_Null;
+    Check = V;
+    if (::match(V, m_SRem(m_Value(LHS), m_ConstantInt(Factor))))
+      Kind = DC_SRem;
+    else if (::match(V, m_URem(m_Value(LHS), m_ConstantInt(Factor))))
+      Kind = DC_URem;
+    else if (::match(V, m_And(m_Value(LHS), m_ConstantInt(Factor))))
+      Kind = DC_And;
+    return Kind != DC_Null;
+  }
+
+  /// Merge another divisibility check into this one.
+  bool merge(const DivisibilityCheck &O) {
+    assert(Kind != DC_Null && O.Kind != DC_Null);
+    if (LHS != O.LHS)
+      // We don't have two divisibility checks on the same operand.
+      return false;
+
+    if (!(Check->hasOneUse() && Kind != DC_And) &&
+        !(O.Check->hasOneUse() && O.Kind != DC_And))
+      // We would not remove a division: bail out.
+      return false;
+
+    // Determine the factors we're checking for.
+    bool Failed = false;
+    APInt LHS = getFactor(O, Failed);
+    APInt RHS = O.getFactor(*this, Failed);
+    if (Failed)
+      return false;
+
+    // If we don't have a single signedness, we can fold the checks
+    // together if one of them is for a power of 2, because
+    // divisibility by a power of 2 is the same for srem and urem.
+    if (Kind != O.Kind && O.Kind != DC_And && LHS.isPowerOf2())
+      Kind = O.Kind;
+    if (Kind != O.Kind && !RHS.isPowerOf2())
+      return false;
+    assert(Kind == DC_SRem || Kind == DC_URem && "bad kind after merging");
+    bool Signed = Kind == DC_SRem;
+
+    // Fold them together.
+    APInt GCD = APIntOps::GreatestCommonDivisor(LHS, RHS);
+    APInt LCM = LHS.udiv(GCD);
+    bool Overflow = false;
+    LCM = Signed ? LCM.smul_ov(RHS, Overflow) : LCM.umul_ov(RHS, Overflow);
+    // On overflow, there cannot exist a non-zero value that is divisible by
+    // both factors at once.
+    if (Overflow) LCM = 0;
+    Factor = cast<ConstantInt>(ConstantInt::get(Factor->getType(), LCM));
+    return true;
+  }
+
+  Value *create(InstCombiner::BuilderTy *Builder) {
+    // Avoid division by zero.
+    if (!Factor->getValue())
+      return LHS;
+    return Kind == DC_SRem ? Builder->CreateSRem(LHS, Factor)
+                           : Builder->CreateURem(LHS, Factor);
+  }
+
+private:
+  /// Get the unsigned multiplicative factor we're checking for.
+  APInt getFactor(const DivisibilityCheck &O, bool &Failed) const {
+    switch (Kind) {
+    case DC_Null:
+      llvm_unreachable("unexpected Kind");
+
+    case DC_SRem:
+      if (Factor->getValue().isNegative())
+        return -Factor->getValue();
+      // Fall through.
+    case DC_URem:
+      return Factor->getValue();
+
+    case DC_And:
+      assert(O.Kind != DC_And && "bad kind pair");
+      // If we're also checking for divisibility by K * 2^N,
+      // the low N bits of the mask are irrelevant.
+      APInt Result =
+          Factor->getValue() |
+          APInt::getLowBitsSet(Factor->getValue().getBitWidth(),
+                               O.getFactor(*this, Failed).countTrailingZeros());
+      ++Result;
+      if (!!Result && !Result.isPowerOf2())
+        Failed = true;
+      return Result;
+    }
+  }
+};
+
+struct DivisibilityCheck_match {
+  DivisibilityCheck &Check;
+  DivisibilityCheck_match(DivisibilityCheck &Check) : Check(Check) {}
+  bool match(Value *V) { return Check.match(V); }
+};
+
+/// Matcher for divisibility checks.
+DivisibilityCheck_match m_DivisibilityCheck(DivisibilityCheck &Check) {
+  return DivisibilityCheck_match(Check);
+}
+}
+
 /// visitICmpInstWithInstAndIntCst - Handle "icmp (instr, intcst)".
 ///
 Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI,
@@ -1334,6 +1451,15 @@
         Op = BinaryOperator::CreateOr(ICIP, ICIQ);
       return Op;
     }
+    DivisibilityCheck DivL, DivR;
+    if (match(LHSI,
+              m_Or(m_DivisibilityCheck(DivL), m_DivisibilityCheck(DivR))) &&
+        DivL.merge(DivR)) {
+      // Simplifiy icmp eq (or (srem P, M), (srem P, N)), 0
+      //  -> icmp eq (srem P, lcm(M, N)), 0
+      return new ICmpInst(ICI.getPredicate(), DivL.create(Builder),
+                          Constant::getNullValue(LHSI->getType()));
+    }
     break;
   }
 
Index: test/Transforms/InstCombine/divisibility.ll
===================================================================
--- test/Transforms/InstCombine/divisibility.ll	(revision 0)
+++ test/Transforms/InstCombine/divisibility.ll	(revision 0)
@@ -0,0 +1,276 @@
+; Test that multiple divisibility checks are merged.
+
+; RUN: opt < %s -instcombine -S | FileCheck %s
+
+define i1 @test1(i32 %A) {
+  %B = srem i32 %A, 2
+  %C = srem i32 %A, 3
+  %D = or i32 %B, %C
+  %E = icmp eq i32 %D, 0
+  ret i1 %E
+; CHECK-LABEL: @test1(
+; CHECK-NEXT: srem i32 %A, 6
+; CHECK-NEXT: icmp eq i32 %{{.*}}, 0
+; CHECK-NEXT: ret i1
+}
+
+define i1 @test2(i32 %A) {
+  %B = urem i32 %A, 2
+  %C = urem i32 %A, 3
+  %D = or i32 %B, %C
+  %E = icmp eq i32 %D, 0
+  ret i1 %E
+; CHECK-LABEL: @test2(
+; CHECK-NEXT: urem i32 %A, 6
+; CHECK-NEXT: icmp eq i32 %{{.*}}, 0
+; CHECK-NEXT: ret i1
+}
+
+define i1 @test3(i32 %A) {
+  %B = srem i32 %A, 2
+  %C = urem i32 %A, 3
+  %D = or i32 %B, %C
+  %E = icmp eq i32 %D, 0
+  ret i1 %E
+; CHECK-LABEL: @test3(
+; CHECK-NEXT: urem i32 %A, 6
+; CHECK-NEXT: icmp eq i32 %{{.*}}, 0
+; CHECK-NEXT: ret i1
+}
+
+define i1 @test4(i32 %A) {
+  %B = urem i32 %A, 2
+  %C = srem i32 %A, 3
+  %D = or i32 %B, %C
+  %E = icmp eq i32 %D, 0
+  ret i1 %E
+; CHECK-LABEL: @test4(
+; CHECK-NEXT: srem i32 %A, 6
+; CHECK-NEXT: icmp eq i32 %{{.*}}, 0
+; CHECK-NEXT: ret i1
+}
+
+define i1 @test5(i32 %A) {
+  %B = srem i32 %A, 8
+  %C = srem i32 %A, 12
+  %D = or i32 %B, %C
+  %E = icmp eq i32 %D, 0
+  ret i1 %E
+; CHECK-LABEL: @test5(
+; CHECK-NEXT: srem i32 %A, 24
+; CHECK-NEXT: icmp eq i32 %{{.*}}, 0
+; CHECK-NEXT: ret i1
+}
+
+define i1 @test6(i32 %A) {
+  %B = and i32 %A, 6
+  %C = srem i32 %A, 12
+  %D = or i32 %B, %C
+  %E = icmp eq i32 %D, 0
+  ret i1 %E
+; CHECK-LABEL: @test6(
+; CHECK-NEXT: srem i32 %A, 24
+; CHECK-NEXT: icmp eq i32 %{{.*}}, 0
+; CHECK-NEXT: ret i1
+}
+
+define i1 @test7(i32 %A) {
+  %B = and i32 %A, 8
+  %C = srem i32 %A, 12
+  %D = or i32 %B, %C
+  %E = icmp eq i32 %D, 0
+  ret i1 %E
+; CHECK-LABEL: @test7(
+; CHECK-NEXT: and i32 %A, 8
+; CHECK-NEXT: srem i32 %A, 12
+; CHECK-NEXT: or
+; CHECK-NEXT: icmp
+; CHECK-NEXT: ret i1
+}
+
+define i1 @test8(i32 %A, i32 %B) {
+  %C = srem i32 %A, 2
+  %D = srem i32 %B, 3
+  %E = or i32 %C, %D
+  %F = icmp eq i32 %E, 0
+  ret i1 %F
+; CHECK-LABEL: @test8(
+; CHECK-NEXT: srem i32 %A, 2
+; CHECK-NEXT: srem i32 %B, 3
+; CHECK-NEXT: or
+; CHECK-NEXT: icmp
+; CHECK-NEXT: ret i1
+}
+
+define i1 @test9(i32 %A) {
+  %B = srem i32 %A, 7589
+  %C = srem i32 %A, 395309
+  %D = or i32 %B, %C
+  %E = icmp eq i32 %D, 0
+  ret i1 %E
+; CHECK-LABEL: @test9(
+; CHECK-NEXT: icmp eq i32 %A, 0
+; CHECK-NEXT: ret i1 %E
+}
+
+define i1 @test10(i32 %A) {
+  ; 7589 and 395309 are prime, and
+  ; 7589 * 395309 == 3000000001 == -1294967295 (2^32)
+  %B = urem i32 %A, 7589
+  %C = urem i32 %A, 395309
+  %D = or i32 %B, %C
+  %E = icmp eq i32 %D, 0
+  ret i1 %E
+; CHECK-LABEL: @test10(
+; CHECK-NEXT: urem i32 %A, -1294967295
+; CHECK-NEXT: icmp eq i32 %{{.*}}, 0
+; CHECK-NEXT: ret i1
+}
+
+define i1 @test11(i32 %A) {
+  %B = urem i32 %A, 65535
+  %C = urem i32 %A, 65537
+  %D = or i32 %B, %C
+  %E = icmp eq i32 %D, 0
+  ret i1 %E
+; CHECK-LABEL: @test11(
+; CHECK-NEXT: urem i32 %A, -1
+; CHECK-NEXT: icmp eq i32 %{{.*}}, 0
+; CHECK-NEXT: ret i1
+}
+
+define i1 @test12(i32 %A) {
+  %B = urem i32 %A, 65536
+  %C = urem i32 %A, 65537
+  %D = or i32 %B, %C
+  %E = icmp eq i32 %D, 0
+  ret i1 %E
+; CHECK-LABEL: @test12(
+; CHECK-NEXT: icmp eq i32 %A, 0
+; CHECK-NEXT: ret i1
+}
+
+define i1 @test13(i32 %A) {
+  %B = srem i32 %A, 65536
+  %C = urem i32 %A, 65535
+  %D = or i32 %B, %C
+  %E = icmp eq i32 %D, 0
+  ret i1 %E
+; CHECK-LABEL: @test13(
+; CHECK-NEXT: urem i32 %A, -65536
+; CHECK-NEXT: icmp eq i32 %{{.*}}, 0
+; CHECK-NEXT: ret i1
+}
+
+define i1 @test14(i32 %A) {
+  %B = srem i32 %A, 95
+  %C = srem i32 %A, 22605091
+  %D = or i32 %B, %C
+  %E = icmp eq i32 %D, 0
+  ret i1 %E
+; CHECK-LABEL: @test14(
+; CHECK-NEXT: srem i32 %A, 2147483645
+; CHECK-NEXT: icmp eq i32 %{{.*}}, 0
+; CHECK-NEXT: ret i1
+}
+
+define i1 @test15(i32 %A) {
+  %B = srem i32 %A, 97
+  %C = srem i32 %A, 22605091
+  %D = or i32 %B, %C
+  %E = icmp eq i32 %D, 0
+  ret i1 %E
+; CHECK-LABEL: @test15(
+; CHECK-NEXT: icmp eq i32 %A, 0
+; CHECK-NEXT: ret i1
+}
+
+define i32 @test16(i32 %A) {
+  %B = srem i32 %A, 3
+  %C = srem i32 %A, 5
+  %D = or i32 %B, %C
+  %E = icmp eq i32 %D, 0
+  %F = zext i1 %E to i32
+  %G = add i32 %B, %F
+  ret i32 %G
+; CHECK-LABEL: @test16(
+; CHECK-NEXT:  %B = srem i32 %A, 3
+; CHECK-NEXT:  %[[REM:.*]] = srem i32 %A, 15
+; CHECK-NEXT:  %E = icmp eq i32 %[[REM]], 0
+; CHECK-NEXT:  %F = zext i1 %E to i32
+; CHECK-NEXT:  %G = add i32 %B, %F
+; CHECK-NEXT:  ret i32 %G
+}
+
+define i32 @test17(i32 %A) {
+  %B = srem i32 %A, 3
+  %C = srem i32 %A, 5
+  %D = or i32 %B, %C
+  %E = icmp eq i32 %D, 0
+  %F = zext i1 %E to i32
+  %G = add i32 %B, %F
+  %H = add i32 %C, %G
+  ret i32 %H
+; CHECK-LABEL: @test17(
+; CHECK-NEXT:  %B = srem i32 %A, 3
+; CHECK-NEXT:  %C = srem i32 %A, 5
+; CHECK-NOT: srem
+; CHECK: ret i32
+}
+
+define i32 @test18(i32 %A) {
+  %B = srem i32 %A, 3
+  %C = and i32 %A, 7
+  %D = or i32 %B, %C
+  %E = icmp eq i32 %D, 0
+  %F = zext i1 %E to i32
+  %G = add i32 %C, %F
+  ret i32 %G
+; CHECK-LABEL: @test18(
+; CHECK-NEXT:  %C = and i32 %A, 7
+; CHECK-NEXT:  %[[REM:.*]] = srem i32 %A, 24
+; CHECK-NEXT:  %E = icmp eq i32 %[[REM]], 0
+; CHECK-NEXT:  %F = zext i1 %E to i32
+; CHECK-NEXT:  %G = add
+; CHECK-NEXT:  ret i32 %G
+}
+
+define i1 @test19(i32 %A) {
+  %B = srem i32 %A, 6
+  %C = srem i32 %A, 10
+  %D = icmp eq i32 %B, 0
+  %E = icmp eq i32 %C, 0
+  %F = and i1 %D, %E
+  ret i1 %F
+; CHECK-LABEL: @test19(
+; CHECK-NEXT:  %[[REM:.*]] = srem i32 %A, 30
+; CHECK-NEXT:  icmp eq i32 %[[REM]], 0
+; CHECK-NEXT:  ret i1
+}
+
+define i1 @test20(i32 %A) {
+  %B = and i32 %A, 1
+  %C = srem i32 %A, 3
+  %D = and i32 %A, 3
+  %E = srem i32 %A, 5
+  %F = srem i32 %A, 6
+  %G = icmp eq i32 %B, 0
+  %H = icmp eq i32 %C, 0
+  %I = icmp eq i32 %D, 0
+  %J = icmp eq i32 %E, 0
+  %K = icmp eq i32 %F, 0
+  %L = and i1 %G, %H
+  %M = and i1 %L, %I
+  %N = and i1 %M, %J
+  %O = and i1 %N, %K
+  ret i1 %O
+; FIXME: We should recurse into operands of 'or's in comparisons to 0.
+; CHECK-LABEL: @test20(
+; CHECK-NEXT:  srem i32 %A, 5
+; CHECK-NEXT:  srem i32 %A, 6
+; CHECK-NEXT:  srem i32 %A, 12
+; CHECK-NEXT:  or i32
+; CHECK-NEXT:  or i32
+; CHECK-NEXT:  icmp eq i32
+; CHECK-NEXT:  ret i1
+}
Index: include/llvm/ADT/APInt.h
===================================================================
--- include/llvm/ADT/APInt.h	(revision 213404)
+++ include/llvm/ADT/APInt.h	(working copy)
@@ -1756,13 +1756,13 @@
 /// \brief Returns the floor log base 2 of the specified APInt value.
 inline unsigned logBase2(const APInt &APIVal) { return APIVal.logBase2(); }
 
-/// \brief Compute GCD of two APInt values.
+/// \brief Compute GCD of two unsigned APInt values.
 ///
 /// This function returns the greatest common divisor of the two APInt values
 /// using Euclid's algorithm.
 ///
 /// \returns the greatest common divisor of Val1 and Val2
-APInt GreatestCommonDivisor(const APInt &Val1, const APInt &Val2);
+APInt GreatestCommonDivisor(APInt Val1, APInt Val2);
 
 /// \brief Converts the given APInt to a double value.
 ///


More information about the llvm-commits mailing list