[llvm] [KnownBits] Implement knownbits lshr/ashr with exact flag (PR #84254)

via llvm-commits llvm-commits at lists.llvm.org
Wed Mar 6 15:03:40 PST 2024


https://github.com/goldsteinn updated https://github.com/llvm/llvm-project/pull/84254

>From 7ccab7d59de56964dde40d5f29e2dcdc666f7311 Mon Sep 17 00:00:00 2001
From: Noah Goldstein <goldstein.w.n at gmail.com>
Date: Tue, 5 Mar 2024 21:56:27 -0600
Subject: [PATCH 1/2] [KnownBits] Add API support for `exact` in `lshr`/`ashr`;
 NFC

---
 llvm/include/llvm/Support/KnownBits.h          |  4 ++--
 llvm/lib/Analysis/ValueTracking.cpp            | 14 ++++++++------
 llvm/lib/CodeGen/GlobalISel/GISelKnownBits.cpp |  4 +++-
 llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp |  6 ++++--
 llvm/lib/Support/KnownBits.cpp                 |  4 ++--
 5 files changed, 19 insertions(+), 13 deletions(-)

diff --git a/llvm/include/llvm/Support/KnownBits.h b/llvm/include/llvm/Support/KnownBits.h
index 46dbf0c2baa5fe..06d2c90f7b0f6b 100644
--- a/llvm/include/llvm/Support/KnownBits.h
+++ b/llvm/include/llvm/Support/KnownBits.h
@@ -402,12 +402,12 @@ struct KnownBits {
   /// Compute known bits for lshr(LHS, RHS).
   /// NOTE: RHS (shift amount) bitwidth doesn't need to be the same as LHS.
   static KnownBits lshr(const KnownBits &LHS, const KnownBits &RHS,
-                        bool ShAmtNonZero = false);
+                        bool ShAmtNonZero = false, bool Exact = false);
 
   /// Compute known bits for ashr(LHS, RHS).
   /// NOTE: RHS (shift amount) bitwidth doesn't need to be the same as LHS.
   static KnownBits ashr(const KnownBits &LHS, const KnownBits &RHS,
-                        bool ShAmtNonZero = false);
+                        bool ShAmtNonZero = false, bool Exact = false);
 
   /// Determine if these known bits always give the same ICMP_EQ result.
   static std::optional<bool> eq(const KnownBits &LHS, const KnownBits &RHS);
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 52ae9f034e5d34..3304db68e3deae 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -1142,9 +1142,10 @@ static void computeKnownBitsFromOperator(const Operator *I,
     break;
   }
   case Instruction::LShr: {
-    auto KF = [](const KnownBits &KnownVal, const KnownBits &KnownAmt,
-                 bool ShAmtNonZero) {
-      return KnownBits::lshr(KnownVal, KnownAmt, ShAmtNonZero);
+    bool Exact = Q.IIQ.isExact(cast<BinaryOperator>(I));
+    auto KF = [Exact](const KnownBits &KnownVal, const KnownBits &KnownAmt,
+                      bool ShAmtNonZero) {
+      return KnownBits::lshr(KnownVal, KnownAmt, ShAmtNonZero, Exact);
     };
     computeKnownBitsFromShiftOperator(I, DemandedElts, Known, Known2, Depth, Q,
                                       KF);
@@ -1155,9 +1156,10 @@ static void computeKnownBitsFromOperator(const Operator *I,
     break;
   }
   case Instruction::AShr: {
-    auto KF = [](const KnownBits &KnownVal, const KnownBits &KnownAmt,
-                 bool ShAmtNonZero) {
-      return KnownBits::ashr(KnownVal, KnownAmt, ShAmtNonZero);
+    bool Exact = Q.IIQ.isExact(cast<BinaryOperator>(I));
+    auto KF = [Exact](const KnownBits &KnownVal, const KnownBits &KnownAmt,
+                      bool ShAmtNonZero) {
+      return KnownBits::ashr(KnownVal, KnownAmt, ShAmtNonZero, Exact);
     };
     computeKnownBitsFromShiftOperator(I, DemandedElts, Known, Known2, Depth, Q,
                                       KF);
diff --git a/llvm/lib/CodeGen/GlobalISel/GISelKnownBits.cpp b/llvm/lib/CodeGen/GlobalISel/GISelKnownBits.cpp
index 099bf45b2734cb..21e0b7b2b68fc7 100644
--- a/llvm/lib/CodeGen/GlobalISel/GISelKnownBits.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/GISelKnownBits.cpp
@@ -565,7 +565,9 @@ void GISelKnownBits::computeKnownBitsImpl(Register R, KnownBits &Known,
     KnownBits ExtKnown = KnownBits::makeConstant(APInt(BitWidth, BitWidth));
     KnownBits ShiftKnown = KnownBits::computeForAddSub(
         /*Add=*/false, /*NSW=*/false, /* NUW=*/false, ExtKnown, WidthKnown);
-    Known = KnownBits::ashr(KnownBits::shl(Known, ShiftKnown), ShiftKnown);
+    Known = KnownBits::ashr(KnownBits::shl(Known, ShiftKnown), ShiftKnown,
+                            /*ShAmtNonZero=*/false,
+                            /*Exact*/ true);
     break;
   }
   case TargetOpcode::G_UADDO:
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index f7ace79e8c51d4..92e20cc1304b70 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -3485,7 +3485,8 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
   case ISD::SRL:
     Known = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
     Known2 = computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
-    Known = KnownBits::lshr(Known, Known2);
+    Known = KnownBits::lshr(Known, Known2, /*ShAmtNonZero=*/false,
+                            Op->getFlags().hasExact());
 
     // Minimum shift high bits are known zero.
     if (const APInt *ShMinAmt =
@@ -3495,7 +3496,8 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
   case ISD::SRA:
     Known = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
     Known2 = computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
-    Known = KnownBits::ashr(Known, Known2);
+    Known = KnownBits::ashr(Known, Known2, /*ShAmtNonZero=*/false,
+                            Op->getFlags().hasExact());
     break;
   case ISD::FSHL:
   case ISD::FSHR:
diff --git a/llvm/lib/Support/KnownBits.cpp b/llvm/lib/Support/KnownBits.cpp
index 74d857457aec1e..ed25e52b9ace67 100644
--- a/llvm/lib/Support/KnownBits.cpp
+++ b/llvm/lib/Support/KnownBits.cpp
@@ -343,7 +343,7 @@ KnownBits KnownBits::shl(const KnownBits &LHS, const KnownBits &RHS, bool NUW,
 }
 
 KnownBits KnownBits::lshr(const KnownBits &LHS, const KnownBits &RHS,
-                          bool ShAmtNonZero) {
+                          bool ShAmtNonZero, bool /*Exact*/) {
   unsigned BitWidth = LHS.getBitWidth();
   auto ShiftByConst = [&](const KnownBits &LHS, unsigned ShiftAmt) {
     KnownBits Known = LHS;
@@ -389,7 +389,7 @@ KnownBits KnownBits::lshr(const KnownBits &LHS, const KnownBits &RHS,
 }
 
 KnownBits KnownBits::ashr(const KnownBits &LHS, const KnownBits &RHS,
-                          bool ShAmtNonZero) {
+                          bool ShAmtNonZero, bool /*Exact*/) {
   unsigned BitWidth = LHS.getBitWidth();
   auto ShiftByConst = [&](const KnownBits &LHS, unsigned ShiftAmt) {
     KnownBits Known = LHS;

>From b1f874327f235884726c84ba6fb489633abc0b04 Mon Sep 17 00:00:00 2001
From: Noah Goldstein <goldstein.w.n at gmail.com>
Date: Tue, 5 Mar 2024 22:03:44 -0600
Subject: [PATCH 2/2] [KnownBits] Implement knownbits `lshr`/`ashr` with exact
 flag

The exact flag basically allows us to set an upper bound on shift
amount when we have a known 1 in `LHS`.

Typically we deduce exact using knownbits (on non-exact incoming
shifts), so this is particularly impactful, but may be useful in some
circumstances.
---
 llvm/lib/Support/KnownBits.cpp           | 28 ++++++++++++++--
 llvm/unittests/Support/KnownBitsTest.cpp | 42 ++++++++++++++++++++++++
 2 files changed, 68 insertions(+), 2 deletions(-)

diff --git a/llvm/lib/Support/KnownBits.cpp b/llvm/lib/Support/KnownBits.cpp
index ed25e52b9ace67..c33c3680825a10 100644
--- a/llvm/lib/Support/KnownBits.cpp
+++ b/llvm/lib/Support/KnownBits.cpp
@@ -343,7 +343,7 @@ KnownBits KnownBits::shl(const KnownBits &LHS, const KnownBits &RHS, bool NUW,
 }
 
 KnownBits KnownBits::lshr(const KnownBits &LHS, const KnownBits &RHS,
-                          bool ShAmtNonZero, bool /*Exact*/) {
+                          bool ShAmtNonZero, bool Exact) {
   unsigned BitWidth = LHS.getBitWidth();
   auto ShiftByConst = [&](const KnownBits &LHS, unsigned ShiftAmt) {
     KnownBits Known = LHS;
@@ -367,6 +367,18 @@ KnownBits KnownBits::lshr(const KnownBits &LHS, const KnownBits &RHS,
   // Find the common bits from all possible shifts.
   APInt MaxValue = RHS.getMaxValue();
   unsigned MaxShiftAmount = getMaxShiftAmount(MaxValue, BitWidth);
+
+  // If exact, bound MaxShiftAmount to first known 1 in LHS.
+  if (Exact) {
+    unsigned FirstOne = LHS.countMaxTrailingZeros();
+    if (FirstOne < MinShiftAmount) {
+      // Always poison. Return zero because we don't like returning conflict.
+      Known.setAllZero();
+      return Known;
+    }
+    MaxShiftAmount = std::min(MaxShiftAmount, FirstOne);
+  }
+
   unsigned ShiftAmtZeroMask = RHS.Zero.zextOrTrunc(32).getZExtValue();
   unsigned ShiftAmtOneMask = RHS.One.zextOrTrunc(32).getZExtValue();
   Known.Zero.setAllBits();
@@ -389,7 +401,7 @@ KnownBits KnownBits::lshr(const KnownBits &LHS, const KnownBits &RHS,
 }
 
 KnownBits KnownBits::ashr(const KnownBits &LHS, const KnownBits &RHS,
-                          bool ShAmtNonZero, bool /*Exact*/) {
+                          bool ShAmtNonZero, bool Exact) {
   unsigned BitWidth = LHS.getBitWidth();
   auto ShiftByConst = [&](const KnownBits &LHS, unsigned ShiftAmt) {
     KnownBits Known = LHS;
@@ -415,6 +427,18 @@ KnownBits KnownBits::ashr(const KnownBits &LHS, const KnownBits &RHS,
   // Find the common bits from all possible shifts.
   APInt MaxValue = RHS.getMaxValue();
   unsigned MaxShiftAmount = getMaxShiftAmount(MaxValue, BitWidth);
+
+  // If exact, bound MaxShiftAmount to first known 1 in LHS.
+  if (Exact) {
+    unsigned FirstOne = LHS.countMaxTrailingZeros();
+    if (FirstOne < MinShiftAmount) {
+      // Always poison. Return zero because we don't like returning conflict.
+      Known.setAllZero();
+      return Known;
+    }
+    MaxShiftAmount = std::min(MaxShiftAmount, FirstOne);
+  }
+
   unsigned ShiftAmtZeroMask = RHS.Zero.zextOrTrunc(32).getZExtValue();
   unsigned ShiftAmtOneMask = RHS.One.zextOrTrunc(32).getZExtValue();
   Known.Zero.setAllBits();
diff --git a/llvm/unittests/Support/KnownBitsTest.cpp b/llvm/unittests/Support/KnownBitsTest.cpp
index 658f3796721c4e..1add11f4f46069 100644
--- a/llvm/unittests/Support/KnownBitsTest.cpp
+++ b/llvm/unittests/Support/KnownBitsTest.cpp
@@ -320,6 +320,22 @@ TEST(KnownBitsTest, AbsDiffSpecialCase) {
   EXPECT_EQ(0b0000ul, Res.Zero.getZExtValue());
 }
 
+TEST(KnownBitsTest, ShrExactSpecialCase) {
+  const unsigned N = 4;
+  KnownBits LHS(N), RHS(N);
+
+  LHS.One.setBit(1);
+  LHS.One.setBit(2);
+
+  EXPECT_FALSE(KnownBits::lshr(LHS, RHS).One[1]);
+  EXPECT_FALSE(KnownBits::ashr(LHS, RHS).One[1]);
+
+  EXPECT_TRUE(
+      KnownBits::lshr(LHS, RHS, /*ShAmtNonZero=*/false, /*Exact=*/true).One[1]);
+  EXPECT_TRUE(
+      KnownBits::ashr(LHS, RHS, /*ShAmtNonZero=*/false, /*Exact=*/true).One[1]);
+}
+
 TEST(KnownBitsTest, BinaryExhaustive) {
   testBinaryOpExhaustive(
       [](const KnownBits &Known1, const KnownBits &Known2) {
@@ -516,6 +532,19 @@ TEST(KnownBitsTest, BinaryExhaustive) {
         return N1.lshr(N2);
       },
       checkOptimalityBinary, /* RefinePoisonToZero */ true);
+  testBinaryOpExhaustive(
+      [](const KnownBits &Known1, const KnownBits &Known2) {
+        return KnownBits::lshr(Known1, Known2, /*ShAmtNonZero=*/false,
+                               /*Exact=*/true);
+      },
+      [](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
+        if (N2.uge(N2.getBitWidth()))
+          return std::nullopt;
+        if (!N2.isZero() && !N1.extractBits(N2.getZExtValue(), 0).isZero())
+          return std::nullopt;
+        return N1.lshr(N2);
+      },
+      checkOptimalityBinary, /* RefinePoisonToZero */ true);
   testBinaryOpExhaustive(
       [](const KnownBits &Known1, const KnownBits &Known2) {
         return KnownBits::ashr(Known1, Known2);
@@ -526,6 +555,19 @@ TEST(KnownBitsTest, BinaryExhaustive) {
         return N1.ashr(N2);
       },
       checkOptimalityBinary, /* RefinePoisonToZero */ true);
+  testBinaryOpExhaustive(
+      [](const KnownBits &Known1, const KnownBits &Known2) {
+        return KnownBits::ashr(Known1, Known2, /*ShAmtNonZero=*/false,
+                               /*Exact=*/true);
+      },
+      [](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
+        if (N2.uge(N2.getBitWidth()))
+          return std::nullopt;
+        if (!N2.isZero() && !N1.extractBits(N2.getZExtValue(), 0).isZero())
+          return std::nullopt;
+        return N1.ashr(N2);
+      },
+      checkOptimalityBinary, /* RefinePoisonToZero */ true);
 
   testBinaryOpExhaustive(
       [](const KnownBits &Known1, const KnownBits &Known2) {



More information about the llvm-commits mailing list