[llvm] 1006ac3 - [LoopNest] Consider loop nest with inner loop guard using outer loop

Whitney Tsang via llvm-commits llvm-commits at lists.llvm.org
Fri May 7 09:04:37 PDT 2021

Author: Whitney Tsang
Date: 2021-05-07T16:04:18Z
New Revision: 1006ac3963eaf39153d6637b631662e87ebf3b4d

URL: https://github.com/llvm/llvm-project/commit/1006ac3963eaf39153d6637b631662e87ebf3b4d
DIFF: https://github.com/llvm/llvm-project/commit/1006ac3963eaf39153d6637b631662e87ebf3b4d.diff

LOG: [LoopNest] Consider loop nest with inner loop guard using outer loop
induction variable to be perfect

This patch allow more conditional branches to be considered as loop
guard, and so more loop nests can be considered perfect.

Reviewed By: bmahjour, sidbav

Differential Revision: https://reviews.llvm.org/D94717




diff  --git a/llvm/include/llvm/Analysis/LoopNestAnalysis.h b/llvm/include/llvm/Analysis/LoopNestAnalysis.h
index ace17547444f7..e045419f8d537 100644
--- a/llvm/include/llvm/Analysis/LoopNestAnalysis.h
+++ b/llvm/include/llvm/Analysis/LoopNestAnalysis.h
@@ -61,10 +61,12 @@ class LoopNest {
   static unsigned getMaxPerfectDepth(const Loop &Root, ScalarEvolution &SE);
   /// Recursivelly traverse all empty 'single successor' basic blocks of \p From
-  /// (if there are any). Return the last basic block found or \p End if it was
-  /// reached during the search.
+  /// (if there are any). When \p CheckUniquePred is set to true, check if
+  /// each of the empty single successors has a unique predecessor. Return
+  /// the last basic block found or \p End if it was reached during the search.
   static const BasicBlock &skipEmptyBlockUntil(const BasicBlock *From,
-                                               const BasicBlock *End);
+                                               const BasicBlock *End,
+                                               bool CheckUniquePred = false);
   /// Return the outermost loop in the loop nest.
   Loop &getOutermostLoop() const { return *Loops.front(); }

diff  --git a/llvm/lib/Analysis/LoopInfo.cpp b/llvm/lib/Analysis/LoopInfo.cpp
index adb2bdb184c2f..b2d7edb356689 100644
--- a/llvm/lib/Analysis/LoopInfo.cpp
+++ b/llvm/lib/Analysis/LoopInfo.cpp
@@ -20,6 +20,7 @@
 #include "llvm/Analysis/IVDescriptors.h"
 #include "llvm/Analysis/LoopInfoImpl.h"
 #include "llvm/Analysis/LoopIterator.h"
+#include "llvm/Analysis/LoopNestAnalysis.h"
 #include "llvm/Analysis/MemorySSA.h"
 #include "llvm/Analysis/MemorySSAUpdater.h"
 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
@@ -380,10 +381,6 @@ BranchInst *Loop::getLoopGuardBranch() const {
   if (!ExitFromLatch)
     return nullptr;
-  BasicBlock *ExitFromLatchSucc = ExitFromLatch->getUniqueSuccessor();
-  if (!ExitFromLatchSucc)
-    return nullptr;
   BasicBlock *GuardBB = Preheader->getUniquePredecessor();
   if (!GuardBB)
     return nullptr;
@@ -397,7 +394,17 @@ BranchInst *Loop::getLoopGuardBranch() const {
   BasicBlock *GuardOtherSucc = (GuardBI->getSuccessor(0) == Preheader)
                                    ? GuardBI->getSuccessor(1)
                                    : GuardBI->getSuccessor(0);
-  return (GuardOtherSucc == ExitFromLatchSucc) ? GuardBI : nullptr;
+  // Check if ExitFromLatch (or any BasicBlock which is an empty unique
+  // successor of ExitFromLatch) is equal to GuardOtherSucc. If
+  // skipEmptyBlockUntil returns GuardOtherSucc, then the guard branch for the
+  // loop is GuardBI (return GuardBI), otherwise return nullptr.
+  if (&LoopNest::skipEmptyBlockUntil(ExitFromLatch, GuardOtherSucc,
+                                     /*CheckUniquePred=*/true) ==
+      GuardOtherSucc)
+    return GuardBI;
+  else
+    return nullptr;
 bool Loop::isCanonical(ScalarEvolution &SE) const {

diff  --git a/llvm/lib/Analysis/LoopNestAnalysis.cpp b/llvm/lib/Analysis/LoopNestAnalysis.cpp
index ee74d4b0d04be..2649ed60f762b 100644
--- a/llvm/lib/Analysis/LoopNestAnalysis.cpp
+++ b/llvm/lib/Analysis/LoopNestAnalysis.cpp
@@ -206,7 +206,8 @@ unsigned LoopNest::getMaxPerfectDepth(const Loop &Root, ScalarEvolution &SE) {
 const BasicBlock &LoopNest::skipEmptyBlockUntil(const BasicBlock *From,
-                                                const BasicBlock *End) {
+                                                const BasicBlock *End,
+                                                bool CheckUniquePred) {
   assert(From && "Expecting valid From");
   assert(End && "Expecting valid End");
@@ -220,8 +221,9 @@ const BasicBlock &LoopNest::skipEmptyBlockUntil(const BasicBlock *From,
   // Visited is used to avoid running into an infinite loop.
   SmallPtrSet<const BasicBlock *, 4> Visited;
   const BasicBlock *BB = From->getUniqueSuccessor();
-  const BasicBlock *PredBB = BB;
-  while (BB && BB != End && IsEmpty(BB) && !Visited.count(BB)) {
+  const BasicBlock *PredBB = From;
+  while (BB && BB != End && IsEmpty(BB) && !Visited.count(BB) &&
+         (!CheckUniquePred || BB->getUniquePredecessor())) {
     PredBB = BB;
     BB = BB->getUniqueSuccessor();
@@ -335,9 +337,11 @@ static bool checkLoopsStructure(const Loop &OuterLoop, const Loop &InnerLoop,
   // Ensure the inner loop exit block lead to the outer loop latch possibly
   // through empty blocks.
-  const BasicBlock &SuccInner =
-      LoopNest::skipEmptyBlockUntil(InnerLoop.getExitBlock(), OuterLoopLatch);
-  if (&SuccInner != OuterLoopLatch && &SuccInner != ExtraPhiBlock) {
+  if ((!ExtraPhiBlock ||
+       &LoopNest::skipEmptyBlockUntil(InnerLoop.getExitBlock(),
+                                      ExtraPhiBlock) != ExtraPhiBlock) &&
+      (&LoopNest::skipEmptyBlockUntil(InnerLoop.getExitBlock(),
+                                      OuterLoopLatch) != OuterLoopLatch)) {
         dbgs() << "Inner loop exit block " << *InnerLoopExit

diff  --git a/llvm/test/Analysis/LoopNestAnalysis/imperfectnest.ll b/llvm/test/Analysis/LoopNestAnalysis/imperfectnest.ll
index 4c8066ec58775..77b361bc6baef 100644
--- a/llvm/test/Analysis/LoopNestAnalysis/imperfectnest.ll
+++ b/llvm/test/Analysis/LoopNestAnalysis/imperfectnest.ll
@@ -424,70 +424,3 @@ for.cond.for.end13_crit_edge:
   ret void
-; Test an imperfect loop nest of the form:
-;   for (int i = 0; i < nx; ++i)
-;     if (i > 5) { // user branch
-;       for (int j = 1; j <= 5; j+=2)
-;         y[j][i] = x[i][j] + j;
-;     }
-define void @imperf_nest_6(i32** %y, i32** %x, i32 signext %nx, i32 signext %ny) {
-;    CHECK-LABEL: IsPerfect=false, Depth=2, OutermostLoop: imperf_nest_6_loop_i, Loops: ( imperf_nest_6_loop_i imperf_nest_6_loop_j )
-  %cmp2 = icmp slt i32 0, %nx
-  br i1 %cmp2, label %imperf_nest_6_loop_i.lr.ph, label %for.end13
-  br label %imperf_nest_6_loop_i
-  %i.0 = phi i32 [ 0, %imperf_nest_6_loop_i.lr.ph ], [ %inc12, %for.inc11 ]
-  %cmp1 = icmp sgt i32 %i.0, 5
-  br i1 %cmp1, label %imperf_nest_6_loop_j.lr.ph, label %if.end
-  br label %imperf_nest_6_loop_j
-  %j.0 = phi i32 [ 1, %imperf_nest_6_loop_j.lr.ph ], [ %inc, %for.inc ]
-  %idxprom = sext i32 %i.0 to i64
-  %arrayidx = getelementptr inbounds i32*, i32** %x, i64 %idxprom
-  %0 = load i32*, i32** %arrayidx, align 8
-  %idxprom5 = sext i32 %j.0 to i64
-  %arrayidx6 = getelementptr inbounds i32, i32* %0, i64 %idxprom5
-  %1 = load i32, i32* %arrayidx6, align 4
-  %add = add nsw i32 %1, %j.0
-  %idxprom7 = sext i32 %j.0 to i64
-  %arrayidx8 = getelementptr inbounds i32*, i32** %y, i64 %idxprom7
-  %2 = load i32*, i32** %arrayidx8, align 8
-  %idxprom9 = sext i32 %i.0 to i64
-  %arrayidx10 = getelementptr inbounds i32, i32* %2, i64 %idxprom9
-  store i32 %add, i32* %arrayidx10, align 4
-  br label %for.inc
-  %inc = add nsw i32 %j.0, 2
-  %cmp3 = icmp sle i32 %inc, 5
-  br i1 %cmp3, label %imperf_nest_6_loop_j, label %for.cond2.for.end_crit_edge
-  br label %for.end
-  br label %if.end
-  br label %for.inc11
-  %inc12 = add nsw i32 %i.0, 1
-  %cmp = icmp slt i32 %inc12, %nx
-  br i1 %cmp, label %imperf_nest_6_loop_i, label %for.cond.for.end13_crit_edge
-  br label %for.end13
-  ret void

diff  --git a/llvm/test/Analysis/LoopNestAnalysis/perfectnest.ll b/llvm/test/Analysis/LoopNestAnalysis/perfectnest.ll
index 7593d6f1748b3..f8b0e6ad2c884 100644
--- a/llvm/test/Analysis/LoopNestAnalysis/perfectnest.ll
+++ b/llvm/test/Analysis/LoopNestAnalysis/perfectnest.ll
@@ -322,3 +322,148 @@ for.end7:
   %x.addr.0.lcssa = phi i32 [ %split7, %for.cond.for.end7_crit_edge ], [ %x, %entry ]
   ret i32 %x.addr.0.lcssa
+; Test a perfect loop nest of the form:
+;   for (int i = 0; i < nx; ++i)
+;     if (i < ny) { // guard branch for the j-loop
+;       for (int j=i; j < ny; j+=1)
+;         y[j][i] = x[i][j] + j;
+;     }
+define double @perf_nest_guard_branch(i32** %y, i32** %x, i32 signext %nx, i32 signext %ny) {
+; CHECK-LABEL: IsPerfect=true, Depth=1, OutermostLoop: test6Loop2, Loops: ( test6Loop2 )
+; CHECK-LABEL: IsPerfect=true, Depth=2, OutermostLoop: test6Loop1, Loops: ( test6Loop1 test6Loop2 )
+  %cmp2 = icmp slt i32 0, %nx
+  br i1 %cmp2, label %test6Loop1.lr.ph, label %for.end13
+test6Loop1.lr.ph:                                   ; preds = %entry
+  br label %test6Loop1
+test6Loop1:                                         ; preds = %test6Loop1.lr.ph, %for.inc11
+  %i.0 = phi i32 [ 0, %test6Loop1.lr.ph ], [ %inc12, %for.inc11 ]
+  %cmp1 = icmp slt i32 %i.0, %ny
+  br i1 %cmp1, label %test6Loop2.lr.ph, label %if.end
+test6Loop2.lr.ph:                                  ; preds = %if.then
+  br label %test6Loop2
+test6Loop2:                                        ; preds = %test6Loop2.lr.ph, %for.inc
+  %j.0 = phi i32 [ %i.0, %test6Loop2.lr.ph ], [ %inc, %for.inc ]
+  %idxprom = sext i32 %i.0 to i64
+  %arrayidx = getelementptr inbounds i32*, i32** %x, i64 %idxprom
+  %0 = load i32*, i32** %arrayidx, align 8
+  %idxprom5 = sext i32 %j.0 to i64
+  %arrayidx6 = getelementptr inbounds i32, i32* %0, i64 %idxprom5
+  %1 = load i32, i32* %arrayidx6, align 4
+  %add = add nsw i32 %1, %j.0
+  %idxprom7 = sext i32 %j.0 to i64
+  %arrayidx8 = getelementptr inbounds i32*, i32** %y, i64 %idxprom7
+  %2 = load i32*, i32** %arrayidx8, align 8
+  %idxprom9 = sext i32 %i.0 to i64
+  %arrayidx10 = getelementptr inbounds i32, i32* %2, i64 %idxprom9
+  store i32 %add, i32* %arrayidx10, align 4
+  br label %for.inc
+for.inc:                                          ; preds = %test6Loop2
+  %inc = add nsw i32 %j.0, 1
+  %cmp3 = icmp slt i32 %inc, %ny
+  br i1 %cmp3, label %test6Loop2, label %for.cond2.for.end_crit_edge
+for.cond2.for.end_crit_edge:                      ; preds = %for.inc
+  br label %for.end
+for.end:                                          ; preds = %for.cond2.for.end_crit_edge, %if.then
+  br label %if.end
+if.end:                                           ; preds = %for.end, %test6Loop1
+  br label %for.inc11
+for.inc11:                                        ; preds = %if.end
+  %inc12 = add nsw i32 %i.0, 1
+  %cmp = icmp slt i32 %inc12, %nx
+  br i1 %cmp, label %test6Loop1, label %for.cond.for.end13_crit_edge
+for.cond.for.end13_crit_edge:                     ; preds = %for.inc11
+  br label %for.end13
+for.end13:                                        ; preds = %for.cond.for.end13_crit_edge, %entry
+  %arrayidx14 = getelementptr inbounds i32*, i32** %y, i64 0
+  %3 = load i32*, i32** %arrayidx14, align 8
+  %arrayidx15 = getelementptr inbounds i32, i32* %3, i64 0
+  %4 = load i32, i32* %arrayidx15, align 4
+  %conv = sitofp i32 %4 to double
+  ret double %conv
+; Test a perfect loop nest of the form:
+;   for (int i = 0; i < nx; ++i)
+;     if (i < ny) { // guard branch for the j-loop
+;       for (int j=i; j < ny; j+=1)
+;         y[j][i] = x[i][j] + j;
+;     }
+define double @test6(i32** %y, i32** %x, i32 signext %nx, i32 signext %ny) {
+; CHECK-LABEL: IsPerfect=true, Depth=1, OutermostLoop: test6Loop2, Loops: ( test6Loop2 )
+; CHECK-LABEL: IsPerfect=true, Depth=2, OutermostLoop: test6Loop1, Loops: ( test6Loop1 test6Loop2 )
+  %cmp2 = icmp slt i32 0, %nx
+  br i1 %cmp2, label %test6Loop1.lr.ph, label %for.end13
+test6Loop1.lr.ph:                                   ; preds = %entry
+  br label %test6Loop1
+test6Loop1:                                         ; preds = %test6Loop1.lr.ph, %for.inc11
+  %i.0 = phi i32 [ 0, %test6Loop1.lr.ph ], [ %inc12, %for.inc11 ]
+  %cmp1 = icmp slt i32 %i.0, %ny
+  br i1 %cmp1, label %test6Loop2.lr.ph, label %if.end
+test6Loop2.lr.ph:                                  ; preds = %if.then
+  br label %test6Loop2
+test6Loop2:                                        ; preds = %test6Loop2.lr.ph, %for.inc
+  %j.0 = phi i32 [ %i.0, %test6Loop2.lr.ph ], [ %inc, %for.inc ]
+  %idxprom = sext i32 %i.0 to i64
+  %arrayidx = getelementptr inbounds i32*, i32** %x, i64 %idxprom
+  %0 = load i32*, i32** %arrayidx, align 8
+  %idxprom5 = sext i32 %j.0 to i64
+  %arrayidx6 = getelementptr inbounds i32, i32* %0, i64 %idxprom5
+  %1 = load i32, i32* %arrayidx6, align 4
+  %add = add nsw i32 %1, %j.0
+  %idxprom7 = sext i32 %j.0 to i64
+  %arrayidx8 = getelementptr inbounds i32*, i32** %y, i64 %idxprom7
+  %2 = load i32*, i32** %arrayidx8, align 8
+  %idxprom9 = sext i32 %i.0 to i64
+  %arrayidx10 = getelementptr inbounds i32, i32* %2, i64 %idxprom9
+  store i32 %add, i32* %arrayidx10, align 4
+  br label %for.inc
+for.inc:                                          ; preds = %test6Loop2
+  %inc = add nsw i32 %j.0, 1
+  %cmp3 = icmp slt i32 %inc, %ny
+  br i1 %cmp3, label %test6Loop2, label %for.cond2.for.end_crit_edge
+for.cond2.for.end_crit_edge:                      ; preds = %for.inc
+  br label %for.end
+for.end:                                          ; preds = %for.cond2.for.end_crit_edge, %if.then
+  br label %if.end
+if.end:                                           ; preds = %for.end, %test6Loop1
+  br label %for.inc11
+for.inc11:                                        ; preds = %if.end
+  %inc12 = add nsw i32 %i.0, 1
+  %cmp = icmp slt i32 %inc12, %nx
+  br i1 %cmp, label %test6Loop1, label %for.cond.for.end13_crit_edge
+for.cond.for.end13_crit_edge:                     ; preds = %for.inc11
+  br label %for.end13
+for.end13:                                        ; preds = %for.cond.for.end13_crit_edge, %entry
+  %arrayidx14 = getelementptr inbounds i32*, i32** %y, i64 0
+  %3 = load i32*, i32** %arrayidx14, align 8
+  %arrayidx15 = getelementptr inbounds i32, i32* %3, i64 0
+  %4 = load i32, i32* %arrayidx15, align 4
+  %conv = sitofp i32 %4 to double
+  ret double %conv

diff  --git a/llvm/unittests/Analysis/LoopInfoTest.cpp b/llvm/unittests/Analysis/LoopInfoTest.cpp
index bb518904e8182..db6484f6928fc 100644
--- a/llvm/unittests/Analysis/LoopInfoTest.cpp
+++ b/llvm/unittests/Analysis/LoopInfoTest.cpp
@@ -1500,3 +1500,51 @@ TEST(LoopInfoTest, LoopNotRotated) {
+TEST(LoopInfoTest, LoopUserBranch) {
+  const char *ModuleStr =
+      "target datalayout = \"e-m:o-i64:64-f80:128-n8:16:32:64-S128\"\n"
+      "define void @foo(i32* %B, i64 signext %nx, i1 %cond) {\n"
+      "entry:\n"
+      "  br i1 %cond, label %bb, label %guard\n"
+      "guard:\n"
+      "  %cmp.guard = icmp slt i64 0, %nx\n"
+      "  br i1 %cmp.guard, label %for.i.preheader, label %for.end\n"
+      "for.i.preheader:\n"
+      "  br label %for.i\n"
+      "for.i:\n"
+      "  %i = phi i64 [ 0, %for.i.preheader ], [ %inc13, %for.i ]\n"
+      "  %Bi = getelementptr inbounds i32, i32* %B, i64 %i\n"
+      "  store i32 0, i32* %Bi, align 4\n"
+      "  %inc13 = add nsw i64 %i, 1\n"
+      "  %cmp = icmp slt i64 %inc13, %nx\n"
+      "  br i1 %cmp, label %for.i, label %for.i.exit\n"
+      "for.i.exit:\n"
+      "  br label %bb\n"
+      "bb:\n"
+      "  br label %for.end\n"
+      "for.end:\n"
+      "  ret void\n"
+      "}\n";
+  // Parse the module.
+  LLVMContext Context;
+  std::unique_ptr<Module> M = makeLLVMModule(Context, ModuleStr);
+  runWithLoopInfo(*M, "foo", [&](Function &F, LoopInfo &LI) {
+    Function::iterator FI = F.begin();
+    FI = ++FI;
+    BasicBlock *Guard = &*FI;
+    assert(Guard->getName() == "guard");
+    FI = ++FI;
+    BasicBlock *Header = &*(++FI);
+    assert(Header->getName() == "for.i");
+    Loop *L = LI.getLoopFor(Header);
+    EXPECT_NE(L, nullptr);
+    // L should not have a guard branch
+    EXPECT_EQ(L->getLoopGuardBranch(), nullptr);
+  });


More information about the llvm-commits mailing list