[llvm] d81db0e - [KnownBits] Implement knownbits `lshr`/`ashr` with exact flag
Noah Goldstein via llvm-commits
llvm-commits at lists.llvm.org
Mon Mar 11 13:51:34 PDT 2024
Author: Noah Goldstein
Date: 2024-03-11T15:51:07-05:00
New Revision: d81db0e5f5b1404ff4813af3050d671528ad45cc
URL: https://github.com/llvm/llvm-project/commit/d81db0e5f5b1404ff4813af3050d671528ad45cc
DIFF: https://github.com/llvm/llvm-project/commit/d81db0e5f5b1404ff4813af3050d671528ad45cc.diff
LOG: [KnownBits] Implement knownbits `lshr`/`ashr` with exact flag
The exact flag basically allows us to set an upper bound on shift
amount when we have a known 1 in `LHS`.
Typically we deduce exact using knownbits (on non-exact incoming
shifts), so this is particularly impactful, but may be useful in some
circumstances.
Closes #84254
Added:
Modified:
llvm/lib/Support/KnownBits.cpp
llvm/test/Analysis/ValueTracking/knownbits-shift.ll
llvm/unittests/Support/KnownBitsTest.cpp
Removed:
################################################################################
diff --git a/llvm/lib/Support/KnownBits.cpp b/llvm/lib/Support/KnownBits.cpp
index ed25e52b9ace67..c33c3680825a10 100644
--- a/llvm/lib/Support/KnownBits.cpp
+++ b/llvm/lib/Support/KnownBits.cpp
@@ -343,7 +343,7 @@ KnownBits KnownBits::shl(const KnownBits &LHS, const KnownBits &RHS, bool NUW,
}
KnownBits KnownBits::lshr(const KnownBits &LHS, const KnownBits &RHS,
- bool ShAmtNonZero, bool /*Exact*/) {
+ bool ShAmtNonZero, bool Exact) {
unsigned BitWidth = LHS.getBitWidth();
auto ShiftByConst = [&](const KnownBits &LHS, unsigned ShiftAmt) {
KnownBits Known = LHS;
@@ -367,6 +367,18 @@ KnownBits KnownBits::lshr(const KnownBits &LHS, const KnownBits &RHS,
// Find the common bits from all possible shifts.
APInt MaxValue = RHS.getMaxValue();
unsigned MaxShiftAmount = getMaxShiftAmount(MaxValue, BitWidth);
+
+ // If exact, bound MaxShiftAmount to first known 1 in LHS.
+ if (Exact) {
+ unsigned FirstOne = LHS.countMaxTrailingZeros();
+ if (FirstOne < MinShiftAmount) {
+ // Always poison. Return zero because we don't like returning conflict.
+ Known.setAllZero();
+ return Known;
+ }
+ MaxShiftAmount = std::min(MaxShiftAmount, FirstOne);
+ }
+
unsigned ShiftAmtZeroMask = RHS.Zero.zextOrTrunc(32).getZExtValue();
unsigned ShiftAmtOneMask = RHS.One.zextOrTrunc(32).getZExtValue();
Known.Zero.setAllBits();
@@ -389,7 +401,7 @@ KnownBits KnownBits::lshr(const KnownBits &LHS, const KnownBits &RHS,
}
KnownBits KnownBits::ashr(const KnownBits &LHS, const KnownBits &RHS,
- bool ShAmtNonZero, bool /*Exact*/) {
+ bool ShAmtNonZero, bool Exact) {
unsigned BitWidth = LHS.getBitWidth();
auto ShiftByConst = [&](const KnownBits &LHS, unsigned ShiftAmt) {
KnownBits Known = LHS;
@@ -415,6 +427,18 @@ KnownBits KnownBits::ashr(const KnownBits &LHS, const KnownBits &RHS,
// Find the common bits from all possible shifts.
APInt MaxValue = RHS.getMaxValue();
unsigned MaxShiftAmount = getMaxShiftAmount(MaxValue, BitWidth);
+
+ // If exact, bound MaxShiftAmount to first known 1 in LHS.
+ if (Exact) {
+ unsigned FirstOne = LHS.countMaxTrailingZeros();
+ if (FirstOne < MinShiftAmount) {
+ // Always poison. Return zero because we don't like returning conflict.
+ Known.setAllZero();
+ return Known;
+ }
+ MaxShiftAmount = std::min(MaxShiftAmount, FirstOne);
+ }
+
unsigned ShiftAmtZeroMask = RHS.Zero.zextOrTrunc(32).getZExtValue();
unsigned ShiftAmtOneMask = RHS.One.zextOrTrunc(32).getZExtValue();
Known.Zero.setAllBits();
diff --git a/llvm/test/Analysis/ValueTracking/knownbits-shift.ll b/llvm/test/Analysis/ValueTracking/knownbits-shift.ll
index 3235f69b5221a1..5cb355eff5a699 100644
--- a/llvm/test/Analysis/ValueTracking/knownbits-shift.ll
+++ b/llvm/test/Analysis/ValueTracking/knownbits-shift.ll
@@ -3,9 +3,7 @@
define i8 @simplify_lshr_with_exact(i8 %x) {
; CHECK-LABEL: @simplify_lshr_with_exact(
-; CHECK-NEXT: [[SHR:%.*]] = lshr exact i8 6, [[X:%.*]]
-; CHECK-NEXT: [[R:%.*]] = and i8 [[SHR]], 2
-; CHECK-NEXT: ret i8 [[R]]
+; CHECK-NEXT: ret i8 2
;
%shr = lshr exact i8 6, %x
%r = and i8 %shr, 2
@@ -14,9 +12,7 @@ define i8 @simplify_lshr_with_exact(i8 %x) {
define i8 @simplify_ashr_with_exact(i8 %x) {
; CHECK-LABEL: @simplify_ashr_with_exact(
-; CHECK-NEXT: [[SHR:%.*]] = ashr exact i8 -122, [[X:%.*]]
-; CHECK-NEXT: [[R:%.*]] = and i8 [[SHR]], 2
-; CHECK-NEXT: ret i8 [[R]]
+; CHECK-NEXT: ret i8 2
;
%shr = ashr exact i8 -122, %x
%r = and i8 %shr, 2
diff --git a/llvm/unittests/Support/KnownBitsTest.cpp b/llvm/unittests/Support/KnownBitsTest.cpp
index 658f3796721c4e..7c183e9626f985 100644
--- a/llvm/unittests/Support/KnownBitsTest.cpp
+++ b/llvm/unittests/Support/KnownBitsTest.cpp
@@ -516,6 +516,19 @@ TEST(KnownBitsTest, BinaryExhaustive) {
return N1.lshr(N2);
},
checkOptimalityBinary, /* RefinePoisonToZero */ true);
+ testBinaryOpExhaustive(
+ [](const KnownBits &Known1, const KnownBits &Known2) {
+ return KnownBits::lshr(Known1, Known2, /*ShAmtNonZero=*/false,
+ /*Exact=*/true);
+ },
+ [](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
+ if (N2.uge(N2.getBitWidth()))
+ return std::nullopt;
+ if (!N1.extractBits(N2.getZExtValue(), 0).isZero())
+ return std::nullopt;
+ return N1.lshr(N2);
+ },
+ checkOptimalityBinary, /* RefinePoisonToZero */ true);
testBinaryOpExhaustive(
[](const KnownBits &Known1, const KnownBits &Known2) {
return KnownBits::ashr(Known1, Known2);
@@ -526,6 +539,19 @@ TEST(KnownBitsTest, BinaryExhaustive) {
return N1.ashr(N2);
},
checkOptimalityBinary, /* RefinePoisonToZero */ true);
+ testBinaryOpExhaustive(
+ [](const KnownBits &Known1, const KnownBits &Known2) {
+ return KnownBits::ashr(Known1, Known2, /*ShAmtNonZero=*/false,
+ /*Exact=*/true);
+ },
+ [](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
+ if (N2.uge(N2.getBitWidth()))
+ return std::nullopt;
+ if (!N1.extractBits(N2.getZExtValue(), 0).isZero())
+ return std::nullopt;
+ return N1.ashr(N2);
+ },
+ checkOptimalityBinary, /* RefinePoisonToZero */ true);
testBinaryOpExhaustive(
[](const KnownBits &Known1, const KnownBits &Known2) {
More information about the llvm-commits
mailing list