[llvm] [DAG] isKnownNeverZero - add more cases for UDIV, SDIV, SRA, and SRL operations (PR #89522)

Matt Arsenault via llvm-commits llvm-commits at lists.llvm.org
Fri Aug 16 11:38:30 PDT 2024


================
@@ -5580,27 +5580,88 @@ bool SelectionDAG::isKnownNeverZero(SDValue Op, unsigned Depth) const {
     if (ValKnown.isNegative())
       return true;
     // If max shift cnt of known ones is non-zero, result is non-zero.
-    APInt MaxCnt = computeKnownBits(Op.getOperand(1), Depth + 1).getMaxValue();
+    const KnownBits Shift = computeKnownBits(Op.getOperand(1), Depth + 1);
+    APInt MaxCnt = Shift.getMaxValue();
     if (MaxCnt.ult(ValKnown.getBitWidth()) &&
         !ValKnown.One.lshr(MaxCnt).isZero())
       return true;
+    // Similar to udiv but we try to see if we can turn it into a division
+    const KnownBits One =
+        KnownBits::makeConstant(APInt(ValKnown.getBitWidth(), 1));
+
+    std::optional<bool> uge =
+        KnownBits::uge(ValKnown, KnownBits::shl(One, Shift));
+    if (uge && *uge)
+      return true;
     break;
   }
-  case ISD::UDIV:
-  case ISD::SDIV:
+  case ISD::UDIV: {
+    if (Op->getFlags().hasExact())
+      return isKnownNeverZero(Op.getOperand(0), Depth + 1);
+    KnownBits Op0 = computeKnownBits(Op.getOperand(0), Depth + 1);
+    KnownBits Op1 = computeKnownBits(Op.getOperand(1), Depth + 1);
+    // True if Op0 u>= Op1
+
+    std::optional<bool> Uge = KnownBits::uge(Op0, Op1);
+    if (Uge && *Uge)
+      return true;
+    break;
+  }
+  case ISD::SDIV: {
     // div exact can only produce a zero if the dividend is zero.
-    // TODO: For udiv this is also true if Op1 u<= Op0
     if (Op->getFlags().hasExact())
       return isKnownNeverZero(Op.getOperand(0), Depth + 1);
+    KnownBits Op0 = computeKnownBits(Op.getOperand(0), Depth + 1);
+    KnownBits Op1 = computeKnownBits(Op.getOperand(1), Depth + 1);
+
+    // For signed division need to compare abs value of the operands.
+    Op0 = Op0.abs(/*IntMinIsPoison*/ false);
+    Op1 = Op1.abs(/*IntMinIsPoison*/ false);
+
+    std::optional<bool> Uge = KnownBits::uge(Op0, Op1);
+    if (Uge && *Uge)
+      return true;
     break;
+  }
 
-  case ISD::ADD:
+  case ISD::ADD: {
     if (Op->getFlags().hasNoUnsignedWrap())
       if (isKnownNeverZero(Op.getOperand(1), Depth + 1) ||
           isKnownNeverZero(Op.getOperand(0), Depth + 1))
         return true;
+
+    KnownBits Op0 = computeKnownBits(Op.getOperand(0), Depth + 1);
+    KnownBits Op1 = computeKnownBits(Op.getOperand(1), Depth + 1);
+
+    // If X and Y are both non-negative (as signed values) then their sum is not
+    // zero unless both X and Y are zero.
+    if (Op0.isNonNegative() && Op1.isNonNegative())
+      if (isKnownNeverZero(Op.getOperand(1), Depth + 1) ||
+          isKnownNeverZero(Op.getOperand(0), Depth + 1))
+        return true;
+    // If X and Y are both negative (as signed values) then their sum is not
+    // zero unless both X and Y equal INT_MIN.
+    if (Op0.isNegative() && Op1.isNegative()) {
+      APInt Mask = APInt::getSignedMaxValue(Op0.getBitWidth());
+      // The sign bit of X is set.  If some other bit is set then X is not equal
+      // to INT_MIN.
+      if (Op0.One.intersects(Mask))
+        return true;
+      // The sign bit of Y is set.  If some other bit is set then Y is not equal
+      // to INT_MIN.
+      if (Op1.One.intersects(Mask))
+        return true;
+    }
+
+    if (KnownBits::computeForAddSub(
+            /*Add=*/true, Op->getFlags().hasNoSignedWrap(),
+            Op->getFlags().hasNoUnsignedWrap(), Op0, Op1)
+            .isNonZero())
----------------
arsenm wrote:

I believe there's a know ::add helper 

https://github.com/llvm/llvm-project/pull/89522


More information about the llvm-commits mailing list