[llvm] [KnownBits] Implement knownbits lshr/ashr with exact flag (PR #84254)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Mar 6 15:03:40 PST 2024
https://github.com/goldsteinn updated https://github.com/llvm/llvm-project/pull/84254
>From 7ccab7d59de56964dde40d5f29e2dcdc666f7311 Mon Sep 17 00:00:00 2001
From: Noah Goldstein <goldstein.w.n at gmail.com>
Date: Tue, 5 Mar 2024 21:56:27 -0600
Subject: [PATCH 1/2] [KnownBits] Add API support for `exact` in `lshr`/`ashr`;
NFC
---
llvm/include/llvm/Support/KnownBits.h | 4 ++--
llvm/lib/Analysis/ValueTracking.cpp | 14 ++++++++------
llvm/lib/CodeGen/GlobalISel/GISelKnownBits.cpp | 4 +++-
llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 6 ++++--
llvm/lib/Support/KnownBits.cpp | 4 ++--
5 files changed, 19 insertions(+), 13 deletions(-)
diff --git a/llvm/include/llvm/Support/KnownBits.h b/llvm/include/llvm/Support/KnownBits.h
index 46dbf0c2baa5fe..06d2c90f7b0f6b 100644
--- a/llvm/include/llvm/Support/KnownBits.h
+++ b/llvm/include/llvm/Support/KnownBits.h
@@ -402,12 +402,12 @@ struct KnownBits {
/// Compute known bits for lshr(LHS, RHS).
/// NOTE: RHS (shift amount) bitwidth doesn't need to be the same as LHS.
static KnownBits lshr(const KnownBits &LHS, const KnownBits &RHS,
- bool ShAmtNonZero = false);
+ bool ShAmtNonZero = false, bool Exact = false);
/// Compute known bits for ashr(LHS, RHS).
/// NOTE: RHS (shift amount) bitwidth doesn't need to be the same as LHS.
static KnownBits ashr(const KnownBits &LHS, const KnownBits &RHS,
- bool ShAmtNonZero = false);
+ bool ShAmtNonZero = false, bool Exact = false);
/// Determine if these known bits always give the same ICMP_EQ result.
static std::optional<bool> eq(const KnownBits &LHS, const KnownBits &RHS);
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 52ae9f034e5d34..3304db68e3deae 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -1142,9 +1142,10 @@ static void computeKnownBitsFromOperator(const Operator *I,
break;
}
case Instruction::LShr: {
- auto KF = [](const KnownBits &KnownVal, const KnownBits &KnownAmt,
- bool ShAmtNonZero) {
- return KnownBits::lshr(KnownVal, KnownAmt, ShAmtNonZero);
+ bool Exact = Q.IIQ.isExact(cast<BinaryOperator>(I));
+ auto KF = [Exact](const KnownBits &KnownVal, const KnownBits &KnownAmt,
+ bool ShAmtNonZero) {
+ return KnownBits::lshr(KnownVal, KnownAmt, ShAmtNonZero, Exact);
};
computeKnownBitsFromShiftOperator(I, DemandedElts, Known, Known2, Depth, Q,
KF);
@@ -1155,9 +1156,10 @@ static void computeKnownBitsFromOperator(const Operator *I,
break;
}
case Instruction::AShr: {
- auto KF = [](const KnownBits &KnownVal, const KnownBits &KnownAmt,
- bool ShAmtNonZero) {
- return KnownBits::ashr(KnownVal, KnownAmt, ShAmtNonZero);
+ bool Exact = Q.IIQ.isExact(cast<BinaryOperator>(I));
+ auto KF = [Exact](const KnownBits &KnownVal, const KnownBits &KnownAmt,
+ bool ShAmtNonZero) {
+ return KnownBits::ashr(KnownVal, KnownAmt, ShAmtNonZero, Exact);
};
computeKnownBitsFromShiftOperator(I, DemandedElts, Known, Known2, Depth, Q,
KF);
diff --git a/llvm/lib/CodeGen/GlobalISel/GISelKnownBits.cpp b/llvm/lib/CodeGen/GlobalISel/GISelKnownBits.cpp
index 099bf45b2734cb..21e0b7b2b68fc7 100644
--- a/llvm/lib/CodeGen/GlobalISel/GISelKnownBits.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/GISelKnownBits.cpp
@@ -565,7 +565,9 @@ void GISelKnownBits::computeKnownBitsImpl(Register R, KnownBits &Known,
KnownBits ExtKnown = KnownBits::makeConstant(APInt(BitWidth, BitWidth));
KnownBits ShiftKnown = KnownBits::computeForAddSub(
/*Add=*/false, /*NSW=*/false, /* NUW=*/false, ExtKnown, WidthKnown);
- Known = KnownBits::ashr(KnownBits::shl(Known, ShiftKnown), ShiftKnown);
+ Known = KnownBits::ashr(KnownBits::shl(Known, ShiftKnown), ShiftKnown,
+ /*ShAmtNonZero=*/false,
+ /*Exact*/ true);
break;
}
case TargetOpcode::G_UADDO:
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index f7ace79e8c51d4..92e20cc1304b70 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -3485,7 +3485,8 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
case ISD::SRL:
Known = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
Known2 = computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
- Known = KnownBits::lshr(Known, Known2);
+ Known = KnownBits::lshr(Known, Known2, /*ShAmtNonZero=*/false,
+ Op->getFlags().hasExact());
// Minimum shift high bits are known zero.
if (const APInt *ShMinAmt =
@@ -3495,7 +3496,8 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
case ISD::SRA:
Known = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
Known2 = computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
- Known = KnownBits::ashr(Known, Known2);
+ Known = KnownBits::ashr(Known, Known2, /*ShAmtNonZero=*/false,
+ Op->getFlags().hasExact());
break;
case ISD::FSHL:
case ISD::FSHR:
diff --git a/llvm/lib/Support/KnownBits.cpp b/llvm/lib/Support/KnownBits.cpp
index 74d857457aec1e..ed25e52b9ace67 100644
--- a/llvm/lib/Support/KnownBits.cpp
+++ b/llvm/lib/Support/KnownBits.cpp
@@ -343,7 +343,7 @@ KnownBits KnownBits::shl(const KnownBits &LHS, const KnownBits &RHS, bool NUW,
}
KnownBits KnownBits::lshr(const KnownBits &LHS, const KnownBits &RHS,
- bool ShAmtNonZero) {
+ bool ShAmtNonZero, bool /*Exact*/) {
unsigned BitWidth = LHS.getBitWidth();
auto ShiftByConst = [&](const KnownBits &LHS, unsigned ShiftAmt) {
KnownBits Known = LHS;
@@ -389,7 +389,7 @@ KnownBits KnownBits::lshr(const KnownBits &LHS, const KnownBits &RHS,
}
KnownBits KnownBits::ashr(const KnownBits &LHS, const KnownBits &RHS,
- bool ShAmtNonZero) {
+ bool ShAmtNonZero, bool /*Exact*/) {
unsigned BitWidth = LHS.getBitWidth();
auto ShiftByConst = [&](const KnownBits &LHS, unsigned ShiftAmt) {
KnownBits Known = LHS;
>From b1f874327f235884726c84ba6fb489633abc0b04 Mon Sep 17 00:00:00 2001
From: Noah Goldstein <goldstein.w.n at gmail.com>
Date: Tue, 5 Mar 2024 22:03:44 -0600
Subject: [PATCH 2/2] [KnownBits] Implement knownbits `lshr`/`ashr` with exact
flag
The exact flag basically allows us to set an upper bound on shift
amount when we have a known 1 in `LHS`.
Typically we deduce exact using knownbits (on non-exact incoming
shifts), so this is particularly impactful, but may be useful in some
circumstances.
---
llvm/lib/Support/KnownBits.cpp | 28 ++++++++++++++--
llvm/unittests/Support/KnownBitsTest.cpp | 42 ++++++++++++++++++++++++
2 files changed, 68 insertions(+), 2 deletions(-)
diff --git a/llvm/lib/Support/KnownBits.cpp b/llvm/lib/Support/KnownBits.cpp
index ed25e52b9ace67..c33c3680825a10 100644
--- a/llvm/lib/Support/KnownBits.cpp
+++ b/llvm/lib/Support/KnownBits.cpp
@@ -343,7 +343,7 @@ KnownBits KnownBits::shl(const KnownBits &LHS, const KnownBits &RHS, bool NUW,
}
KnownBits KnownBits::lshr(const KnownBits &LHS, const KnownBits &RHS,
- bool ShAmtNonZero, bool /*Exact*/) {
+ bool ShAmtNonZero, bool Exact) {
unsigned BitWidth = LHS.getBitWidth();
auto ShiftByConst = [&](const KnownBits &LHS, unsigned ShiftAmt) {
KnownBits Known = LHS;
@@ -367,6 +367,18 @@ KnownBits KnownBits::lshr(const KnownBits &LHS, const KnownBits &RHS,
// Find the common bits from all possible shifts.
APInt MaxValue = RHS.getMaxValue();
unsigned MaxShiftAmount = getMaxShiftAmount(MaxValue, BitWidth);
+
+ // If exact, bound MaxShiftAmount to first known 1 in LHS.
+ if (Exact) {
+ unsigned FirstOne = LHS.countMaxTrailingZeros();
+ if (FirstOne < MinShiftAmount) {
+ // Always poison. Return zero because we don't like returning conflict.
+ Known.setAllZero();
+ return Known;
+ }
+ MaxShiftAmount = std::min(MaxShiftAmount, FirstOne);
+ }
+
unsigned ShiftAmtZeroMask = RHS.Zero.zextOrTrunc(32).getZExtValue();
unsigned ShiftAmtOneMask = RHS.One.zextOrTrunc(32).getZExtValue();
Known.Zero.setAllBits();
@@ -389,7 +401,7 @@ KnownBits KnownBits::lshr(const KnownBits &LHS, const KnownBits &RHS,
}
KnownBits KnownBits::ashr(const KnownBits &LHS, const KnownBits &RHS,
- bool ShAmtNonZero, bool /*Exact*/) {
+ bool ShAmtNonZero, bool Exact) {
unsigned BitWidth = LHS.getBitWidth();
auto ShiftByConst = [&](const KnownBits &LHS, unsigned ShiftAmt) {
KnownBits Known = LHS;
@@ -415,6 +427,18 @@ KnownBits KnownBits::ashr(const KnownBits &LHS, const KnownBits &RHS,
// Find the common bits from all possible shifts.
APInt MaxValue = RHS.getMaxValue();
unsigned MaxShiftAmount = getMaxShiftAmount(MaxValue, BitWidth);
+
+ // If exact, bound MaxShiftAmount to first known 1 in LHS.
+ if (Exact) {
+ unsigned FirstOne = LHS.countMaxTrailingZeros();
+ if (FirstOne < MinShiftAmount) {
+ // Always poison. Return zero because we don't like returning conflict.
+ Known.setAllZero();
+ return Known;
+ }
+ MaxShiftAmount = std::min(MaxShiftAmount, FirstOne);
+ }
+
unsigned ShiftAmtZeroMask = RHS.Zero.zextOrTrunc(32).getZExtValue();
unsigned ShiftAmtOneMask = RHS.One.zextOrTrunc(32).getZExtValue();
Known.Zero.setAllBits();
diff --git a/llvm/unittests/Support/KnownBitsTest.cpp b/llvm/unittests/Support/KnownBitsTest.cpp
index 658f3796721c4e..1add11f4f46069 100644
--- a/llvm/unittests/Support/KnownBitsTest.cpp
+++ b/llvm/unittests/Support/KnownBitsTest.cpp
@@ -320,6 +320,22 @@ TEST(KnownBitsTest, AbsDiffSpecialCase) {
EXPECT_EQ(0b0000ul, Res.Zero.getZExtValue());
}
+TEST(KnownBitsTest, ShrExactSpecialCase) {
+ const unsigned N = 4;
+ KnownBits LHS(N), RHS(N);
+
+ LHS.One.setBit(1);
+ LHS.One.setBit(2);
+
+ EXPECT_FALSE(KnownBits::lshr(LHS, RHS).One[1]);
+ EXPECT_FALSE(KnownBits::ashr(LHS, RHS).One[1]);
+
+ EXPECT_TRUE(
+ KnownBits::lshr(LHS, RHS, /*ShAmtNonZero=*/false, /*Exact=*/true).One[1]);
+ EXPECT_TRUE(
+ KnownBits::ashr(LHS, RHS, /*ShAmtNonZero=*/false, /*Exact=*/true).One[1]);
+}
+
TEST(KnownBitsTest, BinaryExhaustive) {
testBinaryOpExhaustive(
[](const KnownBits &Known1, const KnownBits &Known2) {
@@ -516,6 +532,19 @@ TEST(KnownBitsTest, BinaryExhaustive) {
return N1.lshr(N2);
},
checkOptimalityBinary, /* RefinePoisonToZero */ true);
+ testBinaryOpExhaustive(
+ [](const KnownBits &Known1, const KnownBits &Known2) {
+ return KnownBits::lshr(Known1, Known2, /*ShAmtNonZero=*/false,
+ /*Exact=*/true);
+ },
+ [](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
+ if (N2.uge(N2.getBitWidth()))
+ return std::nullopt;
+ if (!N2.isZero() && !N1.extractBits(N2.getZExtValue(), 0).isZero())
+ return std::nullopt;
+ return N1.lshr(N2);
+ },
+ checkOptimalityBinary, /* RefinePoisonToZero */ true);
testBinaryOpExhaustive(
[](const KnownBits &Known1, const KnownBits &Known2) {
return KnownBits::ashr(Known1, Known2);
@@ -526,6 +555,19 @@ TEST(KnownBitsTest, BinaryExhaustive) {
return N1.ashr(N2);
},
checkOptimalityBinary, /* RefinePoisonToZero */ true);
+ testBinaryOpExhaustive(
+ [](const KnownBits &Known1, const KnownBits &Known2) {
+ return KnownBits::ashr(Known1, Known2, /*ShAmtNonZero=*/false,
+ /*Exact=*/true);
+ },
+ [](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
+ if (N2.uge(N2.getBitWidth()))
+ return std::nullopt;
+ if (!N2.isZero() && !N1.extractBits(N2.getZExtValue(), 0).isZero())
+ return std::nullopt;
+ return N1.ashr(N2);
+ },
+ checkOptimalityBinary, /* RefinePoisonToZero */ true);
testBinaryOpExhaustive(
[](const KnownBits &Known1, const KnownBits &Known2) {
More information about the llvm-commits
mailing list