[llvm] c0ef83e - [LSR] Check if terminating value is safe to expand before transformation

via llvm-commits llvm-commits at lists.llvm.org
Tue Nov 15 14:56:52 PST 2022


Author: eopXD
Date: 2022-11-15T14:56:47-08:00
New Revision: c0ef83e3b930c34cbfe861dccfe515bb1e450dfa

URL: https://github.com/llvm/llvm-project/commit/c0ef83e3b930c34cbfe861dccfe515bb1e450dfa
DIFF: https://github.com/llvm/llvm-project/commit/c0ef83e3b930c34cbfe861dccfe515bb1e450dfa.diff

LOG: [LSR] Check if terminating value is safe to expand before transformation

According to report by @JojoR, the assertion error was hit hence we need
to have this check before the actual transformation.

Reviewed By: Meinersbur, #loopoptwg

Differential Revision: https://reviews.llvm.org/D136415

Added: 
    

Modified: 
    llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
    llvm/test/Transforms/LoopStrengthReduce/lsr-term-fold-negative-testcase.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
index f9df7ba54b334..697d6e467db1c 100644
--- a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
@@ -6614,7 +6614,7 @@ static llvm::PHINode *GetInductionVariable(const Loop &L, ScalarEvolution &SE,
   return nullptr;
 }
 
-static Optional<std::pair<PHINode *, PHINode *>>
+static Optional<std::pair<PHINode *, std::pair<PHINode *, const SCEV *>>>
 canFoldTermCondOfLoop(Loop *L, ScalarEvolution &SE, DominatorTree &DT,
                       const LoopInfo &LI) {
   if (!L->isInnermost()) {
@@ -6699,16 +6699,37 @@ canFoldTermCondOfLoop(Loop *L, ScalarEvolution &SE, DominatorTree &DT,
 
   // For `IsToHelpFold`, other IV that is an affine AddRec will be sufficient to
   // replace the terminating condition
-  auto IsToHelpFold = [&](PHINode &PN) -> bool {
+  auto IsToHelpFold = [&](PHINode &PN) -> std::pair<bool, const SCEV *> {
+    const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(SE.getSCEV(&PN));
+    const SCEV *BECount = SE.getBackedgeTakenCount(L);
+    const SCEV *TermValueS = SE.getAddExpr(
+        AddRec->getOperand(0),
+        SE.getTruncateOrZeroExtend(
+            SE.getMulExpr(
+                AddRec->getOperand(1),
+                SE.getTruncateOrZeroExtend(
+                    SE.getAddExpr(BECount, SE.getOne(BECount->getType())),
+                    AddRec->getOperand(1)->getType())),
+            AddRec->getOperand(0)->getType()));
+    const DataLayout &DL = L->getHeader()->getModule()->getDataLayout();
+    SCEVExpander Expander(SE, DL, "lsr_fold_term_cond");
+    if (!Expander.isSafeToExpand(TermValueS)) {
+      LLVM_DEBUG(
+          dbgs() << "Is not safe to expand terminating value for phi node" << PN
+                 << "\n");
+      return {false, nullptr};
+    }
     // TODO: Right now we limit the phi node to help the folding be of a start
     // value of getelementptr. We can extend to any kinds of IV as long as it is
     // an affine AddRec. Add a switch to cover more types of instructions here
     // and down in the actual transformation.
-    return isa<GetElementPtrInst>(PN.getIncomingValueForBlock(LoopPreheader));
+    return {isa<GetElementPtrInst>(PN.getIncomingValueForBlock(LoopPreheader)),
+            TermValueS};
   };
 
   PHINode *ToFold = nullptr;
   PHINode *ToHelpFold = nullptr;
+  const SCEV *TermValueS = nullptr;
 
   for (PHINode &PN : L->getHeader()->phis()) {
     if (!SE.isSCEVable(PN.getType())) {
@@ -6729,8 +6750,10 @@ canFoldTermCondOfLoop(Loop *L, ScalarEvolution &SE, DominatorTree &DT,
 
     if (IsToFold(PN))
       ToFold = &PN;
-    else if (IsToHelpFold(PN))
+    else if (auto P = IsToHelpFold(PN); P.first) {
       ToHelpFold = &PN;
+      TermValueS = P.second;
+    }
   }
 
   LLVM_DEBUG(if (ToFold && !ToHelpFold) dbgs()
@@ -6746,7 +6769,7 @@ canFoldTermCondOfLoop(Loop *L, ScalarEvolution &SE, DominatorTree &DT,
 
   if (!ToFold || !ToHelpFold)
     return None;
-  return {{ToFold, ToHelpFold}};
+  return {{ToFold, {ToHelpFold, TermValueS}}};
 }
 
 static bool ReduceLoopStrength(Loop *L, IVUsers &IU, ScalarEvolution &SE,
@@ -6810,11 +6833,14 @@ static bool ReduceLoopStrength(Loop *L, IVUsers &IU, ScalarEvolution &SE,
   if (AllowTerminatingConditionFoldingAfterLSR) {
     auto CanFoldTerminatingCondition = canFoldTermCondOfLoop(L, SE, DT, LI);
     if (CanFoldTerminatingCondition) {
+      Changed = true;
+      NumTermFold++;
+
       BasicBlock *LoopPreheader = L->getLoopPreheader();
       BasicBlock *LoopLatch = L->getLoopLatch();
 
       PHINode *ToFold = CanFoldTerminatingCondition->first;
-      PHINode *ToHelpFold = CanFoldTerminatingCondition->second;
+      PHINode *ToHelpFold = CanFoldTerminatingCondition->second.first;
 
       (void)ToFold;
       LLVM_DEBUG(dbgs() << "To fold phi-node:\n"
@@ -6834,56 +6860,35 @@ static bool ReduceLoopStrength(Loop *L, IVUsers &IU, ScalarEvolution &SE,
       GetElementPtrInst *StartValueGEP = cast<GetElementPtrInst>(StartValue);
       Type *PtrTy = StartValueGEP->getPointerOperand()->getType();
 
-      const SCEV *BECount = SE.getBackedgeTakenCount(L);
-      const SCEVAddRecExpr *AddRec =
-          cast<SCEVAddRecExpr>(SE.getSCEV(ToHelpFold));
-
-      // TermValue = Start + Stride * (BackedgeCount + 1)
-      const SCEV *TermValueS = SE.getAddExpr(
-          AddRec->getOperand(0),
-          SE.getTruncateOrZeroExtend(
-              SE.getMulExpr(
-                  AddRec->getOperand(1),
-                  SE.getTruncateOrZeroExtend(
-                      SE.getAddExpr(BECount, SE.getOne(BECount->getType())),
-                      AddRec->getOperand(1)->getType())),
-              AddRec->getOperand(0)->getType()));
-
-      // NOTE: If this is triggered, we should add this into predicate
-      if (!Expander.isSafeToExpand(TermValueS)) {
-        LLVMContext &Ctx = L->getHeader()->getContext();
-        Ctx.emitError(
-            "Terminating value is not safe to expand, need to add it to "
-            "predicate");
-      } else { // Now we replace the condition with ToHelpFold and remove ToFold
-        Changed = true;
-        NumTermFold++;
-
-        Value *TermValue = Expander.expandCodeFor(
-            TermValueS, PtrTy, LoopPreheader->getTerminator());
-
-        LLVM_DEBUG(dbgs() << "Start value of new term-cond phi-node:\n"
-                          << *StartValue << "\n"
-                          << "Terminating value of new term-cond phi-node:\n"
-                          << *TermValue << "\n");
-
-        // Create new terminating condition at loop latch
-        BranchInst *BI = cast<BranchInst>(LoopLatch->getTerminator());
-        ICmpInst *OldTermCond = cast<ICmpInst>(BI->getCondition());
-        IRBuilder<> LatchBuilder(LoopLatch->getTerminator());
-        Value *NewTermCond = LatchBuilder.CreateICmp(
-            OldTermCond->getPredicate(), LoopValue, TermValue,
-            "lsr_fold_term_cond.replaced_term_cond");
-
-        LLVM_DEBUG(dbgs() << "Old term-cond:\n"
-                          << *OldTermCond << "\n"
-                          << "New term-cond:\b" << *NewTermCond << "\n");
-
-        BI->setCondition(NewTermCond);
-
-        OldTermCond->eraseFromParent();
-        DeleteDeadPHIs(L->getHeader(), &TLI, MSSAU.get());
-      }
+      const SCEV *TermValueS = CanFoldTerminatingCondition->second.second;
+      assert(
+          Expander.isSafeToExpand(TermValueS) &&
+          "Terminating value was checked safe in canFoldTerminatingCondition");
+
+      Value *TermValue = Expander.expandCodeFor(TermValueS, PtrTy,
+                                                LoopPreheader->getTerminator());
+
+      LLVM_DEBUG(dbgs() << "Start value of new term-cond phi-node:\n"
+                        << *StartValue << "\n"
+                        << "Terminating value of new term-cond phi-node:\n"
+                        << *TermValue << "\n");
+
+      // Create new terminating condition at loop latch
+      BranchInst *BI = cast<BranchInst>(LoopLatch->getTerminator());
+      ICmpInst *OldTermCond = cast<ICmpInst>(BI->getCondition());
+      IRBuilder<> LatchBuilder(LoopLatch->getTerminator());
+      Value *NewTermCond = LatchBuilder.CreateICmp(
+          OldTermCond->getPredicate(), LoopValue, TermValue,
+          "lsr_fold_term_cond.replaced_term_cond");
+
+      LLVM_DEBUG(dbgs() << "Old term-cond:\n"
+                        << *OldTermCond << "\n"
+                        << "New term-cond:\b" << *NewTermCond << "\n");
+
+      BI->setCondition(NewTermCond);
+
+      OldTermCond->eraseFromParent();
+      DeleteDeadPHIs(L->getHeader(), &TLI, MSSAU.get());
 
       ExpCleaner.markResultUsed();
     }

diff  --git a/llvm/test/Transforms/LoopStrengthReduce/lsr-term-fold-negative-testcase.ll b/llvm/test/Transforms/LoopStrengthReduce/lsr-term-fold-negative-testcase.ll
index ba41ddaeb028e..21bb8e5bd1cef 100644
--- a/llvm/test/Transforms/LoopStrengthReduce/lsr-term-fold-negative-testcase.ll
+++ b/llvm/test/Transforms/LoopStrengthReduce/lsr-term-fold-negative-testcase.ll
@@ -158,3 +158,82 @@ for.body:                                         ; preds = %for.body, %entry
 for.end:                                          ; preds = %for.body
   ret void
 }
+
+; The test case is reduced from FFmpeg/libavfilter/ebur128.c
+; Testing check if terminating value is safe to expand
+%struct.FFEBUR128State = type { i32, ptr, i64, i64 }
+
+ at histogram_energy_boundaries = global [1001 x double] zeroinitializer, align 8
+
+define void @ebur128_calc_gating_block(ptr %st, ptr %optional_output) {
+; CHECK: Is not safe to expand terminating value for phi node  %i.026 = phi i64 [ 0, %for.body7.lr.ph ], [ %inc, %for.body7 ]
+entry:
+  %0 = load i32, ptr %st, align 8
+  %conv = zext i32 %0 to i64
+  %cmp28.not = icmp eq i32 %0, 0
+  br i1 %cmp28.not, label %for.end13, label %for.cond2.preheader.lr.ph
+
+for.cond2.preheader.lr.ph:                        ; preds = %entry
+  %audio_data_index = getelementptr inbounds %struct.FFEBUR128State, ptr %st, i64 0, i32 3
+  %1 = load i64, ptr %audio_data_index, align 8
+  %div = udiv i64 %1, %conv
+  %cmp525.not = icmp ult i64 %1, %conv
+  %audio_data = getelementptr inbounds %struct.FFEBUR128State, ptr %st, i64 0, i32 1
+  %umax = tail call i64 @llvm.umax.i64(i64 %div, i64 1)
+  br label %for.cond2.preheader
+
+for.cond2.preheader:                              ; preds = %for.cond2.preheader.lr.ph, %for.inc11
+  %channel_sum.030 = phi double [ 0.000000e+00, %for.cond2.preheader.lr.ph ], [ %channel_sum.1.lcssa, %for.inc11 ]
+  %c.029 = phi i64 [ 0, %for.cond2.preheader.lr.ph ], [ %inc12, %for.inc11 ]
+  br i1 %cmp525.not, label %for.inc11, label %for.body7.lr.ph
+
+for.body7.lr.ph:                                  ; preds = %for.cond2.preheader
+  %2 = load ptr, ptr %audio_data, align 8
+  br label %for.body7
+
+for.body7:                                        ; preds = %for.body7.lr.ph, %for.body7
+  %channel_sum.127 = phi double [ %channel_sum.030, %for.body7.lr.ph ], [ %add10, %for.body7 ]
+  %i.026 = phi i64 [ 0, %for.body7.lr.ph ], [ %inc, %for.body7 ]
+  %mul = mul i64 %i.026, %conv
+  %add = add i64 %mul, %c.029
+  %arrayidx = getelementptr inbounds double, ptr %2, i64 %add
+  %3 = load double, ptr %arrayidx, align 8
+  %add10 = fadd double %channel_sum.127, %3
+  %inc = add nuw i64 %i.026, 1
+  %exitcond.not = icmp eq i64 %inc, %umax
+  br i1 %exitcond.not, label %for.inc11, label %for.body7
+
+for.inc11:                                        ; preds = %for.body7, %for.cond2.preheader
+  %channel_sum.1.lcssa = phi double [ %channel_sum.030, %for.cond2.preheader ], [ %add10, %for.body7 ]
+  %inc12 = add nuw nsw i64 %c.029, 1
+  %exitcond32.not = icmp eq i64 %inc12, %conv
+  br i1 %exitcond32.not, label %for.end13, label %for.cond2.preheader
+
+for.end13:                                        ; preds = %for.inc11, %entry
+  %channel_sum.0.lcssa = phi double [ 0.000000e+00, %entry ], [ %channel_sum.1.lcssa, %for.inc11 ]
+  %add14 = fadd double %channel_sum.0.lcssa, 0.000000e+00
+  store double %add14, ptr %optional_output, align 8
+  ret void
+}
+
+declare i64 @llvm.umax.i64(i64, i64)
+
+%struct.PAKT_INFO = type { i32, i32, i32, [0 x i32] }
+
+define i64 @alac_seek(ptr %0) {
+; CHECK: Is not safe to expand terminating value for phi node  %indvars.iv.i = phi i64 [ 0, %entry ], [ %indvars.iv.next.i, %for.body.i ]
+entry:
+  %div = udiv i64 1, 0
+  br label %for.body.i
+
+for.body.i:                                       ; preds = %for.body.i, %entry
+  %indvars.iv.i = phi i64 [ 0, %entry ], [ %indvars.iv.next.i, %for.body.i ]
+  %arrayidx.i = getelementptr %struct.PAKT_INFO, ptr %0, i64 0, i32 3, i64 %indvars.iv.i
+  %1 = load i32, ptr %arrayidx.i, align 4
+  %indvars.iv.next.i = add i64 %indvars.iv.i, 1
+  %exitcond.not.i = icmp eq i64 %indvars.iv.i, %div
+  br i1 %exitcond.not.i, label %alac_pakt_block_offset.exit, label %for.body.i
+
+alac_pakt_block_offset.exit:                      ; preds = %for.body.i
+  ret i64 0
+}


        


More information about the llvm-commits mailing list