[llvm] [ConstantRange] Improve `shlWithNoWrap` (PR #101800)

via llvm-commits llvm-commits at lists.llvm.org
Sat Aug 3 01:10:51 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-ir

Author: Yingwei Zheng (dtcxzyw)

<details>
<summary>Changes</summary>

Closes https://github.com/dtcxzyw/llvm-tools/issues/22.


---
Full diff: https://github.com/llvm/llvm-project/pull/101800.diff


2 Files Affected:

- (modified) llvm/lib/IR/ConstantRange.cpp (+25-4) 
- (modified) llvm/test/Transforms/CorrelatedValuePropagation/shl.ll (+15-1) 


``````````diff
diff --git a/llvm/lib/IR/ConstantRange.cpp b/llvm/lib/IR/ConstantRange.cpp
index 50b211a302e8f..062e0e7a3b251 100644
--- a/llvm/lib/IR/ConstantRange.cpp
+++ b/llvm/lib/IR/ConstantRange.cpp
@@ -1624,12 +1624,33 @@ ConstantRange ConstantRange::shlWithNoWrap(const ConstantRange &Other,
     return getEmpty();
 
   ConstantRange Result = shl(Other);
+  if (!NoWrapKind)
+    return Result;
 
-  if (NoWrapKind & OverflowingBinaryOperator::NoSignedWrap)
-    Result = Result.intersectWith(sshl_sat(Other), RangeType);
+  KnownBits Known = toKnownBits();
+
+  if (NoWrapKind & OverflowingBinaryOperator::NoSignedWrap) {
+    ConstantRange ShAmtRange = Other;
+    if (isAllNonNegative())
+      ShAmtRange = ShAmtRange.intersectWith(
+          ConstantRange(APInt::getZero(getBitWidth()),
+                        APInt(getBitWidth(), Known.countMaxLeadingZeros())),
+          Unsigned);
+    else if (isAllNegative())
+      ShAmtRange = ShAmtRange.intersectWith(
+          ConstantRange(APInt::getZero(getBitWidth()),
+                        APInt(getBitWidth(), Known.countMaxLeadingOnes())),
+          Unsigned);
+    Result = Result.intersectWith(sshl_sat(ShAmtRange), RangeType);
+  }
 
-  if (NoWrapKind & OverflowingBinaryOperator::NoUnsignedWrap)
-    Result = Result.intersectWith(ushl_sat(Other), RangeType);
+  if (NoWrapKind & OverflowingBinaryOperator::NoUnsignedWrap) {
+    ConstantRange ShAmtRange =
+        getNonEmpty(APInt::getZero(getBitWidth()),
+                    APInt(getBitWidth(), Known.countMaxLeadingZeros() + 1));
+    Result = Result.intersectWith(
+        ushl_sat(Other.intersectWith(ShAmtRange, Unsigned)), RangeType);
+  }
 
   return Result;
 }
diff --git a/llvm/test/Transforms/CorrelatedValuePropagation/shl.ll b/llvm/test/Transforms/CorrelatedValuePropagation/shl.ll
index 8b4dbc98425bf..a55081e1604e6 100644
--- a/llvm/test/Transforms/CorrelatedValuePropagation/shl.ll
+++ b/llvm/test/Transforms/CorrelatedValuePropagation/shl.ll
@@ -105,7 +105,7 @@ exit:
 define i8 @test5(i8 %b) {
 ; CHECK-LABEL: @test5(
 ; CHECK-NEXT:    [[SHL:%.*]] = shl nuw nsw i8 0, [[B:%.*]]
-; CHECK-NEXT:    ret i8 [[SHL]]
+; CHECK-NEXT:    ret i8 0
 ;
   %shl = shl i8 0, %b
   ret i8 %shl
@@ -474,3 +474,17 @@ define i1 @shl_nuw_nsw_test4(i32 %x, i32 range(i32 0, 32) %k) {
   %cmp = icmp eq i64 %shl, -9223372036854775808
   ret i1 %cmp
 }
+
+define i1 @shl_nuw_nsw_test5(i32 %x) {
+; CHECK-LABEL: @shl_nuw_nsw_test5(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[SHL:%.*]] = shl nuw nsw i32 768, [[X:%.*]]
+; CHECK-NEXT:    [[ADD:%.*]] = add nuw nsw i32 [[SHL]], 1846
+; CHECK-NEXT:    ret i1 true
+;
+entry:
+  %shl = shl nuw nsw i32 768, %x
+  %add = add nuw i32 %shl, 1846
+  %cmp = icmp sgt i32 %add, 0
+  ret i1 %cmp
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/101800


More information about the llvm-commits mailing list