[llvm] goldsteinn/demanded elts consistent (PR #99080)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Jul 16 11:50:35 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-analysis
Author: None (goldsteinn)
<details>
<summary>Changes</summary>
- **[ValueTracking] Consistently propagate `DemandedElts` is `computeKnownBits`**
- **[ValueTracking] Consistently propagate `DemandedElts` is `isKnownNonZero`**
- **[ValueTracking] Consistently propagate `DemandedElts` is `ComputeNumSignBits`**
- **[ValueTracking] Consistently propagate `DemandedElts` is `computeKnownFPClass`**
---
Patch is 34.76 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/99080.diff
2 Files Affected:
- (modified) llvm/include/llvm/Analysis/ValueTracking.h (+17-5)
- (modified) llvm/lib/Analysis/ValueTracking.cpp (+146-104)
``````````diff
diff --git a/llvm/include/llvm/Analysis/ValueTracking.h b/llvm/include/llvm/Analysis/ValueTracking.h
index 354ad5bc95317..2c2f965a3cd6f 100644
--- a/llvm/include/llvm/Analysis/ValueTracking.h
+++ b/llvm/include/llvm/Analysis/ValueTracking.h
@@ -526,16 +526,17 @@ inline KnownFPClass computeKnownFPClass(
}
/// Wrapper to account for known fast math flags at the use instruction.
-inline KnownFPClass computeKnownFPClass(const Value *V, FastMathFlags FMF,
- FPClassTest InterestedClasses,
- unsigned Depth,
- const SimplifyQuery &SQ) {
+inline KnownFPClass
+computeKnownFPClass(const Value *V, const APInt &DemandedElts,
+ FastMathFlags FMF, FPClassTest InterestedClasses,
+ unsigned Depth, const SimplifyQuery &SQ) {
if (FMF.noNaNs())
InterestedClasses &= ~fcNan;
if (FMF.noInfs())
InterestedClasses &= ~fcInf;
- KnownFPClass Result = computeKnownFPClass(V, InterestedClasses, Depth, SQ);
+ KnownFPClass Result =
+ computeKnownFPClass(V, DemandedElts, InterestedClasses, Depth, SQ);
if (FMF.noNaNs())
Result.KnownFPClasses &= ~fcNan;
@@ -544,6 +545,17 @@ inline KnownFPClass computeKnownFPClass(const Value *V, FastMathFlags FMF,
return Result;
}
+inline KnownFPClass computeKnownFPClass(const Value *V, FastMathFlags FMF,
+ FPClassTest InterestedClasses,
+ unsigned Depth,
+ const SimplifyQuery &SQ) {
+ auto *FVTy = dyn_cast<FixedVectorType>(V->getType());
+ APInt DemandedElts =
+ FVTy ? APInt::getAllOnes(FVTy->getNumElements()) : APInt(1, 1);
+ return computeKnownFPClass(V, DemandedElts, FMF, InterestedClasses, Depth,
+ SQ);
+}
+
/// Return true if we can prove that the specified FP value is never equal to
/// -0.0. Users should use caution when considering PreserveSign
/// denormal-fp-math.
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index f8ec868398323..6e039ad2deadb 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -303,15 +303,21 @@ bool llvm::isKnownNegative(const Value *V, const SimplifyQuery &SQ,
return computeKnownBits(V, Depth, SQ).isNegative();
}
-static bool isKnownNonEqual(const Value *V1, const Value *V2, unsigned Depth,
+static bool isKnownNonEqual(const Value *V1, const Value *V2,
+ const APInt &DemandedElts, unsigned Depth,
const SimplifyQuery &Q);
bool llvm::isKnownNonEqual(const Value *V1, const Value *V2,
const DataLayout &DL, AssumptionCache *AC,
const Instruction *CxtI, const DominatorTree *DT,
bool UseInstrInfo) {
+ assert(V1->getType() == V2->getType() &&
+ "Testing equality of non-equal types!");
+ auto *FVTy = dyn_cast<FixedVectorType>(V1->getType());
+ APInt DemandedElts =
+ FVTy ? APInt::getAllOnes(FVTy->getNumElements()) : APInt(1, 1);
return ::isKnownNonEqual(
- V1, V2, 0,
+ V1, V2, DemandedElts, 0,
SimplifyQuery(DL, DT, AC, safeCxtI(V2, V1, CxtI), UseInstrInfo));
}
@@ -1091,15 +1097,15 @@ static void computeKnownBitsFromOperator(const Operator *I,
break;
}
case Instruction::UDiv: {
- computeKnownBits(I->getOperand(0), Known, Depth + 1, Q);
- computeKnownBits(I->getOperand(1), Known2, Depth + 1, Q);
+ computeKnownBits(I->getOperand(0), DemandedElts, Known, Depth + 1, Q);
+ computeKnownBits(I->getOperand(1), DemandedElts, Known2, Depth + 1, Q);
Known =
KnownBits::udiv(Known, Known2, Q.IIQ.isExact(cast<BinaryOperator>(I)));
break;
}
case Instruction::SDiv: {
- computeKnownBits(I->getOperand(0), Known, Depth + 1, Q);
- computeKnownBits(I->getOperand(1), Known2, Depth + 1, Q);
+ computeKnownBits(I->getOperand(0), DemandedElts, Known, Depth + 1, Q);
+ computeKnownBits(I->getOperand(1), DemandedElts, Known2, Depth + 1, Q);
Known =
KnownBits::sdiv(Known, Known2, Q.IIQ.isExact(cast<BinaryOperator>(I)));
break;
@@ -1107,7 +1113,7 @@ static void computeKnownBitsFromOperator(const Operator *I,
case Instruction::Select: {
auto ComputeForArm = [&](Value *Arm, bool Invert) {
KnownBits Res(Known.getBitWidth());
- computeKnownBits(Arm, Res, Depth + 1, Q);
+ computeKnownBits(Arm, DemandedElts, Res, Depth + 1, Q);
adjustKnownBitsForSelectArm(Res, I->getOperand(0), Arm, Invert, Depth, Q);
return Res;
};
@@ -1142,7 +1148,7 @@ static void computeKnownBitsFromOperator(const Operator *I,
assert(SrcBitWidth && "SrcBitWidth can't be zero");
Known = Known.anyextOrTrunc(SrcBitWidth);
- computeKnownBits(I->getOperand(0), Known, Depth + 1, Q);
+ computeKnownBits(I->getOperand(0), DemandedElts, Known, Depth + 1, Q);
if (auto *Inst = dyn_cast<PossiblyNonNegInst>(I);
Inst && Inst->hasNonNeg() && !Known.isNegative())
Known.makeNonNegative();
@@ -1164,7 +1170,8 @@ static void computeKnownBitsFromOperator(const Operator *I,
if (match(I, m_ElementWiseBitCast(m_Value(V))) &&
V->getType()->isFPOrFPVectorTy()) {
Type *FPType = V->getType()->getScalarType();
- KnownFPClass Result = computeKnownFPClass(V, fcAllFlags, Depth + 1, Q);
+ KnownFPClass Result =
+ computeKnownFPClass(V, DemandedElts, fcAllFlags, Depth + 1, Q);
FPClassTest FPClasses = Result.KnownFPClasses;
// TODO: Treat it as zero/poison if the use of I is unreachable.
@@ -1245,7 +1252,7 @@ static void computeKnownBitsFromOperator(const Operator *I,
unsigned SrcBitWidth = I->getOperand(0)->getType()->getScalarSizeInBits();
Known = Known.trunc(SrcBitWidth);
- computeKnownBits(I->getOperand(0), Known, Depth + 1, Q);
+ computeKnownBits(I->getOperand(0), DemandedElts, Known, Depth + 1, Q);
// If the sign bit of the input is known set or clear, then we know the
// top bits of the result.
Known = Known.sext(BitWidth);
@@ -1305,14 +1312,14 @@ static void computeKnownBitsFromOperator(const Operator *I,
break;
}
case Instruction::SRem:
- computeKnownBits(I->getOperand(0), Known, Depth + 1, Q);
- computeKnownBits(I->getOperand(1), Known2, Depth + 1, Q);
+ computeKnownBits(I->getOperand(0), DemandedElts, Known, Depth + 1, Q);
+ computeKnownBits(I->getOperand(1), DemandedElts, Known2, Depth + 1, Q);
Known = KnownBits::srem(Known, Known2);
break;
case Instruction::URem:
- computeKnownBits(I->getOperand(0), Known, Depth + 1, Q);
- computeKnownBits(I->getOperand(1), Known2, Depth + 1, Q);
+ computeKnownBits(I->getOperand(0), DemandedElts, Known, Depth + 1, Q);
+ computeKnownBits(I->getOperand(1), DemandedElts, Known2, Depth + 1, Q);
Known = KnownBits::urem(Known, Known2);
break;
case Instruction::Alloca:
@@ -1465,17 +1472,17 @@ static void computeKnownBitsFromOperator(const Operator *I,
unsigned OpNum = P->getOperand(0) == R ? 0 : 1;
Instruction *RInst = P->getIncomingBlock(OpNum)->getTerminator();
- Instruction *LInst = P->getIncomingBlock(1-OpNum)->getTerminator();
+ Instruction *LInst = P->getIncomingBlock(1 - OpNum)->getTerminator();
// Ok, we have a PHI of the form L op= R. Check for low
// zero bits.
RecQ.CxtI = RInst;
- computeKnownBits(R, Known2, Depth + 1, RecQ);
+ computeKnownBits(R, DemandedElts, Known2, Depth + 1, RecQ);
// We need to take the minimum number of known bits
KnownBits Known3(BitWidth);
RecQ.CxtI = LInst;
- computeKnownBits(L, Known3, Depth + 1, RecQ);
+ computeKnownBits(L, DemandedElts, Known3, Depth + 1, RecQ);
Known.Zero.setLowBits(std::min(Known2.countMinTrailingZeros(),
Known3.countMinTrailingZeros()));
@@ -1548,7 +1555,8 @@ static void computeKnownBitsFromOperator(const Operator *I,
// want to waste time spinning around in loops.
// TODO: See if we can base recursion limiter on number of incoming phi
// edges so we don't overly clamp analysis.
- computeKnownBits(IncValue, Known2, MaxAnalysisRecursionDepth - 1, RecQ);
+ computeKnownBits(IncValue, DemandedElts, Known2,
+ MaxAnalysisRecursionDepth - 1, RecQ);
// See if we can further use a conditional branch into the phi
// to help us determine the range of the value.
@@ -1619,9 +1627,10 @@ static void computeKnownBitsFromOperator(const Operator *I,
}
if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) {
switch (II->getIntrinsicID()) {
- default: break;
+ default:
+ break;
case Intrinsic::abs: {
- computeKnownBits(I->getOperand(0), Known2, Depth + 1, Q);
+ computeKnownBits(I->getOperand(0), DemandedElts, Known2, Depth + 1, Q);
bool IntMinIsPoison = match(II->getArgOperand(1), m_One());
Known = Known2.abs(IntMinIsPoison);
break;
@@ -1637,7 +1646,7 @@ static void computeKnownBitsFromOperator(const Operator *I,
Known.One |= Known2.One.byteSwap();
break;
case Intrinsic::ctlz: {
- computeKnownBits(I->getOperand(0), Known2, Depth + 1, Q);
+ computeKnownBits(I->getOperand(0), DemandedElts, Known2, Depth + 1, Q);
// If we have a known 1, its position is our upper bound.
unsigned PossibleLZ = Known2.countMaxLeadingZeros();
// If this call is poison for 0 input, the result will be less than 2^n.
@@ -1648,7 +1657,7 @@ static void computeKnownBitsFromOperator(const Operator *I,
break;
}
case Intrinsic::cttz: {
- computeKnownBits(I->getOperand(0), Known2, Depth + 1, Q);
+ computeKnownBits(I->getOperand(0), DemandedElts, Known2, Depth + 1, Q);
// If we have a known 1, its position is our upper bound.
unsigned PossibleTZ = Known2.countMaxTrailingZeros();
// If this call is poison for 0 input, the result will be less than 2^n.
@@ -1659,7 +1668,7 @@ static void computeKnownBitsFromOperator(const Operator *I,
break;
}
case Intrinsic::ctpop: {
- computeKnownBits(I->getOperand(0), Known2, Depth + 1, Q);
+ computeKnownBits(I->getOperand(0), DemandedElts, Known2, Depth + 1, Q);
// We can bound the space the count needs. Also, bits known to be zero
// can't contribute to the population.
unsigned BitsPossiblySet = Known2.countMaxPopulation();
@@ -1681,8 +1690,8 @@ static void computeKnownBitsFromOperator(const Operator *I,
ShiftAmt = BitWidth - ShiftAmt;
KnownBits Known3(BitWidth);
- computeKnownBits(I->getOperand(0), Known2, Depth + 1, Q);
- computeKnownBits(I->getOperand(1), Known3, Depth + 1, Q);
+ computeKnownBits(I->getOperand(0), DemandedElts, Known2, Depth + 1, Q);
+ computeKnownBits(I->getOperand(1), DemandedElts, Known3, Depth + 1, Q);
Known.Zero =
Known2.Zero.shl(ShiftAmt) | Known3.Zero.lshr(BitWidth - ShiftAmt);
@@ -1691,27 +1700,30 @@ static void computeKnownBitsFromOperator(const Operator *I,
break;
}
case Intrinsic::uadd_sat:
- computeKnownBits(I->getOperand(0), Known, Depth + 1, Q);
- computeKnownBits(I->getOperand(1), Known2, Depth + 1, Q);
+ computeKnownBits(I->getOperand(0), DemandedElts, Known, Depth + 1, Q);
+ computeKnownBits(I->getOperand(1), DemandedElts, Known2, Depth + 1, Q);
Known = KnownBits::uadd_sat(Known, Known2);
break;
case Intrinsic::usub_sat:
- computeKnownBits(I->getOperand(0), Known, Depth + 1, Q);
- computeKnownBits(I->getOperand(1), Known2, Depth + 1, Q);
+ computeKnownBits(I->getOperand(0), DemandedElts, Known, Depth + 1, Q);
+ computeKnownBits(I->getOperand(1), DemandedElts, Known2, Depth + 1, Q);
Known = KnownBits::usub_sat(Known, Known2);
break;
case Intrinsic::sadd_sat:
- computeKnownBits(I->getOperand(0), Known, Depth + 1, Q);
- computeKnownBits(I->getOperand(1), Known2, Depth + 1, Q);
+ computeKnownBits(I->getOperand(0), DemandedElts, Known, Depth + 1, Q);
+ computeKnownBits(I->getOperand(1), DemandedElts, Known2, Depth + 1, Q);
Known = KnownBits::sadd_sat(Known, Known2);
break;
case Intrinsic::ssub_sat:
- computeKnownBits(I->getOperand(0), Known, Depth + 1, Q);
- computeKnownBits(I->getOperand(1), Known2, Depth + 1, Q);
+ computeKnownBits(I->getOperand(0), DemandedElts, Known, Depth + 1, Q);
+ computeKnownBits(I->getOperand(1), DemandedElts, Known2, Depth + 1, Q);
Known = KnownBits::ssub_sat(Known, Known2);
break;
// Vec reverse preserves bits from input vec.
case Intrinsic::vector_reverse:
+ computeKnownBits(I->getOperand(0), DemandedElts.reverseBits(), Known,
+ Depth + 1, Q);
+ break;
// for min/max/and/or reduce, any bit common to each element in the
// input vec is set in the output.
case Intrinsic::vector_reduce_and:
@@ -1738,31 +1750,31 @@ static void computeKnownBitsFromOperator(const Operator *I,
break;
}
case Intrinsic::umin:
- computeKnownBits(I->getOperand(0), Known, Depth + 1, Q);
- computeKnownBits(I->getOperand(1), Known2, Depth + 1, Q);
+ computeKnownBits(I->getOperand(0), DemandedElts, Known, Depth + 1, Q);
+ computeKnownBits(I->getOperand(1), DemandedElts, Known2, Depth + 1, Q);
Known = KnownBits::umin(Known, Known2);
break;
case Intrinsic::umax:
- computeKnownBits(I->getOperand(0), Known, Depth + 1, Q);
- computeKnownBits(I->getOperand(1), Known2, Depth + 1, Q);
+ computeKnownBits(I->getOperand(0), DemandedElts, Known, Depth + 1, Q);
+ computeKnownBits(I->getOperand(1), DemandedElts, Known2, Depth + 1, Q);
Known = KnownBits::umax(Known, Known2);
break;
case Intrinsic::smin:
- computeKnownBits(I->getOperand(0), Known, Depth + 1, Q);
- computeKnownBits(I->getOperand(1), Known2, Depth + 1, Q);
+ computeKnownBits(I->getOperand(0), DemandedElts, Known, Depth + 1, Q);
+ computeKnownBits(I->getOperand(1), DemandedElts, Known2, Depth + 1, Q);
Known = KnownBits::smin(Known, Known2);
break;
case Intrinsic::smax:
- computeKnownBits(I->getOperand(0), Known, Depth + 1, Q);
- computeKnownBits(I->getOperand(1), Known2, Depth + 1, Q);
+ computeKnownBits(I->getOperand(0), DemandedElts, Known, Depth + 1, Q);
+ computeKnownBits(I->getOperand(1), DemandedElts, Known2, Depth + 1, Q);
Known = KnownBits::smax(Known, Known2);
break;
case Intrinsic::ptrmask: {
- computeKnownBits(I->getOperand(0), Known, Depth + 1, Q);
+ computeKnownBits(I->getOperand(0), DemandedElts, Known, Depth + 1, Q);
const Value *Mask = I->getOperand(1);
Known2 = KnownBits(Mask->getType()->getScalarSizeInBits());
- computeKnownBits(Mask, Known2, Depth + 1, Q);
+ computeKnownBits(Mask, DemandedElts, Known2, Depth + 1, Q);
// TODO: 1-extend would be more precise.
Known &= Known2.anyextOrTrunc(BitWidth);
break;
@@ -2648,7 +2660,7 @@ static bool isNonZeroSub(const APInt &DemandedElts, unsigned Depth,
if (C->isNullValue() && isKnownNonZero(Y, DemandedElts, Q, Depth))
return true;
- return ::isKnownNonEqual(X, Y, Depth, Q);
+ return ::isKnownNonEqual(X, Y, DemandedElts, Depth, Q);
}
static bool isNonZeroMul(const APInt &DemandedElts, unsigned Depth,
@@ -2772,8 +2784,11 @@ static bool isKnownNonZeroFromOperator(const Operator *I,
// This all implies the 2 i16 elements are non-zero.
Type *FromTy = I->getOperand(0)->getType();
if ((FromTy->isIntOrIntVectorTy() || FromTy->isPtrOrPtrVectorTy()) &&
- (BitWidth % getBitWidth(FromTy->getScalarType(), Q.DL)) == 0)
+ (BitWidth % getBitWidth(FromTy->getScalarType(), Q.DL)) == 0) {
+ if (match(I, m_ElementWiseBitCast(m_Value())))
+ return isKnownNonZero(I->getOperand(0), DemandedElts, Q, Depth);
return isKnownNonZero(I->getOperand(0), Q, Depth);
+ }
} break;
case Instruction::IntToPtr:
// Note that we have to take special care to avoid looking through
@@ -2782,7 +2797,7 @@ static bool isKnownNonZeroFromOperator(const Operator *I,
if (!isa<ScalableVectorType>(I->getType()) &&
Q.DL.getTypeSizeInBits(I->getOperand(0)->getType()).getFixedValue() <=
Q.DL.getTypeSizeInBits(I->getType()).getFixedValue())
- return isKnownNonZero(I->getOperand(0), Q, Depth);
+ return isKnownNonZero(I->getOperand(0), DemandedElts, Q, Depth);
break;
case Instruction::PtrToInt:
// Similar to int2ptr above, we can look through ptr2int here if the cast
@@ -2790,13 +2805,13 @@ static bool isKnownNonZeroFromOperator(const Operator *I,
if (!isa<ScalableVectorType>(I->getType()) &&
Q.DL.getTypeSizeInBits(I->getOperand(0)->getType()).getFixedValue() <=
Q.DL.getTypeSizeInBits(I->getType()).getFixedValue())
- return isKnownNonZero(I->getOperand(0), Q, Depth);
+ return isKnownNonZero(I->getOperand(0), DemandedElts, Q, Depth);
break;
case Instruction::Trunc:
// nuw/nsw trunc preserves zero/non-zero status of input.
if (auto *TI = dyn_cast<TruncInst>(I))
if (TI->hasNoSignedWrap() || TI->hasNoUnsignedWrap())
- return isKnownNonZero(TI->getOperand(0), Q, Depth);
+ return isKnownNonZero(TI->getOperand(0), DemandedElts, Q, Depth);
break;
case Instruction::Sub:
@@ -2817,13 +2832,13 @@ static bool isKnownNonZeroFromOperator(const Operator *I,
case Instruction::SExt:
case Instruction::ZExt:
// ext X != 0 if X != 0.
- return isKnownNonZero(I->getOperand(0), Q, Depth);
+ return isKnownNonZero(I->getOperand(0), DemandedElts, Q, Depth);
case Instruction::Shl: {
// shl nsw/nuw can't remove any non-zero bits.
const OverflowingBinaryOperator *BO = cast<OverflowingBinaryOperator>(I);
if (Q.IIQ.hasNoUnsignedWrap(BO) || Q.IIQ.hasNoSignedWrap(BO))
- return isKnownNonZero(I->getOperand(0), Q, Depth);
+ return isKnownNonZero(I->getOperand(0), DemandedElts, Q, Depth);
// shl X, Y != 0 if X is odd. Note that the value of the shift is undefined
// if the lowest bit is shifted off the end.
@@ -2839,7 +2854,7 @@ static bool isKnownNonZeroFromOperator(const Operator *I,
// shr exact can only shift out zero bits.
const PossiblyExactOperator *BO = cast<PossiblyExactOperator>(I);
if (BO->isExact())
- return isKnownNonZero(I->getOperand(0), Q, Depth);
+ return isKnownNonZero(I->getOperand(0), DemandedElts, Q, Depth);
// shr X, Y != 0 if X is negative. Note that the value of the shift is not
// defined if the sign bit is shifted off the end.
@@ -3094,6 +3109,8 @@ static bool isKnownNonZeroFromOperator(const Operator *I,
/*NSW=*/true, /* NUW=*/false);
// Vec reverse preserves zero/non-zero status from input vec.
case Intrinsic::vector_reverse:
+ return isKnownNonZero(II->getArgOperand(0), DemandedElts.reverseBits(),
+ Q, Depth);
// umin/smin/smax/smin/or of all non-zero elements is always non-zero.
case Intrinsic::vector_reduce_or:
case Intrinsic::vector_reduce_umax:
@@ -3418,7 +3435,8 @@ getInvertibleOperands(const Operator *Op1,
/// Only handle a small subset of binops where (binop V2, X) with non-zero X
/// implies V2 != V1.
static bool isModifyingBinopOfNonZero(const Value *V1, const Value *V2,
- unsigned Depth, const SimplifyQuery &Q) {
+ const APInt &DemandedElts, unsigned Depth,
+ const SimplifyQuery &Q) {
const BinaryOperator *BO = dyn_cast<BinaryOperator>(V1);
if (!BO)
return false;
@@ -3438,39 +3456,43 @@ static bool isModifyingBinopOfNonZero(const Value *V1, con...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/99080
More information about the llvm-commits
mailing list