[llvm] 8c569c9 - [KnownBits] Add implementation for `KnownBits::sdiv`

Noah Goldstein via llvm-commits llvm-commits at lists.llvm.org
Tue May 16 16:58:43 PDT 2023


Author: Noah Goldstein
Date: 2023-05-16T18:58:12-05:00
New Revision: 8c569c922bc5bbe9fb67ff3ff3ac28e17b012360

URL: https://github.com/llvm/llvm-project/commit/8c569c922bc5bbe9fb67ff3ff3ac28e17b012360
DIFF: https://github.com/llvm/llvm-project/commit/8c569c922bc5bbe9fb67ff3ff3ac28e17b012360.diff

LOG: [KnownBits] Add implementation for `KnownBits::sdiv`

Can figure out some of the upper bits (similiar to `udiv`) if we know
the sign of the inputs.

As well, if we have the `exact` flag we can sometimes determine some
low-bits.

Differential Revision: https://reviews.llvm.org/D150093

Added: 
    

Modified: 
    llvm/include/llvm/Support/KnownBits.h
    llvm/lib/Support/KnownBits.cpp
    llvm/unittests/Support/KnownBitsTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Support/KnownBits.h b/llvm/include/llvm/Support/KnownBits.h
index a632bc4973a08..bce234eecbb8a 100644
--- a/llvm/include/llvm/Support/KnownBits.h
+++ b/llvm/include/llvm/Support/KnownBits.h
@@ -342,6 +342,10 @@ struct KnownBits {
   /// Compute known bits from zero-extended multiply-hi.
   static KnownBits mulhu(const KnownBits &LHS, const KnownBits &RHS);
 
+  /// Compute known bits for sdiv(LHS, RHS).
+  static KnownBits sdiv(const KnownBits &LHS, const KnownBits &RHS,
+                        bool Exact = false);
+
   /// Compute known bits for udiv(LHS, RHS).
   static KnownBits udiv(const KnownBits &LHS, const KnownBits &RHS);
 

diff  --git a/llvm/lib/Support/KnownBits.cpp b/llvm/lib/Support/KnownBits.cpp
index ddeb6a4961601..a00cc04a676f0 100644
--- a/llvm/lib/Support/KnownBits.cpp
+++ b/llvm/lib/Support/KnownBits.cpp
@@ -536,6 +536,74 @@ KnownBits KnownBits::mulhu(const KnownBits &LHS, const KnownBits &RHS) {
   return mul(WideLHS, WideRHS).extractBits(BitWidth, BitWidth);
 }
 
+KnownBits KnownBits::sdiv(const KnownBits &LHS, const KnownBits &RHS,
+                          bool Exact) {
+  // Equivilent of `udiv`. We must have caught this before it was folded.
+  if (LHS.isNonNegative() && RHS.isNonNegative())
+    return udiv(LHS, RHS);
+
+  unsigned BitWidth = LHS.getBitWidth();
+  assert(!LHS.hasConflict() && !RHS.hasConflict() && "Bad inputs");
+  KnownBits Known(BitWidth);
+
+  APInt Num, Denum;
+  // Positive -> true
+  // Negative -> false
+  // Unknown -> nullopt
+  std::optional<bool> ResultSign;
+  if (LHS.isNegative() && RHS.isNegative()) {
+    Denum = RHS.getSignedMaxValue();
+    Num = LHS.getSignedMinValue();
+    ResultSign = true;
+    // Result non-negative.
+  } else if (LHS.isNegative() && RHS.isStrictlyPositive()) {
+    // Result is non-negative if Exact OR -LHS u>= RHS.
+    if (Exact || (-LHS.getSignedMaxValue()).uge(RHS.getSignedMaxValue())) {
+      Denum = RHS.getSignedMinValue();
+      Num = LHS.getSignedMinValue();
+      ResultSign = false;
+    }
+  } else if (LHS.isStrictlyPositive() && RHS.isNegative()) {
+    // Result is non-negative if Exact OR LHS u>= -RHS.
+    if (Exact || LHS.getSignedMinValue().uge(-RHS.getSignedMinValue())) {
+      Denum = RHS.getSignedMaxValue();
+      Num = LHS.getSignedMaxValue();
+      ResultSign = false;
+    }
+  }
+
+  if (ResultSign) {
+    APInt Res = Num.sdiv(Denum);
+    if (*ResultSign) {
+      unsigned LeadZ = Res.countLeadingZeros();
+      Known.Zero.setHighBits(LeadZ);
+      Known.makeNonNegative();
+    } else {
+      unsigned LeadO = Res.countLeadingOnes();
+      Known.One.setHighBits(LeadO);
+      Known.makeNegative();
+    }
+  }
+
+  if (Exact) {
+    // Odd / Odd -> Odd
+    if (LHS.One[0] && RHS.One[0]) {
+      Known.Zero.clearBit(0);
+      Known.One.setBit(0);
+    }
+    // Even / Odd -> Even
+    else if (LHS.Zero[0] && RHS.One[0]) {
+      Known.One.clearBit(0);
+      Known.Zero.setBit(0);
+    }
+    // Odd / Even -> impossible
+    // Even / Even -> unknown
+  }
+
+  assert(!Known.hasConflict() && "Bad Output");
+  return Known;
+}
+
 KnownBits KnownBits::udiv(const KnownBits &LHS, const KnownBits &RHS) {
   unsigned BitWidth = LHS.getBitWidth();
   assert(!LHS.hasConflict() && !RHS.hasConflict());

diff  --git a/llvm/unittests/Support/KnownBitsTest.cpp b/llvm/unittests/Support/KnownBitsTest.cpp
index ece7e80147db8..22dcbed76b9ff 100644
--- a/llvm/unittests/Support/KnownBitsTest.cpp
+++ b/llvm/unittests/Support/KnownBitsTest.cpp
@@ -249,6 +249,27 @@ TEST(KnownBitsTest, BinaryExhaustive) {
         return N1.udiv(N2);
       },
       checkCorrectnessOnlyBinary);
+  testBinaryOpExhaustive(
+      [](const KnownBits &Known1, const KnownBits &Known2) {
+        return KnownBits::sdiv(Known1, Known2);
+      },
+      [](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
+        if (N2.isZero() || (N1.isMinSignedValue() && N2.isAllOnes()))
+          return std::nullopt;
+        return N1.sdiv(N2);
+      },
+      checkCorrectnessOnlyBinary);
+  testBinaryOpExhaustive(
+      [](const KnownBits &Known1, const KnownBits &Known2) {
+        return KnownBits::sdiv(Known1, Known2, /*Exact*/ true);
+      },
+      [](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
+        if (N2.isZero() || (N1.isMinSignedValue() && N2.isAllOnes()) ||
+            !N1.srem(N2).isZero())
+          return std::nullopt;
+        return N1.sdiv(N2);
+      },
+      checkCorrectnessOnlyBinary);
   testBinaryOpExhaustive(
       [](const KnownBits &Known1, const KnownBits &Known2) {
         return KnownBits::urem(Known1, Known2);


        


More information about the llvm-commits mailing list