[llvm] 9431f8a - [KnownBits] Add a computeForMul method

Quentin Colombet via llvm-commits llvm-commits at lists.llvm.org
Thu Oct 8 11:35:13 PDT 2020


Author: Quentin Colombet
Date: 2020-10-08T11:33:06-07:00
New Revision: 9431f8ad2e033b3c7629ff74fe41d7c42a9554f8

URL: https://github.com/llvm/llvm-project/commit/9431f8ad2e033b3c7629ff74fe41d7c42a9554f8
DIFF: https://github.com/llvm/llvm-project/commit/9431f8ad2e033b3c7629ff74fe41d7c42a9554f8.diff

LOG: [KnownBits] Add a computeForMul method

This patch refactors the logic in ValueTracking.cpp so that
computeKnownBitsForMul now uses a helper function from KnownBits.

NFC

Differential Revision: https://reviews.llvm.org/D88935

Added: 
    

Modified: 
    llvm/include/llvm/Support/KnownBits.h
    llvm/lib/Analysis/ValueTracking.cpp
    llvm/lib/Support/KnownBits.cpp
    llvm/unittests/Support/KnownBitsTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Support/KnownBits.h b/llvm/include/llvm/Support/KnownBits.h
index 8da6c7d98ba5..f3fde0c74b02 100644
--- a/llvm/include/llvm/Support/KnownBits.h
+++ b/llvm/include/llvm/Support/KnownBits.h
@@ -245,6 +245,9 @@ struct KnownBits {
   static KnownBits computeForAddSub(bool Add, bool NSW, const KnownBits &LHS,
                                     KnownBits RHS);
 
+  /// Compute known bits resulting from multiplying LHS and RHS.
+  static KnownBits computeForMul(const KnownBits &LHS, const KnownBits &RHS);
+
   /// Compute known bits for umax(LHS, RHS).
   static KnownBits umax(const KnownBits &LHS, const KnownBits &RHS);
 

diff  --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index e78beb04e5ea..f84531fee2fa 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -415,7 +415,6 @@ static void computeKnownBitsMul(const Value *Op0, const Value *Op1, bool NSW,
                                 const APInt &DemandedElts, KnownBits &Known,
                                 KnownBits &Known2, unsigned Depth,
                                 const Query &Q) {
-  unsigned BitWidth = Known.getBitWidth();
   computeKnownBits(Op1, DemandedElts, Known, Depth + 1, Q);
   computeKnownBits(Op0, DemandedElts, Known2, Depth + 1, Q);
 
@@ -433,7 +432,7 @@ static void computeKnownBitsMul(const Value *Op0, const Value *Op1, bool NSW,
       bool isKnownNegativeOp0 = Known2.isNegative();
       // The product of two numbers with the same sign is non-negative.
       isKnownNonNegative = (isKnownNegativeOp1 && isKnownNegativeOp0) ||
-        (isKnownNonNegativeOp1 && isKnownNonNegativeOp0);
+                           (isKnownNonNegativeOp1 && isKnownNonNegativeOp0);
       // The product of a negative number and a non-negative number is either
       // negative or zero.
       if (!isKnownNonNegative)
@@ -444,78 +443,7 @@ static void computeKnownBitsMul(const Value *Op0, const Value *Op1, bool NSW,
     }
   }
 
-  assert(!Known.hasConflict() && !Known2.hasConflict());
-  // Compute a conservative estimate for high known-0 bits.
-  unsigned LeadZ =  std::max(Known.countMinLeadingZeros() +
-                             Known2.countMinLeadingZeros(),
-                             BitWidth) - BitWidth;
-  LeadZ = std::min(LeadZ, BitWidth);
-
-  // The result of the bottom bits of an integer multiply can be
-  // inferred by looking at the bottom bits of both operands and
-  // multiplying them together.
-  // We can infer at least the minimum number of known trailing bits
-  // of both operands. Depending on number of trailing zeros, we can
-  // infer more bits, because (a*b) <=> ((a/m) * (b/n)) * (m*n) assuming
-  // a and b are divisible by m and n respectively.
-  // We then calculate how many of those bits are inferrable and set
-  // the output. For example, the i8 mul:
-  //  a = XXXX1100 (12)
-  //  b = XXXX1110 (14)
-  // We know the bottom 3 bits are zero since the first can be divided by
-  // 4 and the second by 2, thus having ((12/4) * (14/2)) * (2*4).
-  // Applying the multiplication to the trimmed arguments gets:
-  //    XX11 (3)
-  //    X111 (7)
-  // -------
-  //    XX11
-  //   XX11
-  //  XX11
-  // XX11
-  // -------
-  // XXXXX01
-  // Which allows us to infer the 2 LSBs. Since we're multiplying the result
-  // by 8, the bottom 3 bits will be 0, so we can infer a total of 5 bits.
-  // The proof for this can be described as:
-  // Pre: (C1 >= 0) && (C1 < (1 << C5)) && (C2 >= 0) && (C2 < (1 << C6)) &&
-  //      (C7 == (1 << (umin(countTrailingZeros(C1), C5) +
-  //                    umin(countTrailingZeros(C2), C6) +
-  //                    umin(C5 - umin(countTrailingZeros(C1), C5),
-  //                         C6 - umin(countTrailingZeros(C2), C6)))) - 1)
-  // %aa = shl i8 %a, C5
-  // %bb = shl i8 %b, C6
-  // %aaa = or i8 %aa, C1
-  // %bbb = or i8 %bb, C2
-  // %mul = mul i8 %aaa, %bbb
-  // %mask = and i8 %mul, C7
-  //   =>
-  // %mask = i8 ((C1*C2)&C7)
-  // Where C5, C6 describe the known bits of %a, %b
-  // C1, C2 describe the known bottom bits of %a, %b.
-  // C7 describes the mask of the known bits of the result.
-  APInt Bottom0 = Known.One;
-  APInt Bottom1 = Known2.One;
-
-  // How many times we'd be able to divide each argument by 2 (shr by 1).
-  // This gives us the number of trailing zeros on the multiplication result.
-  unsigned TrailBitsKnown0 = (Known.Zero | Known.One).countTrailingOnes();
-  unsigned TrailBitsKnown1 = (Known2.Zero | Known2.One).countTrailingOnes();
-  unsigned TrailZero0 = Known.countMinTrailingZeros();
-  unsigned TrailZero1 = Known2.countMinTrailingZeros();
-  unsigned TrailZ = TrailZero0 + TrailZero1;
-
-  // Figure out the fewest known-bits operand.
-  unsigned SmallestOperand = std::min(TrailBitsKnown0 - TrailZero0,
-                                      TrailBitsKnown1 - TrailZero1);
-  unsigned ResultBitsKnown = std::min(SmallestOperand + TrailZ, BitWidth);
-
-  APInt BottomKnown = Bottom0.getLoBits(TrailBitsKnown0) *
-                      Bottom1.getLoBits(TrailBitsKnown1);
-
-  Known.resetAll();
-  Known.Zero.setHighBits(LeadZ);
-  Known.Zero |= (~BottomKnown).getLoBits(ResultBitsKnown);
-  Known.One |= BottomKnown.getLoBits(ResultBitsKnown);
+  Known = KnownBits::computeForMul(Known, Known2);
 
   // Only make use of no-wrap flags if we failed to compute the sign bit
   // directly.  This matters if the multiplication always overflows, in

diff  --git a/llvm/lib/Support/KnownBits.cpp b/llvm/lib/Support/KnownBits.cpp
index ed32a80a061d..532eef34a99e 100644
--- a/llvm/lib/Support/KnownBits.cpp
+++ b/llvm/lib/Support/KnownBits.cpp
@@ -163,6 +163,85 @@ KnownBits KnownBits::abs() const {
   return KnownAbs;
 }
 
+KnownBits KnownBits::computeForMul(const KnownBits &LHS, const KnownBits &RHS) {
+  unsigned BitWidth = LHS.getBitWidth();
+
+  assert(!LHS.hasConflict() && !RHS.hasConflict());
+  // Compute a conservative estimate for high known-0 bits.
+  unsigned LeadZ =
+      std::max(LHS.countMinLeadingZeros() + RHS.countMinLeadingZeros(),
+               BitWidth) -
+      BitWidth;
+  LeadZ = std::min(LeadZ, BitWidth);
+
+  // The result of the bottom bits of an integer multiply can be
+  // inferred by looking at the bottom bits of both operands and
+  // multiplying them together.
+  // We can infer at least the minimum number of known trailing bits
+  // of both operands. Depending on number of trailing zeros, we can
+  // infer more bits, because (a*b) <=> ((a/m) * (b/n)) * (m*n) assuming
+  // a and b are divisible by m and n respectively.
+  // We then calculate how many of those bits are inferrable and set
+  // the output. For example, the i8 mul:
+  //  a = XXXX1100 (12)
+  //  b = XXXX1110 (14)
+  // We know the bottom 3 bits are zero since the first can be divided by
+  // 4 and the second by 2, thus having ((12/4) * (14/2)) * (2*4).
+  // Applying the multiplication to the trimmed arguments gets:
+  //    XX11 (3)
+  //    X111 (7)
+  // -------
+  //    XX11
+  //   XX11
+  //  XX11
+  // XX11
+  // -------
+  // XXXXX01
+  // Which allows us to infer the 2 LSBs. Since we're multiplying the result
+  // by 8, the bottom 3 bits will be 0, so we can infer a total of 5 bits.
+  // The proof for this can be described as:
+  // Pre: (C1 >= 0) && (C1 < (1 << C5)) && (C2 >= 0) && (C2 < (1 << C6)) &&
+  //      (C7 == (1 << (umin(countTrailingZeros(C1), C5) +
+  //                    umin(countTrailingZeros(C2), C6) +
+  //                    umin(C5 - umin(countTrailingZeros(C1), C5),
+  //                         C6 - umin(countTrailingZeros(C2), C6)))) - 1)
+  // %aa = shl i8 %a, C5
+  // %bb = shl i8 %b, C6
+  // %aaa = or i8 %aa, C1
+  // %bbb = or i8 %bb, C2
+  // %mul = mul i8 %aaa, %bbb
+  // %mask = and i8 %mul, C7
+  //   =>
+  // %mask = i8 ((C1*C2)&C7)
+  // Where C5, C6 describe the known bits of %a, %b
+  // C1, C2 describe the known bottom bits of %a, %b.
+  // C7 describes the mask of the known bits of the result.
+  APInt Bottom0 = LHS.One;
+  APInt Bottom1 = RHS.One;
+
+  // How many times we'd be able to divide each argument by 2 (shr by 1).
+  // This gives us the number of trailing zeros on the multiplication result.
+  unsigned TrailBitsKnown0 = (LHS.Zero | LHS.One).countTrailingOnes();
+  unsigned TrailBitsKnown1 = (RHS.Zero | RHS.One).countTrailingOnes();
+  unsigned TrailZero0 = LHS.countMinTrailingZeros();
+  unsigned TrailZero1 = RHS.countMinTrailingZeros();
+  unsigned TrailZ = TrailZero0 + TrailZero1;
+
+  // Figure out the fewest known-bits operand.
+  unsigned SmallestOperand =
+      std::min(TrailBitsKnown0 - TrailZero0, TrailBitsKnown1 - TrailZero1);
+  unsigned ResultBitsKnown = std::min(SmallestOperand + TrailZ, BitWidth);
+
+  APInt BottomKnown =
+      Bottom0.getLoBits(TrailBitsKnown0) * Bottom1.getLoBits(TrailBitsKnown1);
+
+  KnownBits Res(BitWidth);
+  Res.Zero.setHighBits(LeadZ);
+  Res.Zero |= (~BottomKnown).getLoBits(ResultBitsKnown);
+  Res.One = BottomKnown.getLoBits(ResultBitsKnown);
+  return Res;
+}
+
 KnownBits &KnownBits::operator&=(const KnownBits &RHS) {
   // Result bit is 0 if either operand bit is 0.
   Zero |= RHS.Zero;

diff  --git a/llvm/unittests/Support/KnownBitsTest.cpp b/llvm/unittests/Support/KnownBitsTest.cpp
index 89555a5881a5..701293f7dae5 100644
--- a/llvm/unittests/Support/KnownBitsTest.cpp
+++ b/llvm/unittests/Support/KnownBitsTest.cpp
@@ -112,6 +112,7 @@ TEST(KnownBitsTest, BinaryExhaustive) {
       KnownBits KnownUMin(KnownAnd);
       KnownBits KnownSMax(KnownAnd);
       KnownBits KnownSMin(KnownAnd);
+      KnownBits KnownMul(KnownAnd);
 
       ForeachNumInKnownBits(Known1, [&](const APInt &N1) {
         ForeachNumInKnownBits(Known2, [&](const APInt &N2) {
@@ -144,6 +145,10 @@ TEST(KnownBitsTest, BinaryExhaustive) {
           Res = APIntOps::smin(N1, N2);
           KnownSMin.One &= Res;
           KnownSMin.Zero &= ~Res;
+
+          Res = N1 * N2;
+          KnownMul.One &= Res;
+          KnownMul.Zero &= ~Res;
         });
       });
 
@@ -174,6 +179,12 @@ TEST(KnownBitsTest, BinaryExhaustive) {
       KnownBits ComputedSMin = KnownBits::smin(Known1, Known2);
       EXPECT_EQ(KnownSMin.Zero, ComputedSMin.Zero);
       EXPECT_EQ(KnownSMin.One, ComputedSMin.One);
+
+      // ComputedMul is conservatively correct, but not guaranteed to be
+      // precise.
+      KnownBits ComputedMul = KnownBits::computeForMul(Known1, Known2);
+      EXPECT_TRUE(ComputedMul.Zero.isSubsetOf(KnownMul.Zero));
+      EXPECT_TRUE(ComputedMul.One.isSubsetOf(KnownMul.One));
     });
   });
 }


        


More information about the llvm-commits mailing list