[llvm] [ValueTracking] Use assume to compute overflowResult. (PR #121665)

via llvm-commits llvm-commits at lists.llvm.org
Sat Jan 4 13:13:12 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-analysis

@llvm/pr-subscribers-llvm-transforms

Author: Andreas Jonson (andjo403)

<details>
<summary>Changes</summary>

Rust code have a lot of assumes with the overflow flag from WithOverflow Instructions as condition, this instruction can due to the assume be folded to the operation without overflow check.

e.g. the pattern that this PR look for:
```
  %call = call { i8, i1 } @<!-- -->llvm.uadd.with.overflow.i8(i8 %a, i8 %b)
  %overflow = extractvalue { i8, i1 } %call, 1
  %not = xor i1 %overflow, true
  call void @<!-- -->llvm.assume(i1 %not)
```
that can be folded to:
`%call = add nuw i8 %a, %b`

It is the hashbrown crate that have e.g. [this match with use of hint::unreachable_unchecked()](https://github.com/rust-lang/hashbrown/blob/16044fe2f8784f665a563849cc63867cbafb80c6/src/raw/mod.rs#L1364) together with the use of checked instructions in [calculate_layout_for](https://github.com/rust-lang/hashbrown/blob/16044fe2f8784f665a563849cc63867cbafb80c6/src/raw/mod.rs#L168) that after inlining and simplifyCfg result in this pattern.

Prof: https://alive2.llvm.org/ce/z/oRu3oi

---
Full diff: https://github.com/llvm/llvm-project/pull/121665.diff


2 Files Affected:

- (modified) llvm/lib/Analysis/ValueTracking.cpp (+41) 
- (modified) llvm/test/Transforms/InstCombine/with_overflow.ll (+129) 


``````````diff
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 2f6e869ae7b735..51b3627056fbb6 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -7177,10 +7177,33 @@ llvm::computeConstantRangeIncludingKnownBits(const WithCache<const Value *> &V,
   return CR1.intersectWith(CR2, RangeType);
 }
 
+static bool isKnownToNotOverflowFromAssume(const SimplifyQuery &Q) {
+  // Use of assumptions is context-sensitive. If we don't have a context, we
+  // cannot use them!
+  if (!Q.AC || !Q.CxtI)
+    return false;
+
+  for (AssumptionCache::ResultElem &Elem : Q.AC->assumptionsFor(Q.CxtI)) {
+    if (!Elem.Assume)
+      continue;
+
+    AssumeInst *I = cast<AssumeInst>(Elem.Assume);
+    if (match(I->getArgOperand(0),
+              m_Not(m_ExtractValue<1>(m_Specific(Q.CxtI)))) &&
+        isValidAssumeForContext(I, Q.CxtI, /*DT=*/nullptr,
+                                /*AllowEphemerals=*/true))
+      return true;
+  }
+  return false;
+}
+
 OverflowResult llvm::computeOverflowForUnsignedMul(const Value *LHS,
                                                    const Value *RHS,
                                                    const SimplifyQuery &SQ,
                                                    bool IsNSW) {
+  if (isKnownToNotOverflowFromAssume(SQ))
+    return OverflowResult::NeverOverflows;
+
   KnownBits LHSKnown = computeKnownBits(LHS, /*Depth=*/0, SQ);
   KnownBits RHSKnown = computeKnownBits(RHS, /*Depth=*/0, SQ);
 
@@ -7196,6 +7219,9 @@ OverflowResult llvm::computeOverflowForUnsignedMul(const Value *LHS,
 OverflowResult llvm::computeOverflowForSignedMul(const Value *LHS,
                                                  const Value *RHS,
                                                  const SimplifyQuery &SQ) {
+  if (isKnownToNotOverflowFromAssume(SQ))
+    return OverflowResult::NeverOverflows;
+
   // Multiplying n * m significant bits yields a result of n + m significant
   // bits. If the total number of significant bits does not exceed the
   // result bit width (minus 1), there is no overflow.
@@ -7236,6 +7262,9 @@ OverflowResult
 llvm::computeOverflowForUnsignedAdd(const WithCache<const Value *> &LHS,
                                     const WithCache<const Value *> &RHS,
                                     const SimplifyQuery &SQ) {
+  if (isKnownToNotOverflowFromAssume(SQ))
+    return OverflowResult::NeverOverflows;
+
   ConstantRange LHSRange =
       computeConstantRangeIncludingKnownBits(LHS, /*ForSigned=*/false, SQ);
   ConstantRange RHSRange =
@@ -7251,6 +7280,9 @@ computeOverflowForSignedAdd(const WithCache<const Value *> &LHS,
     return OverflowResult::NeverOverflows;
   }
 
+  if (isKnownToNotOverflowFromAssume(SQ))
+    return OverflowResult::NeverOverflows;
+
   // If LHS and RHS each have at least two sign bits, the addition will look
   // like
   //
@@ -7305,6 +7337,9 @@ computeOverflowForSignedAdd(const WithCache<const Value *> &LHS,
 OverflowResult llvm::computeOverflowForUnsignedSub(const Value *LHS,
                                                    const Value *RHS,
                                                    const SimplifyQuery &SQ) {
+  if (isKnownToNotOverflowFromAssume(SQ))
+    return OverflowResult::NeverOverflows;
+
   // X - (X % ?)
   // The remainder of a value can't have greater magnitude than itself,
   // so the subtraction can't overflow.
@@ -7338,6 +7373,9 @@ OverflowResult llvm::computeOverflowForUnsignedSub(const Value *LHS,
 OverflowResult llvm::computeOverflowForSignedSub(const Value *LHS,
                                                  const Value *RHS,
                                                  const SimplifyQuery &SQ) {
+  if (isKnownToNotOverflowFromAssume(SQ))
+    return OverflowResult::NeverOverflows;
+
   // X - (X % ?)
   // The remainder of a value can't have greater magnitude than itself,
   // so the subtraction can't overflow.
@@ -10180,6 +10218,9 @@ void llvm::findValuesAffectedByCondition(
                                                            m_Value()))) {
       // Handle patterns that computeKnownFPClass() support.
       AddAffected(A);
+    } else if (IsAssume && match(V, m_Not(m_ExtractValue<1>(m_Value(A)))) &&
+               isa<WithOverflowInst>(A)) {
+      AddAffected(A);
     }
   }
 }
diff --git a/llvm/test/Transforms/InstCombine/with_overflow.ll b/llvm/test/Transforms/InstCombine/with_overflow.ll
index fa810408730e1b..dd64e993bb620c 100644
--- a/llvm/test/Transforms/InstCombine/with_overflow.ll
+++ b/llvm/test/Transforms/InstCombine/with_overflow.ll
@@ -1064,3 +1064,132 @@ define i8 @smul_7(i8 %x, ptr %p) {
   store i1 %ov, ptr %p
   ret i8 %r
 }
+
+define i8 @uadd_assume_no_overflow(i8 noundef %a, i8 noundef %b) {
+; CHECK-LABEL: @uadd_assume_no_overflow(
+; CHECK-NEXT:    [[CALL:%.*]] = add nuw i8 [[A:%.*]], [[B:%.*]]
+; CHECK-NEXT:    ret i8 [[CALL]]
+;
+  %call = call { i8, i1 } @llvm.uadd.with.overflow.i8(i8 %a, i8 %b)
+  %overflow = extractvalue { i8, i1 } %call, 1
+  %ret = extractvalue { i8, i1 } %call, 0
+  %not = xor i1 %overflow, true
+  call void @llvm.assume(i1 %not)
+  ret i8 %ret
+}
+
+define i8 @sadd_assume_no_overflow(i8 noundef %a, i8 noundef %b) {
+; CHECK-LABEL: @sadd_assume_no_overflow(
+; CHECK-NEXT:    [[CALL:%.*]] = add nsw i8 [[A:%.*]], [[B:%.*]]
+; CHECK-NEXT:    ret i8 [[CALL]]
+;
+  %call = call { i8, i1 } @llvm.sadd.with.overflow.i8(i8 %a, i8 %b)
+  %overflow = extractvalue { i8, i1 } %call, 1
+  %ret = extractvalue { i8, i1 } %call, 0
+  %not = xor i1 %overflow, true
+  call void @llvm.assume(i1 %not)
+  ret i8 %ret
+}
+
+define i8 @usub_assume_no_overflow(i8 noundef %a, i8 noundef %b) {
+; CHECK-LABEL: @usub_assume_no_overflow(
+; CHECK-NEXT:    [[CALL:%.*]] = sub nuw i8 [[A:%.*]], [[B:%.*]]
+; CHECK-NEXT:    ret i8 [[CALL]]
+;
+  %call = call { i8, i1 } @llvm.usub.with.overflow.i8(i8 %a, i8 %b)
+  %overflow = extractvalue { i8, i1 } %call, 1
+  %ret = extractvalue { i8, i1 } %call, 0
+  %not = xor i1 %overflow, true
+  call void @llvm.assume(i1 %not)
+  ret i8 %ret
+}
+
+define i8 @ssub_assume_no_overflow(i8 noundef %a, i8 noundef %b) {
+; CHECK-LABEL: @ssub_assume_no_overflow(
+; CHECK-NEXT:    [[CALL:%.*]] = sub nsw i8 [[A:%.*]], [[B:%.*]]
+; CHECK-NEXT:    ret i8 [[CALL]]
+;
+  %call = call { i8, i1 } @llvm.ssub.with.overflow.i8(i8 %a, i8 %b)
+  %overflow = extractvalue { i8, i1 } %call, 1
+  %ret = extractvalue { i8, i1 } %call, 0
+  %not = xor i1 %overflow, true
+  call void @llvm.assume(i1 %not)
+  ret i8 %ret
+}
+
+define i8 @umul_assume_no_overflow(i8 noundef %a, i8 noundef %b) {
+; CHECK-LABEL: @umul_assume_no_overflow(
+; CHECK-NEXT:    [[CALL:%.*]] = mul nuw i8 [[A:%.*]], [[B:%.*]]
+; CHECK-NEXT:    ret i8 [[CALL]]
+;
+  %call = call { i8, i1 } @llvm.umul.with.overflow.i8(i8 %a, i8 %b)
+  %overflow = extractvalue { i8, i1 } %call, 1
+  %ret = extractvalue { i8, i1 } %call, 0
+  %not = xor i1 %overflow, true
+  call void @llvm.assume(i1 %not)
+  ret i8 %ret
+}
+
+define i8 @smul_assume_no_overflow(i8 noundef %a, i8 noundef %b) {
+; CHECK-LABEL: @smul_assume_no_overflow(
+; CHECK-NEXT:    [[CALL:%.*]] = mul nsw i8 [[A:%.*]], [[B:%.*]]
+; CHECK-NEXT:    ret i8 [[CALL]]
+;
+  %call = call { i8, i1 } @llvm.smul.with.overflow.i8(i8 %a, i8 %b)
+  %overflow = extractvalue { i8, i1 } %call, 1
+  %ret = extractvalue { i8, i1 } %call, 0
+  %not = xor i1 %overflow, true
+  call void @llvm.assume(i1 %not)
+  ret i8 %ret
+}
+
+define i1 @ephemeral_call_assume_no_overflow(i8 noundef %a, i8 noundef %b) {
+; CHECK-LABEL: @ephemeral_call_assume_no_overflow(
+; CHECK-NEXT:    ret i1 true
+;
+  %call = call { i8, i1 } @llvm.smul.with.overflow.i8(i8 %a, i8 %b)
+  %overflow = extractvalue { i8, i1 } %call, 1
+  %not = xor i1 %overflow, true
+  call void @llvm.assume(i1 %not)
+  ret i1 %not
+}
+
+define i8 @neg_assume_overflow(i8 noundef %a, i8 noundef %b) {
+; CHECK-LABEL: @neg_assume_overflow(
+; CHECK-NEXT:    [[CALL:%.*]] = call { i8, i1 } @llvm.smul.with.overflow.i8(i8 [[A:%.*]], i8 [[B:%.*]])
+; CHECK-NEXT:    [[OVERFLOW:%.*]] = extractvalue { i8, i1 } [[CALL]], 1
+; CHECK-NEXT:    [[RET:%.*]] = extractvalue { i8, i1 } [[CALL]], 0
+; CHECK-NEXT:    call void @llvm.assume(i1 [[OVERFLOW]])
+; CHECK-NEXT:    ret i8 [[RET]]
+;
+  %call = call { i8, i1 } @llvm.smul.with.overflow.i8(i8 %a, i8 %b)
+  %overflow = extractvalue { i8, i1 } %call, 1
+  %ret = extractvalue { i8, i1 } %call, 0
+  call void @llvm.assume(i1 %overflow)
+  ret i8 %ret
+}
+
+define i8 @neg_assume_not_guaranteed_to_execute(i8 noundef %a, i8 noundef %b, i1 %cond) {
+; CHECK-LABEL: @neg_assume_not_guaranteed_to_execute(
+; CHECK-NEXT:    [[CALL:%.*]] = call { i8, i1 } @llvm.smul.with.overflow.i8(i8 [[A:%.*]], i8 [[B:%.*]])
+; CHECK-NEXT:    br i1 [[COND:%.*]], label [[BB1:%.*]], label [[BB2:%.*]]
+; CHECK:       bb1:
+; CHECK-NEXT:    [[OVERFLOW:%.*]] = extractvalue { i8, i1 } [[CALL]], 1
+; CHECK-NEXT:    [[NOT:%.*]] = xor i1 [[OVERFLOW]], true
+; CHECK-NEXT:    call void @llvm.assume(i1 [[NOT]])
+; CHECK-NEXT:    br label [[BB2]]
+; CHECK:       bb2:
+; CHECK-NEXT:    [[RET:%.*]] = extractvalue { i8, i1 } [[CALL]], 0
+; CHECK-NEXT:    ret i8 [[RET]]
+;
+  %call = call { i8, i1 } @llvm.smul.with.overflow.i8(i8 %a, i8 %b)
+  %overflow = extractvalue { i8, i1 } %call, 1
+  %ret = extractvalue { i8, i1 } %call, 0
+  br i1 %cond, label %bb1, label %bb2
+bb1:
+  %not = xor i1 %overflow, true
+  call void @llvm.assume(i1 %not)
+  br label %bb2
+bb2:
+  ret i8 %ret
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/121665


More information about the llvm-commits mailing list