[llvm] [InstSimplify] Infer icmp from with.overflow intrinsics (PR #75511)

via llvm-commits llvm-commits at lists.llvm.org
Thu Dec 14 10:16:21 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-analysis

Author: Yingwei Zheng (dtcxzyw)

<details>
<summary>Changes</summary>

This patch tries to simplify the pattern `Overflow | icmp pred Res, C` into `Overflow` or `true`, where `Overflow` and `Res` are return values of a call to `with.overflow` intrinsic.
Alive2: https://alive2.llvm.org/ce/z/4-LEV2

Fixes #<!-- -->75360.

---
Full diff: https://github.com/llvm/llvm-project/pull/75511.diff


4 Files Affected:

- (modified) llvm/include/llvm/Analysis/ValueTracking.h (+5) 
- (modified) llvm/lib/Analysis/InstructionSimplify.cpp (+35) 
- (modified) llvm/lib/Analysis/ValueTracking.cpp (+4-4) 
- (modified) llvm/test/Transforms/InstCombine/sadd-with-overflow.ll (+78) 


``````````diff
diff --git a/llvm/include/llvm/Analysis/ValueTracking.h b/llvm/include/llvm/Analysis/ValueTracking.h
index a3186e61b94adf..baa16306ebf5df 100644
--- a/llvm/include/llvm/Analysis/ValueTracking.h
+++ b/llvm/include/llvm/Analysis/ValueTracking.h
@@ -863,6 +863,11 @@ ConstantRange computeConstantRange(const Value *V, bool ForSigned,
                                    const DominatorTree *DT = nullptr,
                                    unsigned Depth = 0);
 
+/// Combine constant ranges from computeConstantRange() and computeKnownBits().
+ConstantRange
+computeConstantRangeIncludingKnownBits(const WithCache<const Value *> &V,
+                                       bool ForSigned, const SimplifyQuery &SQ);
+
 /// Return true if this function can prove that the instruction I will
 /// always transfer execution to one of its successors (including the next
 /// instruction that follows within a basic block). E.g. this is not
diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index 2a45acf63aa2ca..fa2e42a4c22e60 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -2313,6 +2313,36 @@ static Value *simplifyOrLogic(Value *X, Value *Y) {
   return nullptr;
 }
 
+/// Res, Overflow = xxx_with_overflow X, Y
+/// Try to simplify the pattern "Overflow | icmp pred Res, C".
+static Value *simplifyOrOfICmpAndWithOverflow(Value *Op0, Value *Op1,
+                                              const SimplifyQuery &SQ) {
+  const WithOverflowInst *WO;
+  const APInt *C;
+  ICmpInst::Predicate Pred;
+  if (!match(Op0, m_ExtractValue<1>(m_WithOverflowInst(WO))) ||
+      !match(Op1, m_ICmp(Pred, m_ExtractValue<0>(m_Specific(WO)), m_APInt(C))))
+    return nullptr;
+
+  // See if we can infer the result of icmp from the nowrap flag.
+  const auto LHS_CR = llvm::computeConstantRangeIncludingKnownBits(
+      WO->getLHS(), ICmpInst::isSigned(Pred), SQ);
+  const auto RHS_CR = llvm::computeConstantRangeIncludingKnownBits(
+      WO->getRHS(), ICmpInst::isSigned(Pred), SQ);
+  const auto DomCR = LHS_CR.overflowingBinaryOp(WO->getBinaryOp(), RHS_CR,
+                                                WO->getNoWrapKind());
+  const auto CR = llvm::ConstantRange::makeExactICmpRegion(Pred, *C);
+
+  ConstantRange Intersection = DomCR.intersectWith(CR);
+  ConstantRange Difference = DomCR.difference(CR);
+  if (Intersection.isEmptySet())
+    return Op0;
+  if (Difference.isEmptySet())
+    return ConstantInt::getTrue(Op0->getType());
+
+  return nullptr;
+}
+
 /// Given operands for an Or, see if we can fold the result.
 /// If not, this returns null.
 static Value *simplifyOrInst(Value *Op0, Value *Op1, const SimplifyQuery &Q,
@@ -2480,6 +2510,11 @@ static Value *simplifyOrInst(Value *Op0, Value *Op1, const SimplifyQuery &Q,
       if (*Implied == true)
         return ConstantInt::getTrue(Op1->getType());
     }
+
+    if (auto *V = simplifyOrOfICmpAndWithOverflow(Op0, Op1, Q))
+      return V;
+    if (auto *V = simplifyOrOfICmpAndWithOverflow(Op1, Op0, Q))
+      return V;
   }
 
   if (Value *V = simplifyByDomEq(Instruction::Or, Op0, Op1, Q, MaxRecurse))
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 5445746ab2a1bc..e5469a6d659090 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -6228,10 +6228,10 @@ static OverflowResult mapOverflowResult(ConstantRange::OverflowResult OR) {
 }
 
 /// Combine constant ranges from computeConstantRange() and computeKnownBits().
-static ConstantRange
-computeConstantRangeIncludingKnownBits(const WithCache<const Value *> &V,
-                                       bool ForSigned,
-                                       const SimplifyQuery &SQ) {
+ConstantRange
+llvm::computeConstantRangeIncludingKnownBits(const WithCache<const Value *> &V,
+                                             bool ForSigned,
+                                             const SimplifyQuery &SQ) {
   ConstantRange CR1 =
       ConstantRange::fromKnownBits(V.getKnownBits(SQ), ForSigned);
   ConstantRange CR2 = computeConstantRange(V, ForSigned, SQ.IIQ.UseInstrInfo);
diff --git a/llvm/test/Transforms/InstCombine/sadd-with-overflow.ll b/llvm/test/Transforms/InstCombine/sadd-with-overflow.ll
index 4b37ccbe3370b6..a784028ec3a245 100644
--- a/llvm/test/Transforms/InstCombine/sadd-with-overflow.ll
+++ b/llvm/test/Transforms/InstCombine/sadd-with-overflow.ll
@@ -122,3 +122,81 @@ define { i32, i1 } @fold_sub_simple(i32 %x) {
   %b = tail call { i32, i1 } @llvm.sadd.with.overflow.i32(i32 %a, i32 30)
   ret { i32, i1 } %b
 }
+
+; Tests from PR75360
+define i1 @ckd_add_unsigned(i31 %num) {
+; CHECK-LABEL: @ckd_add_unsigned(
+; CHECK-NEXT:    [[A2:%.*]] = icmp eq i31 [[NUM:%.*]], -1
+; CHECK-NEXT:    ret i1 [[A2]]
+;
+  %a0 = zext i31 %num to i32
+  %a1 = tail call { i32, i1 } @llvm.sadd.with.overflow.i32(i32 %a0, i32 1)
+  %a2 = extractvalue { i32, i1 } %a1, 1
+  %a3 = extractvalue { i32, i1 } %a1, 0
+  %a4 = icmp slt i32 %a3, 0
+  %a5 = or i1 %a2, %a4
+  ret i1 %a5
+}
+
+define i1 @ckd_add_unsigned_commuted(i31 %num) {
+; CHECK-LABEL: @ckd_add_unsigned_commuted(
+; CHECK-NEXT:    [[A2:%.*]] = icmp eq i31 [[NUM:%.*]], -1
+; CHECK-NEXT:    ret i1 [[A2]]
+;
+  %a0 = zext i31 %num to i32
+  %a1 = tail call { i32, i1 } @llvm.sadd.with.overflow.i32(i32 %a0, i32 1)
+  %a2 = extractvalue { i32, i1 } %a1, 1
+  %a3 = extractvalue { i32, i1 } %a1, 0
+  %a4 = icmp slt i32 %a3, 0
+  %a5 = or i1 %a4, %a2
+  ret i1 %a5
+}
+
+define i1 @ckd_add_unsigned_imply_true(i31 %num) {
+; CHECK-LABEL: @ckd_add_unsigned_imply_true(
+; CHECK-NEXT:    ret i1 true
+;
+  %a0 = zext i31 %num to i32
+  %a1 = tail call { i32, i1 } @llvm.sadd.with.overflow.i32(i32 %a0, i32 1)
+  %a2 = extractvalue { i32, i1 } %a1, 1
+  %a3 = extractvalue { i32, i1 } %a1, 0
+  %a4 = icmp sgt i32 %a3, -1
+  %a5 = or i1 %a2, %a4
+  ret i1 %a5
+}
+
+define i1 @ckd_add_unsigned_fail1(i32 %a0) {
+; CHECK-LABEL: @ckd_add_unsigned_fail1(
+; CHECK-NEXT:    [[A1:%.*]] = tail call { i32, i1 } @llvm.sadd.with.overflow.i32(i32 [[A0:%.*]], i32 1)
+; CHECK-NEXT:    [[A2:%.*]] = extractvalue { i32, i1 } [[A1]], 1
+; CHECK-NEXT:    [[A3:%.*]] = extractvalue { i32, i1 } [[A1]], 0
+; CHECK-NEXT:    [[A4:%.*]] = icmp slt i32 [[A3]], 0
+; CHECK-NEXT:    [[A5:%.*]] = or i1 [[A2]], [[A4]]
+; CHECK-NEXT:    ret i1 [[A5]]
+;
+  %a1 = tail call { i32, i1 } @llvm.sadd.with.overflow.i32(i32 %a0, i32 1)
+  %a2 = extractvalue { i32, i1 } %a1, 1
+  %a3 = extractvalue { i32, i1 } %a1, 0
+  %a4 = icmp slt i32 %a3, 0
+  %a5 = or i1 %a2, %a4
+  ret i1 %a5
+}
+
+define i1 @ckd_add_unsigned_fail2(i31 %num) {
+; CHECK-LABEL: @ckd_add_unsigned_fail2(
+; CHECK-NEXT:    [[A0:%.*]] = zext i31 [[NUM:%.*]] to i32
+; CHECK-NEXT:    [[A1:%.*]] = tail call { i32, i1 } @llvm.sadd.with.overflow.i32(i32 [[A0]], i32 1)
+; CHECK-NEXT:    [[A2:%.*]] = extractvalue { i32, i1 } [[A1]], 1
+; CHECK-NEXT:    [[A3:%.*]] = extractvalue { i32, i1 } [[A1]], 0
+; CHECK-NEXT:    [[A4:%.*]] = icmp slt i32 [[A3]], 2
+; CHECK-NEXT:    [[A5:%.*]] = or i1 [[A2]], [[A4]]
+; CHECK-NEXT:    ret i1 [[A5]]
+;
+  %a0 = zext i31 %num to i32
+  %a1 = tail call { i32, i1 } @llvm.sadd.with.overflow.i32(i32 %a0, i32 1)
+  %a2 = extractvalue { i32, i1 } %a1, 1
+  %a3 = extractvalue { i32, i1 } %a1, 0
+  %a4 = icmp slt i32 %a3, 2
+  %a5 = or i1 %a2, %a4
+  ret i1 %a5
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/75511


More information about the llvm-commits mailing list