[llvm] [SCEV] Improve applyLoopGuards to support Mul (PR #83428)

via llvm-commits llvm-commits at lists.llvm.org
Thu Feb 29 05:52:19 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-analysis

Author: None (komalon1)

<details>
<summary>Changes</summary>

Improve applyLoopGuards to preserve divisibility information of SCEVMul Expressions.

This fixes #<!-- -->82367.

---
Full diff: https://github.com/llvm/llvm-project/pull/83428.diff


3 Files Affected:

- (modified) llvm/lib/Analysis/ScalarEvolution.cpp (+18-16) 
- (modified) llvm/test/Analysis/ScalarEvolution/trip-count-minmax.ll (+2-2) 
- (modified) llvm/test/Analysis/ScalarEvolution/trip-multiple-guard-info.ll (+38) 


``````````diff
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 4b2db80bc1ec30..052ae4923d4a43 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -15000,6 +15000,13 @@ class SCEVLoopGuardRewriter : public SCEVRewriteVisitor<SCEVLoopGuardRewriter> {
       return SCEVRewriteVisitor<SCEVLoopGuardRewriter>::visitSMinExpr(Expr);
     return I->second;
   }
+
+  const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
+    auto I = Map.find(Expr);
+    if (I == Map.end())
+      return SCEVRewriteVisitor<SCEVLoopGuardRewriter>::visitMulExpr(Expr);
+    return I->second;
+  }
 };
 
 const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
@@ -15149,17 +15156,15 @@ const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
       const SCEV *URemLHS = nullptr;
       const SCEV *URemRHS = nullptr;
       if (matchURem(LHS, URemLHS, URemRHS)) {
-        if (const SCEVUnknown *LHSUnknown = dyn_cast<SCEVUnknown>(URemLHS)) {
-          auto I = RewriteMap.find(LHSUnknown);
-          const SCEV *RewrittenLHS =
-              I != RewriteMap.end() ? I->second : LHSUnknown;
-          RewrittenLHS = ApplyDivisibiltyOnMinMaxExpr(RewrittenLHS, URemRHS);
-          const auto *Multiple =
-              getMulExpr(getUDivExpr(RewrittenLHS, URemRHS), URemRHS);
-          RewriteMap[LHSUnknown] = Multiple;
-          ExprsToRewrite.push_back(LHSUnknown);
-          return;
-        }
+        auto I = RewriteMap.find(URemLHS);
+        const SCEV *RewrittenLHS =
+            I != RewriteMap.end() ? I->second : URemLHS;
+        RewrittenLHS = ApplyDivisibiltyOnMinMaxExpr(RewrittenLHS, URemRHS);
+        const auto *Multiple =
+            getMulExpr(getUDivExpr(RewrittenLHS, URemRHS), URemRHS);
+        RewriteMap[URemLHS] = Multiple;
+        ExprsToRewrite.push_back(URemLHS);
+        return;
       }
     }
 
@@ -15208,11 +15213,8 @@ const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
             auto *MulRHS = Mul->getOperand(1);
             if (isa<SCEVConstant>(MulLHS))
               std::swap(MulLHS, MulRHS);
-            if (auto *Div = dyn_cast<SCEVUDivExpr>(MulLHS))
-              if (Div->getOperand(1) == MulRHS) {
-                DividesBy = MulRHS;
-                return true;
-              }
+            DividesBy = MulRHS;
+            return true;
           }
           if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr))
             return HasDivisibiltyInfo(MinMax->getOperand(0), DividesBy) ||
diff --git a/llvm/test/Analysis/ScalarEvolution/trip-count-minmax.ll b/llvm/test/Analysis/ScalarEvolution/trip-count-minmax.ll
index 7d4876baa9e5d9..482accc38cb391 100644
--- a/llvm/test/Analysis/ScalarEvolution/trip-count-minmax.ll
+++ b/llvm/test/Analysis/ScalarEvolution/trip-count-minmax.ll
@@ -65,7 +65,7 @@ define void @umin(i32 noundef %a, i32 noundef %b) {
 ; CHECK-NEXT:  Loop %for.body: symbolic max backedge-taken count is (-1 + ((2 * %a) umin (4 * %b)))
 ; CHECK-NEXT:  Loop %for.body: Predicated backedge-taken count is (-1 + ((2 * %a) umin (4 * %b)))
 ; CHECK-NEXT:   Predicates:
-; CHECK-NEXT:  Loop %for.body: Trip multiple is 1
+; CHECK-NEXT:  Loop %for.body: Trip multiple is 2
 ;
 ; void umin(unsigned a, unsigned b) {
 ;   a *= 2;
@@ -165,7 +165,7 @@ define void @smin(i32 noundef %a, i32 noundef %b) {
 ; CHECK-NEXT:  Loop %for.body: symbolic max backedge-taken count is (-1 + ((2 * %a)<nsw> smin (4 * %b)<nsw>))
 ; CHECK-NEXT:  Loop %for.body: Predicated backedge-taken count is (-1 + ((2 * %a)<nsw> smin (4 * %b)<nsw>))
 ; CHECK-NEXT:   Predicates:
-; CHECK-NEXT:  Loop %for.body: Trip multiple is 1
+; CHECK-NEXT:  Loop %for.body: Trip multiple is 2
 ;
 ; void smin(signed a, signed b) {
 ;   a *= 2;
diff --git a/llvm/test/Analysis/ScalarEvolution/trip-multiple-guard-info.ll b/llvm/test/Analysis/ScalarEvolution/trip-multiple-guard-info.ll
index a0a5158bdff160..25fe63e9d61bc0 100644
--- a/llvm/test/Analysis/ScalarEvolution/trip-multiple-guard-info.ll
+++ b/llvm/test/Analysis/ScalarEvolution/trip-multiple-guard-info.ll
@@ -607,5 +607,43 @@ exit:
   ret void
 }
 
+define void @test_trip_scevmul_multiple_5(i32 %num1, i32 %num2) {
+; CHECK-LABEL: 'test_trip_scevmul_multiple_5'
+; CHECK-NEXT:  Classifying expressions for: @test_trip_scevmul_multiple_5
+; CHECK-NEXT:    %num = mul i32 %num1, %num2
+; CHECK-NEXT:    --> (%num1 * %num2) U: full-set S: full-set
+; CHECK-NEXT:    %u = urem i32 %num, 5
+; CHECK-NEXT:    --> ((-5 * ((%num1 * %num2) /u 5)) + (%num1 * %num2)) U: full-set S: full-set
+; CHECK-NEXT:    %i.010 = phi i32 [ 0, %entry ], [ %inc, %for.body ]
+; CHECK-NEXT:    --> {0,+,1}<nuw><nsw><%for.body> U: [0,-2147483648) S: [0,-2147483648) Exits: (-1 + (%num1 * %num2)) LoopDispositions: { %for.body: Computable }
+; CHECK-NEXT:    %inc = add nuw nsw i32 %i.010, 1
+; CHECK-NEXT:    --> {1,+,1}<nuw><nsw><%for.body> U: [1,-2147483648) S: [1,-2147483648) Exits: (%num1 * %num2) LoopDispositions: { %for.body: Computable }
+; CHECK-NEXT:  Determining loop execution counts for: @test_trip_scevmul_multiple_5
+; CHECK-NEXT:  Loop %for.body: backedge-taken count is (-1 + (%num1 * %num2))
+; CHECK-NEXT:  Loop %for.body: constant max backedge-taken count is -2
+; CHECK-NEXT:  Loop %for.body: symbolic max backedge-taken count is (-1 + (%num1 * %num2))
+; CHECK-NEXT:  Loop %for.body: Predicated backedge-taken count is (-1 + (%num1 * %num2))
+; CHECK-NEXT:   Predicates:
+; CHECK-NEXT:  Loop %for.body: Trip multiple is 5
+;
+entry:
+  %num = mul i32 %num1, %num2
+  %u = urem i32 %num, 5
+  %cmp = icmp eq i32 %u, 0
+  tail call void @llvm.assume(i1 %cmp)
+  %cmp.1 = icmp uge i32 %num, 5
+  tail call void @llvm.assume(i1 %cmp.1)
+  br label %for.body
+
+for.body:
+  %i.010 = phi i32 [ 0, %entry ], [ %inc, %for.body ]
+  %inc = add nuw nsw i32 %i.010, 1
+  %cmp2 = icmp ult i32 %inc, %num
+  br i1 %cmp2, label %for.body, label %exit
+
+exit:
+  ret void
+}
+
 declare void @llvm.assume(i1)
 declare void @llvm.experimental.guard(i1, ...)

``````````

</details>


https://github.com/llvm/llvm-project/pull/83428


More information about the llvm-commits mailing list