[llvm] [LSR] Split the -lsr-term-fold transformation into it's own pass (PR #104234)

Nikita Popov via llvm-commits llvm-commits at lists.llvm.org
Fri Aug 16 06:13:10 PDT 2024


================
@@ -0,0 +1,386 @@
+//===- LoopTermFold.cpp - Eliminate last use of IV in exit branch----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Transforms/Scalar/LoopTermFold.h"
+#include "llvm/ADT/Statistic.h"
+#include "llvm/Analysis/LoopAnalysisManager.h"
+#include "llvm/Analysis/LoopInfo.h"
+#include "llvm/Analysis/LoopPass.h"
+#include "llvm/Analysis/MemorySSA.h"
+#include "llvm/Analysis/MemorySSAUpdater.h"
+#include "llvm/Analysis/ScalarEvolution.h"
+#include "llvm/Analysis/ScalarEvolutionExpressions.h"
+#include "llvm/Analysis/TargetLibraryInfo.h"
+#include "llvm/Analysis/TargetTransformInfo.h"
+#include "llvm/Analysis/ValueTracking.h"
+#include "llvm/CodeGen/Passes.h"
+#include "llvm/Config/llvm-config.h"
+#include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/Dominators.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/InstrTypes.h"
+#include "llvm/IR/Instruction.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/Type.h"
+#include "llvm/IR/Value.h"
+#include "llvm/InitializePasses.h"
+#include "llvm/Pass.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
+#include "llvm/Transforms/Scalar.h"
+#include "llvm/Transforms/Utils.h"
+#include "llvm/Transforms/Utils/BasicBlockUtils.h"
+#include "llvm/Transforms/Utils/Local.h"
+#include "llvm/Transforms/Utils/LoopUtils.h"
+#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
+#include <cassert>
+#include <optional>
+#include <utility>
+
+using namespace llvm;
+
+#define DEBUG_TYPE "loop-term-fold"
+
+STATISTIC(NumTermFold,
+          "Number of terminating condition fold recognized and performed");
+
+static std::optional<std::tuple<PHINode *, PHINode *, const SCEV *, bool>>
+canFoldTermCondOfLoop(Loop *L, ScalarEvolution &SE, DominatorTree &DT,
+                      const LoopInfo &LI, const TargetTransformInfo &TTI) {
+  if (!L->isInnermost()) {
+    LLVM_DEBUG(dbgs() << "Cannot fold on non-innermost loop\n");
+    return std::nullopt;
+  }
+  // Only inspect on simple loop structure
+  if (!L->isLoopSimplifyForm()) {
+    LLVM_DEBUG(dbgs() << "Cannot fold on non-simple loop\n");
+    return std::nullopt;
+  }
+
+  if (!SE.hasLoopInvariantBackedgeTakenCount(L)) {
+    LLVM_DEBUG(dbgs() << "Cannot fold on backedge that is loop variant\n");
+    return std::nullopt;
+  }
+
+  BasicBlock *LoopLatch = L->getLoopLatch();
+  BranchInst *BI = dyn_cast<BranchInst>(LoopLatch->getTerminator());
+  if (!BI || BI->isUnconditional())
+    return std::nullopt;
+  auto *TermCond = dyn_cast<ICmpInst>(BI->getCondition());
+  if (!TermCond) {
+    LLVM_DEBUG(
+        dbgs() << "Cannot fold on branching condition that is not an ICmpInst");
+    return std::nullopt;
+  }
+  if (!TermCond->hasOneUse()) {
+    LLVM_DEBUG(
+        dbgs()
+        << "Cannot replace terminating condition with more than one use\n");
+    return std::nullopt;
+  }
+
+  BinaryOperator *LHS = dyn_cast<BinaryOperator>(TermCond->getOperand(0));
+  Value *RHS = TermCond->getOperand(1);
+  if (!LHS || !L->isLoopInvariant(RHS))
+    // We could pattern match the inverse form of the icmp, but that is
+    // non-canonical, and this pass is running *very* late in the pipeline.
+    return std::nullopt;
+
+  // Find the IV used by the current exit condition.
+  PHINode *ToFold;
+  Value *ToFoldStart, *ToFoldStep;
+  if (!matchSimpleRecurrence(LHS, ToFold, ToFoldStart, ToFoldStep))
+    return std::nullopt;
+
+  // Ensure the simple recurrence is a part of the current loop.
+  if (ToFold->getParent() != L->getHeader())
+    return std::nullopt;
+
+  // If that IV isn't dead after we rewrite the exit condition in terms of
+  // another IV, there's no point in doing the transform.
+  if (!isAlmostDeadIV(ToFold, LoopLatch, TermCond))
+    return std::nullopt;
+
+  // Inserting instructions in the preheader has a runtime cost, scale
+  // the allowed cost with the loops trip count as best we can.
+  const unsigned ExpansionBudget = [&]() {
+    unsigned Budget = 2 * SCEVCheapExpansionBudget;
+    if (unsigned SmallTC = SE.getSmallConstantMaxTripCount(L))
+      return std::min(Budget, SmallTC);
+    if (std::optional<unsigned> SmallTC = getLoopEstimatedTripCount(L))
+      return std::min(Budget, *SmallTC);
+    // Unknown trip count, assume long running by default.
+    return Budget;
+  }();
+
+  const SCEV *BECount = SE.getBackedgeTakenCount(L);
+  const DataLayout &DL = L->getHeader()->getDataLayout();
+  SCEVExpander Expander(SE, DL, "lsr_fold_term_cond");
+
+  PHINode *ToHelpFold = nullptr;
+  const SCEV *TermValueS = nullptr;
+  bool MustDropPoison = false;
+  auto InsertPt = L->getLoopPreheader()->getTerminator();
+  for (PHINode &PN : L->getHeader()->phis()) {
+    if (ToFold == &PN)
+      continue;
+
+    if (!SE.isSCEVable(PN.getType())) {
+      LLVM_DEBUG(dbgs() << "IV of phi '" << PN
+                        << "' is not SCEV-able, not qualified for the "
+                           "terminating condition folding.\n");
+      continue;
+    }
+    const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(SE.getSCEV(&PN));
+    // Only speculate on affine AddRec
+    if (!AddRec || !AddRec->isAffine()) {
+      LLVM_DEBUG(dbgs() << "SCEV of phi '" << PN
+                        << "' is not an affine add recursion, not qualified "
+                           "for the terminating condition folding.\n");
+      continue;
+    }
+
+    // Check that we can compute the value of AddRec on the exiting iteration
+    // without soundness problems.  evaluateAtIteration internally needs
+    // to multiply the stride of the iteration number - which may wrap around.
+    // The issue here is subtle because computing the result accounting for
+    // wrap is insufficient. In order to use the result in an exit test, we
+    // must also know that AddRec doesn't take the same value on any previous
+    // iteration. The simplest case to consider is a candidate IV which is
+    // narrower than the trip count (and thus original IV), but this can
+    // also happen due to non-unit strides on the candidate IVs.
+    if (!AddRec->hasNoSelfWrap() ||
+        !SE.isKnownNonZero(AddRec->getStepRecurrence(SE)))
+      continue;
+
+    const SCEVAddRecExpr *PostInc = AddRec->getPostIncExpr(SE);
+    const SCEV *TermValueSLocal = PostInc->evaluateAtIteration(BECount, SE);
+    if (!Expander.isSafeToExpand(TermValueSLocal)) {
+      LLVM_DEBUG(
+          dbgs() << "Is not safe to expand terminating value for phi node" << PN
+                 << "\n");
+      continue;
+    }
+
+    if (Expander.isHighCostExpansion(TermValueSLocal, L, ExpansionBudget,
+                                     &TTI, InsertPt)) {
+      LLVM_DEBUG(
+          dbgs() << "Is too expensive to expand terminating value for phi node"
+                 << PN << "\n");
+      continue;
+    }
+
+    // The candidate IV may have been otherwise dead and poison from the
+    // very first iteration.  If we can't disprove that, we can't use the IV.
+    if (!mustExecuteUBIfPoisonOnPathTo(&PN, LoopLatch->getTerminator(), &DT)) {
+      LLVM_DEBUG(dbgs() << "Can not prove poison safety for IV "
+                        << PN << "\n");
+      continue;
+    }
+
+    // The candidate IV may become poison on the last iteration.  If this
+    // value is not branched on, this is a well defined program.  We're
+    // about to add a new use to this IV, and we have to ensure we don't
+    // insert UB which didn't previously exist.
+    bool MustDropPoisonLocal = false;
+    Instruction *PostIncV =
+      cast<Instruction>(PN.getIncomingValueForBlock(LoopLatch));
+    if (!mustExecuteUBIfPoisonOnPathTo(PostIncV, LoopLatch->getTerminator(),
+                                       &DT)) {
+      LLVM_DEBUG(dbgs() << "Can not prove poison safety to insert use"
+                        << PN << "\n");
+
+      // If this is a complex recurrance with multiple instructions computing
+      // the backedge value, we might need to strip poison flags from all of
+      // them.
+      if (PostIncV->getOperand(0) != &PN)
+        continue;
+
+      // In order to perform the transform, we need to drop the poison generating
+      // flags on this instruction (if any).
+      MustDropPoisonLocal = PostIncV->hasPoisonGeneratingFlags();
+    }
+
+    // We pick the last legal alternate IV.  We could expore choosing an optimal
+    // alternate IV if we had a decent heuristic to do so.
+    ToHelpFold = &PN;
+    TermValueS = TermValueSLocal;
+    MustDropPoison = MustDropPoisonLocal;
+  }
+
+  LLVM_DEBUG(if (ToFold && !ToHelpFold) dbgs()
+                 << "Cannot find other AddRec IV to help folding\n";);
+
+  LLVM_DEBUG(if (ToFold && ToHelpFold) dbgs()
+             << "\nFound loop that can fold terminating condition\n"
+             << "  BECount (SCEV): " << *SE.getBackedgeTakenCount(L) << "\n"
+             << "  TermCond: " << *TermCond << "\n"
+             << "  BrandInst: " << *BI << "\n"
+             << "  ToFold: " << *ToFold << "\n"
+             << "  ToHelpFold: " << *ToHelpFold << "\n");
+
+  if (!ToFold || !ToHelpFold)
+    return std::nullopt;
+  return std::make_tuple(ToFold, ToHelpFold, TermValueS, MustDropPoison);
+}
+
+static bool RunTermFold(Loop *L, ScalarEvolution &SE,
+                        DominatorTree &DT, LoopInfo &LI,
+                        const TargetTransformInfo &TTI,
+                        TargetLibraryInfo &TLI,
+                        MemorySSA *MSSA) {
+  std::unique_ptr<MemorySSAUpdater> MSSAU;
+  if (MSSA)
+    MSSAU = std::make_unique<MemorySSAUpdater>(MSSA);
+
+  auto Opt = canFoldTermCondOfLoop(L, SE, DT, LI, TTI);
+  if (!Opt)
+    return false;
+
+  auto [ToFold, ToHelpFold, TermValueS, MustDrop] = *Opt;
+
+  NumTermFold++;
+
+  BasicBlock *LoopPreheader = L->getLoopPreheader();
+  BasicBlock *LoopLatch = L->getLoopLatch();
+
+  (void)ToFold;
+  LLVM_DEBUG(dbgs() << "To fold phi-node:\n"
+             << *ToFold << "\n"
+             << "New term-cond phi-node:\n"
+             << *ToHelpFold << "\n");
+
+  Value *StartValue = ToHelpFold->getIncomingValueForBlock(LoopPreheader);
+  (void)StartValue;
+  Value *LoopValue = ToHelpFold->getIncomingValueForBlock(LoopLatch);
+
+  // See comment in canFoldTermCondOfLoop on why this is sufficient.
+  if (MustDrop)
+    cast<Instruction>(LoopValue)->dropPoisonGeneratingFlags();
+
+  // SCEVExpander for both use in preheader and latch
+  const DataLayout &DL = L->getHeader()->getDataLayout();
+  SCEVExpander Expander(SE, DL, "lsr_fold_term_cond");
+
+  assert(Expander.isSafeToExpand(TermValueS) &&
+         "Terminating value was checked safe in canFoldTerminatingCondition");
+
+  // Create new terminating value at loop preheader
+  Value *TermValue = Expander.expandCodeFor(TermValueS, ToHelpFold->getType(),
+                                            LoopPreheader->getTerminator());
+
+  LLVM_DEBUG(dbgs() << "Start value of new term-cond phi-node:\n"
+             << *StartValue << "\n"
+             << "Terminating value of new term-cond phi-node:\n"
+             << *TermValue << "\n");
+
+  // Create new terminating condition at loop latch
+  BranchInst *BI = cast<BranchInst>(LoopLatch->getTerminator());
+  ICmpInst *OldTermCond = cast<ICmpInst>(BI->getCondition());
+  IRBuilder<> LatchBuilder(LoopLatch->getTerminator());
+  Value *NewTermCond =
+    LatchBuilder.CreateICmp(CmpInst::ICMP_EQ, LoopValue, TermValue,
+                            "lsr_fold_term_cond.replaced_term_cond");
+  // Swap successors to exit loop body if IV equals to new TermValue
+  if (BI->getSuccessor(0) == L->getHeader())
+    BI->swapSuccessors();
+
+  LLVM_DEBUG(dbgs() << "Old term-cond:\n"
+             << *OldTermCond << "\n"
+             << "New term-cond:\n" << *NewTermCond << "\n");
+
+  BI->setCondition(NewTermCond);
+
+  Expander.clear();
+  OldTermCond->eraseFromParent();
+  DeleteDeadPHIs(L->getHeader(), &TLI, MSSAU.get());
+  return true;
+}
+
+namespace {
+
+class LoopTermFold : public LoopPass {
+public:
+  static char ID; // Pass ID, replacement for typeid
+
+  LoopTermFold();
+
+private:
+  bool runOnLoop(Loop *L, LPPassManager &LPM) override;
+  void getAnalysisUsage(AnalysisUsage &AU) const override;
+};
+
+} // end anonymous namespace
+
+LoopTermFold::LoopTermFold() : LoopPass(ID) {
+  initializeLoopTermFoldPass(*PassRegistry::getPassRegistry());
+}
+
+void LoopTermFold::getAnalysisUsage(AnalysisUsage &AU) const {
+  // We split critical edges, so we change the CFG.  However, we do update
+  // many analyses if they are around.
+  AU.addPreservedID(LoopSimplifyID);
+
+  AU.addRequired<LoopInfoWrapperPass>();
+  AU.addPreserved<LoopInfoWrapperPass>();
+  AU.addRequiredID(LoopSimplifyID);
+  AU.addRequired<DominatorTreeWrapperPass>();
+  AU.addPreserved<DominatorTreeWrapperPass>();
+  AU.addRequired<ScalarEvolutionWrapperPass>();
+  AU.addPreserved<ScalarEvolutionWrapperPass>();
+  AU.addRequired<TargetLibraryInfoWrapperPass>();
+  AU.addRequired<TargetTransformInfoWrapperPass>();
+  AU.addPreserved<MemorySSAWrapperPass>();
+}
+
+
+bool LoopTermFold::runOnLoop(Loop *L, LPPassManager & /*LPM*/) {
+  if (skipLoop(L))
+    return false;
+
+  auto &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE();
+  auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
+  auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
+  const auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(
+      *L->getHeader()->getParent());
+  auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(
+      *L->getHeader()->getParent());
+  auto *MSSAAnalysis = getAnalysisIfAvailable<MemorySSAWrapperPass>();
----------------
nikic wrote:

This MSSA fetch looks unnecessary? We're not using it and don't need to do anything to preserve it either.

https://github.com/llvm/llvm-project/pull/104234


More information about the llvm-commits mailing list