[llvm] r320269 - Infer lowest bits of an integer Multiply when the low bits of the operands are known

Simon Dardis via llvm-commits llvm-commits at lists.llvm.org
Sat Dec 9 15:25:57 PST 2017


Author: sdardis
Date: Sat Dec  9 15:25:57 2017
New Revision: 320269

URL: http://llvm.org/viewvc/llvm-project?rev=320269&view=rev
Log:
Infer lowest bits of an integer Multiply when the low bits of the operands are known

When the lowest bits of the operands to an integer multiply are known, the low bits of the result are deducible.
Code to deduce known-zero bottom bits already existed, but this change improves on that by deducing known-ones.

Patch by: Pedro Ferreira

Reviewers: craig.topper, sanjoy, efriedma

Differential Revision: https://reviews.llvm.org/D34029

Modified:
    llvm/trunk/lib/Analysis/ValueTracking.cpp
    llvm/trunk/unittests/Analysis/ValueTrackingTest.cpp

Modified: llvm/trunk/lib/Analysis/ValueTracking.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Analysis/ValueTracking.cpp?rev=320269&r1=320268&r2=320269&view=diff
==============================================================================
--- llvm/trunk/lib/Analysis/ValueTracking.cpp (original)
+++ llvm/trunk/lib/Analysis/ValueTracking.cpp Sat Dec  9 15:25:57 2017
@@ -336,21 +336,78 @@ static void computeKnownBitsMul(const Va
     }
   }
 
-  // If low bits are zero in either operand, output low known-0 bits.
-  // Also compute a conservative estimate for high known-0 bits.
-  // More trickiness is possible, but this is sufficient for the
-  // interesting case of alignment computation.
-  unsigned TrailZ = Known.countMinTrailingZeros() +
-                    Known2.countMinTrailingZeros();
+  assert(!Known.hasConflict() && !Known2.hasConflict());
+  // Compute a conservative estimate for high known-0 bits.
   unsigned LeadZ =  std::max(Known.countMinLeadingZeros() +
                              Known2.countMinLeadingZeros(),
                              BitWidth) - BitWidth;
-
-  TrailZ = std::min(TrailZ, BitWidth);
   LeadZ = std::min(LeadZ, BitWidth);
+
+  // The result of the bottom bits of an integer multiply can be
+  // inferred by looking at the bottom bits of both operands and
+  // multiplying them together.
+  // We can infer at least the minimum number of known trailing bits
+  // of both operands. Depending on number of trailing zeros, we can
+  // infer more bits, because (a*b) <=> ((a/m) * (b/n)) * (m*n) assuming
+  // a and b are divisible by m and n respectively.
+  // We then calculate how many of those bits are inferrable and set
+  // the output. For example, the i8 mul:
+  //  a = XXXX1100 (12)
+  //  b = XXXX1110 (14)
+  // We know the bottom 3 bits are zero since the first can be divided by
+  // 4 and the second by 2, thus having ((12/4) * (14/2)) * (2*4).
+  // Applying the multiplication to the trimmed arguments gets:
+  //    XX11 (3)
+  //    X111 (7)
+  // -------
+  //    XX11
+  //   XX11
+  //  XX11
+  // XX11
+  // -------
+  // XXXXX01
+  // Which allows us to infer the 2 LSBs. Since we're multiplying the result
+  // by 8, the bottom 3 bits will be 0, so we can infer a total of 5 bits.
+  // The proof for this can be described as:
+  // Pre: (C1 >= 0) && (C1 < (1 << C5)) && (C2 >= 0) && (C2 < (1 << C6)) &&
+  //      (C7 == (1 << (umin(countTrailingZeros(C1), C5) +
+  //                    umin(countTrailingZeros(C2), C6) +
+  //                    umin(C5 - umin(countTrailingZeros(C1), C5),
+  //                         C6 - umin(countTrailingZeros(C2), C6)))) - 1)
+  // %aa = shl i8 %a, C5
+  // %bb = shl i8 %b, C6
+  // %aaa = or i8 %aa, C1
+  // %bbb = or i8 %bb, C2
+  // %mul = mul i8 %aaa, %bbb
+  // %mask = and i8 %mul, C7
+  //   =>
+  // %mask = i8 ((C1*C2)&C7)
+  // Where C5, C6 describe the known bits of %a, %b
+  // C1, C2 describe the known bottom bits of %a, %b.
+  // C7 describes the mask of the known bits of the result.
+  APInt Bottom0 = Known.One;
+  APInt Bottom1 = Known2.One;
+
+  // How many times we'd be able to divide each argument by 2 (shr by 1).
+  // This gives us the number of trailing zeros on the multiplication result.
+  unsigned TrailBitsKnown0 = (Known.Zero | Known.One).countTrailingOnes();
+  unsigned TrailBitsKnown1 = (Known2.Zero | Known2.One).countTrailingOnes();
+  unsigned TrailZero0 = Known.countMinTrailingZeros();
+  unsigned TrailZero1 = Known2.countMinTrailingZeros();
+  unsigned TrailZ = TrailZero0 + TrailZero1;
+
+  // Figure out the fewest known-bits operand.
+  unsigned SmallestOperand = std::min(TrailBitsKnown0 - TrailZero0,
+                                      TrailBitsKnown1 - TrailZero1);
+  unsigned ResultBitsKnown = std::min(SmallestOperand + TrailZ, BitWidth);
+
+  APInt BottomKnown = Bottom0.getLoBits(TrailBitsKnown0) *
+                      Bottom1.getLoBits(TrailBitsKnown1);
+
   Known.resetAll();
-  Known.Zero.setLowBits(TrailZ);
   Known.Zero.setHighBits(LeadZ);
+  Known.Zero |= (~BottomKnown).getLoBits(ResultBitsKnown);
+  Known.One |= BottomKnown.getLoBits(ResultBitsKnown);
 
   // Only make use of no-wrap flags if we failed to compute the sign bit
   // directly.  This matters if the multiplication always overflows, in

Modified: llvm/trunk/unittests/Analysis/ValueTrackingTest.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/unittests/Analysis/ValueTrackingTest.cpp?rev=320269&r1=320268&r2=320269&view=diff
==============================================================================
--- llvm/trunk/unittests/Analysis/ValueTrackingTest.cpp (original)
+++ llvm/trunk/unittests/Analysis/ValueTrackingTest.cpp Sat Dec  9 15:25:57 2017
@@ -15,6 +15,7 @@
 #include "llvm/IR/Module.h"
 #include "llvm/Support/ErrorHandling.h"
 #include "llvm/Support/SourceMgr.h"
+#include "llvm/Support/KnownBits.h"
 #include "gtest/gtest.h"
 
 using namespace llvm;
@@ -258,3 +259,57 @@ TEST(ValueTracking, ComputeNumSignBits_P
       cast<ReturnInst>(F->getEntryBlock().getTerminator())->getOperand(0);
   EXPECT_EQ(ComputeNumSignBits(RVal, M->getDataLayout()), 1u);
 }
+
+TEST(ValueTracking, ComputeKnownBits) {
+  StringRef Assembly = "define i32 @f(i32 %a, i32 %b) { "
+                       "  %ash = mul i32 %a, 8 "
+                       "  %aad = add i32 %ash, 7 "
+                       "  %aan = and i32 %aad, 4095 "
+                       "  %bsh = shl i32 %b, 4 "
+                       "  %bad = or i32 %bsh, 6 "
+                       "  %ban = and i32 %bad, 4095 "
+                       "  %mul = mul i32 %aan, %ban "
+                       "  ret i32 %mul "
+                       "} ";
+
+  LLVMContext Context;
+  SMDiagnostic Error;
+  auto M = parseAssemblyString(Assembly, Error, Context);
+  assert(M && "Bad assembly?");
+
+  auto *F = M->getFunction("f");
+  assert(F && "Bad assembly?");
+
+  auto *RVal =
+      cast<ReturnInst>(F->getEntryBlock().getTerminator())->getOperand(0);
+  auto Known = computeKnownBits(RVal, M->getDataLayout());
+  ASSERT_FALSE(Known.hasConflict());
+  EXPECT_EQ(Known.One.getZExtValue(), 10u);
+  EXPECT_EQ(Known.Zero.getZExtValue(), 4278190085u);
+}
+
+TEST(ValueTracking, ComputeKnownMulBits) {
+  StringRef Assembly = "define i32 @f(i32 %a, i32 %b) { "
+                       "  %aa = shl i32 %a, 5 "
+                       "  %bb = shl i32 %b, 5 "
+                       "  %aaa = or i32 %aa, 24 "
+                       "  %bbb = or i32 %bb, 28 "
+                       "  %mul = mul i32 %aaa, %bbb "
+                       "  ret i32 %mul "
+                       "} ";
+
+  LLVMContext Context;
+  SMDiagnostic Error;
+  auto M = parseAssemblyString(Assembly, Error, Context);
+  assert(M && "Bad assembly?");
+
+  auto *F = M->getFunction("f");
+  assert(F && "Bad assembly?");
+
+  auto *RVal =
+      cast<ReturnInst>(F->getEntryBlock().getTerminator())->getOperand(0);
+  auto Known = computeKnownBits(RVal, M->getDataLayout());
+  ASSERT_FALSE(Known.hasConflict());
+  EXPECT_EQ(Known.One.getZExtValue(), 32u);
+  EXPECT_EQ(Known.Zero.getZExtValue(), 95u);
+}




More information about the llvm-commits mailing list