[llvm] InstSimplify: teach simplifyICmpWithConstant about samesign (PR #125899)

via llvm-commits llvm-commits at lists.llvm.org
Wed Feb 5 09:57:51 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-analysis

Author: Ramkumar Ramachandra (artagnon)

<details>
<summary>Changes</summary>

 We have chosen to change ConstantRange::makeAllowedICmpRegion to respect samesign information, noting that ConstantRange::makeExactICmpRegion should not be modified.

---
Full diff: https://github.com/llvm/llvm-project/pull/125899.diff


4 Files Affected:

- (modified) llvm/include/llvm/IR/ConstantRange.h (+4-1) 
- (modified) llvm/lib/Analysis/InstructionSimplify.cpp (+1-1) 
- (modified) llvm/lib/IR/ConstantRange.cpp (+50-44) 
- (modified) llvm/test/Analysis/ValueTracking/constant-ranges.ll (+36) 


``````````diff
diff --git a/llvm/include/llvm/IR/ConstantRange.h b/llvm/include/llvm/IR/ConstantRange.h
index d086c25390fd227..a40de0a792ef0aa 100644
--- a/llvm/include/llvm/IR/ConstantRange.h
+++ b/llvm/include/llvm/IR/ConstantRange.h
@@ -32,6 +32,7 @@
 #define LLVM_IR_CONSTANTRANGE_H
 
 #include "llvm/ADT/APInt.h"
+#include "llvm/IR/CmpPredicate.h"
 #include "llvm/IR/InstrTypes.h"
 #include "llvm/IR/Instruction.h"
 #include "llvm/Support/Compiler.h"
@@ -99,8 +100,10 @@ class [[nodiscard]] ConstantRange {
   /// answer is not representable as a ConstantRange, the return value will be a
   /// proper superset of the above.
   ///
+  /// Note that we respect samesign information on the icmp.
+  ///
   /// Example: Pred = ult and Other = i8 [2, 5) returns Result = [0, 4)
-  static ConstantRange makeAllowedICmpRegion(CmpInst::Predicate Pred,
+  static ConstantRange makeAllowedICmpRegion(CmpPredicate Pred,
                                              const ConstantRange &Other);
 
   /// Produce the largest range such that all values in the returned range
diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index 3cbc4107433ef3d..7a5a7a39efb1eb9 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -3012,7 +3012,7 @@ static Value *simplifyICmpWithConstant(CmpPredicate Pred, Value *LHS,
   }
 
   // Rule out tautological comparisons (eg., ult 0 or uge 0).
-  ConstantRange RHS_CR = ConstantRange::makeExactICmpRegion(Pred, *C);
+  ConstantRange RHS_CR = ConstantRange::makeAllowedICmpRegion(Pred, *C);
   if (RHS_CR.isEmptySet())
     return ConstantInt::getFalse(ITy);
   if (RHS_CR.isFullSet())
diff --git a/llvm/lib/IR/ConstantRange.cpp b/llvm/lib/IR/ConstantRange.cpp
index 35664353989929d..e776ec38b27bfe2 100644
--- a/llvm/lib/IR/ConstantRange.cpp
+++ b/llvm/lib/IR/ConstantRange.cpp
@@ -95,54 +95,60 @@ KnownBits ConstantRange::toKnownBits() const {
   return Known;
 }
 
-ConstantRange ConstantRange::makeAllowedICmpRegion(CmpInst::Predicate Pred,
+ConstantRange ConstantRange::makeAllowedICmpRegion(CmpPredicate Pred,
                                                    const ConstantRange &CR) {
   if (CR.isEmptySet())
     return CR;
 
-  uint32_t W = CR.getBitWidth();
-  switch (Pred) {
-  default:
-    llvm_unreachable("Invalid ICmp predicate to makeAllowedICmpRegion()");
-  case CmpInst::ICMP_EQ:
-    return CR;
-  case CmpInst::ICMP_NE:
-    if (CR.isSingleElement())
-      return ConstantRange(CR.getUpper(), CR.getLower());
-    return getFull(W);
-  case CmpInst::ICMP_ULT: {
-    APInt UMax(CR.getUnsignedMax());
-    if (UMax.isMinValue())
-      return getEmpty(W);
-    return ConstantRange(APInt::getMinValue(W), std::move(UMax));
-  }
-  case CmpInst::ICMP_SLT: {
-    APInt SMax(CR.getSignedMax());
-    if (SMax.isMinSignedValue())
-      return getEmpty(W);
-    return ConstantRange(APInt::getSignedMinValue(W), std::move(SMax));
-  }
-  case CmpInst::ICMP_ULE:
-    return getNonEmpty(APInt::getMinValue(W), CR.getUnsignedMax() + 1);
-  case CmpInst::ICMP_SLE:
-    return getNonEmpty(APInt::getSignedMinValue(W), CR.getSignedMax() + 1);
-  case CmpInst::ICMP_UGT: {
-    APInt UMin(CR.getUnsignedMin());
-    if (UMin.isMaxValue())
-      return getEmpty(W);
-    return ConstantRange(std::move(UMin) + 1, APInt::getZero(W));
-  }
-  case CmpInst::ICMP_SGT: {
-    APInt SMin(CR.getSignedMin());
-    if (SMin.isMaxSignedValue())
-      return getEmpty(W);
-    return ConstantRange(std::move(SMin) + 1, APInt::getSignedMinValue(W));
-  }
-  case CmpInst::ICMP_UGE:
-    return getNonEmpty(CR.getUnsignedMin(), APInt::getZero(W));
-  case CmpInst::ICMP_SGE:
-    return getNonEmpty(CR.getSignedMin(), APInt::getSignedMinValue(W));
-  }
+  auto CheckPred = [CR](CmpInst::Predicate P) {
+    uint32_t W = CR.getBitWidth();
+    switch (P) {
+    default:
+      llvm_unreachable("Invalid ICmp predicate to makeAllowedICmpRegion()");
+    case CmpInst::ICMP_EQ:
+      return CR;
+    case CmpInst::ICMP_NE:
+      if (CR.isSingleElement())
+        return ConstantRange(CR.getUpper(), CR.getLower());
+      return getFull(W);
+    case CmpInst::ICMP_ULT: {
+      APInt UMax(CR.getUnsignedMax());
+      if (UMax.isMinValue())
+        return getEmpty(W);
+      return ConstantRange(APInt::getMinValue(W), std::move(UMax));
+    }
+    case CmpInst::ICMP_SLT: {
+      APInt SMax(CR.getSignedMax());
+      if (SMax.isMinSignedValue())
+        return getEmpty(W);
+      return ConstantRange(APInt::getSignedMinValue(W), std::move(SMax));
+    }
+    case CmpInst::ICMP_ULE:
+      return getNonEmpty(APInt::getMinValue(W), CR.getUnsignedMax() + 1);
+    case CmpInst::ICMP_SLE:
+      return getNonEmpty(APInt::getSignedMinValue(W), CR.getSignedMax() + 1);
+    case CmpInst::ICMP_UGT: {
+      APInt UMin(CR.getUnsignedMin());
+      if (UMin.isMaxValue())
+        return getEmpty(W);
+      return ConstantRange(std::move(UMin) + 1, APInt::getZero(W));
+    }
+    case CmpInst::ICMP_SGT: {
+      APInt SMin(CR.getSignedMin());
+      if (SMin.isMaxSignedValue())
+        return getEmpty(W);
+      return ConstantRange(std::move(SMin) + 1, APInt::getSignedMinValue(W));
+    }
+    case CmpInst::ICMP_UGE:
+      return getNonEmpty(CR.getUnsignedMin(), APInt::getZero(W));
+    case CmpInst::ICMP_SGE:
+      return getNonEmpty(CR.getSignedMin(), APInt::getSignedMinValue(W));
+    }
+  };
+  if (Pred.hasSameSign())
+    return CheckPred(Pred).unionWith(
+        CheckPred(ICmpInst::getFlippedSignednessPredicate(Pred)));
+  return CheckPred(Pred);
 }
 
 ConstantRange ConstantRange::makeSatisfyingICmpRegion(CmpInst::Predicate Pred,
diff --git a/llvm/test/Analysis/ValueTracking/constant-ranges.ll b/llvm/test/Analysis/ValueTracking/constant-ranges.ll
index c440cfad889d3b4..2e9731895bff3ce 100644
--- a/llvm/test/Analysis/ValueTracking/constant-ranges.ll
+++ b/llvm/test/Analysis/ValueTracking/constant-ranges.ll
@@ -160,6 +160,15 @@ define i1 @srem_posC_okay0(i8 %x) {
   ret i1 %r
 }
 
+define i1 @srem_posC_okay0_samesign(i8 %x) {
+; CHECK-LABEL: @srem_posC_okay0_samesign(
+; CHECK-NEXT:    ret i1 true
+;
+  %val = srem i8 34, %x
+  %r = icmp samesign ule i8 %val, 34
+  ret i1 %r
+}
+
 define i1 @srem_posC_okay1(i8 %x) {
 ; CHECK-LABEL: @srem_posC_okay1(
 ; CHECK-NEXT:    ret i1 true
@@ -169,6 +178,15 @@ define i1 @srem_posC_okay1(i8 %x) {
   ret i1 %r
 }
 
+define i1 @srem_posC_okay1_samesign(i8 %x) {
+; CHECK-LABEL: @srem_posC_okay1_samesign(
+; CHECK-NEXT:    ret i1 true
+;
+  %val = srem i8 34, %x
+  %r = icmp samesign uge i8 %val, -3
+  ret i1 %r
+}
+
 define i1 @srem_negC_okay0(i8 %x) {
 ; CHECK-LABEL: @srem_negC_okay0(
 ; CHECK-NEXT:    ret i1 true
@@ -178,6 +196,15 @@ define i1 @srem_negC_okay0(i8 %x) {
   ret i1 %r
 }
 
+define i1 @srem_negC_okay0_samesign(i8 %x) {
+; CHECK-LABEL: @srem_negC_okay0_samesign(
+; CHECK-NEXT:    ret i1 true
+;
+  %val = srem i8 -34, %x
+  %r = icmp samesign ule i8 %val, 0
+  ret i1 %r
+}
+
 define i1 @srem_negC_okay1(i8 %x) {
 ; CHECK-LABEL: @srem_negC_okay1(
 ; CHECK-NEXT:    ret i1 true
@@ -187,6 +214,15 @@ define i1 @srem_negC_okay1(i8 %x) {
   ret i1 %r
 }
 
+define i1 @srem_negC_okay1_samesign(i8 %x) {
+; CHECK-LABEL: @srem_negC_okay1_samesign(
+; CHECK-NEXT:    ret i1 true
+;
+  %val = srem i8 -34, %x
+  %r = icmp samesign uge i8 %val, -34
+  ret i1 %r
+}
+
 define i1 @srem_posC_fail0(i8 %x) {
 ; CHECK-LABEL: @srem_posC_fail0(
 ; CHECK-NEXT:    [[VAL:%.*]] = srem i8 34, [[X:%.*]]

``````````

</details>


https://github.com/llvm/llvm-project/pull/125899


More information about the llvm-commits mailing list