[llvm] 17162b6 - [KnownBits] Make `nuw` and `nsw` support in `computeForAddSub` optimal

Noah Goldstein via llvm-commits llvm-commits at lists.llvm.org
Tue Mar 5 11:00:14 PST 2024


Author: Noah Goldstein
Date: 2024-03-05T12:59:58-06:00
New Revision: 17162b61c2e6968482fab928f89bdca8b4ac06d9

URL: https://github.com/llvm/llvm-project/commit/17162b61c2e6968482fab928f89bdca8b4ac06d9
DIFF: https://github.com/llvm/llvm-project/commit/17162b61c2e6968482fab928f89bdca8b4ac06d9.diff

LOG: [KnownBits] Make `nuw` and `nsw` support in `computeForAddSub` optimal

Just some improvements that should hopefully strengthen analysis.

Closes #83580

Added: 
    

Modified: 
    llvm/include/llvm/Support/KnownBits.h
    llvm/lib/Support/KnownBits.cpp
    llvm/test/CodeGen/AArch64/sve-cmp-folds.ll
    llvm/test/CodeGen/AArch64/sve-extract-element.ll
    llvm/test/CodeGen/AMDGPU/ds-sub-offset.ll
    llvm/test/Transforms/InstCombine/fold-log2-ceil-idiom.ll
    llvm/test/Transforms/InstCombine/icmp-sub.ll
    llvm/test/Transforms/InstCombine/sub.ll
    llvm/unittests/Support/KnownBitsTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Support/KnownBits.h b/llvm/include/llvm/Support/KnownBits.h
index f5fce296fefe70..46dbf0c2baa5fe 100644
--- a/llvm/include/llvm/Support/KnownBits.h
+++ b/llvm/include/llvm/Support/KnownBits.h
@@ -62,6 +62,11 @@ struct KnownBits {
   /// Returns true if we don't know any bits.
   bool isUnknown() const { return Zero.isZero() && One.isZero(); }
 
+  /// Returns true if we don't know the sign bit.
+  bool isSignUnknown() const {
+    return !Zero.isSignBitSet() && !One.isSignBitSet();
+  }
+
   /// Resets the known state of all bits.
   void resetAll() {
     Zero.clearAllBits();
@@ -330,7 +335,7 @@ struct KnownBits {
 
   /// Compute known bits resulting from adding LHS and RHS.
   static KnownBits computeForAddSub(bool Add, bool NSW, bool NUW,
-                                    const KnownBits &LHS, KnownBits RHS);
+                                    const KnownBits &LHS, const KnownBits &RHS);
 
   /// Compute known bits results from subtracting RHS from LHS with 1-bit
   /// Borrow.

diff  --git a/llvm/lib/Support/KnownBits.cpp b/llvm/lib/Support/KnownBits.cpp
index f999abe7dd14e6..74d857457aec1e 100644
--- a/llvm/lib/Support/KnownBits.cpp
+++ b/llvm/lib/Support/KnownBits.cpp
@@ -54,34 +54,89 @@ KnownBits KnownBits::computeForAddCarry(
       LHS, RHS, Carry.Zero.getBoolValue(), Carry.One.getBoolValue());
 }
 
-KnownBits KnownBits::computeForAddSub(bool Add, bool NSW, bool /*NUW*/,
-                                      const KnownBits &LHS, KnownBits RHS) {
-  KnownBits KnownOut;
-  if (Add) {
-    // Sum = LHS + RHS + 0
-    KnownOut = ::computeForAddCarry(
-        LHS, RHS, /*CarryZero*/true, /*CarryOne*/false);
-  } else {
-    // Sum = LHS + ~RHS + 1
-    std::swap(RHS.Zero, RHS.One);
-    KnownOut = ::computeForAddCarry(
-        LHS, RHS, /*CarryZero*/false, /*CarryOne*/true);
+KnownBits KnownBits::computeForAddSub(bool Add, bool NSW, bool NUW,
+                                      const KnownBits &LHS,
+                                      const KnownBits &RHS) {
+  unsigned BitWidth = LHS.getBitWidth();
+  KnownBits KnownOut(BitWidth);
+  // This can be a relatively expensive helper, so optimistically save some
+  // work.
+  if (LHS.isUnknown() && RHS.isUnknown())
+    return KnownOut;
+
+  if (!LHS.isUnknown() && !RHS.isUnknown()) {
+    if (Add) {
+      // Sum = LHS + RHS + 0
+      KnownOut = ::computeForAddCarry(LHS, RHS, /*CarryZero=*/true,
+                                      /*CarryOne=*/false);
+    } else {
+      // Sum = LHS + ~RHS + 1
+      KnownBits NotRHS = RHS;
+      std::swap(NotRHS.Zero, NotRHS.One);
+      KnownOut = ::computeForAddCarry(LHS, NotRHS, /*CarryZero=*/false,
+                                      /*CarryOne=*/true);
+    }
   }
 
-  // Are we still trying to solve for the sign bit?
-  if (!KnownOut.isNegative() && !KnownOut.isNonNegative()) {
-    if (NSW) {
-      // Adding two non-negative numbers, or subtracting a negative number from
-      // a non-negative one, can't wrap into negative.
-      if (LHS.isNonNegative() && RHS.isNonNegative())
-        KnownOut.makeNonNegative();
-      // Adding two negative numbers, or subtracting a non-negative number from
-      // a negative one, can't wrap into non-negative.
-      else if (LHS.isNegative() && RHS.isNegative())
-        KnownOut.makeNegative();
+  // Handle add/sub given nsw and/or nuw.
+  if (NUW) {
+    if (Add) {
+      // (add nuw X, Y)
+      APInt MinVal = LHS.getMinValue().uadd_sat(RHS.getMinValue());
+      // None of the adds can end up overflowing, so min consecutive highbits
+      // in minimum possible of X + Y must all remain set.
+      if (NSW) {
+        unsigned NumBits = MinVal.trunc(BitWidth - 1).countl_one();
+        // If we have NSW as well, we also know we can't overflow the signbit so
+        // can start counting from 1 bit back.
+        KnownOut.One.setBits(BitWidth - 1 - NumBits, BitWidth - 1);
+      }
+      KnownOut.One.setHighBits(MinVal.countl_one());
+    } else {
+      // (sub nuw X, Y)
+      APInt MaxVal = LHS.getMaxValue().usub_sat(RHS.getMinValue());
+      // None of the subs can overflow at any point, so any common high bits
+      // will subtract away and result in zeros.
+      if (NSW) {
+        // If we have NSW as well, we also know we can't overflow the signbit so
+        // can start counting from 1 bit back.
+        unsigned NumBits = MaxVal.trunc(BitWidth - 1).countl_zero();
+        KnownOut.Zero.setBits(BitWidth - 1 - NumBits, BitWidth - 1);
+      }
+      KnownOut.Zero.setHighBits(MaxVal.countl_zero());
+    }
+  }
+
+  if (NSW) {
+    APInt MinVal;
+    APInt MaxVal;
+    if (Add) {
+      // (add nsw X, Y)
+      MinVal = LHS.getSignedMinValue().sadd_sat(RHS.getSignedMinValue());
+      MaxVal = LHS.getSignedMaxValue().sadd_sat(RHS.getSignedMaxValue());
+    } else {
+      // (sub nsw X, Y)
+      MinVal = LHS.getSignedMinValue().ssub_sat(RHS.getSignedMaxValue());
+      MaxVal = LHS.getSignedMaxValue().ssub_sat(RHS.getSignedMinValue());
+    }
+    if (MinVal.isNonNegative()) {
+      // If min is non-negative, result will always be non-neg (can't overflow
+      // around).
+      unsigned NumBits = MinVal.trunc(BitWidth - 1).countl_one();
+      KnownOut.One.setBits(BitWidth - 1 - NumBits, BitWidth - 1);
+      KnownOut.Zero.setSignBit();
+    }
+    if (MaxVal.isNegative()) {
+      // If max is negative, result will always be neg (can't overflow around).
+      unsigned NumBits = MaxVal.trunc(BitWidth - 1).countl_zero();
+      KnownOut.Zero.setBits(BitWidth - 1 - NumBits, BitWidth - 1);
+      KnownOut.One.setSignBit();
     }
   }
 
+  // Just return 0 if the nsw/nuw is violated and we have poison.
+  if (KnownOut.hasConflict())
+    KnownOut.setAllZero();
   return KnownOut;
 }
 

diff  --git a/llvm/test/CodeGen/AArch64/sve-cmp-folds.ll b/llvm/test/CodeGen/AArch64/sve-cmp-folds.ll
index beded623272c13..c8a36e47efca6e 100644
--- a/llvm/test/CodeGen/AArch64/sve-cmp-folds.ll
+++ b/llvm/test/CodeGen/AArch64/sve-cmp-folds.ll
@@ -114,9 +114,12 @@ define i1 @foo_last(<vscale x 4 x float> %a, <vscale x 4 x float> %b) {
 ; CHECK-LABEL: foo_last:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    ptrue p0.s
-; CHECK-NEXT:    fcmeq p1.s, p0/z, z0.s, z1.s
-; CHECK-NEXT:    ptest p0, p1.b
-; CHECK-NEXT:    cset w0, lo
+; CHECK-NEXT:    mov x8, #-1 // =0xffffffffffffffff
+; CHECK-NEXT:    whilels p1.s, xzr, x8
+; CHECK-NEXT:    fcmeq p0.s, p0/z, z0.s, z1.s
+; CHECK-NEXT:    mov z0.s, p0/z, #1 // =0x1
+; CHECK-NEXT:    lastb w8, p1, z0.s
+; CHECK-NEXT:    and w0, w8, #0x1
 ; CHECK-NEXT:    ret
   %vcond = fcmp oeq <vscale x 4 x float> %a, %b
   %vscale = call i64 @llvm.vscale.i64()

diff  --git a/llvm/test/CodeGen/AArch64/sve-extract-element.ll b/llvm/test/CodeGen/AArch64/sve-extract-element.ll
index 273785f2436404..a3c34b53baa079 100644
--- a/llvm/test/CodeGen/AArch64/sve-extract-element.ll
+++ b/llvm/test/CodeGen/AArch64/sve-extract-element.ll
@@ -614,9 +614,11 @@ define i1 @test_lane9_8xi1(<vscale x 8 x i1> %a) #0 {
 define i1 @test_last_8xi1(<vscale x 8 x i1> %a) #0 {
 ; CHECK-LABEL: test_last_8xi1:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    ptrue p1.h
-; CHECK-NEXT:    ptest p1, p0.b
-; CHECK-NEXT:    cset w0, lo
+; CHECK-NEXT:    mov x8, #-1 // =0xffffffffffffffff
+; CHECK-NEXT:    mov z0.h, p0/z, #1 // =0x1
+; CHECK-NEXT:    whilels p1.h, xzr, x8
+; CHECK-NEXT:    lastb w8, p1, z0.h
+; CHECK-NEXT:    and w0, w8, #0x1
 ; CHECK-NEXT:    ret
   %vscale = call i64 @llvm.vscale.i64()
   %shl = shl nuw nsw i64 %vscale, 3

diff  --git a/llvm/test/CodeGen/AMDGPU/ds-sub-offset.ll b/llvm/test/CodeGen/AMDGPU/ds-sub-offset.ll
index 6e6b204031c0f0..7b9b130e1cf796 100644
--- a/llvm/test/CodeGen/AMDGPU/ds-sub-offset.ll
+++ b/llvm/test/CodeGen/AMDGPU/ds-sub-offset.ll
@@ -137,19 +137,18 @@ define amdgpu_kernel void @write_ds_sub_max_offset_global_clamp_bit(float %dummy
 ; CI:       ; %bb.0:
 ; CI-NEXT:    s_load_dword s0, s[0:1], 0x0
 ; CI-NEXT:    s_mov_b64 vcc, 0
-; CI-NEXT:    v_not_b32_e32 v0, v0
-; CI-NEXT:    v_lshlrev_b32_e32 v0, 2, v0
-; CI-NEXT:    v_mov_b32_e32 v2, 0x7b
+; CI-NEXT:    v_mov_b32_e32 v1, 0x7b
+; CI-NEXT:    v_mov_b32_e32 v2, 0
+; CI-NEXT:    s_mov_b32 m0, -1
 ; CI-NEXT:    s_waitcnt lgkmcnt(0)
-; CI-NEXT:    v_mov_b32_e32 v1, s0
-; CI-NEXT:    v_div_fmas_f32 v1, v1, v1, v1
+; CI-NEXT:    v_mov_b32_e32 v0, s0
+; CI-NEXT:    v_div_fmas_f32 v0, v0, v0, v0
 ; CI-NEXT:    s_mov_b32 s0, 0
-; CI-NEXT:    s_mov_b32 m0, -1
 ; CI-NEXT:    s_mov_b32 s3, 0xf000
 ; CI-NEXT:    s_mov_b32 s2, -1
 ; CI-NEXT:    s_mov_b32 s1, s0
-; CI-NEXT:    ds_write_b32 v0, v2 offset:65532
-; CI-NEXT:    buffer_store_dword v1, off, s[0:3], 0
+; CI-NEXT:    ds_write_b32 v2, v1
+; CI-NEXT:    buffer_store_dword v0, off, s[0:3], 0
 ; CI-NEXT:    s_waitcnt vmcnt(0)
 ; CI-NEXT:    s_endpgm
 ;
@@ -157,15 +156,14 @@ define amdgpu_kernel void @write_ds_sub_max_offset_global_clamp_bit(float %dummy
 ; GFX9:       ; %bb.0:
 ; GFX9-NEXT:    s_load_dword s0, s[0:1], 0x0
 ; GFX9-NEXT:    s_mov_b64 vcc, 0
-; GFX9-NEXT:    v_not_b32_e32 v0, v0
-; GFX9-NEXT:    v_lshlrev_b32_e32 v3, 2, v0
-; GFX9-NEXT:    v_mov_b32_e32 v4, 0x7b
+; GFX9-NEXT:    v_mov_b32_e32 v3, 0x7b
+; GFX9-NEXT:    v_mov_b32_e32 v4, 0
+; GFX9-NEXT:    ds_write_b32 v4, v3
 ; GFX9-NEXT:    s_waitcnt lgkmcnt(0)
-; GFX9-NEXT:    v_mov_b32_e32 v1, s0
-; GFX9-NEXT:    v_div_fmas_f32 v2, v1, v1, v1
+; GFX9-NEXT:    v_mov_b32_e32 v0, s0
+; GFX9-NEXT:    v_div_fmas_f32 v2, v0, v0, v0
 ; GFX9-NEXT:    v_mov_b32_e32 v0, 0
 ; GFX9-NEXT:    v_mov_b32_e32 v1, 0
-; GFX9-NEXT:    ds_write_b32 v3, v4 offset:65532
 ; GFX9-NEXT:    global_store_dword v[0:1], v2, off
 ; GFX9-NEXT:    s_waitcnt vmcnt(0)
 ; GFX9-NEXT:    s_endpgm
@@ -173,13 +171,12 @@ define amdgpu_kernel void @write_ds_sub_max_offset_global_clamp_bit(float %dummy
 ; GFX10-LABEL: write_ds_sub_max_offset_global_clamp_bit:
 ; GFX10:       ; %bb.0:
 ; GFX10-NEXT:    s_load_dword s0, s[0:1], 0x0
-; GFX10-NEXT:    v_not_b32_e32 v0, v0
 ; GFX10-NEXT:    s_mov_b32 vcc_lo, 0
-; GFX10-NEXT:    v_mov_b32_e32 v3, 0x7b
-; GFX10-NEXT:    v_lshlrev_b32_e32 v2, 2, v0
 ; GFX10-NEXT:    v_mov_b32_e32 v0, 0
+; GFX10-NEXT:    v_mov_b32_e32 v2, 0x7b
+; GFX10-NEXT:    v_mov_b32_e32 v3, 0
 ; GFX10-NEXT:    v_mov_b32_e32 v1, 0
-; GFX10-NEXT:    ds_write_b32 v2, v3 offset:65532
+; GFX10-NEXT:    ds_write_b32 v3, v2
 ; GFX10-NEXT:    s_waitcnt lgkmcnt(0)
 ; GFX10-NEXT:    v_div_fmas_f32 v4, s0, s0, s0
 ; GFX10-NEXT:    global_store_dword v[0:1], v4, off
@@ -189,13 +186,11 @@ define amdgpu_kernel void @write_ds_sub_max_offset_global_clamp_bit(float %dummy
 ; GFX11-LABEL: write_ds_sub_max_offset_global_clamp_bit:
 ; GFX11:       ; %bb.0:
 ; GFX11-NEXT:    s_load_b32 s0, s[0:1], 0x0
-; GFX11-NEXT:    v_not_b32_e32 v0, v0
 ; GFX11-NEXT:    s_mov_b32 vcc_lo, 0
-; GFX11-NEXT:    s_delay_alu instid0(VALU_DEP_1)
-; GFX11-NEXT:    v_dual_mov_b32 v3, 0x7b :: v_dual_lshlrev_b32 v2, 2, v0
 ; GFX11-NEXT:    v_mov_b32_e32 v0, 0
+; GFX11-NEXT:    v_dual_mov_b32 v2, 0x7b :: v_dual_mov_b32 v3, 0
 ; GFX11-NEXT:    v_mov_b32_e32 v1, 0
-; GFX11-NEXT:    ds_store_b32 v2, v3 offset:65532
+; GFX11-NEXT:    ds_store_b32 v3, v2
 ; GFX11-NEXT:    s_waitcnt lgkmcnt(0)
 ; GFX11-NEXT:    v_div_fmas_f32 v4, s0, s0, s0
 ; GFX11-NEXT:    global_store_b32 v[0:1], v4, off dlc

diff  --git a/llvm/test/Transforms/InstCombine/fold-log2-ceil-idiom.ll b/llvm/test/Transforms/InstCombine/fold-log2-ceil-idiom.ll
index 2594c3fce81464..434d98449f99c4 100644
--- a/llvm/test/Transforms/InstCombine/fold-log2-ceil-idiom.ll
+++ b/llvm/test/Transforms/InstCombine/fold-log2-ceil-idiom.ll
@@ -43,7 +43,7 @@ define i64 @log2_ceil_idiom_zext(i32 %x) {
 ; CHECK-NEXT:    [[TMP1:%.*]] = add i32 [[X]], -1
 ; CHECK-NEXT:    [[TMP2:%.*]] = call i32 @llvm.ctlz.i32(i32 [[TMP1]], i1 false), !range [[RNG0]]
 ; CHECK-NEXT:    [[TMP3:%.*]] = sub nuw nsw i32 32, [[TMP2]]
-; CHECK-NEXT:    [[RET:%.*]] = zext i32 [[TMP3]] to i64
+; CHECK-NEXT:    [[RET:%.*]] = zext nneg i32 [[TMP3]] to i64
 ; CHECK-NEXT:    ret i64 [[RET]]
 ;
   %ctlz = tail call i32 @llvm.ctlz.i32(i32 %x, i1 true)

diff  --git a/llvm/test/Transforms/InstCombine/icmp-sub.ll b/llvm/test/Transforms/InstCombine/icmp-sub.ll
index 2dad575fede83c..5645dededf2e4b 100644
--- a/llvm/test/Transforms/InstCombine/icmp-sub.ll
+++ b/llvm/test/Transforms/InstCombine/icmp-sub.ll
@@ -36,7 +36,7 @@ define i1 @test_nuw_nsw_and_unsigned_pred(i64 %x) {
 
 define i1 @test_nuw_nsw_and_signed_pred(i64 %x) {
 ; CHECK-LABEL: @test_nuw_nsw_and_signed_pred(
-; CHECK-NEXT:    [[Z:%.*]] = icmp sgt i64 [[X:%.*]], 7
+; CHECK-NEXT:    [[Z:%.*]] = icmp ugt i64 [[X:%.*]], 7
 ; CHECK-NEXT:    ret i1 [[Z]]
 ;
   %y = sub nuw nsw i64 10, %x
@@ -46,8 +46,7 @@ define i1 @test_nuw_nsw_and_signed_pred(i64 %x) {
 
 define i1 @test_negative_nuw_and_signed_pred(i64 %x) {
 ; CHECK-LABEL: @test_negative_nuw_and_signed_pred(
-; CHECK-NEXT:    [[NOTSUB:%.*]] = add nuw i64 [[X:%.*]], -11
-; CHECK-NEXT:    [[Z:%.*]] = icmp sgt i64 [[NOTSUB]], -4
+; CHECK-NEXT:    [[Z:%.*]] = icmp ugt i64 [[X:%.*]], 7
 ; CHECK-NEXT:    ret i1 [[Z]]
 ;
   %y = sub nuw i64 10, %x

diff  --git a/llvm/test/Transforms/InstCombine/sub.ll b/llvm/test/Transforms/InstCombine/sub.ll
index 76cd7ab5c10cd1..249b5673c8acfd 100644
--- a/llvm/test/Transforms/InstCombine/sub.ll
+++ b/llvm/test/Transforms/InstCombine/sub.ll
@@ -2367,7 +2367,7 @@ define <2 x i8> @sub_to_and_vector3(<2 x i8> %x) {
 ; CHECK-LABEL: @sub_to_and_vector3(
 ; CHECK-NEXT:    [[SUB:%.*]] = sub nuw <2 x i8> <i8 71, i8 71>, [[X:%.*]]
 ; CHECK-NEXT:    [[AND:%.*]] = and <2 x i8> [[SUB]], <i8 120, i8 undef>
-; CHECK-NEXT:    [[R:%.*]] = sub <2 x i8> <i8 44, i8 44>, [[AND]]
+; CHECK-NEXT:    [[R:%.*]] = sub nsw <2 x i8> <i8 44, i8 44>, [[AND]]
 ; CHECK-NEXT:    ret <2 x i8> [[R]]
 ;
   %sub = sub nuw <2 x i8> <i8 71, i8 71>, %x

diff  --git a/llvm/unittests/Support/KnownBitsTest.cpp b/llvm/unittests/Support/KnownBitsTest.cpp
index d0ea1095056663..658f3796721c4e 100644
--- a/llvm/unittests/Support/KnownBitsTest.cpp
+++ b/llvm/unittests/Support/KnownBitsTest.cpp
@@ -169,41 +169,69 @@ static void TestAddSubExhaustive(bool IsAdd) {
   unsigned Bits = 4;
   ForeachKnownBits(Bits, [&](const KnownBits &Known1) {
     ForeachKnownBits(Bits, [&](const KnownBits &Known2) {
-      KnownBits Known(Bits), KnownNSW(Bits);
+      KnownBits Known(Bits), KnownNSW(Bits), KnownNUW(Bits),
+          KnownNSWAndNUW(Bits);
       Known.Zero.setAllBits();
       Known.One.setAllBits();
       KnownNSW.Zero.setAllBits();
       KnownNSW.One.setAllBits();
+      KnownNUW.Zero.setAllBits();
+      KnownNUW.One.setAllBits();
+      KnownNSWAndNUW.Zero.setAllBits();
+      KnownNSWAndNUW.One.setAllBits();
 
       ForeachNumInKnownBits(Known1, [&](const APInt &N1) {
         ForeachNumInKnownBits(Known2, [&](const APInt &N2) {
-          bool Overflow;
+          bool SignedOverflow;
+          bool UnsignedOverflow;
           APInt Res;
-          if (IsAdd)
-            Res = N1.sadd_ov(N2, Overflow);
-          else
-            Res = N1.ssub_ov(N2, Overflow);
+          if (IsAdd) {
+            Res = N1.uadd_ov(N2, UnsignedOverflow);
+            Res = N1.sadd_ov(N2, SignedOverflow);
+          } else {
+            Res = N1.usub_ov(N2, UnsignedOverflow);
+            Res = N1.ssub_ov(N2, SignedOverflow);
+          }
 
           Known.One &= Res;
           Known.Zero &= ~Res;
 
-          if (!Overflow) {
+          if (!SignedOverflow) {
             KnownNSW.One &= Res;
             KnownNSW.Zero &= ~Res;
           }
+
+          if (!UnsignedOverflow) {
+            KnownNUW.One &= Res;
+            KnownNUW.Zero &= ~Res;
+          }
+
+          if (!UnsignedOverflow && !SignedOverflow) {
+            KnownNSWAndNUW.One &= Res;
+            KnownNSWAndNUW.Zero &= ~Res;
+          }
         });
       });
 
       KnownBits KnownComputed = KnownBits::computeForAddSub(
           IsAdd, /*NSW=*/false, /*NUW=*/false, Known1, Known2);
-      EXPECT_EQ(Known, KnownComputed);
+      EXPECT_TRUE(isOptimal(Known, KnownComputed, {Known1, Known2}));
 
-      // The NSW calculation is not precise, only check that it's
-      // conservatively correct.
       KnownBits KnownNSWComputed = KnownBits::computeForAddSub(
           IsAdd, /*NSW=*/true, /*NUW=*/false, Known1, Known2);
-      EXPECT_TRUE(KnownNSWComputed.Zero.isSubsetOf(KnownNSW.Zero));
-      EXPECT_TRUE(KnownNSWComputed.One.isSubsetOf(KnownNSW.One));
+      if (!KnownNSW.hasConflict())
+        EXPECT_TRUE(isOptimal(KnownNSW, KnownNSWComputed, {Known1, Known2}));
+
+      KnownBits KnownNUWComputed = KnownBits::computeForAddSub(
+          IsAdd, /*NSW=*/false, /*NUW=*/true, Known1, Known2);
+      if (!KnownNUW.hasConflict())
+        EXPECT_TRUE(isOptimal(KnownNUW, KnownNUWComputed, {Known1, Known2}));
+
+      KnownBits KnownNSWAndNUWComputed = KnownBits::computeForAddSub(
+          IsAdd, /*NSW=*/true, /*NUW=*/true, Known1, Known2);
+      if (!KnownNSWAndNUW.hasConflict())
+        EXPECT_TRUE(isOptimal(KnownNSWAndNUW, KnownNSWAndNUWComputed,
+                              {Known1, Known2}));
     });
   });
 }
@@ -244,6 +272,28 @@ TEST(KnownBitsTest, SubBorrowExhaustive) {
   });
 }
 
+TEST(KnownBitsTest, SignBitUnknown) {
+  KnownBits Known(2);
+  EXPECT_TRUE(Known.isSignUnknown());
+  Known.Zero.setBit(0);
+  EXPECT_TRUE(Known.isSignUnknown());
+  Known.Zero.setBit(1);
+  EXPECT_FALSE(Known.isSignUnknown());
+  Known.Zero.clearBit(0);
+  EXPECT_FALSE(Known.isSignUnknown());
+  Known.Zero.clearBit(1);
+  EXPECT_TRUE(Known.isSignUnknown());
+
+  Known.One.setBit(0);
+  EXPECT_TRUE(Known.isSignUnknown());
+  Known.One.setBit(1);
+  EXPECT_FALSE(Known.isSignUnknown());
+  Known.One.clearBit(0);
+  EXPECT_FALSE(Known.isSignUnknown());
+  Known.One.clearBit(1);
+  EXPECT_TRUE(Known.isSignUnknown());
+}
+
 TEST(KnownBitsTest, AbsDiffSpecialCase) {
   // There are 2 implementation of abs
diff  - both are currently needed to cover
   // extra cases.


        


More information about the llvm-commits mailing list