[llvm] [ConstantRange] Add support for `shlWithNoWrap` (PR #100594)

via llvm-commits llvm-commits at lists.llvm.org
Thu Jul 25 09:36:17 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-ir

Author: Yingwei Zheng (dtcxzyw)

<details>
<summary>Changes</summary>

This patch adds initial support for `ConstantRange:: shlWithNoWrap` to fold https://github.com/dtcxzyw/llvm-tools/issues/22. However, this patch cannot fix the original issue. Improvements for `ConstantRange::[u|s]shl_sat` will be submitted in subsequent patches.




---
Full diff: https://github.com/llvm/llvm-project/pull/100594.diff


4 Files Affected:

- (modified) llvm/include/llvm/IR/ConstantRange.h (+8) 
- (modified) llvm/lib/IR/ConstantRange.cpp (+19) 
- (modified) llvm/test/Transforms/CorrelatedValuePropagation/shl.ll (+59-5) 
- (modified) llvm/unittests/IR/ConstantRangeTest.cpp (+42) 


``````````diff
diff --git a/llvm/include/llvm/IR/ConstantRange.h b/llvm/include/llvm/IR/ConstantRange.h
index 86d0a6b35d748..d086c25390fd2 100644
--- a/llvm/include/llvm/IR/ConstantRange.h
+++ b/llvm/include/llvm/IR/ConstantRange.h
@@ -501,6 +501,14 @@ class [[nodiscard]] ConstantRange {
   /// TODO: This isn't fully implemented yet.
   ConstantRange shl(const ConstantRange &Other) const;
 
+  /// Return a new range representing the possible values resulting
+  /// from a left shift with wrap type \p NoWrapKind of a value in this
+  /// range and a value in \p Other.
+  /// If the result range is disjoint, the preferred range is determined by the
+  /// \p PreferredRangeType.
+  ConstantRange shlWithNoWrap(const ConstantRange &Other, unsigned NoWrapKind,
+                              PreferredRangeType RangeType = Smallest) const;
+
   /// Return a new range representing the possible values resulting from a
   /// logical right shift of a value in this range and a value in \p Other.
   ConstantRange lshr(const ConstantRange &Other) const;
diff --git a/llvm/lib/IR/ConstantRange.cpp b/llvm/lib/IR/ConstantRange.cpp
index 6068540cf08da..50b211a302e8f 100644
--- a/llvm/lib/IR/ConstantRange.cpp
+++ b/llvm/lib/IR/ConstantRange.cpp
@@ -988,6 +988,8 @@ ConstantRange ConstantRange::overflowingBinaryOp(Instruction::BinaryOps BinOp,
     return subWithNoWrap(Other, NoWrapKind);
   case Instruction::Mul:
     return multiplyWithNoWrap(Other, NoWrapKind);
+  case Instruction::Shl:
+    return shlWithNoWrap(Other, NoWrapKind);
   default:
     // Don't know about this Overflowing Binary Operation.
     // Conservatively fallback to plain binop handling.
@@ -1615,6 +1617,23 @@ ConstantRange::shl(const ConstantRange &Other) const {
   return ConstantRange::getNonEmpty(std::move(Min), std::move(Max) + 1);
 }
 
+ConstantRange ConstantRange::shlWithNoWrap(const ConstantRange &Other,
+                                           unsigned NoWrapKind,
+                                           PreferredRangeType RangeType) const {
+  if (isEmptySet() || Other.isEmptySet())
+    return getEmpty();
+
+  ConstantRange Result = shl(Other);
+
+  if (NoWrapKind & OverflowingBinaryOperator::NoSignedWrap)
+    Result = Result.intersectWith(sshl_sat(Other), RangeType);
+
+  if (NoWrapKind & OverflowingBinaryOperator::NoUnsignedWrap)
+    Result = Result.intersectWith(ushl_sat(Other), RangeType);
+
+  return Result;
+}
+
 ConstantRange
 ConstantRange::lshr(const ConstantRange &Other) const {
   if (isEmptySet() || Other.isEmptySet())
diff --git a/llvm/test/Transforms/CorrelatedValuePropagation/shl.ll b/llvm/test/Transforms/CorrelatedValuePropagation/shl.ll
index 88311219dee58..8b4dbc98425bf 100644
--- a/llvm/test/Transforms/CorrelatedValuePropagation/shl.ll
+++ b/llvm/test/Transforms/CorrelatedValuePropagation/shl.ll
@@ -86,7 +86,7 @@ define i8 @test4(i8 %a, i8 %b) {
 ; CHECK-NEXT:    br i1 [[CMP]], label [[BB:%.*]], label [[EXIT:%.*]]
 ; CHECK:       bb:
 ; CHECK-NEXT:    [[SHL:%.*]] = shl nuw nsw i8 [[A:%.*]], [[B]]
-; CHECK-NEXT:    ret i8 [[SHL]]
+; CHECK-NEXT:    ret i8 -1
 ; CHECK:       exit:
 ; CHECK-NEXT:    ret i8 0
 ;
@@ -382,8 +382,7 @@ define i1 @nuw_range1(i8 %b) {
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[C:%.*]] = add nuw nsw i8 [[B:%.*]], 1
 ; CHECK-NEXT:    [[SHL:%.*]] = shl nuw i8 [[C]], 2
-; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i8 [[SHL]], 0
-; CHECK-NEXT:    ret i1 [[CMP]]
+; CHECK-NEXT:    ret i1 false
 ;
 entry:
   %c = add nuw nsw i8 %b, 1
@@ -397,8 +396,7 @@ define i1 @nuw_range2(i8 %b) {
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[C:%.*]] = add nuw nsw i8 [[B:%.*]], 3
 ; CHECK-NEXT:    [[SHL:%.*]] = shl nuw i8 [[C]], 2
-; CHECK-NEXT:    [[CMP:%.*]] = icmp ult i8 [[SHL]], 2
-; CHECK-NEXT:    ret i1 [[CMP]]
+; CHECK-NEXT:    ret i1 false
 ;
 entry:
   %c = add nuw nsw i8 %b, 3
@@ -420,3 +418,59 @@ entry:
   %cmp = icmp slt i8 %c, %shl
   ret i1 %cmp
 }
+
+define i64 @shl_nuw_nsw_test1(i32 %x) {
+; CHECK-LABEL: @shl_nuw_nsw_test1(
+; CHECK-NEXT:    [[SHL1:%.*]] = shl nuw nsw i32 1, [[X:%.*]]
+; CHECK-NEXT:    [[ADD1:%.*]] = add nsw i32 [[SHL1]], -1
+; CHECK-NEXT:    [[EXT:%.*]] = zext nneg i32 [[ADD1]] to i64
+; CHECK-NEXT:    [[SHL2:%.*]] = shl nuw nsw i64 [[EXT]], 2
+; CHECK-NEXT:    [[ADD2:%.*]] = add nuw nsw i64 [[SHL2]], 39
+; CHECK-NEXT:    [[LSHR:%.*]] = lshr i64 [[ADD2]], 3
+; CHECK-NEXT:    ret i64 [[LSHR]]
+;
+  %shl1 = shl nuw nsw i32 1, %x
+  %add1 = add nsw i32 %shl1, -1
+  %ext = sext i32 %add1 to i64
+  %shl2 = shl nsw i64 %ext, 2
+  %add2 = add nsw i64 %shl2, 39
+  %lshr = lshr i64 %add2, 3
+  %and = and i64 %lshr, 4294967295
+  ret i64 %and
+}
+
+define i32 @shl_nuw_nsw_test2(i32 range(i32 -2147483248, 1) %x) {
+; CHECK-LABEL: @shl_nuw_nsw_test2(
+; CHECK-NEXT:    [[SHL:%.*]] = shl nsw i32 [[X:%.*]], 1
+; CHECK-NEXT:    ret i32 200
+;
+  %shl = shl nsw i32 %x, 1
+  %smax = call i32 @llvm.smax.i32(i32 %shl, i32 200)
+  ret i32 %smax
+}
+
+define i64 @shl_nuw_nsw_test3(i1 %cond, i64 range(i64 1, 0) %x, i64 range(i64 3, 0) %y) {
+; CHECK-LABEL: @shl_nuw_nsw_test3(
+; CHECK-NEXT:    [[SHL:%.*]] = shl nuw i64 1, [[X:%.*]]
+; CHECK-NEXT:    [[SEL:%.*]] = select i1 [[COND:%.*]], i64 [[Y:%.*]], i64 [[SHL]]
+; CHECK-NEXT:    ret i64 [[SEL]]
+;
+  %shl = shl nuw i64 1, %x
+  %sel = select i1 %cond, i64 %y, i64 %shl
+  %umax = call i64 @llvm.umax.i64(i64 %sel, i64 2)
+  ret i64 %umax
+}
+
+define i1 @shl_nuw_nsw_test4(i32 %x, i32 range(i32 0, 32) %k) {
+; CHECK-LABEL: @shl_nuw_nsw_test4(
+; CHECK-NEXT:    [[CONV:%.*]] = sext i32 [[X:%.*]] to i64
+; CHECK-NEXT:    [[SH_PROM:%.*]] = zext nneg i32 [[K:%.*]] to i64
+; CHECK-NEXT:    [[SHL:%.*]] = shl nsw i64 [[CONV]], [[SH_PROM]]
+; CHECK-NEXT:    ret i1 false
+;
+  %conv = sext i32 %x to i64
+  %sh_prom = zext nneg i32 %k to i64
+  %shl = shl nsw i64 %conv, %sh_prom
+  %cmp = icmp eq i64 %shl, -9223372036854775808
+  ret i1 %cmp
+}
diff --git a/llvm/unittests/IR/ConstantRangeTest.cpp b/llvm/unittests/IR/ConstantRangeTest.cpp
index 7977a78a7d3ec..1705f3e6af977 100644
--- a/llvm/unittests/IR/ConstantRangeTest.cpp
+++ b/llvm/unittests/IR/ConstantRangeTest.cpp
@@ -1502,6 +1502,48 @@ TEST_F(ConstantRangeTest, Shl) {
       });
 }
 
+TEST_F(ConstantRangeTest, ShlWithNoWrap) {
+  using OBO = OverflowingBinaryOperator;
+  TestBinaryOpExhaustive(
+      [](const ConstantRange &CR1, const ConstantRange &CR2) {
+        return CR1.shlWithNoWrap(CR2, OBO::NoUnsignedWrap);
+      },
+      [](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
+        bool IsOverflow;
+        APInt Res = N1.ushl_ov(N2, IsOverflow);
+        if (IsOverflow)
+          return std::nullopt;
+        return Res;
+      },
+      PreferSmallest, CheckCorrectnessOnly);
+  TestBinaryOpExhaustive(
+      [](const ConstantRange &CR1, const ConstantRange &CR2) {
+        return CR1.shlWithNoWrap(CR2, OBO::NoSignedWrap);
+      },
+      [](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
+        bool IsOverflow;
+        APInt Res = N1.sshl_ov(N2, IsOverflow);
+        if (IsOverflow)
+          return std::nullopt;
+        return Res;
+      },
+      PreferSmallest, CheckCorrectnessOnly);
+  TestBinaryOpExhaustive(
+      [](const ConstantRange &CR1, const ConstantRange &CR2) {
+        return CR1.shlWithNoWrap(CR2, OBO::NoUnsignedWrap | OBO::NoSignedWrap);
+      },
+      [](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
+        bool IsOverflow1, IsOverflow2;
+        APInt Res1 = N1.ushl_ov(N2, IsOverflow1);
+        APInt Res2 = N1.sshl_ov(N2, IsOverflow2);
+        if (IsOverflow1 || IsOverflow2)
+          return std::nullopt;
+        assert(Res1 == Res2 && "Left shift results differ?");
+        return Res1;
+      },
+      PreferSmallest, CheckCorrectnessOnly);
+}
+
 TEST_F(ConstantRangeTest, Lshr) {
   EXPECT_EQ(Full.lshr(Full), Full);
   EXPECT_EQ(Full.lshr(Empty), Empty);

``````````

</details>


https://github.com/llvm/llvm-project/pull/100594


More information about the llvm-commits mailing list