[llvm] [LSR] Optimize lsr-term-fold if the operand of condition is PHINode (PR #96048)

Zhijin Zeng via llvm-commits llvm-commits at lists.llvm.org
Wed Jun 19 02:30:00 PDT 2024


https://github.com/zengdage updated https://github.com/llvm/llvm-project/pull/96048

>From 1e3f6d595656392b64c10e6d43475bc113ec1842 Mon Sep 17 00:00:00 2001
From: Zhijin Zeng <zhijin.zeng at spacemit.com>
Date: Wed, 19 Jun 2024 10:49:44 +0800
Subject: [PATCH 1/2] [LSR][NFC] Add one test case for lsr-term-fold

The test case primarily emphasizes the lhs operand of
the branch condition within the latch terminator, which
is PHINode rather than BinaryOperator.

```
%trip_count = phi i32 [ %new_trip_count, %for.latch ], [ %delta, %entry ]
...
%cond3 = icmp sgt i32 %trip_count, 1
```
---
 .../LoopStrengthReduce/lsr-term-fold.ll       | 61 +++++++++++++++++++
 1 file changed, 61 insertions(+)

diff --git a/llvm/test/Transforms/LoopStrengthReduce/lsr-term-fold.ll b/llvm/test/Transforms/LoopStrengthReduce/lsr-term-fold.ll
index 7299a014b7983..75d0c727517ea 100644
--- a/llvm/test/Transforms/LoopStrengthReduce/lsr-term-fold.ll
+++ b/llvm/test/Transforms/LoopStrengthReduce/lsr-term-fold.ll
@@ -733,3 +733,64 @@ for.body:                                         ; preds = %for.body, %entry
 for.end:                                          ; preds = %for.body
   ret void
 }
+
+define ptr @no_binary_operator(ptr %start, ptr %end, i8 %value) {
+; CHECK-LABEL: @no_binary_operator(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[END_I:%.*]] = ptrtoint ptr [[END:%.*]] to i64
+; CHECK-NEXT:    [[START_I:%.*]] = ptrtoint ptr [[START:%.*]] to i64
+; CHECK-NEXT:    [[DELTA_I:%.*]] = sub i64 [[END_I]], [[START_I]]
+; CHECK-NEXT:    [[DELTA:%.*]] = trunc i64 [[DELTA_I]] to i32
+; CHECK-NEXT:    [[COND1:%.*]] = icmp sgt i32 [[DELTA]], 0
+; CHECK-NEXT:    br i1 [[COND1]], label [[FOR_BODY_PREHEADER:%.*]], label [[FOR_END:%.*]]
+; CHECK:       for.body.preheader:
+; CHECK-NEXT:    br label [[FOR_BODY:%.*]]
+; CHECK:       for.body:
+; CHECK-NEXT:    [[TRIP_COUNT:%.*]] = phi i32 [ [[NEW_TRIP_COUNT:%.*]], [[FOR_LATCH:%.*]] ], [ [[DELTA]], [[FOR_BODY_PREHEADER]] ]
+; CHECK-NEXT:    [[ADDR:%.*]] = phi ptr [ [[NEW_ADDR:%.*]], [[FOR_LATCH]] ], [ [[START]], [[FOR_BODY_PREHEADER]] ]
+; CHECK-NEXT:    [[DATA:%.*]] = load i8, ptr [[ADDR]], align 1
+; CHECK-NEXT:    [[COND2:%.*]] = icmp eq i8 [[DATA]], [[VALUE:%.*]]
+; CHECK-NEXT:    br i1 [[COND2]], label [[FOR_BODY_FOR_END_LOOPEXIT_CRIT_EDGE:%.*]], label [[FOR_LATCH]]
+; CHECK:       for.latch:
+; CHECK-NEXT:    [[NEW_ADDR]] = getelementptr inbounds i8, ptr [[ADDR]], i64 1
+; CHECK-NEXT:    [[NEW_TRIP_COUNT]] = add nsw i32 [[TRIP_COUNT]], -1
+; CHECK-NEXT:    [[COND3:%.*]] = icmp sgt i32 [[TRIP_COUNT]], 1
+; CHECK-NEXT:    br i1 [[COND3]], label [[FOR_BODY]], label [[FOR_END_LOOPEXITSPLIT:%.*]]
+; CHECK:       for.end.loopexitsplit:
+; CHECK-NEXT:    [[NEW_ADDR_LCSSA:%.*]] = phi ptr [ [[NEW_ADDR]], [[FOR_LATCH]] ]
+; CHECK-NEXT:    br label [[FOR_END_LOOPEXIT:%.*]]
+; CHECK:       for.body.for.end.loopexit_crit_edge:
+; CHECK-NEXT:    [[ADDR_LCSSA:%.*]] = phi ptr [ [[ADDR]], [[FOR_BODY]] ]
+; CHECK-NEXT:    br label [[FOR_END_LOOPEXIT]]
+; CHECK:       for.end.loopexit:
+; CHECK-NEXT:    [[RETV_PH:%.*]] = phi ptr [ [[ADDR_LCSSA]], [[FOR_BODY_FOR_END_LOOPEXIT_CRIT_EDGE]] ], [ [[NEW_ADDR_LCSSA]], [[FOR_END_LOOPEXITSPLIT]] ]
+; CHECK-NEXT:    br label [[FOR_END]]
+; CHECK:       for.end:
+; CHECK-NEXT:    [[RETV:%.*]] = phi ptr [ [[START]], [[ENTRY:%.*]] ], [ [[RETV_PH]], [[FOR_END_LOOPEXIT]] ]
+; CHECK-NEXT:    ret ptr [[RETV]]
+;
+entry:
+  %end_i = ptrtoint ptr %end to i64
+  %start_i = ptrtoint ptr %start to i64
+  %delta_i = sub i64 %end_i, %start_i
+  %delta = trunc i64 %delta_i to i32
+  %cond1 = icmp sgt i32 %delta, 0
+  br i1 %cond1, label %for.body, label %for.end
+
+for.body:                                                ; preds = %entry, %for.latch
+  %trip_count = phi i32 [ %new_trip_count, %for.latch ], [ %delta, %entry ]
+  %addr = phi ptr [ %new_addr, %for.latch ], [ %start, %entry ]
+  %data = load i8, ptr %addr, align 1
+  %cond2 = icmp eq i8 %data, %value
+  br i1 %cond2, label %for.end, label %for.latch
+
+for.latch:                                               ; preds = %for.body
+  %new_addr = getelementptr inbounds i8, ptr %addr, i64 1
+  %new_trip_count = add nsw i32 %trip_count, -1
+  %cond3 = icmp sgt i32 %trip_count, 1
+  br i1 %cond3, label %for.body, label %for.end
+
+for.end:                                               ; preds = %for.body, %for.latch, %entry
+  %retv = phi ptr [ %start, %entry ], [ %new_addr, %for.latch ], [ %addr, %for.body ]
+  ret ptr %retv
+}

>From 7ee93c027e96666cf5f3cc6281beefa223a40829 Mon Sep 17 00:00:00 2001
From: Zhijin Zeng <zhijin.zeng at spacemit.com>
Date: Wed, 19 Jun 2024 11:43:22 +0800
Subject: [PATCH 2/2] [LSR] Optimize lsr-term-fold if the operand of condition
 is PHINode

lsr-term-fold just support that the operand of the branch condition within
the latch terminator is BinaryOperator, and if the operand is PHINode,
lsr-term-fold should also do terminator fold for it.
---
 .../Transforms/Scalar/LoopStrengthReduce.cpp  | 29 ++++++++++++++-----
 .../LoopStrengthReduce/lsr-term-fold.ll       | 17 +++++++----
 2 files changed, 32 insertions(+), 14 deletions(-)

diff --git a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
index 3a98e257367b2..6eda52a51c77d 100644
--- a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
@@ -6830,15 +6830,18 @@ canFoldTermCondOfLoop(Loop *L, ScalarEvolution &SE, DominatorTree &DT,
     return std::nullopt;
   }
 
-  if (!SE.hasLoopInvariantBackedgeTakenCount(L)) {
-    LLVM_DEBUG(dbgs() << "Cannot fold on backedge that is loop variant\n");
-    return std::nullopt;
-  }
-
   BasicBlock *LoopLatch = L->getLoopLatch();
   BranchInst *BI = dyn_cast<BranchInst>(LoopLatch->getTerminator());
   if (!BI || BI->isUnconditional())
     return std::nullopt;
+
+  if (!SE.hasLoopInvariantBackedgeTakenCount(L)) {
+    if (isa<SCEVCouldNotCompute>(SE.getExitCount(L, BI->getParent()))) {
+      LLVM_DEBUG(dbgs() << "Cannot fold on backedge that is loop variant\n");
+      return std::nullopt;
+    }
+  }
+
   auto *TermCond = dyn_cast<ICmpInst>(BI->getCondition());
   if (!TermCond) {
     LLVM_DEBUG(
@@ -6852,9 +6855,8 @@ canFoldTermCondOfLoop(Loop *L, ScalarEvolution &SE, DominatorTree &DT,
     return std::nullopt;
   }
 
-  BinaryOperator *LHS = dyn_cast<BinaryOperator>(TermCond->getOperand(0));
   Value *RHS = TermCond->getOperand(1);
-  if (!LHS || !L->isLoopInvariant(RHS))
+  if (!L->isLoopInvariant(RHS))
     // We could pattern match the inverse form of the icmp, but that is
     // non-canonical, and this pass is running *very* late in the pipeline.
     return std::nullopt;
@@ -6862,7 +6864,15 @@ canFoldTermCondOfLoop(Loop *L, ScalarEvolution &SE, DominatorTree &DT,
   // Find the IV used by the current exit condition.
   PHINode *ToFold;
   Value *ToFoldStart, *ToFoldStep;
-  if (!matchSimpleRecurrence(LHS, ToFold, ToFoldStart, ToFoldStep))
+  if (BinaryOperator *LHS = dyn_cast<BinaryOperator>(TermCond->getOperand(0))) {
+    if (!matchSimpleRecurrence(LHS, ToFold, ToFoldStart, ToFoldStep))
+      return std::nullopt;
+  } else if (PHINode *LHS = dyn_cast<PHINode>(TermCond->getOperand(0))) {
+    BinaryOperator *BO = nullptr;
+    if (!matchSimpleRecurrence(LHS, BO, ToFoldStart, ToFoldStep))
+      return std::nullopt;
+    ToFold = LHS;
+  } else
     return std::nullopt;
 
   // Ensure the simple recurrence is a part of the current loop.
@@ -6887,6 +6897,9 @@ canFoldTermCondOfLoop(Loop *L, ScalarEvolution &SE, DominatorTree &DT,
   }();
 
   const SCEV *BECount = SE.getBackedgeTakenCount(L);
+  if (isa<SCEVCouldNotCompute>(BECount))
+    BECount = SE.getExitCount(L, BI->getParent());
+
   const DataLayout &DL = L->getHeader()->getModule()->getDataLayout();
   SCEVExpander Expander(SE, DL, "lsr_fold_term_cond");
 
diff --git a/llvm/test/Transforms/LoopStrengthReduce/lsr-term-fold.ll b/llvm/test/Transforms/LoopStrengthReduce/lsr-term-fold.ll
index 75d0c727517ea..877f48dece710 100644
--- a/llvm/test/Transforms/LoopStrengthReduce/lsr-term-fold.ll
+++ b/llvm/test/Transforms/LoopStrengthReduce/lsr-term-fold.ll
@@ -744,18 +744,23 @@ define ptr @no_binary_operator(ptr %start, ptr %end, i8 %value) {
 ; CHECK-NEXT:    [[COND1:%.*]] = icmp sgt i32 [[DELTA]], 0
 ; CHECK-NEXT:    br i1 [[COND1]], label [[FOR_BODY_PREHEADER:%.*]], label [[FOR_END:%.*]]
 ; CHECK:       for.body.preheader:
+; CHECK-NEXT:    [[TMP0:%.*]] = trunc i64 [[END_I]] to i32
+; CHECK-NEXT:    [[TMP1:%.*]] = add i32 [[TMP0]], -1
+; CHECK-NEXT:    [[TMP2:%.*]] = trunc i64 [[START_I]] to i32
+; CHECK-NEXT:    [[TMP3:%.*]] = sub i32 [[TMP1]], [[TMP2]]
+; CHECK-NEXT:    [[TMP4:%.*]] = zext i32 [[TMP3]] to i64
+; CHECK-NEXT:    [[TMP5:%.*]] = add nuw nsw i64 [[TMP4]], 1
+; CHECK-NEXT:    [[SCEVGEP:%.*]] = getelementptr i8, ptr [[START]], i64 [[TMP5]]
 ; CHECK-NEXT:    br label [[FOR_BODY:%.*]]
 ; CHECK:       for.body:
-; CHECK-NEXT:    [[TRIP_COUNT:%.*]] = phi i32 [ [[NEW_TRIP_COUNT:%.*]], [[FOR_LATCH:%.*]] ], [ [[DELTA]], [[FOR_BODY_PREHEADER]] ]
-; CHECK-NEXT:    [[ADDR:%.*]] = phi ptr [ [[NEW_ADDR:%.*]], [[FOR_LATCH]] ], [ [[START]], [[FOR_BODY_PREHEADER]] ]
+; CHECK-NEXT:    [[ADDR:%.*]] = phi ptr [ [[NEW_ADDR:%.*]], [[FOR_LATCH:%.*]] ], [ [[START]], [[FOR_BODY_PREHEADER]] ]
 ; CHECK-NEXT:    [[DATA:%.*]] = load i8, ptr [[ADDR]], align 1
 ; CHECK-NEXT:    [[COND2:%.*]] = icmp eq i8 [[DATA]], [[VALUE:%.*]]
 ; CHECK-NEXT:    br i1 [[COND2]], label [[FOR_BODY_FOR_END_LOOPEXIT_CRIT_EDGE:%.*]], label [[FOR_LATCH]]
 ; CHECK:       for.latch:
-; CHECK-NEXT:    [[NEW_ADDR]] = getelementptr inbounds i8, ptr [[ADDR]], i64 1
-; CHECK-NEXT:    [[NEW_TRIP_COUNT]] = add nsw i32 [[TRIP_COUNT]], -1
-; CHECK-NEXT:    [[COND3:%.*]] = icmp sgt i32 [[TRIP_COUNT]], 1
-; CHECK-NEXT:    br i1 [[COND3]], label [[FOR_BODY]], label [[FOR_END_LOOPEXITSPLIT:%.*]]
+; CHECK-NEXT:    [[NEW_ADDR]] = getelementptr i8, ptr [[ADDR]], i64 1
+; CHECK-NEXT:    [[LSR_FOLD_TERM_COND_REPLACED_TERM_COND:%.*]] = icmp eq ptr [[NEW_ADDR]], [[SCEVGEP]]
+; CHECK-NEXT:    br i1 [[LSR_FOLD_TERM_COND_REPLACED_TERM_COND]], label [[FOR_END_LOOPEXITSPLIT:%.*]], label [[FOR_BODY]]
 ; CHECK:       for.end.loopexitsplit:
 ; CHECK-NEXT:    [[NEW_ADDR_LCSSA:%.*]] = phi ptr [ [[NEW_ADDR]], [[FOR_LATCH]] ]
 ; CHECK-NEXT:    br label [[FOR_END_LOOPEXIT:%.*]]



More information about the llvm-commits mailing list