[llvm] SCEV: teach isImpliedViaOperations about samesign (PR #124270)
via llvm-commits
llvm-commits at lists.llvm.org
Fri Jan 24 05:42:20 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-analysis
Author: Ramkumar Ramachandra (artagnon)
<details>
<summary>Changes</summary>
Use CmpPredicate::getMatching in isImpliedCondBalancedTypes to pass samesign information to isImpliedViaOperations, and teach it to call CmpPredicate::getPreferredSignedPredicate, effectively making it optimize with samesign information.
---
Full diff: https://github.com/llvm/llvm-project/pull/124270.diff
2 Files Affected:
- (modified) llvm/lib/Analysis/ScalarEvolution.cpp (+17-16)
- (modified) llvm/test/Analysis/ScalarEvolution/implied-via-division.ll (+131-33)
``````````diff
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 7d7d37b3d228dd..d3e060e4f2be84 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -11863,15 +11863,13 @@ bool ScalarEvolution::isImpliedCondBalancedTypes(
}
// Check whether the found predicate is the same as the desired predicate.
- // FIXME: use CmpPredicate::getMatching here.
- if (FoundPred == static_cast<CmpInst::Predicate>(Pred))
- return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI);
+ if (auto P = CmpPredicate::getMatching(FoundPred, Pred))
+ return isImpliedCondOperands(*P, LHS, RHS, FoundLHS, FoundRHS, CtxI);
// Check whether swapping the found predicate makes it the same as the
// desired predicate.
- // FIXME: use CmpPredicate::getMatching here.
- if (ICmpInst::getSwappedCmpPredicate(FoundPred) ==
- static_cast<CmpInst::Predicate>(Pred)) {
+ if (auto P = CmpPredicate::getMatching(
+ ICmpInst::getSwappedCmpPredicate(FoundPred), Pred)) {
// We can write the implication
// 0. LHS Pred RHS <- FoundLHS SwapPred FoundRHS
// using one of the following ways:
@@ -11882,22 +11880,23 @@ bool ScalarEvolution::isImpliedCondBalancedTypes(
// Forms 1. and 2. require swapping the operands of one condition. Don't
// do this if it would break canonical constant/addrec ordering.
if (!isa<SCEVConstant>(RHS) && !isa<SCEVAddRecExpr>(LHS))
- return isImpliedCondOperands(FoundPred, RHS, LHS, FoundLHS, FoundRHS,
- CtxI);
+ return isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(*P), RHS,
+ LHS, FoundLHS, FoundRHS, CtxI);
if (!isa<SCEVConstant>(FoundRHS) && !isa<SCEVAddRecExpr>(FoundLHS))
- return isImpliedCondOperands(Pred, LHS, RHS, FoundRHS, FoundLHS, CtxI);
+ return isImpliedCondOperands(*P, LHS, RHS, FoundRHS, FoundLHS, CtxI);
// There's no clear preference between forms 3. and 4., try both. Avoid
// forming getNotSCEV of pointer values as the resulting subtract is
// not legal.
if (!LHS->getType()->isPointerTy() && !RHS->getType()->isPointerTy() &&
- isImpliedCondOperands(FoundPred, getNotSCEV(LHS), getNotSCEV(RHS),
- FoundLHS, FoundRHS, CtxI))
+ isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(*P),
+ getNotSCEV(LHS), getNotSCEV(RHS), FoundLHS,
+ FoundRHS, CtxI))
return true;
if (!FoundLHS->getType()->isPointerTy() &&
!FoundRHS->getType()->isPointerTy() &&
- isImpliedCondOperands(Pred, LHS, RHS, getNotSCEV(FoundLHS),
+ isImpliedCondOperands(*P, LHS, RHS, getNotSCEV(FoundLHS),
getNotSCEV(FoundRHS), CtxI))
return true;
@@ -12567,14 +12566,16 @@ bool ScalarEvolution::isImpliedViaOperations(CmpPredicate Pred, const SCEV *LHS,
return false;
// We only want to work with GT comparison so far.
- if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_SLT) {
+ if (ICmpInst::isLT(Pred)) {
Pred = ICmpInst::getSwappedCmpPredicate(Pred);
std::swap(LHS, RHS);
std::swap(FoundLHS, FoundRHS);
}
+ CmpInst::Predicate P = Pred.getPreferredSignedPredicate();
+
// For unsigned, try to reduce it to corresponding signed comparison.
- if (Pred == ICmpInst::ICMP_UGT)
+ if (P == ICmpInst::ICMP_UGT)
// We can replace unsigned predicate with its signed counterpart if all
// involved values are non-negative.
// TODO: We could have better support for unsigned.
@@ -12587,10 +12588,10 @@ bool ScalarEvolution::isImpliedViaOperations(CmpPredicate Pred, const SCEV *LHS,
FoundRHS) &&
isImpliedCondOperands(ICmpInst::ICMP_SGT, RHS, MinusOne, FoundLHS,
FoundRHS))
- Pred = ICmpInst::ICMP_SGT;
+ P = ICmpInst::ICMP_SGT;
}
- if (Pred != ICmpInst::ICMP_SGT)
+ if (P != ICmpInst::ICMP_SGT)
return false;
auto GetOpFromSExt = [&](const SCEV *S) {
diff --git a/llvm/test/Analysis/ScalarEvolution/implied-via-division.ll b/llvm/test/Analysis/ScalarEvolution/implied-via-division.ll
index a1d30406095ec5..d83301243ef30b 100644
--- a/llvm/test/Analysis/ScalarEvolution/implied-via-division.ll
+++ b/llvm/test/Analysis/ScalarEvolution/implied-via-division.ll
@@ -2,12 +2,10 @@
; RUN: opt < %s -disable-output -passes="print<scalar-evolution>" \
; RUN: -scalar-evolution-classify-expressions=0 2>&1 | FileCheck %s
-declare void @llvm.experimental.guard(i1, ...)
-
-define void @test_1(i32 %n) nounwind {
-; Prove that (n > 1) ===> (n / 2 > 0).
-; CHECK-LABEL: 'test_1'
-; CHECK-NEXT: Determining loop execution counts for: @test_1
+define void @implied1(i32 %n) {
+; Prove that (n s> 1) ===> (n / 2 s> 0).
+; CHECK-LABEL: 'implied1'
+; CHECK-NEXT: Determining loop execution counts for: @implied1
; CHECK-NEXT: Loop %header: backedge-taken count is (-1 + %n.div.2)<nsw>
; CHECK-NEXT: Loop %header: constant max backedge-taken count is i32 1073741822
; CHECK-NEXT: Loop %header: symbolic max backedge-taken count is (-1 + %n.div.2)<nsw>
@@ -29,10 +27,35 @@ exit:
ret void
}
-define void @test_1neg(i32 %n) nounwind {
-; Prove that (n > 0) =\=> (n / 2 > 0).
-; CHECK-LABEL: 'test_1neg'
-; CHECK-NEXT: Determining loop execution counts for: @test_1neg
+define void @implied1_samesign(i32 %n) {
+; Prove that (n > 1) ===> (n / 2 s> 0).
+; CHECK-LABEL: 'implied1_samesign'
+; CHECK-NEXT: Determining loop execution counts for: @implied1_samesign
+; CHECK-NEXT: Loop %header: backedge-taken count is (-1 + %n.div.2)<nsw>
+; CHECK-NEXT: Loop %header: constant max backedge-taken count is i32 1073741822
+; CHECK-NEXT: Loop %header: symbolic max backedge-taken count is (-1 + %n.div.2)<nsw>
+; CHECK-NEXT: Loop %header: Trip multiple is 1
+;
+entry:
+ %cmp1 = icmp samesign ugt i32 %n, 1
+ %n.div.2 = sdiv i32 %n, 2
+ call void @llvm.assume(i1 %cmp1)
+ br label %header
+
+header:
+ %indvar = phi i32 [ %indvar.next, %header ], [ 0, %entry ]
+ %indvar.next = add i32 %indvar, 1
+ %exitcond = icmp sgt i32 %n.div.2, %indvar.next
+ br i1 %exitcond, label %header, label %exit
+
+exit:
+ ret void
+}
+
+define void @implied1_neg(i32 %n) {
+; Prove that (n s> 0) =\=> (n / 2 s> 0).
+; CHECK-LABEL: 'implied1_neg'
+; CHECK-NEXT: Determining loop execution counts for: @implied1_neg
; CHECK-NEXT: Loop %header: backedge-taken count is (-1 + (1 smax %n.div.2))<nsw>
; CHECK-NEXT: Loop %header: constant max backedge-taken count is i32 1073741822
; CHECK-NEXT: Loop %header: symbolic max backedge-taken count is (-1 + (1 smax %n.div.2))<nsw>
@@ -54,10 +77,10 @@ exit:
ret void
}
-define void @test_2(i32 %n) nounwind {
-; Prove that (n >= 2) ===> (n / 2 > 0).
-; CHECK-LABEL: 'test_2'
-; CHECK-NEXT: Determining loop execution counts for: @test_2
+define void @implied2(i32 %n) {
+; Prove that (n s>= 2) ===> (n / 2 s> 0).
+; CHECK-LABEL: 'implied2'
+; CHECK-NEXT: Determining loop execution counts for: @implied2
; CHECK-NEXT: Loop %header: backedge-taken count is (-1 + %n.div.2)<nsw>
; CHECK-NEXT: Loop %header: constant max backedge-taken count is i32 1073741822
; CHECK-NEXT: Loop %header: symbolic max backedge-taken count is (-1 + %n.div.2)<nsw>
@@ -79,10 +102,35 @@ exit:
ret void
}
-define void @test_2neg(i32 %n) nounwind {
-; Prove that (n >= 1) =\=> (n / 2 > 0).
-; CHECK-LABEL: 'test_2neg'
-; CHECK-NEXT: Determining loop execution counts for: @test_2neg
+define void @implied2_samesign(i32 %n) {
+; Prove that (n >= 2) ===> (n / 2 s> 0).
+; CHECK-LABEL: 'implied2_samesign'
+; CHECK-NEXT: Determining loop execution counts for: @implied2_samesign
+; CHECK-NEXT: Loop %header: backedge-taken count is (-1 + (1 smax %n.div.2))<nsw>
+; CHECK-NEXT: Loop %header: constant max backedge-taken count is i32 1073741822
+; CHECK-NEXT: Loop %header: symbolic max backedge-taken count is (-1 + (1 smax %n.div.2))<nsw>
+; CHECK-NEXT: Loop %header: Trip multiple is 1
+;
+entry:
+ %cmp1 = icmp samesign uge i32 %n, 2
+ %n.div.2 = sdiv i32 %n, 2
+ call void @llvm.assume(i1 %cmp1)
+ br label %header
+
+header:
+ %indvar = phi i32 [ %indvar.next, %header ], [ 0, %entry ]
+ %indvar.next = add i32 %indvar, 1
+ %exitcond = icmp sgt i32 %n.div.2, %indvar.next
+ br i1 %exitcond, label %header, label %exit
+
+exit:
+ ret void
+}
+
+define void @implied2_neg(i32 %n) {
+; Prove that (n s>= 1) =\=> (n / 2 s> 0).
+; CHECK-LABEL: 'implied2_neg'
+; CHECK-NEXT: Determining loop execution counts for: @implied2_neg
; CHECK-NEXT: Loop %header: backedge-taken count is (-1 + (1 smax %n.div.2))<nsw>
; CHECK-NEXT: Loop %header: constant max backedge-taken count is i32 1073741822
; CHECK-NEXT: Loop %header: symbolic max backedge-taken count is (-1 + (1 smax %n.div.2))<nsw>
@@ -104,10 +152,10 @@ exit:
ret void
}
-define void @test_3(i32 %n) nounwind {
-; Prove that (n > -2) ===> (n / 2 >= 0).
-; CHECK-LABEL: 'test_3'
-; CHECK-NEXT: Determining loop execution counts for: @test_3
+define void @implied3(i32 %n) {
+; Prove that (n s> -2) ===> (n / 2 s>= 0).
+; CHECK-LABEL: 'implied3'
+; CHECK-NEXT: Determining loop execution counts for: @implied3
; CHECK-NEXT: Loop %header: backedge-taken count is (1 + %n.div.2)<nsw>
; CHECK-NEXT: Loop %header: constant max backedge-taken count is i32 1073741824
; CHECK-NEXT: Loop %header: symbolic max backedge-taken count is (1 + %n.div.2)<nsw>
@@ -129,10 +177,35 @@ exit:
ret void
}
-define void @test_3neg(i32 %n) nounwind {
+define void @implied3_samesign(i32 %n) {
+; Prove that (n > -2) ===> (n / 2 s>= 0).
+; CHECK-LABEL: 'implied3_samesign'
+; CHECK-NEXT: Determining loop execution counts for: @implied3_samesign
+; CHECK-NEXT: Loop %header: backedge-taken count is (1 + %n.div.2)<nsw>
+; CHECK-NEXT: Loop %header: constant max backedge-taken count is i32 1
+; CHECK-NEXT: Loop %header: symbolic max backedge-taken count is (1 + %n.div.2)<nsw>
+; CHECK-NEXT: Loop %header: Trip multiple is 1
+;
+entry:
+ %cmp1 = icmp samesign ugt i32 %n, -2
+ %n.div.2 = sdiv i32 %n, 2
+ call void @llvm.assume(i1 %cmp1)
+ br label %header
+
+header:
+ %indvar = phi i32 [ %indvar.next, %header ], [ 0, %entry ]
+ %indvar.next = add i32 %indvar, 1
+ %exitcond = icmp sge i32 %n.div.2, %indvar
+ br i1 %exitcond, label %header, label %exit
+
+exit:
+ ret void
+}
+
+define void @implied3_neg(i32 %n) {
; Prove that (n > -3) =\=> (n / 2 >= 0).
-; CHECK-LABEL: 'test_3neg'
-; CHECK-NEXT: Determining loop execution counts for: @test_3neg
+; CHECK-LABEL: 'implied3_neg'
+; CHECK-NEXT: Determining loop execution counts for: @implied3_neg
; CHECK-NEXT: Loop %header: backedge-taken count is (0 smax (1 + %n.div.2)<nsw>)
; CHECK-NEXT: Loop %header: constant max backedge-taken count is i32 1073741824
; CHECK-NEXT: Loop %header: symbolic max backedge-taken count is (0 smax (1 + %n.div.2)<nsw>)
@@ -154,10 +227,10 @@ exit:
ret void
}
-define void @test_4(i32 %n) nounwind {
-; Prove that (n >= -1) ===> (n / 2 >= 0).
-; CHECK-LABEL: 'test_4'
-; CHECK-NEXT: Determining loop execution counts for: @test_4
+define void @implied4(i32 %n) {
+; Prove that (n s>= -1) ===> (n / 2 s>= 0).
+; CHECK-LABEL: 'implied4'
+; CHECK-NEXT: Determining loop execution counts for: @implied4
; CHECK-NEXT: Loop %header: backedge-taken count is (1 + %n.div.2)<nsw>
; CHECK-NEXT: Loop %header: constant max backedge-taken count is i32 1073741824
; CHECK-NEXT: Loop %header: symbolic max backedge-taken count is (1 + %n.div.2)<nsw>
@@ -179,10 +252,35 @@ exit:
ret void
}
-define void @test_4neg(i32 %n) nounwind {
-; Prove that (n >= -2) =\=> (n / 2 >= 0).
-; CHECK-LABEL: 'test_4neg'
-; CHECK-NEXT: Determining loop execution counts for: @test_4neg
+define void @implied4_samesign(i32 %n) {
+; Prove that (n >= -1) ===> (n / 2 s>= 0).
+; CHECK-LABEL: 'implied4_samesign'
+; CHECK-NEXT: Determining loop execution counts for: @implied4_samesign
+; CHECK-NEXT: Loop %header: backedge-taken count is (1 + %n.div.2)<nsw>
+; CHECK-NEXT: Loop %header: constant max backedge-taken count is i32 1
+; CHECK-NEXT: Loop %header: symbolic max backedge-taken count is (1 + %n.div.2)<nsw>
+; CHECK-NEXT: Loop %header: Trip multiple is 1
+;
+entry:
+ %cmp1 = icmp samesign uge i32 %n, -1
+ %n.div.2 = sdiv i32 %n, 2
+ call void @llvm.assume(i1 %cmp1)
+ br label %header
+
+header:
+ %indvar = phi i32 [ %indvar.next, %header ], [ 0, %entry ]
+ %indvar.next = add i32 %indvar, 1
+ %exitcond = icmp sge i32 %n.div.2, %indvar
+ br i1 %exitcond, label %header, label %exit
+
+exit:
+ ret void
+}
+
+define void @implied4_neg(i32 %n) {
+; Prove that (n s>= -2) =\=> (n / 2 s>= 0).
+; CHECK-LABEL: 'implied4_neg'
+; CHECK-NEXT: Determining loop execution counts for: @implied4_neg
; CHECK-NEXT: Loop %header: backedge-taken count is (0 smax (1 + %n.div.2)<nsw>)
; CHECK-NEXT: Loop %header: constant max backedge-taken count is i32 1073741824
; CHECK-NEXT: Loop %header: symbolic max backedge-taken count is (0 smax (1 + %n.div.2)<nsw>)
``````````
</details>
https://github.com/llvm/llvm-project/pull/124270
More information about the llvm-commits
mailing list