[llvm] [InstCombine][InstSimplify] Pass `SimplifyQuery` to `computeKnownBits` directly. NFC. (PR #74246)

via llvm-commits llvm-commits at lists.llvm.org
Sun Dec 3 09:41:03 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: Yingwei Zheng (dtcxzyw)

<details>
<summary>Changes</summary>

This patch passes `SimplifyQuery` to `computeKnownBits` directly in `InstSimplify` and `InstCombine`.
As the `DomConditionCache` in #<!-- -->73662 is only used in `InstCombine`, it is inconvenient to introduce a new argument `DC` to `computeKnownBits`.

https://github.com/llvm/llvm-project/issues/74242 will be fixed by this patch and #<!-- -->73662.


---
Full diff: https://github.com/llvm/llvm-project/pull/74246.diff


2 Files Affected:

- (modified) llvm/lib/Analysis/InstructionSimplify.cpp (+19-21) 
- (modified) llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp (+2-4) 


``````````diff
diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index cef9f6ec179ba..2a45acf63aa2c 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -811,7 +811,7 @@ static Value *simplifySubInst(Value *Op0, Value *Op1, bool IsNSW, bool IsNUW,
     if (IsNUW)
       return Constant::getNullValue(Op0->getType());
 
-    KnownBits Known = computeKnownBits(Op1, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+    KnownBits Known = computeKnownBits(Op1, /* Depth */ 0, Q);
     if (Known.Zero.isMaxSignedValue()) {
       // Op1 is either 0 or the minimum signed value. If the sub is NSW, then
       // Op1 must be 0 because negating the minimum signed value is undefined.
@@ -1063,7 +1063,7 @@ static bool isDivZero(Value *X, Value *Y, const SimplifyQuery &Q,
   //      ("computeConstantRangeIncludingKnownBits")?
   const APInt *C;
   if (match(Y, m_APInt(C)) &&
-      computeKnownBits(X, Q.DL, 0, Q.AC, Q.CxtI, Q.DT).getMaxValue().ult(*C))
+      computeKnownBits(X, /* Depth */ 0, Q).getMaxValue().ult(*C))
     return true;
 
   // Try again for any divisor:
@@ -1125,8 +1125,7 @@ static Value *simplifyDivRem(Instruction::BinaryOps Opcode, Value *Op0,
   if (Op0 == Op1)
     return IsDiv ? ConstantInt::get(Ty, 1) : Constant::getNullValue(Ty);
 
-
-  KnownBits Known = computeKnownBits(Op1, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+  KnownBits Known = computeKnownBits(Op1, /* Depth */ 0, Q);
   // X / 0 -> poison
   // X % 0 -> poison
   // If the divisor is known to be zero, just return poison. This can happen in
@@ -1195,7 +1194,7 @@ static Value *simplifyDiv(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1,
   // less trailing zeros, then the result must be poison.
   const APInt *DivC;
   if (IsExact && match(Op1, m_APInt(DivC)) && DivC->countr_zero()) {
-    KnownBits KnownOp0 = computeKnownBits(Op0, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+    KnownBits KnownOp0 = computeKnownBits(Op0, /* Depth */ 0, Q);
     if (KnownOp0.countMaxTrailingZeros() < DivC->countr_zero())
       return PoisonValue::get(Op0->getType());
   }
@@ -1355,7 +1354,7 @@ static Value *simplifyShift(Instruction::BinaryOps Opcode, Value *Op0,
 
   // If any bits in the shift amount make that value greater than or equal to
   // the number of bits in the type, the shift is undefined.
-  KnownBits KnownAmt = computeKnownBits(Op1, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+  KnownBits KnownAmt = computeKnownBits(Op1, /* Depth */ 0, Q);
   if (KnownAmt.getMinValue().uge(KnownAmt.getBitWidth()))
     return PoisonValue::get(Op0->getType());
 
@@ -1368,7 +1367,7 @@ static Value *simplifyShift(Instruction::BinaryOps Opcode, Value *Op0,
   // Check for nsw shl leading to a poison value.
   if (IsNSW) {
     assert(Opcode == Instruction::Shl && "Expected shl for nsw instruction");
-    KnownBits KnownVal = computeKnownBits(Op0, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+    KnownBits KnownVal = computeKnownBits(Op0, /* Depth */ 0, Q);
     KnownBits KnownShl = KnownBits::shl(KnownVal, KnownAmt);
 
     if (KnownVal.Zero.isSignBitSet())
@@ -1404,8 +1403,7 @@ static Value *simplifyRightShift(Instruction::BinaryOps Opcode, Value *Op0,
   // The low bit cannot be shifted out of an exact shift if it is set.
   // TODO: Generalize by counting trailing zeros (see fold for exact division).
   if (IsExact) {
-    KnownBits Op0Known =
-        computeKnownBits(Op0, Q.DL, /*Depth=*/0, Q.AC, Q.CxtI, Q.DT);
+    KnownBits Op0Known = computeKnownBits(Op0, /* Depth */ 0, Q);
     if (Op0Known.One[0])
       return Op0;
   }
@@ -1477,7 +1475,7 @@ static Value *simplifyLShrInst(Value *Op0, Value *Op1, bool IsExact,
   if (Q.IIQ.UseInstrInfo && match(Op1, m_APInt(ShRAmt)) &&
       match(Op0, m_c_Or(m_NUWShl(m_Value(X), m_APInt(ShLAmt)), m_Value(Y))) &&
       *ShRAmt == *ShLAmt) {
-    const KnownBits YKnown = computeKnownBits(Y, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+    const KnownBits YKnown = computeKnownBits(Y, /* Depth */ 0, Q);
     const unsigned EffWidthY = YKnown.countMaxActiveBits();
     if (ShRAmt->uge(EffWidthY))
       return X;
@@ -2105,7 +2103,7 @@ static Value *simplifyAndInst(Value *Op0, Value *Op1, const SimplifyQuery &Q,
       match(Op0, m_Add(m_Value(Shift), m_AllOnes())) &&
       isKnownToBeAPowerOfTwo(Shift, Q.DL, /*OrZero*/ false, 0, Q.AC, Q.CxtI,
                              Q.DT)) {
-    KnownBits Known = computeKnownBits(Shift, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+    KnownBits Known = computeKnownBits(Shift, /* Depth */ 0, Q);
     // Use getActiveBits() to make use of the additional power of two knowledge
     if (PowerC->getActiveBits() >= Known.getMaxValue().getActiveBits())
       return ConstantInt::getNullValue(Op1->getType());
@@ -2169,10 +2167,10 @@ static Value *simplifyAndInst(Value *Op0, Value *Op1, const SimplifyQuery &Q,
                         m_Value(Y)))) {
     const unsigned Width = Op0->getType()->getScalarSizeInBits();
     const unsigned ShftCnt = ShAmt->getLimitedValue(Width);
-    const KnownBits YKnown = computeKnownBits(Y, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+    const KnownBits YKnown = computeKnownBits(Y, /* Depth */ 0, Q);
     const unsigned EffWidthY = YKnown.countMaxActiveBits();
     if (EffWidthY <= ShftCnt) {
-      const KnownBits XKnown = computeKnownBits(X, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+      const KnownBits XKnown = computeKnownBits(X, /* Depth */ 0, Q);
       const unsigned EffWidthX = XKnown.countMaxActiveBits();
       const APInt EffBitsY = APInt::getLowBitsSet(Width, EffWidthY);
       const APInt EffBitsX = APInt::getLowBitsSet(Width, EffWidthX) << ShftCnt;
@@ -2968,7 +2966,7 @@ static Value *simplifyICmpWithZero(CmpInst::Predicate Pred, Value *LHS,
       return getTrue(ITy);
     break;
   case ICmpInst::ICMP_SLT: {
-    KnownBits LHSKnown = computeKnownBits(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+    KnownBits LHSKnown = computeKnownBits(LHS, /* Depth */ 0, Q);
     if (LHSKnown.isNegative())
       return getTrue(ITy);
     if (LHSKnown.isNonNegative())
@@ -2976,7 +2974,7 @@ static Value *simplifyICmpWithZero(CmpInst::Predicate Pred, Value *LHS,
     break;
   }
   case ICmpInst::ICMP_SLE: {
-    KnownBits LHSKnown = computeKnownBits(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+    KnownBits LHSKnown = computeKnownBits(LHS, /* Depth */ 0, Q);
     if (LHSKnown.isNegative())
       return getTrue(ITy);
     if (LHSKnown.isNonNegative() &&
@@ -2985,7 +2983,7 @@ static Value *simplifyICmpWithZero(CmpInst::Predicate Pred, Value *LHS,
     break;
   }
   case ICmpInst::ICMP_SGE: {
-    KnownBits LHSKnown = computeKnownBits(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+    KnownBits LHSKnown = computeKnownBits(LHS, /* Depth */ 0, Q);
     if (LHSKnown.isNegative())
       return getFalse(ITy);
     if (LHSKnown.isNonNegative())
@@ -2993,7 +2991,7 @@ static Value *simplifyICmpWithZero(CmpInst::Predicate Pred, Value *LHS,
     break;
   }
   case ICmpInst::ICMP_SGT: {
-    KnownBits LHSKnown = computeKnownBits(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+    KnownBits LHSKnown = computeKnownBits(LHS, /* Depth */ 0, Q);
     if (LHSKnown.isNegative())
       return getFalse(ITy);
     if (LHSKnown.isNonNegative() &&
@@ -3070,8 +3068,8 @@ static Value *simplifyICmpWithBinOpOnLHS(CmpInst::Predicate Pred,
       return getTrue(ITy);
 
     if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SGE) {
-      KnownBits RHSKnown = computeKnownBits(RHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
-      KnownBits YKnown = computeKnownBits(Y, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+      KnownBits RHSKnown = computeKnownBits(RHS, /* Depth */ 0, Q);
+      KnownBits YKnown = computeKnownBits(Y, /* Depth */ 0, Q);
       if (RHSKnown.isNonNegative() && YKnown.isNegative())
         return Pred == ICmpInst::ICMP_SLT ? getTrue(ITy) : getFalse(ITy);
       if (RHSKnown.isNegative() || YKnown.isNonNegative())
@@ -3094,7 +3092,7 @@ static Value *simplifyICmpWithBinOpOnLHS(CmpInst::Predicate Pred,
       break;
     case ICmpInst::ICMP_SGT:
     case ICmpInst::ICMP_SGE: {
-      KnownBits Known = computeKnownBits(RHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+      KnownBits Known = computeKnownBits(RHS, /* Depth */ 0, Q);
       if (!Known.isNonNegative())
         break;
       [[fallthrough]];
@@ -3105,7 +3103,7 @@ static Value *simplifyICmpWithBinOpOnLHS(CmpInst::Predicate Pred,
       return getFalse(ITy);
     case ICmpInst::ICMP_SLT:
     case ICmpInst::ICMP_SLE: {
-      KnownBits Known = computeKnownBits(RHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+      KnownBits Known = computeKnownBits(RHS, /* Depth */ 0, Q);
       if (!Known.isNonNegative())
         break;
       [[fallthrough]];
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
index c72eb0f74de8e..b7958978c450c 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
@@ -962,15 +962,13 @@ static bool setShiftFlags(BinaryOperator &I, const SimplifyQuery &Q) {
   }
 
   // Compute what we know about shift count.
-  KnownBits KnownCnt =
-      computeKnownBits(I.getOperand(1), Q.DL, /*Depth*/ 0, Q.AC, Q.CxtI, Q.DT);
+  KnownBits KnownCnt = computeKnownBits(I.getOperand(1), /* Depth */ 0, Q);
   unsigned BitWidth = KnownCnt.getBitWidth();
   // Since shift produces a poison value if RHS is equal to or larger than the
   // bit width, we can safely assume that RHS is less than the bit width.
   uint64_t MaxCnt = KnownCnt.getMaxValue().getLimitedValue(BitWidth - 1);
 
-  KnownBits KnownAmt =
-      computeKnownBits(I.getOperand(0), Q.DL, /*Depth*/ 0, Q.AC, Q.CxtI, Q.DT);
+  KnownBits KnownAmt = computeKnownBits(I.getOperand(0), /* Depth */ 0, Q);
   bool Changed = false;
 
   if (I.getOpcode() == Instruction::Shl) {

``````````

</details>


https://github.com/llvm/llvm-project/pull/74246


More information about the llvm-commits mailing list