[llvm] goldsteinn/demanded elts consistent (PR #99080)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Jul 16 11:50:04 PDT 2024
https://github.com/goldsteinn created https://github.com/llvm/llvm-project/pull/99080
- **[ValueTracking] Consistently propagate `DemandedElts` is `computeKnownBits`**
- **[ValueTracking] Consistently propagate `DemandedElts` is `isKnownNonZero`**
- **[ValueTracking] Consistently propagate `DemandedElts` is `ComputeNumSignBits`**
- **[ValueTracking] Consistently propagate `DemandedElts` is `computeKnownFPClass`**
>From 012caddcb9277297dd9beb0cd74f55ab4b380bac Mon Sep 17 00:00:00 2001
From: Noah Goldstein <goldstein.w.n at gmail.com>
Date: Tue, 16 Jul 2024 20:27:55 +0800
Subject: [PATCH 1/4] [ValueTracking] Consistently propagate `DemandedElts` is
`computeKnownBits`
---
llvm/lib/Analysis/ValueTracking.cpp | 88 +++++++++++++++--------------
1 file changed, 47 insertions(+), 41 deletions(-)
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index f8ec868398323..8bc0e7f23b81c 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -1091,15 +1091,15 @@ static void computeKnownBitsFromOperator(const Operator *I,
break;
}
case Instruction::UDiv: {
- computeKnownBits(I->getOperand(0), Known, Depth + 1, Q);
- computeKnownBits(I->getOperand(1), Known2, Depth + 1, Q);
+ computeKnownBits(I->getOperand(0), DemandedElts, Known, Depth + 1, Q);
+ computeKnownBits(I->getOperand(1), DemandedElts, Known2, Depth + 1, Q);
Known =
KnownBits::udiv(Known, Known2, Q.IIQ.isExact(cast<BinaryOperator>(I)));
break;
}
case Instruction::SDiv: {
- computeKnownBits(I->getOperand(0), Known, Depth + 1, Q);
- computeKnownBits(I->getOperand(1), Known2, Depth + 1, Q);
+ computeKnownBits(I->getOperand(0), DemandedElts, Known, Depth + 1, Q);
+ computeKnownBits(I->getOperand(1), DemandedElts, Known2, Depth + 1, Q);
Known =
KnownBits::sdiv(Known, Known2, Q.IIQ.isExact(cast<BinaryOperator>(I)));
break;
@@ -1107,7 +1107,7 @@ static void computeKnownBitsFromOperator(const Operator *I,
case Instruction::Select: {
auto ComputeForArm = [&](Value *Arm, bool Invert) {
KnownBits Res(Known.getBitWidth());
- computeKnownBits(Arm, Res, Depth + 1, Q);
+ computeKnownBits(Arm, DemandedElts, Res, Depth + 1, Q);
adjustKnownBitsForSelectArm(Res, I->getOperand(0), Arm, Invert, Depth, Q);
return Res;
};
@@ -1142,7 +1142,7 @@ static void computeKnownBitsFromOperator(const Operator *I,
assert(SrcBitWidth && "SrcBitWidth can't be zero");
Known = Known.anyextOrTrunc(SrcBitWidth);
- computeKnownBits(I->getOperand(0), Known, Depth + 1, Q);
+ computeKnownBits(I->getOperand(0), DemandedElts, Known, Depth + 1, Q);
if (auto *Inst = dyn_cast<PossiblyNonNegInst>(I);
Inst && Inst->hasNonNeg() && !Known.isNegative())
Known.makeNonNegative();
@@ -1164,7 +1164,8 @@ static void computeKnownBitsFromOperator(const Operator *I,
if (match(I, m_ElementWiseBitCast(m_Value(V))) &&
V->getType()->isFPOrFPVectorTy()) {
Type *FPType = V->getType()->getScalarType();
- KnownFPClass Result = computeKnownFPClass(V, fcAllFlags, Depth + 1, Q);
+ KnownFPClass Result =
+ computeKnownFPClass(V, DemandedElts, fcAllFlags, Depth + 1, Q);
FPClassTest FPClasses = Result.KnownFPClasses;
// TODO: Treat it as zero/poison if the use of I is unreachable.
@@ -1245,7 +1246,7 @@ static void computeKnownBitsFromOperator(const Operator *I,
unsigned SrcBitWidth = I->getOperand(0)->getType()->getScalarSizeInBits();
Known = Known.trunc(SrcBitWidth);
- computeKnownBits(I->getOperand(0), Known, Depth + 1, Q);
+ computeKnownBits(I->getOperand(0), DemandedElts, Known, Depth + 1, Q);
// If the sign bit of the input is known set or clear, then we know the
// top bits of the result.
Known = Known.sext(BitWidth);
@@ -1305,14 +1306,14 @@ static void computeKnownBitsFromOperator(const Operator *I,
break;
}
case Instruction::SRem:
- computeKnownBits(I->getOperand(0), Known, Depth + 1, Q);
- computeKnownBits(I->getOperand(1), Known2, Depth + 1, Q);
+ computeKnownBits(I->getOperand(0), DemandedElts, Known, Depth + 1, Q);
+ computeKnownBits(I->getOperand(1), DemandedElts, Known2, Depth + 1, Q);
Known = KnownBits::srem(Known, Known2);
break;
case Instruction::URem:
- computeKnownBits(I->getOperand(0), Known, Depth + 1, Q);
- computeKnownBits(I->getOperand(1), Known2, Depth + 1, Q);
+ computeKnownBits(I->getOperand(0), DemandedElts, Known, Depth + 1, Q);
+ computeKnownBits(I->getOperand(1), DemandedElts, Known2, Depth + 1, Q);
Known = KnownBits::urem(Known, Known2);
break;
case Instruction::Alloca:
@@ -1465,17 +1466,17 @@ static void computeKnownBitsFromOperator(const Operator *I,
unsigned OpNum = P->getOperand(0) == R ? 0 : 1;
Instruction *RInst = P->getIncomingBlock(OpNum)->getTerminator();
- Instruction *LInst = P->getIncomingBlock(1-OpNum)->getTerminator();
+ Instruction *LInst = P->getIncomingBlock(1 - OpNum)->getTerminator();
// Ok, we have a PHI of the form L op= R. Check for low
// zero bits.
RecQ.CxtI = RInst;
- computeKnownBits(R, Known2, Depth + 1, RecQ);
+ computeKnownBits(R, DemandedElts, Known2, Depth + 1, RecQ);
// We need to take the minimum number of known bits
KnownBits Known3(BitWidth);
RecQ.CxtI = LInst;
- computeKnownBits(L, Known3, Depth + 1, RecQ);
+ computeKnownBits(L, DemandedElts, Known3, Depth + 1, RecQ);
Known.Zero.setLowBits(std::min(Known2.countMinTrailingZeros(),
Known3.countMinTrailingZeros()));
@@ -1548,7 +1549,8 @@ static void computeKnownBitsFromOperator(const Operator *I,
// want to waste time spinning around in loops.
// TODO: See if we can base recursion limiter on number of incoming phi
// edges so we don't overly clamp analysis.
- computeKnownBits(IncValue, Known2, MaxAnalysisRecursionDepth - 1, RecQ);
+ computeKnownBits(IncValue, DemandedElts, Known2,
+ MaxAnalysisRecursionDepth - 1, RecQ);
// See if we can further use a conditional branch into the phi
// to help us determine the range of the value.
@@ -1619,9 +1621,10 @@ static void computeKnownBitsFromOperator(const Operator *I,
}
if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) {
switch (II->getIntrinsicID()) {
- default: break;
+ default:
+ break;
case Intrinsic::abs: {
- computeKnownBits(I->getOperand(0), Known2, Depth + 1, Q);
+ computeKnownBits(I->getOperand(0), DemandedElts, Known2, Depth + 1, Q);
bool IntMinIsPoison = match(II->getArgOperand(1), m_One());
Known = Known2.abs(IntMinIsPoison);
break;
@@ -1637,7 +1640,7 @@ static void computeKnownBitsFromOperator(const Operator *I,
Known.One |= Known2.One.byteSwap();
break;
case Intrinsic::ctlz: {
- computeKnownBits(I->getOperand(0), Known2, Depth + 1, Q);
+ computeKnownBits(I->getOperand(0), DemandedElts, Known2, Depth + 1, Q);
// If we have a known 1, its position is our upper bound.
unsigned PossibleLZ = Known2.countMaxLeadingZeros();
// If this call is poison for 0 input, the result will be less than 2^n.
@@ -1648,7 +1651,7 @@ static void computeKnownBitsFromOperator(const Operator *I,
break;
}
case Intrinsic::cttz: {
- computeKnownBits(I->getOperand(0), Known2, Depth + 1, Q);
+ computeKnownBits(I->getOperand(0), DemandedElts, Known2, Depth + 1, Q);
// If we have a known 1, its position is our upper bound.
unsigned PossibleTZ = Known2.countMaxTrailingZeros();
// If this call is poison for 0 input, the result will be less than 2^n.
@@ -1659,7 +1662,7 @@ static void computeKnownBitsFromOperator(const Operator *I,
break;
}
case Intrinsic::ctpop: {
- computeKnownBits(I->getOperand(0), Known2, Depth + 1, Q);
+ computeKnownBits(I->getOperand(0), DemandedElts, Known2, Depth + 1, Q);
// We can bound the space the count needs. Also, bits known to be zero
// can't contribute to the population.
unsigned BitsPossiblySet = Known2.countMaxPopulation();
@@ -1681,8 +1684,8 @@ static void computeKnownBitsFromOperator(const Operator *I,
ShiftAmt = BitWidth - ShiftAmt;
KnownBits Known3(BitWidth);
- computeKnownBits(I->getOperand(0), Known2, Depth + 1, Q);
- computeKnownBits(I->getOperand(1), Known3, Depth + 1, Q);
+ computeKnownBits(I->getOperand(0), DemandedElts, Known2, Depth + 1, Q);
+ computeKnownBits(I->getOperand(1), DemandedElts, Known3, Depth + 1, Q);
Known.Zero =
Known2.Zero.shl(ShiftAmt) | Known3.Zero.lshr(BitWidth - ShiftAmt);
@@ -1691,27 +1694,30 @@ static void computeKnownBitsFromOperator(const Operator *I,
break;
}
case Intrinsic::uadd_sat:
- computeKnownBits(I->getOperand(0), Known, Depth + 1, Q);
- computeKnownBits(I->getOperand(1), Known2, Depth + 1, Q);
+ computeKnownBits(I->getOperand(0), DemandedElts, Known, Depth + 1, Q);
+ computeKnownBits(I->getOperand(1), DemandedElts, Known2, Depth + 1, Q);
Known = KnownBits::uadd_sat(Known, Known2);
break;
case Intrinsic::usub_sat:
- computeKnownBits(I->getOperand(0), Known, Depth + 1, Q);
- computeKnownBits(I->getOperand(1), Known2, Depth + 1, Q);
+ computeKnownBits(I->getOperand(0), DemandedElts, Known, Depth + 1, Q);
+ computeKnownBits(I->getOperand(1), DemandedElts, Known2, Depth + 1, Q);
Known = KnownBits::usub_sat(Known, Known2);
break;
case Intrinsic::sadd_sat:
- computeKnownBits(I->getOperand(0), Known, Depth + 1, Q);
- computeKnownBits(I->getOperand(1), Known2, Depth + 1, Q);
+ computeKnownBits(I->getOperand(0), DemandedElts, Known, Depth + 1, Q);
+ computeKnownBits(I->getOperand(1), DemandedElts, Known2, Depth + 1, Q);
Known = KnownBits::sadd_sat(Known, Known2);
break;
case Intrinsic::ssub_sat:
- computeKnownBits(I->getOperand(0), Known, Depth + 1, Q);
- computeKnownBits(I->getOperand(1), Known2, Depth + 1, Q);
+ computeKnownBits(I->getOperand(0), DemandedElts, Known, Depth + 1, Q);
+ computeKnownBits(I->getOperand(1), DemandedElts, Known2, Depth + 1, Q);
Known = KnownBits::ssub_sat(Known, Known2);
break;
// Vec reverse preserves bits from input vec.
case Intrinsic::vector_reverse:
+ computeKnownBits(I->getOperand(0), DemandedElts.reverseBits(), Known,
+ Depth + 1, Q);
+ break;
// for min/max/and/or reduce, any bit common to each element in the
// input vec is set in the output.
case Intrinsic::vector_reduce_and:
@@ -1738,31 +1744,31 @@ static void computeKnownBitsFromOperator(const Operator *I,
break;
}
case Intrinsic::umin:
- computeKnownBits(I->getOperand(0), Known, Depth + 1, Q);
- computeKnownBits(I->getOperand(1), Known2, Depth + 1, Q);
+ computeKnownBits(I->getOperand(0), DemandedElts, Known, Depth + 1, Q);
+ computeKnownBits(I->getOperand(1), DemandedElts, Known2, Depth + 1, Q);
Known = KnownBits::umin(Known, Known2);
break;
case Intrinsic::umax:
- computeKnownBits(I->getOperand(0), Known, Depth + 1, Q);
- computeKnownBits(I->getOperand(1), Known2, Depth + 1, Q);
+ computeKnownBits(I->getOperand(0), DemandedElts, Known, Depth + 1, Q);
+ computeKnownBits(I->getOperand(1), DemandedElts, Known2, Depth + 1, Q);
Known = KnownBits::umax(Known, Known2);
break;
case Intrinsic::smin:
- computeKnownBits(I->getOperand(0), Known, Depth + 1, Q);
- computeKnownBits(I->getOperand(1), Known2, Depth + 1, Q);
+ computeKnownBits(I->getOperand(0), DemandedElts, Known, Depth + 1, Q);
+ computeKnownBits(I->getOperand(1), DemandedElts, Known2, Depth + 1, Q);
Known = KnownBits::smin(Known, Known2);
break;
case Intrinsic::smax:
- computeKnownBits(I->getOperand(0), Known, Depth + 1, Q);
- computeKnownBits(I->getOperand(1), Known2, Depth + 1, Q);
+ computeKnownBits(I->getOperand(0), DemandedElts, Known, Depth + 1, Q);
+ computeKnownBits(I->getOperand(1), DemandedElts, Known2, Depth + 1, Q);
Known = KnownBits::smax(Known, Known2);
break;
case Intrinsic::ptrmask: {
- computeKnownBits(I->getOperand(0), Known, Depth + 1, Q);
+ computeKnownBits(I->getOperand(0), DemandedElts, Known, Depth + 1, Q);
const Value *Mask = I->getOperand(1);
Known2 = KnownBits(Mask->getType()->getScalarSizeInBits());
- computeKnownBits(Mask, Known2, Depth + 1, Q);
+ computeKnownBits(Mask, DemandedElts, Known2, Depth + 1, Q);
// TODO: 1-extend would be more precise.
Known &= Known2.anyextOrTrunc(BitWidth);
break;
>From d2cf153517b9c21e04418dc0fc3f7d7fd7441c4a Mon Sep 17 00:00:00 2001
From: Noah Goldstein <goldstein.w.n at gmail.com>
Date: Tue, 16 Jul 2024 20:38:18 +0800
Subject: [PATCH 2/4] [ValueTracking] Consistently propagate `DemandedElts` is
`isKnownNonZero`
---
llvm/lib/Analysis/ValueTracking.cpp | 90 ++++++++++++++++++-----------
1 file changed, 56 insertions(+), 34 deletions(-)
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 8bc0e7f23b81c..b715ab6eabf70 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -303,15 +303,21 @@ bool llvm::isKnownNegative(const Value *V, const SimplifyQuery &SQ,
return computeKnownBits(V, Depth, SQ).isNegative();
}
-static bool isKnownNonEqual(const Value *V1, const Value *V2, unsigned Depth,
+static bool isKnownNonEqual(const Value *V1, const Value *V2,
+ const APInt &DemandedElts, unsigned Depth,
const SimplifyQuery &Q);
bool llvm::isKnownNonEqual(const Value *V1, const Value *V2,
const DataLayout &DL, AssumptionCache *AC,
const Instruction *CxtI, const DominatorTree *DT,
bool UseInstrInfo) {
+ assert(V1->getType() == V2->getType() &&
+ "Testing equality of non-equal types!");
+ auto *FVTy = dyn_cast<FixedVectorType>(V1->getType());
+ APInt DemandedElts =
+ FVTy ? APInt::getAllOnes(FVTy->getNumElements()) : APInt(1, 1);
return ::isKnownNonEqual(
- V1, V2, 0,
+ V1, V2, DemandedElts, 0,
SimplifyQuery(DL, DT, AC, safeCxtI(V2, V1, CxtI), UseInstrInfo));
}
@@ -2654,7 +2660,7 @@ static bool isNonZeroSub(const APInt &DemandedElts, unsigned Depth,
if (C->isNullValue() && isKnownNonZero(Y, DemandedElts, Q, Depth))
return true;
- return ::isKnownNonEqual(X, Y, Depth, Q);
+ return ::isKnownNonEqual(X, Y, DemandedElts, Depth, Q);
}
static bool isNonZeroMul(const APInt &DemandedElts, unsigned Depth,
@@ -2778,8 +2784,11 @@ static bool isKnownNonZeroFromOperator(const Operator *I,
// This all implies the 2 i16 elements are non-zero.
Type *FromTy = I->getOperand(0)->getType();
if ((FromTy->isIntOrIntVectorTy() || FromTy->isPtrOrPtrVectorTy()) &&
- (BitWidth % getBitWidth(FromTy->getScalarType(), Q.DL)) == 0)
+ (BitWidth % getBitWidth(FromTy->getScalarType(), Q.DL)) == 0) {
+ if (match(I, m_ElementWiseBitCast(m_Value())))
+ return isKnownNonZero(I->getOperand(0), DemandedElts, Q, Depth);
return isKnownNonZero(I->getOperand(0), Q, Depth);
+ }
} break;
case Instruction::IntToPtr:
// Note that we have to take special care to avoid looking through
@@ -2788,7 +2797,7 @@ static bool isKnownNonZeroFromOperator(const Operator *I,
if (!isa<ScalableVectorType>(I->getType()) &&
Q.DL.getTypeSizeInBits(I->getOperand(0)->getType()).getFixedValue() <=
Q.DL.getTypeSizeInBits(I->getType()).getFixedValue())
- return isKnownNonZero(I->getOperand(0), Q, Depth);
+ return isKnownNonZero(I->getOperand(0), DemandedElts, Q, Depth);
break;
case Instruction::PtrToInt:
// Similar to int2ptr above, we can look through ptr2int here if the cast
@@ -2796,13 +2805,13 @@ static bool isKnownNonZeroFromOperator(const Operator *I,
if (!isa<ScalableVectorType>(I->getType()) &&
Q.DL.getTypeSizeInBits(I->getOperand(0)->getType()).getFixedValue() <=
Q.DL.getTypeSizeInBits(I->getType()).getFixedValue())
- return isKnownNonZero(I->getOperand(0), Q, Depth);
+ return isKnownNonZero(I->getOperand(0), DemandedElts, Q, Depth);
break;
case Instruction::Trunc:
// nuw/nsw trunc preserves zero/non-zero status of input.
if (auto *TI = dyn_cast<TruncInst>(I))
if (TI->hasNoSignedWrap() || TI->hasNoUnsignedWrap())
- return isKnownNonZero(TI->getOperand(0), Q, Depth);
+ return isKnownNonZero(TI->getOperand(0), DemandedElts, Q, Depth);
break;
case Instruction::Sub:
@@ -2823,13 +2832,13 @@ static bool isKnownNonZeroFromOperator(const Operator *I,
case Instruction::SExt:
case Instruction::ZExt:
// ext X != 0 if X != 0.
- return isKnownNonZero(I->getOperand(0), Q, Depth);
+ return isKnownNonZero(I->getOperand(0), DemandedElts, Q, Depth);
case Instruction::Shl: {
// shl nsw/nuw can't remove any non-zero bits.
const OverflowingBinaryOperator *BO = cast<OverflowingBinaryOperator>(I);
if (Q.IIQ.hasNoUnsignedWrap(BO) || Q.IIQ.hasNoSignedWrap(BO))
- return isKnownNonZero(I->getOperand(0), Q, Depth);
+ return isKnownNonZero(I->getOperand(0), DemandedElts, Q, Depth);
// shl X, Y != 0 if X is odd. Note that the value of the shift is undefined
// if the lowest bit is shifted off the end.
@@ -2845,7 +2854,7 @@ static bool isKnownNonZeroFromOperator(const Operator *I,
// shr exact can only shift out zero bits.
const PossiblyExactOperator *BO = cast<PossiblyExactOperator>(I);
if (BO->isExact())
- return isKnownNonZero(I->getOperand(0), Q, Depth);
+ return isKnownNonZero(I->getOperand(0), DemandedElts, Q, Depth);
// shr X, Y != 0 if X is negative. Note that the value of the shift is not
// defined if the sign bit is shifted off the end.
@@ -3100,6 +3109,8 @@ static bool isKnownNonZeroFromOperator(const Operator *I,
/*NSW=*/true, /* NUW=*/false);
// Vec reverse preserves zero/non-zero status from input vec.
case Intrinsic::vector_reverse:
+ return isKnownNonZero(II->getArgOperand(0), DemandedElts.reverseBits(),
+ Q, Depth);
// umin/smin/smax/smin/or of all non-zero elements is always non-zero.
case Intrinsic::vector_reduce_or:
case Intrinsic::vector_reduce_umax:
@@ -3424,7 +3435,8 @@ getInvertibleOperands(const Operator *Op1,
/// Only handle a small subset of binops where (binop V2, X) with non-zero X
/// implies V2 != V1.
static bool isModifyingBinopOfNonZero(const Value *V1, const Value *V2,
- unsigned Depth, const SimplifyQuery &Q) {
+ const APInt &DemandedElts, unsigned Depth,
+ const SimplifyQuery &Q) {
const BinaryOperator *BO = dyn_cast<BinaryOperator>(V1);
if (!BO)
return false;
@@ -3444,39 +3456,43 @@ static bool isModifyingBinopOfNonZero(const Value *V1, const Value *V2,
Op = BO->getOperand(0);
else
return false;
- return isKnownNonZero(Op, Q, Depth + 1);
+ return isKnownNonZero(Op, DemandedElts, Q, Depth + 1);
}
return false;
}
/// Return true if V2 == V1 * C, where V1 is known non-zero, C is not 0/1 and
/// the multiplication is nuw or nsw.
-static bool isNonEqualMul(const Value *V1, const Value *V2, unsigned Depth,
+static bool isNonEqualMul(const Value *V1, const Value *V2,
+ const APInt &DemandedElts, unsigned Depth,
const SimplifyQuery &Q) {
if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(V2)) {
const APInt *C;
return match(OBO, m_Mul(m_Specific(V1), m_APInt(C))) &&
(OBO->hasNoUnsignedWrap() || OBO->hasNoSignedWrap()) &&
- !C->isZero() && !C->isOne() && isKnownNonZero(V1, Q, Depth + 1);
+ !C->isZero() && !C->isOne() &&
+ isKnownNonZero(V1, DemandedElts, Q, Depth + 1);
}
return false;
}
/// Return true if V2 == V1 << C, where V1 is known non-zero, C is not 0 and
/// the shift is nuw or nsw.
-static bool isNonEqualShl(const Value *V1, const Value *V2, unsigned Depth,
+static bool isNonEqualShl(const Value *V1, const Value *V2,
+ const APInt &DemandedElts, unsigned Depth,
const SimplifyQuery &Q) {
if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(V2)) {
const APInt *C;
return match(OBO, m_Shl(m_Specific(V1), m_APInt(C))) &&
(OBO->hasNoUnsignedWrap() || OBO->hasNoSignedWrap()) &&
- !C->isZero() && isKnownNonZero(V1, Q, Depth + 1);
+ !C->isZero() && isKnownNonZero(V1, DemandedElts, Q, Depth + 1);
}
return false;
}
static bool isNonEqualPHIs(const PHINode *PN1, const PHINode *PN2,
- unsigned Depth, const SimplifyQuery &Q) {
+ const APInt &DemandedElts, unsigned Depth,
+ const SimplifyQuery &Q) {
// Check two PHIs are in same block.
if (PN1->getParent() != PN2->getParent())
return false;
@@ -3498,14 +3514,15 @@ static bool isNonEqualPHIs(const PHINode *PN1, const PHINode *PN2,
SimplifyQuery RecQ = Q;
RecQ.CxtI = IncomBB->getTerminator();
- if (!isKnownNonEqual(IV1, IV2, Depth + 1, RecQ))
+ if (!isKnownNonEqual(IV1, IV2, DemandedElts, Depth + 1, RecQ))
return false;
UsedFullRecursion = true;
}
return true;
}
-static bool isNonEqualSelect(const Value *V1, const Value *V2, unsigned Depth,
+static bool isNonEqualSelect(const Value *V1, const Value *V2,
+ const APInt &DemandedElts, unsigned Depth,
const SimplifyQuery &Q) {
const SelectInst *SI1 = dyn_cast<SelectInst>(V1);
if (!SI1)
@@ -3516,12 +3533,12 @@ static bool isNonEqualSelect(const Value *V1, const Value *V2, unsigned Depth,
const Value *Cond2 = SI2->getCondition();
if (Cond1 == Cond2)
return isKnownNonEqual(SI1->getTrueValue(), SI2->getTrueValue(),
- Depth + 1, Q) &&
+ DemandedElts, Depth + 1, Q) &&
isKnownNonEqual(SI1->getFalseValue(), SI2->getFalseValue(),
- Depth + 1, Q);
+ DemandedElts, Depth + 1, Q);
}
- return isKnownNonEqual(SI1->getTrueValue(), V2, Depth + 1, Q) &&
- isKnownNonEqual(SI1->getFalseValue(), V2, Depth + 1, Q);
+ return isKnownNonEqual(SI1->getTrueValue(), V2, DemandedElts, Depth + 1, Q) &&
+ isKnownNonEqual(SI1->getFalseValue(), V2, DemandedElts, Depth + 1, Q);
}
// Check to see if A is both a GEP and is the incoming value for a PHI in the
@@ -3577,7 +3594,8 @@ static bool isNonEqualPointersWithRecursiveGEP(const Value *A, const Value *B,
}
/// Return true if it is known that V1 != V2.
-static bool isKnownNonEqual(const Value *V1, const Value *V2, unsigned Depth,
+static bool isKnownNonEqual(const Value *V1, const Value *V2,
+ const APInt &DemandedElts, unsigned Depth,
const SimplifyQuery &Q) {
if (V1 == V2)
return false;
@@ -3595,40 +3613,44 @@ static bool isKnownNonEqual(const Value *V1, const Value *V2, unsigned Depth,
auto *O2 = dyn_cast<Operator>(V2);
if (O1 && O2 && O1->getOpcode() == O2->getOpcode()) {
if (auto Values = getInvertibleOperands(O1, O2))
- return isKnownNonEqual(Values->first, Values->second, Depth + 1, Q);
+ return isKnownNonEqual(Values->first, Values->second, DemandedElts,
+ Depth + 1, Q);
if (const PHINode *PN1 = dyn_cast<PHINode>(V1)) {
const PHINode *PN2 = cast<PHINode>(V2);
// FIXME: This is missing a generalization to handle the case where one is
// a PHI and another one isn't.
- if (isNonEqualPHIs(PN1, PN2, Depth, Q))
+ if (isNonEqualPHIs(PN1, PN2, DemandedElts, Depth, Q))
return true;
};
}
- if (isModifyingBinopOfNonZero(V1, V2, Depth, Q) ||
- isModifyingBinopOfNonZero(V2, V1, Depth, Q))
+ if (isModifyingBinopOfNonZero(V1, V2, DemandedElts, Depth, Q) ||
+ isModifyingBinopOfNonZero(V2, V1, DemandedElts, Depth, Q))
return true;
- if (isNonEqualMul(V1, V2, Depth, Q) || isNonEqualMul(V2, V1, Depth, Q))
+ if (isNonEqualMul(V1, V2, DemandedElts, Depth, Q) ||
+ isNonEqualMul(V2, V1, DemandedElts, Depth, Q))
return true;
- if (isNonEqualShl(V1, V2, Depth, Q) || isNonEqualShl(V2, V1, Depth, Q))
+ if (isNonEqualShl(V1, V2, DemandedElts, Depth, Q) ||
+ isNonEqualShl(V2, V1, DemandedElts, Depth, Q))
return true;
if (V1->getType()->isIntOrIntVectorTy()) {
// Are any known bits in V1 contradictory to known bits in V2? If V1
// has a known zero where V2 has a known one, they must not be equal.
- KnownBits Known1 = computeKnownBits(V1, Depth, Q);
+ KnownBits Known1 = computeKnownBits(V1, DemandedElts, Depth, Q);
if (!Known1.isUnknown()) {
- KnownBits Known2 = computeKnownBits(V2, Depth, Q);
+ KnownBits Known2 = computeKnownBits(V2, DemandedElts, Depth, Q);
if (Known1.Zero.intersects(Known2.One) ||
Known2.Zero.intersects(Known1.One))
return true;
}
}
- if (isNonEqualSelect(V1, V2, Depth, Q) || isNonEqualSelect(V2, V1, Depth, Q))
+ if (isNonEqualSelect(V1, V2, DemandedElts, Depth, Q) ||
+ isNonEqualSelect(V2, V1, DemandedElts, Depth, Q))
return true;
if (isNonEqualPointersWithRecursiveGEP(V1, V2, Q) ||
@@ -3640,7 +3662,7 @@ static bool isKnownNonEqual(const Value *V1, const Value *V2, unsigned Depth,
// Check PtrToInt type matches the pointer size.
if (match(V1, m_PtrToIntSameSize(Q.DL, m_Value(A))) &&
match(V2, m_PtrToIntSameSize(Q.DL, m_Value(B))))
- return isKnownNonEqual(A, B, Depth + 1, Q);
+ return isKnownNonEqual(A, B, DemandedElts, Depth + 1, Q);
return false;
}
>From 2b88065d241568c2ef95701e8105980177bf5a2f Mon Sep 17 00:00:00 2001
From: Noah Goldstein <goldstein.w.n at gmail.com>
Date: Tue, 16 Jul 2024 20:40:39 +0800
Subject: [PATCH 3/4] [ValueTracking] Consistently propagate `DemandedElts` is
`ComputeNumSignBits`
---
llvm/lib/Analysis/ValueTracking.cpp | 67 +++++++++++++++++------------
1 file changed, 40 insertions(+), 27 deletions(-)
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index b715ab6eabf70..f54de030d3344 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -3801,7 +3801,8 @@ static unsigned ComputeNumSignBitsImpl(const Value *V,
default: break;
case Instruction::SExt:
Tmp = TyBits - U->getOperand(0)->getType()->getScalarSizeInBits();
- return ComputeNumSignBits(U->getOperand(0), Depth + 1, Q) + Tmp;
+ return ComputeNumSignBits(U->getOperand(0), DemandedElts, Depth + 1, Q) +
+ Tmp;
case Instruction::SDiv: {
const APInt *Denominator;
@@ -3813,7 +3814,8 @@ static unsigned ComputeNumSignBitsImpl(const Value *V,
break;
// Calculate the incoming numerator bits.
- unsigned NumBits = ComputeNumSignBits(U->getOperand(0), Depth + 1, Q);
+ unsigned NumBits =
+ ComputeNumSignBits(U->getOperand(0), DemandedElts, Depth + 1, Q);
// Add floor(log(C)) bits to the numerator bits.
return std::min(TyBits, NumBits + Denominator->logBase2());
@@ -3822,7 +3824,7 @@ static unsigned ComputeNumSignBitsImpl(const Value *V,
}
case Instruction::SRem: {
- Tmp = ComputeNumSignBits(U->getOperand(0), Depth + 1, Q);
+ Tmp = ComputeNumSignBits(U->getOperand(0), DemandedElts, Depth + 1, Q);
const APInt *Denominator;
// srem X, C -> we know that the result is within [-C+1,C) when C is a
@@ -3853,7 +3855,7 @@ static unsigned ComputeNumSignBitsImpl(const Value *V,
}
case Instruction::AShr: {
- Tmp = ComputeNumSignBits(U->getOperand(0), Depth + 1, Q);
+ Tmp = ComputeNumSignBits(U->getOperand(0), DemandedElts, Depth + 1, Q);
// ashr X, C -> adds C sign bits. Vectors too.
const APInt *ShAmt;
if (match(U->getOperand(1), m_APInt(ShAmt))) {
@@ -3869,7 +3871,7 @@ static unsigned ComputeNumSignBitsImpl(const Value *V,
const APInt *ShAmt;
if (match(U->getOperand(1), m_APInt(ShAmt))) {
// shl destroys sign bits.
- Tmp = ComputeNumSignBits(U->getOperand(0), Depth + 1, Q);
+ Tmp = ComputeNumSignBits(U->getOperand(0), DemandedElts, Depth + 1, Q);
if (ShAmt->uge(TyBits) || // Bad shift.
ShAmt->uge(Tmp)) break; // Shifted all sign bits out.
Tmp2 = ShAmt->getZExtValue();
@@ -3881,9 +3883,9 @@ static unsigned ComputeNumSignBitsImpl(const Value *V,
case Instruction::Or:
case Instruction::Xor: // NOT is handled here.
// Logical binary ops preserve the number of sign bits at the worst.
- Tmp = ComputeNumSignBits(U->getOperand(0), Depth + 1, Q);
+ Tmp = ComputeNumSignBits(U->getOperand(0), DemandedElts, Depth + 1, Q);
if (Tmp != 1) {
- Tmp2 = ComputeNumSignBits(U->getOperand(1), Depth + 1, Q);
+ Tmp2 = ComputeNumSignBits(U->getOperand(1), DemandedElts, Depth + 1, Q);
FirstAnswer = std::min(Tmp, Tmp2);
// We computed what we know about the sign bits as our first
// answer. Now proceed to the generic code that uses
@@ -3899,9 +3901,10 @@ static unsigned ComputeNumSignBitsImpl(const Value *V,
if (isSignedMinMaxClamp(U, X, CLow, CHigh))
return std::min(CLow->getNumSignBits(), CHigh->getNumSignBits());
- Tmp = ComputeNumSignBits(U->getOperand(1), Depth + 1, Q);
- if (Tmp == 1) break;
- Tmp2 = ComputeNumSignBits(U->getOperand(2), Depth + 1, Q);
+ Tmp = ComputeNumSignBits(U->getOperand(1), DemandedElts, Depth + 1, Q);
+ if (Tmp == 1)
+ break;
+ Tmp2 = ComputeNumSignBits(U->getOperand(2), DemandedElts, Depth + 1, Q);
return std::min(Tmp, Tmp2);
}
@@ -3915,7 +3918,7 @@ static unsigned ComputeNumSignBitsImpl(const Value *V,
if (const auto *CRHS = dyn_cast<Constant>(U->getOperand(1)))
if (CRHS->isAllOnesValue()) {
KnownBits Known(TyBits);
- computeKnownBits(U->getOperand(0), Known, Depth + 1, Q);
+ computeKnownBits(U->getOperand(0), DemandedElts, Known, Depth + 1, Q);
// If the input is known to be 0 or 1, the output is 0/-1, which is
// all sign bits set.
@@ -3928,19 +3931,21 @@ static unsigned ComputeNumSignBitsImpl(const Value *V,
return Tmp;
}
- Tmp2 = ComputeNumSignBits(U->getOperand(1), Depth + 1, Q);
- if (Tmp2 == 1) break;
+ Tmp2 = ComputeNumSignBits(U->getOperand(1), DemandedElts, Depth + 1, Q);
+ if (Tmp2 == 1)
+ break;
return std::min(Tmp, Tmp2) - 1;
case Instruction::Sub:
- Tmp2 = ComputeNumSignBits(U->getOperand(1), Depth + 1, Q);
- if (Tmp2 == 1) break;
+ Tmp2 = ComputeNumSignBits(U->getOperand(1), DemandedElts, Depth + 1, Q);
+ if (Tmp2 == 1)
+ break;
// Handle NEG.
if (const auto *CLHS = dyn_cast<Constant>(U->getOperand(0)))
if (CLHS->isNullValue()) {
KnownBits Known(TyBits);
- computeKnownBits(U->getOperand(1), Known, Depth + 1, Q);
+ computeKnownBits(U->getOperand(1), DemandedElts, Known, Depth + 1, Q);
// If the input is known to be 0 or 1, the output is 0/-1, which is
// all sign bits set.
if ((Known.Zero | 1).isAllOnes())
@@ -3957,17 +3962,22 @@ static unsigned ComputeNumSignBitsImpl(const Value *V,
// Sub can have at most one carry bit. Thus we know that the output
// is, at worst, one more bit than the inputs.
- Tmp = ComputeNumSignBits(U->getOperand(0), Depth + 1, Q);
- if (Tmp == 1) break;
+ Tmp = ComputeNumSignBits(U->getOperand(0), DemandedElts, Depth + 1, Q);
+ if (Tmp == 1)
+ break;
return std::min(Tmp, Tmp2) - 1;
case Instruction::Mul: {
// The output of the Mul can be at most twice the valid bits in the
// inputs.
- unsigned SignBitsOp0 = ComputeNumSignBits(U->getOperand(0), Depth + 1, Q);
- if (SignBitsOp0 == 1) break;
- unsigned SignBitsOp1 = ComputeNumSignBits(U->getOperand(1), Depth + 1, Q);
- if (SignBitsOp1 == 1) break;
+ unsigned SignBitsOp0 =
+ ComputeNumSignBits(U->getOperand(0), DemandedElts, Depth + 1, Q);
+ if (SignBitsOp0 == 1)
+ break;
+ unsigned SignBitsOp1 =
+ ComputeNumSignBits(U->getOperand(1), DemandedElts, Depth + 1, Q);
+ if (SignBitsOp1 == 1)
+ break;
unsigned OutValidBits =
(TyBits - SignBitsOp0 + 1) + (TyBits - SignBitsOp1 + 1);
return OutValidBits > TyBits ? 1 : TyBits - OutValidBits + 1;
@@ -3988,8 +3998,8 @@ static unsigned ComputeNumSignBitsImpl(const Value *V,
for (unsigned i = 0, e = NumIncomingValues; i != e; ++i) {
if (Tmp == 1) return Tmp;
RecQ.CxtI = PN->getIncomingBlock(i)->getTerminator();
- Tmp = std::min(
- Tmp, ComputeNumSignBits(PN->getIncomingValue(i), Depth + 1, RecQ));
+ Tmp = std::min(Tmp, ComputeNumSignBits(PN->getIncomingValue(i),
+ DemandedElts, Depth + 1, RecQ));
}
return Tmp;
}
@@ -4050,10 +4060,13 @@ static unsigned ComputeNumSignBitsImpl(const Value *V,
case Instruction::Call: {
if (const auto *II = dyn_cast<IntrinsicInst>(U)) {
switch (II->getIntrinsicID()) {
- default: break;
+ default:
+ break;
case Intrinsic::abs:
- Tmp = ComputeNumSignBits(U->getOperand(0), Depth + 1, Q);
- if (Tmp == 1) break;
+ Tmp =
+ ComputeNumSignBits(U->getOperand(0), DemandedElts, Depth + 1, Q);
+ if (Tmp == 1)
+ break;
// Absolute value reduces number of sign bits by at most 1.
return Tmp - 1;
>From 0df283c8117daebd019927aff76f78f43cfb5569 Mon Sep 17 00:00:00 2001
From: Noah Goldstein <goldstein.w.n at gmail.com>
Date: Tue, 16 Jul 2024 21:30:36 +0800
Subject: [PATCH 4/4] [ValueTracking] Consistently propagate `DemandedElts` is
`computeKnownFPClass`
---
llvm/include/llvm/Analysis/ValueTracking.h | 22 +++++++++++++++++-----
llvm/lib/Analysis/ValueTracking.cpp | 5 +++--
2 files changed, 20 insertions(+), 7 deletions(-)
diff --git a/llvm/include/llvm/Analysis/ValueTracking.h b/llvm/include/llvm/Analysis/ValueTracking.h
index 354ad5bc95317..2c2f965a3cd6f 100644
--- a/llvm/include/llvm/Analysis/ValueTracking.h
+++ b/llvm/include/llvm/Analysis/ValueTracking.h
@@ -526,16 +526,17 @@ inline KnownFPClass computeKnownFPClass(
}
/// Wrapper to account for known fast math flags at the use instruction.
-inline KnownFPClass computeKnownFPClass(const Value *V, FastMathFlags FMF,
- FPClassTest InterestedClasses,
- unsigned Depth,
- const SimplifyQuery &SQ) {
+inline KnownFPClass
+computeKnownFPClass(const Value *V, const APInt &DemandedElts,
+ FastMathFlags FMF, FPClassTest InterestedClasses,
+ unsigned Depth, const SimplifyQuery &SQ) {
if (FMF.noNaNs())
InterestedClasses &= ~fcNan;
if (FMF.noInfs())
InterestedClasses &= ~fcInf;
- KnownFPClass Result = computeKnownFPClass(V, InterestedClasses, Depth, SQ);
+ KnownFPClass Result =
+ computeKnownFPClass(V, DemandedElts, InterestedClasses, Depth, SQ);
if (FMF.noNaNs())
Result.KnownFPClasses &= ~fcNan;
@@ -544,6 +545,17 @@ inline KnownFPClass computeKnownFPClass(const Value *V, FastMathFlags FMF,
return Result;
}
+inline KnownFPClass computeKnownFPClass(const Value *V, FastMathFlags FMF,
+ FPClassTest InterestedClasses,
+ unsigned Depth,
+ const SimplifyQuery &SQ) {
+ auto *FVTy = dyn_cast<FixedVectorType>(V->getType());
+ APInt DemandedElts =
+ FVTy ? APInt::getAllOnes(FVTy->getNumElements()) : APInt(1, 1);
+ return computeKnownFPClass(V, DemandedElts, FMF, InterestedClasses, Depth,
+ SQ);
+}
+
/// Return true if we can prove that the specified FP value is never equal to
/// -0.0. Users should use caution when considering PreserveSign
/// denormal-fp-math.
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index f54de030d3344..6e039ad2deadb 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -5274,8 +5274,9 @@ void computeKnownFPClass(const Value *V, const APInt &DemandedElts,
}
// reverse preserves all characteristics of the input vec's element.
case Intrinsic::vector_reverse:
- Known = computeKnownFPClass(II->getArgOperand(0), II->getFastMathFlags(),
- InterestedClasses, Depth + 1, Q);
+ Known = computeKnownFPClass(
+ II->getArgOperand(0), DemandedElts.reverseBits(),
+ II->getFastMathFlags(), InterestedClasses, Depth + 1, Q);
break;
case Intrinsic::trunc:
case Intrinsic::floor:
More information about the llvm-commits
mailing list