[llvm] [SCEV] Improve applyLoopGuards to support Mul (PR #83428)
via llvm-commits
llvm-commits at lists.llvm.org
Thu Feb 29 07:24:26 PST 2024
https://github.com/komalon1 updated https://github.com/llvm/llvm-project/pull/83428
>From 9e0555c485cfe624ccaaffc2326f2f1adb8dbfae Mon Sep 17 00:00:00 2001
From: komalon1 <alon.kom at mobileye.com>
Date: Thu, 29 Feb 2024 14:51:18 +0200
Subject: [PATCH] [SCEV] Improve applyLoopGuards to suppport Mul
Improve applyLoopGuards to preserve divisibility information
of SCEVMul Expressions.
It does so by preserving the information that a SCEVMul with a SCEVConstant operand divides by this constant.
It also generalized the assumption cache divisibility propagation to non-SCEVUnkinowns.
For example if:
TC = TC1 * TC2;
__builtin_assume(TC % 8 == 0);
We now propagate the information that SCEVMul divides by 8, by rewriting it to ((TC1 * TC2) \ 8) * 8
This fixes #82367.
---
llvm/lib/Analysis/ScalarEvolution.cpp | 33 ++++++++--------
.../ScalarEvolution/trip-count-minmax.ll | 4 +-
.../trip-multiple-guard-info.ll | 38 +++++++++++++++++++
3 files changed, 57 insertions(+), 18 deletions(-)
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 4b2db80bc1ec30..ea6cb82b6f1c85 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -15000,6 +15000,13 @@ class SCEVLoopGuardRewriter : public SCEVRewriteVisitor<SCEVLoopGuardRewriter> {
return SCEVRewriteVisitor<SCEVLoopGuardRewriter>::visitSMinExpr(Expr);
return I->second;
}
+
+ const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
+ auto I = Map.find(Expr);
+ if (I == Map.end())
+ return SCEVRewriteVisitor<SCEVLoopGuardRewriter>::visitMulExpr(Expr);
+ return I->second;
+ }
};
const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
@@ -15149,17 +15156,14 @@ const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
const SCEV *URemLHS = nullptr;
const SCEV *URemRHS = nullptr;
if (matchURem(LHS, URemLHS, URemRHS)) {
- if (const SCEVUnknown *LHSUnknown = dyn_cast<SCEVUnknown>(URemLHS)) {
- auto I = RewriteMap.find(LHSUnknown);
- const SCEV *RewrittenLHS =
- I != RewriteMap.end() ? I->second : LHSUnknown;
- RewrittenLHS = ApplyDivisibiltyOnMinMaxExpr(RewrittenLHS, URemRHS);
- const auto *Multiple =
- getMulExpr(getUDivExpr(RewrittenLHS, URemRHS), URemRHS);
- RewriteMap[LHSUnknown] = Multiple;
- ExprsToRewrite.push_back(LHSUnknown);
- return;
- }
+ auto I = RewriteMap.find(URemLHS);
+ const SCEV *RewrittenLHS = I != RewriteMap.end() ? I->second : URemLHS;
+ RewrittenLHS = ApplyDivisibiltyOnMinMaxExpr(RewrittenLHS, URemRHS);
+ const auto *Multiple =
+ getMulExpr(getUDivExpr(RewrittenLHS, URemRHS), URemRHS);
+ RewriteMap[URemLHS] = Multiple;
+ ExprsToRewrite.push_back(URemLHS);
+ return;
}
}
@@ -15208,11 +15212,8 @@ const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
auto *MulRHS = Mul->getOperand(1);
if (isa<SCEVConstant>(MulLHS))
std::swap(MulLHS, MulRHS);
- if (auto *Div = dyn_cast<SCEVUDivExpr>(MulLHS))
- if (Div->getOperand(1) == MulRHS) {
- DividesBy = MulRHS;
- return true;
- }
+ DividesBy = MulRHS;
+ return true;
}
if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr))
return HasDivisibiltyInfo(MinMax->getOperand(0), DividesBy) ||
diff --git a/llvm/test/Analysis/ScalarEvolution/trip-count-minmax.ll b/llvm/test/Analysis/ScalarEvolution/trip-count-minmax.ll
index 7d4876baa9e5d9..482accc38cb391 100644
--- a/llvm/test/Analysis/ScalarEvolution/trip-count-minmax.ll
+++ b/llvm/test/Analysis/ScalarEvolution/trip-count-minmax.ll
@@ -65,7 +65,7 @@ define void @umin(i32 noundef %a, i32 noundef %b) {
; CHECK-NEXT: Loop %for.body: symbolic max backedge-taken count is (-1 + ((2 * %a) umin (4 * %b)))
; CHECK-NEXT: Loop %for.body: Predicated backedge-taken count is (-1 + ((2 * %a) umin (4 * %b)))
; CHECK-NEXT: Predicates:
-; CHECK-NEXT: Loop %for.body: Trip multiple is 1
+; CHECK-NEXT: Loop %for.body: Trip multiple is 2
;
; void umin(unsigned a, unsigned b) {
; a *= 2;
@@ -165,7 +165,7 @@ define void @smin(i32 noundef %a, i32 noundef %b) {
; CHECK-NEXT: Loop %for.body: symbolic max backedge-taken count is (-1 + ((2 * %a)<nsw> smin (4 * %b)<nsw>))
; CHECK-NEXT: Loop %for.body: Predicated backedge-taken count is (-1 + ((2 * %a)<nsw> smin (4 * %b)<nsw>))
; CHECK-NEXT: Predicates:
-; CHECK-NEXT: Loop %for.body: Trip multiple is 1
+; CHECK-NEXT: Loop %for.body: Trip multiple is 2
;
; void smin(signed a, signed b) {
; a *= 2;
diff --git a/llvm/test/Analysis/ScalarEvolution/trip-multiple-guard-info.ll b/llvm/test/Analysis/ScalarEvolution/trip-multiple-guard-info.ll
index a0a5158bdff160..25fe63e9d61bc0 100644
--- a/llvm/test/Analysis/ScalarEvolution/trip-multiple-guard-info.ll
+++ b/llvm/test/Analysis/ScalarEvolution/trip-multiple-guard-info.ll
@@ -607,5 +607,43 @@ exit:
ret void
}
+define void @test_trip_scevmul_multiple_5(i32 %num1, i32 %num2) {
+; CHECK-LABEL: 'test_trip_scevmul_multiple_5'
+; CHECK-NEXT: Classifying expressions for: @test_trip_scevmul_multiple_5
+; CHECK-NEXT: %num = mul i32 %num1, %num2
+; CHECK-NEXT: --> (%num1 * %num2) U: full-set S: full-set
+; CHECK-NEXT: %u = urem i32 %num, 5
+; CHECK-NEXT: --> ((-5 * ((%num1 * %num2) /u 5)) + (%num1 * %num2)) U: full-set S: full-set
+; CHECK-NEXT: %i.010 = phi i32 [ 0, %entry ], [ %inc, %for.body ]
+; CHECK-NEXT: --> {0,+,1}<nuw><nsw><%for.body> U: [0,-2147483648) S: [0,-2147483648) Exits: (-1 + (%num1 * %num2)) LoopDispositions: { %for.body: Computable }
+; CHECK-NEXT: %inc = add nuw nsw i32 %i.010, 1
+; CHECK-NEXT: --> {1,+,1}<nuw><nsw><%for.body> U: [1,-2147483648) S: [1,-2147483648) Exits: (%num1 * %num2) LoopDispositions: { %for.body: Computable }
+; CHECK-NEXT: Determining loop execution counts for: @test_trip_scevmul_multiple_5
+; CHECK-NEXT: Loop %for.body: backedge-taken count is (-1 + (%num1 * %num2))
+; CHECK-NEXT: Loop %for.body: constant max backedge-taken count is -2
+; CHECK-NEXT: Loop %for.body: symbolic max backedge-taken count is (-1 + (%num1 * %num2))
+; CHECK-NEXT: Loop %for.body: Predicated backedge-taken count is (-1 + (%num1 * %num2))
+; CHECK-NEXT: Predicates:
+; CHECK-NEXT: Loop %for.body: Trip multiple is 5
+;
+entry:
+ %num = mul i32 %num1, %num2
+ %u = urem i32 %num, 5
+ %cmp = icmp eq i32 %u, 0
+ tail call void @llvm.assume(i1 %cmp)
+ %cmp.1 = icmp uge i32 %num, 5
+ tail call void @llvm.assume(i1 %cmp.1)
+ br label %for.body
+
+for.body:
+ %i.010 = phi i32 [ 0, %entry ], [ %inc, %for.body ]
+ %inc = add nuw nsw i32 %i.010, 1
+ %cmp2 = icmp ult i32 %inc, %num
+ br i1 %cmp2, label %for.body, label %exit
+
+exit:
+ ret void
+}
+
declare void @llvm.assume(i1)
declare void @llvm.experimental.guard(i1, ...)
More information about the llvm-commits
mailing list