[llvm] d86fff6 - [ValueTracking] Fix computeKnownBits() with bitwidth-changing ptrtoint

Nikita Popov via llvm-commits llvm-commits at lists.llvm.org
Sat May 16 05:19:12 PDT 2020


Author: Nikita Popov
Date: 2020-05-16T14:17:11+02:00
New Revision: d86fff6ae7cfd6666511e9c2129711b88ed514be

URL: https://github.com/llvm/llvm-project/commit/d86fff6ae7cfd6666511e9c2129711b88ed514be
DIFF: https://github.com/llvm/llvm-project/commit/d86fff6ae7cfd6666511e9c2129711b88ed514be.diff

LOG: [ValueTracking] Fix computeKnownBits() with bitwidth-changing ptrtoint

computeKnownBitsFromAssume() currently asserts if m_V matches a
ptrtoint that changes the bitwidth. Because InstCombine
canonicalizes ptrtoint instructions to use explicit zext/trunc,
we never ran into the issue in practice. I'm adding unit tests,
as I don't know if this can be triggered via IR anywhere.

Fix this by calling anyextOrTrunc(BitWidth) on the computed
KnownBits. Note that we are going from the KnownBits of the
ptrtoint result to the KnownBits of the ptrtoint operand,
so we need to truncate if the ptrtoint zexted and anyext if
the ptrtoint truncated.

Differential Revision: https://reviews.llvm.org/D79234

Added: 
    

Modified: 
    llvm/lib/Analysis/ValueTracking.cpp
    llvm/unittests/Analysis/ValueTrackingTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index a4b0202dc18c..8d69df3dec4e 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -785,6 +785,7 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known,
     if (!Cmp)
       continue;
 
+    // Note that ptrtoint may change the bitwidth.
     Value *A, *B;
     auto m_V = m_CombineOr(m_Specific(V), m_PtrToInt(m_Specific(V)));
 
@@ -797,18 +798,18 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known,
       // assume(v = a)
       if (match(Cmp, m_c_ICmp(Pred, m_V, m_Value(A))) &&
           isValidAssumeForContext(I, Q.CxtI, Q.DT)) {
-        KnownBits RHSKnown(BitWidth);
-        computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I));
+        KnownBits RHSKnown =
+            computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
         Known.Zero |= RHSKnown.Zero;
         Known.One  |= RHSKnown.One;
       // assume(v & b = a)
       } else if (match(Cmp,
                        m_c_ICmp(Pred, m_c_And(m_V, m_Value(B)), m_Value(A))) &&
                  isValidAssumeForContext(I, Q.CxtI, Q.DT)) {
-        KnownBits RHSKnown(BitWidth);
-        computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I));
-        KnownBits MaskKnown(BitWidth);
-        computeKnownBits(B, MaskKnown, Depth+1, Query(Q, I));
+        KnownBits RHSKnown =
+            computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
+        KnownBits MaskKnown =
+            computeKnownBits(B, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
 
         // For those bits in the mask that are known to be one, we can propagate
         // known bits from the RHS to V.
@@ -818,10 +819,10 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known,
       } else if (match(Cmp, m_c_ICmp(Pred, m_Not(m_c_And(m_V, m_Value(B))),
                                      m_Value(A))) &&
                  isValidAssumeForContext(I, Q.CxtI, Q.DT)) {
-        KnownBits RHSKnown(BitWidth);
-        computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I));
-        KnownBits MaskKnown(BitWidth);
-        computeKnownBits(B, MaskKnown, Depth+1, Query(Q, I));
+        KnownBits RHSKnown =
+            computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
+        KnownBits MaskKnown =
+            computeKnownBits(B, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
 
         // For those bits in the mask that are known to be one, we can propagate
         // inverted known bits from the RHS to V.
@@ -831,10 +832,10 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known,
       } else if (match(Cmp,
                        m_c_ICmp(Pred, m_c_Or(m_V, m_Value(B)), m_Value(A))) &&
                  isValidAssumeForContext(I, Q.CxtI, Q.DT)) {
-        KnownBits RHSKnown(BitWidth);
-        computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I));
-        KnownBits BKnown(BitWidth);
-        computeKnownBits(B, BKnown, Depth+1, Query(Q, I));
+        KnownBits RHSKnown =
+            computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
+        KnownBits BKnown =
+            computeKnownBits(B, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
 
         // For those bits in B that are known to be zero, we can propagate known
         // bits from the RHS to V.
@@ -844,10 +845,10 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known,
       } else if (match(Cmp, m_c_ICmp(Pred, m_Not(m_c_Or(m_V, m_Value(B))),
                                      m_Value(A))) &&
                  isValidAssumeForContext(I, Q.CxtI, Q.DT)) {
-        KnownBits RHSKnown(BitWidth);
-        computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I));
-        KnownBits BKnown(BitWidth);
-        computeKnownBits(B, BKnown, Depth+1, Query(Q, I));
+        KnownBits RHSKnown =
+            computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
+        KnownBits BKnown =
+            computeKnownBits(B, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
 
         // For those bits in B that are known to be zero, we can propagate
         // inverted known bits from the RHS to V.
@@ -857,10 +858,10 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known,
       } else if (match(Cmp,
                        m_c_ICmp(Pred, m_c_Xor(m_V, m_Value(B)), m_Value(A))) &&
                  isValidAssumeForContext(I, Q.CxtI, Q.DT)) {
-        KnownBits RHSKnown(BitWidth);
-        computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I));
-        KnownBits BKnown(BitWidth);
-        computeKnownBits(B, BKnown, Depth+1, Query(Q, I));
+        KnownBits RHSKnown =
+            computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
+        KnownBits BKnown =
+            computeKnownBits(B, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
 
         // For those bits in B that are known to be zero, we can propagate known
         // bits from the RHS to V. For those bits in B that are known to be one,
@@ -873,10 +874,10 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known,
       } else if (match(Cmp, m_c_ICmp(Pred, m_Not(m_c_Xor(m_V, m_Value(B))),
                                      m_Value(A))) &&
                  isValidAssumeForContext(I, Q.CxtI, Q.DT)) {
-        KnownBits RHSKnown(BitWidth);
-        computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I));
-        KnownBits BKnown(BitWidth);
-        computeKnownBits(B, BKnown, Depth+1, Query(Q, I));
+        KnownBits RHSKnown =
+            computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
+        KnownBits BKnown =
+            computeKnownBits(B, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
 
         // For those bits in B that are known to be zero, we can propagate
         // inverted known bits from the RHS to V. For those bits in B that are
@@ -889,8 +890,9 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known,
       } else if (match(Cmp, m_c_ICmp(Pred, m_Shl(m_V, m_ConstantInt(C)),
                                      m_Value(A))) &&
                  isValidAssumeForContext(I, Q.CxtI, Q.DT) && C < BitWidth) {
-        KnownBits RHSKnown(BitWidth);
-        computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I));
+        KnownBits RHSKnown =
+            computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
+
         // For those bits in RHS that are known, we can propagate them to known
         // bits in V shifted to the right by C.
         RHSKnown.Zero.lshrInPlace(C);
@@ -901,8 +903,8 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known,
       } else if (match(Cmp, m_c_ICmp(Pred, m_Not(m_Shl(m_V, m_ConstantInt(C))),
                                      m_Value(A))) &&
                  isValidAssumeForContext(I, Q.CxtI, Q.DT) && C < BitWidth) {
-        KnownBits RHSKnown(BitWidth);
-        computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I));
+        KnownBits RHSKnown =
+            computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
         // For those bits in RHS that are known, we can propagate them inverted
         // to known bits in V shifted to the right by C.
         RHSKnown.One.lshrInPlace(C);
@@ -913,8 +915,8 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known,
       } else if (match(Cmp, m_c_ICmp(Pred, m_Shr(m_V, m_ConstantInt(C)),
                                      m_Value(A))) &&
                  isValidAssumeForContext(I, Q.CxtI, Q.DT) && C < BitWidth) {
-        KnownBits RHSKnown(BitWidth);
-        computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I));
+        KnownBits RHSKnown =
+            computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
         // For those bits in RHS that are known, we can propagate them to known
         // bits in V shifted to the right by C.
         Known.Zero |= RHSKnown.Zero << C;
@@ -923,8 +925,8 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known,
       } else if (match(Cmp, m_c_ICmp(Pred, m_Not(m_Shr(m_V, m_ConstantInt(C))),
                                      m_Value(A))) &&
                  isValidAssumeForContext(I, Q.CxtI, Q.DT) && C < BitWidth) {
-        KnownBits RHSKnown(BitWidth);
-        computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I));
+        KnownBits RHSKnown =
+            computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
         // For those bits in RHS that are known, we can propagate them inverted
         // to known bits in V shifted to the right by C.
         Known.Zero |= RHSKnown.One  << C;
@@ -935,8 +937,8 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known,
       // assume(v >=_s c) where c is non-negative
       if (match(Cmp, m_ICmp(Pred, m_V, m_Value(A))) &&
           isValidAssumeForContext(I, Q.CxtI, Q.DT)) {
-        KnownBits RHSKnown(BitWidth);
-        computeKnownBits(A, RHSKnown, Depth + 1, Query(Q, I));
+        KnownBits RHSKnown =
+            computeKnownBits(A, Depth + 1, Query(Q, I)).anyextOrTrunc(BitWidth);
 
         if (RHSKnown.isNonNegative()) {
           // We know that the sign bit is zero.
@@ -948,8 +950,8 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known,
       // assume(v >_s c) where c is at least -1.
       if (match(Cmp, m_ICmp(Pred, m_V, m_Value(A))) &&
           isValidAssumeForContext(I, Q.CxtI, Q.DT)) {
-        KnownBits RHSKnown(BitWidth);
-        computeKnownBits(A, RHSKnown, Depth + 1, Query(Q, I));
+        KnownBits RHSKnown =
+            computeKnownBits(A, Depth + 1, Query(Q, I)).anyextOrTrunc(BitWidth);
 
         if (RHSKnown.isAllOnes() || RHSKnown.isNonNegative()) {
           // We know that the sign bit is zero.
@@ -961,8 +963,8 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known,
       // assume(v <=_s c) where c is negative
       if (match(Cmp, m_ICmp(Pred, m_V, m_Value(A))) &&
           isValidAssumeForContext(I, Q.CxtI, Q.DT)) {
-        KnownBits RHSKnown(BitWidth);
-        computeKnownBits(A, RHSKnown, Depth + 1, Query(Q, I));
+        KnownBits RHSKnown =
+            computeKnownBits(A, Depth + 1, Query(Q, I)).anyextOrTrunc(BitWidth);
 
         if (RHSKnown.isNegative()) {
           // We know that the sign bit is one.
@@ -974,8 +976,8 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known,
       // assume(v <_s c) where c is non-positive
       if (match(Cmp, m_ICmp(Pred, m_V, m_Value(A))) &&
           isValidAssumeForContext(I, Q.CxtI, Q.DT)) {
-        KnownBits RHSKnown(BitWidth);
-        computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I));
+        KnownBits RHSKnown =
+            computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
 
         if (RHSKnown.isZero() || RHSKnown.isNegative()) {
           // We know that the sign bit is one.
@@ -987,8 +989,8 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known,
       // assume(v <=_u c)
       if (match(Cmp, m_ICmp(Pred, m_V, m_Value(A))) &&
           isValidAssumeForContext(I, Q.CxtI, Q.DT)) {
-        KnownBits RHSKnown(BitWidth);
-        computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I));
+        KnownBits RHSKnown =
+            computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
 
         // Whatever high bits in c are zero are known to be zero.
         Known.Zero.setHighBits(RHSKnown.countMinLeadingZeros());
@@ -998,8 +1000,8 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known,
       // assume(v <_u c)
       if (match(Cmp, m_ICmp(Pred, m_V, m_Value(A))) &&
           isValidAssumeForContext(I, Q.CxtI, Q.DT)) {
-        KnownBits RHSKnown(BitWidth);
-        computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I));
+        KnownBits RHSKnown =
+            computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
 
         // If the RHS is known zero, then this assumption must be wrong (nothing
         // is unsigned less than zero). Signal a conflict and get out of here.

diff  --git a/llvm/unittests/Analysis/ValueTrackingTest.cpp b/llvm/unittests/Analysis/ValueTrackingTest.cpp
index 85456e6b3191..a5ebb7ff5ce1 100644
--- a/llvm/unittests/Analysis/ValueTrackingTest.cpp
+++ b/llvm/unittests/Analysis/ValueTrackingTest.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "llvm/Analysis/ValueTracking.h"
+#include "llvm/Analysis/AssumptionCache.h"
 #include "llvm/AsmParser/Parser.h"
 #include "llvm/IR/Function.h"
 #include "llvm/IR/InstIterator.h"
@@ -40,7 +41,7 @@ class ValueTrackingTest : public testing::Test {
     M = parseModule(Assembly);
     ASSERT_TRUE(M);
 
-    Function *F = M->getFunction("test");
+    F = M->getFunction("test");
     ASSERT_TRUE(F) << "Test must have a function @test";
     if (!F)
       return;
@@ -57,6 +58,7 @@ class ValueTrackingTest : public testing::Test {
 
   LLVMContext Context;
   std::unique_ptr<Module> M;
+  Function *F = nullptr;
   Instruction *A = nullptr;
 };
 
@@ -954,6 +956,44 @@ TEST_F(ComputeKnownBitsTest, ComputeKnownUSubSatZerosPreserved) {
   expectKnownBits(/*zero*/ 2u, /*one*/ 0u);
 }
 
+TEST_F(ComputeKnownBitsTest, ComputeKnownBitsPtrToIntTrunc) {
+  // ptrtoint truncates the pointer type.
+  parseAssembly(
+      "define void @test(i8** %p) {\n"
+      "  %A = load i8*, i8** %p\n"
+      "  %i = ptrtoint i8* %A to i32\n"
+      "  %m = and i32 %i, 31\n"
+      "  %c = icmp eq i32 %m, 0\n"
+      "  call void @llvm.assume(i1 %c)\n"
+      "  ret void\n"
+      "}\n"
+      "declare void @llvm.assume(i1)\n");
+  AssumptionCache AC(*F);
+  KnownBits Known = computeKnownBits(
+      A, M->getDataLayout(), /* Depth */ 0, &AC, F->front().getTerminator());
+  EXPECT_EQ(Known.Zero.getZExtValue(), 31u);
+  EXPECT_EQ(Known.One.getZExtValue(), 0u);
+}
+
+TEST_F(ComputeKnownBitsTest, ComputeKnownBitsPtrToIntZext) {
+  // ptrtoint zero extends the pointer type.
+  parseAssembly(
+      "define void @test(i8** %p) {\n"
+      "  %A = load i8*, i8** %p\n"
+      "  %i = ptrtoint i8* %A to i128\n"
+      "  %m = and i128 %i, 31\n"
+      "  %c = icmp eq i128 %m, 0\n"
+      "  call void @llvm.assume(i1 %c)\n"
+      "  ret void\n"
+      "}\n"
+      "declare void @llvm.assume(i1)\n");
+  AssumptionCache AC(*F);
+  KnownBits Known = computeKnownBits(
+      A, M->getDataLayout(), /* Depth */ 0, &AC, F->front().getTerminator());
+  EXPECT_EQ(Known.Zero.getZExtValue(), 31u);
+  EXPECT_EQ(Known.One.getZExtValue(), 0u);
+}
+
 class IsBytewiseValueTest : public ValueTrackingTest,
                             public ::testing::WithParamInterface<
                                 std::pair<const char *, const char *>> {


        


More information about the llvm-commits mailing list