[llvm] d2502eb - [KnownBits] Add support for nuw/nsw on shifts

Nikita Popov via llvm-commits llvm-commits at lists.llvm.org
Thu May 25 01:17:19 PDT 2023


Author: Nikita Popov
Date: 2023-05-25T10:17:10+02:00
New Revision: d2502eb091fabc36463e491b066bb002b47ba521

URL: https://github.com/llvm/llvm-project/commit/d2502eb091fabc36463e491b066bb002b47ba521
DIFF: https://github.com/llvm/llvm-project/commit/d2502eb091fabc36463e491b066bb002b47ba521.diff

LOG: [KnownBits] Add support for nuw/nsw on shifts

Implement precise nuw/nsw support in the KnownBits implementation,
replacing the rather crude handling in ValueTracking.

Differential Revision: https://reviews.llvm.org/D151208

Added: 
    

Modified: 
    llvm/include/llvm/Support/KnownBits.h
    llvm/lib/Analysis/ValueTracking.cpp
    llvm/lib/Support/KnownBits.cpp
    llvm/test/Transforms/ConstraintElimination/geps-unsigned-predicates.ll
    llvm/test/Transforms/LoopVectorize/AArch64/sve-interleaved-accesses.ll
    llvm/unittests/Support/KnownBitsTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Support/KnownBits.h b/llvm/include/llvm/Support/KnownBits.h
index 6ba4dd4f82540..9229a4d61d4b4 100644
--- a/llvm/include/llvm/Support/KnownBits.h
+++ b/llvm/include/llvm/Support/KnownBits.h
@@ -382,7 +382,8 @@ struct KnownBits {
 
   /// Compute known bits for shl(LHS, RHS).
   /// NOTE: RHS (shift amount) bitwidth doesn't need to be the same as LHS.
-  static KnownBits shl(const KnownBits &LHS, const KnownBits &RHS);
+  static KnownBits shl(const KnownBits &LHS, const KnownBits &RHS,
+                       bool NUW = false, bool NSW = false);
 
   /// Compute known bits for lshr(LHS, RHS).
   /// NOTE: RHS (shift amount) bitwidth doesn't need to be the same as LHS.

diff  --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 90dbfcb85fb2c..7ec34cdca0be5 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -1353,20 +1353,10 @@ static void computeKnownBitsFromOperator(const Operator *I,
     break;
   }
   case Instruction::Shl: {
+    bool NUW = Q.IIQ.hasNoUnsignedWrap(cast<OverflowingBinaryOperator>(I));
     bool NSW = Q.IIQ.hasNoSignedWrap(cast<OverflowingBinaryOperator>(I));
-    auto KF = [NSW](const KnownBits &KnownVal, const KnownBits &KnownAmt) {
-      KnownBits Result = KnownBits::shl(KnownVal, KnownAmt);
-      // If this shift has "nsw" keyword, then the result is either a poison
-      // value or has the same sign bit as the first operand.
-      if (NSW) {
-        if (KnownVal.Zero.isSignBitSet())
-          Result.Zero.setSignBit();
-        if (KnownVal.One.isSignBitSet())
-          Result.One.setSignBit();
-        if (Result.hasConflict())
-          Result.setAllZero();
-      }
-      return Result;
+    auto KF = [NUW, NSW](const KnownBits &KnownVal, const KnownBits &KnownAmt) {
+      return KnownBits::shl(KnownVal, KnownAmt, NUW, NSW);
     };
     computeKnownBitsFromShiftOperator(I, DemandedElts, Known, Known2, Depth, Q,
                                       KF);

diff  --git a/llvm/lib/Support/KnownBits.cpp b/llvm/lib/Support/KnownBits.cpp
index d52f08739be79..8bb236baf4ae5 100644
--- a/llvm/lib/Support/KnownBits.cpp
+++ b/llvm/lib/Support/KnownBits.cpp
@@ -164,21 +164,51 @@ KnownBits KnownBits::smin(const KnownBits &LHS, const KnownBits &RHS) {
   return Flip(umax(Flip(LHS), Flip(RHS)));
 }
 
-KnownBits KnownBits::shl(const KnownBits &LHS, const KnownBits &RHS) {
+KnownBits KnownBits::shl(const KnownBits &LHS, const KnownBits &RHS, bool NUW,
+                         bool NSW) {
   unsigned BitWidth = LHS.getBitWidth();
-  KnownBits Known(BitWidth);
+  auto ShiftByConst = [&](const KnownBits &LHS,
+                          uint64_t ShiftAmt) -> std::optional<KnownBits> {
+    KnownBits Known;
+    Known.Zero = LHS.Zero << ShiftAmt;
+    Known.Zero.setLowBits(ShiftAmt);
+    Known.One = LHS.One << ShiftAmt;
+    if ((!NUW && !NSW) || ShiftAmt == 0)
+      return Known;
+
+    KnownBits ShiftedOutBits = LHS.extractBits(ShiftAmt, BitWidth - ShiftAmt);
+    if (NUW && !ShiftedOutBits.One.isZero())
+      // One bit has been shifted out.
+      return std::nullopt;
+    if (NSW) {
+      if (!ShiftedOutBits.Zero.isZero() && !ShiftedOutBits.One.isZero())
+        // Both zeros and ones have been shifted out.
+        return std::nullopt;
+      if (NUW || !ShiftedOutBits.Zero.isZero()) {
+        if (Known.isNegative())
+          // Zero bit has been shifted out, but result sign is negative.
+          return std::nullopt;
+        Known.makeNonNegative();
+      } else if (!ShiftedOutBits.One.isZero()) {
+        if (Known.isNonNegative())
+          // One bit has been shifted out, but result sign is negative.
+          return std::nullopt;
+        Known.makeNegative();
+      }
+    }
+    return Known;
+  };
 
   // If the shift amount is a valid constant then transform LHS directly.
   if (RHS.isConstant() && RHS.getConstant().ult(BitWidth)) {
-    unsigned Shift = RHS.getConstant().getZExtValue();
-    Known = LHS;
-    Known.Zero <<= Shift;
-    Known.One <<= Shift;
-    // Low bits are known zero.
-    Known.Zero.setLowBits(Shift);
+    if (auto Res = ShiftByConst(LHS, RHS.getConstant().getZExtValue()))
+      return *Res;
+    KnownBits Known(BitWidth);
+    Known.setAllZero();
     return Known;
   }
 
+  KnownBits Known(BitWidth);
   APInt MinShiftAmount = RHS.getMinValue();
   if (MinShiftAmount.uge(BitWidth)) {
     // Always poison. Return zero because we don't like returning conflict.
@@ -193,6 +223,8 @@ KnownBits KnownBits::shl(const KnownBits &LHS, const KnownBits &RHS) {
     MinTrailingZeros += MinShiftAmount.getZExtValue();
     MinTrailingZeros = std::min(MinTrailingZeros, BitWidth);
     Known.Zero.setLowBits(MinTrailingZeros);
+    if (NUW && NSW && !MinShiftAmount.isZero())
+      Known.makeNonNegative();
     return Known;
   }
 
@@ -210,15 +242,20 @@ KnownBits KnownBits::shl(const KnownBits &LHS, const KnownBits &RHS) {
     if ((ShiftAmtZeroMask & ShiftAmt) != ShiftAmt ||
         (ShiftAmtOneMask | ShiftAmt) != ShiftAmt)
       continue;
-    KnownBits SpecificShift;
-    SpecificShift.Zero = LHS.Zero << ShiftAmt;
-    SpecificShift.Zero.setLowBits(ShiftAmt);
-    SpecificShift.One = LHS.One << ShiftAmt;
-    Known = Known.intersectWith(SpecificShift);
+    auto Res = ShiftByConst(LHS, ShiftAmt);
+    if (!Res)
+      // All larger shift amounts will overflow as well.
+      break;
+    Known = Known.intersectWith(*Res);
     if (Known.isUnknown())
       break;
   }
 
+  // All shift amounts may result in poison.
+  if (Known.hasConflict()) {
+    assert((NUW || NSW) && "Can only happen with nowrap flags");
+    Known.setAllZero();
+  }
   return Known;
 }
 

diff  --git a/llvm/test/Transforms/ConstraintElimination/geps-unsigned-predicates.ll b/llvm/test/Transforms/ConstraintElimination/geps-unsigned-predicates.ll
index 524a13fadf661..51eb96c38c1c8 100644
--- a/llvm/test/Transforms/ConstraintElimination/geps-unsigned-predicates.ll
+++ b/llvm/test/Transforms/ConstraintElimination/geps-unsigned-predicates.ll
@@ -513,7 +513,7 @@ define void @test.ult.gep.shl(ptr readonly %src, ptr readnone %max, i8 %idx) {
 ; CHECK-NEXT:    [[IDX_SHL_1:%.*]] = shl nuw nsw i8 [[IDX]], 1
 ; CHECK-NEXT:    [[ADD_PTR_SHL_1:%.*]] = getelementptr inbounds i32, ptr [[SRC]], i8 [[IDX_SHL_1]]
 ; CHECK-NEXT:    [[C_MAX_0:%.*]] = icmp ult ptr [[ADD_PTR_SHL_1]], [[MAX]]
-; CHECK-NEXT:    call void @use(i1 [[C_MAX_0]])
+; CHECK-NEXT:    call void @use(i1 true)
 ; CHECK-NEXT:    [[IDX_SHL_2:%.*]] = shl nuw i8 [[IDX]], 2
 ; CHECK-NEXT:    [[ADD_PTR_SHL_2:%.*]] = getelementptr inbounds i32, ptr [[SRC]], i8 [[IDX_SHL_2]]
 ; CHECK-NEXT:    [[C_MAX_1:%.*]] = icmp ult ptr [[ADD_PTR_SHL_2]], [[MAX]]

diff  --git a/llvm/test/Transforms/LoopVectorize/AArch64/sve-interleaved-accesses.ll b/llvm/test/Transforms/LoopVectorize/AArch64/sve-interleaved-accesses.ll
index 24f6a5a0b1eb3..7c91170bd051b 100644
--- a/llvm/test/Transforms/LoopVectorize/AArch64/sve-interleaved-accesses.ll
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/sve-interleaved-accesses.ll
@@ -82,7 +82,7 @@ define void @test_array_load2_store2(i32 %C, i32 %D) #1 {
 ; CHECK-NEXT:    [[ARRAYIDX3:%.*]] = getelementptr inbounds [1024 x i32], ptr @CD, i64 0, i64 [[OR]]
 ; CHECK-NEXT:    store i32 [[MUL]], ptr [[ARRAYIDX3]], align 4
 ; CHECK-NEXT:    [[INDVARS_IV_NEXT]] = add nuw nsw i64 [[INDVARS_IV]], 2
-; CHECK-NEXT:    [[CMP:%.*]] = icmp slt i64 [[INDVARS_IV]], 1022
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ult i64 [[INDVARS_IV]], 1022
 ; CHECK-NEXT:    br i1 [[CMP]], label [[FOR_BODY]], label [[FOR_END]], !llvm.loop [[LOOP3:![0-9]+]]
 ; CHECK:       for.end:
 ; CHECK-NEXT:    ret void

diff  --git a/llvm/unittests/Support/KnownBitsTest.cpp b/llvm/unittests/Support/KnownBitsTest.cpp
index 1607ddfdec502..427b323912abb 100644
--- a/llvm/unittests/Support/KnownBitsTest.cpp
+++ b/llvm/unittests/Support/KnownBitsTest.cpp
@@ -343,6 +343,41 @@ TEST(KnownBitsTest, BinaryExhaustive) {
           return std::nullopt;
         return N1.shl(N2);
       });
+  testBinaryOpExhaustive(
+      [](const KnownBits &Known1, const KnownBits &Known2) {
+        return KnownBits::shl(Known1, Known2, /* NUW */ true);
+      },
+      [](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
+        bool Overflow;
+        APInt Res = N1.ushl_ov(N2, Overflow);
+        if (Overflow)
+          return std::nullopt;
+        return Res;
+      });
+  testBinaryOpExhaustive(
+      [](const KnownBits &Known1, const KnownBits &Known2) {
+        return KnownBits::shl(Known1, Known2, /* NUW */ false, /* NSW */ true);
+      },
+      [](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
+        bool Overflow;
+        APInt Res = N1.sshl_ov(N2, Overflow);
+        if (Overflow)
+          return std::nullopt;
+        return Res;
+      });
+  testBinaryOpExhaustive(
+      [](const KnownBits &Known1, const KnownBits &Known2) {
+        return KnownBits::shl(Known1, Known2, /* NUW */ true, /* NSW */ true);
+      },
+      [](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
+        bool OverflowUnsigned, OverflowSigned;
+        APInt Res = N1.ushl_ov(N2, OverflowUnsigned);
+        (void)N1.sshl_ov(N2, OverflowSigned);
+        if (OverflowUnsigned || OverflowSigned)
+          return std::nullopt;
+        return Res;
+      });
+
   testBinaryOpExhaustive(
       [](const KnownBits &Known1, const KnownBits &Known2) {
         return KnownBits::lshr(Known1, Known2);


        


More information about the llvm-commits mailing list