[llvm] 8d38568 - [ValueTracking] Don't handle ptrtoint with mismatches sizes

Nikita Popov via llvm-commits llvm-commits at lists.llvm.org
Wed Jul 5 03:18:09 PDT 2023


Author: Nikita Popov
Date: 2023-07-05T12:18:00+02:00
New Revision: 8d3856836cd0aa5c87800b039a8c854b8044a711

URL: https://github.com/llvm/llvm-project/commit/8d3856836cd0aa5c87800b039a8c854b8044a711
DIFF: https://github.com/llvm/llvm-project/commit/8d3856836cd0aa5c87800b039a8c854b8044a711.diff

LOG: [ValueTracking] Don't handle ptrtoint with mismatches sizes

When processing assumes, we also handle assumes on ptrtoint of the
value. In canonical IR, these will have the same size as the value.
However, in non-canonical IR there may be an implicit zext or
trunc, which results in a bit width mismatch. We currently handle
this by adjusting bitwidth everywhere, but this is fragile and I'm
pretty sure that the way we do this is incorrect for some predicates,
because we effectively end up commuting an ext/trunc and an icmp.

Instead, add an m_PtrToIntSameSize() matcher that will only handle
bitwidth preserving cases. For the bitwidth-changing cases, wait
until they have been canonicalized.

The original handling for this was added purely to prevent crashes
in an earlier implementation which failed to account for this
entirely.

Added: 
    

Modified: 
    llvm/include/llvm/IR/PatternMatch.h
    llvm/lib/Analysis/ValueTracking.cpp
    llvm/unittests/Analysis/ValueTrackingTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h
index 5a4c67744a5cb9..621eba6bd0b679 100644
--- a/llvm/include/llvm/IR/PatternMatch.h
+++ b/llvm/include/llvm/IR/PatternMatch.h
@@ -1596,6 +1596,23 @@ template <typename Op_t, unsigned Opcode> struct CastClass_match {
   }
 };
 
+template <typename Op_t> struct PtrToIntSameSize_match {
+  const DataLayout &DL;
+  Op_t Op;
+
+  PtrToIntSameSize_match(const DataLayout &DL, const Op_t &OpMatch)
+      : DL(DL), Op(OpMatch) {}
+
+  template <typename OpTy> bool match(OpTy *V) {
+    if (auto *O = dyn_cast<Operator>(V))
+      return O->getOpcode() == Instruction::PtrToInt &&
+             DL.getTypeSizeInBits(O->getType()) ==
+                 DL.getTypeSizeInBits(O->getOperand(0)->getType()) &&
+             Op.match(O->getOperand(0));
+    return false;
+  }
+};
+
 /// Matches BitCast.
 template <typename OpTy>
 inline CastClass_match<OpTy, Instruction::BitCast> m_BitCast(const OpTy &Op) {
@@ -1608,6 +1625,12 @@ inline CastClass_match<OpTy, Instruction::PtrToInt> m_PtrToInt(const OpTy &Op) {
   return CastClass_match<OpTy, Instruction::PtrToInt>(Op);
 }
 
+template <typename OpTy>
+inline PtrToIntSameSize_match<OpTy> m_PtrToIntSameSize(const DataLayout &DL,
+                                                       const OpTy &Op) {
+  return PtrToIntSameSize_match<OpTy>(DL, Op);
+}
+
 /// Matches IntToPtr.
 template <typename OpTy>
 inline CastClass_match<OpTy, Instruction::IntToPtr> m_IntToPtr(const OpTy &Op) {

diff  --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 51879c42482ca5..5499fe23219c2f 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -638,7 +638,8 @@ static void computeKnownBitsFromCmp(const Value *V, const ICmpInst *Cmp,
 
   // Note that ptrtoint may change the bitwidth.
   Value *A, *B;
-  auto m_V = m_CombineOr(m_Specific(V), m_PtrToInt(m_Specific(V)));
+  auto m_V =
+      m_CombineOr(m_Specific(V), m_PtrToIntSameSize(Q.DL, m_Specific(V)));
 
   CmpInst::Predicate Pred;
   uint64_t C;
@@ -648,16 +649,13 @@ static void computeKnownBitsFromCmp(const Value *V, const ICmpInst *Cmp,
   case ICmpInst::ICMP_EQ:
     // assume(v = a)
     if (match(Cmp, m_c_ICmp(Pred, m_V, m_Value(A)))) {
-      KnownBits RHSKnown =
-          computeKnownBits(A, Depth + 1, QueryNoAC).anyextOrTrunc(BitWidth);
+      KnownBits RHSKnown = computeKnownBits(A, Depth + 1, QueryNoAC);
       Known = Known.unionWith(RHSKnown);
       // assume(v & b = a)
     } else if (match(Cmp,
                      m_c_ICmp(Pred, m_c_And(m_V, m_Value(B)), m_Value(A)))) {
-      KnownBits RHSKnown =
-          computeKnownBits(A, Depth + 1, QueryNoAC).anyextOrTrunc(BitWidth);
-      KnownBits MaskKnown =
-          computeKnownBits(B, Depth + 1, QueryNoAC).anyextOrTrunc(BitWidth);
+      KnownBits RHSKnown = computeKnownBits(A, Depth + 1, QueryNoAC);
+      KnownBits MaskKnown = computeKnownBits(B, Depth + 1, QueryNoAC);
 
       // For those bits in the mask that are known to be one, we can propagate
       // known bits from the RHS to V.
@@ -666,10 +664,8 @@ static void computeKnownBitsFromCmp(const Value *V, const ICmpInst *Cmp,
       // assume(~(v & b) = a)
     } else if (match(Cmp, m_c_ICmp(Pred, m_Not(m_c_And(m_V, m_Value(B))),
                                    m_Value(A)))) {
-      KnownBits RHSKnown =
-          computeKnownBits(A, Depth + 1, QueryNoAC).anyextOrTrunc(BitWidth);
-      KnownBits MaskKnown =
-          computeKnownBits(B, Depth + 1, QueryNoAC).anyextOrTrunc(BitWidth);
+      KnownBits RHSKnown = computeKnownBits(A, Depth + 1, QueryNoAC);
+      KnownBits MaskKnown = computeKnownBits(B, Depth + 1, QueryNoAC);
 
       // For those bits in the mask that are known to be one, we can propagate
       // inverted known bits from the RHS to V.
@@ -678,10 +674,8 @@ static void computeKnownBitsFromCmp(const Value *V, const ICmpInst *Cmp,
       // assume(v | b = a)
     } else if (match(Cmp,
                      m_c_ICmp(Pred, m_c_Or(m_V, m_Value(B)), m_Value(A)))) {
-      KnownBits RHSKnown =
-          computeKnownBits(A, Depth + 1, QueryNoAC).anyextOrTrunc(BitWidth);
-      KnownBits BKnown =
-          computeKnownBits(B, Depth + 1, QueryNoAC).anyextOrTrunc(BitWidth);
+      KnownBits RHSKnown = computeKnownBits(A, Depth + 1, QueryNoAC);
+      KnownBits BKnown = computeKnownBits(B, Depth + 1, QueryNoAC);
 
       // For those bits in B that are known to be zero, we can propagate known
       // bits from the RHS to V.
@@ -690,10 +684,8 @@ static void computeKnownBitsFromCmp(const Value *V, const ICmpInst *Cmp,
       // assume(~(v | b) = a)
     } else if (match(Cmp, m_c_ICmp(Pred, m_Not(m_c_Or(m_V, m_Value(B))),
                                    m_Value(A)))) {
-      KnownBits RHSKnown =
-          computeKnownBits(A, Depth + 1, QueryNoAC).anyextOrTrunc(BitWidth);
-      KnownBits BKnown =
-          computeKnownBits(B, Depth + 1, QueryNoAC).anyextOrTrunc(BitWidth);
+      KnownBits RHSKnown = computeKnownBits(A, Depth + 1, QueryNoAC);
+      KnownBits BKnown = computeKnownBits(B, Depth + 1, QueryNoAC);
 
       // For those bits in B that are known to be zero, we can propagate
       // inverted known bits from the RHS to V.
@@ -702,10 +694,8 @@ static void computeKnownBitsFromCmp(const Value *V, const ICmpInst *Cmp,
       // assume(v ^ b = a)
     } else if (match(Cmp,
                      m_c_ICmp(Pred, m_c_Xor(m_V, m_Value(B)), m_Value(A)))) {
-      KnownBits RHSKnown =
-          computeKnownBits(A, Depth + 1, QueryNoAC).anyextOrTrunc(BitWidth);
-      KnownBits BKnown =
-          computeKnownBits(B, Depth + 1, QueryNoAC).anyextOrTrunc(BitWidth);
+      KnownBits RHSKnown = computeKnownBits(A, Depth + 1, QueryNoAC);
+      KnownBits BKnown = computeKnownBits(B, Depth + 1, QueryNoAC);
 
       // For those bits in B that are known to be zero, we can propagate known
       // bits from the RHS to V. For those bits in B that are known to be one,
@@ -717,10 +707,8 @@ static void computeKnownBitsFromCmp(const Value *V, const ICmpInst *Cmp,
       // assume(~(v ^ b) = a)
     } else if (match(Cmp, m_c_ICmp(Pred, m_Not(m_c_Xor(m_V, m_Value(B))),
                                    m_Value(A)))) {
-      KnownBits RHSKnown =
-          computeKnownBits(A, Depth + 1, QueryNoAC).anyextOrTrunc(BitWidth);
-      KnownBits BKnown =
-          computeKnownBits(B, Depth + 1, QueryNoAC).anyextOrTrunc(BitWidth);
+      KnownBits RHSKnown = computeKnownBits(A, Depth + 1, QueryNoAC);
+      KnownBits BKnown = computeKnownBits(B, Depth + 1, QueryNoAC);
 
       // For those bits in B that are known to be zero, we can propagate
       // inverted known bits from the RHS to V. For those bits in B that are
@@ -733,8 +721,7 @@ static void computeKnownBitsFromCmp(const Value *V, const ICmpInst *Cmp,
     } else if (match(Cmp, m_c_ICmp(Pred, m_Shl(m_V, m_ConstantInt(C)),
                                    m_Value(A))) &&
                C < BitWidth) {
-      KnownBits RHSKnown =
-          computeKnownBits(A, Depth + 1, QueryNoAC).anyextOrTrunc(BitWidth);
+      KnownBits RHSKnown = computeKnownBits(A, Depth + 1, QueryNoAC);
 
       // For those bits in RHS that are known, we can propagate them to known
       // bits in V shifted to the right by C.
@@ -745,8 +732,7 @@ static void computeKnownBitsFromCmp(const Value *V, const ICmpInst *Cmp,
     } else if (match(Cmp, m_c_ICmp(Pred, m_Not(m_Shl(m_V, m_ConstantInt(C))),
                                    m_Value(A))) &&
                C < BitWidth) {
-      KnownBits RHSKnown =
-          computeKnownBits(A, Depth + 1, QueryNoAC).anyextOrTrunc(BitWidth);
+      KnownBits RHSKnown = computeKnownBits(A, Depth + 1, QueryNoAC);
       // For those bits in RHS that are known, we can propagate them inverted
       // to known bits in V shifted to the right by C.
       RHSKnown.One.lshrInPlace(C);
@@ -757,8 +743,7 @@ static void computeKnownBitsFromCmp(const Value *V, const ICmpInst *Cmp,
     } else if (match(Cmp, m_c_ICmp(Pred, m_Shr(m_V, m_ConstantInt(C)),
                                    m_Value(A))) &&
                C < BitWidth) {
-      KnownBits RHSKnown =
-          computeKnownBits(A, Depth + 1, QueryNoAC).anyextOrTrunc(BitWidth);
+      KnownBits RHSKnown = computeKnownBits(A, Depth + 1, QueryNoAC);
       // For those bits in RHS that are known, we can propagate them to known
       // bits in V shifted to the right by C.
       Known.Zero |= RHSKnown.Zero << C;
@@ -767,8 +752,7 @@ static void computeKnownBitsFromCmp(const Value *V, const ICmpInst *Cmp,
     } else if (match(Cmp, m_c_ICmp(Pred, m_Not(m_Shr(m_V, m_ConstantInt(C))),
                                    m_Value(A))) &&
                C < BitWidth) {
-      KnownBits RHSKnown =
-          computeKnownBits(A, Depth + 1, QueryNoAC).anyextOrTrunc(BitWidth);
+      KnownBits RHSKnown = computeKnownBits(A, Depth + 1, QueryNoAC);
       // For those bits in RHS that are known, we can propagate them inverted
       // to known bits in V shifted to the right by C.
       Known.Zero |= RHSKnown.One << C;
@@ -778,8 +762,7 @@ static void computeKnownBitsFromCmp(const Value *V, const ICmpInst *Cmp,
   case ICmpInst::ICMP_SGE:
     // assume(v >=_s c) where c is non-negative
     if (match(Cmp, m_ICmp(Pred, m_V, m_Value(A)))) {
-      KnownBits RHSKnown =
-          computeKnownBits(A, Depth + 1, QueryNoAC).anyextOrTrunc(BitWidth);
+      KnownBits RHSKnown = computeKnownBits(A, Depth + 1, QueryNoAC);
 
       if (RHSKnown.isNonNegative()) {
         // We know that the sign bit is zero.
@@ -790,8 +773,7 @@ static void computeKnownBitsFromCmp(const Value *V, const ICmpInst *Cmp,
   case ICmpInst::ICMP_SGT:
     // assume(v >_s c) where c is at least -1.
     if (match(Cmp, m_ICmp(Pred, m_V, m_Value(A)))) {
-      KnownBits RHSKnown =
-          computeKnownBits(A, Depth + 1, QueryNoAC).anyextOrTrunc(BitWidth);
+      KnownBits RHSKnown = computeKnownBits(A, Depth + 1, QueryNoAC);
 
       if (RHSKnown.isAllOnes() || RHSKnown.isNonNegative()) {
         // We know that the sign bit is zero.
@@ -802,8 +784,7 @@ static void computeKnownBitsFromCmp(const Value *V, const ICmpInst *Cmp,
   case ICmpInst::ICMP_SLE:
     // assume(v <=_s c) where c is negative
     if (match(Cmp, m_ICmp(Pred, m_V, m_Value(A)))) {
-      KnownBits RHSKnown =
-          computeKnownBits(A, Depth + 1, QueryNoAC).anyextOrTrunc(BitWidth);
+      KnownBits RHSKnown = computeKnownBits(A, Depth + 1, QueryNoAC);
 
       if (RHSKnown.isNegative()) {
         // We know that the sign bit is one.
@@ -814,8 +795,7 @@ static void computeKnownBitsFromCmp(const Value *V, const ICmpInst *Cmp,
   case ICmpInst::ICMP_SLT:
     // assume(v <_s c) where c is non-positive
     if (match(Cmp, m_ICmp(Pred, m_V, m_Value(A)))) {
-      KnownBits RHSKnown =
-          computeKnownBits(A, Depth + 1, QueryNoAC).anyextOrTrunc(BitWidth);
+      KnownBits RHSKnown = computeKnownBits(A, Depth + 1, QueryNoAC);
 
       if (RHSKnown.isZero() || RHSKnown.isNegative()) {
         // We know that the sign bit is one.
@@ -826,8 +806,7 @@ static void computeKnownBitsFromCmp(const Value *V, const ICmpInst *Cmp,
   case ICmpInst::ICMP_ULE:
     // assume(v <=_u c)
     if (match(Cmp, m_ICmp(Pred, m_V, m_Value(A)))) {
-      KnownBits RHSKnown =
-          computeKnownBits(A, Depth + 1, QueryNoAC).anyextOrTrunc(BitWidth);
+      KnownBits RHSKnown = computeKnownBits(A, Depth + 1, QueryNoAC);
 
       // Whatever high bits in c are zero are known to be zero.
       Known.Zero.setHighBits(RHSKnown.countMinLeadingZeros());
@@ -836,8 +815,7 @@ static void computeKnownBitsFromCmp(const Value *V, const ICmpInst *Cmp,
   case ICmpInst::ICMP_ULT:
     // assume(v <_u c)
     if (match(Cmp, m_ICmp(Pred, m_V, m_Value(A)))) {
-      KnownBits RHSKnown =
-          computeKnownBits(A, Depth + 1, QueryNoAC).anyextOrTrunc(BitWidth);
+      KnownBits RHSKnown = computeKnownBits(A, Depth + 1, QueryNoAC);
 
       // If the RHS is known zero, then this assumption must be wrong (nothing
       // is unsigned less than zero). Signal a conflict and get out of here.
@@ -859,7 +837,7 @@ static void computeKnownBitsFromCmp(const Value *V, const ICmpInst *Cmp,
     // assume (v & b != 0) where b is a power of 2
     const APInt *BPow2;
     if (match(Cmp, m_ICmp(Pred, m_c_And(m_V, m_Power2(BPow2)), m_Zero()))) {
-      Known.One |= BPow2->zextOrTrunc(BitWidth);
+      Known.One |= *BPow2;
     }
   } break;
   }

diff  --git a/llvm/unittests/Analysis/ValueTrackingTest.cpp b/llvm/unittests/Analysis/ValueTrackingTest.cpp
index e8f15f3c5f969e..4eed22d224588f 100644
--- a/llvm/unittests/Analysis/ValueTrackingTest.cpp
+++ b/llvm/unittests/Analysis/ValueTrackingTest.cpp
@@ -2254,7 +2254,7 @@ TEST_F(ComputeKnownBitsTest, ComputeKnownUSubSatZerosPreserved) {
 }
 
 TEST_F(ComputeKnownBitsTest, ComputeKnownBitsPtrToIntTrunc) {
-  // ptrtoint truncates the pointer type.
+  // ptrtoint truncates the pointer type. Make sure we don't crash.
   parseAssembly(
       "define void @test(ptr %p) {\n"
       "  %A = load ptr, ptr %p\n"
@@ -2268,12 +2268,11 @@ TEST_F(ComputeKnownBitsTest, ComputeKnownBitsPtrToIntTrunc) {
   AssumptionCache AC(*F);
   KnownBits Known = computeKnownBits(
       A, M->getDataLayout(), /* Depth */ 0, &AC, F->front().getTerminator());
-  EXPECT_EQ(Known.Zero.getZExtValue(), 31u);
-  EXPECT_EQ(Known.One.getZExtValue(), 0u);
+  EXPECT_TRUE(Known.isUnknown());
 }
 
 TEST_F(ComputeKnownBitsTest, ComputeKnownBitsPtrToIntZext) {
-  // ptrtoint zero extends the pointer type.
+  // ptrtoint zero extends the pointer type. Make sure we don't crash.
   parseAssembly(
       "define void @test(ptr %p) {\n"
       "  %A = load ptr, ptr %p\n"
@@ -2287,8 +2286,7 @@ TEST_F(ComputeKnownBitsTest, ComputeKnownBitsPtrToIntZext) {
   AssumptionCache AC(*F);
   KnownBits Known = computeKnownBits(
       A, M->getDataLayout(), /* Depth */ 0, &AC, F->front().getTerminator());
-  EXPECT_EQ(Known.Zero.getZExtValue(), 31u);
-  EXPECT_EQ(Known.One.getZExtValue(), 0u);
+  EXPECT_TRUE(Known.isUnknown());
 }
 
 TEST_F(ComputeKnownBitsTest, ComputeKnownBitsFreeze) {


        


More information about the llvm-commits mailing list