[llvm] 9a7286b - [SCEV] Help getLoopInvariantExitCondDuringFirstIterations deal with complex `umin` exit counts. PR59615

Max Kazantsev via llvm-commits llvm-commits at lists.llvm.org
Wed Dec 21 03:12:24 PST 2022


Author: Max Kazantsev
Date: 2022-12-21T18:12:17+07:00
New Revision: 9a7286b61f61b78db1197fdedcaaa4cac1b04fe3

URL: https://github.com/llvm/llvm-project/commit/9a7286b61f61b78db1197fdedcaaa4cac1b04fe3
DIFF: https://github.com/llvm/llvm-project/commit/9a7286b61f61b78db1197fdedcaaa4cac1b04fe3.diff

LOG: [SCEV] Help getLoopInvariantExitCondDuringFirstIterations deal with complex `umin` exit counts. PR59615

Recent improvements in symbolic exit count computation revealed some problems with
SCEV's ability to find invariant predicate during first iterations. Ultimately it is based on its
ability to prove some facts for value on the last iteration. This last value, when it includes
`umin` as part of exit count, isn't always simplified enough. The motivating example is following:

https://github.com/llvm/llvm-project/issues/59615

Could not prove:
```
        Pred = 36, LHS = (-1 + (-1 * (2147483645 umin (-1 + %var)<nsw>))<nsw> + %var), RHS = %var
        FoundPred = 36, FoundLHS = {1,+,1}<nuw><nsw><%bb3>, FoundRHS = %var
```
Can prove:
```
        Pred = 36, LHS = (-1 + (-1 * (-1 + %var)<nsw>)<nsw> + %var), RHS = %var
        FoundPred = 36, FoundLHS = {1,+,1}<nuw><nsw><%bb3>, FoundRHS = %var
```

Here ` (2147483645 umin (-1 + %var)<nsw>)` is exit count composed of two parts from
two different exits: `2147483645 ` and `(-1 + %var)<nsw>`. When it was only one (latter)
analyzeable exit, for it everything was easily provable. Unfortunately, in general case `umin`
in one of `add`'s operands doesn't guarantee that the whole sum reduces, especially in presence
of negative steps and lack of `nuw`. I don't think there is a generic legal way to somehow play
around this `umin`.

So the ad-hoc solution is following: if we failed to find an equivalent predicate that is invariant
during first `MaxIter` iterations, and `MaxIter = umin(a, b, c...)`, try to find solution for at least one
of `a`, `b`, `c`... Because they all are `uge` than `MaxIter`, whatever is true during `a (b, c)` iterations
is also true during `MaxIter` iterations.

Differential Revision: https://reviews.llvm.org/D140456
Reviewed By: nikic

Added: 
    

Modified: 
    llvm/include/llvm/Analysis/ScalarEvolution.h
    llvm/lib/Analysis/ScalarEvolution.cpp
    llvm/test/Transforms/IndVarSimplify/X86/pr59615.ll

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h
index 917b6d1784697..eb194880d8e27 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolution.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolution.h
@@ -1122,6 +1122,11 @@ class ScalarEvolution {
                                                 const Instruction *CtxI,
                                                 const SCEV *MaxIter);
 
+  std::optional<LoopInvariantPredicate>
+  getLoopInvariantExitCondDuringFirstIterationsImpl(
+      ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L,
+      const Instruction *CtxI, const SCEV *MaxIter);
+
   /// Simplify LHS and RHS in a comparison with predicate Pred. Return true
   /// iff any changes were made. If the operands are provably equal or
   /// unequal, LHS and RHS are set to the same value and Pred is set to either

diff  --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 3c445aace3035..34a9cddcb08bd 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -11046,6 +11046,26 @@ std::optional<ScalarEvolution::LoopInvariantPredicate>
 ScalarEvolution::getLoopInvariantExitCondDuringFirstIterations(
     ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L,
     const Instruction *CtxI, const SCEV *MaxIter) {
+  if (auto LIP = getLoopInvariantExitCondDuringFirstIterationsImpl(
+          Pred, LHS, RHS, L, CtxI, MaxIter))
+    return LIP;
+  if (auto *UMin = dyn_cast<SCEVUMinExpr>(MaxIter))
+    // Number of iterations expressed as UMIN isn't always great for expressing
+    // the value on the last iteration. If the straightforward approach didn't
+    // work, try the following trick: if the a predicate is invariant for X, it
+    // is also invariant for umin(X, ...). So try to find something that works
+    // among subexpressions of MaxIter expressed as umin.
+    for (auto *Op : UMin->operands())
+      if (auto LIP = getLoopInvariantExitCondDuringFirstIterationsImpl(
+              Pred, LHS, RHS, L, CtxI, Op))
+        return LIP;
+  return std::nullopt;
+}
+
+std::optional<ScalarEvolution::LoopInvariantPredicate>
+ScalarEvolution::getLoopInvariantExitCondDuringFirstIterationsImpl(
+    ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L,
+    const Instruction *CtxI, const SCEV *MaxIter) {
   // Try to prove the following set of facts:
   // - The predicate is monotonic in the iteration space.
   // - If the check does not fail on the 1st iteration:

diff  --git a/llvm/test/Transforms/IndVarSimplify/X86/pr59615.ll b/llvm/test/Transforms/IndVarSimplify/X86/pr59615.ll
index 71f9d901b4590..17b7b9d40b07a 100644
--- a/llvm/test/Transforms/IndVarSimplify/X86/pr59615.ll
+++ b/llvm/test/Transforms/IndVarSimplify/X86/pr59615.ll
@@ -8,35 +8,30 @@ define void @test() {
 ; CHECK-LABEL: @test(
 ; CHECK-NEXT:  bb:
 ; CHECK-NEXT:    [[VAR:%.*]] = load atomic i32, ptr addrspace(1) poison unordered, align 8, !range [[RNG0:![0-9]+]], !invariant.load !1, !noundef !1
-; CHECK-NEXT:    [[VAR1:%.*]] = add nsw i32 [[VAR]], -1
 ; CHECK-NEXT:    [[VAR2:%.*]] = icmp eq i32 [[VAR]], 0
 ; CHECK-NEXT:    br i1 [[VAR2]], label [[BB18:%.*]], label [[BB19:%.*]]
 ; CHECK:       bb3:
 ; CHECK-NEXT:    [[INDVARS_IV:%.*]] = phi i64 [ 0, [[BB19]] ], [ [[INDVARS_IV_NEXT:%.*]], [[BB12:%.*]] ]
-; CHECK-NEXT:    [[TMP0:%.*]] = sub nsw i64 [[TMP3:%.*]], [[INDVARS_IV]]
-; CHECK-NEXT:    [[TMP1:%.*]] = trunc i64 [[TMP0]] to i32
-; CHECK-NEXT:    [[VAR6:%.*]] = icmp ult i32 [[TMP1]], [[VAR]]
-; CHECK-NEXT:    br i1 [[VAR6]], label [[BB8:%.*]], label [[BB7:%.*]]
+; CHECK-NEXT:    br i1 true, label [[BB8:%.*]], label [[BB7:%.*]]
 ; CHECK:       bb7:
 ; CHECK-NEXT:    ret void
 ; CHECK:       bb8:
 ; CHECK-NEXT:    [[VAR9:%.*]] = load atomic i32, ptr addrspace(1) poison unordered, align 8, !range [[RNG0]], !invariant.load !1, !noundef !1
-; CHECK-NEXT:    [[TMP2:%.*]] = zext i32 [[VAR9]] to i64
-; CHECK-NEXT:    [[VAR10:%.*]] = icmp ult i64 [[INDVARS_IV]], [[TMP2]]
+; CHECK-NEXT:    [[TMP0:%.*]] = zext i32 [[VAR9]] to i64
+; CHECK-NEXT:    [[VAR10:%.*]] = icmp ult i64 [[INDVARS_IV]], [[TMP0]]
 ; CHECK-NEXT:    br i1 [[VAR10]], label [[BB12]], label [[BB11:%.*]]
 ; CHECK:       bb11:
 ; CHECK-NEXT:    ret void
 ; CHECK:       bb12:
 ; CHECK-NEXT:    [[INDVARS_IV_NEXT]] = add nuw nsw i64 [[INDVARS_IV]], 1
-; CHECK-NEXT:    [[EXITCOND:%.*]] = icmp ne i64 [[INDVARS_IV_NEXT]], [[WIDE_TRIP_COUNT:%.*]]
+; CHECK-NEXT:    [[EXITCOND:%.*]] = icmp ne i64 [[INDVARS_IV_NEXT]], [[TMP1:%.*]]
 ; CHECK-NEXT:    br i1 [[EXITCOND]], label [[BB3:%.*]], label [[BB17:%.*]]
 ; CHECK:       bb17:
 ; CHECK-NEXT:    unreachable
 ; CHECK:       bb18:
 ; CHECK-NEXT:    ret void
 ; CHECK:       bb19:
-; CHECK-NEXT:    [[TMP3]] = sext i32 [[VAR1]] to i64
-; CHECK-NEXT:    [[WIDE_TRIP_COUNT]] = zext i32 [[VAR]] to i64
+; CHECK-NEXT:    [[TMP1]] = zext i32 [[VAR]] to i64
 ; CHECK-NEXT:    br label [[BB3]]
 ;
 bb:


        


More information about the llvm-commits mailing list