[llvm] [KnownBits] Implement knownbits lshr/ashr with exact flag (PR #84254)
via llvm-commits
llvm-commits at lists.llvm.org
Thu Mar 7 09:33:55 PST 2024
https://github.com/goldsteinn updated https://github.com/llvm/llvm-project/pull/84254
>From 0173cfcf13ff864790a2f212ddbe95e199f2af64 Mon Sep 17 00:00:00 2001
From: Noah Goldstein <goldstein.w.n at gmail.com>
Date: Thu, 7 Mar 2024 11:28:28 -0600
Subject: [PATCH 1/3] [KnownBits] Add test for computing more information for
`lshr`/`ashr` with `exact` flag; NFC
---
.../Analysis/ValueTracking/knownbits-shift.ll | 24 +++++++++++++++++++
1 file changed, 24 insertions(+)
create mode 100644 llvm/test/Analysis/ValueTracking/knownbits-shift.ll
diff --git a/llvm/test/Analysis/ValueTracking/knownbits-shift.ll b/llvm/test/Analysis/ValueTracking/knownbits-shift.ll
new file mode 100644
index 00000000000000..3235f69b5221a1
--- /dev/null
+++ b/llvm/test/Analysis/ValueTracking/knownbits-shift.ll
@@ -0,0 +1,24 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt -passes=instcombine -S < %s | FileCheck %s
+
+define i8 @simplify_lshr_with_exact(i8 %x) {
+; CHECK-LABEL: @simplify_lshr_with_exact(
+; CHECK-NEXT: [[SHR:%.*]] = lshr exact i8 6, [[X:%.*]]
+; CHECK-NEXT: [[R:%.*]] = and i8 [[SHR]], 2
+; CHECK-NEXT: ret i8 [[R]]
+;
+ %shr = lshr exact i8 6, %x
+ %r = and i8 %shr, 2
+ ret i8 %r
+}
+
+define i8 @simplify_ashr_with_exact(i8 %x) {
+; CHECK-LABEL: @simplify_ashr_with_exact(
+; CHECK-NEXT: [[SHR:%.*]] = ashr exact i8 -122, [[X:%.*]]
+; CHECK-NEXT: [[R:%.*]] = and i8 [[SHR]], 2
+; CHECK-NEXT: ret i8 [[R]]
+;
+ %shr = ashr exact i8 -122, %x
+ %r = and i8 %shr, 2
+ ret i8 %r
+}
>From cd68183e35a1830e27c3d75576a9265fc545e72c Mon Sep 17 00:00:00 2001
From: Noah Goldstein <goldstein.w.n at gmail.com>
Date: Tue, 5 Mar 2024 21:56:27 -0600
Subject: [PATCH 2/3] [KnownBits] Add API support for `exact` in `lshr`/`ashr`;
NFC
---
llvm/include/llvm/Support/KnownBits.h | 4 ++--
llvm/lib/Analysis/ValueTracking.cpp | 14 ++++++++------
llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 6 ++++--
llvm/lib/Support/KnownBits.cpp | 4 ++--
4 files changed, 16 insertions(+), 12 deletions(-)
diff --git a/llvm/include/llvm/Support/KnownBits.h b/llvm/include/llvm/Support/KnownBits.h
index 46dbf0c2baa5fe..06d2c90f7b0f6b 100644
--- a/llvm/include/llvm/Support/KnownBits.h
+++ b/llvm/include/llvm/Support/KnownBits.h
@@ -402,12 +402,12 @@ struct KnownBits {
/// Compute known bits for lshr(LHS, RHS).
/// NOTE: RHS (shift amount) bitwidth doesn't need to be the same as LHS.
static KnownBits lshr(const KnownBits &LHS, const KnownBits &RHS,
- bool ShAmtNonZero = false);
+ bool ShAmtNonZero = false, bool Exact = false);
/// Compute known bits for ashr(LHS, RHS).
/// NOTE: RHS (shift amount) bitwidth doesn't need to be the same as LHS.
static KnownBits ashr(const KnownBits &LHS, const KnownBits &RHS,
- bool ShAmtNonZero = false);
+ bool ShAmtNonZero = false, bool Exact = false);
/// Determine if these known bits always give the same ICMP_EQ result.
static std::optional<bool> eq(const KnownBits &LHS, const KnownBits &RHS);
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 52ae9f034e5d34..3304db68e3deae 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -1142,9 +1142,10 @@ static void computeKnownBitsFromOperator(const Operator *I,
break;
}
case Instruction::LShr: {
- auto KF = [](const KnownBits &KnownVal, const KnownBits &KnownAmt,
- bool ShAmtNonZero) {
- return KnownBits::lshr(KnownVal, KnownAmt, ShAmtNonZero);
+ bool Exact = Q.IIQ.isExact(cast<BinaryOperator>(I));
+ auto KF = [Exact](const KnownBits &KnownVal, const KnownBits &KnownAmt,
+ bool ShAmtNonZero) {
+ return KnownBits::lshr(KnownVal, KnownAmt, ShAmtNonZero, Exact);
};
computeKnownBitsFromShiftOperator(I, DemandedElts, Known, Known2, Depth, Q,
KF);
@@ -1155,9 +1156,10 @@ static void computeKnownBitsFromOperator(const Operator *I,
break;
}
case Instruction::AShr: {
- auto KF = [](const KnownBits &KnownVal, const KnownBits &KnownAmt,
- bool ShAmtNonZero) {
- return KnownBits::ashr(KnownVal, KnownAmt, ShAmtNonZero);
+ bool Exact = Q.IIQ.isExact(cast<BinaryOperator>(I));
+ auto KF = [Exact](const KnownBits &KnownVal, const KnownBits &KnownAmt,
+ bool ShAmtNonZero) {
+ return KnownBits::ashr(KnownVal, KnownAmt, ShAmtNonZero, Exact);
};
computeKnownBitsFromShiftOperator(I, DemandedElts, Known, Known2, Depth, Q,
KF);
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index f7ace79e8c51d4..92e20cc1304b70 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -3485,7 +3485,8 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
case ISD::SRL:
Known = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
Known2 = computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
- Known = KnownBits::lshr(Known, Known2);
+ Known = KnownBits::lshr(Known, Known2, /*ShAmtNonZero=*/false,
+ Op->getFlags().hasExact());
// Minimum shift high bits are known zero.
if (const APInt *ShMinAmt =
@@ -3495,7 +3496,8 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
case ISD::SRA:
Known = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
Known2 = computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
- Known = KnownBits::ashr(Known, Known2);
+ Known = KnownBits::ashr(Known, Known2, /*ShAmtNonZero=*/false,
+ Op->getFlags().hasExact());
break;
case ISD::FSHL:
case ISD::FSHR:
diff --git a/llvm/lib/Support/KnownBits.cpp b/llvm/lib/Support/KnownBits.cpp
index 74d857457aec1e..ed25e52b9ace67 100644
--- a/llvm/lib/Support/KnownBits.cpp
+++ b/llvm/lib/Support/KnownBits.cpp
@@ -343,7 +343,7 @@ KnownBits KnownBits::shl(const KnownBits &LHS, const KnownBits &RHS, bool NUW,
}
KnownBits KnownBits::lshr(const KnownBits &LHS, const KnownBits &RHS,
- bool ShAmtNonZero) {
+ bool ShAmtNonZero, bool /*Exact*/) {
unsigned BitWidth = LHS.getBitWidth();
auto ShiftByConst = [&](const KnownBits &LHS, unsigned ShiftAmt) {
KnownBits Known = LHS;
@@ -389,7 +389,7 @@ KnownBits KnownBits::lshr(const KnownBits &LHS, const KnownBits &RHS,
}
KnownBits KnownBits::ashr(const KnownBits &LHS, const KnownBits &RHS,
- bool ShAmtNonZero) {
+ bool ShAmtNonZero, bool /*Exact*/) {
unsigned BitWidth = LHS.getBitWidth();
auto ShiftByConst = [&](const KnownBits &LHS, unsigned ShiftAmt) {
KnownBits Known = LHS;
>From 895ce6745fe8abc81ce8ba0215ed4fdc1f965024 Mon Sep 17 00:00:00 2001
From: Noah Goldstein <goldstein.w.n at gmail.com>
Date: Tue, 5 Mar 2024 22:03:44 -0600
Subject: [PATCH 3/3] [KnownBits] Implement knownbits `lshr`/`ashr` with exact
flag
The exact flag basically allows us to set an upper bound on shift
amount when we have a known 1 in `LHS`.
Typically we deduce exact using knownbits (on non-exact incoming
shifts), so this is particularly impactful, but may be useful in some
circumstances.
---
llvm/lib/Support/KnownBits.cpp | 28 +++++++++++++++++--
.../Analysis/ValueTracking/knownbits-shift.ll | 8 ++----
llvm/unittests/Support/KnownBitsTest.cpp | 26 +++++++++++++++++
3 files changed, 54 insertions(+), 8 deletions(-)
diff --git a/llvm/lib/Support/KnownBits.cpp b/llvm/lib/Support/KnownBits.cpp
index ed25e52b9ace67..c33c3680825a10 100644
--- a/llvm/lib/Support/KnownBits.cpp
+++ b/llvm/lib/Support/KnownBits.cpp
@@ -343,7 +343,7 @@ KnownBits KnownBits::shl(const KnownBits &LHS, const KnownBits &RHS, bool NUW,
}
KnownBits KnownBits::lshr(const KnownBits &LHS, const KnownBits &RHS,
- bool ShAmtNonZero, bool /*Exact*/) {
+ bool ShAmtNonZero, bool Exact) {
unsigned BitWidth = LHS.getBitWidth();
auto ShiftByConst = [&](const KnownBits &LHS, unsigned ShiftAmt) {
KnownBits Known = LHS;
@@ -367,6 +367,18 @@ KnownBits KnownBits::lshr(const KnownBits &LHS, const KnownBits &RHS,
// Find the common bits from all possible shifts.
APInt MaxValue = RHS.getMaxValue();
unsigned MaxShiftAmount = getMaxShiftAmount(MaxValue, BitWidth);
+
+ // If exact, bound MaxShiftAmount to first known 1 in LHS.
+ if (Exact) {
+ unsigned FirstOne = LHS.countMaxTrailingZeros();
+ if (FirstOne < MinShiftAmount) {
+ // Always poison. Return zero because we don't like returning conflict.
+ Known.setAllZero();
+ return Known;
+ }
+ MaxShiftAmount = std::min(MaxShiftAmount, FirstOne);
+ }
+
unsigned ShiftAmtZeroMask = RHS.Zero.zextOrTrunc(32).getZExtValue();
unsigned ShiftAmtOneMask = RHS.One.zextOrTrunc(32).getZExtValue();
Known.Zero.setAllBits();
@@ -389,7 +401,7 @@ KnownBits KnownBits::lshr(const KnownBits &LHS, const KnownBits &RHS,
}
KnownBits KnownBits::ashr(const KnownBits &LHS, const KnownBits &RHS,
- bool ShAmtNonZero, bool /*Exact*/) {
+ bool ShAmtNonZero, bool Exact) {
unsigned BitWidth = LHS.getBitWidth();
auto ShiftByConst = [&](const KnownBits &LHS, unsigned ShiftAmt) {
KnownBits Known = LHS;
@@ -415,6 +427,18 @@ KnownBits KnownBits::ashr(const KnownBits &LHS, const KnownBits &RHS,
// Find the common bits from all possible shifts.
APInt MaxValue = RHS.getMaxValue();
unsigned MaxShiftAmount = getMaxShiftAmount(MaxValue, BitWidth);
+
+ // If exact, bound MaxShiftAmount to first known 1 in LHS.
+ if (Exact) {
+ unsigned FirstOne = LHS.countMaxTrailingZeros();
+ if (FirstOne < MinShiftAmount) {
+ // Always poison. Return zero because we don't like returning conflict.
+ Known.setAllZero();
+ return Known;
+ }
+ MaxShiftAmount = std::min(MaxShiftAmount, FirstOne);
+ }
+
unsigned ShiftAmtZeroMask = RHS.Zero.zextOrTrunc(32).getZExtValue();
unsigned ShiftAmtOneMask = RHS.One.zextOrTrunc(32).getZExtValue();
Known.Zero.setAllBits();
diff --git a/llvm/test/Analysis/ValueTracking/knownbits-shift.ll b/llvm/test/Analysis/ValueTracking/knownbits-shift.ll
index 3235f69b5221a1..5cb355eff5a699 100644
--- a/llvm/test/Analysis/ValueTracking/knownbits-shift.ll
+++ b/llvm/test/Analysis/ValueTracking/knownbits-shift.ll
@@ -3,9 +3,7 @@
define i8 @simplify_lshr_with_exact(i8 %x) {
; CHECK-LABEL: @simplify_lshr_with_exact(
-; CHECK-NEXT: [[SHR:%.*]] = lshr exact i8 6, [[X:%.*]]
-; CHECK-NEXT: [[R:%.*]] = and i8 [[SHR]], 2
-; CHECK-NEXT: ret i8 [[R]]
+; CHECK-NEXT: ret i8 2
;
%shr = lshr exact i8 6, %x
%r = and i8 %shr, 2
@@ -14,9 +12,7 @@ define i8 @simplify_lshr_with_exact(i8 %x) {
define i8 @simplify_ashr_with_exact(i8 %x) {
; CHECK-LABEL: @simplify_ashr_with_exact(
-; CHECK-NEXT: [[SHR:%.*]] = ashr exact i8 -122, [[X:%.*]]
-; CHECK-NEXT: [[R:%.*]] = and i8 [[SHR]], 2
-; CHECK-NEXT: ret i8 [[R]]
+; CHECK-NEXT: ret i8 2
;
%shr = ashr exact i8 -122, %x
%r = and i8 %shr, 2
diff --git a/llvm/unittests/Support/KnownBitsTest.cpp b/llvm/unittests/Support/KnownBitsTest.cpp
index 658f3796721c4e..7c183e9626f985 100644
--- a/llvm/unittests/Support/KnownBitsTest.cpp
+++ b/llvm/unittests/Support/KnownBitsTest.cpp
@@ -516,6 +516,19 @@ TEST(KnownBitsTest, BinaryExhaustive) {
return N1.lshr(N2);
},
checkOptimalityBinary, /* RefinePoisonToZero */ true);
+ testBinaryOpExhaustive(
+ [](const KnownBits &Known1, const KnownBits &Known2) {
+ return KnownBits::lshr(Known1, Known2, /*ShAmtNonZero=*/false,
+ /*Exact=*/true);
+ },
+ [](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
+ if (N2.uge(N2.getBitWidth()))
+ return std::nullopt;
+ if (!N1.extractBits(N2.getZExtValue(), 0).isZero())
+ return std::nullopt;
+ return N1.lshr(N2);
+ },
+ checkOptimalityBinary, /* RefinePoisonToZero */ true);
testBinaryOpExhaustive(
[](const KnownBits &Known1, const KnownBits &Known2) {
return KnownBits::ashr(Known1, Known2);
@@ -526,6 +539,19 @@ TEST(KnownBitsTest, BinaryExhaustive) {
return N1.ashr(N2);
},
checkOptimalityBinary, /* RefinePoisonToZero */ true);
+ testBinaryOpExhaustive(
+ [](const KnownBits &Known1, const KnownBits &Known2) {
+ return KnownBits::ashr(Known1, Known2, /*ShAmtNonZero=*/false,
+ /*Exact=*/true);
+ },
+ [](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
+ if (N2.uge(N2.getBitWidth()))
+ return std::nullopt;
+ if (!N1.extractBits(N2.getZExtValue(), 0).isZero())
+ return std::nullopt;
+ return N1.ashr(N2);
+ },
+ checkOptimalityBinary, /* RefinePoisonToZero */ true);
testBinaryOpExhaustive(
[](const KnownBits &Known1, const KnownBits &Known2) {
More information about the llvm-commits
mailing list