[llvm] d81db0e - [KnownBits] Implement knownbits `lshr`/`ashr` with exact flag

Noah Goldstein via llvm-commits llvm-commits at lists.llvm.org
Mon Mar 11 13:51:34 PDT 2024


Author: Noah Goldstein
Date: 2024-03-11T15:51:07-05:00
New Revision: d81db0e5f5b1404ff4813af3050d671528ad45cc

URL: https://github.com/llvm/llvm-project/commit/d81db0e5f5b1404ff4813af3050d671528ad45cc
DIFF: https://github.com/llvm/llvm-project/commit/d81db0e5f5b1404ff4813af3050d671528ad45cc.diff

LOG: [KnownBits] Implement knownbits `lshr`/`ashr` with exact flag

The exact flag basically allows us to set an upper bound on shift
amount when we have a known 1 in `LHS`.

Typically we deduce exact using knownbits (on non-exact incoming
shifts), so this is particularly impactful, but may be useful in some
circumstances.

Closes #84254

Added: 
    

Modified: 
    llvm/lib/Support/KnownBits.cpp
    llvm/test/Analysis/ValueTracking/knownbits-shift.ll
    llvm/unittests/Support/KnownBitsTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Support/KnownBits.cpp b/llvm/lib/Support/KnownBits.cpp
index ed25e52b9ace67..c33c3680825a10 100644
--- a/llvm/lib/Support/KnownBits.cpp
+++ b/llvm/lib/Support/KnownBits.cpp
@@ -343,7 +343,7 @@ KnownBits KnownBits::shl(const KnownBits &LHS, const KnownBits &RHS, bool NUW,
 }
 
 KnownBits KnownBits::lshr(const KnownBits &LHS, const KnownBits &RHS,
-                          bool ShAmtNonZero, bool /*Exact*/) {
+                          bool ShAmtNonZero, bool Exact) {
   unsigned BitWidth = LHS.getBitWidth();
   auto ShiftByConst = [&](const KnownBits &LHS, unsigned ShiftAmt) {
     KnownBits Known = LHS;
@@ -367,6 +367,18 @@ KnownBits KnownBits::lshr(const KnownBits &LHS, const KnownBits &RHS,
   // Find the common bits from all possible shifts.
   APInt MaxValue = RHS.getMaxValue();
   unsigned MaxShiftAmount = getMaxShiftAmount(MaxValue, BitWidth);
+
+  // If exact, bound MaxShiftAmount to first known 1 in LHS.
+  if (Exact) {
+    unsigned FirstOne = LHS.countMaxTrailingZeros();
+    if (FirstOne < MinShiftAmount) {
+      // Always poison. Return zero because we don't like returning conflict.
+      Known.setAllZero();
+      return Known;
+    }
+    MaxShiftAmount = std::min(MaxShiftAmount, FirstOne);
+  }
+
   unsigned ShiftAmtZeroMask = RHS.Zero.zextOrTrunc(32).getZExtValue();
   unsigned ShiftAmtOneMask = RHS.One.zextOrTrunc(32).getZExtValue();
   Known.Zero.setAllBits();
@@ -389,7 +401,7 @@ KnownBits KnownBits::lshr(const KnownBits &LHS, const KnownBits &RHS,
 }
 
 KnownBits KnownBits::ashr(const KnownBits &LHS, const KnownBits &RHS,
-                          bool ShAmtNonZero, bool /*Exact*/) {
+                          bool ShAmtNonZero, bool Exact) {
   unsigned BitWidth = LHS.getBitWidth();
   auto ShiftByConst = [&](const KnownBits &LHS, unsigned ShiftAmt) {
     KnownBits Known = LHS;
@@ -415,6 +427,18 @@ KnownBits KnownBits::ashr(const KnownBits &LHS, const KnownBits &RHS,
   // Find the common bits from all possible shifts.
   APInt MaxValue = RHS.getMaxValue();
   unsigned MaxShiftAmount = getMaxShiftAmount(MaxValue, BitWidth);
+
+  // If exact, bound MaxShiftAmount to first known 1 in LHS.
+  if (Exact) {
+    unsigned FirstOne = LHS.countMaxTrailingZeros();
+    if (FirstOne < MinShiftAmount) {
+      // Always poison. Return zero because we don't like returning conflict.
+      Known.setAllZero();
+      return Known;
+    }
+    MaxShiftAmount = std::min(MaxShiftAmount, FirstOne);
+  }
+
   unsigned ShiftAmtZeroMask = RHS.Zero.zextOrTrunc(32).getZExtValue();
   unsigned ShiftAmtOneMask = RHS.One.zextOrTrunc(32).getZExtValue();
   Known.Zero.setAllBits();

diff  --git a/llvm/test/Analysis/ValueTracking/knownbits-shift.ll b/llvm/test/Analysis/ValueTracking/knownbits-shift.ll
index 3235f69b5221a1..5cb355eff5a699 100644
--- a/llvm/test/Analysis/ValueTracking/knownbits-shift.ll
+++ b/llvm/test/Analysis/ValueTracking/knownbits-shift.ll
@@ -3,9 +3,7 @@
 
 define i8 @simplify_lshr_with_exact(i8 %x) {
 ; CHECK-LABEL: @simplify_lshr_with_exact(
-; CHECK-NEXT:    [[SHR:%.*]] = lshr exact i8 6, [[X:%.*]]
-; CHECK-NEXT:    [[R:%.*]] = and i8 [[SHR]], 2
-; CHECK-NEXT:    ret i8 [[R]]
+; CHECK-NEXT:    ret i8 2
 ;
   %shr = lshr exact i8 6, %x
   %r = and i8 %shr, 2
@@ -14,9 +12,7 @@ define i8 @simplify_lshr_with_exact(i8 %x) {
 
 define i8 @simplify_ashr_with_exact(i8 %x) {
 ; CHECK-LABEL: @simplify_ashr_with_exact(
-; CHECK-NEXT:    [[SHR:%.*]] = ashr exact i8 -122, [[X:%.*]]
-; CHECK-NEXT:    [[R:%.*]] = and i8 [[SHR]], 2
-; CHECK-NEXT:    ret i8 [[R]]
+; CHECK-NEXT:    ret i8 2
 ;
   %shr = ashr exact i8 -122, %x
   %r = and i8 %shr, 2

diff  --git a/llvm/unittests/Support/KnownBitsTest.cpp b/llvm/unittests/Support/KnownBitsTest.cpp
index 658f3796721c4e..7c183e9626f985 100644
--- a/llvm/unittests/Support/KnownBitsTest.cpp
+++ b/llvm/unittests/Support/KnownBitsTest.cpp
@@ -516,6 +516,19 @@ TEST(KnownBitsTest, BinaryExhaustive) {
         return N1.lshr(N2);
       },
       checkOptimalityBinary, /* RefinePoisonToZero */ true);
+  testBinaryOpExhaustive(
+      [](const KnownBits &Known1, const KnownBits &Known2) {
+        return KnownBits::lshr(Known1, Known2, /*ShAmtNonZero=*/false,
+                               /*Exact=*/true);
+      },
+      [](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
+        if (N2.uge(N2.getBitWidth()))
+          return std::nullopt;
+        if (!N1.extractBits(N2.getZExtValue(), 0).isZero())
+          return std::nullopt;
+        return N1.lshr(N2);
+      },
+      checkOptimalityBinary, /* RefinePoisonToZero */ true);
   testBinaryOpExhaustive(
       [](const KnownBits &Known1, const KnownBits &Known2) {
         return KnownBits::ashr(Known1, Known2);
@@ -526,6 +539,19 @@ TEST(KnownBitsTest, BinaryExhaustive) {
         return N1.ashr(N2);
       },
       checkOptimalityBinary, /* RefinePoisonToZero */ true);
+  testBinaryOpExhaustive(
+      [](const KnownBits &Known1, const KnownBits &Known2) {
+        return KnownBits::ashr(Known1, Known2, /*ShAmtNonZero=*/false,
+                               /*Exact=*/true);
+      },
+      [](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
+        if (N2.uge(N2.getBitWidth()))
+          return std::nullopt;
+        if (!N1.extractBits(N2.getZExtValue(), 0).isZero())
+          return std::nullopt;
+        return N1.ashr(N2);
+      },
+      checkOptimalityBinary, /* RefinePoisonToZero */ true);
 
   testBinaryOpExhaustive(
       [](const KnownBits &Known1, const KnownBits &Known2) {


        


More information about the llvm-commits mailing list