[llvm] c0ef83e - [LSR] Check if terminating value is safe to expand before transformation
via llvm-commits
llvm-commits at lists.llvm.org
Tue Nov 15 14:56:52 PST 2022
Author: eopXD
Date: 2022-11-15T14:56:47-08:00
New Revision: c0ef83e3b930c34cbfe861dccfe515bb1e450dfa
URL: https://github.com/llvm/llvm-project/commit/c0ef83e3b930c34cbfe861dccfe515bb1e450dfa
DIFF: https://github.com/llvm/llvm-project/commit/c0ef83e3b930c34cbfe861dccfe515bb1e450dfa.diff
LOG: [LSR] Check if terminating value is safe to expand before transformation
According to report by @JojoR, the assertion error was hit hence we need
to have this check before the actual transformation.
Reviewed By: Meinersbur, #loopoptwg
Differential Revision: https://reviews.llvm.org/D136415
Added:
Modified:
llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
llvm/test/Transforms/LoopStrengthReduce/lsr-term-fold-negative-testcase.ll
Removed:
################################################################################
diff --git a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
index f9df7ba54b334..697d6e467db1c 100644
--- a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
@@ -6614,7 +6614,7 @@ static llvm::PHINode *GetInductionVariable(const Loop &L, ScalarEvolution &SE,
return nullptr;
}
-static Optional<std::pair<PHINode *, PHINode *>>
+static Optional<std::pair<PHINode *, std::pair<PHINode *, const SCEV *>>>
canFoldTermCondOfLoop(Loop *L, ScalarEvolution &SE, DominatorTree &DT,
const LoopInfo &LI) {
if (!L->isInnermost()) {
@@ -6699,16 +6699,37 @@ canFoldTermCondOfLoop(Loop *L, ScalarEvolution &SE, DominatorTree &DT,
// For `IsToHelpFold`, other IV that is an affine AddRec will be sufficient to
// replace the terminating condition
- auto IsToHelpFold = [&](PHINode &PN) -> bool {
+ auto IsToHelpFold = [&](PHINode &PN) -> std::pair<bool, const SCEV *> {
+ const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(SE.getSCEV(&PN));
+ const SCEV *BECount = SE.getBackedgeTakenCount(L);
+ const SCEV *TermValueS = SE.getAddExpr(
+ AddRec->getOperand(0),
+ SE.getTruncateOrZeroExtend(
+ SE.getMulExpr(
+ AddRec->getOperand(1),
+ SE.getTruncateOrZeroExtend(
+ SE.getAddExpr(BECount, SE.getOne(BECount->getType())),
+ AddRec->getOperand(1)->getType())),
+ AddRec->getOperand(0)->getType()));
+ const DataLayout &DL = L->getHeader()->getModule()->getDataLayout();
+ SCEVExpander Expander(SE, DL, "lsr_fold_term_cond");
+ if (!Expander.isSafeToExpand(TermValueS)) {
+ LLVM_DEBUG(
+ dbgs() << "Is not safe to expand terminating value for phi node" << PN
+ << "\n");
+ return {false, nullptr};
+ }
// TODO: Right now we limit the phi node to help the folding be of a start
// value of getelementptr. We can extend to any kinds of IV as long as it is
// an affine AddRec. Add a switch to cover more types of instructions here
// and down in the actual transformation.
- return isa<GetElementPtrInst>(PN.getIncomingValueForBlock(LoopPreheader));
+ return {isa<GetElementPtrInst>(PN.getIncomingValueForBlock(LoopPreheader)),
+ TermValueS};
};
PHINode *ToFold = nullptr;
PHINode *ToHelpFold = nullptr;
+ const SCEV *TermValueS = nullptr;
for (PHINode &PN : L->getHeader()->phis()) {
if (!SE.isSCEVable(PN.getType())) {
@@ -6729,8 +6750,10 @@ canFoldTermCondOfLoop(Loop *L, ScalarEvolution &SE, DominatorTree &DT,
if (IsToFold(PN))
ToFold = &PN;
- else if (IsToHelpFold(PN))
+ else if (auto P = IsToHelpFold(PN); P.first) {
ToHelpFold = &PN;
+ TermValueS = P.second;
+ }
}
LLVM_DEBUG(if (ToFold && !ToHelpFold) dbgs()
@@ -6746,7 +6769,7 @@ canFoldTermCondOfLoop(Loop *L, ScalarEvolution &SE, DominatorTree &DT,
if (!ToFold || !ToHelpFold)
return None;
- return {{ToFold, ToHelpFold}};
+ return {{ToFold, {ToHelpFold, TermValueS}}};
}
static bool ReduceLoopStrength(Loop *L, IVUsers &IU, ScalarEvolution &SE,
@@ -6810,11 +6833,14 @@ static bool ReduceLoopStrength(Loop *L, IVUsers &IU, ScalarEvolution &SE,
if (AllowTerminatingConditionFoldingAfterLSR) {
auto CanFoldTerminatingCondition = canFoldTermCondOfLoop(L, SE, DT, LI);
if (CanFoldTerminatingCondition) {
+ Changed = true;
+ NumTermFold++;
+
BasicBlock *LoopPreheader = L->getLoopPreheader();
BasicBlock *LoopLatch = L->getLoopLatch();
PHINode *ToFold = CanFoldTerminatingCondition->first;
- PHINode *ToHelpFold = CanFoldTerminatingCondition->second;
+ PHINode *ToHelpFold = CanFoldTerminatingCondition->second.first;
(void)ToFold;
LLVM_DEBUG(dbgs() << "To fold phi-node:\n"
@@ -6834,56 +6860,35 @@ static bool ReduceLoopStrength(Loop *L, IVUsers &IU, ScalarEvolution &SE,
GetElementPtrInst *StartValueGEP = cast<GetElementPtrInst>(StartValue);
Type *PtrTy = StartValueGEP->getPointerOperand()->getType();
- const SCEV *BECount = SE.getBackedgeTakenCount(L);
- const SCEVAddRecExpr *AddRec =
- cast<SCEVAddRecExpr>(SE.getSCEV(ToHelpFold));
-
- // TermValue = Start + Stride * (BackedgeCount + 1)
- const SCEV *TermValueS = SE.getAddExpr(
- AddRec->getOperand(0),
- SE.getTruncateOrZeroExtend(
- SE.getMulExpr(
- AddRec->getOperand(1),
- SE.getTruncateOrZeroExtend(
- SE.getAddExpr(BECount, SE.getOne(BECount->getType())),
- AddRec->getOperand(1)->getType())),
- AddRec->getOperand(0)->getType()));
-
- // NOTE: If this is triggered, we should add this into predicate
- if (!Expander.isSafeToExpand(TermValueS)) {
- LLVMContext &Ctx = L->getHeader()->getContext();
- Ctx.emitError(
- "Terminating value is not safe to expand, need to add it to "
- "predicate");
- } else { // Now we replace the condition with ToHelpFold and remove ToFold
- Changed = true;
- NumTermFold++;
-
- Value *TermValue = Expander.expandCodeFor(
- TermValueS, PtrTy, LoopPreheader->getTerminator());
-
- LLVM_DEBUG(dbgs() << "Start value of new term-cond phi-node:\n"
- << *StartValue << "\n"
- << "Terminating value of new term-cond phi-node:\n"
- << *TermValue << "\n");
-
- // Create new terminating condition at loop latch
- BranchInst *BI = cast<BranchInst>(LoopLatch->getTerminator());
- ICmpInst *OldTermCond = cast<ICmpInst>(BI->getCondition());
- IRBuilder<> LatchBuilder(LoopLatch->getTerminator());
- Value *NewTermCond = LatchBuilder.CreateICmp(
- OldTermCond->getPredicate(), LoopValue, TermValue,
- "lsr_fold_term_cond.replaced_term_cond");
-
- LLVM_DEBUG(dbgs() << "Old term-cond:\n"
- << *OldTermCond << "\n"
- << "New term-cond:\b" << *NewTermCond << "\n");
-
- BI->setCondition(NewTermCond);
-
- OldTermCond->eraseFromParent();
- DeleteDeadPHIs(L->getHeader(), &TLI, MSSAU.get());
- }
+ const SCEV *TermValueS = CanFoldTerminatingCondition->second.second;
+ assert(
+ Expander.isSafeToExpand(TermValueS) &&
+ "Terminating value was checked safe in canFoldTerminatingCondition");
+
+ Value *TermValue = Expander.expandCodeFor(TermValueS, PtrTy,
+ LoopPreheader->getTerminator());
+
+ LLVM_DEBUG(dbgs() << "Start value of new term-cond phi-node:\n"
+ << *StartValue << "\n"
+ << "Terminating value of new term-cond phi-node:\n"
+ << *TermValue << "\n");
+
+ // Create new terminating condition at loop latch
+ BranchInst *BI = cast<BranchInst>(LoopLatch->getTerminator());
+ ICmpInst *OldTermCond = cast<ICmpInst>(BI->getCondition());
+ IRBuilder<> LatchBuilder(LoopLatch->getTerminator());
+ Value *NewTermCond = LatchBuilder.CreateICmp(
+ OldTermCond->getPredicate(), LoopValue, TermValue,
+ "lsr_fold_term_cond.replaced_term_cond");
+
+ LLVM_DEBUG(dbgs() << "Old term-cond:\n"
+ << *OldTermCond << "\n"
+ << "New term-cond:\b" << *NewTermCond << "\n");
+
+ BI->setCondition(NewTermCond);
+
+ OldTermCond->eraseFromParent();
+ DeleteDeadPHIs(L->getHeader(), &TLI, MSSAU.get());
ExpCleaner.markResultUsed();
}
diff --git a/llvm/test/Transforms/LoopStrengthReduce/lsr-term-fold-negative-testcase.ll b/llvm/test/Transforms/LoopStrengthReduce/lsr-term-fold-negative-testcase.ll
index ba41ddaeb028e..21bb8e5bd1cef 100644
--- a/llvm/test/Transforms/LoopStrengthReduce/lsr-term-fold-negative-testcase.ll
+++ b/llvm/test/Transforms/LoopStrengthReduce/lsr-term-fold-negative-testcase.ll
@@ -158,3 +158,82 @@ for.body: ; preds = %for.body, %entry
for.end: ; preds = %for.body
ret void
}
+
+; The test case is reduced from FFmpeg/libavfilter/ebur128.c
+; Testing check if terminating value is safe to expand
+%struct.FFEBUR128State = type { i32, ptr, i64, i64 }
+
+ at histogram_energy_boundaries = global [1001 x double] zeroinitializer, align 8
+
+define void @ebur128_calc_gating_block(ptr %st, ptr %optional_output) {
+; CHECK: Is not safe to expand terminating value for phi node %i.026 = phi i64 [ 0, %for.body7.lr.ph ], [ %inc, %for.body7 ]
+entry:
+ %0 = load i32, ptr %st, align 8
+ %conv = zext i32 %0 to i64
+ %cmp28.not = icmp eq i32 %0, 0
+ br i1 %cmp28.not, label %for.end13, label %for.cond2.preheader.lr.ph
+
+for.cond2.preheader.lr.ph: ; preds = %entry
+ %audio_data_index = getelementptr inbounds %struct.FFEBUR128State, ptr %st, i64 0, i32 3
+ %1 = load i64, ptr %audio_data_index, align 8
+ %div = udiv i64 %1, %conv
+ %cmp525.not = icmp ult i64 %1, %conv
+ %audio_data = getelementptr inbounds %struct.FFEBUR128State, ptr %st, i64 0, i32 1
+ %umax = tail call i64 @llvm.umax.i64(i64 %div, i64 1)
+ br label %for.cond2.preheader
+
+for.cond2.preheader: ; preds = %for.cond2.preheader.lr.ph, %for.inc11
+ %channel_sum.030 = phi double [ 0.000000e+00, %for.cond2.preheader.lr.ph ], [ %channel_sum.1.lcssa, %for.inc11 ]
+ %c.029 = phi i64 [ 0, %for.cond2.preheader.lr.ph ], [ %inc12, %for.inc11 ]
+ br i1 %cmp525.not, label %for.inc11, label %for.body7.lr.ph
+
+for.body7.lr.ph: ; preds = %for.cond2.preheader
+ %2 = load ptr, ptr %audio_data, align 8
+ br label %for.body7
+
+for.body7: ; preds = %for.body7.lr.ph, %for.body7
+ %channel_sum.127 = phi double [ %channel_sum.030, %for.body7.lr.ph ], [ %add10, %for.body7 ]
+ %i.026 = phi i64 [ 0, %for.body7.lr.ph ], [ %inc, %for.body7 ]
+ %mul = mul i64 %i.026, %conv
+ %add = add i64 %mul, %c.029
+ %arrayidx = getelementptr inbounds double, ptr %2, i64 %add
+ %3 = load double, ptr %arrayidx, align 8
+ %add10 = fadd double %channel_sum.127, %3
+ %inc = add nuw i64 %i.026, 1
+ %exitcond.not = icmp eq i64 %inc, %umax
+ br i1 %exitcond.not, label %for.inc11, label %for.body7
+
+for.inc11: ; preds = %for.body7, %for.cond2.preheader
+ %channel_sum.1.lcssa = phi double [ %channel_sum.030, %for.cond2.preheader ], [ %add10, %for.body7 ]
+ %inc12 = add nuw nsw i64 %c.029, 1
+ %exitcond32.not = icmp eq i64 %inc12, %conv
+ br i1 %exitcond32.not, label %for.end13, label %for.cond2.preheader
+
+for.end13: ; preds = %for.inc11, %entry
+ %channel_sum.0.lcssa = phi double [ 0.000000e+00, %entry ], [ %channel_sum.1.lcssa, %for.inc11 ]
+ %add14 = fadd double %channel_sum.0.lcssa, 0.000000e+00
+ store double %add14, ptr %optional_output, align 8
+ ret void
+}
+
+declare i64 @llvm.umax.i64(i64, i64)
+
+%struct.PAKT_INFO = type { i32, i32, i32, [0 x i32] }
+
+define i64 @alac_seek(ptr %0) {
+; CHECK: Is not safe to expand terminating value for phi node %indvars.iv.i = phi i64 [ 0, %entry ], [ %indvars.iv.next.i, %for.body.i ]
+entry:
+ %div = udiv i64 1, 0
+ br label %for.body.i
+
+for.body.i: ; preds = %for.body.i, %entry
+ %indvars.iv.i = phi i64 [ 0, %entry ], [ %indvars.iv.next.i, %for.body.i ]
+ %arrayidx.i = getelementptr %struct.PAKT_INFO, ptr %0, i64 0, i32 3, i64 %indvars.iv.i
+ %1 = load i32, ptr %arrayidx.i, align 4
+ %indvars.iv.next.i = add i64 %indvars.iv.i, 1
+ %exitcond.not.i = icmp eq i64 %indvars.iv.i, %div
+ br i1 %exitcond.not.i, label %alac_pakt_block_offset.exit, label %for.body.i
+
+alac_pakt_block_offset.exit: ; preds = %for.body.i
+ ret i64 0
+}
More information about the llvm-commits
mailing list