[llvm] [LSR] Optimize lsr-term-fold if the operand of condition is PHINode (PR #96048)

via llvm-commits llvm-commits at lists.llvm.org
Wed Jun 19 02:28:05 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: Zhijin Zeng (zengdage)

<details>
<summary>Changes</summary>

`lsr-term-fold` just support that the operand of the branch condition within the latch terminator is `BinaryOperator,` and if the operand is `PHINode`, `lsr-term-fold` should also do terminator fold for it.

---
Full diff: https://github.com/llvm/llvm-project/pull/96048.diff


2 Files Affected:

- (modified) llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp (+21-8) 
- (modified) llvm/test/Transforms/LoopStrengthReduce/lsr-term-fold.ll (+66) 


``````````diff
diff --git a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
index 3a98e257367b2..6eda52a51c77d 100644
--- a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
@@ -6830,15 +6830,18 @@ canFoldTermCondOfLoop(Loop *L, ScalarEvolution &SE, DominatorTree &DT,
     return std::nullopt;
   }
 
-  if (!SE.hasLoopInvariantBackedgeTakenCount(L)) {
-    LLVM_DEBUG(dbgs() << "Cannot fold on backedge that is loop variant\n");
-    return std::nullopt;
-  }
-
   BasicBlock *LoopLatch = L->getLoopLatch();
   BranchInst *BI = dyn_cast<BranchInst>(LoopLatch->getTerminator());
   if (!BI || BI->isUnconditional())
     return std::nullopt;
+
+  if (!SE.hasLoopInvariantBackedgeTakenCount(L)) {
+    if (isa<SCEVCouldNotCompute>(SE.getExitCount(L, BI->getParent()))) {
+      LLVM_DEBUG(dbgs() << "Cannot fold on backedge that is loop variant\n");
+      return std::nullopt;
+    }
+  }
+
   auto *TermCond = dyn_cast<ICmpInst>(BI->getCondition());
   if (!TermCond) {
     LLVM_DEBUG(
@@ -6852,9 +6855,8 @@ canFoldTermCondOfLoop(Loop *L, ScalarEvolution &SE, DominatorTree &DT,
     return std::nullopt;
   }
 
-  BinaryOperator *LHS = dyn_cast<BinaryOperator>(TermCond->getOperand(0));
   Value *RHS = TermCond->getOperand(1);
-  if (!LHS || !L->isLoopInvariant(RHS))
+  if (!L->isLoopInvariant(RHS))
     // We could pattern match the inverse form of the icmp, but that is
     // non-canonical, and this pass is running *very* late in the pipeline.
     return std::nullopt;
@@ -6862,7 +6864,15 @@ canFoldTermCondOfLoop(Loop *L, ScalarEvolution &SE, DominatorTree &DT,
   // Find the IV used by the current exit condition.
   PHINode *ToFold;
   Value *ToFoldStart, *ToFoldStep;
-  if (!matchSimpleRecurrence(LHS, ToFold, ToFoldStart, ToFoldStep))
+  if (BinaryOperator *LHS = dyn_cast<BinaryOperator>(TermCond->getOperand(0))) {
+    if (!matchSimpleRecurrence(LHS, ToFold, ToFoldStart, ToFoldStep))
+      return std::nullopt;
+  } else if (PHINode *LHS = dyn_cast<PHINode>(TermCond->getOperand(0))) {
+    BinaryOperator *BO = nullptr;
+    if (!matchSimpleRecurrence(LHS, BO, ToFoldStart, ToFoldStep))
+      return std::nullopt;
+    ToFold = LHS;
+  } else
     return std::nullopt;
 
   // Ensure the simple recurrence is a part of the current loop.
@@ -6887,6 +6897,9 @@ canFoldTermCondOfLoop(Loop *L, ScalarEvolution &SE, DominatorTree &DT,
   }();
 
   const SCEV *BECount = SE.getBackedgeTakenCount(L);
+  if (isa<SCEVCouldNotCompute>(BECount))
+    BECount = SE.getExitCount(L, BI->getParent());
+
   const DataLayout &DL = L->getHeader()->getModule()->getDataLayout();
   SCEVExpander Expander(SE, DL, "lsr_fold_term_cond");
 
diff --git a/llvm/test/Transforms/LoopStrengthReduce/lsr-term-fold.ll b/llvm/test/Transforms/LoopStrengthReduce/lsr-term-fold.ll
index 7299a014b7983..877f48dece710 100644
--- a/llvm/test/Transforms/LoopStrengthReduce/lsr-term-fold.ll
+++ b/llvm/test/Transforms/LoopStrengthReduce/lsr-term-fold.ll
@@ -733,3 +733,69 @@ for.body:                                         ; preds = %for.body, %entry
 for.end:                                          ; preds = %for.body
   ret void
 }
+
+define ptr @no_binary_operator(ptr %start, ptr %end, i8 %value) {
+; CHECK-LABEL: @no_binary_operator(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[END_I:%.*]] = ptrtoint ptr [[END:%.*]] to i64
+; CHECK-NEXT:    [[START_I:%.*]] = ptrtoint ptr [[START:%.*]] to i64
+; CHECK-NEXT:    [[DELTA_I:%.*]] = sub i64 [[END_I]], [[START_I]]
+; CHECK-NEXT:    [[DELTA:%.*]] = trunc i64 [[DELTA_I]] to i32
+; CHECK-NEXT:    [[COND1:%.*]] = icmp sgt i32 [[DELTA]], 0
+; CHECK-NEXT:    br i1 [[COND1]], label [[FOR_BODY_PREHEADER:%.*]], label [[FOR_END:%.*]]
+; CHECK:       for.body.preheader:
+; CHECK-NEXT:    [[TMP0:%.*]] = trunc i64 [[END_I]] to i32
+; CHECK-NEXT:    [[TMP1:%.*]] = add i32 [[TMP0]], -1
+; CHECK-NEXT:    [[TMP2:%.*]] = trunc i64 [[START_I]] to i32
+; CHECK-NEXT:    [[TMP3:%.*]] = sub i32 [[TMP1]], [[TMP2]]
+; CHECK-NEXT:    [[TMP4:%.*]] = zext i32 [[TMP3]] to i64
+; CHECK-NEXT:    [[TMP5:%.*]] = add nuw nsw i64 [[TMP4]], 1
+; CHECK-NEXT:    [[SCEVGEP:%.*]] = getelementptr i8, ptr [[START]], i64 [[TMP5]]
+; CHECK-NEXT:    br label [[FOR_BODY:%.*]]
+; CHECK:       for.body:
+; CHECK-NEXT:    [[ADDR:%.*]] = phi ptr [ [[NEW_ADDR:%.*]], [[FOR_LATCH:%.*]] ], [ [[START]], [[FOR_BODY_PREHEADER]] ]
+; CHECK-NEXT:    [[DATA:%.*]] = load i8, ptr [[ADDR]], align 1
+; CHECK-NEXT:    [[COND2:%.*]] = icmp eq i8 [[DATA]], [[VALUE:%.*]]
+; CHECK-NEXT:    br i1 [[COND2]], label [[FOR_BODY_FOR_END_LOOPEXIT_CRIT_EDGE:%.*]], label [[FOR_LATCH]]
+; CHECK:       for.latch:
+; CHECK-NEXT:    [[NEW_ADDR]] = getelementptr i8, ptr [[ADDR]], i64 1
+; CHECK-NEXT:    [[LSR_FOLD_TERM_COND_REPLACED_TERM_COND:%.*]] = icmp eq ptr [[NEW_ADDR]], [[SCEVGEP]]
+; CHECK-NEXT:    br i1 [[LSR_FOLD_TERM_COND_REPLACED_TERM_COND]], label [[FOR_END_LOOPEXITSPLIT:%.*]], label [[FOR_BODY]]
+; CHECK:       for.end.loopexitsplit:
+; CHECK-NEXT:    [[NEW_ADDR_LCSSA:%.*]] = phi ptr [ [[NEW_ADDR]], [[FOR_LATCH]] ]
+; CHECK-NEXT:    br label [[FOR_END_LOOPEXIT:%.*]]
+; CHECK:       for.body.for.end.loopexit_crit_edge:
+; CHECK-NEXT:    [[ADDR_LCSSA:%.*]] = phi ptr [ [[ADDR]], [[FOR_BODY]] ]
+; CHECK-NEXT:    br label [[FOR_END_LOOPEXIT]]
+; CHECK:       for.end.loopexit:
+; CHECK-NEXT:    [[RETV_PH:%.*]] = phi ptr [ [[ADDR_LCSSA]], [[FOR_BODY_FOR_END_LOOPEXIT_CRIT_EDGE]] ], [ [[NEW_ADDR_LCSSA]], [[FOR_END_LOOPEXITSPLIT]] ]
+; CHECK-NEXT:    br label [[FOR_END]]
+; CHECK:       for.end:
+; CHECK-NEXT:    [[RETV:%.*]] = phi ptr [ [[START]], [[ENTRY:%.*]] ], [ [[RETV_PH]], [[FOR_END_LOOPEXIT]] ]
+; CHECK-NEXT:    ret ptr [[RETV]]
+;
+entry:
+  %end_i = ptrtoint ptr %end to i64
+  %start_i = ptrtoint ptr %start to i64
+  %delta_i = sub i64 %end_i, %start_i
+  %delta = trunc i64 %delta_i to i32
+  %cond1 = icmp sgt i32 %delta, 0
+  br i1 %cond1, label %for.body, label %for.end
+
+for.body:                                                ; preds = %entry, %for.latch
+  %trip_count = phi i32 [ %new_trip_count, %for.latch ], [ %delta, %entry ]
+  %addr = phi ptr [ %new_addr, %for.latch ], [ %start, %entry ]
+  %data = load i8, ptr %addr, align 1
+  %cond2 = icmp eq i8 %data, %value
+  br i1 %cond2, label %for.end, label %for.latch
+
+for.latch:                                               ; preds = %for.body
+  %new_addr = getelementptr inbounds i8, ptr %addr, i64 1
+  %new_trip_count = add nsw i32 %trip_count, -1
+  %cond3 = icmp sgt i32 %trip_count, 1
+  br i1 %cond3, label %for.body, label %for.end
+
+for.end:                                               ; preds = %for.body, %for.latch, %entry
+  %retv = phi ptr [ %start, %entry ], [ %new_addr, %for.latch ], [ %addr, %for.body ]
+  ret ptr %retv
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/96048


More information about the llvm-commits mailing list