[llvm] 706ead0 - [LoopFlatten] Make it a FunctionPass

Sjoerd Meijer via llvm-commits llvm-commits at lists.llvm.org
Tue Nov 10 12:03:51 PST 2020


Author: Sjoerd Meijer
Date: 2020-11-10T20:03:31Z
New Revision: 706ead0e875bf0a127c429ce507e8e79d330d731

URL: https://github.com/llvm/llvm-project/commit/706ead0e875bf0a127c429ce507e8e79d330d731
DIFF: https://github.com/llvm/llvm-project/commit/706ead0e875bf0a127c429ce507e8e79d330d731.diff

LOG: [LoopFlatten] Make it a FunctionPass

This converts LoopFlatten from a LoopPass to a FunctionPass so that we don't
run into problems of a loop pass deleting a (inner)loop.

Differential Revision: https://reviews.llvm.org/D90940

Added: 
    

Modified: 
    llvm/include/llvm/Transforms/Scalar.h
    llvm/include/llvm/Transforms/Scalar/LoopFlatten.h
    llvm/lib/Passes/PassBuilder.cpp
    llvm/lib/Passes/PassRegistry.def
    llvm/lib/Transforms/Scalar/LoopFlatten.cpp
    llvm/test/Transforms/LoopFlatten/loop-flatten-negative.ll

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Transforms/Scalar.h b/llvm/include/llvm/Transforms/Scalar.h
index 50946b54fb985..5575ba2d52d97 100644
--- a/llvm/include/llvm/Transforms/Scalar.h
+++ b/llvm/include/llvm/Transforms/Scalar.h
@@ -153,7 +153,7 @@ Pass *createLoopInterchangePass();
 //
 // LoopFlatten - This pass flattens nested loops into a single loop.
 //
-Pass *createLoopFlattenPass();
+FunctionPass *createLoopFlattenPass();
 
 //===----------------------------------------------------------------------===//
 //

diff  --git a/llvm/include/llvm/Transforms/Scalar/LoopFlatten.h b/llvm/include/llvm/Transforms/Scalar/LoopFlatten.h
index 9d1c44c1732c9..41f91f0900132 100644
--- a/llvm/include/llvm/Transforms/Scalar/LoopFlatten.h
+++ b/llvm/include/llvm/Transforms/Scalar/LoopFlatten.h
@@ -24,8 +24,7 @@ class LoopFlattenPass : public PassInfoMixin<LoopFlattenPass> {
 public:
   LoopFlattenPass() = default;
 
-  PreservedAnalyses run(Loop &L, LoopAnalysisManager &AM,
-                        LoopStandardAnalysisResults &AR, LPMUpdater &U);
+  PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM);
 };
 
 } // end namespace llvm

diff  --git a/llvm/lib/Passes/PassBuilder.cpp b/llvm/lib/Passes/PassBuilder.cpp
index e178b6ebd471b..ffaec542b9813 100644
--- a/llvm/lib/Passes/PassBuilder.cpp
+++ b/llvm/lib/Passes/PassBuilder.cpp
@@ -543,7 +543,7 @@ PassBuilder::buildO1FunctionSimplificationPipeline(OptimizationLevel Level,
 
   LPM2.addPass(LoopDeletionPass());
   if (EnableLoopFlatten)
-    LPM2.addPass(LoopFlattenPass());
+    FPM.addPass(LoopFlattenPass());
   // Do not enable unrolling in PreLinkThinLTO phase during sample PGO
   // because it changes IR to makes profile annotation in back compile
   // inaccurate. The normal unroller doesn't pay attention to forced full unroll

diff  --git a/llvm/lib/Passes/PassRegistry.def b/llvm/lib/Passes/PassRegistry.def
index 9fe17d58bc602..eea6d4be0304e 100644
--- a/llvm/lib/Passes/PassRegistry.def
+++ b/llvm/lib/Passes/PassRegistry.def
@@ -240,6 +240,7 @@ FUNCTION_PASS("load-store-vectorizer", LoadStoreVectorizerPass())
 FUNCTION_PASS("loop-simplify", LoopSimplifyPass())
 FUNCTION_PASS("loop-sink", LoopSinkPass())
 FUNCTION_PASS("loop-unroll-and-jam", LoopUnrollAndJamPass())
+FUNCTION_PASS("loop-flatten", LoopFlattenPass())
 FUNCTION_PASS("lowerinvoke", LowerInvokePass())
 FUNCTION_PASS("lowerswitch", LowerSwitchPass())
 FUNCTION_PASS("mem2reg", PromotePass())
@@ -380,7 +381,6 @@ LOOP_PASS("loop-rotate", LoopRotatePass())
 LOOP_PASS("no-op-loop", NoOpLoopPass())
 LOOP_PASS("print", PrintLoopPass(dbgs()))
 LOOP_PASS("loop-deletion", LoopDeletionPass())
-LOOP_PASS("loop-flatten", LoopFlattenPass())
 LOOP_PASS("loop-simplifycfg", LoopSimplifyCFGPass())
 LOOP_PASS("loop-reduce", LoopStrengthReducePass())
 LOOP_PASS("indvars", IndVarSimplifyPass())

diff  --git a/llvm/lib/Transforms/Scalar/LoopFlatten.cpp b/llvm/lib/Transforms/Scalar/LoopFlatten.cpp
index 65ea4734e452b..6167e2d06ddd0 100644
--- a/llvm/lib/Transforms/Scalar/LoopFlatten.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopFlatten.cpp
@@ -29,7 +29,6 @@
 #include "llvm/Transforms/Scalar/LoopFlatten.h"
 #include "llvm/Analysis/AssumptionCache.h"
 #include "llvm/Analysis/LoopInfo.h"
-#include "llvm/Analysis/LoopPass.h"
 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
 #include "llvm/Analysis/ScalarEvolution.h"
 #include "llvm/Analysis/TargetTransformInfo.h"
@@ -416,17 +415,14 @@ static OverflowResult checkOverflow(struct FlattenInfo &FI,
 
 static bool FlattenLoopPair(struct FlattenInfo &FI, DominatorTree *DT,
                             LoopInfo *LI, ScalarEvolution *SE,
-                            AssumptionCache *AC, const TargetTransformInfo *TTI,
-                            std::function<void(Loop *)> markLoopAsDeleted) {
+                            AssumptionCache *AC, TargetTransformInfo *TTI) {
   Function *F = FI.OuterLoop->getHeader()->getParent();
-
   LLVM_DEBUG(dbgs() << "Loop flattening running on outer loop "
                     << FI.OuterLoop->getHeader()->getName() << " and inner loop "
                     << FI.InnerLoop->getHeader()->getName() << " in "
                     << F->getName() << "\n");
 
   SmallPtrSet<Instruction *, 8> IterationInstructions;
-
   if (!findLoopComponents(FI.InnerLoop, IterationInstructions, FI.InnerInductionPHI,
                           FI.InnerLimit, FI.InnerIncrement, FI.InnerBranch, SE))
     return false;
@@ -528,40 +524,51 @@ static bool FlattenLoopPair(struct FlattenInfo &FI, DominatorTree *DT,
 
   // Tell LoopInfo, SCEV and the pass manager that the inner loop has been
   // deleted, and any information that have about the outer loop invalidated.
-  markLoopAsDeleted(FI.InnerLoop);
   SE->forgetLoop(FI.OuterLoop);
   SE->forgetLoop(FI.InnerLoop);
   LI->erase(FI.InnerLoop);
-
   return true;
 }
 
-PreservedAnalyses LoopFlattenPass::run(Loop &L, LoopAnalysisManager &AM,
-                                       LoopStandardAnalysisResults &AR,
-                                       LPMUpdater &Updater) {
-  if (L.getSubLoops().size() != 1)
-    return PreservedAnalyses::all();
+bool Flatten(DominatorTree *DT, LoopInfo *LI, ScalarEvolution *SE,
+             AssumptionCache *AC, TargetTransformInfo *TTI) {
+  bool Changed = false;
+  for (auto *InnerLoop : LI->getLoopsInPreorder()) {
+    auto *OuterLoop = InnerLoop->getParentLoop();
+    if (!OuterLoop)
+      continue;
+    struct FlattenInfo FI(OuterLoop, InnerLoop);
+    Changed |= FlattenLoopPair(FI, DT, LI, SE, AC, TTI);
+  }
+  return Changed;
+}
 
-  Loop *InnerLoop = *L.begin();
-  std::string LoopName(InnerLoop->getName());
-  struct FlattenInfo FI(InnerLoop->getParentLoop(), InnerLoop);
-  if (!FlattenLoopPair(
-          FI, &AR.DT, &AR.LI, &AR.SE, &AR.AC, &AR.TTI,
-          [&](Loop *L) { Updater.markLoopAsDeleted(*L, LoopName); }))
+PreservedAnalyses LoopFlattenPass::run(Function &F,
+                                       FunctionAnalysisManager &AM) {
+  auto *DT = &AM.getResult<DominatorTreeAnalysis>(F);
+  auto *LI = &AM.getResult<LoopAnalysis>(F);
+  auto *SE = &AM.getResult<ScalarEvolutionAnalysis>(F);
+  auto *AC = &AM.getResult<AssumptionAnalysis>(F);
+  auto *TTI = &AM.getResult<TargetIRAnalysis>(F);
+
+  if (!Flatten(DT, LI, SE, AC, TTI))
     return PreservedAnalyses::all();
-  return getLoopPassPreservedAnalyses();
+
+  PreservedAnalyses PA;
+  PA.preserveSet<CFGAnalyses>();
+  return PA;
 }
 
 namespace {
-class LoopFlattenLegacyPass : public LoopPass {
+class LoopFlattenLegacyPass : public FunctionPass {
 public:
   static char ID; // Pass ID, replacement for typeid
-  LoopFlattenLegacyPass() : LoopPass(ID) {
+  LoopFlattenLegacyPass() : FunctionPass(ID) {
     initializeLoopFlattenLegacyPassPass(*PassRegistry::getPassRegistry());
   }
 
   // Possibly flatten loop L into its child.
-  bool runOnLoop(Loop *L, LPPassManager &) override;
+  bool runOnFunction(Function &F) override;
 
   void getAnalysisUsage(AnalysisUsage &AU) const override {
     getLoopAnalysisUsage(AU);
@@ -576,33 +583,20 @@ class LoopFlattenLegacyPass : public LoopPass {
 char LoopFlattenLegacyPass::ID = 0;
 INITIALIZE_PASS_BEGIN(LoopFlattenLegacyPass, "loop-flatten", "Flattens loops",
                       false, false)
-INITIALIZE_PASS_DEPENDENCY(LoopPass)
 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
 INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
 INITIALIZE_PASS_END(LoopFlattenLegacyPass, "loop-flatten", "Flattens loops",
                     false, false)
 
-Pass *llvm::createLoopFlattenPass() { return new LoopFlattenLegacyPass(); }
-
-bool LoopFlattenLegacyPass::runOnLoop(Loop *L, LPPassManager &LPM) {
-  if (skipLoop(L))
-    return false;
-
-  if (L->getSubLoops().size() != 1)
-    return false;
+FunctionPass *llvm::createLoopFlattenPass() { return new LoopFlattenLegacyPass(); }
 
+bool LoopFlattenLegacyPass::runOnFunction(Function &F) {
   ScalarEvolution *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
   LoopInfo *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
   auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>();
   DominatorTree *DT = DTWP ? &DTWP->getDomTree() : nullptr;
   auto &TTIP = getAnalysis<TargetTransformInfoWrapperPass>();
-  TargetTransformInfo *TTI = &TTIP.getTTI(*L->getHeader()->getParent());
-  AssumptionCache *AC =
-      &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(
-          *L->getHeader()->getParent());
-
-  Loop *InnerLoop = *L->begin();
-  struct FlattenInfo FI(InnerLoop->getParentLoop(), InnerLoop);
-  return FlattenLoopPair(FI, DT, LI, SE, AC, TTI,
-                         [&](Loop *L) { LPM.markLoopAsDeleted(*L); });
+  auto *TTI = &TTIP.getTTI(F);
+  auto *AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
+  return Flatten(DT, LI, SE, AC, TTI);
 }

diff  --git a/llvm/test/Transforms/LoopFlatten/loop-flatten-negative.ll b/llvm/test/Transforms/LoopFlatten/loop-flatten-negative.ll
index aad23318f6e9e..ca7cbd42468f3 100644
--- a/llvm/test/Transforms/LoopFlatten/loop-flatten-negative.ll
+++ b/llvm/test/Transforms/LoopFlatten/loop-flatten-negative.ll
@@ -393,3 +393,216 @@ for.end16:                                        ; preds = %lor.end
 for.end19:                                        ; preds = %for.end16
   ret i32 undef
 }
+
+; A 3d loop corresponding to:
+;
+; for (int i = 0; i < N; ++i)
+;    for (int j = 0; j < N; ++j)
+;      for (int k = 0; k < N; ++k)
+;        f(&A[i + N * (j + N * k)]);
+;
+define void @d3_1(i32* %A, i32 %N) {
+entry:
+  %cmp35 = icmp sgt i32 %N, 0
+  br i1 %cmp35, label %for.cond1.preheader.lr.ph, label %for.cond.cleanup
+
+for.cond1.preheader.lr.ph:
+  br label %for.cond1.preheader.us
+
+for.cond1.preheader.us:
+  %i.036.us = phi i32 [ 0, %for.cond1.preheader.lr.ph ], [ %inc15.us, %for.cond1.for.cond.cleanup3_crit_edge.us ]
+  br i1 true, label %for.cond5.preheader.us.us.preheader, label %for.cond5.preheader.us52.preheader
+
+for.cond5.preheader.us52.preheader:
+  br label %for.cond5.preheader.us52
+
+for.cond5.preheader.us.us.preheader:
+  br label %for.cond5.preheader.us.us
+
+for.cond5.preheader.us52:
+  br i1 false, label %for.cond5.preheader.us52, label %for.cond1.for.cond.cleanup3_crit_edge.us.loopexit58
+
+for.cond1.for.cond.cleanup3_crit_edge.us.loopexit:
+  br label %for.cond1.for.cond.cleanup3_crit_edge.us
+
+for.cond1.for.cond.cleanup3_crit_edge.us.loopexit58:
+  br label %for.cond1.for.cond.cleanup3_crit_edge.us
+
+for.cond1.for.cond.cleanup3_crit_edge.us:
+  %inc15.us = add nuw nsw i32 %i.036.us, 1
+  %cmp.us = icmp slt i32 %inc15.us, %N
+  br i1 %cmp.us, label %for.cond1.preheader.us, label %for.cond.cleanup.loopexit
+
+for.cond5.preheader.us.us:
+  %j.033.us.us = phi i32 [ %inc12.us.us, %for.cond5.for.cond.cleanup7_crit_edge.us.us ], [ 0, %for.cond5.preheader.us.us.preheader ]
+  br label %for.body8.us.us
+
+for.cond5.for.cond.cleanup7_crit_edge.us.us:
+  %inc12.us.us = add nuw nsw i32 %j.033.us.us, 1
+  %cmp2.us.us = icmp slt i32 %inc12.us.us, %N
+  br i1 %cmp2.us.us, label %for.cond5.preheader.us.us, label %for.cond1.for.cond.cleanup3_crit_edge.us.loopexit
+
+for.body8.us.us:
+  %k.031.us.us = phi i32 [ 0, %for.cond5.preheader.us.us ], [ %inc.us.us, %for.body8.us.us ]
+  %mul.us.us = mul nsw i32 %k.031.us.us, %N
+  %add.us.us = add nsw i32 %mul.us.us, %j.033.us.us
+  %mul9.us.us = mul nsw i32 %add.us.us, %N
+  %add10.us.us = add nsw i32 %mul9.us.us, %i.036.us
+  %idxprom.us.us = sext i32 %add10.us.us to i64
+  %arrayidx.us.us = getelementptr inbounds i32, i32* %A, i64 %idxprom.us.us
+  tail call void @f(i32* %arrayidx.us.us) #2
+  %inc.us.us = add nuw nsw i32 %k.031.us.us, 1
+  %cmp6.us.us = icmp slt i32 %inc.us.us, %N
+  br i1 %cmp6.us.us, label %for.body8.us.us, label %for.cond5.for.cond.cleanup7_crit_edge.us.us
+
+for.cond.cleanup.loopexit:
+  br label %for.cond.cleanup
+
+for.cond.cleanup:
+  ret void
+}
+
+; A 3d loop corresponding to:
+;
+;   for (int k = 0; k < N; ++k)
+;    for (int i = 0; i < N; ++i)
+;      for (int j = 0; j < M; ++j)
+;        f(&A[i*M+j]);
+;
+; This could be supported, but isn't at the moment.
+;
+define void @d3_2(i32* %A, i32 %N, i32 %M) {
+entry:
+  %cmp30 = icmp sgt i32 %N, 0
+  br i1 %cmp30, label %for.cond1.preheader.lr.ph, label %for.cond.cleanup
+
+for.cond1.preheader.lr.ph:
+  %cmp625 = icmp sgt i32 %M, 0
+  br label %for.cond1.preheader.us
+
+for.cond1.preheader.us:
+  %k.031.us = phi i32 [ 0, %for.cond1.preheader.lr.ph ], [ %inc13.us, %for.cond1.for.cond.cleanup3_crit_edge.us ]
+  br i1 %cmp625, label %for.cond5.preheader.us.us.preheader, label %for.cond5.preheader.us43.preheader
+
+for.cond5.preheader.us43.preheader:
+  br label %for.cond1.for.cond.cleanup3_crit_edge.us.loopexit50
+
+for.cond5.preheader.us.us.preheader:
+  br label %for.cond5.preheader.us.us
+
+for.cond1.for.cond.cleanup3_crit_edge.us.loopexit:
+  br label %for.cond1.for.cond.cleanup3_crit_edge.us
+
+for.cond1.for.cond.cleanup3_crit_edge.us.loopexit50:
+  br label %for.cond1.for.cond.cleanup3_crit_edge.us
+
+for.cond1.for.cond.cleanup3_crit_edge.us:
+  %inc13.us = add nuw nsw i32 %k.031.us, 1
+  %exitcond52 = icmp ne i32 %inc13.us, %N
+  br i1 %exitcond52, label %for.cond1.preheader.us, label %for.cond.cleanup.loopexit
+
+for.cond5.preheader.us.us:
+  %i.028.us.us = phi i32 [ %inc10.us.us, %for.cond5.for.cond.cleanup7_crit_edge.us.us ], [ 0, %for.cond5.preheader.us.us.preheader ]
+  %mul.us.us = mul nsw i32 %i.028.us.us, %M
+  br label %for.body8.us.us
+
+for.cond5.for.cond.cleanup7_crit_edge.us.us:
+  %inc10.us.us = add nuw nsw i32 %i.028.us.us, 1
+  %exitcond51 = icmp ne i32 %inc10.us.us, %N
+  br i1 %exitcond51, label %for.cond5.preheader.us.us, label %for.cond1.for.cond.cleanup3_crit_edge.us.loopexit
+
+for.body8.us.us:
+  %j.026.us.us = phi i32 [ 0, %for.cond5.preheader.us.us ], [ %inc.us.us, %for.body8.us.us ]
+  %add.us.us = add nsw i32 %j.026.us.us, %mul.us.us
+  %idxprom.us.us = sext i32 %add.us.us to i64
+  %arrayidx.us.us = getelementptr inbounds i32, i32* %A, i64 %idxprom.us.us
+  tail call void @f(i32* %arrayidx.us.us) #2
+  %inc.us.us = add nuw nsw i32 %j.026.us.us, 1
+  %exitcond = icmp ne i32 %inc.us.us, %M
+  br i1 %exitcond, label %for.body8.us.us, label %for.cond5.for.cond.cleanup7_crit_edge.us.us
+
+for.cond.cleanup.loopexit:
+  br label %for.cond.cleanup
+
+for.cond.cleanup:
+  ret void
+}
+
+; A 3d loop corresponding to:
+;
+;   for (int i = 0; i < N; ++i)
+;     for (int j = 0; j < M; ++j) {
+;       A[i*M+j] = 0;
+;       for (int k = 0; k < N; ++k)
+;         g();
+;     }
+;
+define void @d3_3(i32* nocapture %A, i32 %N, i32 %M) {
+entry:
+  %cmp29 = icmp sgt i32 %N, 0
+  br i1 %cmp29, label %for.cond1.preheader.lr.ph, label %for.cond.cleanup
+
+for.cond1.preheader.lr.ph:
+  %cmp227 = icmp sgt i32 %M, 0
+  br i1 %cmp227, label %for.cond1.preheader.us.preheader, label %for.cond1.preheader.preheader
+
+for.cond1.preheader.preheader:
+  br label %for.cond.cleanup.loopexit49
+
+for.cond1.preheader.us.preheader:
+  br label %for.cond1.preheader.us
+
+for.cond1.preheader.us:
+  %i.030.us = phi i32 [ %inc13.us, %for.cond1.for.cond.cleanup3_crit_edge.us ], [ 0, %for.cond1.preheader.us.preheader ]
+  %mul.us = mul nsw i32 %i.030.us, %M
+  br i1 true, label %for.body4.us.us.preheader, label %for.body4.us32.preheader
+
+for.body4.us32.preheader:
+  br label %for.cond1.for.cond.cleanup3_crit_edge.us.loopexit48
+
+for.body4.us.us.preheader:
+  br label %for.body4.us.us
+
+for.cond1.for.cond.cleanup3_crit_edge.us.loopexit:
+  br label %for.cond1.for.cond.cleanup3_crit_edge.us
+
+for.cond1.for.cond.cleanup3_crit_edge.us.loopexit48:
+  br label %for.cond1.for.cond.cleanup3_crit_edge.us
+
+for.cond1.for.cond.cleanup3_crit_edge.us:
+  %inc13.us = add nuw nsw i32 %i.030.us, 1
+  %exitcond51 = icmp ne i32 %inc13.us, %N
+  br i1 %exitcond51, label %for.cond1.preheader.us, label %for.cond.cleanup.loopexit
+
+for.body4.us.us:
+  %j.028.us.us = phi i32 [ %inc10.us.us, %for.cond5.for.cond.cleanup7_crit_edge.us.us ], [ 0, %for.body4.us.us.preheader ]
+  %add.us.us = add nsw i32 %j.028.us.us, %mul.us
+  %idxprom.us.us = sext i32 %add.us.us to i64
+  %arrayidx.us.us = getelementptr inbounds i32, i32* %A, i64 %idxprom.us.us
+  store i32 0, i32* %arrayidx.us.us, align 4
+  br label %for.body8.us.us
+
+for.cond5.for.cond.cleanup7_crit_edge.us.us:
+  %inc10.us.us = add nuw nsw i32 %j.028.us.us, 1
+  %exitcond50 = icmp ne i32 %inc10.us.us, %M
+  br i1 %exitcond50, label %for.body4.us.us, label %for.cond1.for.cond.cleanup3_crit_edge.us.loopexit
+
+for.body8.us.us:
+  %k.026.us.us = phi i32 [ 0, %for.body4.us.us ], [ %inc.us.us, %for.body8.us.us ]
+  tail call void bitcast (void (...)* @g to void ()*)() #2
+  %inc.us.us = add nuw nsw i32 %k.026.us.us, 1
+  %exitcond = icmp ne i32 %inc.us.us, %N
+  br i1 %exitcond, label %for.body8.us.us, label %for.cond5.for.cond.cleanup7_crit_edge.us.us
+
+for.cond.cleanup.loopexit:
+  br label %for.cond.cleanup
+
+for.cond.cleanup.loopexit49:
+  br label %for.cond.cleanup
+
+for.cond.cleanup:
+  ret void
+}
+
+declare dso_local void @f(i32*)
+declare dso_local void @g(...)


        


More information about the llvm-commits mailing list