[llvm] [DAG] isKnownNeverZero - add DemandedElts element mask to isKnownNeverZero calls. (PR #135951)

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Sat Apr 19 05:07:26 PDT 2025


https://github.com/RKSimon updated https://github.com/llvm/llvm-project/pull/135951

>From 68da740a5cc17256453c28254915b8894a815142 Mon Sep 17 00:00:00 2001
From: Simon Pilgrim <llvm-dev at redking.me.uk>
Date: Wed, 16 Apr 2025 12:13:27 +0100
Subject: [PATCH] [DAG] isKnownNeverZero - add DemandedElts element mask to
 isKnownNeverZero calls.

Matches what we've done for computeKnownBits etc. to improve vector handling
---
 llvm/include/llvm/CodeGen/SelectionDAG.h      |  5 ++
 .../lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 87 ++++++++++++-------
 2 files changed, 59 insertions(+), 33 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h
index 63423463eeee2..165a60e366b8b 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAG.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAG.h
@@ -2156,6 +2156,11 @@ class SelectionDAG {
   /// positive or negative zero.
   bool isKnownNeverZeroFloat(SDValue Op) const;
 
+  /// Test whether the given SDValue is known to contain non-zero value(s) for
+  /// all the demanded elements.
+  bool isKnownNeverZero(SDValue Op, const APInt &DemandedElts,
+                        unsigned Depth = 0) const;
+
   /// Test whether the given SDValue is known to contain non-zero value(s).
   bool isKnownNeverZero(SDValue Op, unsigned Depth = 0) const;
 
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 46fc8856640de..c6b92772eb696 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -5743,6 +5743,19 @@ bool SelectionDAG::isKnownNeverZeroFloat(SDValue Op) const {
 }
 
 bool SelectionDAG::isKnownNeverZero(SDValue Op, unsigned Depth) const {
+  EVT VT = Op.getValueType();
+
+  // Since the number of lanes in a scalable vector is unknown at compile time,
+  // we track one bit which is implicitly broadcast to all lanes.  This means
+  // that all lanes in a scalable vector are considered demanded.
+  APInt DemandedElts = VT.isFixedLengthVector()
+                           ? APInt::getAllOnes(VT.getVectorNumElements())
+                           : APInt(1, 1);
+  return isKnownNeverZero(Op, DemandedElts, Depth);
+}
+
+bool SelectionDAG::isKnownNeverZero(SDValue Op, const APInt &DemandedElts,
+                                    unsigned Depth) const {
   if (Depth >= MaxRecursionDepth)
     return false; // Limit search depth.
 
@@ -5754,6 +5767,9 @@ bool SelectionDAG::isKnownNeverZero(SDValue Op, unsigned Depth) const {
                                [](ConstantSDNode *C) { return !C->isZero(); }))
     return true;
 
+  if (!DemandedElts)
+    return false; // No demanded elts, better to assume we don't know anything.
+
   // TODO: Recognize more cases here. Most of the cases are also incomplete to
   // some degree.
   switch (Op.getOpcode()) {
@@ -5761,23 +5777,25 @@ bool SelectionDAG::isKnownNeverZero(SDValue Op, unsigned Depth) const {
     break;
 
   case ISD::OR:
-    return isKnownNeverZero(Op.getOperand(1), Depth + 1) ||
-           isKnownNeverZero(Op.getOperand(0), Depth + 1);
+    return isKnownNeverZero(Op.getOperand(1), DemandedElts, Depth + 1) ||
+           isKnownNeverZero(Op.getOperand(0), DemandedElts, Depth + 1);
 
   case ISD::VSELECT:
   case ISD::SELECT:
-    return isKnownNeverZero(Op.getOperand(1), Depth + 1) &&
-           isKnownNeverZero(Op.getOperand(2), Depth + 1);
+    return isKnownNeverZero(Op.getOperand(1), DemandedElts, Depth + 1) &&
+           isKnownNeverZero(Op.getOperand(2), DemandedElts, Depth + 1);
 
   case ISD::SHL: {
     if (Op->getFlags().hasNoSignedWrap() || Op->getFlags().hasNoUnsignedWrap())
-      return isKnownNeverZero(Op.getOperand(0), Depth + 1);
-    KnownBits ValKnown = computeKnownBits(Op.getOperand(0), Depth + 1);
+      return isKnownNeverZero(Op.getOperand(0), DemandedElts, Depth + 1);
+    KnownBits ValKnown =
+        computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
     // 1 << X is never zero.
     if (ValKnown.One[0])
       return true;
     // If max shift cnt of known ones is non-zero, result is non-zero.
-    APInt MaxCnt = computeKnownBits(Op.getOperand(1), Depth + 1).getMaxValue();
+    APInt MaxCnt = computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1)
+                       .getMaxValue();
     if (MaxCnt.ult(ValKnown.getBitWidth()) &&
         !ValKnown.One.shl(MaxCnt).isZero())
       return true;
@@ -5785,44 +5803,44 @@ bool SelectionDAG::isKnownNeverZero(SDValue Op, unsigned Depth) const {
   }
   case ISD::UADDSAT:
   case ISD::UMAX:
-    return isKnownNeverZero(Op.getOperand(1), Depth + 1) ||
-           isKnownNeverZero(Op.getOperand(0), Depth + 1);
+    return isKnownNeverZero(Op.getOperand(1), DemandedElts, Depth + 1) ||
+           isKnownNeverZero(Op.getOperand(0), DemandedElts, Depth + 1);
 
     // For smin/smax: If either operand is known negative/positive
     // respectively we don't need the other to be known at all.
   case ISD::SMAX: {
-    KnownBits Op1 = computeKnownBits(Op.getOperand(1), Depth + 1);
+    KnownBits Op1 = computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
     if (Op1.isStrictlyPositive())
       return true;
 
-    KnownBits Op0 = computeKnownBits(Op.getOperand(0), Depth + 1);
+    KnownBits Op0 = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
     if (Op0.isStrictlyPositive())
       return true;
 
     if (Op1.isNonZero() && Op0.isNonZero())
       return true;
 
-    return isKnownNeverZero(Op.getOperand(1), Depth + 1) &&
-           isKnownNeverZero(Op.getOperand(0), Depth + 1);
+    return isKnownNeverZero(Op.getOperand(1), DemandedElts, Depth + 1) &&
+           isKnownNeverZero(Op.getOperand(0), DemandedElts, Depth + 1);
   }
   case ISD::SMIN: {
-    KnownBits Op1 = computeKnownBits(Op.getOperand(1), Depth + 1);
+    KnownBits Op1 = computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
     if (Op1.isNegative())
       return true;
 
-    KnownBits Op0 = computeKnownBits(Op.getOperand(0), Depth + 1);
+    KnownBits Op0 = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
     if (Op0.isNegative())
       return true;
 
     if (Op1.isNonZero() && Op0.isNonZero())
       return true;
 
-    return isKnownNeverZero(Op.getOperand(1), Depth + 1) &&
-           isKnownNeverZero(Op.getOperand(0), Depth + 1);
+    return isKnownNeverZero(Op.getOperand(1), DemandedElts, Depth + 1) &&
+           isKnownNeverZero(Op.getOperand(0), DemandedElts, Depth + 1);
   }
   case ISD::UMIN:
-    return isKnownNeverZero(Op.getOperand(1), Depth + 1) &&
-           isKnownNeverZero(Op.getOperand(0), Depth + 1);
+    return isKnownNeverZero(Op.getOperand(1), DemandedElts, Depth + 1) &&
+           isKnownNeverZero(Op.getOperand(0), DemandedElts, Depth + 1);
 
   case ISD::ROTL:
   case ISD::ROTR:
@@ -5830,17 +5848,19 @@ bool SelectionDAG::isKnownNeverZero(SDValue Op, unsigned Depth) const {
   case ISD::BSWAP:
   case ISD::CTPOP:
   case ISD::ABS:
-    return isKnownNeverZero(Op.getOperand(0), Depth + 1);
+    return isKnownNeverZero(Op.getOperand(0), DemandedElts, Depth + 1);
 
   case ISD::SRA:
   case ISD::SRL: {
     if (Op->getFlags().hasExact())
       return isKnownNeverZero(Op.getOperand(0), Depth + 1);
-    KnownBits ValKnown = computeKnownBits(Op.getOperand(0), Depth + 1);
+    KnownBits ValKnown =
+        computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
     if (ValKnown.isNegative())
       return true;
     // If max shift cnt of known ones is non-zero, result is non-zero.
-    APInt MaxCnt = computeKnownBits(Op.getOperand(1), Depth + 1).getMaxValue();
+    APInt MaxCnt = computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1)
+                       .getMaxValue();
     if (MaxCnt.ult(ValKnown.getBitWidth()) &&
         !ValKnown.One.lshr(MaxCnt).isZero())
       return true;
@@ -5851,37 +5871,38 @@ bool SelectionDAG::isKnownNeverZero(SDValue Op, unsigned Depth) const {
     // div exact can only produce a zero if the dividend is zero.
     // TODO: For udiv this is also true if Op1 u<= Op0
     if (Op->getFlags().hasExact())
-      return isKnownNeverZero(Op.getOperand(0), Depth + 1);
+      return isKnownNeverZero(Op.getOperand(0), DemandedElts, Depth + 1);
     break;
 
   case ISD::ADD:
     if (Op->getFlags().hasNoUnsignedWrap())
-      if (isKnownNeverZero(Op.getOperand(1), Depth + 1) ||
-          isKnownNeverZero(Op.getOperand(0), Depth + 1))
+      if (isKnownNeverZero(Op.getOperand(1), DemandedElts, Depth + 1) ||
+          isKnownNeverZero(Op.getOperand(0), DemandedElts, Depth + 1))
         return true;
     // TODO: There are a lot more cases we can prove for add.
     break;
 
   case ISD::SUB: {
     if (isNullConstant(Op.getOperand(0)))
-      return isKnownNeverZero(Op.getOperand(1), Depth + 1);
+      return isKnownNeverZero(Op.getOperand(1), DemandedElts, Depth + 1);
 
-    std::optional<bool> ne =
-        KnownBits::ne(computeKnownBits(Op.getOperand(0), Depth + 1),
-                      computeKnownBits(Op.getOperand(1), Depth + 1));
+    std::optional<bool> ne = KnownBits::ne(
+        computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1),
+        computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1));
     return ne && *ne;
   }
 
   case ISD::MUL:
     if (Op->getFlags().hasNoSignedWrap() || Op->getFlags().hasNoUnsignedWrap())
-      if (isKnownNeverZero(Op.getOperand(1), Depth + 1) &&
-          isKnownNeverZero(Op.getOperand(0), Depth + 1))
+      if (isKnownNeverZero(Op.getOperand(1), DemandedElts, Depth + 1) &&
+          isKnownNeverZero(Op.getOperand(0), DemandedElts, Depth + 1))
         return true;
     break;
 
   case ISD::ZERO_EXTEND:
   case ISD::SIGN_EXTEND:
-    return isKnownNeverZero(Op.getOperand(0), Depth + 1);
+    return isKnownNeverZero(Op.getOperand(0), DemandedElts, Depth + 1);
+
   case ISD::VSCALE: {
     const Function &F = getMachineFunction().getFunction();
     const APInt &Multiplier = Op.getConstantOperandAPInt(0);
@@ -5893,7 +5914,7 @@ bool SelectionDAG::isKnownNeverZero(SDValue Op, unsigned Depth) const {
   }
   }
 
-  return computeKnownBits(Op, Depth).isNonZero();
+  return computeKnownBits(Op, DemandedElts, Depth).isNonZero();
 }
 
 bool SelectionDAG::cannotBeOrderedNegativeFP(SDValue Op) const {



More information about the llvm-commits mailing list