[llvm] 27a62ec - [LSR] Split the -lsr-term-fold transformation into it's own pass (#104234)

via llvm-commits llvm-commits at lists.llvm.org
Sat Aug 17 18:34:27 PDT 2024


Author: Philip Reames
Date: 2024-08-17T18:34:23-07:00
New Revision: 27a62ec72aed3faf1388600f485552471b580e3b

URL: https://github.com/llvm/llvm-project/commit/27a62ec72aed3faf1388600f485552471b580e3b
DIFF: https://github.com/llvm/llvm-project/commit/27a62ec72aed3faf1388600f485552471b580e3b.diff

LOG: [LSR] Split the -lsr-term-fold transformation into it's own pass (#104234)

This transformation doesn't actually use any of the internal state of
LSR and recomputes all information from SCEV.  Splitting it out makes
it easier to test.
    
Note that long term I would like to write a version of this transform
which *is* integrated with LSR's solver, but if that happens, we'll
just delete the extra pass.
    
Integration wise, I switched from using TTI to using a pass configuration
variable.  This seems slightly more idiomatic, and means we don't run
the extra logic on any target other than RISCV.

Added: 
    llvm/include/llvm/Transforms/Scalar/LoopTermFold.h
    llvm/lib/Transforms/Scalar/LoopTermFold.cpp

Modified: 
    llvm/include/llvm/Analysis/TargetTransformInfo.h
    llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
    llvm/include/llvm/CodeGen/BasicTTIImpl.h
    llvm/include/llvm/CodeGen/TargetPassConfig.h
    llvm/include/llvm/InitializePasses.h
    llvm/include/llvm/LinkAllPasses.h
    llvm/include/llvm/Passes/MachinePassRegistry.def
    llvm/include/llvm/Transforms/Scalar.h
    llvm/lib/Analysis/TargetTransformInfo.cpp
    llvm/lib/CodeGen/TargetPassConfig.cpp
    llvm/lib/Passes/PassBuilder.cpp
    llvm/lib/Passes/PassRegistry.def
    llvm/lib/Target/RISCV/RISCVTargetMachine.cpp
    llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
    llvm/lib/Transforms/Scalar/CMakeLists.txt
    llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
    llvm/lib/Transforms/Scalar/Scalar.cpp
    llvm/test/CodeGen/RISCV/O3-pipeline.ll
    llvm/test/Transforms/LoopStrengthReduce/RISCV/lsr-cost-compare.ll
    llvm/test/Transforms/LoopStrengthReduce/RISCV/term-fold-crash.ll
    llvm/test/Transforms/LoopStrengthReduce/lsr-term-fold-negative-testcase.ll
    llvm/test/Transforms/LoopStrengthReduce/lsr-term-fold.ll
    llvm/test/Transforms/LoopStrengthReduce/lsr-unreachable-bb-phi-node.ll

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index dde90abc06cd9c..b2124c6106198e 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -739,11 +739,6 @@ class TargetTransformInfo {
   /// cost should return false, otherwise return true.
   bool isNumRegsMajorCostOfLSR() const;
 
-  /// Return true if LSR should attempts to replace a use of an otherwise dead
-  /// primary IV in the latch condition with another IV available in the loop.
-  /// When successful, makes the primary IV dead.
-  bool shouldFoldTerminatingConditionAfterLSR() const;
-
   /// Return true if LSR should drop a found solution if it's calculated to be
   /// less profitable than the baseline.
   bool shouldDropLSRSolutionIfLessProfitable() const;
@@ -1888,7 +1883,6 @@ class TargetTransformInfo::Concept {
   virtual bool isLSRCostLess(const TargetTransformInfo::LSRCost &C1,
                              const TargetTransformInfo::LSRCost &C2) = 0;
   virtual bool isNumRegsMajorCostOfLSR() = 0;
-  virtual bool shouldFoldTerminatingConditionAfterLSR() const = 0;
   virtual bool shouldDropLSRSolutionIfLessProfitable() const = 0;
   virtual bool isProfitableLSRChainElement(Instruction *I) = 0;
   virtual bool canMacroFuseCmp() = 0;
@@ -2367,9 +2361,6 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
   bool isNumRegsMajorCostOfLSR() override {
     return Impl.isNumRegsMajorCostOfLSR();
   }
-  bool shouldFoldTerminatingConditionAfterLSR() const override {
-    return Impl.shouldFoldTerminatingConditionAfterLSR();
-  }
   bool shouldDropLSRSolutionIfLessProfitable() const override {
     return Impl.shouldDropLSRSolutionIfLessProfitable();
   }

diff  --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index d208a710bb27fd..11b07ac0b7fc47 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -244,8 +244,6 @@ class TargetTransformInfoImplBase {
 
   bool isNumRegsMajorCostOfLSR() const { return true; }
 
-  bool shouldFoldTerminatingConditionAfterLSR() const { return false; }
-
   bool shouldDropLSRSolutionIfLessProfitable() const { return false; }
 
   bool isProfitableLSRChainElement(Instruction *I) const { return false; }

diff  --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index 77ddc10e8a0e76..217e3f1324f9c9 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -394,11 +394,6 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
     return TargetTransformInfoImplBase::isNumRegsMajorCostOfLSR();
   }
 
-  bool shouldFoldTerminatingConditionAfterLSR() const {
-    return TargetTransformInfoImplBase::
-        shouldFoldTerminatingConditionAfterLSR();
-  }
-
   bool shouldDropLSRSolutionIfLessProfitable() const {
     return TargetTransformInfoImplBase::shouldDropLSRSolutionIfLessProfitable();
   }

diff  --git a/llvm/include/llvm/CodeGen/TargetPassConfig.h b/llvm/include/llvm/CodeGen/TargetPassConfig.h
index d00e0bed91a457..2f5951e3ec3bce 100644
--- a/llvm/include/llvm/CodeGen/TargetPassConfig.h
+++ b/llvm/include/llvm/CodeGen/TargetPassConfig.h
@@ -140,6 +140,9 @@ class TargetPassConfig : public ImmutablePass {
   /// callers.
   bool RequireCodeGenSCCOrder = false;
 
+  /// Enable LoopTermFold immediately after LSR
+  bool EnableLoopTermFold = false;
+
   /// Add the actual instruction selection passes. This does not include
   /// preparation passes on IR.
   bool addCoreISelPasses();

diff  --git a/llvm/include/llvm/InitializePasses.h b/llvm/include/llvm/InitializePasses.h
index a4ac314bb590e5..cc5e93c58f564a 100644
--- a/llvm/include/llvm/InitializePasses.h
+++ b/llvm/include/llvm/InitializePasses.h
@@ -169,6 +169,7 @@ void initializeLoopInfoWrapperPassPass(PassRegistry &);
 void initializeLoopPassPass(PassRegistry &);
 void initializeLoopSimplifyPass(PassRegistry &);
 void initializeLoopStrengthReducePass(PassRegistry &);
+void initializeLoopTermFoldPass(PassRegistry &);
 void initializeLoopUnrollPass(PassRegistry &);
 void initializeLowerAtomicLegacyPassPass(PassRegistry &);
 void initializeLowerConstantIntrinsicsPass(PassRegistry &);

diff  --git a/llvm/include/llvm/LinkAllPasses.h b/llvm/include/llvm/LinkAllPasses.h
index e6a70dfd1ea6f8..1da02153d846f1 100644
--- a/llvm/include/llvm/LinkAllPasses.h
+++ b/llvm/include/llvm/LinkAllPasses.h
@@ -90,6 +90,7 @@ struct ForcePassLinking {
     (void)llvm::createLoopExtractorPass();
     (void)llvm::createLoopSimplifyPass();
     (void)llvm::createLoopStrengthReducePass();
+    (void)llvm::createLoopTermFoldPass();
     (void)llvm::createLoopUnrollPass();
     (void)llvm::createLowerGlobalDtorsLegacyPass();
     (void)llvm::createLowerInvokePass();

diff  --git a/llvm/include/llvm/Passes/MachinePassRegistry.def b/llvm/include/llvm/Passes/MachinePassRegistry.def
index 8e669ee5791239..05baf514fa7210 100644
--- a/llvm/include/llvm/Passes/MachinePassRegistry.def
+++ b/llvm/include/llvm/Passes/MachinePassRegistry.def
@@ -79,6 +79,7 @@ FUNCTION_PASS("win-eh-prepare", WinEHPreparePass())
 #define LOOP_PASS(NAME, CREATE_PASS)
 #endif
 LOOP_PASS("loop-reduce", LoopStrengthReducePass())
+LOOP_PASS("loop-term-fold", LoopTermFoldPass())
 #undef LOOP_PASS
 
 #ifndef MACHINE_MODULE_PASS

diff  --git a/llvm/include/llvm/Transforms/Scalar.h b/llvm/include/llvm/Transforms/Scalar.h
index 98d0adca355214..17f4327eb3e1ab 100644
--- a/llvm/include/llvm/Transforms/Scalar.h
+++ b/llvm/include/llvm/Transforms/Scalar.h
@@ -51,6 +51,14 @@ Pass *createLICMPass();
 //
 Pass *createLoopStrengthReducePass();
 
+//===----------------------------------------------------------------------===//
+//
+// LoopTermFold -  This pass attempts to eliminate the last use of an IV in
+// a loop terminator instruction by rewriting it in terms of another IV.
+// Expected to be run immediately after LSR.
+//
+Pass *createLoopTermFoldPass();
+
 //===----------------------------------------------------------------------===//
 //
 // LoopUnroll - This pass is a simple loop unrolling pass.

diff  --git a/llvm/include/llvm/Transforms/Scalar/LoopTermFold.h b/llvm/include/llvm/Transforms/Scalar/LoopTermFold.h
new file mode 100644
index 00000000000000..974024c586aa80
--- /dev/null
+++ b/llvm/include/llvm/Transforms/Scalar/LoopTermFold.h
@@ -0,0 +1,30 @@
+//===- LoopTermFold.h - Loop Term Fold Pass ---------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_TRANSFORMS_SCALAR_LOOPTERMFOLD_H
+#define LLVM_TRANSFORMS_SCALAR_LOOPTERMFOLD_H
+
+#include "llvm/Analysis/LoopAnalysisManager.h"
+#include "llvm/IR/PassManager.h"
+
+namespace llvm {
+
+class Loop;
+class LPMUpdater;
+
+class LoopTermFoldPass : public PassInfoMixin<LoopTermFoldPass> {
+public:
+  PreservedAnalyses run(Loop &L, LoopAnalysisManager &AM,
+                        LoopStandardAnalysisResults &AR, LPMUpdater &U);
+};
+
+} // end namespace llvm
+
+#endif // LLVM_TRANSFORMS_SCALAR_LOOPTERMFOLD_H

diff  --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index dcde78925bfa98..2c26493bd3f1ca 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -427,10 +427,6 @@ bool TargetTransformInfo::isNumRegsMajorCostOfLSR() const {
   return TTIImpl->isNumRegsMajorCostOfLSR();
 }
 
-bool TargetTransformInfo::shouldFoldTerminatingConditionAfterLSR() const {
-  return TTIImpl->shouldFoldTerminatingConditionAfterLSR();
-}
-
 bool TargetTransformInfo::shouldDropLSRSolutionIfLessProfitable() const {
   return TTIImpl->shouldDropLSRSolutionIfLessProfitable();
 }

diff  --git a/llvm/lib/CodeGen/TargetPassConfig.cpp b/llvm/lib/CodeGen/TargetPassConfig.cpp
index 1b0012b65b80d4..1d52ebe6717f04 100644
--- a/llvm/lib/CodeGen/TargetPassConfig.cpp
+++ b/llvm/lib/CodeGen/TargetPassConfig.cpp
@@ -828,6 +828,8 @@ void TargetPassConfig::addIRPasses() {
     if (!DisableLSR) {
       addPass(createCanonicalizeFreezeInLoopsPass());
       addPass(createLoopStrengthReducePass());
+      if (EnableLoopTermFold)
+        addPass(createLoopTermFoldPass());
       if (PrintLSR)
         addPass(createPrintFunctionPass(dbgs(),
                                         "\n\n*** Code after LSR ***\n"));

diff  --git a/llvm/lib/Passes/PassBuilder.cpp b/llvm/lib/Passes/PassBuilder.cpp
index 3200767282b226..17eed97fd950c9 100644
--- a/llvm/lib/Passes/PassBuilder.cpp
+++ b/llvm/lib/Passes/PassBuilder.cpp
@@ -249,6 +249,7 @@
 #include "llvm/Transforms/Scalar/LoopSimplifyCFG.h"
 #include "llvm/Transforms/Scalar/LoopSink.h"
 #include "llvm/Transforms/Scalar/LoopStrengthReduce.h"
+#include "llvm/Transforms/Scalar/LoopTermFold.h"
 #include "llvm/Transforms/Scalar/LoopUnrollAndJamPass.h"
 #include "llvm/Transforms/Scalar/LoopUnrollPass.h"
 #include "llvm/Transforms/Scalar/LoopVersioningLICM.h"

diff  --git a/llvm/lib/Passes/PassRegistry.def b/llvm/lib/Passes/PassRegistry.def
index a11fc3755494ab..6b5e1cf83c4698 100644
--- a/llvm/lib/Passes/PassRegistry.def
+++ b/llvm/lib/Passes/PassRegistry.def
@@ -646,6 +646,7 @@ LOOP_PASS("loop-idiom-vectorize", LoopIdiomVectorizePass())
 LOOP_PASS("loop-instsimplify", LoopInstSimplifyPass())
 LOOP_PASS("loop-predication", LoopPredicationPass())
 LOOP_PASS("loop-reduce", LoopStrengthReducePass())
+LOOP_PASS("loop-term-fold", LoopTermFoldPass())
 LOOP_PASS("loop-simplifycfg", LoopSimplifyCFGPass())
 LOOP_PASS("loop-unroll-full", LoopFullUnrollPass())
 LOOP_PASS("loop-versioning-licm", LoopVersioningLICMPass())

diff  --git a/llvm/lib/Target/RISCV/RISCVTargetMachine.cpp b/llvm/lib/Target/RISCV/RISCVTargetMachine.cpp
index b6884321f08411..794df2212dfa53 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetMachine.cpp
+++ b/llvm/lib/Target/RISCV/RISCVTargetMachine.cpp
@@ -336,6 +336,7 @@ class RISCVPassConfig : public TargetPassConfig {
     if (TM.getOptLevel() != CodeGenOptLevel::None)
       substitutePass(&PostRASchedulerID, &PostMachineSchedulerID);
     setEnableSinkAndFold(EnableSinkFold);
+    EnableLoopTermFold = true;
   }
 
   RISCVTargetMachine &getRISCVTargetMachine() const {

diff  --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
index f5eca2839acd05..cc69e1d118b5a1 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
@@ -394,9 +394,6 @@ class RISCVTTIImpl : public BasicTTIImplBase<RISCVTTIImpl> {
   bool isLSRCostLess(const TargetTransformInfo::LSRCost &C1,
                      const TargetTransformInfo::LSRCost &C2);
 
-  bool shouldFoldTerminatingConditionAfterLSR() const {
-    return true;
-  }
   bool
   shouldConsiderAddressTypePromotion(const Instruction &I,
                                      bool &AllowPromotionWithoutCommonHeader);

diff  --git a/llvm/lib/Transforms/Scalar/CMakeLists.txt b/llvm/lib/Transforms/Scalar/CMakeLists.txt
index ba09ebf8b04c4c..939a1457239567 100644
--- a/llvm/lib/Transforms/Scalar/CMakeLists.txt
+++ b/llvm/lib/Transforms/Scalar/CMakeLists.txt
@@ -44,6 +44,7 @@ add_llvm_component_library(LLVMScalarOpts
   LoopRotation.cpp
   LoopSimplifyCFG.cpp
   LoopStrengthReduce.cpp
+  LoopTermFold.cpp
   LoopUnrollPass.cpp
   LoopUnrollAndJamPass.cpp
   LoopVersioningLICM.cpp

diff  --git a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
index 91461d1ed27592..a62b87fe2a53d4 100644
--- a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
@@ -189,10 +189,6 @@ static cl::opt<unsigned> SetupCostDepthLimit(
     "lsr-setupcost-depth-limit", cl::Hidden, cl::init(7),
     cl::desc("The limit on recursion depth for LSRs setup cost"));
 
-static cl::opt<cl::boolOrDefault> AllowTerminatingConditionFoldingAfterLSR(
-    "lsr-term-fold", cl::Hidden,
-    cl::desc("Attempt to replace primary IV with other IV."));
-
 static cl::opt<cl::boolOrDefault> AllowDropSolutionIfLessProfitable(
     "lsr-drop-solution", cl::Hidden,
     cl::desc("Attempt to drop solution if it is less profitable"));
@@ -205,9 +201,6 @@ static cl::opt<bool> DropScaledForVScale(
     "lsr-drop-scaled-reg-for-vscale", cl::Hidden, cl::init(true),
     cl::desc("Avoid using scaled registers with vscale-relative addressing"));
 
-STATISTIC(NumTermFold,
-          "Number of terminating condition fold recognized and performed");
-
 #ifndef NDEBUG
 // Stress test IV chain generation.
 static cl::opt<bool> StressIVChain(
@@ -7062,186 +7055,6 @@ static llvm::PHINode *GetInductionVariable(const Loop &L, ScalarEvolution &SE,
   return nullptr;
 }
 
-static std::optional<std::tuple<PHINode *, PHINode *, const SCEV *, bool>>
-canFoldTermCondOfLoop(Loop *L, ScalarEvolution &SE, DominatorTree &DT,
-                      const LoopInfo &LI, const TargetTransformInfo &TTI) {
-  if (!L->isInnermost()) {
-    LLVM_DEBUG(dbgs() << "Cannot fold on non-innermost loop\n");
-    return std::nullopt;
-  }
-  // Only inspect on simple loop structure
-  if (!L->isLoopSimplifyForm()) {
-    LLVM_DEBUG(dbgs() << "Cannot fold on non-simple loop\n");
-    return std::nullopt;
-  }
-
-  if (!SE.hasLoopInvariantBackedgeTakenCount(L)) {
-    LLVM_DEBUG(dbgs() << "Cannot fold on backedge that is loop variant\n");
-    return std::nullopt;
-  }
-
-  BasicBlock *LoopLatch = L->getLoopLatch();
-  BranchInst *BI = dyn_cast<BranchInst>(LoopLatch->getTerminator());
-  if (!BI || BI->isUnconditional())
-    return std::nullopt;
-  auto *TermCond = dyn_cast<ICmpInst>(BI->getCondition());
-  if (!TermCond) {
-    LLVM_DEBUG(
-        dbgs() << "Cannot fold on branching condition that is not an ICmpInst");
-    return std::nullopt;
-  }
-  if (!TermCond->hasOneUse()) {
-    LLVM_DEBUG(
-        dbgs()
-        << "Cannot replace terminating condition with more than one use\n");
-    return std::nullopt;
-  }
-
-  BinaryOperator *LHS = dyn_cast<BinaryOperator>(TermCond->getOperand(0));
-  Value *RHS = TermCond->getOperand(1);
-  if (!LHS || !L->isLoopInvariant(RHS))
-    // We could pattern match the inverse form of the icmp, but that is
-    // non-canonical, and this pass is running *very* late in the pipeline.
-    return std::nullopt;
-
-  // Find the IV used by the current exit condition.
-  PHINode *ToFold;
-  Value *ToFoldStart, *ToFoldStep;
-  if (!matchSimpleRecurrence(LHS, ToFold, ToFoldStart, ToFoldStep))
-    return std::nullopt;
-
-  // Ensure the simple recurrence is a part of the current loop.
-  if (ToFold->getParent() != L->getHeader())
-    return std::nullopt;
-
-  // If that IV isn't dead after we rewrite the exit condition in terms of
-  // another IV, there's no point in doing the transform.
-  if (!isAlmostDeadIV(ToFold, LoopLatch, TermCond))
-    return std::nullopt;
-
-  // Inserting instructions in the preheader has a runtime cost, scale
-  // the allowed cost with the loops trip count as best we can.
-  const unsigned ExpansionBudget = [&]() {
-    unsigned Budget = 2 * SCEVCheapExpansionBudget;
-    if (unsigned SmallTC = SE.getSmallConstantMaxTripCount(L))
-      return std::min(Budget, SmallTC);
-    if (std::optional<unsigned> SmallTC = getLoopEstimatedTripCount(L))
-      return std::min(Budget, *SmallTC);
-    // Unknown trip count, assume long running by default.
-    return Budget;
-  }();
-
-  const SCEV *BECount = SE.getBackedgeTakenCount(L);
-  const DataLayout &DL = L->getHeader()->getDataLayout();
-  SCEVExpander Expander(SE, DL, "lsr_fold_term_cond");
-
-  PHINode *ToHelpFold = nullptr;
-  const SCEV *TermValueS = nullptr;
-  bool MustDropPoison = false;
-  auto InsertPt = L->getLoopPreheader()->getTerminator();
-  for (PHINode &PN : L->getHeader()->phis()) {
-    if (ToFold == &PN)
-      continue;
-
-    if (!SE.isSCEVable(PN.getType())) {
-      LLVM_DEBUG(dbgs() << "IV of phi '" << PN
-                        << "' is not SCEV-able, not qualified for the "
-                           "terminating condition folding.\n");
-      continue;
-    }
-    const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(SE.getSCEV(&PN));
-    // Only speculate on affine AddRec
-    if (!AddRec || !AddRec->isAffine()) {
-      LLVM_DEBUG(dbgs() << "SCEV of phi '" << PN
-                        << "' is not an affine add recursion, not qualified "
-                           "for the terminating condition folding.\n");
-      continue;
-    }
-
-    // Check that we can compute the value of AddRec on the exiting iteration
-    // without soundness problems.  evaluateAtIteration internally needs
-    // to multiply the stride of the iteration number - which may wrap around.
-    // The issue here is subtle because computing the result accounting for
-    // wrap is insufficient. In order to use the result in an exit test, we
-    // must also know that AddRec doesn't take the same value on any previous
-    // iteration. The simplest case to consider is a candidate IV which is
-    // narrower than the trip count (and thus original IV), but this can
-    // also happen due to non-unit strides on the candidate IVs.
-    if (!AddRec->hasNoSelfWrap() ||
-        !SE.isKnownNonZero(AddRec->getStepRecurrence(SE)))
-      continue;
-
-    const SCEVAddRecExpr *PostInc = AddRec->getPostIncExpr(SE);
-    const SCEV *TermValueSLocal = PostInc->evaluateAtIteration(BECount, SE);
-    if (!Expander.isSafeToExpand(TermValueSLocal)) {
-      LLVM_DEBUG(
-          dbgs() << "Is not safe to expand terminating value for phi node" << PN
-                 << "\n");
-      continue;
-    }
-
-    if (Expander.isHighCostExpansion(TermValueSLocal, L, ExpansionBudget,
-                                     &TTI, InsertPt)) {
-      LLVM_DEBUG(
-          dbgs() << "Is too expensive to expand terminating value for phi node"
-                 << PN << "\n");
-      continue;
-    }
-
-    // The candidate IV may have been otherwise dead and poison from the
-    // very first iteration.  If we can't disprove that, we can't use the IV.
-    if (!mustExecuteUBIfPoisonOnPathTo(&PN, LoopLatch->getTerminator(), &DT)) {
-      LLVM_DEBUG(dbgs() << "Can not prove poison safety for IV "
-                        << PN << "\n");
-      continue;
-    }
-
-    // The candidate IV may become poison on the last iteration.  If this
-    // value is not branched on, this is a well defined program.  We're
-    // about to add a new use to this IV, and we have to ensure we don't
-    // insert UB which didn't previously exist.
-    bool MustDropPoisonLocal = false;
-    Instruction *PostIncV =
-      cast<Instruction>(PN.getIncomingValueForBlock(LoopLatch));
-    if (!mustExecuteUBIfPoisonOnPathTo(PostIncV, LoopLatch->getTerminator(),
-                                       &DT)) {
-      LLVM_DEBUG(dbgs() << "Can not prove poison safety to insert use"
-                        << PN << "\n");
-
-      // If this is a complex recurrance with multiple instructions computing
-      // the backedge value, we might need to strip poison flags from all of
-      // them.
-      if (PostIncV->getOperand(0) != &PN)
-        continue;
-
-      // In order to perform the transform, we need to drop the poison generating
-      // flags on this instruction (if any).
-      MustDropPoisonLocal = PostIncV->hasPoisonGeneratingFlags();
-    }
-
-    // We pick the last legal alternate IV.  We could expore choosing an optimal
-    // alternate IV if we had a decent heuristic to do so.
-    ToHelpFold = &PN;
-    TermValueS = TermValueSLocal;
-    MustDropPoison = MustDropPoisonLocal;
-  }
-
-  LLVM_DEBUG(if (ToFold && !ToHelpFold) dbgs()
-                 << "Cannot find other AddRec IV to help folding\n";);
-
-  LLVM_DEBUG(if (ToFold && ToHelpFold) dbgs()
-             << "\nFound loop that can fold terminating condition\n"
-             << "  BECount (SCEV): " << *SE.getBackedgeTakenCount(L) << "\n"
-             << "  TermCond: " << *TermCond << "\n"
-             << "  BrandInst: " << *BI << "\n"
-             << "  ToFold: " << *ToFold << "\n"
-             << "  ToHelpFold: " << *ToHelpFold << "\n");
-
-  if (!ToFold || !ToHelpFold)
-    return std::nullopt;
-  return std::make_tuple(ToFold, ToHelpFold, TermValueS, MustDropPoison);
-}
-
 static bool ReduceLoopStrength(Loop *L, IVUsers &IU, ScalarEvolution &SE,
                                DominatorTree &DT, LoopInfo &LI,
                                const TargetTransformInfo &TTI,
@@ -7302,81 +7115,6 @@ static bool ReduceLoopStrength(Loop *L, IVUsers &IU, ScalarEvolution &SE,
     }
   }
 
-  const bool EnableFormTerm = [&] {
-    switch (AllowTerminatingConditionFoldingAfterLSR) {
-    case cl::BOU_TRUE:
-      return true;
-    case cl::BOU_FALSE:
-      return false;
-    case cl::BOU_UNSET:
-      return TTI.shouldFoldTerminatingConditionAfterLSR();
-    }
-    llvm_unreachable("Unhandled cl::boolOrDefault enum");
-  }();
-
-  if (EnableFormTerm) {
-    if (auto Opt = canFoldTermCondOfLoop(L, SE, DT, LI, TTI)) {
-      auto [ToFold, ToHelpFold, TermValueS, MustDrop] = *Opt;
-
-      Changed = true;
-      NumTermFold++;
-
-      BasicBlock *LoopPreheader = L->getLoopPreheader();
-      BasicBlock *LoopLatch = L->getLoopLatch();
-
-      (void)ToFold;
-      LLVM_DEBUG(dbgs() << "To fold phi-node:\n"
-                        << *ToFold << "\n"
-                        << "New term-cond phi-node:\n"
-                        << *ToHelpFold << "\n");
-
-      Value *StartValue = ToHelpFold->getIncomingValueForBlock(LoopPreheader);
-      (void)StartValue;
-      Value *LoopValue = ToHelpFold->getIncomingValueForBlock(LoopLatch);
-
-      // See comment in canFoldTermCondOfLoop on why this is sufficient.
-      if (MustDrop)
-        cast<Instruction>(LoopValue)->dropPoisonGeneratingFlags();
-
-      // SCEVExpander for both use in preheader and latch
-      const DataLayout &DL = L->getHeader()->getDataLayout();
-      SCEVExpander Expander(SE, DL, "lsr_fold_term_cond");
-
-      assert(Expander.isSafeToExpand(TermValueS) &&
-             "Terminating value was checked safe in canFoldTerminatingCondition");
-
-      // Create new terminating value at loop preheader
-      Value *TermValue = Expander.expandCodeFor(TermValueS, ToHelpFold->getType(),
-                                                LoopPreheader->getTerminator());
-
-      LLVM_DEBUG(dbgs() << "Start value of new term-cond phi-node:\n"
-                        << *StartValue << "\n"
-                        << "Terminating value of new term-cond phi-node:\n"
-                        << *TermValue << "\n");
-
-      // Create new terminating condition at loop latch
-      BranchInst *BI = cast<BranchInst>(LoopLatch->getTerminator());
-      ICmpInst *OldTermCond = cast<ICmpInst>(BI->getCondition());
-      IRBuilder<> LatchBuilder(LoopLatch->getTerminator());
-      Value *NewTermCond =
-          LatchBuilder.CreateICmp(CmpInst::ICMP_EQ, LoopValue, TermValue,
-                                  "lsr_fold_term_cond.replaced_term_cond");
-      // Swap successors to exit loop body if IV equals to new TermValue
-      if (BI->getSuccessor(0) == L->getHeader())
-        BI->swapSuccessors();
-
-      LLVM_DEBUG(dbgs() << "Old term-cond:\n"
-                        << *OldTermCond << "\n"
-                        << "New term-cond:\n" << *NewTermCond << "\n");
-
-      BI->setCondition(NewTermCond);
-
-      Expander.clear();
-      OldTermCond->eraseFromParent();
-      DeleteDeadPHIs(L->getHeader(), &TLI, MSSAU.get());
-    }
-  }
-
   if (SalvageableDVIRecords.empty())
     return Changed;
 

diff  --git a/llvm/lib/Transforms/Scalar/LoopTermFold.cpp b/llvm/lib/Transforms/Scalar/LoopTermFold.cpp
new file mode 100644
index 00000000000000..12ef367adc43e3
--- /dev/null
+++ b/llvm/lib/Transforms/Scalar/LoopTermFold.cpp
@@ -0,0 +1,379 @@
+//===- LoopTermFold.cpp - Eliminate last use of IV in exit branch----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Transforms/Scalar/LoopTermFold.h"
+#include "llvm/ADT/Statistic.h"
+#include "llvm/Analysis/LoopAnalysisManager.h"
+#include "llvm/Analysis/LoopInfo.h"
+#include "llvm/Analysis/LoopPass.h"
+#include "llvm/Analysis/MemorySSA.h"
+#include "llvm/Analysis/MemorySSAUpdater.h"
+#include "llvm/Analysis/ScalarEvolution.h"
+#include "llvm/Analysis/ScalarEvolutionExpressions.h"
+#include "llvm/Analysis/TargetLibraryInfo.h"
+#include "llvm/Analysis/TargetTransformInfo.h"
+#include "llvm/Analysis/ValueTracking.h"
+#include "llvm/Config/llvm-config.h"
+#include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/Dominators.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/InstrTypes.h"
+#include "llvm/IR/Instruction.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/Type.h"
+#include "llvm/IR/Value.h"
+#include "llvm/InitializePasses.h"
+#include "llvm/Pass.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
+#include "llvm/Transforms/Scalar.h"
+#include "llvm/Transforms/Utils.h"
+#include "llvm/Transforms/Utils/BasicBlockUtils.h"
+#include "llvm/Transforms/Utils/Local.h"
+#include "llvm/Transforms/Utils/LoopUtils.h"
+#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
+#include <cassert>
+#include <optional>
+#include <utility>
+
+using namespace llvm;
+
+#define DEBUG_TYPE "loop-term-fold"
+
+STATISTIC(NumTermFold,
+          "Number of terminating condition fold recognized and performed");
+
+static std::optional<std::tuple<PHINode *, PHINode *, const SCEV *, bool>>
+canFoldTermCondOfLoop(Loop *L, ScalarEvolution &SE, DominatorTree &DT,
+                      const LoopInfo &LI, const TargetTransformInfo &TTI) {
+  if (!L->isInnermost()) {
+    LLVM_DEBUG(dbgs() << "Cannot fold on non-innermost loop\n");
+    return std::nullopt;
+  }
+  // Only inspect on simple loop structure
+  if (!L->isLoopSimplifyForm()) {
+    LLVM_DEBUG(dbgs() << "Cannot fold on non-simple loop\n");
+    return std::nullopt;
+  }
+
+  if (!SE.hasLoopInvariantBackedgeTakenCount(L)) {
+    LLVM_DEBUG(dbgs() << "Cannot fold on backedge that is loop variant\n");
+    return std::nullopt;
+  }
+
+  BasicBlock *LoopLatch = L->getLoopLatch();
+  BranchInst *BI = dyn_cast<BranchInst>(LoopLatch->getTerminator());
+  if (!BI || BI->isUnconditional())
+    return std::nullopt;
+  auto *TermCond = dyn_cast<ICmpInst>(BI->getCondition());
+  if (!TermCond) {
+    LLVM_DEBUG(
+        dbgs() << "Cannot fold on branching condition that is not an ICmpInst");
+    return std::nullopt;
+  }
+  if (!TermCond->hasOneUse()) {
+    LLVM_DEBUG(
+        dbgs()
+        << "Cannot replace terminating condition with more than one use\n");
+    return std::nullopt;
+  }
+
+  BinaryOperator *LHS = dyn_cast<BinaryOperator>(TermCond->getOperand(0));
+  Value *RHS = TermCond->getOperand(1);
+  if (!LHS || !L->isLoopInvariant(RHS))
+    // We could pattern match the inverse form of the icmp, but that is
+    // non-canonical, and this pass is running *very* late in the pipeline.
+    return std::nullopt;
+
+  // Find the IV used by the current exit condition.
+  PHINode *ToFold;
+  Value *ToFoldStart, *ToFoldStep;
+  if (!matchSimpleRecurrence(LHS, ToFold, ToFoldStart, ToFoldStep))
+    return std::nullopt;
+
+  // Ensure the simple recurrence is a part of the current loop.
+  if (ToFold->getParent() != L->getHeader())
+    return std::nullopt;
+
+  // If that IV isn't dead after we rewrite the exit condition in terms of
+  // another IV, there's no point in doing the transform.
+  if (!isAlmostDeadIV(ToFold, LoopLatch, TermCond))
+    return std::nullopt;
+
+  // Inserting instructions in the preheader has a runtime cost, scale
+  // the allowed cost with the loops trip count as best we can.
+  const unsigned ExpansionBudget = [&]() {
+    unsigned Budget = 2 * SCEVCheapExpansionBudget;
+    if (unsigned SmallTC = SE.getSmallConstantMaxTripCount(L))
+      return std::min(Budget, SmallTC);
+    if (std::optional<unsigned> SmallTC = getLoopEstimatedTripCount(L))
+      return std::min(Budget, *SmallTC);
+    // Unknown trip count, assume long running by default.
+    return Budget;
+  }();
+
+  const SCEV *BECount = SE.getBackedgeTakenCount(L);
+  const DataLayout &DL = L->getHeader()->getDataLayout();
+  SCEVExpander Expander(SE, DL, "lsr_fold_term_cond");
+
+  PHINode *ToHelpFold = nullptr;
+  const SCEV *TermValueS = nullptr;
+  bool MustDropPoison = false;
+  auto InsertPt = L->getLoopPreheader()->getTerminator();
+  for (PHINode &PN : L->getHeader()->phis()) {
+    if (ToFold == &PN)
+      continue;
+
+    if (!SE.isSCEVable(PN.getType())) {
+      LLVM_DEBUG(dbgs() << "IV of phi '" << PN
+                        << "' is not SCEV-able, not qualified for the "
+                           "terminating condition folding.\n");
+      continue;
+    }
+    const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(SE.getSCEV(&PN));
+    // Only speculate on affine AddRec
+    if (!AddRec || !AddRec->isAffine()) {
+      LLVM_DEBUG(dbgs() << "SCEV of phi '" << PN
+                        << "' is not an affine add recursion, not qualified "
+                           "for the terminating condition folding.\n");
+      continue;
+    }
+
+    // Check that we can compute the value of AddRec on the exiting iteration
+    // without soundness problems.  evaluateAtIteration internally needs
+    // to multiply the stride of the iteration number - which may wrap around.
+    // The issue here is subtle because computing the result accounting for
+    // wrap is insufficient. In order to use the result in an exit test, we
+    // must also know that AddRec doesn't take the same value on any previous
+    // iteration. The simplest case to consider is a candidate IV which is
+    // narrower than the trip count (and thus original IV), but this can
+    // also happen due to non-unit strides on the candidate IVs.
+    if (!AddRec->hasNoSelfWrap() ||
+        !SE.isKnownNonZero(AddRec->getStepRecurrence(SE)))
+      continue;
+
+    const SCEVAddRecExpr *PostInc = AddRec->getPostIncExpr(SE);
+    const SCEV *TermValueSLocal = PostInc->evaluateAtIteration(BECount, SE);
+    if (!Expander.isSafeToExpand(TermValueSLocal)) {
+      LLVM_DEBUG(
+          dbgs() << "Is not safe to expand terminating value for phi node" << PN
+                 << "\n");
+      continue;
+    }
+
+    if (Expander.isHighCostExpansion(TermValueSLocal, L, ExpansionBudget, &TTI,
+                                     InsertPt)) {
+      LLVM_DEBUG(
+          dbgs() << "Is too expensive to expand terminating value for phi node"
+                 << PN << "\n");
+      continue;
+    }
+
+    // The candidate IV may have been otherwise dead and poison from the
+    // very first iteration.  If we can't disprove that, we can't use the IV.
+    if (!mustExecuteUBIfPoisonOnPathTo(&PN, LoopLatch->getTerminator(), &DT)) {
+      LLVM_DEBUG(dbgs() << "Can not prove poison safety for IV " << PN << "\n");
+      continue;
+    }
+
+    // The candidate IV may become poison on the last iteration.  If this
+    // value is not branched on, this is a well defined program.  We're
+    // about to add a new use to this IV, and we have to ensure we don't
+    // insert UB which didn't previously exist.
+    bool MustDropPoisonLocal = false;
+    Instruction *PostIncV =
+        cast<Instruction>(PN.getIncomingValueForBlock(LoopLatch));
+    if (!mustExecuteUBIfPoisonOnPathTo(PostIncV, LoopLatch->getTerminator(),
+                                       &DT)) {
+      LLVM_DEBUG(dbgs() << "Can not prove poison safety to insert use" << PN
+                        << "\n");
+
+      // If this is a complex recurrance with multiple instructions computing
+      // the backedge value, we might need to strip poison flags from all of
+      // them.
+      if (PostIncV->getOperand(0) != &PN)
+        continue;
+
+      // In order to perform the transform, we need to drop the poison
+      // generating flags on this instruction (if any).
+      MustDropPoisonLocal = PostIncV->hasPoisonGeneratingFlags();
+    }
+
+    // We pick the last legal alternate IV.  We could expore choosing an optimal
+    // alternate IV if we had a decent heuristic to do so.
+    ToHelpFold = &PN;
+    TermValueS = TermValueSLocal;
+    MustDropPoison = MustDropPoisonLocal;
+  }
+
+  LLVM_DEBUG(if (ToFold && !ToHelpFold) dbgs()
+                 << "Cannot find other AddRec IV to help folding\n";);
+
+  LLVM_DEBUG(if (ToFold && ToHelpFold) dbgs()
+             << "\nFound loop that can fold terminating condition\n"
+             << "  BECount (SCEV): " << *SE.getBackedgeTakenCount(L) << "\n"
+             << "  TermCond: " << *TermCond << "\n"
+             << "  BrandInst: " << *BI << "\n"
+             << "  ToFold: " << *ToFold << "\n"
+             << "  ToHelpFold: " << *ToHelpFold << "\n");
+
+  if (!ToFold || !ToHelpFold)
+    return std::nullopt;
+  return std::make_tuple(ToFold, ToHelpFold, TermValueS, MustDropPoison);
+}
+
+static bool RunTermFold(Loop *L, ScalarEvolution &SE, DominatorTree &DT,
+                        LoopInfo &LI, const TargetTransformInfo &TTI,
+                        TargetLibraryInfo &TLI, MemorySSA *MSSA) {
+  std::unique_ptr<MemorySSAUpdater> MSSAU;
+  if (MSSA)
+    MSSAU = std::make_unique<MemorySSAUpdater>(MSSA);
+
+  auto Opt = canFoldTermCondOfLoop(L, SE, DT, LI, TTI);
+  if (!Opt)
+    return false;
+
+  auto [ToFold, ToHelpFold, TermValueS, MustDrop] = *Opt;
+
+  NumTermFold++;
+
+  BasicBlock *LoopPreheader = L->getLoopPreheader();
+  BasicBlock *LoopLatch = L->getLoopLatch();
+
+  (void)ToFold;
+  LLVM_DEBUG(dbgs() << "To fold phi-node:\n"
+                    << *ToFold << "\n"
+                    << "New term-cond phi-node:\n"
+                    << *ToHelpFold << "\n");
+
+  Value *StartValue = ToHelpFold->getIncomingValueForBlock(LoopPreheader);
+  (void)StartValue;
+  Value *LoopValue = ToHelpFold->getIncomingValueForBlock(LoopLatch);
+
+  // See comment in canFoldTermCondOfLoop on why this is sufficient.
+  if (MustDrop)
+    cast<Instruction>(LoopValue)->dropPoisonGeneratingFlags();
+
+  // SCEVExpander for both use in preheader and latch
+  const DataLayout &DL = L->getHeader()->getDataLayout();
+  SCEVExpander Expander(SE, DL, "lsr_fold_term_cond");
+
+  assert(Expander.isSafeToExpand(TermValueS) &&
+         "Terminating value was checked safe in canFoldTerminatingCondition");
+
+  // Create new terminating value at loop preheader
+  Value *TermValue = Expander.expandCodeFor(TermValueS, ToHelpFold->getType(),
+                                            LoopPreheader->getTerminator());
+
+  LLVM_DEBUG(dbgs() << "Start value of new term-cond phi-node:\n"
+                    << *StartValue << "\n"
+                    << "Terminating value of new term-cond phi-node:\n"
+                    << *TermValue << "\n");
+
+  // Create new terminating condition at loop latch
+  BranchInst *BI = cast<BranchInst>(LoopLatch->getTerminator());
+  ICmpInst *OldTermCond = cast<ICmpInst>(BI->getCondition());
+  IRBuilder<> LatchBuilder(LoopLatch->getTerminator());
+  Value *NewTermCond =
+      LatchBuilder.CreateICmp(CmpInst::ICMP_EQ, LoopValue, TermValue,
+                              "lsr_fold_term_cond.replaced_term_cond");
+  // Swap successors to exit loop body if IV equals to new TermValue
+  if (BI->getSuccessor(0) == L->getHeader())
+    BI->swapSuccessors();
+
+  LLVM_DEBUG(dbgs() << "Old term-cond:\n"
+                    << *OldTermCond << "\n"
+                    << "New term-cond:\n"
+                    << *NewTermCond << "\n");
+
+  BI->setCondition(NewTermCond);
+
+  Expander.clear();
+  OldTermCond->eraseFromParent();
+  DeleteDeadPHIs(L->getHeader(), &TLI, MSSAU.get());
+  return true;
+}
+
+namespace {
+
+class LoopTermFold : public LoopPass {
+public:
+  static char ID; // Pass ID, replacement for typeid
+
+  LoopTermFold();
+
+private:
+  bool runOnLoop(Loop *L, LPPassManager &LPM) override;
+  void getAnalysisUsage(AnalysisUsage &AU) const override;
+};
+
+} // end anonymous namespace
+
+LoopTermFold::LoopTermFold() : LoopPass(ID) {
+  initializeLoopTermFoldPass(*PassRegistry::getPassRegistry());
+}
+
+void LoopTermFold::getAnalysisUsage(AnalysisUsage &AU) const {
+  AU.addRequired<LoopInfoWrapperPass>();
+  AU.addPreserved<LoopInfoWrapperPass>();
+  AU.addPreservedID(LoopSimplifyID);
+  AU.addRequiredID(LoopSimplifyID);
+  AU.addRequired<DominatorTreeWrapperPass>();
+  AU.addPreserved<DominatorTreeWrapperPass>();
+  AU.addRequired<ScalarEvolutionWrapperPass>();
+  AU.addPreserved<ScalarEvolutionWrapperPass>();
+  AU.addRequired<TargetLibraryInfoWrapperPass>();
+  AU.addRequired<TargetTransformInfoWrapperPass>();
+  AU.addPreserved<MemorySSAWrapperPass>();
+}
+
+bool LoopTermFold::runOnLoop(Loop *L, LPPassManager & /*LPM*/) {
+  if (skipLoop(L))
+    return false;
+
+  auto &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE();
+  auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
+  auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
+  const auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(
+      *L->getHeader()->getParent());
+  auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(
+      *L->getHeader()->getParent());
+  auto *MSSAAnalysis = getAnalysisIfAvailable<MemorySSAWrapperPass>();
+  MemorySSA *MSSA = nullptr;
+  if (MSSAAnalysis)
+    MSSA = &MSSAAnalysis->getMSSA();
+  return RunTermFold(L, SE, DT, LI, TTI, TLI, MSSA);
+}
+
+PreservedAnalyses LoopTermFoldPass::run(Loop &L, LoopAnalysisManager &AM,
+                                        LoopStandardAnalysisResults &AR,
+                                        LPMUpdater &) {
+  if (!RunTermFold(&L, AR.SE, AR.DT, AR.LI, AR.TTI, AR.TLI, AR.MSSA))
+    return PreservedAnalyses::all();
+
+  auto PA = getLoopPassPreservedAnalyses();
+  if (AR.MSSA)
+    PA.preserve<MemorySSAAnalysis>();
+  return PA;
+}
+
+char LoopTermFold::ID = 0;
+
+INITIALIZE_PASS_BEGIN(LoopTermFold, "loop-term-fold", "Loop Terminator Folding",
+                      false, false)
+INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(LoopSimplify)
+INITIALIZE_PASS_END(LoopTermFold, "loop-term-fold", "Loop Terminator Folding",
+                    false, false)
+
+Pass *llvm::createLoopTermFoldPass() { return new LoopTermFold(); }

diff  --git a/llvm/lib/Transforms/Scalar/Scalar.cpp b/llvm/lib/Transforms/Scalar/Scalar.cpp
index 86669e8c5aa49b..7aeee1d31f7e79 100644
--- a/llvm/lib/Transforms/Scalar/Scalar.cpp
+++ b/llvm/lib/Transforms/Scalar/Scalar.cpp
@@ -30,6 +30,7 @@ void llvm::initializeScalarOpts(PassRegistry &Registry) {
   initializeLegacyLICMPassPass(Registry);
   initializeLoopDataPrefetchLegacyPassPass(Registry);
   initializeLoopStrengthReducePass(Registry);
+  initializeLoopTermFoldPass(Registry);
   initializeLoopUnrollPass(Registry);
   initializeLowerAtomicLegacyPassPass(Registry);
   initializeMergeICmpsLegacyPassPass(Registry);

diff  --git a/llvm/test/CodeGen/RISCV/O3-pipeline.ll b/llvm/test/CodeGen/RISCV/O3-pipeline.ll
index df9cb5de5d7682..44c270fdc3c257 100644
--- a/llvm/test/CodeGen/RISCV/O3-pipeline.ll
+++ b/llvm/test/CodeGen/RISCV/O3-pipeline.ll
@@ -45,6 +45,7 @@
 ; CHECK-NEXT:         Canonicalize Freeze Instructions in Loops
 ; CHECK-NEXT:         Induction Variable Users
 ; CHECK-NEXT:         Loop Strength Reduction
+; CHECK-NEXT:         Loop Terminator Folding
 ; CHECK-NEXT:       Basic Alias Analysis (stateless AA impl)
 ; CHECK-NEXT:       Function Alias Analysis Results
 ; CHECK-NEXT:       Merge contiguous icmps into a memcmp

diff  --git a/llvm/test/Transforms/LoopStrengthReduce/RISCV/lsr-cost-compare.ll b/llvm/test/Transforms/LoopStrengthReduce/RISCV/lsr-cost-compare.ll
index 9c11bd064ad47c..cadee94ff40960 100644
--- a/llvm/test/Transforms/LoopStrengthReduce/RISCV/lsr-cost-compare.ll
+++ b/llvm/test/Transforms/LoopStrengthReduce/RISCV/lsr-cost-compare.ll
@@ -1,5 +1,5 @@
 ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
-; RUN: opt < %s -loop-reduce -S | FileCheck %s
+; RUN: opt < %s -passes=loop-reduce,loop-term-fold -S | FileCheck %s
 
 target datalayout = "e-m:e-p:64:64-i64:64-i128:128-n64-S128"
 target triple = "riscv64"

diff  --git a/llvm/test/Transforms/LoopStrengthReduce/RISCV/term-fold-crash.ll b/llvm/test/Transforms/LoopStrengthReduce/RISCV/term-fold-crash.ll
index 8ca7f0010bbbe5..9fb240684d232b 100644
--- a/llvm/test/Transforms/LoopStrengthReduce/RISCV/term-fold-crash.ll
+++ b/llvm/test/Transforms/LoopStrengthReduce/RISCV/term-fold-crash.ll
@@ -1,5 +1,5 @@
 ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
-; RUN: opt -S -passes=loop-reduce -mtriple=riscv64-unknown-linux-gnu < %s | FileCheck %s
+; RUN: opt -S -passes=loop-reduce,loop-term-fold -mtriple=riscv64-unknown-linux-gnu < %s | FileCheck %s
 
 define void @test(ptr %p, i8 %arg, i32 %start) {
 ; CHECK-LABEL: define void @test(

diff  --git a/llvm/test/Transforms/LoopStrengthReduce/lsr-term-fold-negative-testcase.ll b/llvm/test/Transforms/LoopStrengthReduce/lsr-term-fold-negative-testcase.ll
index 2d3d3a4b72a1ac..89ddba3343ffa2 100644
--- a/llvm/test/Transforms/LoopStrengthReduce/lsr-term-fold-negative-testcase.ll
+++ b/llvm/test/Transforms/LoopStrengthReduce/lsr-term-fold-negative-testcase.ll
@@ -1,6 +1,6 @@
 ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 2
 ; REQUIRES: asserts
-; RUN: opt < %s -passes="loop-reduce" -S -debug -lsr-term-fold 2>&1 | FileCheck %s
+; RUN: opt < %s -passes=loop-reduce,loop-term-fold -S -debug 2>&1 | FileCheck %s
 
 target datalayout = "e-p:64:64:64-n64"
 

diff  --git a/llvm/test/Transforms/LoopStrengthReduce/lsr-term-fold.ll b/llvm/test/Transforms/LoopStrengthReduce/lsr-term-fold.ll
index 7299a014b79835..6f34dc843ae1ee 100644
--- a/llvm/test/Transforms/LoopStrengthReduce/lsr-term-fold.ll
+++ b/llvm/test/Transforms/LoopStrengthReduce/lsr-term-fold.ll
@@ -1,5 +1,5 @@
 ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
-; RUN: opt < %s -passes="loop-reduce" -S -lsr-term-fold | FileCheck %s
+; RUN: opt < %s -passes="loop-reduce,loop-term-fold" -S | FileCheck %s
 
 target datalayout = "e-p:64:64:64-n64"
 

diff  --git a/llvm/test/Transforms/LoopStrengthReduce/lsr-unreachable-bb-phi-node.ll b/llvm/test/Transforms/LoopStrengthReduce/lsr-unreachable-bb-phi-node.ll
index 1454535b52bccb..67a71496e4cec8 100644
--- a/llvm/test/Transforms/LoopStrengthReduce/lsr-unreachable-bb-phi-node.ll
+++ b/llvm/test/Transforms/LoopStrengthReduce/lsr-unreachable-bb-phi-node.ll
@@ -1,5 +1,5 @@
 ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
-; RUN: opt < %s -loop-reduce -S -lsr-term-fold | FileCheck %s
+; RUN: opt < %s -passes=loop-reduce,loop-term-fold -S | FileCheck %s
 
 ; This test used to crash due to matchSimpleRecurrence matching the simple
 ; recurrence in pn-loop when evaluating unrelated-loop. Since unrelated-loop
@@ -13,9 +13,10 @@ define void @phi_node_
diff erent_bb() {
 ; CHECK-NEXT:    [[TMP3:%.*]] = icmp ugt i32 [[TMP2]], 1
 ; CHECK-NEXT:    br i1 [[TMP3]], label [[PN_LOOP]], label [[UNRELATED_LOOP_PREHEADER:%.*]]
 ; CHECK:       unrelated-loop.preheader:
+; CHECK-NEXT:    [[DOTLCSSA:%.*]] = phi i32 [ [[TMP2]], [[PN_LOOP]] ]
 ; CHECK-NEXT:    br label [[UNRELATED_LOOP:%.*]]
 ; CHECK:       unrelated-loop:
-; CHECK-NEXT:    [[TMP4:%.*]] = icmp eq i32 [[TMP2]], 0
+; CHECK-NEXT:    [[TMP4:%.*]] = icmp eq i32 [[DOTLCSSA]], 0
 ; CHECK-NEXT:    br i1 [[TMP4]], label [[END:%.*]], label [[UNRELATED_LOOP]]
 ; CHECK:       end:
 ; CHECK-NEXT:    ret void


        


More information about the llvm-commits mailing list