[llvm] 8c569c9 - [KnownBits] Add implementation for `KnownBits::sdiv`
Noah Goldstein via llvm-commits
llvm-commits at lists.llvm.org
Tue May 16 16:58:43 PDT 2023
Author: Noah Goldstein
Date: 2023-05-16T18:58:12-05:00
New Revision: 8c569c922bc5bbe9fb67ff3ff3ac28e17b012360
URL: https://github.com/llvm/llvm-project/commit/8c569c922bc5bbe9fb67ff3ff3ac28e17b012360
DIFF: https://github.com/llvm/llvm-project/commit/8c569c922bc5bbe9fb67ff3ff3ac28e17b012360.diff
LOG: [KnownBits] Add implementation for `KnownBits::sdiv`
Can figure out some of the upper bits (similiar to `udiv`) if we know
the sign of the inputs.
As well, if we have the `exact` flag we can sometimes determine some
low-bits.
Differential Revision: https://reviews.llvm.org/D150093
Added:
Modified:
llvm/include/llvm/Support/KnownBits.h
llvm/lib/Support/KnownBits.cpp
llvm/unittests/Support/KnownBitsTest.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/Support/KnownBits.h b/llvm/include/llvm/Support/KnownBits.h
index a632bc4973a08..bce234eecbb8a 100644
--- a/llvm/include/llvm/Support/KnownBits.h
+++ b/llvm/include/llvm/Support/KnownBits.h
@@ -342,6 +342,10 @@ struct KnownBits {
/// Compute known bits from zero-extended multiply-hi.
static KnownBits mulhu(const KnownBits &LHS, const KnownBits &RHS);
+ /// Compute known bits for sdiv(LHS, RHS).
+ static KnownBits sdiv(const KnownBits &LHS, const KnownBits &RHS,
+ bool Exact = false);
+
/// Compute known bits for udiv(LHS, RHS).
static KnownBits udiv(const KnownBits &LHS, const KnownBits &RHS);
diff --git a/llvm/lib/Support/KnownBits.cpp b/llvm/lib/Support/KnownBits.cpp
index ddeb6a4961601..a00cc04a676f0 100644
--- a/llvm/lib/Support/KnownBits.cpp
+++ b/llvm/lib/Support/KnownBits.cpp
@@ -536,6 +536,74 @@ KnownBits KnownBits::mulhu(const KnownBits &LHS, const KnownBits &RHS) {
return mul(WideLHS, WideRHS).extractBits(BitWidth, BitWidth);
}
+KnownBits KnownBits::sdiv(const KnownBits &LHS, const KnownBits &RHS,
+ bool Exact) {
+ // Equivilent of `udiv`. We must have caught this before it was folded.
+ if (LHS.isNonNegative() && RHS.isNonNegative())
+ return udiv(LHS, RHS);
+
+ unsigned BitWidth = LHS.getBitWidth();
+ assert(!LHS.hasConflict() && !RHS.hasConflict() && "Bad inputs");
+ KnownBits Known(BitWidth);
+
+ APInt Num, Denum;
+ // Positive -> true
+ // Negative -> false
+ // Unknown -> nullopt
+ std::optional<bool> ResultSign;
+ if (LHS.isNegative() && RHS.isNegative()) {
+ Denum = RHS.getSignedMaxValue();
+ Num = LHS.getSignedMinValue();
+ ResultSign = true;
+ // Result non-negative.
+ } else if (LHS.isNegative() && RHS.isStrictlyPositive()) {
+ // Result is non-negative if Exact OR -LHS u>= RHS.
+ if (Exact || (-LHS.getSignedMaxValue()).uge(RHS.getSignedMaxValue())) {
+ Denum = RHS.getSignedMinValue();
+ Num = LHS.getSignedMinValue();
+ ResultSign = false;
+ }
+ } else if (LHS.isStrictlyPositive() && RHS.isNegative()) {
+ // Result is non-negative if Exact OR LHS u>= -RHS.
+ if (Exact || LHS.getSignedMinValue().uge(-RHS.getSignedMinValue())) {
+ Denum = RHS.getSignedMaxValue();
+ Num = LHS.getSignedMaxValue();
+ ResultSign = false;
+ }
+ }
+
+ if (ResultSign) {
+ APInt Res = Num.sdiv(Denum);
+ if (*ResultSign) {
+ unsigned LeadZ = Res.countLeadingZeros();
+ Known.Zero.setHighBits(LeadZ);
+ Known.makeNonNegative();
+ } else {
+ unsigned LeadO = Res.countLeadingOnes();
+ Known.One.setHighBits(LeadO);
+ Known.makeNegative();
+ }
+ }
+
+ if (Exact) {
+ // Odd / Odd -> Odd
+ if (LHS.One[0] && RHS.One[0]) {
+ Known.Zero.clearBit(0);
+ Known.One.setBit(0);
+ }
+ // Even / Odd -> Even
+ else if (LHS.Zero[0] && RHS.One[0]) {
+ Known.One.clearBit(0);
+ Known.Zero.setBit(0);
+ }
+ // Odd / Even -> impossible
+ // Even / Even -> unknown
+ }
+
+ assert(!Known.hasConflict() && "Bad Output");
+ return Known;
+}
+
KnownBits KnownBits::udiv(const KnownBits &LHS, const KnownBits &RHS) {
unsigned BitWidth = LHS.getBitWidth();
assert(!LHS.hasConflict() && !RHS.hasConflict());
diff --git a/llvm/unittests/Support/KnownBitsTest.cpp b/llvm/unittests/Support/KnownBitsTest.cpp
index ece7e80147db8..22dcbed76b9ff 100644
--- a/llvm/unittests/Support/KnownBitsTest.cpp
+++ b/llvm/unittests/Support/KnownBitsTest.cpp
@@ -249,6 +249,27 @@ TEST(KnownBitsTest, BinaryExhaustive) {
return N1.udiv(N2);
},
checkCorrectnessOnlyBinary);
+ testBinaryOpExhaustive(
+ [](const KnownBits &Known1, const KnownBits &Known2) {
+ return KnownBits::sdiv(Known1, Known2);
+ },
+ [](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
+ if (N2.isZero() || (N1.isMinSignedValue() && N2.isAllOnes()))
+ return std::nullopt;
+ return N1.sdiv(N2);
+ },
+ checkCorrectnessOnlyBinary);
+ testBinaryOpExhaustive(
+ [](const KnownBits &Known1, const KnownBits &Known2) {
+ return KnownBits::sdiv(Known1, Known2, /*Exact*/ true);
+ },
+ [](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
+ if (N2.isZero() || (N1.isMinSignedValue() && N2.isAllOnes()) ||
+ !N1.srem(N2).isZero())
+ return std::nullopt;
+ return N1.sdiv(N2);
+ },
+ checkCorrectnessOnlyBinary);
testBinaryOpExhaustive(
[](const KnownBits &Known1, const KnownBits &Known2) {
return KnownBits::urem(Known1, Known2);
More information about the llvm-commits
mailing list