[llvm] [KnownBits] Make nuw and nsw support in computeForAddSub optimal (PR #83382)

via llvm-commits llvm-commits at lists.llvm.org
Sat Mar 2 10:22:55 PST 2024


https://github.com/goldsteinn updated https://github.com/llvm/llvm-project/pull/83382

>From aa76e975e69282c4d409c452e701350c6acf8b7e Mon Sep 17 00:00:00 2001
From: Noah Goldstein <goldstein.w.n at gmail.com>
Date: Thu, 29 Feb 2024 11:22:30 -0600
Subject: [PATCH 1/2] [KnownBits] Add API for `nuw` flag in `computeForAddSub`;
 NFC

---
 llvm/include/llvm/Support/KnownBits.h         |  4 +-
 llvm/lib/Analysis/ValueTracking.cpp           | 46 ++++++++++---------
 .../lib/CodeGen/GlobalISel/GISelKnownBits.cpp | 10 ++--
 .../lib/CodeGen/SelectionDAG/SelectionDAG.cpp |  5 +-
 .../CodeGen/SelectionDAG/TargetLowering.cpp   |  6 +--
 llvm/lib/Support/KnownBits.cpp                | 16 ++++---
 llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp |  3 +-
 .../AMDGPU/AMDGPUInstructionSelector.cpp      |  2 +-
 llvm/lib/Target/ARM/ARMISelLowering.cpp       |  3 +-
 .../InstCombineSimplifyDemanded.cpp           | 14 ++++--
 llvm/unittests/Support/KnownBitsTest.cpp      |  6 +--
 11 files changed, 66 insertions(+), 49 deletions(-)

diff --git a/llvm/include/llvm/Support/KnownBits.h b/llvm/include/llvm/Support/KnownBits.h
index 69c569b97ccae1..f5fce296fefe70 100644
--- a/llvm/include/llvm/Support/KnownBits.h
+++ b/llvm/include/llvm/Support/KnownBits.h
@@ -329,8 +329,8 @@ struct KnownBits {
       const KnownBits &LHS, const KnownBits &RHS, const KnownBits &Carry);
 
   /// Compute known bits resulting from adding LHS and RHS.
-  static KnownBits computeForAddSub(bool Add, bool NSW, const KnownBits &LHS,
-                                    KnownBits RHS);
+  static KnownBits computeForAddSub(bool Add, bool NSW, bool NUW,
+                                    const KnownBits &LHS, KnownBits RHS);
 
   /// Compute known bits results from subtracting RHS from LHS with 1-bit
   /// Borrow.
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index e591ac504e9f05..6d8216ddbe29e2 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -350,18 +350,19 @@ unsigned llvm::ComputeMaxSignificantBits(const Value *V, const DataLayout &DL,
 }
 
 static void computeKnownBitsAddSub(bool Add, const Value *Op0, const Value *Op1,
-                                   bool NSW, const APInt &DemandedElts,
+                                   bool NSW, bool NUW,
+                                   const APInt &DemandedElts,
                                    KnownBits &KnownOut, KnownBits &Known2,
                                    unsigned Depth, const SimplifyQuery &Q) {
   computeKnownBits(Op1, DemandedElts, KnownOut, Depth + 1, Q);
 
   // If one operand is unknown and we have no nowrap information,
   // the result will be unknown independently of the second operand.
-  if (KnownOut.isUnknown() && !NSW)
+  if (KnownOut.isUnknown() && !NSW && !NUW)
     return;
 
   computeKnownBits(Op0, DemandedElts, Known2, Depth + 1, Q);
-  KnownOut = KnownBits::computeForAddSub(Add, NSW, Known2, KnownOut);
+  KnownOut = KnownBits::computeForAddSub(Add, NSW, NUW, Known2, KnownOut);
 }
 
 static void computeKnownBitsMul(const Value *Op0, const Value *Op1, bool NSW,
@@ -1145,13 +1146,15 @@ static void computeKnownBitsFromOperator(const Operator *I,
   }
   case Instruction::Sub: {
     bool NSW = Q.IIQ.hasNoSignedWrap(cast<OverflowingBinaryOperator>(I));
-    computeKnownBitsAddSub(false, I->getOperand(0), I->getOperand(1), NSW,
+    bool NUW = Q.IIQ.hasNoUnsignedWrap(cast<OverflowingBinaryOperator>(I));
+    computeKnownBitsAddSub(false, I->getOperand(0), I->getOperand(1), NSW, NUW,
                            DemandedElts, Known, Known2, Depth, Q);
     break;
   }
   case Instruction::Add: {
     bool NSW = Q.IIQ.hasNoSignedWrap(cast<OverflowingBinaryOperator>(I));
-    computeKnownBitsAddSub(true, I->getOperand(0), I->getOperand(1), NSW,
+    bool NUW = Q.IIQ.hasNoUnsignedWrap(cast<OverflowingBinaryOperator>(I));
+    computeKnownBitsAddSub(true, I->getOperand(0), I->getOperand(1), NSW, NUW,
                            DemandedElts, Known, Known2, Depth, Q);
     break;
   }
@@ -1245,12 +1248,12 @@ static void computeKnownBitsFromOperator(const Operator *I,
       // Note that inbounds does *not* guarantee nsw for the addition, as only
       // the offset is signed, while the base address is unsigned.
       Known = KnownBits::computeForAddSub(
-          /*Add=*/true, /*NSW=*/false, Known, IndexBits);
+          /*Add=*/true, /*NSW=*/false, /* NUW=*/false, Known, IndexBits);
     }
     if (!Known.isUnknown() && !AccConstIndices.isZero()) {
       KnownBits Index = KnownBits::makeConstant(AccConstIndices);
       Known = KnownBits::computeForAddSub(
-          /*Add=*/true, /*NSW=*/false, Known, Index);
+          /*Add=*/true, /*NSW=*/false, /* NUW=*/false, Known, Index);
     }
     break;
   }
@@ -1689,15 +1692,15 @@ static void computeKnownBitsFromOperator(const Operator *I,
         default: break;
         case Intrinsic::uadd_with_overflow:
         case Intrinsic::sadd_with_overflow:
-          computeKnownBitsAddSub(true, II->getArgOperand(0),
-                                 II->getArgOperand(1), false, DemandedElts,
-                                 Known, Known2, Depth, Q);
+          computeKnownBitsAddSub(
+              true, II->getArgOperand(0), II->getArgOperand(1), /*NSW=*/false,
+              /* NUW=*/false, DemandedElts, Known, Known2, Depth, Q);
           break;
         case Intrinsic::usub_with_overflow:
         case Intrinsic::ssub_with_overflow:
-          computeKnownBitsAddSub(false, II->getArgOperand(0),
-                                 II->getArgOperand(1), false, DemandedElts,
-                                 Known, Known2, Depth, Q);
+          computeKnownBitsAddSub(
+              false, II->getArgOperand(0), II->getArgOperand(1), /*NSW=*/false,
+              /* NUW=*/false, DemandedElts, Known, Known2, Depth, Q);
           break;
         case Intrinsic::umul_with_overflow:
         case Intrinsic::smul_with_overflow:
@@ -2318,7 +2321,11 @@ static bool isNonZeroRecurrence(const PHINode *PN) {
 
 static bool isNonZeroAdd(const APInt &DemandedElts, unsigned Depth,
                          const SimplifyQuery &Q, unsigned BitWidth, Value *X,
-                         Value *Y, bool NSW) {
+                         Value *Y, bool NSW, bool NUW) {
+  if (NUW)
+    return isKnownNonZero(Y, DemandedElts, Depth, Q) ||
+           isKnownNonZero(X, DemandedElts, Depth, Q);
+
   KnownBits XKnown = computeKnownBits(X, DemandedElts, Depth, Q);
   KnownBits YKnown = computeKnownBits(Y, DemandedElts, Depth, Q);
 
@@ -2351,7 +2358,7 @@ static bool isNonZeroAdd(const APInt &DemandedElts, unsigned Depth,
       isKnownToBeAPowerOfTwo(X, /*OrZero*/ false, Depth, Q))
     return true;
 
-  return KnownBits::computeForAddSub(/*Add*/ true, NSW, XKnown, YKnown)
+  return KnownBits::computeForAddSub(/*Add=*/true, NSW, NUW, XKnown, YKnown)
       .isNonZero();
 }
 
@@ -2556,12 +2563,9 @@ static bool isKnownNonZeroFromOperator(const Operator *I,
     // If Add has nuw wrap flag, then if either X or Y is non-zero the result is
     // non-zero.
     auto *BO = cast<OverflowingBinaryOperator>(I);
-    if (Q.IIQ.hasNoUnsignedWrap(BO))
-      return isKnownNonZero(I->getOperand(1), DemandedElts, Depth, Q) ||
-             isKnownNonZero(I->getOperand(0), DemandedElts, Depth, Q);
-
     return isNonZeroAdd(DemandedElts, Depth, Q, BitWidth, I->getOperand(0),
-                        I->getOperand(1), Q.IIQ.hasNoSignedWrap(BO));
+                        I->getOperand(1), Q.IIQ.hasNoSignedWrap(BO),
+                        Q.IIQ.hasNoUnsignedWrap(BO));
   }
   case Instruction::Mul: {
     // If X and Y are non-zero then so is X * Y as long as the multiplication
@@ -2716,7 +2720,7 @@ static bool isKnownNonZeroFromOperator(const Operator *I,
       case Intrinsic::sadd_sat:
         return isNonZeroAdd(DemandedElts, Depth, Q, BitWidth,
                             II->getArgOperand(0), II->getArgOperand(1),
-                            /*NSW*/ true);
+                            /*NSW=*/true, /* NUW=*/false);
       case Intrinsic::umax:
       case Intrinsic::uadd_sat:
         return isKnownNonZero(II->getArgOperand(1), DemandedElts, Depth, Q) ||
diff --git a/llvm/lib/CodeGen/GlobalISel/GISelKnownBits.cpp b/llvm/lib/CodeGen/GlobalISel/GISelKnownBits.cpp
index ea8c20cdcd45d6..099bf45b2734cb 100644
--- a/llvm/lib/CodeGen/GlobalISel/GISelKnownBits.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/GISelKnownBits.cpp
@@ -269,8 +269,8 @@ void GISelKnownBits::computeKnownBitsImpl(Register R, KnownBits &Known,
                          Depth + 1);
     computeKnownBitsImpl(MI.getOperand(2).getReg(), Known2, DemandedElts,
                          Depth + 1);
-    Known = KnownBits::computeForAddSub(/*Add*/ false, /*NSW*/ false, Known,
-                                        Known2);
+    Known = KnownBits::computeForAddSub(/*Add=*/false, /*NSW=*/false,
+                                        /* NUW=*/false, Known, Known2);
     break;
   }
   case TargetOpcode::G_XOR: {
@@ -296,8 +296,8 @@ void GISelKnownBits::computeKnownBitsImpl(Register R, KnownBits &Known,
                          Depth + 1);
     computeKnownBitsImpl(MI.getOperand(2).getReg(), Known2, DemandedElts,
                          Depth + 1);
-    Known =
-        KnownBits::computeForAddSub(/*Add*/ true, /*NSW*/ false, Known, Known2);
+    Known = KnownBits::computeForAddSub(/*Add=*/true, /*NSW=*/false,
+                                        /* NUW=*/false, Known, Known2);
     break;
   }
   case TargetOpcode::G_AND: {
@@ -564,7 +564,7 @@ void GISelKnownBits::computeKnownBitsImpl(Register R, KnownBits &Known,
     // right.
     KnownBits ExtKnown = KnownBits::makeConstant(APInt(BitWidth, BitWidth));
     KnownBits ShiftKnown = KnownBits::computeForAddSub(
-        /*Add*/ false, /*NSW*/ false, ExtKnown, WidthKnown);
+        /*Add=*/false, /*NSW=*/false, /* NUW=*/false, ExtKnown, WidthKnown);
     Known = KnownBits::ashr(KnownBits::shl(Known, ShiftKnown), ShiftKnown);
     break;
   }
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 5b1b7c7c627723..f7ace79e8c51d4 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -3763,8 +3763,9 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
     SDNodeFlags Flags = Op.getNode()->getFlags();
     Known = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
     Known2 = computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
-    Known = KnownBits::computeForAddSub(Op.getOpcode() == ISD::ADD,
-                                        Flags.hasNoSignedWrap(), Known, Known2);
+    Known = KnownBits::computeForAddSub(
+        Op.getOpcode() == ISD::ADD, Flags.hasNoSignedWrap(),
+        Flags.hasNoUnsignedWrap(), Known, Known2);
     break;
   }
   case ISD::USUBO:
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index 6970b230837fb9..a639cba5e35a80 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -2876,9 +2876,9 @@ bool TargetLowering::SimplifyDemandedBits(
     if (Op.getOpcode() == ISD::MUL) {
       Known = KnownBits::mul(KnownOp0, KnownOp1);
     } else { // Op.getOpcode() is either ISD::ADD or ISD::SUB.
-      Known = KnownBits::computeForAddSub(Op.getOpcode() == ISD::ADD,
-                                          Flags.hasNoSignedWrap(), KnownOp0,
-                                          KnownOp1);
+      Known = KnownBits::computeForAddSub(
+          Op.getOpcode() == ISD::ADD, Flags.hasNoSignedWrap(),
+          Flags.hasNoUnsignedWrap(), KnownOp0, KnownOp1);
     }
     break;
   }
diff --git a/llvm/lib/Support/KnownBits.cpp b/llvm/lib/Support/KnownBits.cpp
index c44a08cc1c2e6a..f999abe7dd14e6 100644
--- a/llvm/lib/Support/KnownBits.cpp
+++ b/llvm/lib/Support/KnownBits.cpp
@@ -54,7 +54,7 @@ KnownBits KnownBits::computeForAddCarry(
       LHS, RHS, Carry.Zero.getBoolValue(), Carry.One.getBoolValue());
 }
 
-KnownBits KnownBits::computeForAddSub(bool Add, bool NSW,
+KnownBits KnownBits::computeForAddSub(bool Add, bool NSW, bool /*NUW*/,
                                       const KnownBits &LHS, KnownBits RHS) {
   KnownBits KnownOut;
   if (Add) {
@@ -180,11 +180,14 @@ KnownBits KnownBits::absdiff(const KnownBits &LHS, const KnownBits &RHS) {
   // absdiff(LHS,RHS) = sub(umax(LHS,RHS), umin(LHS,RHS)).
   KnownBits UMaxValue = umax(LHS, RHS);
   KnownBits UMinValue = umin(LHS, RHS);
-  KnownBits MinMaxDiff = computeForAddSub(false, false, UMaxValue, UMinValue);
+  KnownBits MinMaxDiff = computeForAddSub(/*Add=*/false, /*NSW=*/false,
+                                          /*NUW=*/true, UMaxValue, UMinValue);
 
   // find the common bits between sub(LHS,RHS) and sub(RHS,LHS).
-  KnownBits Diff0 = computeForAddSub(false, false, LHS, RHS);
-  KnownBits Diff1 = computeForAddSub(false, false, RHS, LHS);
+  KnownBits Diff0 =
+      computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/false, LHS, RHS);
+  KnownBits Diff1 =
+      computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/false, RHS, LHS);
   KnownBits SubDiff = Diff0.intersectWith(Diff1);
 
   KnownBits KnownAbsDiff = MinMaxDiff.unionWith(SubDiff);
@@ -459,7 +462,7 @@ KnownBits KnownBits::abs(bool IntMinIsPoison) const {
       Tmp.One.setBit(countMinTrailingZeros());
 
     KnownAbs = computeForAddSub(
-        /*Add*/ false, IntMinIsPoison,
+        /*Add*/ false, IntMinIsPoison, /*NUW=*/false,
         KnownBits::makeConstant(APInt(getBitWidth(), 0)), Tmp);
 
     // One more special case for IntMinIsPoison. If we don't know any ones other
@@ -505,7 +508,8 @@ static KnownBits computeForSatAddSub(bool Add, bool Signed,
   assert(!LHS.hasConflict() && !RHS.hasConflict() && "Bad inputs");
   // We don't see NSW even for sadd/ssub as we want to check if the result has
   // signed overflow.
-  KnownBits Res = KnownBits::computeForAddSub(Add, /*NSW*/ false, LHS, RHS);
+  KnownBits Res =
+      KnownBits::computeForAddSub(Add, /*NSW=*/false, /*NUW=*/false, LHS, RHS);
   unsigned BitWidth = Res.getBitWidth();
   auto SignBitKnown = [&](const KnownBits &K) {
     return K.Zero[BitWidth - 1] || K.One[BitWidth - 1];
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp b/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp
index 4896ae8bad9ef3..2e5b02fbe85660 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp
@@ -1903,7 +1903,8 @@ bool AMDGPUDAGToDAGISel::checkFlatScratchSVSSwizzleBug(
   // voffset to (soffset + inst_offset).
   KnownBits VKnown = CurDAG->computeKnownBits(VAddr);
   KnownBits SKnown = KnownBits::computeForAddSub(
-      true, false, CurDAG->computeKnownBits(SAddr),
+      /*Add=*/true, /*NSW=*/false, /*NUW=*/false,
+      CurDAG->computeKnownBits(SAddr),
       KnownBits::makeConstant(APInt(32, ImmOffset)));
   uint64_t VMax = VKnown.getMaxValue().getZExtValue();
   uint64_t SMax = SKnown.getMaxValue().getZExtValue();
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp b/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp
index b2c65e61b0097c..94cc1d90e0ca58 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp
@@ -4573,7 +4573,7 @@ bool AMDGPUInstructionSelector::checkFlatScratchSVSSwizzleBug(
   // voffset to (soffset + inst_offset).
   auto VKnown = KB->getKnownBits(VAddr);
   auto SKnown = KnownBits::computeForAddSub(
-      true, false, KB->getKnownBits(SAddr),
+      /*Add=*/true, /*NSW=*/false, /*NUW=*/false, KB->getKnownBits(SAddr),
       KnownBits::makeConstant(APInt(32, ImmOffset)));
   uint64_t VMax = VKnown.getMaxValue().getZExtValue();
   uint64_t SMax = SKnown.getMaxValue().getZExtValue();
diff --git a/llvm/lib/Target/ARM/ARMISelLowering.cpp b/llvm/lib/Target/ARM/ARMISelLowering.cpp
index 06d4a39cde7747..dc81178311b6d8 100644
--- a/llvm/lib/Target/ARM/ARMISelLowering.cpp
+++ b/llvm/lib/Target/ARM/ARMISelLowering.cpp
@@ -20154,7 +20154,8 @@ void ARMTargetLowering::computeKnownBitsForTargetNode(const SDValue Op,
     // CSNEG: KnownOp0 or KnownOp1 * -1
     if (Op.getOpcode() == ARMISD::CSINC)
       KnownOp1 = KnownBits::computeForAddSub(
-          true, false, KnownOp1, KnownBits::makeConstant(APInt(32, 1)));
+          /*Add=*/true, /*NSW=*/false, /*NUW=*/false, KnownOp1,
+          KnownBits::makeConstant(APInt(32, 1)));
     else if (Op.getOpcode() == ARMISD::CSINV)
       std::swap(KnownOp1.Zero, KnownOp1.One);
     else if (Op.getOpcode() == ARMISD::CSNEG)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
index 97ae980a7cba70..1b963a7de4a8ae 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
@@ -565,7 +565,8 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
 
     // Otherwise just compute the known bits of the result.
     bool NSW = cast<OverflowingBinaryOperator>(I)->hasNoSignedWrap();
-    Known = KnownBits::computeForAddSub(true, NSW, LHSKnown, RHSKnown);
+    bool NUW = cast<OverflowingBinaryOperator>(I)->hasNoUnsignedWrap();
+    Known = KnownBits::computeForAddSub(true, NSW, NUW, LHSKnown, RHSKnown);
     break;
   }
   case Instruction::Sub: {
@@ -598,7 +599,8 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
 
     // Otherwise just compute the known bits of the result.
     bool NSW = cast<OverflowingBinaryOperator>(I)->hasNoSignedWrap();
-    Known = KnownBits::computeForAddSub(false, NSW, LHSKnown, RHSKnown);
+    bool NUW = cast<OverflowingBinaryOperator>(I)->hasNoUnsignedWrap();
+    Known = KnownBits::computeForAddSub(false, NSW, NUW, LHSKnown, RHSKnown);
     break;
   }
   case Instruction::Mul: {
@@ -1206,7 +1208,9 @@ Value *InstCombinerImpl::SimplifyMultipleUseDemandedBits(
       return I->getOperand(1);
 
     bool NSW = cast<OverflowingBinaryOperator>(I)->hasNoSignedWrap();
-    Known = KnownBits::computeForAddSub(/*Add*/ true, NSW, LHSKnown, RHSKnown);
+    bool NUW = cast<OverflowingBinaryOperator>(I)->hasNoUnsignedWrap();
+    Known =
+        KnownBits::computeForAddSub(/*Add=*/true, NSW, NUW, LHSKnown, RHSKnown);
     computeKnownBitsFromContext(I, Known, Depth, SQ.getWithInstruction(CxtI));
     break;
   }
@@ -1221,8 +1225,10 @@ Value *InstCombinerImpl::SimplifyMultipleUseDemandedBits(
       return I->getOperand(0);
 
     bool NSW = cast<OverflowingBinaryOperator>(I)->hasNoSignedWrap();
+    bool NUW = cast<OverflowingBinaryOperator>(I)->hasNoUnsignedWrap();
     computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, CxtI);
-    Known = KnownBits::computeForAddSub(/*Add*/ false, NSW, LHSKnown, RHSKnown);
+    Known = KnownBits::computeForAddSub(/*Add=*/false, NSW, NUW, LHSKnown,
+                                        RHSKnown);
     computeKnownBitsFromContext(I, Known, Depth, SQ.getWithInstruction(CxtI));
     break;
   }
diff --git a/llvm/unittests/Support/KnownBitsTest.cpp b/llvm/unittests/Support/KnownBitsTest.cpp
index fb9210bffcb433..d0ea1095056663 100644
--- a/llvm/unittests/Support/KnownBitsTest.cpp
+++ b/llvm/unittests/Support/KnownBitsTest.cpp
@@ -194,14 +194,14 @@ static void TestAddSubExhaustive(bool IsAdd) {
         });
       });
 
-      KnownBits KnownComputed =
-          KnownBits::computeForAddSub(IsAdd, /*NSW*/ false, Known1, Known2);
+      KnownBits KnownComputed = KnownBits::computeForAddSub(
+          IsAdd, /*NSW=*/false, /*NUW=*/false, Known1, Known2);
       EXPECT_EQ(Known, KnownComputed);
 
       // The NSW calculation is not precise, only check that it's
       // conservatively correct.
       KnownBits KnownNSWComputed = KnownBits::computeForAddSub(
-          IsAdd, /*NSW*/true, Known1, Known2);
+          IsAdd, /*NSW=*/true, /*NUW=*/false, Known1, Known2);
       EXPECT_TRUE(KnownNSWComputed.Zero.isSubsetOf(KnownNSW.Zero));
       EXPECT_TRUE(KnownNSWComputed.One.isSubsetOf(KnownNSW.One));
     });

>From 99422c763bee53d2602f29c1bda8a6c79b05aaca Mon Sep 17 00:00:00 2001
From: Noah Goldstein <goldstein.w.n at gmail.com>
Date: Wed, 28 Feb 2024 21:59:38 -0600
Subject: [PATCH 2/2] [KnownBits] Make `nuw` and `nsw` support in
 `computeForAddSub` optimal

Just some improvements that should hopefully strengthen analysis.
---
 llvm/include/llvm/Support/KnownBits.h         |   7 +-
 llvm/lib/Support/KnownBits.cpp                | 183 ++++++++++++++++--
 llvm/test/CodeGen/AArch64/sve-cmp-folds.ll    |   9 +-
 .../CodeGen/AArch64/sve-extract-element.ll    |   8 +-
 llvm/test/CodeGen/AMDGPU/ds-sub-offset.ll     |  39 ++--
 .../InstCombine/fold-log2-ceil-idiom.ll       |   2 +-
 llvm/test/Transforms/InstCombine/icmp-sub.ll  |   5 +-
 llvm/test/Transforms/InstCombine/sub.ll       |   2 +-
 llvm/unittests/Support/KnownBitsTest.cpp      |  74 +++++--
 9 files changed, 267 insertions(+), 62 deletions(-)

diff --git a/llvm/include/llvm/Support/KnownBits.h b/llvm/include/llvm/Support/KnownBits.h
index f5fce296fefe70..46dbf0c2baa5fe 100644
--- a/llvm/include/llvm/Support/KnownBits.h
+++ b/llvm/include/llvm/Support/KnownBits.h
@@ -62,6 +62,11 @@ struct KnownBits {
   /// Returns true if we don't know any bits.
   bool isUnknown() const { return Zero.isZero() && One.isZero(); }
 
+  /// Returns true if we don't know the sign bit.
+  bool isSignUnknown() const {
+    return !Zero.isSignBitSet() && !One.isSignBitSet();
+  }
+
   /// Resets the known state of all bits.
   void resetAll() {
     Zero.clearAllBits();
@@ -330,7 +335,7 @@ struct KnownBits {
 
   /// Compute known bits resulting from adding LHS and RHS.
   static KnownBits computeForAddSub(bool Add, bool NSW, bool NUW,
-                                    const KnownBits &LHS, KnownBits RHS);
+                                    const KnownBits &LHS, const KnownBits &RHS);
 
   /// Compute known bits results from subtracting RHS from LHS with 1-bit
   /// Borrow.
diff --git a/llvm/lib/Support/KnownBits.cpp b/llvm/lib/Support/KnownBits.cpp
index f999abe7dd14e6..8b4be7ab417190 100644
--- a/llvm/lib/Support/KnownBits.cpp
+++ b/llvm/lib/Support/KnownBits.cpp
@@ -54,32 +54,183 @@ KnownBits KnownBits::computeForAddCarry(
       LHS, RHS, Carry.Zero.getBoolValue(), Carry.One.getBoolValue());
 }
 
-KnownBits KnownBits::computeForAddSub(bool Add, bool NSW, bool /*NUW*/,
-                                      const KnownBits &LHS, KnownBits RHS) {
+KnownBits KnownBits::computeForAddSub(bool Add, bool NSW, bool NUW,
+                                      const KnownBits &LHS,
+                                      const KnownBits &RHS) {
+  // This can be a relatively expensive helper, so optimistically save some
+  // work.
+  if (LHS.isUnknown() && RHS.isUnknown())
+    return LHS;
   KnownBits KnownOut;
   if (Add) {
     // Sum = LHS + RHS + 0
-    KnownOut = ::computeForAddCarry(
-        LHS, RHS, /*CarryZero*/true, /*CarryOne*/false);
+    KnownOut =
+        ::computeForAddCarry(LHS, RHS, /*CarryZero*/ true, /*CarryOne*/ false);
   } else {
     // Sum = LHS + ~RHS + 1
-    std::swap(RHS.Zero, RHS.One);
-    KnownOut = ::computeForAddCarry(
-        LHS, RHS, /*CarryZero*/false, /*CarryOne*/true);
+    KnownBits NotRHS = RHS;
+    std::swap(NotRHS.Zero, NotRHS.One);
+    KnownOut = ::computeForAddCarry(LHS, NotRHS, /*CarryZero*/ false,
+                                    /*CarryOne*/ true);
   }
+  if (!NSW && !NUW)
+    return KnownOut;
 
-  // Are we still trying to solve for the sign bit?
-  if (!KnownOut.isNegative() && !KnownOut.isNonNegative()) {
-    if (NSW) {
-      // Adding two non-negative numbers, or subtracting a negative number from
-      // a non-negative one, can't wrap into negative.
-      if (LHS.isNonNegative() && RHS.isNonNegative())
+  auto GetMinMaxVal = [Add](bool ForNSW, bool ForMax, const KnownBits &L,
+                            const KnownBits &R, bool &OV) {
+    APInt LVal = ForMax ? L.getMaxValue() : L.getMinValue();
+    APInt RVal = Add == ForMax ? R.getMaxValue() : R.getMinValue();
+
+    if (ForNSW) {
+      LVal.clearSignBit();
+      RVal.clearSignBit();
+    }
+    APInt Res = Add ? LVal.uadd_ov(RVal, OV) : LVal.usub_ov(RVal, OV);
+    if (ForNSW) {
+      OV = Res.isSignBitSet();
+      Res.clearSignBit();
+      if (Res.getBitWidth() > 1 && Res[Res.getBitWidth() - 2])
+        Res.setSignBit();
+    }
+    return Res;
+  };
+
+  auto GetMaxVal = [&GetMinMaxVal](bool ForNSW, const KnownBits &L,
+                                   const KnownBits &R, bool &OV) {
+    return GetMinMaxVal(ForNSW, /*ForMax=*/true, L, R, OV);
+  };
+
+  auto GetMinVal = [&GetMinMaxVal](bool ForNSW, const KnownBits &L,
+                                   const KnownBits &R, bool &OV) {
+    return GetMinMaxVal(ForNSW, /*ForMax=*/false, L, R, OV);
+  };
+
+  auto ForceNegative = [](KnownBits &Known) {
+    Known.Zero.clearSignBit();
+    Known.One.setSignBit();
+  };
+
+  auto ForcePositive = [](KnownBits &Known) {
+    Known.One.clearSignBit();
+    Known.Zero.setSignBit();
+  };
+
+  // Handle add/sub given nsw and/or nuw.
+  //
+  // Possible TODO: Add/Sub implementations mirror one another in many ways.
+  // They could probably be compressed into a single implementation of roughly
+  // half the total LOC. Leaving seperate for now to increase clarity.
+  // NB: We handle NSW by essentially treating as nuw of bitwidth - 1 then
+  // deducing bits based on the known sign result.
+  if (Add) {
+    if (NUW || (LHS.isNonNegative() && RHS.isNonNegative())) {
+      bool OverflowMin;
+      APInt MinVal;
+      if (NSW) {
+        MinVal = GetMinVal(/*ForNSW=*/true, LHS, RHS, OverflowMin);
+        // (add nsw nuw) or (add nsw PosX, PosY)
+
+        // None of the adds can end up overflowing, so min consecutive
+        // highbits in minimum possible of X + Y must all remain set.
+        KnownOut.One.setHighBits(MinVal.countLeadingOnes());
+
+        // NSW and Positive arguments leads to positive result.
+        if (LHS.isNonNegative() && RHS.isNonNegative())
+          ForcePositive(KnownOut);
+      }
+      if (NUW) {
+        KnownOut.One.clearSignBit();
+        // (add nuw X, Y)
+        MinVal = GetMinVal(/*ForNSW=*/false, LHS, RHS, OverflowMin);
+        // Same as (add nsw PosX, PosY), basically since we can't overflow,
+        // the high bits of minimum possible X + Y must remain set.
+        KnownOut.One.setHighBits(MinVal.countLeadingOnes());
+      }
+    } else if (LHS.isNegative() && RHS.isNegative()) {
+      bool OverflowMax;
+      APInt MaxVal = GetMaxVal(/*ForNSW=*/true, LHS, RHS, OverflowMax);
+      // (add nsw NegX, NegY)
+
+      // We need to re-overflow the signbit, so we are looking for sequence
+      // of 0s from consecutive overflows.
+      KnownOut.Zero.setHighBits(MaxVal.countLeadingZeros());
+      ForceNegative(KnownOut);
+    } else if (!KnownOut.isSignUnknown()) {
+      // Pass, avoid extra work if we already know the sign bit.
+    } else if (LHS.isNonNegative() || RHS.isNonNegative()) {
+      bool OverflowMin;
+      (void)GetMinVal(/*ForNSW=*/true, LHS, RHS, OverflowMin);
+      // (add nsw PosX, ?Y)
+
+      // If the minimal possible of X + Y overflows the signbit, then Y must
+      // have been signed (which will cause unsigned overflow otherwise nsw
+      // will be violated) leading to unsigned result.
+      if (OverflowMin)
         KnownOut.makeNonNegative();
-      // Adding two negative numbers, or subtracting a non-negative number from
-      // a negative one, can't wrap into non-negative.
-      else if (LHS.isNegative() && RHS.isNegative())
+    } else if (LHS.isNegative() || RHS.isNegative()) {
+      bool OverflowMax;
+      (void)GetMaxVal(/*ForNSW=*/true, LHS, RHS, OverflowMax);
+      // (add nsw NegX, ?Y)
+
+      // If the maximum possible of X + Y doesn't overflows the signbit,
+      // then Y must have been unsigned (otherwise nsw violated) so NegX +
+      // PosY w.o overflowing the signbit results in Negative.
+      if (!OverflowMax)
         KnownOut.makeNegative();
     }
+  } else {
+    if (NUW || (LHS.isNegative() && RHS.isNonNegative())) {
+      bool OverflowMax;
+      APInt MaxVal;
+      if (NSW) {
+        MaxVal = GetMaxVal(/*ForNSW=*/true, LHS, RHS, OverflowMax);
+        // (sub nsw nuw) or (sub nsw NegX, PosY)
+
+        // None of the subs can overflow at any point, so any common high bits
+        // will subtract away and result in zeros.
+        KnownOut.Zero.setHighBits(MaxVal.countLeadingZeros());
+        if (LHS.isNegative() && RHS.isNonNegative())
+          ForceNegative(KnownOut);
+      }
+      if (NUW) {
+        KnownOut.Zero.clearSignBit();
+        // (sub nuw X, Y)
+        MaxVal = GetMaxVal(/*ForNSW=*/false, LHS, RHS, OverflowMax);
+
+        // Basically all common high bits between X/Y will cancel out as
+        // leading zeros.
+        KnownOut.Zero.setHighBits(MaxVal.countLeadingZeros());
+      }
+    } else if (LHS.isNonNegative() && RHS.isNegative()) {
+      bool OverflowMin;
+      APInt MinVal = GetMinVal(/*ForNSW=*/true, LHS, RHS, OverflowMin);
+      // (sub nsw PosX, NegY)
+
+      // Opposite case of above, we must "re-overflow" the signbit, so
+      // minimal set of high bits will be fixed.
+      KnownOut.One.setHighBits(MinVal.countLeadingOnes());
+      ForcePositive(KnownOut);
+    } else if (!KnownOut.isSignUnknown()) {
+      // Pass, avoid extra work if we already know the sign bit.
+    } else if (LHS.isNegative() || RHS.isNonNegative()) {
+      bool OverflowMax;
+      (void)GetMaxVal(/*ForNSW=*/true, LHS, RHS, OverflowMax);
+      // (sub nsw NegX/?X, ?Y/PosY)
+      if (OverflowMax)
+        KnownOut.makeNegative();
+    } else if (LHS.isNonNegative() || RHS.isNegative()) {
+      bool OverflowMin;
+      (void)GetMinVal(/*ForNSW=*/true, LHS, RHS, OverflowMin);
+      // (sub nsw PosX/?X, ?Y/NegY)
+      if (!OverflowMin)
+        KnownOut.makeNonNegative();
+    }
+  }
+
+  // Just return 0 if the nsw/nuw is violated and we have poison.
+  if (KnownOut.hasConflict()) {
+    KnownOut.setAllZero();
+    return KnownOut;
   }
 
   return KnownOut;
diff --git a/llvm/test/CodeGen/AArch64/sve-cmp-folds.ll b/llvm/test/CodeGen/AArch64/sve-cmp-folds.ll
index beded623272c13..c8a36e47efca6e 100644
--- a/llvm/test/CodeGen/AArch64/sve-cmp-folds.ll
+++ b/llvm/test/CodeGen/AArch64/sve-cmp-folds.ll
@@ -114,9 +114,12 @@ define i1 @foo_last(<vscale x 4 x float> %a, <vscale x 4 x float> %b) {
 ; CHECK-LABEL: foo_last:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    ptrue p0.s
-; CHECK-NEXT:    fcmeq p1.s, p0/z, z0.s, z1.s
-; CHECK-NEXT:    ptest p0, p1.b
-; CHECK-NEXT:    cset w0, lo
+; CHECK-NEXT:    mov x8, #-1 // =0xffffffffffffffff
+; CHECK-NEXT:    whilels p1.s, xzr, x8
+; CHECK-NEXT:    fcmeq p0.s, p0/z, z0.s, z1.s
+; CHECK-NEXT:    mov z0.s, p0/z, #1 // =0x1
+; CHECK-NEXT:    lastb w8, p1, z0.s
+; CHECK-NEXT:    and w0, w8, #0x1
 ; CHECK-NEXT:    ret
   %vcond = fcmp oeq <vscale x 4 x float> %a, %b
   %vscale = call i64 @llvm.vscale.i64()
diff --git a/llvm/test/CodeGen/AArch64/sve-extract-element.ll b/llvm/test/CodeGen/AArch64/sve-extract-element.ll
index 273785f2436404..a3c34b53baa079 100644
--- a/llvm/test/CodeGen/AArch64/sve-extract-element.ll
+++ b/llvm/test/CodeGen/AArch64/sve-extract-element.ll
@@ -614,9 +614,11 @@ define i1 @test_lane9_8xi1(<vscale x 8 x i1> %a) #0 {
 define i1 @test_last_8xi1(<vscale x 8 x i1> %a) #0 {
 ; CHECK-LABEL: test_last_8xi1:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    ptrue p1.h
-; CHECK-NEXT:    ptest p1, p0.b
-; CHECK-NEXT:    cset w0, lo
+; CHECK-NEXT:    mov x8, #-1 // =0xffffffffffffffff
+; CHECK-NEXT:    mov z0.h, p0/z, #1 // =0x1
+; CHECK-NEXT:    whilels p1.h, xzr, x8
+; CHECK-NEXT:    lastb w8, p1, z0.h
+; CHECK-NEXT:    and w0, w8, #0x1
 ; CHECK-NEXT:    ret
   %vscale = call i64 @llvm.vscale.i64()
   %shl = shl nuw nsw i64 %vscale, 3
diff --git a/llvm/test/CodeGen/AMDGPU/ds-sub-offset.ll b/llvm/test/CodeGen/AMDGPU/ds-sub-offset.ll
index 6e6b204031c0f0..7b9b130e1cf796 100644
--- a/llvm/test/CodeGen/AMDGPU/ds-sub-offset.ll
+++ b/llvm/test/CodeGen/AMDGPU/ds-sub-offset.ll
@@ -137,19 +137,18 @@ define amdgpu_kernel void @write_ds_sub_max_offset_global_clamp_bit(float %dummy
 ; CI:       ; %bb.0:
 ; CI-NEXT:    s_load_dword s0, s[0:1], 0x0
 ; CI-NEXT:    s_mov_b64 vcc, 0
-; CI-NEXT:    v_not_b32_e32 v0, v0
-; CI-NEXT:    v_lshlrev_b32_e32 v0, 2, v0
-; CI-NEXT:    v_mov_b32_e32 v2, 0x7b
+; CI-NEXT:    v_mov_b32_e32 v1, 0x7b
+; CI-NEXT:    v_mov_b32_e32 v2, 0
+; CI-NEXT:    s_mov_b32 m0, -1
 ; CI-NEXT:    s_waitcnt lgkmcnt(0)
-; CI-NEXT:    v_mov_b32_e32 v1, s0
-; CI-NEXT:    v_div_fmas_f32 v1, v1, v1, v1
+; CI-NEXT:    v_mov_b32_e32 v0, s0
+; CI-NEXT:    v_div_fmas_f32 v0, v0, v0, v0
 ; CI-NEXT:    s_mov_b32 s0, 0
-; CI-NEXT:    s_mov_b32 m0, -1
 ; CI-NEXT:    s_mov_b32 s3, 0xf000
 ; CI-NEXT:    s_mov_b32 s2, -1
 ; CI-NEXT:    s_mov_b32 s1, s0
-; CI-NEXT:    ds_write_b32 v0, v2 offset:65532
-; CI-NEXT:    buffer_store_dword v1, off, s[0:3], 0
+; CI-NEXT:    ds_write_b32 v2, v1
+; CI-NEXT:    buffer_store_dword v0, off, s[0:3], 0
 ; CI-NEXT:    s_waitcnt vmcnt(0)
 ; CI-NEXT:    s_endpgm
 ;
@@ -157,15 +156,14 @@ define amdgpu_kernel void @write_ds_sub_max_offset_global_clamp_bit(float %dummy
 ; GFX9:       ; %bb.0:
 ; GFX9-NEXT:    s_load_dword s0, s[0:1], 0x0
 ; GFX9-NEXT:    s_mov_b64 vcc, 0
-; GFX9-NEXT:    v_not_b32_e32 v0, v0
-; GFX9-NEXT:    v_lshlrev_b32_e32 v3, 2, v0
-; GFX9-NEXT:    v_mov_b32_e32 v4, 0x7b
+; GFX9-NEXT:    v_mov_b32_e32 v3, 0x7b
+; GFX9-NEXT:    v_mov_b32_e32 v4, 0
+; GFX9-NEXT:    ds_write_b32 v4, v3
 ; GFX9-NEXT:    s_waitcnt lgkmcnt(0)
-; GFX9-NEXT:    v_mov_b32_e32 v1, s0
-; GFX9-NEXT:    v_div_fmas_f32 v2, v1, v1, v1
+; GFX9-NEXT:    v_mov_b32_e32 v0, s0
+; GFX9-NEXT:    v_div_fmas_f32 v2, v0, v0, v0
 ; GFX9-NEXT:    v_mov_b32_e32 v0, 0
 ; GFX9-NEXT:    v_mov_b32_e32 v1, 0
-; GFX9-NEXT:    ds_write_b32 v3, v4 offset:65532
 ; GFX9-NEXT:    global_store_dword v[0:1], v2, off
 ; GFX9-NEXT:    s_waitcnt vmcnt(0)
 ; GFX9-NEXT:    s_endpgm
@@ -173,13 +171,12 @@ define amdgpu_kernel void @write_ds_sub_max_offset_global_clamp_bit(float %dummy
 ; GFX10-LABEL: write_ds_sub_max_offset_global_clamp_bit:
 ; GFX10:       ; %bb.0:
 ; GFX10-NEXT:    s_load_dword s0, s[0:1], 0x0
-; GFX10-NEXT:    v_not_b32_e32 v0, v0
 ; GFX10-NEXT:    s_mov_b32 vcc_lo, 0
-; GFX10-NEXT:    v_mov_b32_e32 v3, 0x7b
-; GFX10-NEXT:    v_lshlrev_b32_e32 v2, 2, v0
 ; GFX10-NEXT:    v_mov_b32_e32 v0, 0
+; GFX10-NEXT:    v_mov_b32_e32 v2, 0x7b
+; GFX10-NEXT:    v_mov_b32_e32 v3, 0
 ; GFX10-NEXT:    v_mov_b32_e32 v1, 0
-; GFX10-NEXT:    ds_write_b32 v2, v3 offset:65532
+; GFX10-NEXT:    ds_write_b32 v3, v2
 ; GFX10-NEXT:    s_waitcnt lgkmcnt(0)
 ; GFX10-NEXT:    v_div_fmas_f32 v4, s0, s0, s0
 ; GFX10-NEXT:    global_store_dword v[0:1], v4, off
@@ -189,13 +186,11 @@ define amdgpu_kernel void @write_ds_sub_max_offset_global_clamp_bit(float %dummy
 ; GFX11-LABEL: write_ds_sub_max_offset_global_clamp_bit:
 ; GFX11:       ; %bb.0:
 ; GFX11-NEXT:    s_load_b32 s0, s[0:1], 0x0
-; GFX11-NEXT:    v_not_b32_e32 v0, v0
 ; GFX11-NEXT:    s_mov_b32 vcc_lo, 0
-; GFX11-NEXT:    s_delay_alu instid0(VALU_DEP_1)
-; GFX11-NEXT:    v_dual_mov_b32 v3, 0x7b :: v_dual_lshlrev_b32 v2, 2, v0
 ; GFX11-NEXT:    v_mov_b32_e32 v0, 0
+; GFX11-NEXT:    v_dual_mov_b32 v2, 0x7b :: v_dual_mov_b32 v3, 0
 ; GFX11-NEXT:    v_mov_b32_e32 v1, 0
-; GFX11-NEXT:    ds_store_b32 v2, v3 offset:65532
+; GFX11-NEXT:    ds_store_b32 v3, v2
 ; GFX11-NEXT:    s_waitcnt lgkmcnt(0)
 ; GFX11-NEXT:    v_div_fmas_f32 v4, s0, s0, s0
 ; GFX11-NEXT:    global_store_b32 v[0:1], v4, off dlc
diff --git a/llvm/test/Transforms/InstCombine/fold-log2-ceil-idiom.ll b/llvm/test/Transforms/InstCombine/fold-log2-ceil-idiom.ll
index 2594c3fce81464..434d98449f99c4 100644
--- a/llvm/test/Transforms/InstCombine/fold-log2-ceil-idiom.ll
+++ b/llvm/test/Transforms/InstCombine/fold-log2-ceil-idiom.ll
@@ -43,7 +43,7 @@ define i64 @log2_ceil_idiom_zext(i32 %x) {
 ; CHECK-NEXT:    [[TMP1:%.*]] = add i32 [[X]], -1
 ; CHECK-NEXT:    [[TMP2:%.*]] = call i32 @llvm.ctlz.i32(i32 [[TMP1]], i1 false), !range [[RNG0]]
 ; CHECK-NEXT:    [[TMP3:%.*]] = sub nuw nsw i32 32, [[TMP2]]
-; CHECK-NEXT:    [[RET:%.*]] = zext i32 [[TMP3]] to i64
+; CHECK-NEXT:    [[RET:%.*]] = zext nneg i32 [[TMP3]] to i64
 ; CHECK-NEXT:    ret i64 [[RET]]
 ;
   %ctlz = tail call i32 @llvm.ctlz.i32(i32 %x, i1 true)
diff --git a/llvm/test/Transforms/InstCombine/icmp-sub.ll b/llvm/test/Transforms/InstCombine/icmp-sub.ll
index 2dad575fede83c..5645dededf2e4b 100644
--- a/llvm/test/Transforms/InstCombine/icmp-sub.ll
+++ b/llvm/test/Transforms/InstCombine/icmp-sub.ll
@@ -36,7 +36,7 @@ define i1 @test_nuw_nsw_and_unsigned_pred(i64 %x) {
 
 define i1 @test_nuw_nsw_and_signed_pred(i64 %x) {
 ; CHECK-LABEL: @test_nuw_nsw_and_signed_pred(
-; CHECK-NEXT:    [[Z:%.*]] = icmp sgt i64 [[X:%.*]], 7
+; CHECK-NEXT:    [[Z:%.*]] = icmp ugt i64 [[X:%.*]], 7
 ; CHECK-NEXT:    ret i1 [[Z]]
 ;
   %y = sub nuw nsw i64 10, %x
@@ -46,8 +46,7 @@ define i1 @test_nuw_nsw_and_signed_pred(i64 %x) {
 
 define i1 @test_negative_nuw_and_signed_pred(i64 %x) {
 ; CHECK-LABEL: @test_negative_nuw_and_signed_pred(
-; CHECK-NEXT:    [[NOTSUB:%.*]] = add nuw i64 [[X:%.*]], -11
-; CHECK-NEXT:    [[Z:%.*]] = icmp sgt i64 [[NOTSUB]], -4
+; CHECK-NEXT:    [[Z:%.*]] = icmp ugt i64 [[X:%.*]], 7
 ; CHECK-NEXT:    ret i1 [[Z]]
 ;
   %y = sub nuw i64 10, %x
diff --git a/llvm/test/Transforms/InstCombine/sub.ll b/llvm/test/Transforms/InstCombine/sub.ll
index 76cd7ab5c10cd1..249b5673c8acfd 100644
--- a/llvm/test/Transforms/InstCombine/sub.ll
+++ b/llvm/test/Transforms/InstCombine/sub.ll
@@ -2367,7 +2367,7 @@ define <2 x i8> @sub_to_and_vector3(<2 x i8> %x) {
 ; CHECK-LABEL: @sub_to_and_vector3(
 ; CHECK-NEXT:    [[SUB:%.*]] = sub nuw <2 x i8> <i8 71, i8 71>, [[X:%.*]]
 ; CHECK-NEXT:    [[AND:%.*]] = and <2 x i8> [[SUB]], <i8 120, i8 undef>
-; CHECK-NEXT:    [[R:%.*]] = sub <2 x i8> <i8 44, i8 44>, [[AND]]
+; CHECK-NEXT:    [[R:%.*]] = sub nsw <2 x i8> <i8 44, i8 44>, [[AND]]
 ; CHECK-NEXT:    ret <2 x i8> [[R]]
 ;
   %sub = sub nuw <2 x i8> <i8 71, i8 71>, %x
diff --git a/llvm/unittests/Support/KnownBitsTest.cpp b/llvm/unittests/Support/KnownBitsTest.cpp
index d0ea1095056663..658f3796721c4e 100644
--- a/llvm/unittests/Support/KnownBitsTest.cpp
+++ b/llvm/unittests/Support/KnownBitsTest.cpp
@@ -169,41 +169,69 @@ static void TestAddSubExhaustive(bool IsAdd) {
   unsigned Bits = 4;
   ForeachKnownBits(Bits, [&](const KnownBits &Known1) {
     ForeachKnownBits(Bits, [&](const KnownBits &Known2) {
-      KnownBits Known(Bits), KnownNSW(Bits);
+      KnownBits Known(Bits), KnownNSW(Bits), KnownNUW(Bits),
+          KnownNSWAndNUW(Bits);
       Known.Zero.setAllBits();
       Known.One.setAllBits();
       KnownNSW.Zero.setAllBits();
       KnownNSW.One.setAllBits();
+      KnownNUW.Zero.setAllBits();
+      KnownNUW.One.setAllBits();
+      KnownNSWAndNUW.Zero.setAllBits();
+      KnownNSWAndNUW.One.setAllBits();
 
       ForeachNumInKnownBits(Known1, [&](const APInt &N1) {
         ForeachNumInKnownBits(Known2, [&](const APInt &N2) {
-          bool Overflow;
+          bool SignedOverflow;
+          bool UnsignedOverflow;
           APInt Res;
-          if (IsAdd)
-            Res = N1.sadd_ov(N2, Overflow);
-          else
-            Res = N1.ssub_ov(N2, Overflow);
+          if (IsAdd) {
+            Res = N1.uadd_ov(N2, UnsignedOverflow);
+            Res = N1.sadd_ov(N2, SignedOverflow);
+          } else {
+            Res = N1.usub_ov(N2, UnsignedOverflow);
+            Res = N1.ssub_ov(N2, SignedOverflow);
+          }
 
           Known.One &= Res;
           Known.Zero &= ~Res;
 
-          if (!Overflow) {
+          if (!SignedOverflow) {
             KnownNSW.One &= Res;
             KnownNSW.Zero &= ~Res;
           }
+
+          if (!UnsignedOverflow) {
+            KnownNUW.One &= Res;
+            KnownNUW.Zero &= ~Res;
+          }
+
+          if (!UnsignedOverflow && !SignedOverflow) {
+            KnownNSWAndNUW.One &= Res;
+            KnownNSWAndNUW.Zero &= ~Res;
+          }
         });
       });
 
       KnownBits KnownComputed = KnownBits::computeForAddSub(
           IsAdd, /*NSW=*/false, /*NUW=*/false, Known1, Known2);
-      EXPECT_EQ(Known, KnownComputed);
+      EXPECT_TRUE(isOptimal(Known, KnownComputed, {Known1, Known2}));
 
-      // The NSW calculation is not precise, only check that it's
-      // conservatively correct.
       KnownBits KnownNSWComputed = KnownBits::computeForAddSub(
           IsAdd, /*NSW=*/true, /*NUW=*/false, Known1, Known2);
-      EXPECT_TRUE(KnownNSWComputed.Zero.isSubsetOf(KnownNSW.Zero));
-      EXPECT_TRUE(KnownNSWComputed.One.isSubsetOf(KnownNSW.One));
+      if (!KnownNSW.hasConflict())
+        EXPECT_TRUE(isOptimal(KnownNSW, KnownNSWComputed, {Known1, Known2}));
+
+      KnownBits KnownNUWComputed = KnownBits::computeForAddSub(
+          IsAdd, /*NSW=*/false, /*NUW=*/true, Known1, Known2);
+      if (!KnownNUW.hasConflict())
+        EXPECT_TRUE(isOptimal(KnownNUW, KnownNUWComputed, {Known1, Known2}));
+
+      KnownBits KnownNSWAndNUWComputed = KnownBits::computeForAddSub(
+          IsAdd, /*NSW=*/true, /*NUW=*/true, Known1, Known2);
+      if (!KnownNSWAndNUW.hasConflict())
+        EXPECT_TRUE(isOptimal(KnownNSWAndNUW, KnownNSWAndNUWComputed,
+                              {Known1, Known2}));
     });
   });
 }
@@ -244,6 +272,28 @@ TEST(KnownBitsTest, SubBorrowExhaustive) {
   });
 }
 
+TEST(KnownBitsTest, SignBitUnknown) {
+  KnownBits Known(2);
+  EXPECT_TRUE(Known.isSignUnknown());
+  Known.Zero.setBit(0);
+  EXPECT_TRUE(Known.isSignUnknown());
+  Known.Zero.setBit(1);
+  EXPECT_FALSE(Known.isSignUnknown());
+  Known.Zero.clearBit(0);
+  EXPECT_FALSE(Known.isSignUnknown());
+  Known.Zero.clearBit(1);
+  EXPECT_TRUE(Known.isSignUnknown());
+
+  Known.One.setBit(0);
+  EXPECT_TRUE(Known.isSignUnknown());
+  Known.One.setBit(1);
+  EXPECT_FALSE(Known.isSignUnknown());
+  Known.One.clearBit(0);
+  EXPECT_FALSE(Known.isSignUnknown());
+  Known.One.clearBit(1);
+  EXPECT_TRUE(Known.isSignUnknown());
+}
+
 TEST(KnownBitsTest, AbsDiffSpecialCase) {
   // There are 2 implementation of absdiff - both are currently needed to cover
   // extra cases.



More information about the llvm-commits mailing list