[llvm] 9c0314f - [ValueTracking] Switch isKnownNonZero() to switch over opcodes (NFCI)

Nikita Popov via llvm-commits llvm-commits at lists.llvm.org
Tue Oct 4 01:54:40 PDT 2022


Author: Nikita Popov
Date: 2022-10-04T10:54:28+02:00
New Revision: 9c0314f54ed693f4eeb38c6337c5c990e1594068

URL: https://github.com/llvm/llvm-project/commit/9c0314f54ed693f4eeb38c6337c5c990e1594068
DIFF: https://github.com/llvm/llvm-project/commit/9c0314f54ed693f4eeb38c6337c5c990e1594068.diff

LOG: [ValueTracking] Switch isKnownNonZero() to switch over opcodes (NFCI)

The change in the assume-queries-counter.ll test is because we skip
and unnecessary known bits query for arguments.

Added: 
    

Modified: 
    llvm/lib/Analysis/ValueTracking.cpp
    llvm/test/Analysis/ValueTracking/assume-queries-counter.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 4d631945b441e..e7c7d636d5277 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -2568,98 +2568,104 @@ bool isKnownNonZero(const Value *V, const APInt &DemandedElts, unsigned Depth,
   if (isKnownNonNullFromDominatingCondition(V, Q.CxtI, Q.DT))
     return true;
 
-  // Check for recursive pointer simplifications.
-  if (V->getType()->isPointerTy()) {
-    // Look through bitcast operations, GEPs, and int2ptr instructions as they
-    // do not alter the value, or at least not the nullness property of the
-    // value, e.g., int2ptr is allowed to zero/sign extend the value.
-    //
+  const Operator *I = dyn_cast<Operator>(V);
+  if (!I)
+    return false;
+
+  unsigned BitWidth = getBitWidth(V->getType()->getScalarType(), Q.DL);
+  switch (I->getOpcode()) {
+  case Instruction::GetElementPtr:
+    if (I->getType()->isPointerTy())
+      return isGEPKnownNonNull(cast<GEPOperator>(I), Depth, Q);
+    break;
+  case Instruction::BitCast:
+    if (I->getType()->isPointerTy())
+      return isKnownNonZero(I->getOperand(0), Depth, Q);
+    break;
+  case Instruction::IntToPtr:
     // Note that we have to take special care to avoid looking through
     // truncating casts, e.g., int2ptr/ptr2int with appropriate sizes, as well
     // as casts that can alter the value, e.g., AddrSpaceCasts.
-    if (const GEPOperator *GEP = dyn_cast<GEPOperator>(V))
-      return isGEPKnownNonNull(GEP, Depth, Q);
-
-    if (auto *BCO = dyn_cast<BitCastOperator>(V))
-      return isKnownNonZero(BCO->getOperand(0), Depth, Q);
-
-    if (auto *I2P = dyn_cast<IntToPtrInst>(V))
-      if (Q.DL.getTypeSizeInBits(I2P->getSrcTy()).getFixedSize() <=
-          Q.DL.getTypeSizeInBits(I2P->getDestTy()).getFixedSize())
-        return isKnownNonZero(I2P->getOperand(0), Depth, Q);
-  }
-
-  // Similar to int2ptr above, we can look through ptr2int here if the cast
-  // is a no-op or an extend and not a truncate.
-  if (auto *P2I = dyn_cast<PtrToIntInst>(V))
-    if (Q.DL.getTypeSizeInBits(P2I->getSrcTy()).getFixedSize() <=
-        Q.DL.getTypeSizeInBits(P2I->getDestTy()).getFixedSize())
-      return isKnownNonZero(P2I->getOperand(0), Depth, Q);
-
-  unsigned BitWidth = getBitWidth(V->getType()->getScalarType(), Q.DL);
-
-  // X | Y != 0 if X != 0 or Y != 0.
-  Value *X = nullptr, *Y = nullptr;
-  if (match(V, m_Or(m_Value(X), m_Value(Y))))
-    return isKnownNonZero(X, DemandedElts, Depth, Q) ||
-           isKnownNonZero(Y, DemandedElts, Depth, Q);
-
-  // ext X != 0 if X != 0.
-  if (isa<SExtInst>(V) || isa<ZExtInst>(V))
-    return isKnownNonZero(cast<Instruction>(V)->getOperand(0), Depth, Q);
+    if (Q.DL.getTypeSizeInBits(I->getOperand(0)->getType()).getFixedSize() <=
+        Q.DL.getTypeSizeInBits(I->getType()).getFixedSize())
+      return isKnownNonZero(I->getOperand(0), Depth, Q);
+    break;
+  case Instruction::PtrToInt:
+    // Similar to int2ptr above, we can look through ptr2int here if the cast
+    // is a no-op or an extend and not a truncate.
+    if (Q.DL.getTypeSizeInBits(I->getOperand(0)->getType()).getFixedSize() <=
+        Q.DL.getTypeSizeInBits(I->getType()).getFixedSize())
+      return isKnownNonZero(I->getOperand(0), Depth, Q);
+    break;
+  case Instruction::Or:
+    // X | Y != 0 if X != 0 or Y != 0.
+    return isKnownNonZero(I->getOperand(0), DemandedElts, Depth, Q) ||
+           isKnownNonZero(I->getOperand(1), DemandedElts, Depth, Q);
+  case Instruction::SExt:
+  case Instruction::ZExt:
+    // ext X != 0 if X != 0.
+    return isKnownNonZero(I->getOperand(0), Depth, Q);
 
-  // shl X, Y != 0 if X is odd.  Note that the value of the shift is undefined
-  // if the lowest bit is shifted off the end.
-  if (match(V, m_Shl(m_Value(X), m_Value(Y)))) {
+  case Instruction::Shl: {
     // shl nuw can't remove any non-zero bits.
     const OverflowingBinaryOperator *BO = cast<OverflowingBinaryOperator>(V);
     if (Q.IIQ.hasNoUnsignedWrap(BO))
-      return isKnownNonZero(X, Depth, Q);
+      return isKnownNonZero(I->getOperand(0), Depth, Q);
 
+    // shl X, Y != 0 if X is odd.  Note that the value of the shift is undefined
+    // if the lowest bit is shifted off the end.
     KnownBits Known(BitWidth);
-    computeKnownBits(X, DemandedElts, Known, Depth, Q);
+    computeKnownBits(I->getOperand(0), DemandedElts, Known, Depth, Q);
     if (Known.One[0])
       return true;
+    break;
   }
-  // shr X, Y != 0 if X is negative.  Note that the value of the shift is not
-  // defined if the sign bit is shifted off the end.
-  else if (match(V, m_Shr(m_Value(X), m_Value(Y)))) {
+  case Instruction::LShr:
+  case Instruction::AShr: {
     // shr exact can only shift out zero bits.
     const PossiblyExactOperator *BO = cast<PossiblyExactOperator>(V);
     if (BO->isExact())
-      return isKnownNonZero(X, Depth, Q);
+      return isKnownNonZero(I->getOperand(0), Depth, Q);
 
-    KnownBits Known = computeKnownBits(X, DemandedElts, Depth, Q);
+    // shr X, Y != 0 if X is negative.  Note that the value of the shift is not
+    // defined if the sign bit is shifted off the end.
+    KnownBits Known =
+        computeKnownBits(I->getOperand(0), DemandedElts, Depth, Q);
     if (Known.isNegative())
       return true;
 
     // If the shifter operand is a constant, and all of the bits shifted
     // out are known to be zero, and X is known non-zero then at least one
     // non-zero bit must remain.
-    if (ConstantInt *Shift = dyn_cast<ConstantInt>(Y)) {
+    if (ConstantInt *Shift = dyn_cast<ConstantInt>(I->getOperand(1))) {
       auto ShiftVal = Shift->getLimitedValue(BitWidth - 1);
       // Is there a known one in the portion not shifted out?
       if (Known.countMaxLeadingZeros() < BitWidth - ShiftVal)
         return true;
       // Are all the bits to be shifted out known zero?
       if (Known.countMinTrailingZeros() >= ShiftVal)
-        return isKnownNonZero(X, DemandedElts, Depth, Q);
+        return isKnownNonZero(I->getOperand(0), DemandedElts, Depth, Q);
     }
+    break;
   }
-  // div exact can only produce a zero if the dividend is zero.
-  else if (match(V, m_Exact(m_IDiv(m_Value(X), m_Value())))) {
-    return isKnownNonZero(X, DemandedElts, Depth, Q);
-  }
-  // X + Y.
-  else if (match(V, m_Add(m_Value(X), m_Value(Y)))) {
-    KnownBits XKnown = computeKnownBits(X, DemandedElts, Depth, Q);
-    KnownBits YKnown = computeKnownBits(Y, DemandedElts, Depth, Q);
+  case Instruction::UDiv:
+  case Instruction::SDiv:
+    // div exact can only produce a zero if the dividend is zero.
+    if (cast<PossiblyExactOperator>(I)->isExact())
+      return isKnownNonZero(I->getOperand(0), DemandedElts, Depth, Q);
+    break;
+  case Instruction::Add: {
+    // X + Y.
+    KnownBits XKnown =
+        computeKnownBits(I->getOperand(0), DemandedElts, Depth, Q);
+    KnownBits YKnown =
+        computeKnownBits(I->getOperand(1), DemandedElts, Depth, Q);
 
     // If X and Y are both non-negative (as signed values) then their sum is not
     // zero unless both X and Y are zero.
     if (XKnown.isNonNegative() && YKnown.isNonNegative())
-      if (isKnownNonZero(X, DemandedElts, Depth, Q) ||
-          isKnownNonZero(Y, DemandedElts, Depth, Q))
+      if (isKnownNonZero(I->getOperand(0), DemandedElts, Depth, Q) ||
+          isKnownNonZero(I->getOperand(1), DemandedElts, Depth, Q))
         return true;
 
     // If X and Y are both negative (as signed values) then their sum is not
@@ -2678,30 +2684,31 @@ bool isKnownNonZero(const Value *V, const APInt &DemandedElts, unsigned Depth,
 
     // The sum of a non-negative number and a power of two is not zero.
     if (XKnown.isNonNegative() &&
-        isKnownToBeAPowerOfTwo(Y, /*OrZero*/ false, Depth, Q))
+        isKnownToBeAPowerOfTwo(I->getOperand(1), /*OrZero*/ false, Depth, Q))
       return true;
     if (YKnown.isNonNegative() &&
-        isKnownToBeAPowerOfTwo(X, /*OrZero*/ false, Depth, Q))
+        isKnownToBeAPowerOfTwo(I->getOperand(0), /*OrZero*/ false, Depth, Q))
       return true;
+    break;
   }
-  // X * Y.
-  else if (match(V, m_Mul(m_Value(X), m_Value(Y)))) {
-    const OverflowingBinaryOperator *BO = cast<OverflowingBinaryOperator>(V);
+  case Instruction::Mul: {
     // If X and Y are non-zero then so is X * Y as long as the multiplication
     // does not overflow.
+    const OverflowingBinaryOperator *BO = cast<OverflowingBinaryOperator>(V);
     if ((Q.IIQ.hasNoSignedWrap(BO) || Q.IIQ.hasNoUnsignedWrap(BO)) &&
-        isKnownNonZero(X, DemandedElts, Depth, Q) &&
-        isKnownNonZero(Y, DemandedElts, Depth, Q))
+        isKnownNonZero(I->getOperand(0), DemandedElts, Depth, Q) &&
+        isKnownNonZero(I->getOperand(1), DemandedElts, Depth, Q))
       return true;
+    break;
   }
-  // (C ? X : Y) != 0 if X != 0 and Y != 0.
-  else if (const SelectInst *SI = dyn_cast<SelectInst>(V)) {
-    if (isKnownNonZero(SI->getTrueValue(), DemandedElts, Depth, Q) &&
-        isKnownNonZero(SI->getFalseValue(), DemandedElts, Depth, Q))
+  case Instruction::Select:
+    // (C ? X : Y) != 0 if X != 0 and Y != 0.
+    if (isKnownNonZero(I->getOperand(1), DemandedElts, Depth, Q) &&
+        isKnownNonZero(I->getOperand(2), DemandedElts, Depth, Q))
       return true;
-  }
-  // PHI
-  else if (const PHINode *PN = dyn_cast<PHINode>(V)) {
+    break;
+  case Instruction::PHI: {
+    auto *PN = cast<PHINode>(I);
     if (Q.IIQ.UseInstrInfo && isNonZeroRecurrence(PN))
       return true;
 
@@ -2715,28 +2722,29 @@ bool isKnownNonZero(const Value *V, const APInt &DemandedElts, unsigned Depth,
       return isKnownNonZero(U.get(), DemandedElts, NewDepth, RecQ);
     });
   }
-  // ExtractElement
-  else if (const auto *EEI = dyn_cast<ExtractElementInst>(V)) {
-    const Value *Vec = EEI->getVectorOperand();
-    const Value *Idx = EEI->getIndexOperand();
-    auto *CIdx = dyn_cast<ConstantInt>(Idx);
-    if (auto *VecTy = dyn_cast<FixedVectorType>(Vec->getType())) {
-      unsigned NumElts = VecTy->getNumElements();
-      APInt DemandedVecElts = APInt::getAllOnes(NumElts);
-      if (CIdx && CIdx->getValue().ult(NumElts))
-        DemandedVecElts = APInt::getOneBitSet(NumElts, CIdx->getZExtValue());
-      return isKnownNonZero(Vec, DemandedVecElts, Depth, Q);
+  case Instruction::ExtractElement:
+    if (const auto *EEI = dyn_cast<ExtractElementInst>(V)) {
+      const Value *Vec = EEI->getVectorOperand();
+      const Value *Idx = EEI->getIndexOperand();
+      auto *CIdx = dyn_cast<ConstantInt>(Idx);
+      if (auto *VecTy = dyn_cast<FixedVectorType>(Vec->getType())) {
+        unsigned NumElts = VecTy->getNumElements();
+        APInt DemandedVecElts = APInt::getAllOnes(NumElts);
+        if (CIdx && CIdx->getValue().ult(NumElts))
+          DemandedVecElts = APInt::getOneBitSet(NumElts, CIdx->getZExtValue());
+        return isKnownNonZero(Vec, DemandedVecElts, Depth, Q);
+      }
     }
-  }
-  // Freeze
-  else if (const FreezeInst *FI = dyn_cast<FreezeInst>(V)) {
-    auto *Op = FI->getOperand(0);
-    if (isKnownNonZero(Op, Depth, Q) &&
-        isGuaranteedNotToBePoison(Op, Q.AC, Q.CxtI, Q.DT, Depth))
+    break;
+  case Instruction::Freeze:
+    if (isKnownNonZero(I->getOperand(0), Depth, Q) &&
+        isGuaranteedNotToBePoison(I->getOperand(0), Q.AC, Q.CxtI, Q.DT, Depth))
       return true;
-  } else if (const auto *II = dyn_cast<IntrinsicInst>(V)) {
-    if (II->getIntrinsicID() == Intrinsic::vscale)
+    break;
+  case Instruction::Call:
+    if (cast<CallInst>(I)->getIntrinsicID() == Intrinsic::vscale)
       return true;
+    break;
   }
 
   KnownBits Known(BitWidth);

diff  --git a/llvm/test/Analysis/ValueTracking/assume-queries-counter.ll b/llvm/test/Analysis/ValueTracking/assume-queries-counter.ll
index f4e7956bc8cf0..406403de4942d 100644
--- a/llvm/test/Analysis/ValueTracking/assume-queries-counter.ll
+++ b/llvm/test/Analysis/ValueTracking/assume-queries-counter.ll
@@ -35,9 +35,8 @@ define dso_local i1 @test2(i32* readonly %0) {
 ; COUNTER1-NEXT:    ret i1 [[TMP2]]
 ;
 ; COUNTER2-LABEL: @test2(
-; COUNTER2-NEXT:    [[TMP2:%.*]] = icmp eq i32* [[TMP0:%.*]], null
-; COUNTER2-NEXT:    call void @llvm.assume(i1 true) [ "nonnull"(i32* [[TMP0]]) ]
-; COUNTER2-NEXT:    ret i1 [[TMP2]]
+; COUNTER2-NEXT:    call void @llvm.assume(i1 true) [ "nonnull"(i32* [[TMP0:%.*]]) ]
+; COUNTER2-NEXT:    ret i1 false
 ;
 ; COUNTER3-LABEL: @test2(
 ; COUNTER3-NEXT:    call void @llvm.assume(i1 true) [ "nonnull"(i32* [[TMP0:%.*]]) ]


        


More information about the llvm-commits mailing list