[llvm] goldsteinn/demanded elts consistent (PR #99080)

via llvm-commits llvm-commits at lists.llvm.org
Tue Jul 16 11:50:04 PDT 2024


https://github.com/goldsteinn created https://github.com/llvm/llvm-project/pull/99080

- **[ValueTracking] Consistently propagate `DemandedElts` is `computeKnownBits`**
- **[ValueTracking] Consistently propagate `DemandedElts` is `isKnownNonZero`**
- **[ValueTracking] Consistently propagate `DemandedElts` is `ComputeNumSignBits`**
- **[ValueTracking] Consistently propagate `DemandedElts` is `computeKnownFPClass`**


>From 012caddcb9277297dd9beb0cd74f55ab4b380bac Mon Sep 17 00:00:00 2001
From: Noah Goldstein <goldstein.w.n at gmail.com>
Date: Tue, 16 Jul 2024 20:27:55 +0800
Subject: [PATCH 1/4] [ValueTracking] Consistently propagate `DemandedElts` is
 `computeKnownBits`

---
 llvm/lib/Analysis/ValueTracking.cpp | 88 +++++++++++++++--------------
 1 file changed, 47 insertions(+), 41 deletions(-)

diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index f8ec868398323..8bc0e7f23b81c 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -1091,15 +1091,15 @@ static void computeKnownBitsFromOperator(const Operator *I,
     break;
   }
   case Instruction::UDiv: {
-    computeKnownBits(I->getOperand(0), Known, Depth + 1, Q);
-    computeKnownBits(I->getOperand(1), Known2, Depth + 1, Q);
+    computeKnownBits(I->getOperand(0), DemandedElts, Known, Depth + 1, Q);
+    computeKnownBits(I->getOperand(1), DemandedElts, Known2, Depth + 1, Q);
     Known =
         KnownBits::udiv(Known, Known2, Q.IIQ.isExact(cast<BinaryOperator>(I)));
     break;
   }
   case Instruction::SDiv: {
-    computeKnownBits(I->getOperand(0), Known, Depth + 1, Q);
-    computeKnownBits(I->getOperand(1), Known2, Depth + 1, Q);
+    computeKnownBits(I->getOperand(0), DemandedElts, Known, Depth + 1, Q);
+    computeKnownBits(I->getOperand(1), DemandedElts, Known2, Depth + 1, Q);
     Known =
         KnownBits::sdiv(Known, Known2, Q.IIQ.isExact(cast<BinaryOperator>(I)));
     break;
@@ -1107,7 +1107,7 @@ static void computeKnownBitsFromOperator(const Operator *I,
   case Instruction::Select: {
     auto ComputeForArm = [&](Value *Arm, bool Invert) {
       KnownBits Res(Known.getBitWidth());
-      computeKnownBits(Arm, Res, Depth + 1, Q);
+      computeKnownBits(Arm, DemandedElts, Res, Depth + 1, Q);
       adjustKnownBitsForSelectArm(Res, I->getOperand(0), Arm, Invert, Depth, Q);
       return Res;
     };
@@ -1142,7 +1142,7 @@ static void computeKnownBitsFromOperator(const Operator *I,
 
     assert(SrcBitWidth && "SrcBitWidth can't be zero");
     Known = Known.anyextOrTrunc(SrcBitWidth);
-    computeKnownBits(I->getOperand(0), Known, Depth + 1, Q);
+    computeKnownBits(I->getOperand(0), DemandedElts, Known, Depth + 1, Q);
     if (auto *Inst = dyn_cast<PossiblyNonNegInst>(I);
         Inst && Inst->hasNonNeg() && !Known.isNegative())
       Known.makeNonNegative();
@@ -1164,7 +1164,8 @@ static void computeKnownBitsFromOperator(const Operator *I,
     if (match(I, m_ElementWiseBitCast(m_Value(V))) &&
         V->getType()->isFPOrFPVectorTy()) {
       Type *FPType = V->getType()->getScalarType();
-      KnownFPClass Result = computeKnownFPClass(V, fcAllFlags, Depth + 1, Q);
+      KnownFPClass Result =
+          computeKnownFPClass(V, DemandedElts, fcAllFlags, Depth + 1, Q);
       FPClassTest FPClasses = Result.KnownFPClasses;
 
       // TODO: Treat it as zero/poison if the use of I is unreachable.
@@ -1245,7 +1246,7 @@ static void computeKnownBitsFromOperator(const Operator *I,
     unsigned SrcBitWidth = I->getOperand(0)->getType()->getScalarSizeInBits();
 
     Known = Known.trunc(SrcBitWidth);
-    computeKnownBits(I->getOperand(0), Known, Depth + 1, Q);
+    computeKnownBits(I->getOperand(0), DemandedElts, Known, Depth + 1, Q);
     // If the sign bit of the input is known set or clear, then we know the
     // top bits of the result.
     Known = Known.sext(BitWidth);
@@ -1305,14 +1306,14 @@ static void computeKnownBitsFromOperator(const Operator *I,
     break;
   }
   case Instruction::SRem:
-    computeKnownBits(I->getOperand(0), Known, Depth + 1, Q);
-    computeKnownBits(I->getOperand(1), Known2, Depth + 1, Q);
+    computeKnownBits(I->getOperand(0), DemandedElts, Known, Depth + 1, Q);
+    computeKnownBits(I->getOperand(1), DemandedElts, Known2, Depth + 1, Q);
     Known = KnownBits::srem(Known, Known2);
     break;
 
   case Instruction::URem:
-    computeKnownBits(I->getOperand(0), Known, Depth + 1, Q);
-    computeKnownBits(I->getOperand(1), Known2, Depth + 1, Q);
+    computeKnownBits(I->getOperand(0), DemandedElts, Known, Depth + 1, Q);
+    computeKnownBits(I->getOperand(1), DemandedElts, Known2, Depth + 1, Q);
     Known = KnownBits::urem(Known, Known2);
     break;
   case Instruction::Alloca:
@@ -1465,17 +1466,17 @@ static void computeKnownBitsFromOperator(const Operator *I,
 
         unsigned OpNum = P->getOperand(0) == R ? 0 : 1;
         Instruction *RInst = P->getIncomingBlock(OpNum)->getTerminator();
-        Instruction *LInst = P->getIncomingBlock(1-OpNum)->getTerminator();
+        Instruction *LInst = P->getIncomingBlock(1 - OpNum)->getTerminator();
 
         // Ok, we have a PHI of the form L op= R. Check for low
         // zero bits.
         RecQ.CxtI = RInst;
-        computeKnownBits(R, Known2, Depth + 1, RecQ);
+        computeKnownBits(R, DemandedElts, Known2, Depth + 1, RecQ);
 
         // We need to take the minimum number of known bits
         KnownBits Known3(BitWidth);
         RecQ.CxtI = LInst;
-        computeKnownBits(L, Known3, Depth + 1, RecQ);
+        computeKnownBits(L, DemandedElts, Known3, Depth + 1, RecQ);
 
         Known.Zero.setLowBits(std::min(Known2.countMinTrailingZeros(),
                                        Known3.countMinTrailingZeros()));
@@ -1548,7 +1549,8 @@ static void computeKnownBitsFromOperator(const Operator *I,
         // want to waste time spinning around in loops.
         // TODO: See if we can base recursion limiter on number of incoming phi
         // edges so we don't overly clamp analysis.
-        computeKnownBits(IncValue, Known2, MaxAnalysisRecursionDepth - 1, RecQ);
+        computeKnownBits(IncValue, DemandedElts, Known2,
+                         MaxAnalysisRecursionDepth - 1, RecQ);
 
         // See if we can further use a conditional branch into the phi
         // to help us determine the range of the value.
@@ -1619,9 +1621,10 @@ static void computeKnownBitsFromOperator(const Operator *I,
     }
     if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) {
       switch (II->getIntrinsicID()) {
-      default: break;
+      default:
+        break;
       case Intrinsic::abs: {
-        computeKnownBits(I->getOperand(0), Known2, Depth + 1, Q);
+        computeKnownBits(I->getOperand(0), DemandedElts, Known2, Depth + 1, Q);
         bool IntMinIsPoison = match(II->getArgOperand(1), m_One());
         Known = Known2.abs(IntMinIsPoison);
         break;
@@ -1637,7 +1640,7 @@ static void computeKnownBitsFromOperator(const Operator *I,
         Known.One |= Known2.One.byteSwap();
         break;
       case Intrinsic::ctlz: {
-        computeKnownBits(I->getOperand(0), Known2, Depth + 1, Q);
+        computeKnownBits(I->getOperand(0), DemandedElts, Known2, Depth + 1, Q);
         // If we have a known 1, its position is our upper bound.
         unsigned PossibleLZ = Known2.countMaxLeadingZeros();
         // If this call is poison for 0 input, the result will be less than 2^n.
@@ -1648,7 +1651,7 @@ static void computeKnownBitsFromOperator(const Operator *I,
         break;
       }
       case Intrinsic::cttz: {
-        computeKnownBits(I->getOperand(0), Known2, Depth + 1, Q);
+        computeKnownBits(I->getOperand(0), DemandedElts, Known2, Depth + 1, Q);
         // If we have a known 1, its position is our upper bound.
         unsigned PossibleTZ = Known2.countMaxTrailingZeros();
         // If this call is poison for 0 input, the result will be less than 2^n.
@@ -1659,7 +1662,7 @@ static void computeKnownBitsFromOperator(const Operator *I,
         break;
       }
       case Intrinsic::ctpop: {
-        computeKnownBits(I->getOperand(0), Known2, Depth + 1, Q);
+        computeKnownBits(I->getOperand(0), DemandedElts, Known2, Depth + 1, Q);
         // We can bound the space the count needs.  Also, bits known to be zero
         // can't contribute to the population.
         unsigned BitsPossiblySet = Known2.countMaxPopulation();
@@ -1681,8 +1684,8 @@ static void computeKnownBitsFromOperator(const Operator *I,
           ShiftAmt = BitWidth - ShiftAmt;
 
         KnownBits Known3(BitWidth);
-        computeKnownBits(I->getOperand(0), Known2, Depth + 1, Q);
-        computeKnownBits(I->getOperand(1), Known3, Depth + 1, Q);
+        computeKnownBits(I->getOperand(0), DemandedElts, Known2, Depth + 1, Q);
+        computeKnownBits(I->getOperand(1), DemandedElts, Known3, Depth + 1, Q);
 
         Known.Zero =
             Known2.Zero.shl(ShiftAmt) | Known3.Zero.lshr(BitWidth - ShiftAmt);
@@ -1691,27 +1694,30 @@ static void computeKnownBitsFromOperator(const Operator *I,
         break;
       }
       case Intrinsic::uadd_sat:
-        computeKnownBits(I->getOperand(0), Known, Depth + 1, Q);
-        computeKnownBits(I->getOperand(1), Known2, Depth + 1, Q);
+        computeKnownBits(I->getOperand(0), DemandedElts, Known, Depth + 1, Q);
+        computeKnownBits(I->getOperand(1), DemandedElts, Known2, Depth + 1, Q);
         Known = KnownBits::uadd_sat(Known, Known2);
         break;
       case Intrinsic::usub_sat:
-        computeKnownBits(I->getOperand(0), Known, Depth + 1, Q);
-        computeKnownBits(I->getOperand(1), Known2, Depth + 1, Q);
+        computeKnownBits(I->getOperand(0), DemandedElts, Known, Depth + 1, Q);
+        computeKnownBits(I->getOperand(1), DemandedElts, Known2, Depth + 1, Q);
         Known = KnownBits::usub_sat(Known, Known2);
         break;
       case Intrinsic::sadd_sat:
-        computeKnownBits(I->getOperand(0), Known, Depth + 1, Q);
-        computeKnownBits(I->getOperand(1), Known2, Depth + 1, Q);
+        computeKnownBits(I->getOperand(0), DemandedElts, Known, Depth + 1, Q);
+        computeKnownBits(I->getOperand(1), DemandedElts, Known2, Depth + 1, Q);
         Known = KnownBits::sadd_sat(Known, Known2);
         break;
       case Intrinsic::ssub_sat:
-        computeKnownBits(I->getOperand(0), Known, Depth + 1, Q);
-        computeKnownBits(I->getOperand(1), Known2, Depth + 1, Q);
+        computeKnownBits(I->getOperand(0), DemandedElts, Known, Depth + 1, Q);
+        computeKnownBits(I->getOperand(1), DemandedElts, Known2, Depth + 1, Q);
         Known = KnownBits::ssub_sat(Known, Known2);
         break;
         // Vec reverse preserves bits from input vec.
       case Intrinsic::vector_reverse:
+        computeKnownBits(I->getOperand(0), DemandedElts.reverseBits(), Known,
+                         Depth + 1, Q);
+        break;
         // for min/max/and/or reduce, any bit common to each element in the
         // input vec is set in the output.
       case Intrinsic::vector_reduce_and:
@@ -1738,31 +1744,31 @@ static void computeKnownBitsFromOperator(const Operator *I,
         break;
       }
       case Intrinsic::umin:
-        computeKnownBits(I->getOperand(0), Known, Depth + 1, Q);
-        computeKnownBits(I->getOperand(1), Known2, Depth + 1, Q);
+        computeKnownBits(I->getOperand(0), DemandedElts, Known, Depth + 1, Q);
+        computeKnownBits(I->getOperand(1), DemandedElts, Known2, Depth + 1, Q);
         Known = KnownBits::umin(Known, Known2);
         break;
       case Intrinsic::umax:
-        computeKnownBits(I->getOperand(0), Known, Depth + 1, Q);
-        computeKnownBits(I->getOperand(1), Known2, Depth + 1, Q);
+        computeKnownBits(I->getOperand(0), DemandedElts, Known, Depth + 1, Q);
+        computeKnownBits(I->getOperand(1), DemandedElts, Known2, Depth + 1, Q);
         Known = KnownBits::umax(Known, Known2);
         break;
       case Intrinsic::smin:
-        computeKnownBits(I->getOperand(0), Known, Depth + 1, Q);
-        computeKnownBits(I->getOperand(1), Known2, Depth + 1, Q);
+        computeKnownBits(I->getOperand(0), DemandedElts, Known, Depth + 1, Q);
+        computeKnownBits(I->getOperand(1), DemandedElts, Known2, Depth + 1, Q);
         Known = KnownBits::smin(Known, Known2);
         break;
       case Intrinsic::smax:
-        computeKnownBits(I->getOperand(0), Known, Depth + 1, Q);
-        computeKnownBits(I->getOperand(1), Known2, Depth + 1, Q);
+        computeKnownBits(I->getOperand(0), DemandedElts, Known, Depth + 1, Q);
+        computeKnownBits(I->getOperand(1), DemandedElts, Known2, Depth + 1, Q);
         Known = KnownBits::smax(Known, Known2);
         break;
       case Intrinsic::ptrmask: {
-        computeKnownBits(I->getOperand(0), Known, Depth + 1, Q);
+        computeKnownBits(I->getOperand(0), DemandedElts, Known, Depth + 1, Q);
 
         const Value *Mask = I->getOperand(1);
         Known2 = KnownBits(Mask->getType()->getScalarSizeInBits());
-        computeKnownBits(Mask, Known2, Depth + 1, Q);
+        computeKnownBits(Mask, DemandedElts, Known2, Depth + 1, Q);
         // TODO: 1-extend would be more precise.
         Known &= Known2.anyextOrTrunc(BitWidth);
         break;

>From d2cf153517b9c21e04418dc0fc3f7d7fd7441c4a Mon Sep 17 00:00:00 2001
From: Noah Goldstein <goldstein.w.n at gmail.com>
Date: Tue, 16 Jul 2024 20:38:18 +0800
Subject: [PATCH 2/4] [ValueTracking] Consistently propagate `DemandedElts` is
 `isKnownNonZero`

---
 llvm/lib/Analysis/ValueTracking.cpp | 90 ++++++++++++++++++-----------
 1 file changed, 56 insertions(+), 34 deletions(-)

diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 8bc0e7f23b81c..b715ab6eabf70 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -303,15 +303,21 @@ bool llvm::isKnownNegative(const Value *V, const SimplifyQuery &SQ,
   return computeKnownBits(V, Depth, SQ).isNegative();
 }
 
-static bool isKnownNonEqual(const Value *V1, const Value *V2, unsigned Depth,
+static bool isKnownNonEqual(const Value *V1, const Value *V2,
+                            const APInt &DemandedElts, unsigned Depth,
                             const SimplifyQuery &Q);
 
 bool llvm::isKnownNonEqual(const Value *V1, const Value *V2,
                            const DataLayout &DL, AssumptionCache *AC,
                            const Instruction *CxtI, const DominatorTree *DT,
                            bool UseInstrInfo) {
+  assert(V1->getType() == V2->getType() &&
+         "Testing equality of non-equal types!");
+  auto *FVTy = dyn_cast<FixedVectorType>(V1->getType());
+  APInt DemandedElts =
+      FVTy ? APInt::getAllOnes(FVTy->getNumElements()) : APInt(1, 1);
   return ::isKnownNonEqual(
-      V1, V2, 0,
+      V1, V2, DemandedElts, 0,
       SimplifyQuery(DL, DT, AC, safeCxtI(V2, V1, CxtI), UseInstrInfo));
 }
 
@@ -2654,7 +2660,7 @@ static bool isNonZeroSub(const APInt &DemandedElts, unsigned Depth,
     if (C->isNullValue() && isKnownNonZero(Y, DemandedElts, Q, Depth))
       return true;
 
-  return ::isKnownNonEqual(X, Y, Depth, Q);
+  return ::isKnownNonEqual(X, Y, DemandedElts, Depth, Q);
 }
 
 static bool isNonZeroMul(const APInt &DemandedElts, unsigned Depth,
@@ -2778,8 +2784,11 @@ static bool isKnownNonZeroFromOperator(const Operator *I,
     //    This all implies the 2 i16 elements are non-zero.
     Type *FromTy = I->getOperand(0)->getType();
     if ((FromTy->isIntOrIntVectorTy() || FromTy->isPtrOrPtrVectorTy()) &&
-        (BitWidth % getBitWidth(FromTy->getScalarType(), Q.DL)) == 0)
+        (BitWidth % getBitWidth(FromTy->getScalarType(), Q.DL)) == 0) {
+      if (match(I, m_ElementWiseBitCast(m_Value())))
+        return isKnownNonZero(I->getOperand(0), DemandedElts, Q, Depth);
       return isKnownNonZero(I->getOperand(0), Q, Depth);
+    }
   } break;
   case Instruction::IntToPtr:
     // Note that we have to take special care to avoid looking through
@@ -2788,7 +2797,7 @@ static bool isKnownNonZeroFromOperator(const Operator *I,
     if (!isa<ScalableVectorType>(I->getType()) &&
         Q.DL.getTypeSizeInBits(I->getOperand(0)->getType()).getFixedValue() <=
             Q.DL.getTypeSizeInBits(I->getType()).getFixedValue())
-      return isKnownNonZero(I->getOperand(0), Q, Depth);
+      return isKnownNonZero(I->getOperand(0), DemandedElts, Q, Depth);
     break;
   case Instruction::PtrToInt:
     // Similar to int2ptr above, we can look through ptr2int here if the cast
@@ -2796,13 +2805,13 @@ static bool isKnownNonZeroFromOperator(const Operator *I,
     if (!isa<ScalableVectorType>(I->getType()) &&
         Q.DL.getTypeSizeInBits(I->getOperand(0)->getType()).getFixedValue() <=
             Q.DL.getTypeSizeInBits(I->getType()).getFixedValue())
-      return isKnownNonZero(I->getOperand(0), Q, Depth);
+      return isKnownNonZero(I->getOperand(0), DemandedElts, Q, Depth);
     break;
   case Instruction::Trunc:
     // nuw/nsw trunc preserves zero/non-zero status of input.
     if (auto *TI = dyn_cast<TruncInst>(I))
       if (TI->hasNoSignedWrap() || TI->hasNoUnsignedWrap())
-        return isKnownNonZero(TI->getOperand(0), Q, Depth);
+        return isKnownNonZero(TI->getOperand(0), DemandedElts, Q, Depth);
     break;
 
   case Instruction::Sub:
@@ -2823,13 +2832,13 @@ static bool isKnownNonZeroFromOperator(const Operator *I,
   case Instruction::SExt:
   case Instruction::ZExt:
     // ext X != 0 if X != 0.
-    return isKnownNonZero(I->getOperand(0), Q, Depth);
+    return isKnownNonZero(I->getOperand(0), DemandedElts, Q, Depth);
 
   case Instruction::Shl: {
     // shl nsw/nuw can't remove any non-zero bits.
     const OverflowingBinaryOperator *BO = cast<OverflowingBinaryOperator>(I);
     if (Q.IIQ.hasNoUnsignedWrap(BO) || Q.IIQ.hasNoSignedWrap(BO))
-      return isKnownNonZero(I->getOperand(0), Q, Depth);
+      return isKnownNonZero(I->getOperand(0), DemandedElts, Q, Depth);
 
     // shl X, Y != 0 if X is odd.  Note that the value of the shift is undefined
     // if the lowest bit is shifted off the end.
@@ -2845,7 +2854,7 @@ static bool isKnownNonZeroFromOperator(const Operator *I,
     // shr exact can only shift out zero bits.
     const PossiblyExactOperator *BO = cast<PossiblyExactOperator>(I);
     if (BO->isExact())
-      return isKnownNonZero(I->getOperand(0), Q, Depth);
+      return isKnownNonZero(I->getOperand(0), DemandedElts, Q, Depth);
 
     // shr X, Y != 0 if X is negative.  Note that the value of the shift is not
     // defined if the sign bit is shifted off the end.
@@ -3100,6 +3109,8 @@ static bool isKnownNonZeroFromOperator(const Operator *I,
                             /*NSW=*/true, /* NUW=*/false);
         // Vec reverse preserves zero/non-zero status from input vec.
       case Intrinsic::vector_reverse:
+        return isKnownNonZero(II->getArgOperand(0), DemandedElts.reverseBits(),
+                              Q, Depth);
         // umin/smin/smax/smin/or of all non-zero elements is always non-zero.
       case Intrinsic::vector_reduce_or:
       case Intrinsic::vector_reduce_umax:
@@ -3424,7 +3435,8 @@ getInvertibleOperands(const Operator *Op1,
 /// Only handle a small subset of binops where (binop V2, X) with non-zero X
 /// implies V2 != V1.
 static bool isModifyingBinopOfNonZero(const Value *V1, const Value *V2,
-                                      unsigned Depth, const SimplifyQuery &Q) {
+                                      const APInt &DemandedElts, unsigned Depth,
+                                      const SimplifyQuery &Q) {
   const BinaryOperator *BO = dyn_cast<BinaryOperator>(V1);
   if (!BO)
     return false;
@@ -3444,39 +3456,43 @@ static bool isModifyingBinopOfNonZero(const Value *V1, const Value *V2,
       Op = BO->getOperand(0);
     else
       return false;
-    return isKnownNonZero(Op, Q, Depth + 1);
+    return isKnownNonZero(Op, DemandedElts, Q, Depth + 1);
   }
   return false;
 }
 
 /// Return true if V2 == V1 * C, where V1 is known non-zero, C is not 0/1 and
 /// the multiplication is nuw or nsw.
-static bool isNonEqualMul(const Value *V1, const Value *V2, unsigned Depth,
+static bool isNonEqualMul(const Value *V1, const Value *V2,
+                          const APInt &DemandedElts, unsigned Depth,
                           const SimplifyQuery &Q) {
   if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(V2)) {
     const APInt *C;
     return match(OBO, m_Mul(m_Specific(V1), m_APInt(C))) &&
            (OBO->hasNoUnsignedWrap() || OBO->hasNoSignedWrap()) &&
-           !C->isZero() && !C->isOne() && isKnownNonZero(V1, Q, Depth + 1);
+           !C->isZero() && !C->isOne() &&
+           isKnownNonZero(V1, DemandedElts, Q, Depth + 1);
   }
   return false;
 }
 
 /// Return true if V2 == V1 << C, where V1 is known non-zero, C is not 0 and
 /// the shift is nuw or nsw.
-static bool isNonEqualShl(const Value *V1, const Value *V2, unsigned Depth,
+static bool isNonEqualShl(const Value *V1, const Value *V2,
+                          const APInt &DemandedElts, unsigned Depth,
                           const SimplifyQuery &Q) {
   if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(V2)) {
     const APInt *C;
     return match(OBO, m_Shl(m_Specific(V1), m_APInt(C))) &&
            (OBO->hasNoUnsignedWrap() || OBO->hasNoSignedWrap()) &&
-           !C->isZero() && isKnownNonZero(V1, Q, Depth + 1);
+           !C->isZero() && isKnownNonZero(V1, DemandedElts, Q, Depth + 1);
   }
   return false;
 }
 
 static bool isNonEqualPHIs(const PHINode *PN1, const PHINode *PN2,
-                           unsigned Depth, const SimplifyQuery &Q) {
+                           const APInt &DemandedElts, unsigned Depth,
+                           const SimplifyQuery &Q) {
   // Check two PHIs are in same block.
   if (PN1->getParent() != PN2->getParent())
     return false;
@@ -3498,14 +3514,15 @@ static bool isNonEqualPHIs(const PHINode *PN1, const PHINode *PN2,
 
     SimplifyQuery RecQ = Q;
     RecQ.CxtI = IncomBB->getTerminator();
-    if (!isKnownNonEqual(IV1, IV2, Depth + 1, RecQ))
+    if (!isKnownNonEqual(IV1, IV2, DemandedElts, Depth + 1, RecQ))
       return false;
     UsedFullRecursion = true;
   }
   return true;
 }
 
-static bool isNonEqualSelect(const Value *V1, const Value *V2, unsigned Depth,
+static bool isNonEqualSelect(const Value *V1, const Value *V2,
+                             const APInt &DemandedElts, unsigned Depth,
                              const SimplifyQuery &Q) {
   const SelectInst *SI1 = dyn_cast<SelectInst>(V1);
   if (!SI1)
@@ -3516,12 +3533,12 @@ static bool isNonEqualSelect(const Value *V1, const Value *V2, unsigned Depth,
     const Value *Cond2 = SI2->getCondition();
     if (Cond1 == Cond2)
       return isKnownNonEqual(SI1->getTrueValue(), SI2->getTrueValue(),
-                             Depth + 1, Q) &&
+                             DemandedElts, Depth + 1, Q) &&
              isKnownNonEqual(SI1->getFalseValue(), SI2->getFalseValue(),
-                             Depth + 1, Q);
+                             DemandedElts, Depth + 1, Q);
   }
-  return isKnownNonEqual(SI1->getTrueValue(), V2, Depth + 1, Q) &&
-         isKnownNonEqual(SI1->getFalseValue(), V2, Depth + 1, Q);
+  return isKnownNonEqual(SI1->getTrueValue(), V2, DemandedElts, Depth + 1, Q) &&
+         isKnownNonEqual(SI1->getFalseValue(), V2, DemandedElts, Depth + 1, Q);
 }
 
 // Check to see if A is both a GEP and is the incoming value for a PHI in the
@@ -3577,7 +3594,8 @@ static bool isNonEqualPointersWithRecursiveGEP(const Value *A, const Value *B,
 }
 
 /// Return true if it is known that V1 != V2.
-static bool isKnownNonEqual(const Value *V1, const Value *V2, unsigned Depth,
+static bool isKnownNonEqual(const Value *V1, const Value *V2,
+                            const APInt &DemandedElts, unsigned Depth,
                             const SimplifyQuery &Q) {
   if (V1 == V2)
     return false;
@@ -3595,40 +3613,44 @@ static bool isKnownNonEqual(const Value *V1, const Value *V2, unsigned Depth,
   auto *O2 = dyn_cast<Operator>(V2);
   if (O1 && O2 && O1->getOpcode() == O2->getOpcode()) {
     if (auto Values = getInvertibleOperands(O1, O2))
-      return isKnownNonEqual(Values->first, Values->second, Depth + 1, Q);
+      return isKnownNonEqual(Values->first, Values->second, DemandedElts,
+                             Depth + 1, Q);
 
     if (const PHINode *PN1 = dyn_cast<PHINode>(V1)) {
       const PHINode *PN2 = cast<PHINode>(V2);
       // FIXME: This is missing a generalization to handle the case where one is
       // a PHI and another one isn't.
-      if (isNonEqualPHIs(PN1, PN2, Depth, Q))
+      if (isNonEqualPHIs(PN1, PN2, DemandedElts, Depth, Q))
         return true;
     };
   }
 
-  if (isModifyingBinopOfNonZero(V1, V2, Depth, Q) ||
-      isModifyingBinopOfNonZero(V2, V1, Depth, Q))
+  if (isModifyingBinopOfNonZero(V1, V2, DemandedElts, Depth, Q) ||
+      isModifyingBinopOfNonZero(V2, V1, DemandedElts, Depth, Q))
     return true;
 
-  if (isNonEqualMul(V1, V2, Depth, Q) || isNonEqualMul(V2, V1, Depth, Q))
+  if (isNonEqualMul(V1, V2, DemandedElts, Depth, Q) ||
+      isNonEqualMul(V2, V1, DemandedElts, Depth, Q))
     return true;
 
-  if (isNonEqualShl(V1, V2, Depth, Q) || isNonEqualShl(V2, V1, Depth, Q))
+  if (isNonEqualShl(V1, V2, DemandedElts, Depth, Q) ||
+      isNonEqualShl(V2, V1, DemandedElts, Depth, Q))
     return true;
 
   if (V1->getType()->isIntOrIntVectorTy()) {
     // Are any known bits in V1 contradictory to known bits in V2? If V1
     // has a known zero where V2 has a known one, they must not be equal.
-    KnownBits Known1 = computeKnownBits(V1, Depth, Q);
+    KnownBits Known1 = computeKnownBits(V1, DemandedElts, Depth, Q);
     if (!Known1.isUnknown()) {
-      KnownBits Known2 = computeKnownBits(V2, Depth, Q);
+      KnownBits Known2 = computeKnownBits(V2, DemandedElts, Depth, Q);
       if (Known1.Zero.intersects(Known2.One) ||
           Known2.Zero.intersects(Known1.One))
         return true;
     }
   }
 
-  if (isNonEqualSelect(V1, V2, Depth, Q) || isNonEqualSelect(V2, V1, Depth, Q))
+  if (isNonEqualSelect(V1, V2, DemandedElts, Depth, Q) ||
+      isNonEqualSelect(V2, V1, DemandedElts, Depth, Q))
     return true;
 
   if (isNonEqualPointersWithRecursiveGEP(V1, V2, Q) ||
@@ -3640,7 +3662,7 @@ static bool isKnownNonEqual(const Value *V1, const Value *V2, unsigned Depth,
   // Check PtrToInt type matches the pointer size.
   if (match(V1, m_PtrToIntSameSize(Q.DL, m_Value(A))) &&
       match(V2, m_PtrToIntSameSize(Q.DL, m_Value(B))))
-    return isKnownNonEqual(A, B, Depth + 1, Q);
+    return isKnownNonEqual(A, B, DemandedElts, Depth + 1, Q);
 
   return false;
 }

>From 2b88065d241568c2ef95701e8105980177bf5a2f Mon Sep 17 00:00:00 2001
From: Noah Goldstein <goldstein.w.n at gmail.com>
Date: Tue, 16 Jul 2024 20:40:39 +0800
Subject: [PATCH 3/4] [ValueTracking] Consistently propagate `DemandedElts` is
 `ComputeNumSignBits`

---
 llvm/lib/Analysis/ValueTracking.cpp | 67 +++++++++++++++++------------
 1 file changed, 40 insertions(+), 27 deletions(-)

diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index b715ab6eabf70..f54de030d3344 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -3801,7 +3801,8 @@ static unsigned ComputeNumSignBitsImpl(const Value *V,
     default: break;
     case Instruction::SExt:
       Tmp = TyBits - U->getOperand(0)->getType()->getScalarSizeInBits();
-      return ComputeNumSignBits(U->getOperand(0), Depth + 1, Q) + Tmp;
+      return ComputeNumSignBits(U->getOperand(0), DemandedElts, Depth + 1, Q) +
+             Tmp;
 
     case Instruction::SDiv: {
       const APInt *Denominator;
@@ -3813,7 +3814,8 @@ static unsigned ComputeNumSignBitsImpl(const Value *V,
           break;
 
         // Calculate the incoming numerator bits.
-        unsigned NumBits = ComputeNumSignBits(U->getOperand(0), Depth + 1, Q);
+        unsigned NumBits =
+            ComputeNumSignBits(U->getOperand(0), DemandedElts, Depth + 1, Q);
 
         // Add floor(log(C)) bits to the numerator bits.
         return std::min(TyBits, NumBits + Denominator->logBase2());
@@ -3822,7 +3824,7 @@ static unsigned ComputeNumSignBitsImpl(const Value *V,
     }
 
     case Instruction::SRem: {
-      Tmp = ComputeNumSignBits(U->getOperand(0), Depth + 1, Q);
+      Tmp = ComputeNumSignBits(U->getOperand(0), DemandedElts, Depth + 1, Q);
 
       const APInt *Denominator;
       // srem X, C -> we know that the result is within [-C+1,C) when C is a
@@ -3853,7 +3855,7 @@ static unsigned ComputeNumSignBitsImpl(const Value *V,
     }
 
     case Instruction::AShr: {
-      Tmp = ComputeNumSignBits(U->getOperand(0), Depth + 1, Q);
+      Tmp = ComputeNumSignBits(U->getOperand(0), DemandedElts, Depth + 1, Q);
       // ashr X, C   -> adds C sign bits.  Vectors too.
       const APInt *ShAmt;
       if (match(U->getOperand(1), m_APInt(ShAmt))) {
@@ -3869,7 +3871,7 @@ static unsigned ComputeNumSignBitsImpl(const Value *V,
       const APInt *ShAmt;
       if (match(U->getOperand(1), m_APInt(ShAmt))) {
         // shl destroys sign bits.
-        Tmp = ComputeNumSignBits(U->getOperand(0), Depth + 1, Q);
+        Tmp = ComputeNumSignBits(U->getOperand(0), DemandedElts, Depth + 1, Q);
         if (ShAmt->uge(TyBits) ||   // Bad shift.
             ShAmt->uge(Tmp)) break; // Shifted all sign bits out.
         Tmp2 = ShAmt->getZExtValue();
@@ -3881,9 +3883,9 @@ static unsigned ComputeNumSignBitsImpl(const Value *V,
     case Instruction::Or:
     case Instruction::Xor: // NOT is handled here.
       // Logical binary ops preserve the number of sign bits at the worst.
-      Tmp = ComputeNumSignBits(U->getOperand(0), Depth + 1, Q);
+      Tmp = ComputeNumSignBits(U->getOperand(0), DemandedElts, Depth + 1, Q);
       if (Tmp != 1) {
-        Tmp2 = ComputeNumSignBits(U->getOperand(1), Depth + 1, Q);
+        Tmp2 = ComputeNumSignBits(U->getOperand(1), DemandedElts, Depth + 1, Q);
         FirstAnswer = std::min(Tmp, Tmp2);
         // We computed what we know about the sign bits as our first
         // answer. Now proceed to the generic code that uses
@@ -3899,9 +3901,10 @@ static unsigned ComputeNumSignBitsImpl(const Value *V,
       if (isSignedMinMaxClamp(U, X, CLow, CHigh))
         return std::min(CLow->getNumSignBits(), CHigh->getNumSignBits());
 
-      Tmp = ComputeNumSignBits(U->getOperand(1), Depth + 1, Q);
-      if (Tmp == 1) break;
-      Tmp2 = ComputeNumSignBits(U->getOperand(2), Depth + 1, Q);
+      Tmp = ComputeNumSignBits(U->getOperand(1), DemandedElts, Depth + 1, Q);
+      if (Tmp == 1)
+        break;
+      Tmp2 = ComputeNumSignBits(U->getOperand(2), DemandedElts, Depth + 1, Q);
       return std::min(Tmp, Tmp2);
     }
 
@@ -3915,7 +3918,7 @@ static unsigned ComputeNumSignBitsImpl(const Value *V,
       if (const auto *CRHS = dyn_cast<Constant>(U->getOperand(1)))
         if (CRHS->isAllOnesValue()) {
           KnownBits Known(TyBits);
-          computeKnownBits(U->getOperand(0), Known, Depth + 1, Q);
+          computeKnownBits(U->getOperand(0), DemandedElts, Known, Depth + 1, Q);
 
           // If the input is known to be 0 or 1, the output is 0/-1, which is
           // all sign bits set.
@@ -3928,19 +3931,21 @@ static unsigned ComputeNumSignBitsImpl(const Value *V,
             return Tmp;
         }
 
-      Tmp2 = ComputeNumSignBits(U->getOperand(1), Depth + 1, Q);
-      if (Tmp2 == 1) break;
+      Tmp2 = ComputeNumSignBits(U->getOperand(1), DemandedElts, Depth + 1, Q);
+      if (Tmp2 == 1)
+        break;
       return std::min(Tmp, Tmp2) - 1;
 
     case Instruction::Sub:
-      Tmp2 = ComputeNumSignBits(U->getOperand(1), Depth + 1, Q);
-      if (Tmp2 == 1) break;
+      Tmp2 = ComputeNumSignBits(U->getOperand(1), DemandedElts, Depth + 1, Q);
+      if (Tmp2 == 1)
+        break;
 
       // Handle NEG.
       if (const auto *CLHS = dyn_cast<Constant>(U->getOperand(0)))
         if (CLHS->isNullValue()) {
           KnownBits Known(TyBits);
-          computeKnownBits(U->getOperand(1), Known, Depth + 1, Q);
+          computeKnownBits(U->getOperand(1), DemandedElts, Known, Depth + 1, Q);
           // If the input is known to be 0 or 1, the output is 0/-1, which is
           // all sign bits set.
           if ((Known.Zero | 1).isAllOnes())
@@ -3957,17 +3962,22 @@ static unsigned ComputeNumSignBitsImpl(const Value *V,
 
       // Sub can have at most one carry bit.  Thus we know that the output
       // is, at worst, one more bit than the inputs.
-      Tmp = ComputeNumSignBits(U->getOperand(0), Depth + 1, Q);
-      if (Tmp == 1) break;
+      Tmp = ComputeNumSignBits(U->getOperand(0), DemandedElts, Depth + 1, Q);
+      if (Tmp == 1)
+        break;
       return std::min(Tmp, Tmp2) - 1;
 
     case Instruction::Mul: {
       // The output of the Mul can be at most twice the valid bits in the
       // inputs.
-      unsigned SignBitsOp0 = ComputeNumSignBits(U->getOperand(0), Depth + 1, Q);
-      if (SignBitsOp0 == 1) break;
-      unsigned SignBitsOp1 = ComputeNumSignBits(U->getOperand(1), Depth + 1, Q);
-      if (SignBitsOp1 == 1) break;
+      unsigned SignBitsOp0 =
+          ComputeNumSignBits(U->getOperand(0), DemandedElts, Depth + 1, Q);
+      if (SignBitsOp0 == 1)
+        break;
+      unsigned SignBitsOp1 =
+          ComputeNumSignBits(U->getOperand(1), DemandedElts, Depth + 1, Q);
+      if (SignBitsOp1 == 1)
+        break;
       unsigned OutValidBits =
           (TyBits - SignBitsOp0 + 1) + (TyBits - SignBitsOp1 + 1);
       return OutValidBits > TyBits ? 1 : TyBits - OutValidBits + 1;
@@ -3988,8 +3998,8 @@ static unsigned ComputeNumSignBitsImpl(const Value *V,
       for (unsigned i = 0, e = NumIncomingValues; i != e; ++i) {
         if (Tmp == 1) return Tmp;
         RecQ.CxtI = PN->getIncomingBlock(i)->getTerminator();
-        Tmp = std::min(
-            Tmp, ComputeNumSignBits(PN->getIncomingValue(i), Depth + 1, RecQ));
+        Tmp = std::min(Tmp, ComputeNumSignBits(PN->getIncomingValue(i),
+                                               DemandedElts, Depth + 1, RecQ));
       }
       return Tmp;
     }
@@ -4050,10 +4060,13 @@ static unsigned ComputeNumSignBitsImpl(const Value *V,
     case Instruction::Call: {
       if (const auto *II = dyn_cast<IntrinsicInst>(U)) {
         switch (II->getIntrinsicID()) {
-        default: break;
+        default:
+          break;
         case Intrinsic::abs:
-          Tmp = ComputeNumSignBits(U->getOperand(0), Depth + 1, Q);
-          if (Tmp == 1) break;
+          Tmp =
+              ComputeNumSignBits(U->getOperand(0), DemandedElts, Depth + 1, Q);
+          if (Tmp == 1)
+            break;
 
           // Absolute value reduces number of sign bits by at most 1.
           return Tmp - 1;

>From 0df283c8117daebd019927aff76f78f43cfb5569 Mon Sep 17 00:00:00 2001
From: Noah Goldstein <goldstein.w.n at gmail.com>
Date: Tue, 16 Jul 2024 21:30:36 +0800
Subject: [PATCH 4/4] [ValueTracking] Consistently propagate `DemandedElts` is
 `computeKnownFPClass`

---
 llvm/include/llvm/Analysis/ValueTracking.h | 22 +++++++++++++++++-----
 llvm/lib/Analysis/ValueTracking.cpp        |  5 +++--
 2 files changed, 20 insertions(+), 7 deletions(-)

diff --git a/llvm/include/llvm/Analysis/ValueTracking.h b/llvm/include/llvm/Analysis/ValueTracking.h
index 354ad5bc95317..2c2f965a3cd6f 100644
--- a/llvm/include/llvm/Analysis/ValueTracking.h
+++ b/llvm/include/llvm/Analysis/ValueTracking.h
@@ -526,16 +526,17 @@ inline KnownFPClass computeKnownFPClass(
 }
 
 /// Wrapper to account for known fast math flags at the use instruction.
-inline KnownFPClass computeKnownFPClass(const Value *V, FastMathFlags FMF,
-                                        FPClassTest InterestedClasses,
-                                        unsigned Depth,
-                                        const SimplifyQuery &SQ) {
+inline KnownFPClass
+computeKnownFPClass(const Value *V, const APInt &DemandedElts,
+                    FastMathFlags FMF, FPClassTest InterestedClasses,
+                    unsigned Depth, const SimplifyQuery &SQ) {
   if (FMF.noNaNs())
     InterestedClasses &= ~fcNan;
   if (FMF.noInfs())
     InterestedClasses &= ~fcInf;
 
-  KnownFPClass Result = computeKnownFPClass(V, InterestedClasses, Depth, SQ);
+  KnownFPClass Result =
+      computeKnownFPClass(V, DemandedElts, InterestedClasses, Depth, SQ);
 
   if (FMF.noNaNs())
     Result.KnownFPClasses &= ~fcNan;
@@ -544,6 +545,17 @@ inline KnownFPClass computeKnownFPClass(const Value *V, FastMathFlags FMF,
   return Result;
 }
 
+inline KnownFPClass computeKnownFPClass(const Value *V, FastMathFlags FMF,
+                                        FPClassTest InterestedClasses,
+                                        unsigned Depth,
+                                        const SimplifyQuery &SQ) {
+  auto *FVTy = dyn_cast<FixedVectorType>(V->getType());
+  APInt DemandedElts =
+      FVTy ? APInt::getAllOnes(FVTy->getNumElements()) : APInt(1, 1);
+  return computeKnownFPClass(V, DemandedElts, FMF, InterestedClasses, Depth,
+                             SQ);
+}
+
 /// Return true if we can prove that the specified FP value is never equal to
 /// -0.0. Users should use caution when considering PreserveSign
 /// denormal-fp-math.
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index f54de030d3344..6e039ad2deadb 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -5274,8 +5274,9 @@ void computeKnownFPClass(const Value *V, const APInt &DemandedElts,
     }
       // reverse preserves all characteristics of the input vec's element.
     case Intrinsic::vector_reverse:
-      Known = computeKnownFPClass(II->getArgOperand(0), II->getFastMathFlags(),
-                                  InterestedClasses, Depth + 1, Q);
+      Known = computeKnownFPClass(
+          II->getArgOperand(0), DemandedElts.reverseBits(),
+          II->getFastMathFlags(), InterestedClasses, Depth + 1, Q);
       break;
     case Intrinsic::trunc:
     case Intrinsic::floor:



More information about the llvm-commits mailing list