[llvm] 07b29fc - [ConstantRange] Improve `shlWithNoWrap` (#101800)

via llvm-commits llvm-commits at lists.llvm.org
Tue Aug 6 11:00:36 PDT 2024


Author: Yingwei Zheng
Date: 2024-08-07T02:00:33+08:00
New Revision: 07b29fc808ca0842d02cf4e973381b974bfdf19f

URL: https://github.com/llvm/llvm-project/commit/07b29fc808ca0842d02cf4e973381b974bfdf19f
DIFF: https://github.com/llvm/llvm-project/commit/07b29fc808ca0842d02cf4e973381b974bfdf19f.diff

LOG: [ConstantRange] Improve `shlWithNoWrap` (#101800)

Closes https://github.com/dtcxzyw/llvm-tools/issues/22.

Added: 
    

Modified: 
    llvm/lib/IR/ConstantRange.cpp
    llvm/test/Transforms/CorrelatedValuePropagation/shl.ll
    llvm/unittests/IR/ConstantRangeTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/IR/ConstantRange.cpp b/llvm/lib/IR/ConstantRange.cpp
index 50b211a302e8ff..c389d7214defca 100644
--- a/llvm/lib/IR/ConstantRange.cpp
+++ b/llvm/lib/IR/ConstantRange.cpp
@@ -1617,21 +1617,107 @@ ConstantRange::shl(const ConstantRange &Other) const {
   return ConstantRange::getNonEmpty(std::move(Min), std::move(Max) + 1);
 }
 
+static ConstantRange computeShlNUW(const ConstantRange &LHS,
+                                   const ConstantRange &RHS) {
+  unsigned BitWidth = LHS.getBitWidth();
+  bool Overflow;
+  APInt LHSMin = LHS.getUnsignedMin();
+  unsigned RHSMin = RHS.getUnsignedMin().getLimitedValue(BitWidth);
+  APInt MinShl = LHSMin.ushl_ov(RHSMin, Overflow);
+  if (Overflow)
+    return ConstantRange::getEmpty(BitWidth);
+  APInt LHSMax = LHS.getUnsignedMax();
+  unsigned RHSMax = RHS.getUnsignedMax().getLimitedValue(BitWidth);
+  APInt MaxShl = MinShl;
+  unsigned MaxShAmt = LHSMax.countLeadingZeros();
+  if (RHSMin <= MaxShAmt)
+    MaxShl = LHSMax << std::min(RHSMax, MaxShAmt);
+  RHSMin = std::max(RHSMin, MaxShAmt + 1);
+  RHSMax = std::min(RHSMax, LHSMin.countLeadingZeros());
+  if (RHSMin <= RHSMax)
+    MaxShl = APIntOps::umax(MaxShl,
+                            APInt::getHighBitsSet(BitWidth, BitWidth - RHSMin));
+  return ConstantRange::getNonEmpty(MinShl, MaxShl + 1);
+}
+
+static ConstantRange computeShlNSWWithNNegLHS(const APInt &LHSMin,
+                                              const APInt &LHSMax,
+                                              unsigned RHSMin,
+                                              unsigned RHSMax) {
+  unsigned BitWidth = LHSMin.getBitWidth();
+  bool Overflow;
+  APInt MinShl = LHSMin.sshl_ov(RHSMin, Overflow);
+  if (Overflow)
+    return ConstantRange::getEmpty(BitWidth);
+  APInt MaxShl = MinShl;
+  unsigned MaxShAmt = LHSMax.countLeadingZeros() - 1;
+  if (RHSMin <= MaxShAmt)
+    MaxShl = LHSMax << std::min(RHSMax, MaxShAmt);
+  RHSMin = std::max(RHSMin, MaxShAmt + 1);
+  RHSMax = std::min(RHSMax, LHSMin.countLeadingZeros() - 1);
+  if (RHSMin <= RHSMax)
+    MaxShl = APIntOps::umax(MaxShl,
+                            APInt::getBitsSet(BitWidth, RHSMin, BitWidth - 1));
+  return ConstantRange::getNonEmpty(MinShl, MaxShl + 1);
+}
+
+static ConstantRange computeShlNSWWithNegLHS(const APInt &LHSMin,
+                                             const APInt &LHSMax,
+                                             unsigned RHSMin, unsigned RHSMax) {
+  unsigned BitWidth = LHSMin.getBitWidth();
+  bool Overflow;
+  APInt MaxShl = LHSMax.sshl_ov(RHSMin, Overflow);
+  if (Overflow)
+    return ConstantRange::getEmpty(BitWidth);
+  APInt MinShl = MaxShl;
+  unsigned MaxShAmt = LHSMin.countLeadingOnes() - 1;
+  if (RHSMin <= MaxShAmt)
+    MinShl = LHSMin.shl(std::min(RHSMax, MaxShAmt));
+  RHSMin = std::max(RHSMin, MaxShAmt + 1);
+  RHSMax = std::min(RHSMax, LHSMax.countLeadingOnes() - 1);
+  if (RHSMin <= RHSMax)
+    MinShl = APInt::getSignMask(BitWidth);
+  return ConstantRange::getNonEmpty(MinShl, MaxShl + 1);
+}
+
+static ConstantRange computeShlNSW(const ConstantRange &LHS,
+                                   const ConstantRange &RHS) {
+  unsigned BitWidth = LHS.getBitWidth();
+  unsigned RHSMin = RHS.getUnsignedMin().getLimitedValue(BitWidth);
+  unsigned RHSMax = RHS.getUnsignedMax().getLimitedValue(BitWidth);
+  APInt LHSMin = LHS.getSignedMin();
+  APInt LHSMax = LHS.getSignedMax();
+  if (LHSMin.isNonNegative())
+    return computeShlNSWWithNNegLHS(LHSMin, LHSMax, RHSMin, RHSMax);
+  else if (LHSMax.isNegative())
+    return computeShlNSWWithNegLHS(LHSMin, LHSMax, RHSMin, RHSMax);
+  return computeShlNSWWithNNegLHS(APInt::getZero(BitWidth), LHSMax, RHSMin,
+                                  RHSMax)
+      .unionWith(computeShlNSWWithNegLHS(LHSMin, APInt::getAllOnes(BitWidth),
+                                         RHSMin, RHSMax),
+                 ConstantRange::Signed);
+}
+
 ConstantRange ConstantRange::shlWithNoWrap(const ConstantRange &Other,
                                            unsigned NoWrapKind,
                                            PreferredRangeType RangeType) const {
   if (isEmptySet() || Other.isEmptySet())
     return getEmpty();
 
-  ConstantRange Result = shl(Other);
-
-  if (NoWrapKind & OverflowingBinaryOperator::NoSignedWrap)
-    Result = Result.intersectWith(sshl_sat(Other), RangeType);
-
-  if (NoWrapKind & OverflowingBinaryOperator::NoUnsignedWrap)
-    Result = Result.intersectWith(ushl_sat(Other), RangeType);
-
-  return Result;
+  switch (NoWrapKind) {
+  case 0:
+    return shl(Other);
+  case OverflowingBinaryOperator::NoSignedWrap:
+    return computeShlNSW(*this, Other);
+  case OverflowingBinaryOperator::NoUnsignedWrap:
+    return computeShlNUW(*this, Other);
+  case OverflowingBinaryOperator::NoSignedWrap |
+      OverflowingBinaryOperator::NoUnsignedWrap:
+    return computeShlNSW(*this, Other)
+        .intersectWith(computeShlNUW(*this, Other), RangeType);
+  default:
+    llvm_unreachable("Invalid NoWrapKind");
+  }
 }
 
 ConstantRange

diff  --git a/llvm/test/Transforms/CorrelatedValuePropagation/shl.ll b/llvm/test/Transforms/CorrelatedValuePropagation/shl.ll
index 8b4dbc98425bf0..1d6e54c9a488a2 100644
--- a/llvm/test/Transforms/CorrelatedValuePropagation/shl.ll
+++ b/llvm/test/Transforms/CorrelatedValuePropagation/shl.ll
@@ -86,7 +86,7 @@ define i8 @test4(i8 %a, i8 %b) {
 ; CHECK-NEXT:    br i1 [[CMP]], label [[BB:%.*]], label [[EXIT:%.*]]
 ; CHECK:       bb:
 ; CHECK-NEXT:    [[SHL:%.*]] = shl nuw nsw i8 [[A:%.*]], [[B]]
-; CHECK-NEXT:    ret i8 -1
+; CHECK-NEXT:    ret i8 [[SHL]]
 ; CHECK:       exit:
 ; CHECK-NEXT:    ret i8 0
 ;
@@ -105,7 +105,7 @@ exit:
 define i8 @test5(i8 %b) {
 ; CHECK-LABEL: @test5(
 ; CHECK-NEXT:    [[SHL:%.*]] = shl nuw nsw i8 0, [[B:%.*]]
-; CHECK-NEXT:    ret i8 [[SHL]]
+; CHECK-NEXT:    ret i8 0
 ;
   %shl = shl i8 0, %b
   ret i8 %shl
@@ -474,3 +474,17 @@ define i1 @shl_nuw_nsw_test4(i32 %x, i32 range(i32 0, 32) %k) {
   %cmp = icmp eq i64 %shl, -9223372036854775808
   ret i1 %cmp
 }
+
+define i1 @shl_nuw_nsw_test5(i32 %x) {
+; CHECK-LABEL: @shl_nuw_nsw_test5(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[SHL:%.*]] = shl nuw nsw i32 768, [[X:%.*]]
+; CHECK-NEXT:    [[ADD:%.*]] = add nuw nsw i32 [[SHL]], 1846
+; CHECK-NEXT:    ret i1 true
+;
+entry:
+  %shl = shl nuw nsw i32 768, %x
+  %add = add nuw i32 %shl, 1846
+  %cmp = icmp sgt i32 %add, 0
+  ret i1 %cmp
+}

diff  --git a/llvm/unittests/IR/ConstantRangeTest.cpp b/llvm/unittests/IR/ConstantRangeTest.cpp
index 1705f3e6af9774..4815117458b9af 100644
--- a/llvm/unittests/IR/ConstantRangeTest.cpp
+++ b/llvm/unittests/IR/ConstantRangeTest.cpp
@@ -228,6 +228,12 @@ static bool CheckNonSignWrappedOnly(const ConstantRange &CR1,
   return !CR1.isSignWrappedSet() && !CR2.isSignWrappedSet();
 }
 
+static bool
+CheckNoSignedWrappedLHSAndNoWrappedRHSOnly(const ConstantRange &CR1,
+                                           const ConstantRange &CR2) {
+  return !CR1.isSignWrappedSet() && !CR2.isWrappedSet();
+}
+
 static bool CheckNonWrappedOrSignWrappedOnly(const ConstantRange &CR1,
                                              const ConstantRange &CR2) {
   return !CR1.isWrappedSet() && !CR1.isSignWrappedSet() &&
@@ -1506,7 +1512,9 @@ TEST_F(ConstantRangeTest, ShlWithNoWrap) {
   using OBO = OverflowingBinaryOperator;
   TestBinaryOpExhaustive(
       [](const ConstantRange &CR1, const ConstantRange &CR2) {
-        return CR1.shlWithNoWrap(CR2, OBO::NoUnsignedWrap);
+        ConstantRange Res = CR1.shlWithNoWrap(CR2, OBO::NoUnsignedWrap);
+        EXPECT_TRUE(CR1.shl(CR2).contains(Res));
+        return Res;
       },
       [](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
         bool IsOverflow;
@@ -1515,7 +1523,7 @@ TEST_F(ConstantRangeTest, ShlWithNoWrap) {
           return std::nullopt;
         return Res;
       },
-      PreferSmallest, CheckCorrectnessOnly);
+      PreferSmallest, CheckNonWrappedOnly);
   TestBinaryOpExhaustive(
       [](const ConstantRange &CR1, const ConstantRange &CR2) {
         return CR1.shlWithNoWrap(CR2, OBO::NoSignedWrap);
@@ -1527,7 +1535,7 @@ TEST_F(ConstantRangeTest, ShlWithNoWrap) {
           return std::nullopt;
         return Res;
       },
-      PreferSmallest, CheckCorrectnessOnly);
+      PreferSmallestSigned, CheckNoSignedWrappedLHSAndNoWrappedRHSOnly);
   TestBinaryOpExhaustive(
       [](const ConstantRange &CR1, const ConstantRange &CR2) {
         return CR1.shlWithNoWrap(CR2, OBO::NoUnsignedWrap | OBO::NoSignedWrap);
@@ -1542,6 +1550,31 @@ TEST_F(ConstantRangeTest, ShlWithNoWrap) {
         return Res1;
       },
       PreferSmallest, CheckCorrectnessOnly);
+
+  EXPECT_EQ(One.shlWithNoWrap(Full, OBO::NoSignedWrap),
+            ConstantRange(APInt(16, 10), APInt(16, 20481)));
+  EXPECT_EQ(One.shlWithNoWrap(Full, OBO::NoUnsignedWrap),
+            ConstantRange(APInt(16, 10), APInt(16, -24575)));
+  EXPECT_EQ(One.shlWithNoWrap(Full, OBO::NoSignedWrap | OBO::NoUnsignedWrap),
+            ConstantRange(APInt(16, 10), APInt(16, 20481)));
+  ConstantRange NegOne(APInt(16, 0xffff));
+  EXPECT_EQ(NegOne.shlWithNoWrap(Full, OBO::NoSignedWrap),
+            ConstantRange(APInt(16, -32768), APInt(16, 0)));
+  EXPECT_EQ(NegOne.shlWithNoWrap(Full, OBO::NoUnsignedWrap), NegOne);
+  EXPECT_EQ(ConstantRange(APInt(16, 768))
+                .shlWithNoWrap(Full, OBO::NoSignedWrap | OBO::NoUnsignedWrap),
+            ConstantRange(APInt(16, 768), APInt(16, 24577)));
+  EXPECT_EQ(Full.shlWithNoWrap(ConstantRange(APInt(16, 1), APInt(16, 16)),
+                               OBO::NoUnsignedWrap),
+            ConstantRange(APInt(16, 0), APInt(16, -1)));
+  EXPECT_EQ(ConstantRange(APInt(4, 3), APInt(4, -8))
+                .shlWithNoWrap(ConstantRange(APInt(4, 0), APInt(4, 4)),
+                               OBO::NoSignedWrap),
+            ConstantRange(APInt(4, 3), APInt(4, -8)));
+  EXPECT_EQ(ConstantRange(APInt(4, -1), APInt(4, 0))
+                .shlWithNoWrap(ConstantRange(APInt(4, 1), APInt(4, 4)),
+                               OBO::NoSignedWrap),
+            ConstantRange(APInt(4, -8), APInt(4, -1)));
 }
 
 TEST_F(ConstantRangeTest, Lshr) {


        


More information about the llvm-commits mailing list