[llvm] [InstCombine][InstSimplify] Pass `SimplifyQuery` to `computeKnownBits` directly. NFC. (PR #74246)
via llvm-commits
llvm-commits at lists.llvm.org
Sun Dec 3 09:41:03 PST 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-transforms
Author: Yingwei Zheng (dtcxzyw)
<details>
<summary>Changes</summary>
This patch passes `SimplifyQuery` to `computeKnownBits` directly in `InstSimplify` and `InstCombine`.
As the `DomConditionCache` in #<!-- -->73662 is only used in `InstCombine`, it is inconvenient to introduce a new argument `DC` to `computeKnownBits`.
https://github.com/llvm/llvm-project/issues/74242 will be fixed by this patch and #<!-- -->73662.
---
Full diff: https://github.com/llvm/llvm-project/pull/74246.diff
2 Files Affected:
- (modified) llvm/lib/Analysis/InstructionSimplify.cpp (+19-21)
- (modified) llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp (+2-4)
``````````diff
diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index cef9f6ec179ba..2a45acf63aa2c 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -811,7 +811,7 @@ static Value *simplifySubInst(Value *Op0, Value *Op1, bool IsNSW, bool IsNUW,
if (IsNUW)
return Constant::getNullValue(Op0->getType());
- KnownBits Known = computeKnownBits(Op1, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+ KnownBits Known = computeKnownBits(Op1, /* Depth */ 0, Q);
if (Known.Zero.isMaxSignedValue()) {
// Op1 is either 0 or the minimum signed value. If the sub is NSW, then
// Op1 must be 0 because negating the minimum signed value is undefined.
@@ -1063,7 +1063,7 @@ static bool isDivZero(Value *X, Value *Y, const SimplifyQuery &Q,
// ("computeConstantRangeIncludingKnownBits")?
const APInt *C;
if (match(Y, m_APInt(C)) &&
- computeKnownBits(X, Q.DL, 0, Q.AC, Q.CxtI, Q.DT).getMaxValue().ult(*C))
+ computeKnownBits(X, /* Depth */ 0, Q).getMaxValue().ult(*C))
return true;
// Try again for any divisor:
@@ -1125,8 +1125,7 @@ static Value *simplifyDivRem(Instruction::BinaryOps Opcode, Value *Op0,
if (Op0 == Op1)
return IsDiv ? ConstantInt::get(Ty, 1) : Constant::getNullValue(Ty);
-
- KnownBits Known = computeKnownBits(Op1, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+ KnownBits Known = computeKnownBits(Op1, /* Depth */ 0, Q);
// X / 0 -> poison
// X % 0 -> poison
// If the divisor is known to be zero, just return poison. This can happen in
@@ -1195,7 +1194,7 @@ static Value *simplifyDiv(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1,
// less trailing zeros, then the result must be poison.
const APInt *DivC;
if (IsExact && match(Op1, m_APInt(DivC)) && DivC->countr_zero()) {
- KnownBits KnownOp0 = computeKnownBits(Op0, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+ KnownBits KnownOp0 = computeKnownBits(Op0, /* Depth */ 0, Q);
if (KnownOp0.countMaxTrailingZeros() < DivC->countr_zero())
return PoisonValue::get(Op0->getType());
}
@@ -1355,7 +1354,7 @@ static Value *simplifyShift(Instruction::BinaryOps Opcode, Value *Op0,
// If any bits in the shift amount make that value greater than or equal to
// the number of bits in the type, the shift is undefined.
- KnownBits KnownAmt = computeKnownBits(Op1, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+ KnownBits KnownAmt = computeKnownBits(Op1, /* Depth */ 0, Q);
if (KnownAmt.getMinValue().uge(KnownAmt.getBitWidth()))
return PoisonValue::get(Op0->getType());
@@ -1368,7 +1367,7 @@ static Value *simplifyShift(Instruction::BinaryOps Opcode, Value *Op0,
// Check for nsw shl leading to a poison value.
if (IsNSW) {
assert(Opcode == Instruction::Shl && "Expected shl for nsw instruction");
- KnownBits KnownVal = computeKnownBits(Op0, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+ KnownBits KnownVal = computeKnownBits(Op0, /* Depth */ 0, Q);
KnownBits KnownShl = KnownBits::shl(KnownVal, KnownAmt);
if (KnownVal.Zero.isSignBitSet())
@@ -1404,8 +1403,7 @@ static Value *simplifyRightShift(Instruction::BinaryOps Opcode, Value *Op0,
// The low bit cannot be shifted out of an exact shift if it is set.
// TODO: Generalize by counting trailing zeros (see fold for exact division).
if (IsExact) {
- KnownBits Op0Known =
- computeKnownBits(Op0, Q.DL, /*Depth=*/0, Q.AC, Q.CxtI, Q.DT);
+ KnownBits Op0Known = computeKnownBits(Op0, /* Depth */ 0, Q);
if (Op0Known.One[0])
return Op0;
}
@@ -1477,7 +1475,7 @@ static Value *simplifyLShrInst(Value *Op0, Value *Op1, bool IsExact,
if (Q.IIQ.UseInstrInfo && match(Op1, m_APInt(ShRAmt)) &&
match(Op0, m_c_Or(m_NUWShl(m_Value(X), m_APInt(ShLAmt)), m_Value(Y))) &&
*ShRAmt == *ShLAmt) {
- const KnownBits YKnown = computeKnownBits(Y, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+ const KnownBits YKnown = computeKnownBits(Y, /* Depth */ 0, Q);
const unsigned EffWidthY = YKnown.countMaxActiveBits();
if (ShRAmt->uge(EffWidthY))
return X;
@@ -2105,7 +2103,7 @@ static Value *simplifyAndInst(Value *Op0, Value *Op1, const SimplifyQuery &Q,
match(Op0, m_Add(m_Value(Shift), m_AllOnes())) &&
isKnownToBeAPowerOfTwo(Shift, Q.DL, /*OrZero*/ false, 0, Q.AC, Q.CxtI,
Q.DT)) {
- KnownBits Known = computeKnownBits(Shift, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+ KnownBits Known = computeKnownBits(Shift, /* Depth */ 0, Q);
// Use getActiveBits() to make use of the additional power of two knowledge
if (PowerC->getActiveBits() >= Known.getMaxValue().getActiveBits())
return ConstantInt::getNullValue(Op1->getType());
@@ -2169,10 +2167,10 @@ static Value *simplifyAndInst(Value *Op0, Value *Op1, const SimplifyQuery &Q,
m_Value(Y)))) {
const unsigned Width = Op0->getType()->getScalarSizeInBits();
const unsigned ShftCnt = ShAmt->getLimitedValue(Width);
- const KnownBits YKnown = computeKnownBits(Y, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+ const KnownBits YKnown = computeKnownBits(Y, /* Depth */ 0, Q);
const unsigned EffWidthY = YKnown.countMaxActiveBits();
if (EffWidthY <= ShftCnt) {
- const KnownBits XKnown = computeKnownBits(X, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+ const KnownBits XKnown = computeKnownBits(X, /* Depth */ 0, Q);
const unsigned EffWidthX = XKnown.countMaxActiveBits();
const APInt EffBitsY = APInt::getLowBitsSet(Width, EffWidthY);
const APInt EffBitsX = APInt::getLowBitsSet(Width, EffWidthX) << ShftCnt;
@@ -2968,7 +2966,7 @@ static Value *simplifyICmpWithZero(CmpInst::Predicate Pred, Value *LHS,
return getTrue(ITy);
break;
case ICmpInst::ICMP_SLT: {
- KnownBits LHSKnown = computeKnownBits(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+ KnownBits LHSKnown = computeKnownBits(LHS, /* Depth */ 0, Q);
if (LHSKnown.isNegative())
return getTrue(ITy);
if (LHSKnown.isNonNegative())
@@ -2976,7 +2974,7 @@ static Value *simplifyICmpWithZero(CmpInst::Predicate Pred, Value *LHS,
break;
}
case ICmpInst::ICMP_SLE: {
- KnownBits LHSKnown = computeKnownBits(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+ KnownBits LHSKnown = computeKnownBits(LHS, /* Depth */ 0, Q);
if (LHSKnown.isNegative())
return getTrue(ITy);
if (LHSKnown.isNonNegative() &&
@@ -2985,7 +2983,7 @@ static Value *simplifyICmpWithZero(CmpInst::Predicate Pred, Value *LHS,
break;
}
case ICmpInst::ICMP_SGE: {
- KnownBits LHSKnown = computeKnownBits(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+ KnownBits LHSKnown = computeKnownBits(LHS, /* Depth */ 0, Q);
if (LHSKnown.isNegative())
return getFalse(ITy);
if (LHSKnown.isNonNegative())
@@ -2993,7 +2991,7 @@ static Value *simplifyICmpWithZero(CmpInst::Predicate Pred, Value *LHS,
break;
}
case ICmpInst::ICMP_SGT: {
- KnownBits LHSKnown = computeKnownBits(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+ KnownBits LHSKnown = computeKnownBits(LHS, /* Depth */ 0, Q);
if (LHSKnown.isNegative())
return getFalse(ITy);
if (LHSKnown.isNonNegative() &&
@@ -3070,8 +3068,8 @@ static Value *simplifyICmpWithBinOpOnLHS(CmpInst::Predicate Pred,
return getTrue(ITy);
if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SGE) {
- KnownBits RHSKnown = computeKnownBits(RHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
- KnownBits YKnown = computeKnownBits(Y, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+ KnownBits RHSKnown = computeKnownBits(RHS, /* Depth */ 0, Q);
+ KnownBits YKnown = computeKnownBits(Y, /* Depth */ 0, Q);
if (RHSKnown.isNonNegative() && YKnown.isNegative())
return Pred == ICmpInst::ICMP_SLT ? getTrue(ITy) : getFalse(ITy);
if (RHSKnown.isNegative() || YKnown.isNonNegative())
@@ -3094,7 +3092,7 @@ static Value *simplifyICmpWithBinOpOnLHS(CmpInst::Predicate Pred,
break;
case ICmpInst::ICMP_SGT:
case ICmpInst::ICMP_SGE: {
- KnownBits Known = computeKnownBits(RHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+ KnownBits Known = computeKnownBits(RHS, /* Depth */ 0, Q);
if (!Known.isNonNegative())
break;
[[fallthrough]];
@@ -3105,7 +3103,7 @@ static Value *simplifyICmpWithBinOpOnLHS(CmpInst::Predicate Pred,
return getFalse(ITy);
case ICmpInst::ICMP_SLT:
case ICmpInst::ICMP_SLE: {
- KnownBits Known = computeKnownBits(RHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+ KnownBits Known = computeKnownBits(RHS, /* Depth */ 0, Q);
if (!Known.isNonNegative())
break;
[[fallthrough]];
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
index c72eb0f74de8e..b7958978c450c 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
@@ -962,15 +962,13 @@ static bool setShiftFlags(BinaryOperator &I, const SimplifyQuery &Q) {
}
// Compute what we know about shift count.
- KnownBits KnownCnt =
- computeKnownBits(I.getOperand(1), Q.DL, /*Depth*/ 0, Q.AC, Q.CxtI, Q.DT);
+ KnownBits KnownCnt = computeKnownBits(I.getOperand(1), /* Depth */ 0, Q);
unsigned BitWidth = KnownCnt.getBitWidth();
// Since shift produces a poison value if RHS is equal to or larger than the
// bit width, we can safely assume that RHS is less than the bit width.
uint64_t MaxCnt = KnownCnt.getMaxValue().getLimitedValue(BitWidth - 1);
- KnownBits KnownAmt =
- computeKnownBits(I.getOperand(0), Q.DL, /*Depth*/ 0, Q.AC, Q.CxtI, Q.DT);
+ KnownBits KnownAmt = computeKnownBits(I.getOperand(0), /* Depth */ 0, Q);
bool Changed = false;
if (I.getOpcode() == Instruction::Shl) {
``````````
</details>
https://github.com/llvm/llvm-project/pull/74246
More information about the llvm-commits
mailing list