[llvm] [SCEV] Improve applyLoopGuards to support Mul (PR #83428)

via llvm-commits llvm-commits at lists.llvm.org
Sun Mar 10 01:36:28 PST 2024


https://github.com/komalon1 updated https://github.com/llvm/llvm-project/pull/83428

>From 58b538ef59bb0dd3d525b6cccdf10e2ed420ca83 Mon Sep 17 00:00:00 2001
From: komalon1 <alon.kom at mobileye.com>
Date: Thu, 29 Feb 2024 14:51:18 +0200
Subject: [PATCH 1/2] [SCEV] Improve applyLoopGuards to suppport Mul

Improve applyLoopGuards to preserve divisibility information
of SCEVMul Expressions.
It does so by preserving the information that a SCEVMul with a SCEVConstant operand divides by this constant.
It also generalized the assumption cache divisibility propagation to non-SCEVUnkinowns.
For example if:

TC = TC1 * TC2;
__builtin_assume(TC % 8 == 0);

We now propagate the information that SCEVMul divides by 8, by rewriting it to ((TC1 * TC2) \ 8) * 8

This fixes #82367.
---
 llvm/lib/Analysis/ScalarEvolution.cpp         | 33 ++++++++--------
 .../ScalarEvolution/trip-count-minmax.ll      |  8 +++-
 .../trip-multiple-guard-info.ll               | 38 +++++++++++++++++++
 3 files changed, 62 insertions(+), 17 deletions(-)

diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index acc0aa23107bb5..f0ff4d7bbab65a 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -15035,6 +15035,13 @@ class SCEVLoopGuardRewriter : public SCEVRewriteVisitor<SCEVLoopGuardRewriter> {
       return SCEVRewriteVisitor<SCEVLoopGuardRewriter>::visitSMinExpr(Expr);
     return I->second;
   }
+
+  const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
+    auto I = Map.find(Expr);
+    if (I == Map.end())
+      return SCEVRewriteVisitor<SCEVLoopGuardRewriter>::visitMulExpr(Expr);
+    return I->second;
+  }
 };
 
 const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
@@ -15184,17 +15191,14 @@ const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
       const SCEV *URemLHS = nullptr;
       const SCEV *URemRHS = nullptr;
       if (matchURem(LHS, URemLHS, URemRHS)) {
-        if (const SCEVUnknown *LHSUnknown = dyn_cast<SCEVUnknown>(URemLHS)) {
-          auto I = RewriteMap.find(LHSUnknown);
-          const SCEV *RewrittenLHS =
-              I != RewriteMap.end() ? I->second : LHSUnknown;
-          RewrittenLHS = ApplyDivisibiltyOnMinMaxExpr(RewrittenLHS, URemRHS);
-          const auto *Multiple =
-              getMulExpr(getUDivExpr(RewrittenLHS, URemRHS), URemRHS);
-          RewriteMap[LHSUnknown] = Multiple;
-          ExprsToRewrite.push_back(LHSUnknown);
-          return;
-        }
+        auto I = RewriteMap.find(URemLHS);
+        const SCEV *RewrittenLHS = I != RewriteMap.end() ? I->second : URemLHS;
+        RewrittenLHS = ApplyDivisibiltyOnMinMaxExpr(RewrittenLHS, URemRHS);
+        const auto *Multiple =
+            getMulExpr(getUDivExpr(RewrittenLHS, URemRHS), URemRHS);
+        RewriteMap[URemLHS] = Multiple;
+        ExprsToRewrite.push_back(URemLHS);
+        return;
       }
     }
 
@@ -15243,11 +15247,8 @@ const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
             auto *MulRHS = Mul->getOperand(1);
             if (isa<SCEVConstant>(MulLHS))
               std::swap(MulLHS, MulRHS);
-            if (auto *Div = dyn_cast<SCEVUDivExpr>(MulLHS))
-              if (Div->getOperand(1) == MulRHS) {
-                DividesBy = MulRHS;
-                return true;
-              }
+            DividesBy = MulRHS;
+            return true;
           }
           if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr))
             return HasDivisibiltyInfo(MinMax->getOperand(0), DividesBy) ||
diff --git a/llvm/test/Analysis/ScalarEvolution/trip-count-minmax.ll b/llvm/test/Analysis/ScalarEvolution/trip-count-minmax.ll
index 8d091a00ed4b97..71f0213ba0efa5 100644
--- a/llvm/test/Analysis/ScalarEvolution/trip-count-minmax.ll
+++ b/llvm/test/Analysis/ScalarEvolution/trip-count-minmax.ll
@@ -61,7 +61,7 @@ define void @umin(i32 noundef %a, i32 noundef %b) {
 ; CHECK-NEXT:  Loop %for.body: backedge-taken count is (-1 + ((2 * %a) umin (4 * %b)))
 ; CHECK-NEXT:  Loop %for.body: constant max backedge-taken count is i32 2147483646
 ; CHECK-NEXT:  Loop %for.body: symbolic max backedge-taken count is (-1 + ((2 * %a) umin (4 * %b)))
-; CHECK-NEXT:  Loop %for.body: Trip multiple is 1
+; CHECK-NEXT:  Loop %for.body: Trip multiple is 2
 ;
 ; void umin(unsigned a, unsigned b) {
 ;   a *= 2;
@@ -157,7 +157,13 @@ define void @smin(i32 noundef %a, i32 noundef %b) {
 ; CHECK-NEXT:  Loop %for.body: backedge-taken count is (-1 + ((2 * %a)<nsw> smin (4 * %b)<nsw>))
 ; CHECK-NEXT:  Loop %for.body: constant max backedge-taken count is i32 2147483646
 ; CHECK-NEXT:  Loop %for.body: symbolic max backedge-taken count is (-1 + ((2 * %a)<nsw> smin (4 * %b)<nsw>))
+<<<<<<< HEAD
 ; CHECK-NEXT:  Loop %for.body: Trip multiple is 1
+=======
+; CHECK-NEXT:  Loop %for.body: Predicated backedge-taken count is (-1 + ((2 * %a)<nsw> smin (4 * %b)<nsw>))
+; CHECK-NEXT:   Predicates:
+; CHECK-NEXT:  Loop %for.body: Trip multiple is 2
+>>>>>>> 9e0555c485cf ([SCEV] Improve applyLoopGuards to suppport Mul)
 ;
 ; void smin(signed a, signed b) {
 ;   a *= 2;
diff --git a/llvm/test/Analysis/ScalarEvolution/trip-multiple-guard-info.ll b/llvm/test/Analysis/ScalarEvolution/trip-multiple-guard-info.ll
index bf140c7fa216a9..433a9b6fb48ac6 100644
--- a/llvm/test/Analysis/ScalarEvolution/trip-multiple-guard-info.ll
+++ b/llvm/test/Analysis/ScalarEvolution/trip-multiple-guard-info.ll
@@ -574,5 +574,43 @@ exit:
   ret void
 }
 
+define void @test_trip_scevmul_multiple_5(i32 %num1, i32 %num2) {
+; CHECK-LABEL: 'test_trip_scevmul_multiple_5'
+; CHECK-NEXT:  Classifying expressions for: @test_trip_scevmul_multiple_5
+; CHECK-NEXT:    %num = mul i32 %num1, %num2
+; CHECK-NEXT:    --> (%num1 * %num2) U: full-set S: full-set
+; CHECK-NEXT:    %u = urem i32 %num, 5
+; CHECK-NEXT:    --> ((-5 * ((%num1 * %num2) /u 5)) + (%num1 * %num2)) U: full-set S: full-set
+; CHECK-NEXT:    %i.010 = phi i32 [ 0, %entry ], [ %inc, %for.body ]
+; CHECK-NEXT:    --> {0,+,1}<nuw><nsw><%for.body> U: [0,-2147483648) S: [0,-2147483648) Exits: (-1 + (%num1 * %num2)) LoopDispositions: { %for.body: Computable }
+; CHECK-NEXT:    %inc = add nuw nsw i32 %i.010, 1
+; CHECK-NEXT:    --> {1,+,1}<nuw><nsw><%for.body> U: [1,-2147483648) S: [1,-2147483648) Exits: (%num1 * %num2) LoopDispositions: { %for.body: Computable }
+; CHECK-NEXT:  Determining loop execution counts for: @test_trip_scevmul_multiple_5
+; CHECK-NEXT:  Loop %for.body: backedge-taken count is (-1 + (%num1 * %num2))
+; CHECK-NEXT:  Loop %for.body: constant max backedge-taken count is -2
+; CHECK-NEXT:  Loop %for.body: symbolic max backedge-taken count is (-1 + (%num1 * %num2))
+; CHECK-NEXT:  Loop %for.body: Predicated backedge-taken count is (-1 + (%num1 * %num2))
+; CHECK-NEXT:   Predicates:
+; CHECK-NEXT:  Loop %for.body: Trip multiple is 5
+;
+entry:
+  %num = mul i32 %num1, %num2
+  %u = urem i32 %num, 5
+  %cmp = icmp eq i32 %u, 0
+  tail call void @llvm.assume(i1 %cmp)
+  %cmp.1 = icmp uge i32 %num, 5
+  tail call void @llvm.assume(i1 %cmp.1)
+  br label %for.body
+
+for.body:
+  %i.010 = phi i32 [ 0, %entry ], [ %inc, %for.body ]
+  %inc = add nuw nsw i32 %i.010, 1
+  %cmp2 = icmp ult i32 %inc, %num
+  br i1 %cmp2, label %for.body, label %exit
+
+exit:
+  ret void
+}
+
 declare void @llvm.assume(i1)
 declare void @llvm.experimental.guard(i1, ...)

>From 8c87ab0f2a638fea08810105727712650aa2be9d Mon Sep 17 00:00:00 2001
From: komalon1 <alon.kom at mobileye.com>
Date: Sun, 10 Mar 2024 11:21:43 +0200
Subject: [PATCH 2/2] Split patches, improve documentation

---
 llvm/lib/Analysis/ScalarEvolution.cpp         | 37 +++++++++---------
 .../ScalarEvolution/trip-count-minmax.ll      |  6 ---
 .../trip-multiple-guard-info.ll               | 38 -------------------
 3 files changed, 17 insertions(+), 64 deletions(-)

diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index f0ff4d7bbab65a..c382fb8e0ad9d7 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -15035,13 +15035,6 @@ class SCEVLoopGuardRewriter : public SCEVRewriteVisitor<SCEVLoopGuardRewriter> {
       return SCEVRewriteVisitor<SCEVLoopGuardRewriter>::visitSMinExpr(Expr);
     return I->second;
   }
-
-  const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
-    auto I = Map.find(Expr);
-    if (I == Map.end())
-      return SCEVRewriteVisitor<SCEVLoopGuardRewriter>::visitMulExpr(Expr);
-    return I->second;
-  }
 };
 
 const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
@@ -15191,14 +15184,17 @@ const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
       const SCEV *URemLHS = nullptr;
       const SCEV *URemRHS = nullptr;
       if (matchURem(LHS, URemLHS, URemRHS)) {
-        auto I = RewriteMap.find(URemLHS);
-        const SCEV *RewrittenLHS = I != RewriteMap.end() ? I->second : URemLHS;
-        RewrittenLHS = ApplyDivisibiltyOnMinMaxExpr(RewrittenLHS, URemRHS);
-        const auto *Multiple =
-            getMulExpr(getUDivExpr(RewrittenLHS, URemRHS), URemRHS);
-        RewriteMap[URemLHS] = Multiple;
-        ExprsToRewrite.push_back(URemLHS);
-        return;
+        if (const SCEVUnknown *LHSUnknown = dyn_cast<SCEVUnknown>(URemLHS)) {
+          auto I = RewriteMap.find(LHSUnknown);
+          const SCEV *RewrittenLHS =
+              I != RewriteMap.end() ? I->second : LHSUnknown;
+          RewrittenLHS = ApplyDivisibiltyOnMinMaxExpr(RewrittenLHS, URemRHS);
+          const auto *Multiple =
+              getMulExpr(getUDivExpr(RewrittenLHS, URemRHS), URemRHS);
+          RewriteMap[LHSUnknown] = Multiple;
+          ExprsToRewrite.push_back(LHSUnknown);
+          return;
+        }
       }
     }
 
@@ -15231,13 +15227,14 @@ const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
       return I != RewriteMap.end() ? I->second : S;
     };
 
-    // Check for the SCEV expression (A /u B) * B while B is a constant, inside
+    // Check for the SCEV expression A * B while B is a constant, inside
     // \p Expr. The check is done recuresively on \p Expr, which is assumed to
     // be a composition of Min/Max SCEVs. Return whether the SCEV expression (A
-    // /u B) * B was found, and return the divisor B in \p DividesBy. For
-    // example, if Expr = umin (umax ((A /u 8) * 8, 16), 64), return true since
-    // (A /u 8) * 8 matched the pattern, and return the constant SCEV 8 in \p
-    // DividesBy.
+    // * B) was found, and return the divisor B in \p DividesBy. For example, if
+    // Expr = umin (umax ((A /u 8) * 8, 16), 64), return true since (A /u 8) * 8
+    // matched the pattern, and return the constant SCEV 8 in \p DividesBy.
+    // The returned \p DividesBy is just a candidate, and later on it is checked
+    // if the whole expression divides by it.
     std::function<bool(const SCEV *, const SCEV *&)> HasDivisibiltyInfo =
         [&](const SCEV *Expr, const SCEV *&DividesBy) {
           if (auto *Mul = dyn_cast<SCEVMulExpr>(Expr)) {
diff --git a/llvm/test/Analysis/ScalarEvolution/trip-count-minmax.ll b/llvm/test/Analysis/ScalarEvolution/trip-count-minmax.ll
index 71f0213ba0efa5..d38010403dad79 100644
--- a/llvm/test/Analysis/ScalarEvolution/trip-count-minmax.ll
+++ b/llvm/test/Analysis/ScalarEvolution/trip-count-minmax.ll
@@ -157,13 +157,7 @@ define void @smin(i32 noundef %a, i32 noundef %b) {
 ; CHECK-NEXT:  Loop %for.body: backedge-taken count is (-1 + ((2 * %a)<nsw> smin (4 * %b)<nsw>))
 ; CHECK-NEXT:  Loop %for.body: constant max backedge-taken count is i32 2147483646
 ; CHECK-NEXT:  Loop %for.body: symbolic max backedge-taken count is (-1 + ((2 * %a)<nsw> smin (4 * %b)<nsw>))
-<<<<<<< HEAD
-; CHECK-NEXT:  Loop %for.body: Trip multiple is 1
-=======
-; CHECK-NEXT:  Loop %for.body: Predicated backedge-taken count is (-1 + ((2 * %a)<nsw> smin (4 * %b)<nsw>))
-; CHECK-NEXT:   Predicates:
 ; CHECK-NEXT:  Loop %for.body: Trip multiple is 2
->>>>>>> 9e0555c485cf ([SCEV] Improve applyLoopGuards to suppport Mul)
 ;
 ; void smin(signed a, signed b) {
 ;   a *= 2;
diff --git a/llvm/test/Analysis/ScalarEvolution/trip-multiple-guard-info.ll b/llvm/test/Analysis/ScalarEvolution/trip-multiple-guard-info.ll
index 433a9b6fb48ac6..bf140c7fa216a9 100644
--- a/llvm/test/Analysis/ScalarEvolution/trip-multiple-guard-info.ll
+++ b/llvm/test/Analysis/ScalarEvolution/trip-multiple-guard-info.ll
@@ -574,43 +574,5 @@ exit:
   ret void
 }
 
-define void @test_trip_scevmul_multiple_5(i32 %num1, i32 %num2) {
-; CHECK-LABEL: 'test_trip_scevmul_multiple_5'
-; CHECK-NEXT:  Classifying expressions for: @test_trip_scevmul_multiple_5
-; CHECK-NEXT:    %num = mul i32 %num1, %num2
-; CHECK-NEXT:    --> (%num1 * %num2) U: full-set S: full-set
-; CHECK-NEXT:    %u = urem i32 %num, 5
-; CHECK-NEXT:    --> ((-5 * ((%num1 * %num2) /u 5)) + (%num1 * %num2)) U: full-set S: full-set
-; CHECK-NEXT:    %i.010 = phi i32 [ 0, %entry ], [ %inc, %for.body ]
-; CHECK-NEXT:    --> {0,+,1}<nuw><nsw><%for.body> U: [0,-2147483648) S: [0,-2147483648) Exits: (-1 + (%num1 * %num2)) LoopDispositions: { %for.body: Computable }
-; CHECK-NEXT:    %inc = add nuw nsw i32 %i.010, 1
-; CHECK-NEXT:    --> {1,+,1}<nuw><nsw><%for.body> U: [1,-2147483648) S: [1,-2147483648) Exits: (%num1 * %num2) LoopDispositions: { %for.body: Computable }
-; CHECK-NEXT:  Determining loop execution counts for: @test_trip_scevmul_multiple_5
-; CHECK-NEXT:  Loop %for.body: backedge-taken count is (-1 + (%num1 * %num2))
-; CHECK-NEXT:  Loop %for.body: constant max backedge-taken count is -2
-; CHECK-NEXT:  Loop %for.body: symbolic max backedge-taken count is (-1 + (%num1 * %num2))
-; CHECK-NEXT:  Loop %for.body: Predicated backedge-taken count is (-1 + (%num1 * %num2))
-; CHECK-NEXT:   Predicates:
-; CHECK-NEXT:  Loop %for.body: Trip multiple is 5
-;
-entry:
-  %num = mul i32 %num1, %num2
-  %u = urem i32 %num, 5
-  %cmp = icmp eq i32 %u, 0
-  tail call void @llvm.assume(i1 %cmp)
-  %cmp.1 = icmp uge i32 %num, 5
-  tail call void @llvm.assume(i1 %cmp.1)
-  br label %for.body
-
-for.body:
-  %i.010 = phi i32 [ 0, %entry ], [ %inc, %for.body ]
-  %inc = add nuw nsw i32 %i.010, 1
-  %cmp2 = icmp ult i32 %inc, %num
-  br i1 %cmp2, label %for.body, label %exit
-
-exit:
-  ret void
-}
-
 declare void @llvm.assume(i1)
 declare void @llvm.experimental.guard(i1, ...)



More information about the llvm-commits mailing list