[llvm] [ConstantRange] Improve `shlWithNoWrap` (PR #101800)
Yingwei Zheng via llvm-commits
llvm-commits at lists.llvm.org
Sun Aug 4 01:08:41 PDT 2024
https://github.com/dtcxzyw updated https://github.com/llvm/llvm-project/pull/101800
>From c10b3d4d4ff0ed15766f5872dc6fc740dca7a42d Mon Sep 17 00:00:00 2001
From: Yingwei Zheng <dtcxzyw2333 at gmail.com>
Date: Sat, 3 Aug 2024 15:54:57 +0800
Subject: [PATCH 1/6] [ConstantRange] Add pre-commit tests. NFC.
---
.../Transforms/CorrelatedValuePropagation/shl.ll | 15 +++++++++++++++
1 file changed, 15 insertions(+)
diff --git a/llvm/test/Transforms/CorrelatedValuePropagation/shl.ll b/llvm/test/Transforms/CorrelatedValuePropagation/shl.ll
index 8b4dbc98425bf..b5b943a20bff2 100644
--- a/llvm/test/Transforms/CorrelatedValuePropagation/shl.ll
+++ b/llvm/test/Transforms/CorrelatedValuePropagation/shl.ll
@@ -474,3 +474,18 @@ define i1 @shl_nuw_nsw_test4(i32 %x, i32 range(i32 0, 32) %k) {
%cmp = icmp eq i64 %shl, -9223372036854775808
ret i1 %cmp
}
+
+define i1 @shl_nuw_nsw_test5(i32 %x) {
+; CHECK-LABEL: @shl_nuw_nsw_test5(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[SHL:%.*]] = shl nuw nsw i32 768, [[X:%.*]]
+; CHECK-NEXT: [[ADD:%.*]] = add nuw i32 [[SHL]], 1846
+; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i32 [[ADD]], 0
+; CHECK-NEXT: ret i1 [[CMP]]
+;
+entry:
+ %shl = shl nuw nsw i32 768, %x
+ %add = add nuw i32 %shl, 1846
+ %cmp = icmp sgt i32 %add, 0
+ ret i1 %cmp
+}
>From 763bea5cd1f1b5a7005fb2ac9aea0da6227c5d79 Mon Sep 17 00:00:00 2001
From: Yingwei Zheng <dtcxzyw2333 at gmail.com>
Date: Sat, 3 Aug 2024 16:04:46 +0800
Subject: [PATCH 2/6] [ConstantRange] Improve shlWithNoWarp
---
llvm/lib/IR/ConstantRange.cpp | 28 +++++++++++++++----
.../CorrelatedValuePropagation/shl.ll | 7 ++---
2 files changed, 26 insertions(+), 9 deletions(-)
diff --git a/llvm/lib/IR/ConstantRange.cpp b/llvm/lib/IR/ConstantRange.cpp
index 50b211a302e8f..24c86641f76df 100644
--- a/llvm/lib/IR/ConstantRange.cpp
+++ b/llvm/lib/IR/ConstantRange.cpp
@@ -1624,12 +1624,30 @@ ConstantRange ConstantRange::shlWithNoWrap(const ConstantRange &Other,
return getEmpty();
ConstantRange Result = shl(Other);
+ KnownBits Known = toKnownBits();
+
+ if (NoWrapKind & OverflowingBinaryOperator::NoSignedWrap) {
+ ConstantRange ShAmtRange = Other;
+ if (isAllNonNegative())
+ ShAmtRange = ShAmtRange.intersectWith(
+ ConstantRange(APInt::getZero(getBitWidth()),
+ APInt(getBitWidth(), Known.countMaxLeadingZeros())),
+ Unsigned);
+ else if (isAllNegative())
+ ShAmtRange = ShAmtRange.intersectWith(
+ ConstantRange(APInt::getZero(getBitWidth()),
+ APInt(getBitWidth(), Known.countMaxLeadingOnes())),
+ Unsigned);
+ Result = Result.intersectWith(sshl_sat(ShAmtRange), RangeType);
+ }
- if (NoWrapKind & OverflowingBinaryOperator::NoSignedWrap)
- Result = Result.intersectWith(sshl_sat(Other), RangeType);
-
- if (NoWrapKind & OverflowingBinaryOperator::NoUnsignedWrap)
- Result = Result.intersectWith(ushl_sat(Other), RangeType);
+ if (NoWrapKind & OverflowingBinaryOperator::NoUnsignedWrap) {
+ ConstantRange ShAmtRange =
+ getNonEmpty(APInt::getZero(getBitWidth()),
+ APInt(getBitWidth(), Known.countMaxLeadingZeros() + 1));
+ Result = Result.intersectWith(
+ ushl_sat(Other.intersectWith(ShAmtRange, Unsigned)), RangeType);
+ }
return Result;
}
diff --git a/llvm/test/Transforms/CorrelatedValuePropagation/shl.ll b/llvm/test/Transforms/CorrelatedValuePropagation/shl.ll
index b5b943a20bff2..a55081e1604e6 100644
--- a/llvm/test/Transforms/CorrelatedValuePropagation/shl.ll
+++ b/llvm/test/Transforms/CorrelatedValuePropagation/shl.ll
@@ -105,7 +105,7 @@ exit:
define i8 @test5(i8 %b) {
; CHECK-LABEL: @test5(
; CHECK-NEXT: [[SHL:%.*]] = shl nuw nsw i8 0, [[B:%.*]]
-; CHECK-NEXT: ret i8 [[SHL]]
+; CHECK-NEXT: ret i8 0
;
%shl = shl i8 0, %b
ret i8 %shl
@@ -479,9 +479,8 @@ define i1 @shl_nuw_nsw_test5(i32 %x) {
; CHECK-LABEL: @shl_nuw_nsw_test5(
; CHECK-NEXT: entry:
; CHECK-NEXT: [[SHL:%.*]] = shl nuw nsw i32 768, [[X:%.*]]
-; CHECK-NEXT: [[ADD:%.*]] = add nuw i32 [[SHL]], 1846
-; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i32 [[ADD]], 0
-; CHECK-NEXT: ret i1 [[CMP]]
+; CHECK-NEXT: [[ADD:%.*]] = add nuw nsw i32 [[SHL]], 1846
+; CHECK-NEXT: ret i1 true
;
entry:
%shl = shl nuw nsw i32 768, %x
>From 9f7da466a27b14629831b9e3d8a3d8ea3b51df1b Mon Sep 17 00:00:00 2001
From: Yingwei Zheng <dtcxzyw2333 at gmail.com>
Date: Sat, 3 Aug 2024 16:09:16 +0800
Subject: [PATCH 3/6] [ConstantRange] Early exit
---
llvm/lib/IR/ConstantRange.cpp | 3 +++
1 file changed, 3 insertions(+)
diff --git a/llvm/lib/IR/ConstantRange.cpp b/llvm/lib/IR/ConstantRange.cpp
index 24c86641f76df..062e0e7a3b251 100644
--- a/llvm/lib/IR/ConstantRange.cpp
+++ b/llvm/lib/IR/ConstantRange.cpp
@@ -1624,6 +1624,9 @@ ConstantRange ConstantRange::shlWithNoWrap(const ConstantRange &Other,
return getEmpty();
ConstantRange Result = shl(Other);
+ if (!NoWrapKind)
+ return Result;
+
KnownBits Known = toKnownBits();
if (NoWrapKind & OverflowingBinaryOperator::NoSignedWrap) {
>From a9efe755565400ee488aa4fc294d6a9a5de9fe1a Mon Sep 17 00:00:00 2001
From: Yingwei Zheng <dtcxzyw2333 at gmail.com>
Date: Sat, 3 Aug 2024 22:59:06 +0800
Subject: [PATCH 4/6] [ConstantRange] Address review comments.
---
llvm/lib/IR/ConstantRange.cpp | 20 ++++++++++++++-----
.../CorrelatedValuePropagation/shl.ll | 2 +-
llvm/unittests/IR/ConstantRangeTest.cpp | 14 +++++++++++++
3 files changed, 30 insertions(+), 6 deletions(-)
diff --git a/llvm/lib/IR/ConstantRange.cpp b/llvm/lib/IR/ConstantRange.cpp
index 062e0e7a3b251..64b679a74cad1 100644
--- a/llvm/lib/IR/ConstantRange.cpp
+++ b/llvm/lib/IR/ConstantRange.cpp
@@ -1645,11 +1645,21 @@ ConstantRange ConstantRange::shlWithNoWrap(const ConstantRange &Other,
}
if (NoWrapKind & OverflowingBinaryOperator::NoUnsignedWrap) {
- ConstantRange ShAmtRange =
- getNonEmpty(APInt::getZero(getBitWidth()),
- APInt(getBitWidth(), Known.countMaxLeadingZeros() + 1));
- Result = Result.intersectWith(
- ushl_sat(Other.intersectWith(ShAmtRange, Unsigned)), RangeType);
+ bool Overflow;
+ APInt LHSMin = getUnsignedMin();
+ APInt MinShl = LHSMin.ushl_ov(Other.getUnsignedMin(), Overflow);
+ if (Overflow)
+ return getEmpty();
+ APInt LHSMax = getUnsignedMax();
+ APInt MaxShl = LHSMax << Other.getUnsignedMax().getLimitedValue(
+ LHSMax.countLeadingZeros());
+ if (LHSMin.countLeadingZeros() != LHSMax.countLeadingZeros())
+ MaxShl = APIntOps::umax(
+ MaxShl, APInt::getHighBitsSet(
+ getBitWidth(),
+ getBitWidth() - Other.getUnsignedMax().getLimitedValue(
+ LHSMax.countLeadingZeros() + 1)));
+ Result = Result.intersectWith(getNonEmpty(MinShl, MaxShl + 1), RangeType);
}
return Result;
diff --git a/llvm/test/Transforms/CorrelatedValuePropagation/shl.ll b/llvm/test/Transforms/CorrelatedValuePropagation/shl.ll
index a55081e1604e6..1d6e54c9a488a 100644
--- a/llvm/test/Transforms/CorrelatedValuePropagation/shl.ll
+++ b/llvm/test/Transforms/CorrelatedValuePropagation/shl.ll
@@ -86,7 +86,7 @@ define i8 @test4(i8 %a, i8 %b) {
; CHECK-NEXT: br i1 [[CMP]], label [[BB:%.*]], label [[EXIT:%.*]]
; CHECK: bb:
; CHECK-NEXT: [[SHL:%.*]] = shl nuw nsw i8 [[A:%.*]], [[B]]
-; CHECK-NEXT: ret i8 -1
+; CHECK-NEXT: ret i8 [[SHL]]
; CHECK: exit:
; CHECK-NEXT: ret i8 0
;
diff --git a/llvm/unittests/IR/ConstantRangeTest.cpp b/llvm/unittests/IR/ConstantRangeTest.cpp
index 1705f3e6af977..9063e034a65b2 100644
--- a/llvm/unittests/IR/ConstantRangeTest.cpp
+++ b/llvm/unittests/IR/ConstantRangeTest.cpp
@@ -1542,6 +1542,20 @@ TEST_F(ConstantRangeTest, ShlWithNoWrap) {
return Res1;
},
PreferSmallest, CheckCorrectnessOnly);
+
+ EXPECT_EQ(One.shlWithNoWrap(Full, OBO::NoSignedWrap),
+ ConstantRange(APInt(16, 10), APInt(16, 20481)));
+ EXPECT_EQ(One.shlWithNoWrap(Full, OBO::NoUnsignedWrap),
+ ConstantRange(APInt(16, 10), APInt(16, -24575)));
+ EXPECT_EQ(One.shlWithNoWrap(Full, OBO::NoSignedWrap | OBO::NoUnsignedWrap),
+ ConstantRange(APInt(16, 10), APInt(16, 20481)));
+ ConstantRange NegOne(APInt(16, 0xffff));
+ EXPECT_EQ(NegOne.shlWithNoWrap(Full, OBO::NoSignedWrap),
+ ConstantRange(APInt(16, -32768), APInt(16, 0)));
+ EXPECT_EQ(NegOne.shlWithNoWrap(Full, OBO::NoUnsignedWrap), NegOne);
+ EXPECT_EQ(ConstantRange(APInt(16, 768))
+ .shlWithNoWrap(Full, OBO::NoSignedWrap | OBO::NoUnsignedWrap),
+ ConstantRange(APInt(16, 768), APInt(16, 24577)));
}
TEST_F(ConstantRangeTest, Lshr) {
>From a973fd2e03dd9642dd806dda9b9f7ccd581d9dcf Mon Sep 17 00:00:00 2001
From: Yingwei Zheng <dtcxzyw2333 at gmail.com>
Date: Sun, 4 Aug 2024 03:17:09 +0800
Subject: [PATCH 5/6] [ConstantRange] Make shl nuw optimal
---
llvm/lib/IR/ConstantRange.cpp | 35 +++++++++++++------------
llvm/unittests/IR/ConstantRangeTest.cpp | 5 +++-
2 files changed, 22 insertions(+), 18 deletions(-)
diff --git a/llvm/lib/IR/ConstantRange.cpp b/llvm/lib/IR/ConstantRange.cpp
index 64b679a74cad1..9aec72c7dd8c2 100644
--- a/llvm/lib/IR/ConstantRange.cpp
+++ b/llvm/lib/IR/ConstantRange.cpp
@@ -1627,19 +1627,17 @@ ConstantRange ConstantRange::shlWithNoWrap(const ConstantRange &Other,
if (!NoWrapKind)
return Result;
- KnownBits Known = toKnownBits();
-
if (NoWrapKind & OverflowingBinaryOperator::NoSignedWrap) {
- ConstantRange ShAmtRange = Other;
+ std::optional<unsigned> ShAmtBound;
if (isAllNonNegative())
- ShAmtRange = ShAmtRange.intersectWith(
- ConstantRange(APInt::getZero(getBitWidth()),
- APInt(getBitWidth(), Known.countMaxLeadingZeros())),
- Unsigned);
+ ShAmtBound = getSignedMin().countLeadingZeros();
else if (isAllNegative())
+ ShAmtBound = getSignedMax().countLeadingOnes();
+ ConstantRange ShAmtRange = Other;
+ if (ShAmtBound)
ShAmtRange = ShAmtRange.intersectWith(
- ConstantRange(APInt::getZero(getBitWidth()),
- APInt(getBitWidth(), Known.countMaxLeadingOnes())),
+ ConstantRange(APInt(getBitWidth(), 0),
+ APInt(getBitWidth(), *ShAmtBound)),
Unsigned);
Result = Result.intersectWith(sshl_sat(ShAmtRange), RangeType);
}
@@ -1647,18 +1645,21 @@ ConstantRange ConstantRange::shlWithNoWrap(const ConstantRange &Other,
if (NoWrapKind & OverflowingBinaryOperator::NoUnsignedWrap) {
bool Overflow;
APInt LHSMin = getUnsignedMin();
- APInt MinShl = LHSMin.ushl_ov(Other.getUnsignedMin(), Overflow);
+ unsigned RHSMin = Other.getUnsignedMin().getLimitedValue(getBitWidth());
+ APInt MinShl = LHSMin.ushl_ov(RHSMin, Overflow);
if (Overflow)
return getEmpty();
APInt LHSMax = getUnsignedMax();
- APInt MaxShl = LHSMax << Other.getUnsignedMax().getLimitedValue(
- LHSMax.countLeadingZeros());
- if (LHSMin.countLeadingZeros() != LHSMax.countLeadingZeros())
+ unsigned RHSMax = Other.getUnsignedMax().getLimitedValue(getBitWidth());
+ APInt MaxShl = MinShl;
+ unsigned MaxShAmt = LHSMax.countLeadingZeros();
+ if (RHSMin <= MaxShAmt)
+ MaxShl = LHSMax << std::min(RHSMax, MaxShAmt);
+ RHSMin = std::max(RHSMin, MaxShAmt + 1);
+ RHSMax = std::min(RHSMax, LHSMin.countLeadingZeros());
+ if (RHSMin <= RHSMax)
MaxShl = APIntOps::umax(
- MaxShl, APInt::getHighBitsSet(
- getBitWidth(),
- getBitWidth() - Other.getUnsignedMax().getLimitedValue(
- LHSMax.countLeadingZeros() + 1)));
+ MaxShl, APInt::getHighBitsSet(getBitWidth(), getBitWidth() - RHSMin));
Result = Result.intersectWith(getNonEmpty(MinShl, MaxShl + 1), RangeType);
}
diff --git a/llvm/unittests/IR/ConstantRangeTest.cpp b/llvm/unittests/IR/ConstantRangeTest.cpp
index 9063e034a65b2..9363462cec65b 100644
--- a/llvm/unittests/IR/ConstantRangeTest.cpp
+++ b/llvm/unittests/IR/ConstantRangeTest.cpp
@@ -1515,7 +1515,7 @@ TEST_F(ConstantRangeTest, ShlWithNoWrap) {
return std::nullopt;
return Res;
},
- PreferSmallest, CheckCorrectnessOnly);
+ PreferSmallest, CheckNonWrappedOnly);
TestBinaryOpExhaustive(
[](const ConstantRange &CR1, const ConstantRange &CR2) {
return CR1.shlWithNoWrap(CR2, OBO::NoSignedWrap);
@@ -1556,6 +1556,9 @@ TEST_F(ConstantRangeTest, ShlWithNoWrap) {
EXPECT_EQ(ConstantRange(APInt(16, 768))
.shlWithNoWrap(Full, OBO::NoSignedWrap | OBO::NoUnsignedWrap),
ConstantRange(APInt(16, 768), APInt(16, 24577)));
+ EXPECT_EQ(Full.shlWithNoWrap(ConstantRange(APInt(16, 1), APInt(16, 16)),
+ OBO::NoUnsignedWrap),
+ ConstantRange(APInt(16, 0), APInt(16, -1)));
}
TEST_F(ConstantRangeTest, Lshr) {
>From 2c335c3bf2407b5ef994b2f1af1b5e283567056e Mon Sep 17 00:00:00 2001
From: Yingwei Zheng <dtcxzyw2333 at gmail.com>
Date: Sun, 4 Aug 2024 16:08:12 +0800
Subject: [PATCH 6/6] [ConstantRange] Avoid unnecessary computations
---
llvm/lib/IR/ConstantRange.cpp | 28 +++++++++++++++----------
llvm/unittests/IR/ConstantRangeTest.cpp | 7 +++++--
2 files changed, 22 insertions(+), 13 deletions(-)
diff --git a/llvm/lib/IR/ConstantRange.cpp b/llvm/lib/IR/ConstantRange.cpp
index 9aec72c7dd8c2..1e0f24713b7b6 100644
--- a/llvm/lib/IR/ConstantRange.cpp
+++ b/llvm/lib/IR/ConstantRange.cpp
@@ -1623,11 +1623,7 @@ ConstantRange ConstantRange::shlWithNoWrap(const ConstantRange &Other,
if (isEmptySet() || Other.isEmptySet())
return getEmpty();
- ConstantRange Result = shl(Other);
- if (!NoWrapKind)
- return Result;
-
- if (NoWrapKind & OverflowingBinaryOperator::NoSignedWrap) {
+ auto ComputeShlWithNSW = [&]{
std::optional<unsigned> ShAmtBound;
if (isAllNonNegative())
ShAmtBound = getSignedMin().countLeadingZeros();
@@ -1639,10 +1635,10 @@ ConstantRange ConstantRange::shlWithNoWrap(const ConstantRange &Other,
ConstantRange(APInt(getBitWidth(), 0),
APInt(getBitWidth(), *ShAmtBound)),
Unsigned);
- Result = Result.intersectWith(sshl_sat(ShAmtRange), RangeType);
- }
+ return sshl_sat(ShAmtRange);
+ };
- if (NoWrapKind & OverflowingBinaryOperator::NoUnsignedWrap) {
+ auto ComputeShlWithNUW = [&] {
bool Overflow;
APInt LHSMin = getUnsignedMin();
unsigned RHSMin = Other.getUnsignedMin().getLimitedValue(getBitWidth());
@@ -1660,10 +1656,20 @@ ConstantRange ConstantRange::shlWithNoWrap(const ConstantRange &Other,
if (RHSMin <= RHSMax)
MaxShl = APIntOps::umax(
MaxShl, APInt::getHighBitsSet(getBitWidth(), getBitWidth() - RHSMin));
- Result = Result.intersectWith(getNonEmpty(MinShl, MaxShl + 1), RangeType);
- }
+ return getNonEmpty(MinShl, MaxShl + 1);
+ };
- return Result;
+ switch(NoWrapKind) {
+ case 0: return shl(Other);
+ case OverflowingBinaryOperator::NoSignedWrap:
+ return shl(Other).intersectWith(ComputeShlWithNSW(), RangeType);
+ case OverflowingBinaryOperator::NoUnsignedWrap:
+ return ComputeShlWithNUW();
+ case OverflowingBinaryOperator::NoSignedWrap | OverflowingBinaryOperator::NoUnsignedWrap:
+ return ComputeShlWithNSW().intersectWith(ComputeShlWithNUW(), RangeType);
+ default:
+ llvm_unreachable("Invalid NoWrapKind");
+ }
}
ConstantRange
diff --git a/llvm/unittests/IR/ConstantRangeTest.cpp b/llvm/unittests/IR/ConstantRangeTest.cpp
index 9363462cec65b..370617ea9fb29 100644
--- a/llvm/unittests/IR/ConstantRangeTest.cpp
+++ b/llvm/unittests/IR/ConstantRangeTest.cpp
@@ -1506,7 +1506,9 @@ TEST_F(ConstantRangeTest, ShlWithNoWrap) {
using OBO = OverflowingBinaryOperator;
TestBinaryOpExhaustive(
[](const ConstantRange &CR1, const ConstantRange &CR2) {
- return CR1.shlWithNoWrap(CR2, OBO::NoUnsignedWrap);
+ ConstantRange Res = CR1.shlWithNoWrap(CR2, OBO::NoUnsignedWrap);
+ EXPECT_TRUE(CR1.shl(CR2).contains(Res));
+ return Res;
},
[](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
bool IsOverflow;
@@ -1518,7 +1520,8 @@ TEST_F(ConstantRangeTest, ShlWithNoWrap) {
PreferSmallest, CheckNonWrappedOnly);
TestBinaryOpExhaustive(
[](const ConstantRange &CR1, const ConstantRange &CR2) {
- return CR1.shlWithNoWrap(CR2, OBO::NoSignedWrap);
+ ConstantRange Res = CR1.shlWithNoWrap(CR2, OBO::NoSignedWrap);
+ return Res;
},
[](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
bool IsOverflow;
More information about the llvm-commits
mailing list