[llvm] 07b29fc - [ConstantRange] Improve `shlWithNoWrap` (#101800)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Aug 6 11:00:36 PDT 2024
Author: Yingwei Zheng
Date: 2024-08-07T02:00:33+08:00
New Revision: 07b29fc808ca0842d02cf4e973381b974bfdf19f
URL: https://github.com/llvm/llvm-project/commit/07b29fc808ca0842d02cf4e973381b974bfdf19f
DIFF: https://github.com/llvm/llvm-project/commit/07b29fc808ca0842d02cf4e973381b974bfdf19f.diff
LOG: [ConstantRange] Improve `shlWithNoWrap` (#101800)
Closes https://github.com/dtcxzyw/llvm-tools/issues/22.
Added:
Modified:
llvm/lib/IR/ConstantRange.cpp
llvm/test/Transforms/CorrelatedValuePropagation/shl.ll
llvm/unittests/IR/ConstantRangeTest.cpp
Removed:
################################################################################
diff --git a/llvm/lib/IR/ConstantRange.cpp b/llvm/lib/IR/ConstantRange.cpp
index 50b211a302e8ff..c389d7214defca 100644
--- a/llvm/lib/IR/ConstantRange.cpp
+++ b/llvm/lib/IR/ConstantRange.cpp
@@ -1617,21 +1617,107 @@ ConstantRange::shl(const ConstantRange &Other) const {
return ConstantRange::getNonEmpty(std::move(Min), std::move(Max) + 1);
}
+static ConstantRange computeShlNUW(const ConstantRange &LHS,
+ const ConstantRange &RHS) {
+ unsigned BitWidth = LHS.getBitWidth();
+ bool Overflow;
+ APInt LHSMin = LHS.getUnsignedMin();
+ unsigned RHSMin = RHS.getUnsignedMin().getLimitedValue(BitWidth);
+ APInt MinShl = LHSMin.ushl_ov(RHSMin, Overflow);
+ if (Overflow)
+ return ConstantRange::getEmpty(BitWidth);
+ APInt LHSMax = LHS.getUnsignedMax();
+ unsigned RHSMax = RHS.getUnsignedMax().getLimitedValue(BitWidth);
+ APInt MaxShl = MinShl;
+ unsigned MaxShAmt = LHSMax.countLeadingZeros();
+ if (RHSMin <= MaxShAmt)
+ MaxShl = LHSMax << std::min(RHSMax, MaxShAmt);
+ RHSMin = std::max(RHSMin, MaxShAmt + 1);
+ RHSMax = std::min(RHSMax, LHSMin.countLeadingZeros());
+ if (RHSMin <= RHSMax)
+ MaxShl = APIntOps::umax(MaxShl,
+ APInt::getHighBitsSet(BitWidth, BitWidth - RHSMin));
+ return ConstantRange::getNonEmpty(MinShl, MaxShl + 1);
+}
+
+static ConstantRange computeShlNSWWithNNegLHS(const APInt &LHSMin,
+ const APInt &LHSMax,
+ unsigned RHSMin,
+ unsigned RHSMax) {
+ unsigned BitWidth = LHSMin.getBitWidth();
+ bool Overflow;
+ APInt MinShl = LHSMin.sshl_ov(RHSMin, Overflow);
+ if (Overflow)
+ return ConstantRange::getEmpty(BitWidth);
+ APInt MaxShl = MinShl;
+ unsigned MaxShAmt = LHSMax.countLeadingZeros() - 1;
+ if (RHSMin <= MaxShAmt)
+ MaxShl = LHSMax << std::min(RHSMax, MaxShAmt);
+ RHSMin = std::max(RHSMin, MaxShAmt + 1);
+ RHSMax = std::min(RHSMax, LHSMin.countLeadingZeros() - 1);
+ if (RHSMin <= RHSMax)
+ MaxShl = APIntOps::umax(MaxShl,
+ APInt::getBitsSet(BitWidth, RHSMin, BitWidth - 1));
+ return ConstantRange::getNonEmpty(MinShl, MaxShl + 1);
+}
+
+static ConstantRange computeShlNSWWithNegLHS(const APInt &LHSMin,
+ const APInt &LHSMax,
+ unsigned RHSMin, unsigned RHSMax) {
+ unsigned BitWidth = LHSMin.getBitWidth();
+ bool Overflow;
+ APInt MaxShl = LHSMax.sshl_ov(RHSMin, Overflow);
+ if (Overflow)
+ return ConstantRange::getEmpty(BitWidth);
+ APInt MinShl = MaxShl;
+ unsigned MaxShAmt = LHSMin.countLeadingOnes() - 1;
+ if (RHSMin <= MaxShAmt)
+ MinShl = LHSMin.shl(std::min(RHSMax, MaxShAmt));
+ RHSMin = std::max(RHSMin, MaxShAmt + 1);
+ RHSMax = std::min(RHSMax, LHSMax.countLeadingOnes() - 1);
+ if (RHSMin <= RHSMax)
+ MinShl = APInt::getSignMask(BitWidth);
+ return ConstantRange::getNonEmpty(MinShl, MaxShl + 1);
+}
+
+static ConstantRange computeShlNSW(const ConstantRange &LHS,
+ const ConstantRange &RHS) {
+ unsigned BitWidth = LHS.getBitWidth();
+ unsigned RHSMin = RHS.getUnsignedMin().getLimitedValue(BitWidth);
+ unsigned RHSMax = RHS.getUnsignedMax().getLimitedValue(BitWidth);
+ APInt LHSMin = LHS.getSignedMin();
+ APInt LHSMax = LHS.getSignedMax();
+ if (LHSMin.isNonNegative())
+ return computeShlNSWWithNNegLHS(LHSMin, LHSMax, RHSMin, RHSMax);
+ else if (LHSMax.isNegative())
+ return computeShlNSWWithNegLHS(LHSMin, LHSMax, RHSMin, RHSMax);
+ return computeShlNSWWithNNegLHS(APInt::getZero(BitWidth), LHSMax, RHSMin,
+ RHSMax)
+ .unionWith(computeShlNSWWithNegLHS(LHSMin, APInt::getAllOnes(BitWidth),
+ RHSMin, RHSMax),
+ ConstantRange::Signed);
+}
+
ConstantRange ConstantRange::shlWithNoWrap(const ConstantRange &Other,
unsigned NoWrapKind,
PreferredRangeType RangeType) const {
if (isEmptySet() || Other.isEmptySet())
return getEmpty();
- ConstantRange Result = shl(Other);
-
- if (NoWrapKind & OverflowingBinaryOperator::NoSignedWrap)
- Result = Result.intersectWith(sshl_sat(Other), RangeType);
-
- if (NoWrapKind & OverflowingBinaryOperator::NoUnsignedWrap)
- Result = Result.intersectWith(ushl_sat(Other), RangeType);
-
- return Result;
+ switch (NoWrapKind) {
+ case 0:
+ return shl(Other);
+ case OverflowingBinaryOperator::NoSignedWrap:
+ return computeShlNSW(*this, Other);
+ case OverflowingBinaryOperator::NoUnsignedWrap:
+ return computeShlNUW(*this, Other);
+ case OverflowingBinaryOperator::NoSignedWrap |
+ OverflowingBinaryOperator::NoUnsignedWrap:
+ return computeShlNSW(*this, Other)
+ .intersectWith(computeShlNUW(*this, Other), RangeType);
+ default:
+ llvm_unreachable("Invalid NoWrapKind");
+ }
}
ConstantRange
diff --git a/llvm/test/Transforms/CorrelatedValuePropagation/shl.ll b/llvm/test/Transforms/CorrelatedValuePropagation/shl.ll
index 8b4dbc98425bf0..1d6e54c9a488a2 100644
--- a/llvm/test/Transforms/CorrelatedValuePropagation/shl.ll
+++ b/llvm/test/Transforms/CorrelatedValuePropagation/shl.ll
@@ -86,7 +86,7 @@ define i8 @test4(i8 %a, i8 %b) {
; CHECK-NEXT: br i1 [[CMP]], label [[BB:%.*]], label [[EXIT:%.*]]
; CHECK: bb:
; CHECK-NEXT: [[SHL:%.*]] = shl nuw nsw i8 [[A:%.*]], [[B]]
-; CHECK-NEXT: ret i8 -1
+; CHECK-NEXT: ret i8 [[SHL]]
; CHECK: exit:
; CHECK-NEXT: ret i8 0
;
@@ -105,7 +105,7 @@ exit:
define i8 @test5(i8 %b) {
; CHECK-LABEL: @test5(
; CHECK-NEXT: [[SHL:%.*]] = shl nuw nsw i8 0, [[B:%.*]]
-; CHECK-NEXT: ret i8 [[SHL]]
+; CHECK-NEXT: ret i8 0
;
%shl = shl i8 0, %b
ret i8 %shl
@@ -474,3 +474,17 @@ define i1 @shl_nuw_nsw_test4(i32 %x, i32 range(i32 0, 32) %k) {
%cmp = icmp eq i64 %shl, -9223372036854775808
ret i1 %cmp
}
+
+define i1 @shl_nuw_nsw_test5(i32 %x) {
+; CHECK-LABEL: @shl_nuw_nsw_test5(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[SHL:%.*]] = shl nuw nsw i32 768, [[X:%.*]]
+; CHECK-NEXT: [[ADD:%.*]] = add nuw nsw i32 [[SHL]], 1846
+; CHECK-NEXT: ret i1 true
+;
+entry:
+ %shl = shl nuw nsw i32 768, %x
+ %add = add nuw i32 %shl, 1846
+ %cmp = icmp sgt i32 %add, 0
+ ret i1 %cmp
+}
diff --git a/llvm/unittests/IR/ConstantRangeTest.cpp b/llvm/unittests/IR/ConstantRangeTest.cpp
index 1705f3e6af9774..4815117458b9af 100644
--- a/llvm/unittests/IR/ConstantRangeTest.cpp
+++ b/llvm/unittests/IR/ConstantRangeTest.cpp
@@ -228,6 +228,12 @@ static bool CheckNonSignWrappedOnly(const ConstantRange &CR1,
return !CR1.isSignWrappedSet() && !CR2.isSignWrappedSet();
}
+static bool
+CheckNoSignedWrappedLHSAndNoWrappedRHSOnly(const ConstantRange &CR1,
+ const ConstantRange &CR2) {
+ return !CR1.isSignWrappedSet() && !CR2.isWrappedSet();
+}
+
static bool CheckNonWrappedOrSignWrappedOnly(const ConstantRange &CR1,
const ConstantRange &CR2) {
return !CR1.isWrappedSet() && !CR1.isSignWrappedSet() &&
@@ -1506,7 +1512,9 @@ TEST_F(ConstantRangeTest, ShlWithNoWrap) {
using OBO = OverflowingBinaryOperator;
TestBinaryOpExhaustive(
[](const ConstantRange &CR1, const ConstantRange &CR2) {
- return CR1.shlWithNoWrap(CR2, OBO::NoUnsignedWrap);
+ ConstantRange Res = CR1.shlWithNoWrap(CR2, OBO::NoUnsignedWrap);
+ EXPECT_TRUE(CR1.shl(CR2).contains(Res));
+ return Res;
},
[](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
bool IsOverflow;
@@ -1515,7 +1523,7 @@ TEST_F(ConstantRangeTest, ShlWithNoWrap) {
return std::nullopt;
return Res;
},
- PreferSmallest, CheckCorrectnessOnly);
+ PreferSmallest, CheckNonWrappedOnly);
TestBinaryOpExhaustive(
[](const ConstantRange &CR1, const ConstantRange &CR2) {
return CR1.shlWithNoWrap(CR2, OBO::NoSignedWrap);
@@ -1527,7 +1535,7 @@ TEST_F(ConstantRangeTest, ShlWithNoWrap) {
return std::nullopt;
return Res;
},
- PreferSmallest, CheckCorrectnessOnly);
+ PreferSmallestSigned, CheckNoSignedWrappedLHSAndNoWrappedRHSOnly);
TestBinaryOpExhaustive(
[](const ConstantRange &CR1, const ConstantRange &CR2) {
return CR1.shlWithNoWrap(CR2, OBO::NoUnsignedWrap | OBO::NoSignedWrap);
@@ -1542,6 +1550,31 @@ TEST_F(ConstantRangeTest, ShlWithNoWrap) {
return Res1;
},
PreferSmallest, CheckCorrectnessOnly);
+
+ EXPECT_EQ(One.shlWithNoWrap(Full, OBO::NoSignedWrap),
+ ConstantRange(APInt(16, 10), APInt(16, 20481)));
+ EXPECT_EQ(One.shlWithNoWrap(Full, OBO::NoUnsignedWrap),
+ ConstantRange(APInt(16, 10), APInt(16, -24575)));
+ EXPECT_EQ(One.shlWithNoWrap(Full, OBO::NoSignedWrap | OBO::NoUnsignedWrap),
+ ConstantRange(APInt(16, 10), APInt(16, 20481)));
+ ConstantRange NegOne(APInt(16, 0xffff));
+ EXPECT_EQ(NegOne.shlWithNoWrap(Full, OBO::NoSignedWrap),
+ ConstantRange(APInt(16, -32768), APInt(16, 0)));
+ EXPECT_EQ(NegOne.shlWithNoWrap(Full, OBO::NoUnsignedWrap), NegOne);
+ EXPECT_EQ(ConstantRange(APInt(16, 768))
+ .shlWithNoWrap(Full, OBO::NoSignedWrap | OBO::NoUnsignedWrap),
+ ConstantRange(APInt(16, 768), APInt(16, 24577)));
+ EXPECT_EQ(Full.shlWithNoWrap(ConstantRange(APInt(16, 1), APInt(16, 16)),
+ OBO::NoUnsignedWrap),
+ ConstantRange(APInt(16, 0), APInt(16, -1)));
+ EXPECT_EQ(ConstantRange(APInt(4, 3), APInt(4, -8))
+ .shlWithNoWrap(ConstantRange(APInt(4, 0), APInt(4, 4)),
+ OBO::NoSignedWrap),
+ ConstantRange(APInt(4, 3), APInt(4, -8)));
+ EXPECT_EQ(ConstantRange(APInt(4, -1), APInt(4, 0))
+ .shlWithNoWrap(ConstantRange(APInt(4, 1), APInt(4, 4)),
+ OBO::NoSignedWrap),
+ ConstantRange(APInt(4, -8), APInt(4, -1)));
}
TEST_F(ConstantRangeTest, Lshr) {
More information about the llvm-commits
mailing list