[llvm] [SCEV] Collect and merge loop guards through PHI nodes with multiple incoming values (PR #113915)

Julian Nagele via llvm-commits llvm-commits at lists.llvm.org
Mon Oct 28 07:39:52 PDT 2024


https://github.com/juliannagele created https://github.com/llvm/llvm-project/pull/113915

This patch aims to strengthen collection of loop guards by processing PHI nodes with multiple incoming values as follows: collect guards for all incoming values/blocks and try to merge them into a single one for the PHI node.

>From b595b8c98d547915c47f99e84dcf9204409aaaae Mon Sep 17 00:00:00 2001
From: Julian Nagele <j.nagele at apple.com>
Date: Fri, 18 Oct 2024 17:40:06 +0100
Subject: [PATCH] [SCEV] Collect and merge loop guards through PHI nodes with
 multiple incoming Values

---
 llvm/include/llvm/Analysis/ScalarEvolution.h  |  5 ++
 llvm/lib/Analysis/ScalarEvolution.cpp         | 85 +++++++++++++++++--
 .../Analysis/ScalarEvolution/trip-count.ll    | 82 ++++++++++++++++++
 .../Transforms/PhaseOrdering/X86/pr38280.ll   |  2 +-
 4 files changed, 164 insertions(+), 10 deletions(-)

diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h
index 179a2c38d9d3c2..cdc46cf24a0546 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolution.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolution.h
@@ -1316,6 +1316,11 @@ class ScalarEvolution {
 
     LoopGuards(ScalarEvolution &SE) : SE(SE) {}
 
+    static LoopGuards
+    collectFromBlock(ScalarEvolution &SE, ScalarEvolution::LoopGuards &Guards,
+                     const BasicBlock *Block, const BasicBlock *Pred,
+                     SmallPtrSet<const BasicBlock *, 8> VisitedBlocks);
+
   public:
     /// Collect rewrite map for loop guards for loop \p L, together with flags
     /// indicating if NUW and NSW can be preserved during rewriting.
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index c939270ed39a65..d9cab0471ef0f3 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -10648,7 +10648,7 @@ ScalarEvolution::getPredecessorWithUniqueSuccessorForBB(const BasicBlock *BB)
   if (const Loop *L = LI.getLoopFor(BB))
     return {L->getLoopPredecessor(), L->getHeader()};
 
-  return {nullptr, nullptr};
+  return {nullptr, BB};
 }
 
 /// SCEV structural equivalence is usually sufficient for testing whether two
@@ -15217,7 +15217,16 @@ bool ScalarEvolution::matchURem(const SCEV *Expr, const SCEV *&LHS,
 
 ScalarEvolution::LoopGuards
 ScalarEvolution::LoopGuards::collect(const Loop *L, ScalarEvolution &SE) {
+  BasicBlock *Header = L->getHeader();
+  BasicBlock *Pred = L->getLoopPredecessor();
   LoopGuards Guards(SE);
+  return collectFromBlock(SE, Guards, Header, Pred, {});
+}
+
+ScalarEvolution::LoopGuards ScalarEvolution::LoopGuards::collectFromBlock(
+    ScalarEvolution &SE, ScalarEvolution::LoopGuards &Guards,
+    const BasicBlock *Block, const BasicBlock *Pred,
+    SmallPtrSet<const BasicBlock *, 8> VisitedBlocks) {
   SmallVector<const SCEV *> ExprsToRewrite;
   auto CollectCondition = [&](ICmpInst::Predicate Predicate, const SCEV *LHS,
                               const SCEV *RHS,
@@ -15556,14 +15565,13 @@ ScalarEvolution::LoopGuards::collect(const Loop *L, ScalarEvolution &SE) {
     }
   };
 
-  BasicBlock *Header = L->getHeader();
   SmallVector<PointerIntPair<Value *, 1, bool>> Terms;
   // First, collect information from assumptions dominating the loop.
   for (auto &AssumeVH : SE.AC.assumptions()) {
     if (!AssumeVH)
       continue;
     auto *AssumeI = cast<CallInst>(AssumeVH);
-    if (!SE.DT.dominates(AssumeI, Header))
+    if (!SE.DT.dominates(AssumeI, Block))
       continue;
     Terms.emplace_back(AssumeI->getOperand(0), true);
   }
@@ -15574,8 +15582,8 @@ ScalarEvolution::LoopGuards::collect(const Loop *L, ScalarEvolution &SE) {
   if (GuardDecl)
     for (const auto *GU : GuardDecl->users())
       if (const auto *Guard = dyn_cast<IntrinsicInst>(GU))
-        if (Guard->getFunction() == Header->getParent() &&
-            SE.DT.dominates(Guard, Header))
+        if (Guard->getFunction() == Block->getParent() &&
+            SE.DT.dominates(Guard, Block))
           Terms.emplace_back(Guard->getArgOperand(0), true);
 
   // Third, collect conditions from dominating branches. Starting at the loop
@@ -15583,11 +15591,10 @@ ScalarEvolution::LoopGuards::collect(const Loop *L, ScalarEvolution &SE) {
   // predecessors that can be found that have unique successors leading to the
   // original header.
   // TODO: share this logic with isLoopEntryGuardedByCond.
-  for (std::pair<const BasicBlock *, const BasicBlock *> Pair(
-           L->getLoopPredecessor(), Header);
-       Pair.first;
+  std::pair<const BasicBlock *, const BasicBlock *> Pair(Pred, Block);
+  for (; Pair.first;
        Pair = SE.getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
-
+    VisitedBlocks.insert(Pair.second);
     const BranchInst *LoopEntryPredicate =
         dyn_cast<BranchInst>(Pair.first->getTerminator());
     if (!LoopEntryPredicate || LoopEntryPredicate->isUnconditional())
@@ -15596,6 +15603,66 @@ ScalarEvolution::LoopGuards::collect(const Loop *L, ScalarEvolution &SE) {
     Terms.emplace_back(LoopEntryPredicate->getCondition(),
                        LoopEntryPredicate->getSuccessor(0) == Pair.second);
   }
+  // Finally, if we stopped climbing the predecessor chain because
+  // there wasn't a unique one to continue, try to collect conditions
+  // for PHINodes by recursively following all of their incoming
+  // blocks and try to merge the found conditions to build a new one
+  // for the Phi.
+  if (Pair.second->hasNPredecessorsOrMore(2)) {
+    for (auto &Phi : Pair.second->phis()) {
+      if (!SE.isSCEVable(Phi.getType()))
+        continue;
+
+      using MinMaxPattern = std::pair<const SCEVConstant *, SCEVTypes>;
+      auto GetMinMaxConst = [&SE, &VisitedBlocks, &Pair,
+                             &Phi](unsigned int In) -> MinMaxPattern {
+        LoopGuards G(SE);
+        if (VisitedBlocks.insert(Phi.getIncomingBlock(In)).second)
+          collectFromBlock(SE, G, Pair.second, Phi.getIncomingBlock(In),
+                           VisitedBlocks);
+        const SCEV *S = G.RewriteMap[SE.getSCEV(Phi.getIncomingValue(In))];
+        auto *SM = dyn_cast_if_present<SCEVMinMaxExpr>(S);
+        if (!SM)
+          return {nullptr, scCouldNotCompute};
+        if (const SCEVConstant *C0 = dyn_cast<SCEVConstant>(SM->getOperand(0)))
+          return {C0, SM->getSCEVType()};
+        if (const SCEVConstant *C1 = dyn_cast<SCEVConstant>(SM->getOperand(1)))
+          return {C1, SM->getSCEVType()};
+        return {nullptr, scCouldNotCompute};
+      };
+      auto MergeMinMaxConst = [](MinMaxPattern P1,
+                                 MinMaxPattern P2) -> MinMaxPattern {
+        auto [C1, T1] = P1;
+        auto [C2, T2] = P2;
+        if (!C1 || !C2 || T1 != T2)
+          return {nullptr, scCouldNotCompute};
+        switch (T1) {
+        case scUMaxExpr:
+          return {C1->getAPInt().ult(C2->getAPInt()) ? C1 : C2, T1};
+        case scSMaxExpr:
+          return {C1->getAPInt().slt(C2->getAPInt()) ? C1 : C2, T1};
+        case scUMinExpr:
+          return {C1->getAPInt().ugt(C2->getAPInt()) ? C1 : C2, T1};
+        case scSMinExpr:
+          return {C1->getAPInt().sgt(C2->getAPInt()) ? C1 : C2, T1};
+        default:
+          llvm_unreachable("Trying to merge non-MinMaxExpr SCEVs.");
+        }
+      };
+      auto P = GetMinMaxConst(0);
+      for (unsigned int In = 1; In < Phi.getNumIncomingValues(); In++) {
+        if (!P.first)
+          break;
+        P = MergeMinMaxConst(P, GetMinMaxConst(In));
+      }
+      if (P.first) {
+        const SCEV *LHS = SE.getSCEV(const_cast<PHINode *>(&Phi));
+        SmallVector<const SCEV *, 2> Ops({P.first, LHS});
+        const SCEV *RHS = SE.getMinMaxExpr(P.second, Ops);
+        Guards.RewriteMap.insert({LHS, RHS});
+      }
+    }
+  }
 
   // Now apply the information from the collected conditions to
   // Guards.RewriteMap. Conditions are processed in reverse order, so the
diff --git a/llvm/test/Analysis/ScalarEvolution/trip-count.ll b/llvm/test/Analysis/ScalarEvolution/trip-count.ll
index 8fc5b9b4096127..7304409814b0e1 100644
--- a/llvm/test/Analysis/ScalarEvolution/trip-count.ll
+++ b/llvm/test/Analysis/ScalarEvolution/trip-count.ll
@@ -211,3 +211,85 @@ for.body:
 exit:
   ret void
 }
+
+define void @epilogue(i64 %count) {
+; CHECK-LABEL: 'epilogue'
+; CHECK-NEXT:  Determining loop execution counts for: @epilogue
+; CHECK-NEXT:  Loop %epilogue: backedge-taken count is (-1 + %count.epilogue)
+; CHECK-NEXT:  Loop %epilogue: constant max backedge-taken count is i64 6
+; CHECK-NEXT:  Loop %epilogue: symbolic max backedge-taken count is (-1 + %count.epilogue)
+; CHECK-NEXT:  Loop %epilogue: Trip multiple is 1
+; CHECK-NEXT:  Loop %while.body: backedge-taken count is ((-8 + %count) /u 8)
+; CHECK-NEXT:  Loop %while.body: constant max backedge-taken count is i64 2305843009213693951
+; CHECK-NEXT:  Loop %while.body: symbolic max backedge-taken count is ((-8 + %count) /u 8)
+; CHECK-NEXT:  Loop %while.body: Trip multiple is 1
+entry:
+  %cmp = icmp ugt i64 %count, 7
+  br i1 %cmp, label %while.body, label %epilogue.preheader
+
+while.body:
+  %iv = phi i64 [ %sub, %while.body ], [ %count, %entry ]
+  %sub = add i64 %iv, -8
+  %exitcond.not = icmp ugt i64 %sub, 7
+  br i1 %exitcond.not, label %while.body, label %while.loopexit
+
+while.loopexit:
+  %sub.exit = phi i64 [ %sub, %while.body ]
+  br label %epilogue.preheader
+
+epilogue.preheader:
+  %count.epilogue = phi i64 [ %count, %entry ], [ %sub.exit, %while.loopexit ]
+  %epilogue.cmp = icmp eq i64 %count.epilogue, 0
+  br i1 %epilogue.cmp, label %exit, label %epilogue
+
+epilogue:
+  %iv.epilogue = phi i64 [ %dec, %epilogue ], [ %count.epilogue, %epilogue.preheader ]
+  %dec = add i64 %iv.epilogue, -1
+  %exitcond.epilogue = icmp eq i64 %dec, 0
+  br i1 %exitcond.epilogue, label %exit, label %epilogue
+
+exit:
+  ret void
+
+}
+
+define void @epilogue2(i64 %count) {
+; CHECK-LABEL: 'epilogue2'
+; CHECK-NEXT:  Determining loop execution counts for: @epilogue2
+; CHECK-NEXT:  Loop %epilogue: backedge-taken count is (-1 + %count.epilogue)
+; CHECK-NEXT:  Loop %epilogue: constant max backedge-taken count is i64 8
+; CHECK-NEXT:  Loop %epilogue: symbolic max backedge-taken count is (-1 + %count.epilogue)
+; CHECK-NEXT:  Loop %epilogue: Trip multiple is 1
+; CHECK-NEXT:  Loop %while.body: backedge-taken count is ((-8 + %count) /u 8)
+; CHECK-NEXT:  Loop %while.body: constant max backedge-taken count is i64 2305843009213693951
+; CHECK-NEXT:  Loop %while.body: symbolic max backedge-taken count is ((-8 + %count) /u 8)
+; CHECK-NEXT:  Loop %while.body: Trip multiple is 1
+entry:
+  %cmp = icmp ugt i64 %count, 9
+  br i1 %cmp, label %while.body, label %epilogue.preheader
+
+while.body:
+  %iv = phi i64 [ %sub, %while.body ], [ %count, %entry ]
+  %sub = add i64 %iv, -8
+  %exitcond.not = icmp ugt i64 %sub, 7
+  br i1 %exitcond.not, label %while.body, label %while.loopexit
+
+while.loopexit:
+  %sub.exit = phi i64 [ %sub, %while.body ]
+  br label %epilogue.preheader
+
+epilogue.preheader:
+  %count.epilogue = phi i64 [ %count, %entry ], [ %sub.exit, %while.loopexit ]
+  %epilogue.cmp = icmp eq i64 %count.epilogue, 0
+  br i1 %epilogue.cmp, label %exit, label %epilogue
+
+epilogue:
+  %iv.epilogue = phi i64 [ %dec, %epilogue ], [ %count.epilogue, %epilogue.preheader ]
+  %dec = add i64 %iv.epilogue, -1
+  %exitcond.epilogue = icmp eq i64 %dec, 0
+  br i1 %exitcond.epilogue, label %exit, label %epilogue
+
+exit:
+  ret void
+
+}
diff --git a/llvm/test/Transforms/PhaseOrdering/X86/pr38280.ll b/llvm/test/Transforms/PhaseOrdering/X86/pr38280.ll
index 70b002f766b753..966d7e3cded0ab 100644
--- a/llvm/test/Transforms/PhaseOrdering/X86/pr38280.ll
+++ b/llvm/test/Transforms/PhaseOrdering/X86/pr38280.ll
@@ -41,7 +41,7 @@ define void @apply_delta(ptr nocapture noundef %dst, ptr nocapture noundef reado
 ; CHECK-NEXT:    [[INCDEC_PTR]] = getelementptr inbounds i8, ptr [[DST_ADDR_130]], i64 1
 ; CHECK-NEXT:    [[INCDEC_PTR8]] = getelementptr inbounds i8, ptr [[SRC_ADDR_129]], i64 1
 ; CHECK-NEXT:    [[TOBOOL_NOT:%.*]] = icmp eq i64 [[DEC]], 0
-; CHECK-NEXT:    br i1 [[TOBOOL_NOT]], label [[WHILE_END9]], label [[WHILE_BODY4]], !llvm.loop [[LOOP0:![0-9]+]]
+; CHECK-NEXT:    br i1 [[TOBOOL_NOT]], label [[WHILE_END9]], label [[WHILE_BODY4]]
 ; CHECK:       while.end9:
 ; CHECK-NEXT:    ret void
 ;



More information about the llvm-commits mailing list