[llvm] [KnownBits] Implement knownbits lshr/ashr with exact flag (PR #84254)

via llvm-commits llvm-commits at lists.llvm.org
Thu Mar 7 09:33:55 PST 2024


https://github.com/goldsteinn updated https://github.com/llvm/llvm-project/pull/84254

>From 0173cfcf13ff864790a2f212ddbe95e199f2af64 Mon Sep 17 00:00:00 2001
From: Noah Goldstein <goldstein.w.n at gmail.com>
Date: Thu, 7 Mar 2024 11:28:28 -0600
Subject: [PATCH 1/3] [KnownBits] Add test for computing more information for
 `lshr`/`ashr` with `exact` flag; NFC

---
 .../Analysis/ValueTracking/knownbits-shift.ll | 24 +++++++++++++++++++
 1 file changed, 24 insertions(+)
 create mode 100644 llvm/test/Analysis/ValueTracking/knownbits-shift.ll

diff --git a/llvm/test/Analysis/ValueTracking/knownbits-shift.ll b/llvm/test/Analysis/ValueTracking/knownbits-shift.ll
new file mode 100644
index 00000000000000..3235f69b5221a1
--- /dev/null
+++ b/llvm/test/Analysis/ValueTracking/knownbits-shift.ll
@@ -0,0 +1,24 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt -passes=instcombine -S < %s | FileCheck %s
+
+define i8 @simplify_lshr_with_exact(i8 %x) {
+; CHECK-LABEL: @simplify_lshr_with_exact(
+; CHECK-NEXT:    [[SHR:%.*]] = lshr exact i8 6, [[X:%.*]]
+; CHECK-NEXT:    [[R:%.*]] = and i8 [[SHR]], 2
+; CHECK-NEXT:    ret i8 [[R]]
+;
+  %shr = lshr exact i8 6, %x
+  %r = and i8 %shr, 2
+  ret i8 %r
+}
+
+define i8 @simplify_ashr_with_exact(i8 %x) {
+; CHECK-LABEL: @simplify_ashr_with_exact(
+; CHECK-NEXT:    [[SHR:%.*]] = ashr exact i8 -122, [[X:%.*]]
+; CHECK-NEXT:    [[R:%.*]] = and i8 [[SHR]], 2
+; CHECK-NEXT:    ret i8 [[R]]
+;
+  %shr = ashr exact i8 -122, %x
+  %r = and i8 %shr, 2
+  ret i8 %r
+}

>From cd68183e35a1830e27c3d75576a9265fc545e72c Mon Sep 17 00:00:00 2001
From: Noah Goldstein <goldstein.w.n at gmail.com>
Date: Tue, 5 Mar 2024 21:56:27 -0600
Subject: [PATCH 2/3] [KnownBits] Add API support for `exact` in `lshr`/`ashr`;
 NFC

---
 llvm/include/llvm/Support/KnownBits.h          |  4 ++--
 llvm/lib/Analysis/ValueTracking.cpp            | 14 ++++++++------
 llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp |  6 ++++--
 llvm/lib/Support/KnownBits.cpp                 |  4 ++--
 4 files changed, 16 insertions(+), 12 deletions(-)

diff --git a/llvm/include/llvm/Support/KnownBits.h b/llvm/include/llvm/Support/KnownBits.h
index 46dbf0c2baa5fe..06d2c90f7b0f6b 100644
--- a/llvm/include/llvm/Support/KnownBits.h
+++ b/llvm/include/llvm/Support/KnownBits.h
@@ -402,12 +402,12 @@ struct KnownBits {
   /// Compute known bits for lshr(LHS, RHS).
   /// NOTE: RHS (shift amount) bitwidth doesn't need to be the same as LHS.
   static KnownBits lshr(const KnownBits &LHS, const KnownBits &RHS,
-                        bool ShAmtNonZero = false);
+                        bool ShAmtNonZero = false, bool Exact = false);
 
   /// Compute known bits for ashr(LHS, RHS).
   /// NOTE: RHS (shift amount) bitwidth doesn't need to be the same as LHS.
   static KnownBits ashr(const KnownBits &LHS, const KnownBits &RHS,
-                        bool ShAmtNonZero = false);
+                        bool ShAmtNonZero = false, bool Exact = false);
 
   /// Determine if these known bits always give the same ICMP_EQ result.
   static std::optional<bool> eq(const KnownBits &LHS, const KnownBits &RHS);
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 52ae9f034e5d34..3304db68e3deae 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -1142,9 +1142,10 @@ static void computeKnownBitsFromOperator(const Operator *I,
     break;
   }
   case Instruction::LShr: {
-    auto KF = [](const KnownBits &KnownVal, const KnownBits &KnownAmt,
-                 bool ShAmtNonZero) {
-      return KnownBits::lshr(KnownVal, KnownAmt, ShAmtNonZero);
+    bool Exact = Q.IIQ.isExact(cast<BinaryOperator>(I));
+    auto KF = [Exact](const KnownBits &KnownVal, const KnownBits &KnownAmt,
+                      bool ShAmtNonZero) {
+      return KnownBits::lshr(KnownVal, KnownAmt, ShAmtNonZero, Exact);
     };
     computeKnownBitsFromShiftOperator(I, DemandedElts, Known, Known2, Depth, Q,
                                       KF);
@@ -1155,9 +1156,10 @@ static void computeKnownBitsFromOperator(const Operator *I,
     break;
   }
   case Instruction::AShr: {
-    auto KF = [](const KnownBits &KnownVal, const KnownBits &KnownAmt,
-                 bool ShAmtNonZero) {
-      return KnownBits::ashr(KnownVal, KnownAmt, ShAmtNonZero);
+    bool Exact = Q.IIQ.isExact(cast<BinaryOperator>(I));
+    auto KF = [Exact](const KnownBits &KnownVal, const KnownBits &KnownAmt,
+                      bool ShAmtNonZero) {
+      return KnownBits::ashr(KnownVal, KnownAmt, ShAmtNonZero, Exact);
     };
     computeKnownBitsFromShiftOperator(I, DemandedElts, Known, Known2, Depth, Q,
                                       KF);
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index f7ace79e8c51d4..92e20cc1304b70 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -3485,7 +3485,8 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
   case ISD::SRL:
     Known = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
     Known2 = computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
-    Known = KnownBits::lshr(Known, Known2);
+    Known = KnownBits::lshr(Known, Known2, /*ShAmtNonZero=*/false,
+                            Op->getFlags().hasExact());
 
     // Minimum shift high bits are known zero.
     if (const APInt *ShMinAmt =
@@ -3495,7 +3496,8 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
   case ISD::SRA:
     Known = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
     Known2 = computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
-    Known = KnownBits::ashr(Known, Known2);
+    Known = KnownBits::ashr(Known, Known2, /*ShAmtNonZero=*/false,
+                            Op->getFlags().hasExact());
     break;
   case ISD::FSHL:
   case ISD::FSHR:
diff --git a/llvm/lib/Support/KnownBits.cpp b/llvm/lib/Support/KnownBits.cpp
index 74d857457aec1e..ed25e52b9ace67 100644
--- a/llvm/lib/Support/KnownBits.cpp
+++ b/llvm/lib/Support/KnownBits.cpp
@@ -343,7 +343,7 @@ KnownBits KnownBits::shl(const KnownBits &LHS, const KnownBits &RHS, bool NUW,
 }
 
 KnownBits KnownBits::lshr(const KnownBits &LHS, const KnownBits &RHS,
-                          bool ShAmtNonZero) {
+                          bool ShAmtNonZero, bool /*Exact*/) {
   unsigned BitWidth = LHS.getBitWidth();
   auto ShiftByConst = [&](const KnownBits &LHS, unsigned ShiftAmt) {
     KnownBits Known = LHS;
@@ -389,7 +389,7 @@ KnownBits KnownBits::lshr(const KnownBits &LHS, const KnownBits &RHS,
 }
 
 KnownBits KnownBits::ashr(const KnownBits &LHS, const KnownBits &RHS,
-                          bool ShAmtNonZero) {
+                          bool ShAmtNonZero, bool /*Exact*/) {
   unsigned BitWidth = LHS.getBitWidth();
   auto ShiftByConst = [&](const KnownBits &LHS, unsigned ShiftAmt) {
     KnownBits Known = LHS;

>From 895ce6745fe8abc81ce8ba0215ed4fdc1f965024 Mon Sep 17 00:00:00 2001
From: Noah Goldstein <goldstein.w.n at gmail.com>
Date: Tue, 5 Mar 2024 22:03:44 -0600
Subject: [PATCH 3/3] [KnownBits] Implement knownbits `lshr`/`ashr` with exact
 flag

The exact flag basically allows us to set an upper bound on shift
amount when we have a known 1 in `LHS`.

Typically we deduce exact using knownbits (on non-exact incoming
shifts), so this is particularly impactful, but may be useful in some
circumstances.
---
 llvm/lib/Support/KnownBits.cpp                | 28 +++++++++++++++++--
 .../Analysis/ValueTracking/knownbits-shift.ll |  8 ++----
 llvm/unittests/Support/KnownBitsTest.cpp      | 26 +++++++++++++++++
 3 files changed, 54 insertions(+), 8 deletions(-)

diff --git a/llvm/lib/Support/KnownBits.cpp b/llvm/lib/Support/KnownBits.cpp
index ed25e52b9ace67..c33c3680825a10 100644
--- a/llvm/lib/Support/KnownBits.cpp
+++ b/llvm/lib/Support/KnownBits.cpp
@@ -343,7 +343,7 @@ KnownBits KnownBits::shl(const KnownBits &LHS, const KnownBits &RHS, bool NUW,
 }
 
 KnownBits KnownBits::lshr(const KnownBits &LHS, const KnownBits &RHS,
-                          bool ShAmtNonZero, bool /*Exact*/) {
+                          bool ShAmtNonZero, bool Exact) {
   unsigned BitWidth = LHS.getBitWidth();
   auto ShiftByConst = [&](const KnownBits &LHS, unsigned ShiftAmt) {
     KnownBits Known = LHS;
@@ -367,6 +367,18 @@ KnownBits KnownBits::lshr(const KnownBits &LHS, const KnownBits &RHS,
   // Find the common bits from all possible shifts.
   APInt MaxValue = RHS.getMaxValue();
   unsigned MaxShiftAmount = getMaxShiftAmount(MaxValue, BitWidth);
+
+  // If exact, bound MaxShiftAmount to first known 1 in LHS.
+  if (Exact) {
+    unsigned FirstOne = LHS.countMaxTrailingZeros();
+    if (FirstOne < MinShiftAmount) {
+      // Always poison. Return zero because we don't like returning conflict.
+      Known.setAllZero();
+      return Known;
+    }
+    MaxShiftAmount = std::min(MaxShiftAmount, FirstOne);
+  }
+
   unsigned ShiftAmtZeroMask = RHS.Zero.zextOrTrunc(32).getZExtValue();
   unsigned ShiftAmtOneMask = RHS.One.zextOrTrunc(32).getZExtValue();
   Known.Zero.setAllBits();
@@ -389,7 +401,7 @@ KnownBits KnownBits::lshr(const KnownBits &LHS, const KnownBits &RHS,
 }
 
 KnownBits KnownBits::ashr(const KnownBits &LHS, const KnownBits &RHS,
-                          bool ShAmtNonZero, bool /*Exact*/) {
+                          bool ShAmtNonZero, bool Exact) {
   unsigned BitWidth = LHS.getBitWidth();
   auto ShiftByConst = [&](const KnownBits &LHS, unsigned ShiftAmt) {
     KnownBits Known = LHS;
@@ -415,6 +427,18 @@ KnownBits KnownBits::ashr(const KnownBits &LHS, const KnownBits &RHS,
   // Find the common bits from all possible shifts.
   APInt MaxValue = RHS.getMaxValue();
   unsigned MaxShiftAmount = getMaxShiftAmount(MaxValue, BitWidth);
+
+  // If exact, bound MaxShiftAmount to first known 1 in LHS.
+  if (Exact) {
+    unsigned FirstOne = LHS.countMaxTrailingZeros();
+    if (FirstOne < MinShiftAmount) {
+      // Always poison. Return zero because we don't like returning conflict.
+      Known.setAllZero();
+      return Known;
+    }
+    MaxShiftAmount = std::min(MaxShiftAmount, FirstOne);
+  }
+
   unsigned ShiftAmtZeroMask = RHS.Zero.zextOrTrunc(32).getZExtValue();
   unsigned ShiftAmtOneMask = RHS.One.zextOrTrunc(32).getZExtValue();
   Known.Zero.setAllBits();
diff --git a/llvm/test/Analysis/ValueTracking/knownbits-shift.ll b/llvm/test/Analysis/ValueTracking/knownbits-shift.ll
index 3235f69b5221a1..5cb355eff5a699 100644
--- a/llvm/test/Analysis/ValueTracking/knownbits-shift.ll
+++ b/llvm/test/Analysis/ValueTracking/knownbits-shift.ll
@@ -3,9 +3,7 @@
 
 define i8 @simplify_lshr_with_exact(i8 %x) {
 ; CHECK-LABEL: @simplify_lshr_with_exact(
-; CHECK-NEXT:    [[SHR:%.*]] = lshr exact i8 6, [[X:%.*]]
-; CHECK-NEXT:    [[R:%.*]] = and i8 [[SHR]], 2
-; CHECK-NEXT:    ret i8 [[R]]
+; CHECK-NEXT:    ret i8 2
 ;
   %shr = lshr exact i8 6, %x
   %r = and i8 %shr, 2
@@ -14,9 +12,7 @@ define i8 @simplify_lshr_with_exact(i8 %x) {
 
 define i8 @simplify_ashr_with_exact(i8 %x) {
 ; CHECK-LABEL: @simplify_ashr_with_exact(
-; CHECK-NEXT:    [[SHR:%.*]] = ashr exact i8 -122, [[X:%.*]]
-; CHECK-NEXT:    [[R:%.*]] = and i8 [[SHR]], 2
-; CHECK-NEXT:    ret i8 [[R]]
+; CHECK-NEXT:    ret i8 2
 ;
   %shr = ashr exact i8 -122, %x
   %r = and i8 %shr, 2
diff --git a/llvm/unittests/Support/KnownBitsTest.cpp b/llvm/unittests/Support/KnownBitsTest.cpp
index 658f3796721c4e..7c183e9626f985 100644
--- a/llvm/unittests/Support/KnownBitsTest.cpp
+++ b/llvm/unittests/Support/KnownBitsTest.cpp
@@ -516,6 +516,19 @@ TEST(KnownBitsTest, BinaryExhaustive) {
         return N1.lshr(N2);
       },
       checkOptimalityBinary, /* RefinePoisonToZero */ true);
+  testBinaryOpExhaustive(
+      [](const KnownBits &Known1, const KnownBits &Known2) {
+        return KnownBits::lshr(Known1, Known2, /*ShAmtNonZero=*/false,
+                               /*Exact=*/true);
+      },
+      [](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
+        if (N2.uge(N2.getBitWidth()))
+          return std::nullopt;
+        if (!N1.extractBits(N2.getZExtValue(), 0).isZero())
+          return std::nullopt;
+        return N1.lshr(N2);
+      },
+      checkOptimalityBinary, /* RefinePoisonToZero */ true);
   testBinaryOpExhaustive(
       [](const KnownBits &Known1, const KnownBits &Known2) {
         return KnownBits::ashr(Known1, Known2);
@@ -526,6 +539,19 @@ TEST(KnownBitsTest, BinaryExhaustive) {
         return N1.ashr(N2);
       },
       checkOptimalityBinary, /* RefinePoisonToZero */ true);
+  testBinaryOpExhaustive(
+      [](const KnownBits &Known1, const KnownBits &Known2) {
+        return KnownBits::ashr(Known1, Known2, /*ShAmtNonZero=*/false,
+                               /*Exact=*/true);
+      },
+      [](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
+        if (N2.uge(N2.getBitWidth()))
+          return std::nullopt;
+        if (!N1.extractBits(N2.getZExtValue(), 0).isZero())
+          return std::nullopt;
+        return N1.ashr(N2);
+      },
+      checkOptimalityBinary, /* RefinePoisonToZero */ true);
 
   testBinaryOpExhaustive(
       [](const KnownBits &Known1, const KnownBits &Known2) {



More information about the llvm-commits mailing list