[llvm] [LSR] Optimize lsr-term-fold if the operand of condition is PHINode (PR #96048)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Jun 19 02:28:05 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-transforms
Author: Zhijin Zeng (zengdage)
<details>
<summary>Changes</summary>
`lsr-term-fold` just support that the operand of the branch condition within the latch terminator is `BinaryOperator,` and if the operand is `PHINode`, `lsr-term-fold` should also do terminator fold for it.
---
Full diff: https://github.com/llvm/llvm-project/pull/96048.diff
2 Files Affected:
- (modified) llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp (+21-8)
- (modified) llvm/test/Transforms/LoopStrengthReduce/lsr-term-fold.ll (+66)
``````````diff
diff --git a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
index 3a98e257367b2..6eda52a51c77d 100644
--- a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
@@ -6830,15 +6830,18 @@ canFoldTermCondOfLoop(Loop *L, ScalarEvolution &SE, DominatorTree &DT,
return std::nullopt;
}
- if (!SE.hasLoopInvariantBackedgeTakenCount(L)) {
- LLVM_DEBUG(dbgs() << "Cannot fold on backedge that is loop variant\n");
- return std::nullopt;
- }
-
BasicBlock *LoopLatch = L->getLoopLatch();
BranchInst *BI = dyn_cast<BranchInst>(LoopLatch->getTerminator());
if (!BI || BI->isUnconditional())
return std::nullopt;
+
+ if (!SE.hasLoopInvariantBackedgeTakenCount(L)) {
+ if (isa<SCEVCouldNotCompute>(SE.getExitCount(L, BI->getParent()))) {
+ LLVM_DEBUG(dbgs() << "Cannot fold on backedge that is loop variant\n");
+ return std::nullopt;
+ }
+ }
+
auto *TermCond = dyn_cast<ICmpInst>(BI->getCondition());
if (!TermCond) {
LLVM_DEBUG(
@@ -6852,9 +6855,8 @@ canFoldTermCondOfLoop(Loop *L, ScalarEvolution &SE, DominatorTree &DT,
return std::nullopt;
}
- BinaryOperator *LHS = dyn_cast<BinaryOperator>(TermCond->getOperand(0));
Value *RHS = TermCond->getOperand(1);
- if (!LHS || !L->isLoopInvariant(RHS))
+ if (!L->isLoopInvariant(RHS))
// We could pattern match the inverse form of the icmp, but that is
// non-canonical, and this pass is running *very* late in the pipeline.
return std::nullopt;
@@ -6862,7 +6864,15 @@ canFoldTermCondOfLoop(Loop *L, ScalarEvolution &SE, DominatorTree &DT,
// Find the IV used by the current exit condition.
PHINode *ToFold;
Value *ToFoldStart, *ToFoldStep;
- if (!matchSimpleRecurrence(LHS, ToFold, ToFoldStart, ToFoldStep))
+ if (BinaryOperator *LHS = dyn_cast<BinaryOperator>(TermCond->getOperand(0))) {
+ if (!matchSimpleRecurrence(LHS, ToFold, ToFoldStart, ToFoldStep))
+ return std::nullopt;
+ } else if (PHINode *LHS = dyn_cast<PHINode>(TermCond->getOperand(0))) {
+ BinaryOperator *BO = nullptr;
+ if (!matchSimpleRecurrence(LHS, BO, ToFoldStart, ToFoldStep))
+ return std::nullopt;
+ ToFold = LHS;
+ } else
return std::nullopt;
// Ensure the simple recurrence is a part of the current loop.
@@ -6887,6 +6897,9 @@ canFoldTermCondOfLoop(Loop *L, ScalarEvolution &SE, DominatorTree &DT,
}();
const SCEV *BECount = SE.getBackedgeTakenCount(L);
+ if (isa<SCEVCouldNotCompute>(BECount))
+ BECount = SE.getExitCount(L, BI->getParent());
+
const DataLayout &DL = L->getHeader()->getModule()->getDataLayout();
SCEVExpander Expander(SE, DL, "lsr_fold_term_cond");
diff --git a/llvm/test/Transforms/LoopStrengthReduce/lsr-term-fold.ll b/llvm/test/Transforms/LoopStrengthReduce/lsr-term-fold.ll
index 7299a014b7983..877f48dece710 100644
--- a/llvm/test/Transforms/LoopStrengthReduce/lsr-term-fold.ll
+++ b/llvm/test/Transforms/LoopStrengthReduce/lsr-term-fold.ll
@@ -733,3 +733,69 @@ for.body: ; preds = %for.body, %entry
for.end: ; preds = %for.body
ret void
}
+
+define ptr @no_binary_operator(ptr %start, ptr %end, i8 %value) {
+; CHECK-LABEL: @no_binary_operator(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[END_I:%.*]] = ptrtoint ptr [[END:%.*]] to i64
+; CHECK-NEXT: [[START_I:%.*]] = ptrtoint ptr [[START:%.*]] to i64
+; CHECK-NEXT: [[DELTA_I:%.*]] = sub i64 [[END_I]], [[START_I]]
+; CHECK-NEXT: [[DELTA:%.*]] = trunc i64 [[DELTA_I]] to i32
+; CHECK-NEXT: [[COND1:%.*]] = icmp sgt i32 [[DELTA]], 0
+; CHECK-NEXT: br i1 [[COND1]], label [[FOR_BODY_PREHEADER:%.*]], label [[FOR_END:%.*]]
+; CHECK: for.body.preheader:
+; CHECK-NEXT: [[TMP0:%.*]] = trunc i64 [[END_I]] to i32
+; CHECK-NEXT: [[TMP1:%.*]] = add i32 [[TMP0]], -1
+; CHECK-NEXT: [[TMP2:%.*]] = trunc i64 [[START_I]] to i32
+; CHECK-NEXT: [[TMP3:%.*]] = sub i32 [[TMP1]], [[TMP2]]
+; CHECK-NEXT: [[TMP4:%.*]] = zext i32 [[TMP3]] to i64
+; CHECK-NEXT: [[TMP5:%.*]] = add nuw nsw i64 [[TMP4]], 1
+; CHECK-NEXT: [[SCEVGEP:%.*]] = getelementptr i8, ptr [[START]], i64 [[TMP5]]
+; CHECK-NEXT: br label [[FOR_BODY:%.*]]
+; CHECK: for.body:
+; CHECK-NEXT: [[ADDR:%.*]] = phi ptr [ [[NEW_ADDR:%.*]], [[FOR_LATCH:%.*]] ], [ [[START]], [[FOR_BODY_PREHEADER]] ]
+; CHECK-NEXT: [[DATA:%.*]] = load i8, ptr [[ADDR]], align 1
+; CHECK-NEXT: [[COND2:%.*]] = icmp eq i8 [[DATA]], [[VALUE:%.*]]
+; CHECK-NEXT: br i1 [[COND2]], label [[FOR_BODY_FOR_END_LOOPEXIT_CRIT_EDGE:%.*]], label [[FOR_LATCH]]
+; CHECK: for.latch:
+; CHECK-NEXT: [[NEW_ADDR]] = getelementptr i8, ptr [[ADDR]], i64 1
+; CHECK-NEXT: [[LSR_FOLD_TERM_COND_REPLACED_TERM_COND:%.*]] = icmp eq ptr [[NEW_ADDR]], [[SCEVGEP]]
+; CHECK-NEXT: br i1 [[LSR_FOLD_TERM_COND_REPLACED_TERM_COND]], label [[FOR_END_LOOPEXITSPLIT:%.*]], label [[FOR_BODY]]
+; CHECK: for.end.loopexitsplit:
+; CHECK-NEXT: [[NEW_ADDR_LCSSA:%.*]] = phi ptr [ [[NEW_ADDR]], [[FOR_LATCH]] ]
+; CHECK-NEXT: br label [[FOR_END_LOOPEXIT:%.*]]
+; CHECK: for.body.for.end.loopexit_crit_edge:
+; CHECK-NEXT: [[ADDR_LCSSA:%.*]] = phi ptr [ [[ADDR]], [[FOR_BODY]] ]
+; CHECK-NEXT: br label [[FOR_END_LOOPEXIT]]
+; CHECK: for.end.loopexit:
+; CHECK-NEXT: [[RETV_PH:%.*]] = phi ptr [ [[ADDR_LCSSA]], [[FOR_BODY_FOR_END_LOOPEXIT_CRIT_EDGE]] ], [ [[NEW_ADDR_LCSSA]], [[FOR_END_LOOPEXITSPLIT]] ]
+; CHECK-NEXT: br label [[FOR_END]]
+; CHECK: for.end:
+; CHECK-NEXT: [[RETV:%.*]] = phi ptr [ [[START]], [[ENTRY:%.*]] ], [ [[RETV_PH]], [[FOR_END_LOOPEXIT]] ]
+; CHECK-NEXT: ret ptr [[RETV]]
+;
+entry:
+ %end_i = ptrtoint ptr %end to i64
+ %start_i = ptrtoint ptr %start to i64
+ %delta_i = sub i64 %end_i, %start_i
+ %delta = trunc i64 %delta_i to i32
+ %cond1 = icmp sgt i32 %delta, 0
+ br i1 %cond1, label %for.body, label %for.end
+
+for.body: ; preds = %entry, %for.latch
+ %trip_count = phi i32 [ %new_trip_count, %for.latch ], [ %delta, %entry ]
+ %addr = phi ptr [ %new_addr, %for.latch ], [ %start, %entry ]
+ %data = load i8, ptr %addr, align 1
+ %cond2 = icmp eq i8 %data, %value
+ br i1 %cond2, label %for.end, label %for.latch
+
+for.latch: ; preds = %for.body
+ %new_addr = getelementptr inbounds i8, ptr %addr, i64 1
+ %new_trip_count = add nsw i32 %trip_count, -1
+ %cond3 = icmp sgt i32 %trip_count, 1
+ br i1 %cond3, label %for.body, label %for.end
+
+for.end: ; preds = %for.body, %for.latch, %entry
+ %retv = phi ptr [ %start, %entry ], [ %new_addr, %for.latch ], [ %addr, %for.body ]
+ ret ptr %retv
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/96048
More information about the llvm-commits
mailing list