[llvm] [LSR] Split the -lsr-term-fold transformation into it's own pass (PR #104234)
Nikita Popov via llvm-commits
llvm-commits at lists.llvm.org
Fri Aug 16 06:13:10 PDT 2024
================
@@ -0,0 +1,386 @@
+//===- LoopTermFold.cpp - Eliminate last use of IV in exit branch----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Transforms/Scalar/LoopTermFold.h"
+#include "llvm/ADT/Statistic.h"
+#include "llvm/Analysis/LoopAnalysisManager.h"
+#include "llvm/Analysis/LoopInfo.h"
+#include "llvm/Analysis/LoopPass.h"
+#include "llvm/Analysis/MemorySSA.h"
+#include "llvm/Analysis/MemorySSAUpdater.h"
+#include "llvm/Analysis/ScalarEvolution.h"
+#include "llvm/Analysis/ScalarEvolutionExpressions.h"
+#include "llvm/Analysis/TargetLibraryInfo.h"
+#include "llvm/Analysis/TargetTransformInfo.h"
+#include "llvm/Analysis/ValueTracking.h"
+#include "llvm/CodeGen/Passes.h"
+#include "llvm/Config/llvm-config.h"
+#include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/Dominators.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/InstrTypes.h"
+#include "llvm/IR/Instruction.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/Type.h"
+#include "llvm/IR/Value.h"
+#include "llvm/InitializePasses.h"
+#include "llvm/Pass.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
+#include "llvm/Transforms/Scalar.h"
+#include "llvm/Transforms/Utils.h"
+#include "llvm/Transforms/Utils/BasicBlockUtils.h"
+#include "llvm/Transforms/Utils/Local.h"
+#include "llvm/Transforms/Utils/LoopUtils.h"
+#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
+#include <cassert>
+#include <optional>
+#include <utility>
+
+using namespace llvm;
+
+#define DEBUG_TYPE "loop-term-fold"
+
+STATISTIC(NumTermFold,
+ "Number of terminating condition fold recognized and performed");
+
+static std::optional<std::tuple<PHINode *, PHINode *, const SCEV *, bool>>
+canFoldTermCondOfLoop(Loop *L, ScalarEvolution &SE, DominatorTree &DT,
+ const LoopInfo &LI, const TargetTransformInfo &TTI) {
+ if (!L->isInnermost()) {
+ LLVM_DEBUG(dbgs() << "Cannot fold on non-innermost loop\n");
+ return std::nullopt;
+ }
+ // Only inspect on simple loop structure
+ if (!L->isLoopSimplifyForm()) {
+ LLVM_DEBUG(dbgs() << "Cannot fold on non-simple loop\n");
+ return std::nullopt;
+ }
+
+ if (!SE.hasLoopInvariantBackedgeTakenCount(L)) {
+ LLVM_DEBUG(dbgs() << "Cannot fold on backedge that is loop variant\n");
+ return std::nullopt;
+ }
+
+ BasicBlock *LoopLatch = L->getLoopLatch();
+ BranchInst *BI = dyn_cast<BranchInst>(LoopLatch->getTerminator());
+ if (!BI || BI->isUnconditional())
+ return std::nullopt;
+ auto *TermCond = dyn_cast<ICmpInst>(BI->getCondition());
+ if (!TermCond) {
+ LLVM_DEBUG(
+ dbgs() << "Cannot fold on branching condition that is not an ICmpInst");
+ return std::nullopt;
+ }
+ if (!TermCond->hasOneUse()) {
+ LLVM_DEBUG(
+ dbgs()
+ << "Cannot replace terminating condition with more than one use\n");
+ return std::nullopt;
+ }
+
+ BinaryOperator *LHS = dyn_cast<BinaryOperator>(TermCond->getOperand(0));
+ Value *RHS = TermCond->getOperand(1);
+ if (!LHS || !L->isLoopInvariant(RHS))
+ // We could pattern match the inverse form of the icmp, but that is
+ // non-canonical, and this pass is running *very* late in the pipeline.
+ return std::nullopt;
+
+ // Find the IV used by the current exit condition.
+ PHINode *ToFold;
+ Value *ToFoldStart, *ToFoldStep;
+ if (!matchSimpleRecurrence(LHS, ToFold, ToFoldStart, ToFoldStep))
+ return std::nullopt;
+
+ // Ensure the simple recurrence is a part of the current loop.
+ if (ToFold->getParent() != L->getHeader())
+ return std::nullopt;
+
+ // If that IV isn't dead after we rewrite the exit condition in terms of
+ // another IV, there's no point in doing the transform.
+ if (!isAlmostDeadIV(ToFold, LoopLatch, TermCond))
+ return std::nullopt;
+
+ // Inserting instructions in the preheader has a runtime cost, scale
+ // the allowed cost with the loops trip count as best we can.
+ const unsigned ExpansionBudget = [&]() {
+ unsigned Budget = 2 * SCEVCheapExpansionBudget;
+ if (unsigned SmallTC = SE.getSmallConstantMaxTripCount(L))
+ return std::min(Budget, SmallTC);
+ if (std::optional<unsigned> SmallTC = getLoopEstimatedTripCount(L))
+ return std::min(Budget, *SmallTC);
+ // Unknown trip count, assume long running by default.
+ return Budget;
+ }();
+
+ const SCEV *BECount = SE.getBackedgeTakenCount(L);
+ const DataLayout &DL = L->getHeader()->getDataLayout();
+ SCEVExpander Expander(SE, DL, "lsr_fold_term_cond");
+
+ PHINode *ToHelpFold = nullptr;
+ const SCEV *TermValueS = nullptr;
+ bool MustDropPoison = false;
+ auto InsertPt = L->getLoopPreheader()->getTerminator();
+ for (PHINode &PN : L->getHeader()->phis()) {
+ if (ToFold == &PN)
+ continue;
+
+ if (!SE.isSCEVable(PN.getType())) {
+ LLVM_DEBUG(dbgs() << "IV of phi '" << PN
+ << "' is not SCEV-able, not qualified for the "
+ "terminating condition folding.\n");
+ continue;
+ }
+ const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(SE.getSCEV(&PN));
+ // Only speculate on affine AddRec
+ if (!AddRec || !AddRec->isAffine()) {
+ LLVM_DEBUG(dbgs() << "SCEV of phi '" << PN
+ << "' is not an affine add recursion, not qualified "
+ "for the terminating condition folding.\n");
+ continue;
+ }
+
+ // Check that we can compute the value of AddRec on the exiting iteration
+ // without soundness problems. evaluateAtIteration internally needs
+ // to multiply the stride of the iteration number - which may wrap around.
+ // The issue here is subtle because computing the result accounting for
+ // wrap is insufficient. In order to use the result in an exit test, we
+ // must also know that AddRec doesn't take the same value on any previous
+ // iteration. The simplest case to consider is a candidate IV which is
+ // narrower than the trip count (and thus original IV), but this can
+ // also happen due to non-unit strides on the candidate IVs.
+ if (!AddRec->hasNoSelfWrap() ||
+ !SE.isKnownNonZero(AddRec->getStepRecurrence(SE)))
+ continue;
+
+ const SCEVAddRecExpr *PostInc = AddRec->getPostIncExpr(SE);
+ const SCEV *TermValueSLocal = PostInc->evaluateAtIteration(BECount, SE);
+ if (!Expander.isSafeToExpand(TermValueSLocal)) {
+ LLVM_DEBUG(
+ dbgs() << "Is not safe to expand terminating value for phi node" << PN
+ << "\n");
+ continue;
+ }
+
+ if (Expander.isHighCostExpansion(TermValueSLocal, L, ExpansionBudget,
+ &TTI, InsertPt)) {
+ LLVM_DEBUG(
+ dbgs() << "Is too expensive to expand terminating value for phi node"
+ << PN << "\n");
+ continue;
+ }
+
+ // The candidate IV may have been otherwise dead and poison from the
+ // very first iteration. If we can't disprove that, we can't use the IV.
+ if (!mustExecuteUBIfPoisonOnPathTo(&PN, LoopLatch->getTerminator(), &DT)) {
+ LLVM_DEBUG(dbgs() << "Can not prove poison safety for IV "
+ << PN << "\n");
+ continue;
+ }
+
+ // The candidate IV may become poison on the last iteration. If this
+ // value is not branched on, this is a well defined program. We're
+ // about to add a new use to this IV, and we have to ensure we don't
+ // insert UB which didn't previously exist.
+ bool MustDropPoisonLocal = false;
+ Instruction *PostIncV =
+ cast<Instruction>(PN.getIncomingValueForBlock(LoopLatch));
+ if (!mustExecuteUBIfPoisonOnPathTo(PostIncV, LoopLatch->getTerminator(),
+ &DT)) {
+ LLVM_DEBUG(dbgs() << "Can not prove poison safety to insert use"
+ << PN << "\n");
+
+ // If this is a complex recurrance with multiple instructions computing
+ // the backedge value, we might need to strip poison flags from all of
+ // them.
+ if (PostIncV->getOperand(0) != &PN)
+ continue;
+
+ // In order to perform the transform, we need to drop the poison generating
+ // flags on this instruction (if any).
+ MustDropPoisonLocal = PostIncV->hasPoisonGeneratingFlags();
+ }
+
+ // We pick the last legal alternate IV. We could expore choosing an optimal
+ // alternate IV if we had a decent heuristic to do so.
+ ToHelpFold = &PN;
+ TermValueS = TermValueSLocal;
+ MustDropPoison = MustDropPoisonLocal;
+ }
+
+ LLVM_DEBUG(if (ToFold && !ToHelpFold) dbgs()
+ << "Cannot find other AddRec IV to help folding\n";);
+
+ LLVM_DEBUG(if (ToFold && ToHelpFold) dbgs()
+ << "\nFound loop that can fold terminating condition\n"
+ << " BECount (SCEV): " << *SE.getBackedgeTakenCount(L) << "\n"
+ << " TermCond: " << *TermCond << "\n"
+ << " BrandInst: " << *BI << "\n"
+ << " ToFold: " << *ToFold << "\n"
+ << " ToHelpFold: " << *ToHelpFold << "\n");
+
+ if (!ToFold || !ToHelpFold)
+ return std::nullopt;
+ return std::make_tuple(ToFold, ToHelpFold, TermValueS, MustDropPoison);
+}
+
+static bool RunTermFold(Loop *L, ScalarEvolution &SE,
+ DominatorTree &DT, LoopInfo &LI,
+ const TargetTransformInfo &TTI,
+ TargetLibraryInfo &TLI,
+ MemorySSA *MSSA) {
+ std::unique_ptr<MemorySSAUpdater> MSSAU;
+ if (MSSA)
+ MSSAU = std::make_unique<MemorySSAUpdater>(MSSA);
+
+ auto Opt = canFoldTermCondOfLoop(L, SE, DT, LI, TTI);
+ if (!Opt)
+ return false;
+
+ auto [ToFold, ToHelpFold, TermValueS, MustDrop] = *Opt;
+
+ NumTermFold++;
+
+ BasicBlock *LoopPreheader = L->getLoopPreheader();
+ BasicBlock *LoopLatch = L->getLoopLatch();
+
+ (void)ToFold;
+ LLVM_DEBUG(dbgs() << "To fold phi-node:\n"
+ << *ToFold << "\n"
+ << "New term-cond phi-node:\n"
+ << *ToHelpFold << "\n");
+
+ Value *StartValue = ToHelpFold->getIncomingValueForBlock(LoopPreheader);
+ (void)StartValue;
+ Value *LoopValue = ToHelpFold->getIncomingValueForBlock(LoopLatch);
+
+ // See comment in canFoldTermCondOfLoop on why this is sufficient.
+ if (MustDrop)
+ cast<Instruction>(LoopValue)->dropPoisonGeneratingFlags();
+
+ // SCEVExpander for both use in preheader and latch
+ const DataLayout &DL = L->getHeader()->getDataLayout();
+ SCEVExpander Expander(SE, DL, "lsr_fold_term_cond");
+
+ assert(Expander.isSafeToExpand(TermValueS) &&
+ "Terminating value was checked safe in canFoldTerminatingCondition");
+
+ // Create new terminating value at loop preheader
+ Value *TermValue = Expander.expandCodeFor(TermValueS, ToHelpFold->getType(),
+ LoopPreheader->getTerminator());
+
+ LLVM_DEBUG(dbgs() << "Start value of new term-cond phi-node:\n"
+ << *StartValue << "\n"
+ << "Terminating value of new term-cond phi-node:\n"
+ << *TermValue << "\n");
+
+ // Create new terminating condition at loop latch
+ BranchInst *BI = cast<BranchInst>(LoopLatch->getTerminator());
+ ICmpInst *OldTermCond = cast<ICmpInst>(BI->getCondition());
+ IRBuilder<> LatchBuilder(LoopLatch->getTerminator());
+ Value *NewTermCond =
+ LatchBuilder.CreateICmp(CmpInst::ICMP_EQ, LoopValue, TermValue,
+ "lsr_fold_term_cond.replaced_term_cond");
+ // Swap successors to exit loop body if IV equals to new TermValue
+ if (BI->getSuccessor(0) == L->getHeader())
+ BI->swapSuccessors();
+
+ LLVM_DEBUG(dbgs() << "Old term-cond:\n"
+ << *OldTermCond << "\n"
+ << "New term-cond:\n" << *NewTermCond << "\n");
+
+ BI->setCondition(NewTermCond);
+
+ Expander.clear();
+ OldTermCond->eraseFromParent();
+ DeleteDeadPHIs(L->getHeader(), &TLI, MSSAU.get());
+ return true;
+}
+
+namespace {
+
+class LoopTermFold : public LoopPass {
+public:
+ static char ID; // Pass ID, replacement for typeid
+
+ LoopTermFold();
+
+private:
+ bool runOnLoop(Loop *L, LPPassManager &LPM) override;
+ void getAnalysisUsage(AnalysisUsage &AU) const override;
+};
+
+} // end anonymous namespace
+
+LoopTermFold::LoopTermFold() : LoopPass(ID) {
+ initializeLoopTermFoldPass(*PassRegistry::getPassRegistry());
+}
+
+void LoopTermFold::getAnalysisUsage(AnalysisUsage &AU) const {
+ // We split critical edges, so we change the CFG. However, we do update
+ // many analyses if they are around.
+ AU.addPreservedID(LoopSimplifyID);
+
+ AU.addRequired<LoopInfoWrapperPass>();
+ AU.addPreserved<LoopInfoWrapperPass>();
+ AU.addRequiredID(LoopSimplifyID);
+ AU.addRequired<DominatorTreeWrapperPass>();
+ AU.addPreserved<DominatorTreeWrapperPass>();
+ AU.addRequired<ScalarEvolutionWrapperPass>();
+ AU.addPreserved<ScalarEvolutionWrapperPass>();
+ AU.addRequired<TargetLibraryInfoWrapperPass>();
+ AU.addRequired<TargetTransformInfoWrapperPass>();
+ AU.addPreserved<MemorySSAWrapperPass>();
+}
+
+
+bool LoopTermFold::runOnLoop(Loop *L, LPPassManager & /*LPM*/) {
+ if (skipLoop(L))
+ return false;
+
+ auto &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE();
+ auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
+ auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
+ const auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(
+ *L->getHeader()->getParent());
+ auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(
+ *L->getHeader()->getParent());
+ auto *MSSAAnalysis = getAnalysisIfAvailable<MemorySSAWrapperPass>();
----------------
nikic wrote:
This MSSA fetch looks unnecessary? We're not using it and don't need to do anything to preserve it either.
https://github.com/llvm/llvm-project/pull/104234
More information about the llvm-commits
mailing list