[llvm] [LSR][term-fold] Ensure the simple recurrence is reachable from the current loop (PR #83085)

Patrick O'Neill via llvm-commits llvm-commits at lists.llvm.org
Mon Mar 4 14:00:23 PST 2024


https://github.com/patrick-rivos updated https://github.com/llvm/llvm-project/pull/83085

>From 2a8a5a11e72d529a97c8521b11989f971602964a Mon Sep 17 00:00:00 2001
From: Patrick O'Neill <patrick at rivosinc.com>
Date: Mon, 26 Feb 2024 15:21:07 -0800
Subject: [PATCH 1/3] [LSR][term-fold] Ensure the simple recurrence is
 reachable from the current loop

If the phi node is unreachable from the current loop, then isAlmostDeadIV
panics. With this patch we bail out early.

Signed-off-by: Patrick O'Neill <patrick at rivosinc.com>
---
 .../Transforms/Scalar/LoopStrengthReduce.cpp  |  6 +++
 llvm/lib/Transforms/Utils/LoopUtils.cpp       |  1 +
 .../lsr-unreachable-bb-phi-node.ll            | 40 +++++++++++++++++++
 3 files changed, 47 insertions(+)
 create mode 100644 llvm/test/Transforms/LoopStrengthReduce/lsr-unreachable-bb-phi-node.ll

diff --git a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
index 08021f3ba853e8..62c271494cf2d7 100644
--- a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
@@ -6808,6 +6808,12 @@ canFoldTermCondOfLoop(Loop *L, ScalarEvolution &SE, DominatorTree &DT,
   if (!matchSimpleRecurrence(LHS, ToFold, ToFoldStart, ToFoldStep))
     return std::nullopt;
 
+  // If ToFold does not have an incoming value from LoopLatch then the simple
+  // recurrence is from a prior loop unreachable from the loop we're currently
+  // considering.
+  if (ToFold->getBasicBlockIndex(LoopLatch) == -1)
+    return std::nullopt;
+
   // If that IV isn't dead after we rewrite the exit condition in terms of
   // another IV, there's no point in doing the transform.
   if (!isAlmostDeadIV(ToFold, LoopLatch, TermCond))
diff --git a/llvm/lib/Transforms/Utils/LoopUtils.cpp b/llvm/lib/Transforms/Utils/LoopUtils.cpp
index a4fdc1f8c12e50..7491a99b03f66e 100644
--- a/llvm/lib/Transforms/Utils/LoopUtils.cpp
+++ b/llvm/lib/Transforms/Utils/LoopUtils.cpp
@@ -468,6 +468,7 @@ llvm::collectChildrenInLoop(DomTreeNode *N, const Loop *CurLoop) {
 
 bool llvm::isAlmostDeadIV(PHINode *PN, BasicBlock *LatchBlock, Value *Cond) {
   int LatchIdx = PN->getBasicBlockIndex(LatchBlock);
+  assert(LatchIdx != -1 && "LatchBlock is not a case in this PHINode");
   Value *IncV = PN->getIncomingValue(LatchIdx);
 
   for (User *U : PN->users())
diff --git a/llvm/test/Transforms/LoopStrengthReduce/lsr-unreachable-bb-phi-node.ll b/llvm/test/Transforms/LoopStrengthReduce/lsr-unreachable-bb-phi-node.ll
new file mode 100644
index 00000000000000..1454535b52bccb
--- /dev/null
+++ b/llvm/test/Transforms/LoopStrengthReduce/lsr-unreachable-bb-phi-node.ll
@@ -0,0 +1,40 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt < %s -loop-reduce -S -lsr-term-fold | FileCheck %s
+
+; This test used to crash due to matchSimpleRecurrence matching the simple
+; recurrence in pn-loop when evaluating unrelated-loop. Since unrelated-loop
+; cannot jump to pn-node isAlmostDeadIV panics.
+define void @phi_node_different_bb() {
+; CHECK-LABEL: @phi_node_different_bb(
+; CHECK-NEXT:    br label [[PN_LOOP:%.*]]
+; CHECK:       pn-loop:
+; CHECK-NEXT:    [[TMP1:%.*]] = phi i32 [ 1, [[TMP0:%.*]] ], [ [[TMP2:%.*]], [[PN_LOOP]] ]
+; CHECK-NEXT:    [[TMP2]] = add i32 [[TMP1]], 1
+; CHECK-NEXT:    [[TMP3:%.*]] = icmp ugt i32 [[TMP2]], 1
+; CHECK-NEXT:    br i1 [[TMP3]], label [[PN_LOOP]], label [[UNRELATED_LOOP_PREHEADER:%.*]]
+; CHECK:       unrelated-loop.preheader:
+; CHECK-NEXT:    br label [[UNRELATED_LOOP:%.*]]
+; CHECK:       unrelated-loop:
+; CHECK-NEXT:    [[TMP4:%.*]] = icmp eq i32 [[TMP2]], 0
+; CHECK-NEXT:    br i1 [[TMP4]], label [[END:%.*]], label [[UNRELATED_LOOP]]
+; CHECK:       end:
+; CHECK-NEXT:    ret void
+;
+  br label %pn-loop
+
+pn-loop:                                          ; preds = %pn-loop, %0
+  %1 = phi i32 [ 1, %0 ], [ %2, %pn-loop ]
+  %2 = add i32 %1, 1
+  %3 = icmp ugt i32 %2, 1
+  br i1 %3, label %pn-loop, label %unrelated-loop.preheader
+
+unrelated-loop.preheader:                         ; preds = %pn-loop
+  br label %unrelated-loop
+
+unrelated-loop:                                   ; preds = %unrelated-loop, %unrelated-loop.preheader
+  %4 = icmp eq i32 %2, 0
+  br i1 %4, label %end, label %unrelated-loop
+
+end:                                              ; preds = %unrelated-loop
+  ret void
+}

>From 6a86d9c00c165a7551fb1b500a06c7d986553097 Mon Sep 17 00:00:00 2001
From: Patrick O'Neill <102189596+patrick-rivos at users.noreply.github.com>
Date: Mon, 4 Mar 2024 13:40:33 -0800
Subject: [PATCH 2/3] Directly compare loop header and phi-node parent

---
 llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
index 62c271494cf2d7..07f350ba2c3bda 100644
--- a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
@@ -6811,7 +6811,7 @@ canFoldTermCondOfLoop(Loop *L, ScalarEvolution &SE, DominatorTree &DT,
   // If ToFold does not have an incoming value from LoopLatch then the simple
   // recurrence is from a prior loop unreachable from the loop we're currently
   // considering.
-  if (ToFold->getBasicBlockIndex(LoopLatch) == -1)
+  if (L->getHeader() != ToFold->getParent())
     return std::nullopt;
 
   // If that IV isn't dead after we rewrite the exit condition in terms of

>From 3d54ee47e97d1043a851b87246bcc1226dd4b1d4 Mon Sep 17 00:00:00 2001
From: Patrick O'Neill <102189596+patrick-rivos at users.noreply.github.com>
Date: Mon, 4 Mar 2024 14:00:11 -0800
Subject: [PATCH 3/3] Update comment and order of comparison

---
 llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp | 6 ++----
 1 file changed, 2 insertions(+), 4 deletions(-)

diff --git a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
index 07f350ba2c3bda..4f550161148410 100644
--- a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
@@ -6808,10 +6808,8 @@ canFoldTermCondOfLoop(Loop *L, ScalarEvolution &SE, DominatorTree &DT,
   if (!matchSimpleRecurrence(LHS, ToFold, ToFoldStart, ToFoldStep))
     return std::nullopt;
 
-  // If ToFold does not have an incoming value from LoopLatch then the simple
-  // recurrence is from a prior loop unreachable from the loop we're currently
-  // considering.
-  if (L->getHeader() != ToFold->getParent())
+  // Ensure the simple recurrence is a part of the current loop.
+  if (ToFold->getParent() != L->getHeader())
     return std::nullopt;
 
   // If that IV isn't dead after we rewrite the exit condition in terms of



More information about the llvm-commits mailing list