[llvm] [LSR] Extend lsr-term-fold to multiple exits loop and support the operand of condition is PHINode (PR #96048)
Zhijin Zeng via llvm-commits
llvm-commits at lists.llvm.org
Thu Jun 20 00:57:47 PDT 2024
https://github.com/zengdage updated https://github.com/llvm/llvm-project/pull/96048
>From 1e3f6d595656392b64c10e6d43475bc113ec1842 Mon Sep 17 00:00:00 2001
From: Zhijin Zeng <zhijin.zeng at spacemit.com>
Date: Wed, 19 Jun 2024 10:49:44 +0800
Subject: [PATCH 1/2] [LSR][NFC] Add one test case for lsr-term-fold
The test case primarily emphasizes the lhs operand of
the branch condition within the latch terminator, which
is PHINode rather than BinaryOperator.
```
%trip_count = phi i32 [ %new_trip_count, %for.latch ], [ %delta, %entry ]
...
%cond3 = icmp sgt i32 %trip_count, 1
```
---
.../LoopStrengthReduce/lsr-term-fold.ll | 61 +++++++++++++++++++
1 file changed, 61 insertions(+)
diff --git a/llvm/test/Transforms/LoopStrengthReduce/lsr-term-fold.ll b/llvm/test/Transforms/LoopStrengthReduce/lsr-term-fold.ll
index 7299a014b7983..75d0c727517ea 100644
--- a/llvm/test/Transforms/LoopStrengthReduce/lsr-term-fold.ll
+++ b/llvm/test/Transforms/LoopStrengthReduce/lsr-term-fold.ll
@@ -733,3 +733,64 @@ for.body: ; preds = %for.body, %entry
for.end: ; preds = %for.body
ret void
}
+
+define ptr @no_binary_operator(ptr %start, ptr %end, i8 %value) {
+; CHECK-LABEL: @no_binary_operator(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[END_I:%.*]] = ptrtoint ptr [[END:%.*]] to i64
+; CHECK-NEXT: [[START_I:%.*]] = ptrtoint ptr [[START:%.*]] to i64
+; CHECK-NEXT: [[DELTA_I:%.*]] = sub i64 [[END_I]], [[START_I]]
+; CHECK-NEXT: [[DELTA:%.*]] = trunc i64 [[DELTA_I]] to i32
+; CHECK-NEXT: [[COND1:%.*]] = icmp sgt i32 [[DELTA]], 0
+; CHECK-NEXT: br i1 [[COND1]], label [[FOR_BODY_PREHEADER:%.*]], label [[FOR_END:%.*]]
+; CHECK: for.body.preheader:
+; CHECK-NEXT: br label [[FOR_BODY:%.*]]
+; CHECK: for.body:
+; CHECK-NEXT: [[TRIP_COUNT:%.*]] = phi i32 [ [[NEW_TRIP_COUNT:%.*]], [[FOR_LATCH:%.*]] ], [ [[DELTA]], [[FOR_BODY_PREHEADER]] ]
+; CHECK-NEXT: [[ADDR:%.*]] = phi ptr [ [[NEW_ADDR:%.*]], [[FOR_LATCH]] ], [ [[START]], [[FOR_BODY_PREHEADER]] ]
+; CHECK-NEXT: [[DATA:%.*]] = load i8, ptr [[ADDR]], align 1
+; CHECK-NEXT: [[COND2:%.*]] = icmp eq i8 [[DATA]], [[VALUE:%.*]]
+; CHECK-NEXT: br i1 [[COND2]], label [[FOR_BODY_FOR_END_LOOPEXIT_CRIT_EDGE:%.*]], label [[FOR_LATCH]]
+; CHECK: for.latch:
+; CHECK-NEXT: [[NEW_ADDR]] = getelementptr inbounds i8, ptr [[ADDR]], i64 1
+; CHECK-NEXT: [[NEW_TRIP_COUNT]] = add nsw i32 [[TRIP_COUNT]], -1
+; CHECK-NEXT: [[COND3:%.*]] = icmp sgt i32 [[TRIP_COUNT]], 1
+; CHECK-NEXT: br i1 [[COND3]], label [[FOR_BODY]], label [[FOR_END_LOOPEXITSPLIT:%.*]]
+; CHECK: for.end.loopexitsplit:
+; CHECK-NEXT: [[NEW_ADDR_LCSSA:%.*]] = phi ptr [ [[NEW_ADDR]], [[FOR_LATCH]] ]
+; CHECK-NEXT: br label [[FOR_END_LOOPEXIT:%.*]]
+; CHECK: for.body.for.end.loopexit_crit_edge:
+; CHECK-NEXT: [[ADDR_LCSSA:%.*]] = phi ptr [ [[ADDR]], [[FOR_BODY]] ]
+; CHECK-NEXT: br label [[FOR_END_LOOPEXIT]]
+; CHECK: for.end.loopexit:
+; CHECK-NEXT: [[RETV_PH:%.*]] = phi ptr [ [[ADDR_LCSSA]], [[FOR_BODY_FOR_END_LOOPEXIT_CRIT_EDGE]] ], [ [[NEW_ADDR_LCSSA]], [[FOR_END_LOOPEXITSPLIT]] ]
+; CHECK-NEXT: br label [[FOR_END]]
+; CHECK: for.end:
+; CHECK-NEXT: [[RETV:%.*]] = phi ptr [ [[START]], [[ENTRY:%.*]] ], [ [[RETV_PH]], [[FOR_END_LOOPEXIT]] ]
+; CHECK-NEXT: ret ptr [[RETV]]
+;
+entry:
+ %end_i = ptrtoint ptr %end to i64
+ %start_i = ptrtoint ptr %start to i64
+ %delta_i = sub i64 %end_i, %start_i
+ %delta = trunc i64 %delta_i to i32
+ %cond1 = icmp sgt i32 %delta, 0
+ br i1 %cond1, label %for.body, label %for.end
+
+for.body: ; preds = %entry, %for.latch
+ %trip_count = phi i32 [ %new_trip_count, %for.latch ], [ %delta, %entry ]
+ %addr = phi ptr [ %new_addr, %for.latch ], [ %start, %entry ]
+ %data = load i8, ptr %addr, align 1
+ %cond2 = icmp eq i8 %data, %value
+ br i1 %cond2, label %for.end, label %for.latch
+
+for.latch: ; preds = %for.body
+ %new_addr = getelementptr inbounds i8, ptr %addr, i64 1
+ %new_trip_count = add nsw i32 %trip_count, -1
+ %cond3 = icmp sgt i32 %trip_count, 1
+ br i1 %cond3, label %for.body, label %for.end
+
+for.end: ; preds = %for.body, %for.latch, %entry
+ %retv = phi ptr [ %start, %entry ], [ %new_addr, %for.latch ], [ %addr, %for.body ]
+ ret ptr %retv
+}
>From c1f826f1d9903214eaa492327f5aefdc252a8017 Mon Sep 17 00:00:00 2001
From: Zhijin Zeng <zhijin.zeng at spacemit.com>
Date: Wed, 19 Jun 2024 11:43:22 +0800
Subject: [PATCH 2/2] [LSR] Extend lsr-term-fold to multiple exits loop and
support the operand of condition is PHINode
1. lsr-term-fold just support that the operand of the branch condition within
the latch terminator is BinaryOperator, and if the operand is PHINode,
lsr-term-fold should also do terminator fold for it.
2. Extend lsr-term-fold to multiple exit loop.
---
.../Transforms/Scalar/LoopStrengthReduce.cpp | 29 ++++++++++++++-----
.../LoopStrengthReduce/lsr-term-fold.ll | 17 +++++++----
2 files changed, 32 insertions(+), 14 deletions(-)
diff --git a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
index 3a98e257367b2..6eda52a51c77d 100644
--- a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
@@ -6830,15 +6830,18 @@ canFoldTermCondOfLoop(Loop *L, ScalarEvolution &SE, DominatorTree &DT,
return std::nullopt;
}
- if (!SE.hasLoopInvariantBackedgeTakenCount(L)) {
- LLVM_DEBUG(dbgs() << "Cannot fold on backedge that is loop variant\n");
- return std::nullopt;
- }
-
BasicBlock *LoopLatch = L->getLoopLatch();
BranchInst *BI = dyn_cast<BranchInst>(LoopLatch->getTerminator());
if (!BI || BI->isUnconditional())
return std::nullopt;
+
+ if (!SE.hasLoopInvariantBackedgeTakenCount(L)) {
+ if (isa<SCEVCouldNotCompute>(SE.getExitCount(L, BI->getParent()))) {
+ LLVM_DEBUG(dbgs() << "Cannot fold on backedge that is loop variant\n");
+ return std::nullopt;
+ }
+ }
+
auto *TermCond = dyn_cast<ICmpInst>(BI->getCondition());
if (!TermCond) {
LLVM_DEBUG(
@@ -6852,9 +6855,8 @@ canFoldTermCondOfLoop(Loop *L, ScalarEvolution &SE, DominatorTree &DT,
return std::nullopt;
}
- BinaryOperator *LHS = dyn_cast<BinaryOperator>(TermCond->getOperand(0));
Value *RHS = TermCond->getOperand(1);
- if (!LHS || !L->isLoopInvariant(RHS))
+ if (!L->isLoopInvariant(RHS))
// We could pattern match the inverse form of the icmp, but that is
// non-canonical, and this pass is running *very* late in the pipeline.
return std::nullopt;
@@ -6862,7 +6864,15 @@ canFoldTermCondOfLoop(Loop *L, ScalarEvolution &SE, DominatorTree &DT,
// Find the IV used by the current exit condition.
PHINode *ToFold;
Value *ToFoldStart, *ToFoldStep;
- if (!matchSimpleRecurrence(LHS, ToFold, ToFoldStart, ToFoldStep))
+ if (BinaryOperator *LHS = dyn_cast<BinaryOperator>(TermCond->getOperand(0))) {
+ if (!matchSimpleRecurrence(LHS, ToFold, ToFoldStart, ToFoldStep))
+ return std::nullopt;
+ } else if (PHINode *LHS = dyn_cast<PHINode>(TermCond->getOperand(0))) {
+ BinaryOperator *BO = nullptr;
+ if (!matchSimpleRecurrence(LHS, BO, ToFoldStart, ToFoldStep))
+ return std::nullopt;
+ ToFold = LHS;
+ } else
return std::nullopt;
// Ensure the simple recurrence is a part of the current loop.
@@ -6887,6 +6897,9 @@ canFoldTermCondOfLoop(Loop *L, ScalarEvolution &SE, DominatorTree &DT,
}();
const SCEV *BECount = SE.getBackedgeTakenCount(L);
+ if (isa<SCEVCouldNotCompute>(BECount))
+ BECount = SE.getExitCount(L, BI->getParent());
+
const DataLayout &DL = L->getHeader()->getModule()->getDataLayout();
SCEVExpander Expander(SE, DL, "lsr_fold_term_cond");
diff --git a/llvm/test/Transforms/LoopStrengthReduce/lsr-term-fold.ll b/llvm/test/Transforms/LoopStrengthReduce/lsr-term-fold.ll
index 75d0c727517ea..877f48dece710 100644
--- a/llvm/test/Transforms/LoopStrengthReduce/lsr-term-fold.ll
+++ b/llvm/test/Transforms/LoopStrengthReduce/lsr-term-fold.ll
@@ -744,18 +744,23 @@ define ptr @no_binary_operator(ptr %start, ptr %end, i8 %value) {
; CHECK-NEXT: [[COND1:%.*]] = icmp sgt i32 [[DELTA]], 0
; CHECK-NEXT: br i1 [[COND1]], label [[FOR_BODY_PREHEADER:%.*]], label [[FOR_END:%.*]]
; CHECK: for.body.preheader:
+; CHECK-NEXT: [[TMP0:%.*]] = trunc i64 [[END_I]] to i32
+; CHECK-NEXT: [[TMP1:%.*]] = add i32 [[TMP0]], -1
+; CHECK-NEXT: [[TMP2:%.*]] = trunc i64 [[START_I]] to i32
+; CHECK-NEXT: [[TMP3:%.*]] = sub i32 [[TMP1]], [[TMP2]]
+; CHECK-NEXT: [[TMP4:%.*]] = zext i32 [[TMP3]] to i64
+; CHECK-NEXT: [[TMP5:%.*]] = add nuw nsw i64 [[TMP4]], 1
+; CHECK-NEXT: [[SCEVGEP:%.*]] = getelementptr i8, ptr [[START]], i64 [[TMP5]]
; CHECK-NEXT: br label [[FOR_BODY:%.*]]
; CHECK: for.body:
-; CHECK-NEXT: [[TRIP_COUNT:%.*]] = phi i32 [ [[NEW_TRIP_COUNT:%.*]], [[FOR_LATCH:%.*]] ], [ [[DELTA]], [[FOR_BODY_PREHEADER]] ]
-; CHECK-NEXT: [[ADDR:%.*]] = phi ptr [ [[NEW_ADDR:%.*]], [[FOR_LATCH]] ], [ [[START]], [[FOR_BODY_PREHEADER]] ]
+; CHECK-NEXT: [[ADDR:%.*]] = phi ptr [ [[NEW_ADDR:%.*]], [[FOR_LATCH:%.*]] ], [ [[START]], [[FOR_BODY_PREHEADER]] ]
; CHECK-NEXT: [[DATA:%.*]] = load i8, ptr [[ADDR]], align 1
; CHECK-NEXT: [[COND2:%.*]] = icmp eq i8 [[DATA]], [[VALUE:%.*]]
; CHECK-NEXT: br i1 [[COND2]], label [[FOR_BODY_FOR_END_LOOPEXIT_CRIT_EDGE:%.*]], label [[FOR_LATCH]]
; CHECK: for.latch:
-; CHECK-NEXT: [[NEW_ADDR]] = getelementptr inbounds i8, ptr [[ADDR]], i64 1
-; CHECK-NEXT: [[NEW_TRIP_COUNT]] = add nsw i32 [[TRIP_COUNT]], -1
-; CHECK-NEXT: [[COND3:%.*]] = icmp sgt i32 [[TRIP_COUNT]], 1
-; CHECK-NEXT: br i1 [[COND3]], label [[FOR_BODY]], label [[FOR_END_LOOPEXITSPLIT:%.*]]
+; CHECK-NEXT: [[NEW_ADDR]] = getelementptr i8, ptr [[ADDR]], i64 1
+; CHECK-NEXT: [[LSR_FOLD_TERM_COND_REPLACED_TERM_COND:%.*]] = icmp eq ptr [[NEW_ADDR]], [[SCEVGEP]]
+; CHECK-NEXT: br i1 [[LSR_FOLD_TERM_COND_REPLACED_TERM_COND]], label [[FOR_END_LOOPEXITSPLIT:%.*]], label [[FOR_BODY]]
; CHECK: for.end.loopexitsplit:
; CHECK-NEXT: [[NEW_ADDR_LCSSA:%.*]] = phi ptr [ [[NEW_ADDR]], [[FOR_LATCH]] ]
; CHECK-NEXT: br label [[FOR_END_LOOPEXIT:%.*]]
More information about the llvm-commits
mailing list