[llvm] [ConstantRange] Improve `shlWithNoWrap` (PR #101800)

Yingwei Zheng via llvm-commits llvm-commits at lists.llvm.org
Sun Aug 4 08:47:22 PDT 2024


https://github.com/dtcxzyw updated https://github.com/llvm/llvm-project/pull/101800

>From c10b3d4d4ff0ed15766f5872dc6fc740dca7a42d Mon Sep 17 00:00:00 2001
From: Yingwei Zheng <dtcxzyw2333 at gmail.com>
Date: Sat, 3 Aug 2024 15:54:57 +0800
Subject: [PATCH 1/7] [ConstantRange] Add pre-commit tests. NFC.

---
 .../Transforms/CorrelatedValuePropagation/shl.ll  | 15 +++++++++++++++
 1 file changed, 15 insertions(+)

diff --git a/llvm/test/Transforms/CorrelatedValuePropagation/shl.ll b/llvm/test/Transforms/CorrelatedValuePropagation/shl.ll
index 8b4dbc98425bf..b5b943a20bff2 100644
--- a/llvm/test/Transforms/CorrelatedValuePropagation/shl.ll
+++ b/llvm/test/Transforms/CorrelatedValuePropagation/shl.ll
@@ -474,3 +474,18 @@ define i1 @shl_nuw_nsw_test4(i32 %x, i32 range(i32 0, 32) %k) {
   %cmp = icmp eq i64 %shl, -9223372036854775808
   ret i1 %cmp
 }
+
+define i1 @shl_nuw_nsw_test5(i32 %x) {
+; CHECK-LABEL: @shl_nuw_nsw_test5(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[SHL:%.*]] = shl nuw nsw i32 768, [[X:%.*]]
+; CHECK-NEXT:    [[ADD:%.*]] = add nuw i32 [[SHL]], 1846
+; CHECK-NEXT:    [[CMP:%.*]] = icmp sgt i32 [[ADD]], 0
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+entry:
+  %shl = shl nuw nsw i32 768, %x
+  %add = add nuw i32 %shl, 1846
+  %cmp = icmp sgt i32 %add, 0
+  ret i1 %cmp
+}

>From 763bea5cd1f1b5a7005fb2ac9aea0da6227c5d79 Mon Sep 17 00:00:00 2001
From: Yingwei Zheng <dtcxzyw2333 at gmail.com>
Date: Sat, 3 Aug 2024 16:04:46 +0800
Subject: [PATCH 2/7] [ConstantRange] Improve shlWithNoWarp

---
 llvm/lib/IR/ConstantRange.cpp                 | 28 +++++++++++++++----
 .../CorrelatedValuePropagation/shl.ll         |  7 ++---
 2 files changed, 26 insertions(+), 9 deletions(-)

diff --git a/llvm/lib/IR/ConstantRange.cpp b/llvm/lib/IR/ConstantRange.cpp
index 50b211a302e8f..24c86641f76df 100644
--- a/llvm/lib/IR/ConstantRange.cpp
+++ b/llvm/lib/IR/ConstantRange.cpp
@@ -1624,12 +1624,30 @@ ConstantRange ConstantRange::shlWithNoWrap(const ConstantRange &Other,
     return getEmpty();
 
   ConstantRange Result = shl(Other);
+  KnownBits Known = toKnownBits();
+
+  if (NoWrapKind & OverflowingBinaryOperator::NoSignedWrap) {
+    ConstantRange ShAmtRange = Other;
+    if (isAllNonNegative())
+      ShAmtRange = ShAmtRange.intersectWith(
+          ConstantRange(APInt::getZero(getBitWidth()),
+                        APInt(getBitWidth(), Known.countMaxLeadingZeros())),
+          Unsigned);
+    else if (isAllNegative())
+      ShAmtRange = ShAmtRange.intersectWith(
+          ConstantRange(APInt::getZero(getBitWidth()),
+                        APInt(getBitWidth(), Known.countMaxLeadingOnes())),
+          Unsigned);
+    Result = Result.intersectWith(sshl_sat(ShAmtRange), RangeType);
+  }
 
-  if (NoWrapKind & OverflowingBinaryOperator::NoSignedWrap)
-    Result = Result.intersectWith(sshl_sat(Other), RangeType);
-
-  if (NoWrapKind & OverflowingBinaryOperator::NoUnsignedWrap)
-    Result = Result.intersectWith(ushl_sat(Other), RangeType);
+  if (NoWrapKind & OverflowingBinaryOperator::NoUnsignedWrap) {
+    ConstantRange ShAmtRange =
+        getNonEmpty(APInt::getZero(getBitWidth()),
+                    APInt(getBitWidth(), Known.countMaxLeadingZeros() + 1));
+    Result = Result.intersectWith(
+        ushl_sat(Other.intersectWith(ShAmtRange, Unsigned)), RangeType);
+  }
 
   return Result;
 }
diff --git a/llvm/test/Transforms/CorrelatedValuePropagation/shl.ll b/llvm/test/Transforms/CorrelatedValuePropagation/shl.ll
index b5b943a20bff2..a55081e1604e6 100644
--- a/llvm/test/Transforms/CorrelatedValuePropagation/shl.ll
+++ b/llvm/test/Transforms/CorrelatedValuePropagation/shl.ll
@@ -105,7 +105,7 @@ exit:
 define i8 @test5(i8 %b) {
 ; CHECK-LABEL: @test5(
 ; CHECK-NEXT:    [[SHL:%.*]] = shl nuw nsw i8 0, [[B:%.*]]
-; CHECK-NEXT:    ret i8 [[SHL]]
+; CHECK-NEXT:    ret i8 0
 ;
   %shl = shl i8 0, %b
   ret i8 %shl
@@ -479,9 +479,8 @@ define i1 @shl_nuw_nsw_test5(i32 %x) {
 ; CHECK-LABEL: @shl_nuw_nsw_test5(
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[SHL:%.*]] = shl nuw nsw i32 768, [[X:%.*]]
-; CHECK-NEXT:    [[ADD:%.*]] = add nuw i32 [[SHL]], 1846
-; CHECK-NEXT:    [[CMP:%.*]] = icmp sgt i32 [[ADD]], 0
-; CHECK-NEXT:    ret i1 [[CMP]]
+; CHECK-NEXT:    [[ADD:%.*]] = add nuw nsw i32 [[SHL]], 1846
+; CHECK-NEXT:    ret i1 true
 ;
 entry:
   %shl = shl nuw nsw i32 768, %x

>From 9f7da466a27b14629831b9e3d8a3d8ea3b51df1b Mon Sep 17 00:00:00 2001
From: Yingwei Zheng <dtcxzyw2333 at gmail.com>
Date: Sat, 3 Aug 2024 16:09:16 +0800
Subject: [PATCH 3/7] [ConstantRange] Early exit

---
 llvm/lib/IR/ConstantRange.cpp | 3 +++
 1 file changed, 3 insertions(+)

diff --git a/llvm/lib/IR/ConstantRange.cpp b/llvm/lib/IR/ConstantRange.cpp
index 24c86641f76df..062e0e7a3b251 100644
--- a/llvm/lib/IR/ConstantRange.cpp
+++ b/llvm/lib/IR/ConstantRange.cpp
@@ -1624,6 +1624,9 @@ ConstantRange ConstantRange::shlWithNoWrap(const ConstantRange &Other,
     return getEmpty();
 
   ConstantRange Result = shl(Other);
+  if (!NoWrapKind)
+    return Result;
+
   KnownBits Known = toKnownBits();
 
   if (NoWrapKind & OverflowingBinaryOperator::NoSignedWrap) {

>From a9efe755565400ee488aa4fc294d6a9a5de9fe1a Mon Sep 17 00:00:00 2001
From: Yingwei Zheng <dtcxzyw2333 at gmail.com>
Date: Sat, 3 Aug 2024 22:59:06 +0800
Subject: [PATCH 4/7] [ConstantRange] Address review comments.

---
 llvm/lib/IR/ConstantRange.cpp                 | 20 ++++++++++++++-----
 .../CorrelatedValuePropagation/shl.ll         |  2 +-
 llvm/unittests/IR/ConstantRangeTest.cpp       | 14 +++++++++++++
 3 files changed, 30 insertions(+), 6 deletions(-)

diff --git a/llvm/lib/IR/ConstantRange.cpp b/llvm/lib/IR/ConstantRange.cpp
index 062e0e7a3b251..64b679a74cad1 100644
--- a/llvm/lib/IR/ConstantRange.cpp
+++ b/llvm/lib/IR/ConstantRange.cpp
@@ -1645,11 +1645,21 @@ ConstantRange ConstantRange::shlWithNoWrap(const ConstantRange &Other,
   }
 
   if (NoWrapKind & OverflowingBinaryOperator::NoUnsignedWrap) {
-    ConstantRange ShAmtRange =
-        getNonEmpty(APInt::getZero(getBitWidth()),
-                    APInt(getBitWidth(), Known.countMaxLeadingZeros() + 1));
-    Result = Result.intersectWith(
-        ushl_sat(Other.intersectWith(ShAmtRange, Unsigned)), RangeType);
+    bool Overflow;
+    APInt LHSMin = getUnsignedMin();
+    APInt MinShl = LHSMin.ushl_ov(Other.getUnsignedMin(), Overflow);
+    if (Overflow)
+      return getEmpty();
+    APInt LHSMax = getUnsignedMax();
+    APInt MaxShl = LHSMax << Other.getUnsignedMax().getLimitedValue(
+                       LHSMax.countLeadingZeros());
+    if (LHSMin.countLeadingZeros() != LHSMax.countLeadingZeros())
+      MaxShl = APIntOps::umax(
+          MaxShl, APInt::getHighBitsSet(
+                      getBitWidth(),
+                      getBitWidth() - Other.getUnsignedMax().getLimitedValue(
+                                          LHSMax.countLeadingZeros() + 1)));
+    Result = Result.intersectWith(getNonEmpty(MinShl, MaxShl + 1), RangeType);
   }
 
   return Result;
diff --git a/llvm/test/Transforms/CorrelatedValuePropagation/shl.ll b/llvm/test/Transforms/CorrelatedValuePropagation/shl.ll
index a55081e1604e6..1d6e54c9a488a 100644
--- a/llvm/test/Transforms/CorrelatedValuePropagation/shl.ll
+++ b/llvm/test/Transforms/CorrelatedValuePropagation/shl.ll
@@ -86,7 +86,7 @@ define i8 @test4(i8 %a, i8 %b) {
 ; CHECK-NEXT:    br i1 [[CMP]], label [[BB:%.*]], label [[EXIT:%.*]]
 ; CHECK:       bb:
 ; CHECK-NEXT:    [[SHL:%.*]] = shl nuw nsw i8 [[A:%.*]], [[B]]
-; CHECK-NEXT:    ret i8 -1
+; CHECK-NEXT:    ret i8 [[SHL]]
 ; CHECK:       exit:
 ; CHECK-NEXT:    ret i8 0
 ;
diff --git a/llvm/unittests/IR/ConstantRangeTest.cpp b/llvm/unittests/IR/ConstantRangeTest.cpp
index 1705f3e6af977..9063e034a65b2 100644
--- a/llvm/unittests/IR/ConstantRangeTest.cpp
+++ b/llvm/unittests/IR/ConstantRangeTest.cpp
@@ -1542,6 +1542,20 @@ TEST_F(ConstantRangeTest, ShlWithNoWrap) {
         return Res1;
       },
       PreferSmallest, CheckCorrectnessOnly);
+
+  EXPECT_EQ(One.shlWithNoWrap(Full, OBO::NoSignedWrap),
+            ConstantRange(APInt(16, 10), APInt(16, 20481)));
+  EXPECT_EQ(One.shlWithNoWrap(Full, OBO::NoUnsignedWrap),
+            ConstantRange(APInt(16, 10), APInt(16, -24575)));
+  EXPECT_EQ(One.shlWithNoWrap(Full, OBO::NoSignedWrap | OBO::NoUnsignedWrap),
+            ConstantRange(APInt(16, 10), APInt(16, 20481)));
+  ConstantRange NegOne(APInt(16, 0xffff));
+  EXPECT_EQ(NegOne.shlWithNoWrap(Full, OBO::NoSignedWrap),
+            ConstantRange(APInt(16, -32768), APInt(16, 0)));
+  EXPECT_EQ(NegOne.shlWithNoWrap(Full, OBO::NoUnsignedWrap), NegOne);
+  EXPECT_EQ(ConstantRange(APInt(16, 768))
+                .shlWithNoWrap(Full, OBO::NoSignedWrap | OBO::NoUnsignedWrap),
+            ConstantRange(APInt(16, 768), APInt(16, 24577)));
 }
 
 TEST_F(ConstantRangeTest, Lshr) {

>From a973fd2e03dd9642dd806dda9b9f7ccd581d9dcf Mon Sep 17 00:00:00 2001
From: Yingwei Zheng <dtcxzyw2333 at gmail.com>
Date: Sun, 4 Aug 2024 03:17:09 +0800
Subject: [PATCH 5/7] [ConstantRange] Make shl nuw optimal

---
 llvm/lib/IR/ConstantRange.cpp           | 35 +++++++++++++------------
 llvm/unittests/IR/ConstantRangeTest.cpp |  5 +++-
 2 files changed, 22 insertions(+), 18 deletions(-)

diff --git a/llvm/lib/IR/ConstantRange.cpp b/llvm/lib/IR/ConstantRange.cpp
index 64b679a74cad1..9aec72c7dd8c2 100644
--- a/llvm/lib/IR/ConstantRange.cpp
+++ b/llvm/lib/IR/ConstantRange.cpp
@@ -1627,19 +1627,17 @@ ConstantRange ConstantRange::shlWithNoWrap(const ConstantRange &Other,
   if (!NoWrapKind)
     return Result;
 
-  KnownBits Known = toKnownBits();
-
   if (NoWrapKind & OverflowingBinaryOperator::NoSignedWrap) {
-    ConstantRange ShAmtRange = Other;
+    std::optional<unsigned> ShAmtBound;
     if (isAllNonNegative())
-      ShAmtRange = ShAmtRange.intersectWith(
-          ConstantRange(APInt::getZero(getBitWidth()),
-                        APInt(getBitWidth(), Known.countMaxLeadingZeros())),
-          Unsigned);
+      ShAmtBound = getSignedMin().countLeadingZeros();
     else if (isAllNegative())
+      ShAmtBound = getSignedMax().countLeadingOnes();
+    ConstantRange ShAmtRange = Other;
+    if (ShAmtBound)
       ShAmtRange = ShAmtRange.intersectWith(
-          ConstantRange(APInt::getZero(getBitWidth()),
-                        APInt(getBitWidth(), Known.countMaxLeadingOnes())),
+          ConstantRange(APInt(getBitWidth(), 0),
+                        APInt(getBitWidth(), *ShAmtBound)),
           Unsigned);
     Result = Result.intersectWith(sshl_sat(ShAmtRange), RangeType);
   }
@@ -1647,18 +1645,21 @@ ConstantRange ConstantRange::shlWithNoWrap(const ConstantRange &Other,
   if (NoWrapKind & OverflowingBinaryOperator::NoUnsignedWrap) {
     bool Overflow;
     APInt LHSMin = getUnsignedMin();
-    APInt MinShl = LHSMin.ushl_ov(Other.getUnsignedMin(), Overflow);
+    unsigned RHSMin = Other.getUnsignedMin().getLimitedValue(getBitWidth());
+    APInt MinShl = LHSMin.ushl_ov(RHSMin, Overflow);
     if (Overflow)
       return getEmpty();
     APInt LHSMax = getUnsignedMax();
-    APInt MaxShl = LHSMax << Other.getUnsignedMax().getLimitedValue(
-                       LHSMax.countLeadingZeros());
-    if (LHSMin.countLeadingZeros() != LHSMax.countLeadingZeros())
+    unsigned RHSMax = Other.getUnsignedMax().getLimitedValue(getBitWidth());
+    APInt MaxShl = MinShl;
+    unsigned MaxShAmt = LHSMax.countLeadingZeros();
+    if (RHSMin <= MaxShAmt)
+      MaxShl = LHSMax << std::min(RHSMax, MaxShAmt);
+    RHSMin = std::max(RHSMin, MaxShAmt + 1);
+    RHSMax = std::min(RHSMax, LHSMin.countLeadingZeros());
+    if (RHSMin <= RHSMax)
       MaxShl = APIntOps::umax(
-          MaxShl, APInt::getHighBitsSet(
-                      getBitWidth(),
-                      getBitWidth() - Other.getUnsignedMax().getLimitedValue(
-                                          LHSMax.countLeadingZeros() + 1)));
+          MaxShl, APInt::getHighBitsSet(getBitWidth(), getBitWidth() - RHSMin));
     Result = Result.intersectWith(getNonEmpty(MinShl, MaxShl + 1), RangeType);
   }
 
diff --git a/llvm/unittests/IR/ConstantRangeTest.cpp b/llvm/unittests/IR/ConstantRangeTest.cpp
index 9063e034a65b2..9363462cec65b 100644
--- a/llvm/unittests/IR/ConstantRangeTest.cpp
+++ b/llvm/unittests/IR/ConstantRangeTest.cpp
@@ -1515,7 +1515,7 @@ TEST_F(ConstantRangeTest, ShlWithNoWrap) {
           return std::nullopt;
         return Res;
       },
-      PreferSmallest, CheckCorrectnessOnly);
+      PreferSmallest, CheckNonWrappedOnly);
   TestBinaryOpExhaustive(
       [](const ConstantRange &CR1, const ConstantRange &CR2) {
         return CR1.shlWithNoWrap(CR2, OBO::NoSignedWrap);
@@ -1556,6 +1556,9 @@ TEST_F(ConstantRangeTest, ShlWithNoWrap) {
   EXPECT_EQ(ConstantRange(APInt(16, 768))
                 .shlWithNoWrap(Full, OBO::NoSignedWrap | OBO::NoUnsignedWrap),
             ConstantRange(APInt(16, 768), APInt(16, 24577)));
+  EXPECT_EQ(Full.shlWithNoWrap(ConstantRange(APInt(16, 1), APInt(16, 16)),
+                               OBO::NoUnsignedWrap),
+            ConstantRange(APInt(16, 0), APInt(16, -1)));
 }
 
 TEST_F(ConstantRangeTest, Lshr) {

>From 15a72fceb6b501577e4afc64bc883f59c393832b Mon Sep 17 00:00:00 2001
From: Yingwei Zheng <dtcxzyw2333 at gmail.com>
Date: Sun, 4 Aug 2024 16:08:12 +0800
Subject: [PATCH 6/7] [ConstantRange] Avoid unnecessary computations

---
 llvm/lib/IR/ConstantRange.cpp           | 30 ++++++++++++++++---------
 llvm/unittests/IR/ConstantRangeTest.cpp |  7 ++++--
 2 files changed, 24 insertions(+), 13 deletions(-)

diff --git a/llvm/lib/IR/ConstantRange.cpp b/llvm/lib/IR/ConstantRange.cpp
index 9aec72c7dd8c2..6efc60a4c1034 100644
--- a/llvm/lib/IR/ConstantRange.cpp
+++ b/llvm/lib/IR/ConstantRange.cpp
@@ -1623,11 +1623,7 @@ ConstantRange ConstantRange::shlWithNoWrap(const ConstantRange &Other,
   if (isEmptySet() || Other.isEmptySet())
     return getEmpty();
 
-  ConstantRange Result = shl(Other);
-  if (!NoWrapKind)
-    return Result;
-
-  if (NoWrapKind & OverflowingBinaryOperator::NoSignedWrap) {
+  auto ComputeShlWithNSW = [&] {
     std::optional<unsigned> ShAmtBound;
     if (isAllNonNegative())
       ShAmtBound = getSignedMin().countLeadingZeros();
@@ -1639,10 +1635,10 @@ ConstantRange ConstantRange::shlWithNoWrap(const ConstantRange &Other,
           ConstantRange(APInt(getBitWidth(), 0),
                         APInt(getBitWidth(), *ShAmtBound)),
           Unsigned);
-    Result = Result.intersectWith(sshl_sat(ShAmtRange), RangeType);
-  }
+    return sshl_sat(ShAmtRange);
+  };
 
-  if (NoWrapKind & OverflowingBinaryOperator::NoUnsignedWrap) {
+  auto ComputeShlWithNUW = [&] {
     bool Overflow;
     APInt LHSMin = getUnsignedMin();
     unsigned RHSMin = Other.getUnsignedMin().getLimitedValue(getBitWidth());
@@ -1660,10 +1656,22 @@ ConstantRange ConstantRange::shlWithNoWrap(const ConstantRange &Other,
     if (RHSMin <= RHSMax)
       MaxShl = APIntOps::umax(
           MaxShl, APInt::getHighBitsSet(getBitWidth(), getBitWidth() - RHSMin));
-    Result = Result.intersectWith(getNonEmpty(MinShl, MaxShl + 1), RangeType);
-  }
+    return getNonEmpty(MinShl, MaxShl + 1);
+  };
 
-  return Result;
+  switch (NoWrapKind) {
+  case 0:
+    return shl(Other);
+  case OverflowingBinaryOperator::NoSignedWrap:
+    return shl(Other).intersectWith(ComputeShlWithNSW(), RangeType);
+  case OverflowingBinaryOperator::NoUnsignedWrap:
+    return ComputeShlWithNUW();
+  case OverflowingBinaryOperator::NoSignedWrap |
+      OverflowingBinaryOperator::NoUnsignedWrap:
+    return ComputeShlWithNSW().intersectWith(ComputeShlWithNUW(), RangeType);
+  default:
+    llvm_unreachable("Invalid NoWrapKind");
+  }
 }
 
 ConstantRange
diff --git a/llvm/unittests/IR/ConstantRangeTest.cpp b/llvm/unittests/IR/ConstantRangeTest.cpp
index 9363462cec65b..370617ea9fb29 100644
--- a/llvm/unittests/IR/ConstantRangeTest.cpp
+++ b/llvm/unittests/IR/ConstantRangeTest.cpp
@@ -1506,7 +1506,9 @@ TEST_F(ConstantRangeTest, ShlWithNoWrap) {
   using OBO = OverflowingBinaryOperator;
   TestBinaryOpExhaustive(
       [](const ConstantRange &CR1, const ConstantRange &CR2) {
-        return CR1.shlWithNoWrap(CR2, OBO::NoUnsignedWrap);
+        ConstantRange Res = CR1.shlWithNoWrap(CR2, OBO::NoUnsignedWrap);
+        EXPECT_TRUE(CR1.shl(CR2).contains(Res));
+        return Res;
       },
       [](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
         bool IsOverflow;
@@ -1518,7 +1520,8 @@ TEST_F(ConstantRangeTest, ShlWithNoWrap) {
       PreferSmallest, CheckNonWrappedOnly);
   TestBinaryOpExhaustive(
       [](const ConstantRange &CR1, const ConstantRange &CR2) {
-        return CR1.shlWithNoWrap(CR2, OBO::NoSignedWrap);
+        ConstantRange Res = CR1.shlWithNoWrap(CR2, OBO::NoSignedWrap);
+        return Res;
       },
       [](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
         bool IsOverflow;

>From 677d8eea077a3fdd9f0dde377663efdcb71d4449 Mon Sep 17 00:00:00 2001
From: Yingwei Zheng <dtcxzyw2333 at gmail.com>
Date: Sun, 4 Aug 2024 23:42:53 +0800
Subject: [PATCH 7/7] [ConstantRange] Make shl nsw optimal

---
 llvm/lib/IR/ConstantRange.cpp           | 124 ++++++++++++++++--------
 llvm/unittests/IR/ConstantRangeTest.cpp |  19 +++-
 2 files changed, 101 insertions(+), 42 deletions(-)

diff --git a/llvm/lib/IR/ConstantRange.cpp b/llvm/lib/IR/ConstantRange.cpp
index 6efc60a4c1034..c389d7214defc 100644
--- a/llvm/lib/IR/ConstantRange.cpp
+++ b/llvm/lib/IR/ConstantRange.cpp
@@ -1617,58 +1617,104 @@ ConstantRange::shl(const ConstantRange &Other) const {
   return ConstantRange::getNonEmpty(std::move(Min), std::move(Max) + 1);
 }
 
+static ConstantRange computeShlNUW(const ConstantRange &LHS,
+                                   const ConstantRange &RHS) {
+  unsigned BitWidth = LHS.getBitWidth();
+  bool Overflow;
+  APInt LHSMin = LHS.getUnsignedMin();
+  unsigned RHSMin = RHS.getUnsignedMin().getLimitedValue(BitWidth);
+  APInt MinShl = LHSMin.ushl_ov(RHSMin, Overflow);
+  if (Overflow)
+    return ConstantRange::getEmpty(BitWidth);
+  APInt LHSMax = LHS.getUnsignedMax();
+  unsigned RHSMax = RHS.getUnsignedMax().getLimitedValue(BitWidth);
+  APInt MaxShl = MinShl;
+  unsigned MaxShAmt = LHSMax.countLeadingZeros();
+  if (RHSMin <= MaxShAmt)
+    MaxShl = LHSMax << std::min(RHSMax, MaxShAmt);
+  RHSMin = std::max(RHSMin, MaxShAmt + 1);
+  RHSMax = std::min(RHSMax, LHSMin.countLeadingZeros());
+  if (RHSMin <= RHSMax)
+    MaxShl = APIntOps::umax(MaxShl,
+                            APInt::getHighBitsSet(BitWidth, BitWidth - RHSMin));
+  return ConstantRange::getNonEmpty(MinShl, MaxShl + 1);
+}
+
+static ConstantRange computeShlNSWWithNNegLHS(const APInt &LHSMin,
+                                              const APInt &LHSMax,
+                                              unsigned RHSMin,
+                                              unsigned RHSMax) {
+  unsigned BitWidth = LHSMin.getBitWidth();
+  bool Overflow;
+  APInt MinShl = LHSMin.sshl_ov(RHSMin, Overflow);
+  if (Overflow)
+    return ConstantRange::getEmpty(BitWidth);
+  APInt MaxShl = MinShl;
+  unsigned MaxShAmt = LHSMax.countLeadingZeros() - 1;
+  if (RHSMin <= MaxShAmt)
+    MaxShl = LHSMax << std::min(RHSMax, MaxShAmt);
+  RHSMin = std::max(RHSMin, MaxShAmt + 1);
+  RHSMax = std::min(RHSMax, LHSMin.countLeadingZeros() - 1);
+  if (RHSMin <= RHSMax)
+    MaxShl = APIntOps::umax(MaxShl,
+                            APInt::getBitsSet(BitWidth, RHSMin, BitWidth - 1));
+  return ConstantRange::getNonEmpty(MinShl, MaxShl + 1);
+}
+
+static ConstantRange computeShlNSWWithNegLHS(const APInt &LHSMin,
+                                             const APInt &LHSMax,
+                                             unsigned RHSMin, unsigned RHSMax) {
+  unsigned BitWidth = LHSMin.getBitWidth();
+  bool Overflow;
+  APInt MaxShl = LHSMax.sshl_ov(RHSMin, Overflow);
+  if (Overflow)
+    return ConstantRange::getEmpty(BitWidth);
+  APInt MinShl = MaxShl;
+  unsigned MaxShAmt = LHSMin.countLeadingOnes() - 1;
+  if (RHSMin <= MaxShAmt)
+    MinShl = LHSMin.shl(std::min(RHSMax, MaxShAmt));
+  RHSMin = std::max(RHSMin, MaxShAmt + 1);
+  RHSMax = std::min(RHSMax, LHSMax.countLeadingOnes() - 1);
+  if (RHSMin <= RHSMax)
+    MinShl = APInt::getSignMask(BitWidth);
+  return ConstantRange::getNonEmpty(MinShl, MaxShl + 1);
+}
+
+static ConstantRange computeShlNSW(const ConstantRange &LHS,
+                                   const ConstantRange &RHS) {
+  unsigned BitWidth = LHS.getBitWidth();
+  unsigned RHSMin = RHS.getUnsignedMin().getLimitedValue(BitWidth);
+  unsigned RHSMax = RHS.getUnsignedMax().getLimitedValue(BitWidth);
+  APInt LHSMin = LHS.getSignedMin();
+  APInt LHSMax = LHS.getSignedMax();
+  if (LHSMin.isNonNegative())
+    return computeShlNSWWithNNegLHS(LHSMin, LHSMax, RHSMin, RHSMax);
+  else if (LHSMax.isNegative())
+    return computeShlNSWWithNegLHS(LHSMin, LHSMax, RHSMin, RHSMax);
+  return computeShlNSWWithNNegLHS(APInt::getZero(BitWidth), LHSMax, RHSMin,
+                                  RHSMax)
+      .unionWith(computeShlNSWWithNegLHS(LHSMin, APInt::getAllOnes(BitWidth),
+                                         RHSMin, RHSMax),
+                 ConstantRange::Signed);
+}
+
 ConstantRange ConstantRange::shlWithNoWrap(const ConstantRange &Other,
                                            unsigned NoWrapKind,
                                            PreferredRangeType RangeType) const {
   if (isEmptySet() || Other.isEmptySet())
     return getEmpty();
 
-  auto ComputeShlWithNSW = [&] {
-    std::optional<unsigned> ShAmtBound;
-    if (isAllNonNegative())
-      ShAmtBound = getSignedMin().countLeadingZeros();
-    else if (isAllNegative())
-      ShAmtBound = getSignedMax().countLeadingOnes();
-    ConstantRange ShAmtRange = Other;
-    if (ShAmtBound)
-      ShAmtRange = ShAmtRange.intersectWith(
-          ConstantRange(APInt(getBitWidth(), 0),
-                        APInt(getBitWidth(), *ShAmtBound)),
-          Unsigned);
-    return sshl_sat(ShAmtRange);
-  };
-
-  auto ComputeShlWithNUW = [&] {
-    bool Overflow;
-    APInt LHSMin = getUnsignedMin();
-    unsigned RHSMin = Other.getUnsignedMin().getLimitedValue(getBitWidth());
-    APInt MinShl = LHSMin.ushl_ov(RHSMin, Overflow);
-    if (Overflow)
-      return getEmpty();
-    APInt LHSMax = getUnsignedMax();
-    unsigned RHSMax = Other.getUnsignedMax().getLimitedValue(getBitWidth());
-    APInt MaxShl = MinShl;
-    unsigned MaxShAmt = LHSMax.countLeadingZeros();
-    if (RHSMin <= MaxShAmt)
-      MaxShl = LHSMax << std::min(RHSMax, MaxShAmt);
-    RHSMin = std::max(RHSMin, MaxShAmt + 1);
-    RHSMax = std::min(RHSMax, LHSMin.countLeadingZeros());
-    if (RHSMin <= RHSMax)
-      MaxShl = APIntOps::umax(
-          MaxShl, APInt::getHighBitsSet(getBitWidth(), getBitWidth() - RHSMin));
-    return getNonEmpty(MinShl, MaxShl + 1);
-  };
-
   switch (NoWrapKind) {
   case 0:
     return shl(Other);
   case OverflowingBinaryOperator::NoSignedWrap:
-    return shl(Other).intersectWith(ComputeShlWithNSW(), RangeType);
+    return computeShlNSW(*this, Other);
   case OverflowingBinaryOperator::NoUnsignedWrap:
-    return ComputeShlWithNUW();
+    return computeShlNUW(*this, Other);
   case OverflowingBinaryOperator::NoSignedWrap |
       OverflowingBinaryOperator::NoUnsignedWrap:
-    return ComputeShlWithNSW().intersectWith(ComputeShlWithNUW(), RangeType);
+    return computeShlNSW(*this, Other)
+        .intersectWith(computeShlNUW(*this, Other), RangeType);
   default:
     llvm_unreachable("Invalid NoWrapKind");
   }
diff --git a/llvm/unittests/IR/ConstantRangeTest.cpp b/llvm/unittests/IR/ConstantRangeTest.cpp
index 370617ea9fb29..4815117458b9a 100644
--- a/llvm/unittests/IR/ConstantRangeTest.cpp
+++ b/llvm/unittests/IR/ConstantRangeTest.cpp
@@ -228,6 +228,12 @@ static bool CheckNonSignWrappedOnly(const ConstantRange &CR1,
   return !CR1.isSignWrappedSet() && !CR2.isSignWrappedSet();
 }
 
+static bool
+CheckNoSignedWrappedLHSAndNoWrappedRHSOnly(const ConstantRange &CR1,
+                                           const ConstantRange &CR2) {
+  return !CR1.isSignWrappedSet() && !CR2.isWrappedSet();
+}
+
 static bool CheckNonWrappedOrSignWrappedOnly(const ConstantRange &CR1,
                                              const ConstantRange &CR2) {
   return !CR1.isWrappedSet() && !CR1.isSignWrappedSet() &&
@@ -1520,8 +1526,7 @@ TEST_F(ConstantRangeTest, ShlWithNoWrap) {
       PreferSmallest, CheckNonWrappedOnly);
   TestBinaryOpExhaustive(
       [](const ConstantRange &CR1, const ConstantRange &CR2) {
-        ConstantRange Res = CR1.shlWithNoWrap(CR2, OBO::NoSignedWrap);
-        return Res;
+        return CR1.shlWithNoWrap(CR2, OBO::NoSignedWrap);
       },
       [](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
         bool IsOverflow;
@@ -1530,7 +1535,7 @@ TEST_F(ConstantRangeTest, ShlWithNoWrap) {
           return std::nullopt;
         return Res;
       },
-      PreferSmallest, CheckCorrectnessOnly);
+      PreferSmallestSigned, CheckNoSignedWrappedLHSAndNoWrappedRHSOnly);
   TestBinaryOpExhaustive(
       [](const ConstantRange &CR1, const ConstantRange &CR2) {
         return CR1.shlWithNoWrap(CR2, OBO::NoUnsignedWrap | OBO::NoSignedWrap);
@@ -1562,6 +1567,14 @@ TEST_F(ConstantRangeTest, ShlWithNoWrap) {
   EXPECT_EQ(Full.shlWithNoWrap(ConstantRange(APInt(16, 1), APInt(16, 16)),
                                OBO::NoUnsignedWrap),
             ConstantRange(APInt(16, 0), APInt(16, -1)));
+  EXPECT_EQ(ConstantRange(APInt(4, 3), APInt(4, -8))
+                .shlWithNoWrap(ConstantRange(APInt(4, 0), APInt(4, 4)),
+                               OBO::NoSignedWrap),
+            ConstantRange(APInt(4, 3), APInt(4, -8)));
+  EXPECT_EQ(ConstantRange(APInt(4, -1), APInt(4, 0))
+                .shlWithNoWrap(ConstantRange(APInt(4, 1), APInt(4, 4)),
+                               OBO::NoSignedWrap),
+            ConstantRange(APInt(4, -8), APInt(4, -1)));
 }
 
 TEST_F(ConstantRangeTest, Lshr) {



More information about the llvm-commits mailing list