[llvm] goldsteinn/demanded elts consistent (PR #99080)

via llvm-commits llvm-commits at lists.llvm.org
Tue Jul 16 11:50:35 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-analysis

Author: None (goldsteinn)

<details>
<summary>Changes</summary>

- **[ValueTracking] Consistently propagate `DemandedElts` is `computeKnownBits`**
- **[ValueTracking] Consistently propagate `DemandedElts` is `isKnownNonZero`**
- **[ValueTracking] Consistently propagate `DemandedElts` is `ComputeNumSignBits`**
- **[ValueTracking] Consistently propagate `DemandedElts` is `computeKnownFPClass`**


---

Patch is 34.76 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/99080.diff


2 Files Affected:

- (modified) llvm/include/llvm/Analysis/ValueTracking.h (+17-5) 
- (modified) llvm/lib/Analysis/ValueTracking.cpp (+146-104) 


``````````diff
diff --git a/llvm/include/llvm/Analysis/ValueTracking.h b/llvm/include/llvm/Analysis/ValueTracking.h
index 354ad5bc95317..2c2f965a3cd6f 100644
--- a/llvm/include/llvm/Analysis/ValueTracking.h
+++ b/llvm/include/llvm/Analysis/ValueTracking.h
@@ -526,16 +526,17 @@ inline KnownFPClass computeKnownFPClass(
 }
 
 /// Wrapper to account for known fast math flags at the use instruction.
-inline KnownFPClass computeKnownFPClass(const Value *V, FastMathFlags FMF,
-                                        FPClassTest InterestedClasses,
-                                        unsigned Depth,
-                                        const SimplifyQuery &SQ) {
+inline KnownFPClass
+computeKnownFPClass(const Value *V, const APInt &DemandedElts,
+                    FastMathFlags FMF, FPClassTest InterestedClasses,
+                    unsigned Depth, const SimplifyQuery &SQ) {
   if (FMF.noNaNs())
     InterestedClasses &= ~fcNan;
   if (FMF.noInfs())
     InterestedClasses &= ~fcInf;
 
-  KnownFPClass Result = computeKnownFPClass(V, InterestedClasses, Depth, SQ);
+  KnownFPClass Result =
+      computeKnownFPClass(V, DemandedElts, InterestedClasses, Depth, SQ);
 
   if (FMF.noNaNs())
     Result.KnownFPClasses &= ~fcNan;
@@ -544,6 +545,17 @@ inline KnownFPClass computeKnownFPClass(const Value *V, FastMathFlags FMF,
   return Result;
 }
 
+inline KnownFPClass computeKnownFPClass(const Value *V, FastMathFlags FMF,
+                                        FPClassTest InterestedClasses,
+                                        unsigned Depth,
+                                        const SimplifyQuery &SQ) {
+  auto *FVTy = dyn_cast<FixedVectorType>(V->getType());
+  APInt DemandedElts =
+      FVTy ? APInt::getAllOnes(FVTy->getNumElements()) : APInt(1, 1);
+  return computeKnownFPClass(V, DemandedElts, FMF, InterestedClasses, Depth,
+                             SQ);
+}
+
 /// Return true if we can prove that the specified FP value is never equal to
 /// -0.0. Users should use caution when considering PreserveSign
 /// denormal-fp-math.
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index f8ec868398323..6e039ad2deadb 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -303,15 +303,21 @@ bool llvm::isKnownNegative(const Value *V, const SimplifyQuery &SQ,
   return computeKnownBits(V, Depth, SQ).isNegative();
 }
 
-static bool isKnownNonEqual(const Value *V1, const Value *V2, unsigned Depth,
+static bool isKnownNonEqual(const Value *V1, const Value *V2,
+                            const APInt &DemandedElts, unsigned Depth,
                             const SimplifyQuery &Q);
 
 bool llvm::isKnownNonEqual(const Value *V1, const Value *V2,
                            const DataLayout &DL, AssumptionCache *AC,
                            const Instruction *CxtI, const DominatorTree *DT,
                            bool UseInstrInfo) {
+  assert(V1->getType() == V2->getType() &&
+         "Testing equality of non-equal types!");
+  auto *FVTy = dyn_cast<FixedVectorType>(V1->getType());
+  APInt DemandedElts =
+      FVTy ? APInt::getAllOnes(FVTy->getNumElements()) : APInt(1, 1);
   return ::isKnownNonEqual(
-      V1, V2, 0,
+      V1, V2, DemandedElts, 0,
       SimplifyQuery(DL, DT, AC, safeCxtI(V2, V1, CxtI), UseInstrInfo));
 }
 
@@ -1091,15 +1097,15 @@ static void computeKnownBitsFromOperator(const Operator *I,
     break;
   }
   case Instruction::UDiv: {
-    computeKnownBits(I->getOperand(0), Known, Depth + 1, Q);
-    computeKnownBits(I->getOperand(1), Known2, Depth + 1, Q);
+    computeKnownBits(I->getOperand(0), DemandedElts, Known, Depth + 1, Q);
+    computeKnownBits(I->getOperand(1), DemandedElts, Known2, Depth + 1, Q);
     Known =
         KnownBits::udiv(Known, Known2, Q.IIQ.isExact(cast<BinaryOperator>(I)));
     break;
   }
   case Instruction::SDiv: {
-    computeKnownBits(I->getOperand(0), Known, Depth + 1, Q);
-    computeKnownBits(I->getOperand(1), Known2, Depth + 1, Q);
+    computeKnownBits(I->getOperand(0), DemandedElts, Known, Depth + 1, Q);
+    computeKnownBits(I->getOperand(1), DemandedElts, Known2, Depth + 1, Q);
     Known =
         KnownBits::sdiv(Known, Known2, Q.IIQ.isExact(cast<BinaryOperator>(I)));
     break;
@@ -1107,7 +1113,7 @@ static void computeKnownBitsFromOperator(const Operator *I,
   case Instruction::Select: {
     auto ComputeForArm = [&](Value *Arm, bool Invert) {
       KnownBits Res(Known.getBitWidth());
-      computeKnownBits(Arm, Res, Depth + 1, Q);
+      computeKnownBits(Arm, DemandedElts, Res, Depth + 1, Q);
       adjustKnownBitsForSelectArm(Res, I->getOperand(0), Arm, Invert, Depth, Q);
       return Res;
     };
@@ -1142,7 +1148,7 @@ static void computeKnownBitsFromOperator(const Operator *I,
 
     assert(SrcBitWidth && "SrcBitWidth can't be zero");
     Known = Known.anyextOrTrunc(SrcBitWidth);
-    computeKnownBits(I->getOperand(0), Known, Depth + 1, Q);
+    computeKnownBits(I->getOperand(0), DemandedElts, Known, Depth + 1, Q);
     if (auto *Inst = dyn_cast<PossiblyNonNegInst>(I);
         Inst && Inst->hasNonNeg() && !Known.isNegative())
       Known.makeNonNegative();
@@ -1164,7 +1170,8 @@ static void computeKnownBitsFromOperator(const Operator *I,
     if (match(I, m_ElementWiseBitCast(m_Value(V))) &&
         V->getType()->isFPOrFPVectorTy()) {
       Type *FPType = V->getType()->getScalarType();
-      KnownFPClass Result = computeKnownFPClass(V, fcAllFlags, Depth + 1, Q);
+      KnownFPClass Result =
+          computeKnownFPClass(V, DemandedElts, fcAllFlags, Depth + 1, Q);
       FPClassTest FPClasses = Result.KnownFPClasses;
 
       // TODO: Treat it as zero/poison if the use of I is unreachable.
@@ -1245,7 +1252,7 @@ static void computeKnownBitsFromOperator(const Operator *I,
     unsigned SrcBitWidth = I->getOperand(0)->getType()->getScalarSizeInBits();
 
     Known = Known.trunc(SrcBitWidth);
-    computeKnownBits(I->getOperand(0), Known, Depth + 1, Q);
+    computeKnownBits(I->getOperand(0), DemandedElts, Known, Depth + 1, Q);
     // If the sign bit of the input is known set or clear, then we know the
     // top bits of the result.
     Known = Known.sext(BitWidth);
@@ -1305,14 +1312,14 @@ static void computeKnownBitsFromOperator(const Operator *I,
     break;
   }
   case Instruction::SRem:
-    computeKnownBits(I->getOperand(0), Known, Depth + 1, Q);
-    computeKnownBits(I->getOperand(1), Known2, Depth + 1, Q);
+    computeKnownBits(I->getOperand(0), DemandedElts, Known, Depth + 1, Q);
+    computeKnownBits(I->getOperand(1), DemandedElts, Known2, Depth + 1, Q);
     Known = KnownBits::srem(Known, Known2);
     break;
 
   case Instruction::URem:
-    computeKnownBits(I->getOperand(0), Known, Depth + 1, Q);
-    computeKnownBits(I->getOperand(1), Known2, Depth + 1, Q);
+    computeKnownBits(I->getOperand(0), DemandedElts, Known, Depth + 1, Q);
+    computeKnownBits(I->getOperand(1), DemandedElts, Known2, Depth + 1, Q);
     Known = KnownBits::urem(Known, Known2);
     break;
   case Instruction::Alloca:
@@ -1465,17 +1472,17 @@ static void computeKnownBitsFromOperator(const Operator *I,
 
         unsigned OpNum = P->getOperand(0) == R ? 0 : 1;
         Instruction *RInst = P->getIncomingBlock(OpNum)->getTerminator();
-        Instruction *LInst = P->getIncomingBlock(1-OpNum)->getTerminator();
+        Instruction *LInst = P->getIncomingBlock(1 - OpNum)->getTerminator();
 
         // Ok, we have a PHI of the form L op= R. Check for low
         // zero bits.
         RecQ.CxtI = RInst;
-        computeKnownBits(R, Known2, Depth + 1, RecQ);
+        computeKnownBits(R, DemandedElts, Known2, Depth + 1, RecQ);
 
         // We need to take the minimum number of known bits
         KnownBits Known3(BitWidth);
         RecQ.CxtI = LInst;
-        computeKnownBits(L, Known3, Depth + 1, RecQ);
+        computeKnownBits(L, DemandedElts, Known3, Depth + 1, RecQ);
 
         Known.Zero.setLowBits(std::min(Known2.countMinTrailingZeros(),
                                        Known3.countMinTrailingZeros()));
@@ -1548,7 +1555,8 @@ static void computeKnownBitsFromOperator(const Operator *I,
         // want to waste time spinning around in loops.
         // TODO: See if we can base recursion limiter on number of incoming phi
         // edges so we don't overly clamp analysis.
-        computeKnownBits(IncValue, Known2, MaxAnalysisRecursionDepth - 1, RecQ);
+        computeKnownBits(IncValue, DemandedElts, Known2,
+                         MaxAnalysisRecursionDepth - 1, RecQ);
 
         // See if we can further use a conditional branch into the phi
         // to help us determine the range of the value.
@@ -1619,9 +1627,10 @@ static void computeKnownBitsFromOperator(const Operator *I,
     }
     if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) {
       switch (II->getIntrinsicID()) {
-      default: break;
+      default:
+        break;
       case Intrinsic::abs: {
-        computeKnownBits(I->getOperand(0), Known2, Depth + 1, Q);
+        computeKnownBits(I->getOperand(0), DemandedElts, Known2, Depth + 1, Q);
         bool IntMinIsPoison = match(II->getArgOperand(1), m_One());
         Known = Known2.abs(IntMinIsPoison);
         break;
@@ -1637,7 +1646,7 @@ static void computeKnownBitsFromOperator(const Operator *I,
         Known.One |= Known2.One.byteSwap();
         break;
       case Intrinsic::ctlz: {
-        computeKnownBits(I->getOperand(0), Known2, Depth + 1, Q);
+        computeKnownBits(I->getOperand(0), DemandedElts, Known2, Depth + 1, Q);
         // If we have a known 1, its position is our upper bound.
         unsigned PossibleLZ = Known2.countMaxLeadingZeros();
         // If this call is poison for 0 input, the result will be less than 2^n.
@@ -1648,7 +1657,7 @@ static void computeKnownBitsFromOperator(const Operator *I,
         break;
       }
       case Intrinsic::cttz: {
-        computeKnownBits(I->getOperand(0), Known2, Depth + 1, Q);
+        computeKnownBits(I->getOperand(0), DemandedElts, Known2, Depth + 1, Q);
         // If we have a known 1, its position is our upper bound.
         unsigned PossibleTZ = Known2.countMaxTrailingZeros();
         // If this call is poison for 0 input, the result will be less than 2^n.
@@ -1659,7 +1668,7 @@ static void computeKnownBitsFromOperator(const Operator *I,
         break;
       }
       case Intrinsic::ctpop: {
-        computeKnownBits(I->getOperand(0), Known2, Depth + 1, Q);
+        computeKnownBits(I->getOperand(0), DemandedElts, Known2, Depth + 1, Q);
         // We can bound the space the count needs.  Also, bits known to be zero
         // can't contribute to the population.
         unsigned BitsPossiblySet = Known2.countMaxPopulation();
@@ -1681,8 +1690,8 @@ static void computeKnownBitsFromOperator(const Operator *I,
           ShiftAmt = BitWidth - ShiftAmt;
 
         KnownBits Known3(BitWidth);
-        computeKnownBits(I->getOperand(0), Known2, Depth + 1, Q);
-        computeKnownBits(I->getOperand(1), Known3, Depth + 1, Q);
+        computeKnownBits(I->getOperand(0), DemandedElts, Known2, Depth + 1, Q);
+        computeKnownBits(I->getOperand(1), DemandedElts, Known3, Depth + 1, Q);
 
         Known.Zero =
             Known2.Zero.shl(ShiftAmt) | Known3.Zero.lshr(BitWidth - ShiftAmt);
@@ -1691,27 +1700,30 @@ static void computeKnownBitsFromOperator(const Operator *I,
         break;
       }
       case Intrinsic::uadd_sat:
-        computeKnownBits(I->getOperand(0), Known, Depth + 1, Q);
-        computeKnownBits(I->getOperand(1), Known2, Depth + 1, Q);
+        computeKnownBits(I->getOperand(0), DemandedElts, Known, Depth + 1, Q);
+        computeKnownBits(I->getOperand(1), DemandedElts, Known2, Depth + 1, Q);
         Known = KnownBits::uadd_sat(Known, Known2);
         break;
       case Intrinsic::usub_sat:
-        computeKnownBits(I->getOperand(0), Known, Depth + 1, Q);
-        computeKnownBits(I->getOperand(1), Known2, Depth + 1, Q);
+        computeKnownBits(I->getOperand(0), DemandedElts, Known, Depth + 1, Q);
+        computeKnownBits(I->getOperand(1), DemandedElts, Known2, Depth + 1, Q);
         Known = KnownBits::usub_sat(Known, Known2);
         break;
       case Intrinsic::sadd_sat:
-        computeKnownBits(I->getOperand(0), Known, Depth + 1, Q);
-        computeKnownBits(I->getOperand(1), Known2, Depth + 1, Q);
+        computeKnownBits(I->getOperand(0), DemandedElts, Known, Depth + 1, Q);
+        computeKnownBits(I->getOperand(1), DemandedElts, Known2, Depth + 1, Q);
         Known = KnownBits::sadd_sat(Known, Known2);
         break;
       case Intrinsic::ssub_sat:
-        computeKnownBits(I->getOperand(0), Known, Depth + 1, Q);
-        computeKnownBits(I->getOperand(1), Known2, Depth + 1, Q);
+        computeKnownBits(I->getOperand(0), DemandedElts, Known, Depth + 1, Q);
+        computeKnownBits(I->getOperand(1), DemandedElts, Known2, Depth + 1, Q);
         Known = KnownBits::ssub_sat(Known, Known2);
         break;
         // Vec reverse preserves bits from input vec.
       case Intrinsic::vector_reverse:
+        computeKnownBits(I->getOperand(0), DemandedElts.reverseBits(), Known,
+                         Depth + 1, Q);
+        break;
         // for min/max/and/or reduce, any bit common to each element in the
         // input vec is set in the output.
       case Intrinsic::vector_reduce_and:
@@ -1738,31 +1750,31 @@ static void computeKnownBitsFromOperator(const Operator *I,
         break;
       }
       case Intrinsic::umin:
-        computeKnownBits(I->getOperand(0), Known, Depth + 1, Q);
-        computeKnownBits(I->getOperand(1), Known2, Depth + 1, Q);
+        computeKnownBits(I->getOperand(0), DemandedElts, Known, Depth + 1, Q);
+        computeKnownBits(I->getOperand(1), DemandedElts, Known2, Depth + 1, Q);
         Known = KnownBits::umin(Known, Known2);
         break;
       case Intrinsic::umax:
-        computeKnownBits(I->getOperand(0), Known, Depth + 1, Q);
-        computeKnownBits(I->getOperand(1), Known2, Depth + 1, Q);
+        computeKnownBits(I->getOperand(0), DemandedElts, Known, Depth + 1, Q);
+        computeKnownBits(I->getOperand(1), DemandedElts, Known2, Depth + 1, Q);
         Known = KnownBits::umax(Known, Known2);
         break;
       case Intrinsic::smin:
-        computeKnownBits(I->getOperand(0), Known, Depth + 1, Q);
-        computeKnownBits(I->getOperand(1), Known2, Depth + 1, Q);
+        computeKnownBits(I->getOperand(0), DemandedElts, Known, Depth + 1, Q);
+        computeKnownBits(I->getOperand(1), DemandedElts, Known2, Depth + 1, Q);
         Known = KnownBits::smin(Known, Known2);
         break;
       case Intrinsic::smax:
-        computeKnownBits(I->getOperand(0), Known, Depth + 1, Q);
-        computeKnownBits(I->getOperand(1), Known2, Depth + 1, Q);
+        computeKnownBits(I->getOperand(0), DemandedElts, Known, Depth + 1, Q);
+        computeKnownBits(I->getOperand(1), DemandedElts, Known2, Depth + 1, Q);
         Known = KnownBits::smax(Known, Known2);
         break;
       case Intrinsic::ptrmask: {
-        computeKnownBits(I->getOperand(0), Known, Depth + 1, Q);
+        computeKnownBits(I->getOperand(0), DemandedElts, Known, Depth + 1, Q);
 
         const Value *Mask = I->getOperand(1);
         Known2 = KnownBits(Mask->getType()->getScalarSizeInBits());
-        computeKnownBits(Mask, Known2, Depth + 1, Q);
+        computeKnownBits(Mask, DemandedElts, Known2, Depth + 1, Q);
         // TODO: 1-extend would be more precise.
         Known &= Known2.anyextOrTrunc(BitWidth);
         break;
@@ -2648,7 +2660,7 @@ static bool isNonZeroSub(const APInt &DemandedElts, unsigned Depth,
     if (C->isNullValue() && isKnownNonZero(Y, DemandedElts, Q, Depth))
       return true;
 
-  return ::isKnownNonEqual(X, Y, Depth, Q);
+  return ::isKnownNonEqual(X, Y, DemandedElts, Depth, Q);
 }
 
 static bool isNonZeroMul(const APInt &DemandedElts, unsigned Depth,
@@ -2772,8 +2784,11 @@ static bool isKnownNonZeroFromOperator(const Operator *I,
     //    This all implies the 2 i16 elements are non-zero.
     Type *FromTy = I->getOperand(0)->getType();
     if ((FromTy->isIntOrIntVectorTy() || FromTy->isPtrOrPtrVectorTy()) &&
-        (BitWidth % getBitWidth(FromTy->getScalarType(), Q.DL)) == 0)
+        (BitWidth % getBitWidth(FromTy->getScalarType(), Q.DL)) == 0) {
+      if (match(I, m_ElementWiseBitCast(m_Value())))
+        return isKnownNonZero(I->getOperand(0), DemandedElts, Q, Depth);
       return isKnownNonZero(I->getOperand(0), Q, Depth);
+    }
   } break;
   case Instruction::IntToPtr:
     // Note that we have to take special care to avoid looking through
@@ -2782,7 +2797,7 @@ static bool isKnownNonZeroFromOperator(const Operator *I,
     if (!isa<ScalableVectorType>(I->getType()) &&
         Q.DL.getTypeSizeInBits(I->getOperand(0)->getType()).getFixedValue() <=
             Q.DL.getTypeSizeInBits(I->getType()).getFixedValue())
-      return isKnownNonZero(I->getOperand(0), Q, Depth);
+      return isKnownNonZero(I->getOperand(0), DemandedElts, Q, Depth);
     break;
   case Instruction::PtrToInt:
     // Similar to int2ptr above, we can look through ptr2int here if the cast
@@ -2790,13 +2805,13 @@ static bool isKnownNonZeroFromOperator(const Operator *I,
     if (!isa<ScalableVectorType>(I->getType()) &&
         Q.DL.getTypeSizeInBits(I->getOperand(0)->getType()).getFixedValue() <=
             Q.DL.getTypeSizeInBits(I->getType()).getFixedValue())
-      return isKnownNonZero(I->getOperand(0), Q, Depth);
+      return isKnownNonZero(I->getOperand(0), DemandedElts, Q, Depth);
     break;
   case Instruction::Trunc:
     // nuw/nsw trunc preserves zero/non-zero status of input.
     if (auto *TI = dyn_cast<TruncInst>(I))
       if (TI->hasNoSignedWrap() || TI->hasNoUnsignedWrap())
-        return isKnownNonZero(TI->getOperand(0), Q, Depth);
+        return isKnownNonZero(TI->getOperand(0), DemandedElts, Q, Depth);
     break;
 
   case Instruction::Sub:
@@ -2817,13 +2832,13 @@ static bool isKnownNonZeroFromOperator(const Operator *I,
   case Instruction::SExt:
   case Instruction::ZExt:
     // ext X != 0 if X != 0.
-    return isKnownNonZero(I->getOperand(0), Q, Depth);
+    return isKnownNonZero(I->getOperand(0), DemandedElts, Q, Depth);
 
   case Instruction::Shl: {
     // shl nsw/nuw can't remove any non-zero bits.
     const OverflowingBinaryOperator *BO = cast<OverflowingBinaryOperator>(I);
     if (Q.IIQ.hasNoUnsignedWrap(BO) || Q.IIQ.hasNoSignedWrap(BO))
-      return isKnownNonZero(I->getOperand(0), Q, Depth);
+      return isKnownNonZero(I->getOperand(0), DemandedElts, Q, Depth);
 
     // shl X, Y != 0 if X is odd.  Note that the value of the shift is undefined
     // if the lowest bit is shifted off the end.
@@ -2839,7 +2854,7 @@ static bool isKnownNonZeroFromOperator(const Operator *I,
     // shr exact can only shift out zero bits.
     const PossiblyExactOperator *BO = cast<PossiblyExactOperator>(I);
     if (BO->isExact())
-      return isKnownNonZero(I->getOperand(0), Q, Depth);
+      return isKnownNonZero(I->getOperand(0), DemandedElts, Q, Depth);
 
     // shr X, Y != 0 if X is negative.  Note that the value of the shift is not
     // defined if the sign bit is shifted off the end.
@@ -3094,6 +3109,8 @@ static bool isKnownNonZeroFromOperator(const Operator *I,
                             /*NSW=*/true, /* NUW=*/false);
         // Vec reverse preserves zero/non-zero status from input vec.
       case Intrinsic::vector_reverse:
+        return isKnownNonZero(II->getArgOperand(0), DemandedElts.reverseBits(),
+                              Q, Depth);
         // umin/smin/smax/smin/or of all non-zero elements is always non-zero.
       case Intrinsic::vector_reduce_or:
       case Intrinsic::vector_reduce_umax:
@@ -3418,7 +3435,8 @@ getInvertibleOperands(const Operator *Op1,
 /// Only handle a small subset of binops where (binop V2, X) with non-zero X
 /// implies V2 != V1.
 static bool isModifyingBinopOfNonZero(const Value *V1, const Value *V2,
-                                      unsigned Depth, const SimplifyQuery &Q) {
+                                      const APInt &DemandedElts, unsigned Depth,
+                                      const SimplifyQuery &Q) {
   const BinaryOperator *BO = dyn_cast<BinaryOperator>(V1);
   if (!BO)
     return false;
@@ -3438,39 +3456,43 @@ static bool isModifyingBinopOfNonZero(const Value *V1, con...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/99080


More information about the llvm-commits mailing list