[llvm] fa3693a - [LoopNest] Handle loop-nest passes in LoopPassManager

Whitney Tsang via llvm-commits llvm-commits at lists.llvm.org
Wed Dec 16 09:07:40 PST 2020


Author: Whitney Tsang
Date: 2020-12-16T17:07:14Z
New Revision: fa3693ad0b34ef1d64f49e3d3dd10865b9fb7a8b

URL: https://github.com/llvm/llvm-project/commit/fa3693ad0b34ef1d64f49e3d3dd10865b9fb7a8b
DIFF: https://github.com/llvm/llvm-project/commit/fa3693ad0b34ef1d64f49e3d3dd10865b9fb7a8b.diff

LOG: [LoopNest] Handle loop-nest passes in LoopPassManager

Per http://llvm.org/OpenProjects.html#llvm_loopnest, the goal of this
patch (and other following patches) is to create facilities that allow
implementing loop nest passes that run on top-level loop nests for the
New Pass Manager.

This patch extends the functionality of LoopPassManager to handle
loop-nest passes by specializing the definition of LoopPassManager that
accepts both kinds of passes in addPass.

Only loop passes are executed if L is not a top-level one, and both
kinds of passes are executed if L is top-level. Currently, loop nest
passes should have the following run method:

PreservedAnalyses run(LoopNest &, LoopAnalysisManager &,
LoopStandardAnalysisResults &, LPMUpdater &);

Reviewed By: Whitney, ychen
Differential Revision: https://reviews.llvm.org/D87045

Added: 
    

Modified: 
    llvm/include/llvm/Analysis/LoopNestAnalysis.h
    llvm/include/llvm/Transforms/Scalar/LoopPassManager.h
    llvm/lib/Analysis/LoopNestAnalysis.cpp
    llvm/lib/Transforms/Scalar/LoopPassManager.cpp
    llvm/unittests/IR/PassBuilderCallbacksTest.cpp
    llvm/unittests/Transforms/Scalar/LoopPassManagerTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/LoopNestAnalysis.h b/llvm/include/llvm/Analysis/LoopNestAnalysis.h
index 792958a312ce..4d77d735819f 100644
--- a/llvm/include/llvm/Analysis/LoopNestAnalysis.h
+++ b/llvm/include/llvm/Analysis/LoopNestAnalysis.h
@@ -128,6 +128,8 @@ class LoopNest {
                         [](const Loop *L) { return L->isLoopSimplifyForm(); });
   }
 
+  StringRef getName() const { return Loops.front()->getName(); }
+
 protected:
   const unsigned MaxPerfectDepth; // maximum perfect nesting depth level.
   LoopVectorTy Loops; // the loops in the nest (in breadth first order).

diff  --git a/llvm/include/llvm/Transforms/Scalar/LoopPassManager.h b/llvm/include/llvm/Transforms/Scalar/LoopPassManager.h
index 4ac12061e79e..a1f43aa6d404 100644
--- a/llvm/include/llvm/Transforms/Scalar/LoopPassManager.h
+++ b/llvm/include/llvm/Transforms/Scalar/LoopPassManager.h
@@ -45,6 +45,7 @@
 #include "llvm/Analysis/GlobalsModRef.h"
 #include "llvm/Analysis/LoopAnalysisManager.h"
 #include "llvm/Analysis/LoopInfo.h"
+#include "llvm/Analysis/LoopNestAnalysis.h"
 #include "llvm/Analysis/MemorySSA.h"
 #include "llvm/Analysis/ScalarEvolution.h"
 #include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h"
@@ -67,13 +68,136 @@ class LPMUpdater;
 // See the comments on the definition of the specialization for details on how
 // it 
diff ers from the primary template.
 template <>
-PreservedAnalyses
-PassManager<Loop, LoopAnalysisManager, LoopStandardAnalysisResults &,
-            LPMUpdater &>::run(Loop &InitialL, LoopAnalysisManager &AM,
-                               LoopStandardAnalysisResults &AnalysisResults,
-                               LPMUpdater &U);
-extern template class PassManager<Loop, LoopAnalysisManager,
-                                  LoopStandardAnalysisResults &, LPMUpdater &>;
+class PassManager<Loop, LoopAnalysisManager, LoopStandardAnalysisResults &,
+                  LPMUpdater &>
+    : public PassInfoMixin<
+          PassManager<Loop, LoopAnalysisManager, LoopStandardAnalysisResults &,
+                      LPMUpdater &>> {
+private:
+  template <typename PassT>
+  using HasRunOnLoopT = decltype(std::declval<PassT>().run(
+      std::declval<Loop &>(), std::declval<LoopAnalysisManager &>(),
+      std::declval<LoopStandardAnalysisResults &>(),
+      std::declval<LPMUpdater &>()));
+
+public:
+  /// Construct a pass manager.
+  ///
+  /// If \p DebugLogging is true, we'll log our progress to llvm::dbgs().
+  explicit PassManager(bool DebugLogging = false)
+      : DebugLogging(DebugLogging) {}
+
+  // FIXME: These are equivalent to the default move constructor/move
+  // assignment. However, using = default triggers linker errors due to the
+  // explicit instantiations below. Find a way to use the default and remove the
+  // duplicated code here.
+  PassManager(PassManager &&Arg)
+      : IsLoopNestPass(std::move(Arg.IsLoopNestPass)),
+        LoopPasses(std::move(Arg.LoopPasses)),
+        LoopNestPasses(std::move(Arg.LoopNestPasses)),
+        DebugLogging(std::move(Arg.DebugLogging)) {}
+
+  PassManager &operator=(PassManager &&RHS) {
+    IsLoopNestPass = std::move(RHS.IsLoopNestPass);
+    LoopPasses = std::move(RHS.LoopPasses);
+    LoopNestPasses = std::move(RHS.LoopNestPasses);
+    DebugLogging = std::move(RHS.DebugLogging);
+    return *this;
+  }
+
+  PreservedAnalyses run(Loop &L, LoopAnalysisManager &AM,
+                        LoopStandardAnalysisResults &AR, LPMUpdater &U);
+
+  /// Add either a loop pass or a loop-nest pass to the pass manager. Append \p
+  /// Pass to the list of loop passes if it has a dedicated \fn run() method for
+  /// loops and to the list of loop-nest passes if the \fn run() method is for
+  /// loop-nests instead. Also append whether \p Pass is loop-nest pass or not
+  /// to the end of \var IsLoopNestPass so we can easily identify the types of
+  /// passes in the pass manager later.
+  template <typename PassT>
+  std::enable_if_t<is_detected<HasRunOnLoopT, PassT>::value>
+  addPass(PassT Pass) {
+    using LoopPassModelT =
+        detail::PassModel<Loop, PassT, PreservedAnalyses, LoopAnalysisManager,
+                          LoopStandardAnalysisResults &, LPMUpdater &>;
+    IsLoopNestPass.push_back(false);
+    LoopPasses.emplace_back(new LoopPassModelT(std::move(Pass)));
+  }
+
+  template <typename PassT>
+  std::enable_if_t<!is_detected<HasRunOnLoopT, PassT>::value>
+  addPass(PassT Pass) {
+    using LoopNestPassModelT =
+        detail::PassModel<LoopNest, PassT, PreservedAnalyses,
+                          LoopAnalysisManager, LoopStandardAnalysisResults &,
+                          LPMUpdater &>;
+    IsLoopNestPass.push_back(true);
+    LoopNestPasses.emplace_back(new LoopNestPassModelT(std::move(Pass)));
+  }
+
+  // Specializations of `addPass` for `RepeatedPass`. These are necessary since
+  // `RepeatedPass` has a templated `run` method that will result in incorrect
+  // detection of `HasRunOnLoopT`.
+  template <typename PassT>
+  std::enable_if_t<is_detected<HasRunOnLoopT, PassT>::value>
+  addPass(RepeatedPass<PassT> Pass) {
+    using RepeatedLoopPassModelT =
+        detail::PassModel<Loop, RepeatedPass<PassT>, PreservedAnalyses,
+                          LoopAnalysisManager, LoopStandardAnalysisResults &,
+                          LPMUpdater &>;
+    IsLoopNestPass.push_back(false);
+    LoopPasses.emplace_back(new RepeatedLoopPassModelT(std::move(Pass)));
+  }
+
+  template <typename PassT>
+  std::enable_if_t<!is_detected<HasRunOnLoopT, PassT>::value>
+  addPass(RepeatedPass<PassT> Pass) {
+    using RepeatedLoopNestPassModelT =
+        detail::PassModel<LoopNest, RepeatedPass<PassT>, PreservedAnalyses,
+                          LoopAnalysisManager, LoopStandardAnalysisResults &,
+                          LPMUpdater &>;
+    IsLoopNestPass.push_back(true);
+    LoopNestPasses.emplace_back(
+        new RepeatedLoopNestPassModelT(std::move(Pass)));
+  }
+
+  bool isEmpty() const { return LoopPasses.empty() && LoopNestPasses.empty(); }
+
+  static bool isRequired() { return true; }
+
+protected:
+  using LoopPassConceptT =
+      detail::PassConcept<Loop, LoopAnalysisManager,
+                          LoopStandardAnalysisResults &, LPMUpdater &>;
+  using LoopNestPassConceptT =
+      detail::PassConcept<LoopNest, LoopAnalysisManager,
+                          LoopStandardAnalysisResults &, LPMUpdater &>;
+
+  // BitVector that identifies whether the passes are loop passes or loop-nest
+  // passes (true for loop-nest passes).
+  BitVector IsLoopNestPass;
+  std::vector<std::unique_ptr<LoopPassConceptT>> LoopPasses;
+  std::vector<std::unique_ptr<LoopNestPassConceptT>> LoopNestPasses;
+
+  /// Flag indicating whether we should do debug logging.
+  bool DebugLogging;
+
+  /// Run either a loop pass or a loop-nest pass. Returns `None` if
+  /// PassInstrumentation's BeforePass returns false. Otherwise, returns the
+  /// preserved analyses of the pass.
+  template <typename IRUnitT, typename PassT>
+  Optional<PreservedAnalyses>
+  runSinglePass(IRUnitT &IR, PassT &Pass, LoopAnalysisManager &AM,
+                LoopStandardAnalysisResults &AR, LPMUpdater &U,
+                PassInstrumentation &PI);
+
+  PreservedAnalyses runWithLoopNestPasses(Loop &L, LoopAnalysisManager &AM,
+                                          LoopStandardAnalysisResults &AR,
+                                          LPMUpdater &U);
+  PreservedAnalyses runWithoutLoopNestPasses(Loop &L, LoopAnalysisManager &AM,
+                                             LoopStandardAnalysisResults &AR,
+                                             LPMUpdater &U);
+};
 
 /// The Loop pass manager.
 ///
@@ -223,6 +347,29 @@ class LPMUpdater {
       : Worklist(Worklist), LAM(LAM) {}
 };
 
+template <typename IRUnitT, typename PassT>
+Optional<PreservedAnalyses> LoopPassManager::runSinglePass(
+    IRUnitT &IR, PassT &Pass, LoopAnalysisManager &AM,
+    LoopStandardAnalysisResults &AR, LPMUpdater &U, PassInstrumentation &PI) {
+  // Check the PassInstrumentation's BeforePass callbacks before running the
+  // pass, skip its execution completely if asked to (callback returns false).
+  if (!PI.runBeforePass<IRUnitT>(*Pass, IR))
+    return None;
+
+  PreservedAnalyses PA;
+  {
+    TimeTraceScope TimeScope(Pass->name(), IR.getName());
+    PA = Pass->run(IR, AM, AR, U);
+  }
+
+  // do not pass deleted Loop into the instrumentation
+  if (U.skipCurrentLoop())
+    PI.runAfterPassInvalidated<IRUnitT>(*Pass, PA);
+  else
+    PI.runAfterPass<IRUnitT>(*Pass, IR, PA);
+  return PA;
+}
+
 /// Adaptor that maps from a function to its loops.
 ///
 /// Designed to allow composition of a LoopPass(Manager) and a

diff  --git a/llvm/lib/Analysis/LoopNestAnalysis.cpp b/llvm/lib/Analysis/LoopNestAnalysis.cpp
index 1e322e15f74c..ef10b7e97461 100644
--- a/llvm/lib/Analysis/LoopNestAnalysis.cpp
+++ b/llvm/lib/Analysis/LoopNestAnalysis.cpp
@@ -306,6 +306,8 @@ static bool checkLoopsStructure(const Loop &OuterLoop, const Loop &InnerLoop,
   return true;
 }
 
+AnalysisKey LoopNestAnalysis::Key;
+
 raw_ostream &llvm::operator<<(raw_ostream &OS, const LoopNest &LN) {
   OS << "IsPerfect=";
   if (LN.getMaxPerfectDepth() == LN.getNestDepth())

diff  --git a/llvm/lib/Transforms/Scalar/LoopPassManager.cpp b/llvm/lib/Transforms/Scalar/LoopPassManager.cpp
index 90e23c88cb84..809f43eb4dd8 100644
--- a/llvm/lib/Transforms/Scalar/LoopPassManager.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopPassManager.cpp
@@ -12,58 +12,101 @@
 
 using namespace llvm;
 
-// Explicit template instantiations and specialization defininitions for core
-// template typedefs.
 namespace llvm {
-template class PassManager<Loop, LoopAnalysisManager,
-                           LoopStandardAnalysisResults &, LPMUpdater &>;
 
 /// Explicitly specialize the pass manager's run method to handle loop nest
 /// structure updates.
-template <>
 PreservedAnalyses
 PassManager<Loop, LoopAnalysisManager, LoopStandardAnalysisResults &,
             LPMUpdater &>::run(Loop &L, LoopAnalysisManager &AM,
                                LoopStandardAnalysisResults &AR, LPMUpdater &U) {
-  PreservedAnalyses PA = PreservedAnalyses::all();
 
   if (DebugLogging)
     dbgs() << "Starting Loop pass manager run.\n";
 
+  // Runs loop-nest passes only when the current loop is a top-level one.
+  PreservedAnalyses PA = (L.isOutermost() && !LoopNestPasses.empty())
+                             ? runWithLoopNestPasses(L, AM, AR, U)
+                             : runWithoutLoopNestPasses(L, AM, AR, U);
+
+  // Invalidation for the current loop should be handled above, and other loop
+  // analysis results shouldn't be impacted by runs over this loop. Therefore,
+  // the remaining analysis results in the AnalysisManager are preserved. We
+  // mark this with a set so that we don't need to inspect each one
+  // individually.
+  // FIXME: This isn't correct! This loop and all nested loops' analyses should
+  // be preserved, but unrolling should invalidate the parent loop's analyses.
+  PA.preserveSet<AllAnalysesOn<Loop>>();
+
+  if (DebugLogging)
+    dbgs() << "Finished Loop pass manager run.\n";
+
+  return PA;
+}
+
+// Run both loop passes and loop-nest passes on top-level loop \p L.
+PreservedAnalyses
+LoopPassManager::runWithLoopNestPasses(Loop &L, LoopAnalysisManager &AM,
+                                       LoopStandardAnalysisResults &AR,
+                                       LPMUpdater &U) {
+  assert(L.isOutermost() &&
+         "Loop-nest passes should only run on top-level loops.");
+  PreservedAnalyses PA = PreservedAnalyses::all();
+
   // Request PassInstrumentation from analysis manager, will use it to run
   // instrumenting callbacks for the passes later.
   PassInstrumentation PI = AM.getResult<PassInstrumentationAnalysis>(L, AR);
-  for (auto &Pass : Passes) {
-    // Check the PassInstrumentation's BeforePass callbacks before running the
-    // pass, skip its execution completely if asked to (callback returns false).
-    if (!PI.runBeforePass<Loop>(*Pass, L))
-      continue;
 
-    PreservedAnalyses PassPA;
-    {
-      TimeTraceScope TimeScope(Pass->name(), L.getName());
-      PassPA = Pass->run(L, AM, AR, U);
+  unsigned LoopPassIndex = 0, LoopNestPassIndex = 0;
+
+  // `LoopNestPtr` points to the `LoopNest` object for the current top-level
+  // loop and `IsLoopNestPtrValid` indicates whether the pointer is still valid.
+  // The `LoopNest` object will have to be re-constructed if the pointer is
+  // invalid when encountering a loop-nest pass.
+  std::unique_ptr<LoopNest> LoopNestPtr;
+  bool IsLoopNestPtrValid = false;
+
+  for (size_t I = 0, E = IsLoopNestPass.size(); I != E; ++I) {
+    Optional<PreservedAnalyses> PassPA;
+    if (!IsLoopNestPass[I]) {
+      // The `I`-th pass is a loop pass.
+      auto &Pass = LoopPasses[LoopPassIndex++];
+      PassPA = runSinglePass(L, Pass, AM, AR, U, PI);
+    } else {
+      // The `I`-th pass is a loop-nest pass.
+      auto &Pass = LoopNestPasses[LoopNestPassIndex++];
+
+      // If the loop-nest object calculated before is no longer valid,
+      // re-calculate it here before running the loop-nest pass.
+      if (!IsLoopNestPtrValid) {
+        LoopNestPtr = LoopNest::getLoopNest(L, AR.SE);
+        IsLoopNestPtrValid = true;
+      }
+      PassPA = runSinglePass(*LoopNestPtr, Pass, AM, AR, U, PI);
     }
 
-    // do not pass deleted Loop into the instrumentation
-    if (U.skipCurrentLoop())
-      PI.runAfterPassInvalidated<Loop>(*Pass, PassPA);
-    else
-      PI.runAfterPass<Loop>(*Pass, L, PassPA);
+    // `PassPA` is `None` means that the before-pass callbacks in
+    // `PassInstrumentation` return false. The pass does not run in this case,
+    // so we can skip the following procedure.
+    if (!PassPA)
+      continue;
 
     // If the loop was deleted, abort the run and return to the outer walk.
     if (U.skipCurrentLoop()) {
-      PA.intersect(std::move(PassPA));
+      PA.intersect(std::move(*PassPA));
       break;
     }
 
     // Update the analysis manager as each pass runs and potentially
     // invalidates analyses.
-    AM.invalidate(L, PassPA);
+    AM.invalidate(L, *PassPA);
 
     // Finally, we intersect the final preserved analyses to compute the
     // aggregate preserved set for this pass manager.
-    PA.intersect(std::move(PassPA));
+    PA.intersect(std::move(*PassPA));
+
+    // Check if the current pass preserved the loop-nest object or not.
+    IsLoopNestPtrValid &= PassPA->getChecker<LoopNestAnalysis>().preserved();
 
     // FIXME: Historically, the pass managers all called the LLVM context's
     // yield function here. We don't have a generic way to acquire the
@@ -71,22 +114,53 @@ PassManager<Loop, LoopAnalysisManager, LoopStandardAnalysisResults &,
     // in the new pass manager so it is currently omitted.
     // ...getContext().yield();
   }
+  return PA;
+}
 
-  // Invalidation for the current loop should be handled above, and other loop
-  // analysis results shouldn't be impacted by runs over this loop. Therefore,
-  // the remaining analysis results in the AnalysisManager are preserved. We
-  // mark this with a set so that we don't need to inspect each one
-  // individually.
-  // FIXME: This isn't correct! This loop and all nested loops' analyses should
-  // be preserved, but unrolling should invalidate the parent loop's analyses.
-  PA.preserveSet<AllAnalysesOn<Loop>>();
+// Run all loop passes on loop \p L. Loop-nest passes don't run either because
+// \p L is not a top-level one or simply because there are no loop-nest passes
+// in the pass manager at all.
+PreservedAnalyses
+LoopPassManager::runWithoutLoopNestPasses(Loop &L, LoopAnalysisManager &AM,
+                                          LoopStandardAnalysisResults &AR,
+                                          LPMUpdater &U) {
+  PreservedAnalyses PA = PreservedAnalyses::all();
 
-  if (DebugLogging)
-    dbgs() << "Finished Loop pass manager run.\n";
+  // Request PassInstrumentation from analysis manager, will use it to run
+  // instrumenting callbacks for the passes later.
+  PassInstrumentation PI = AM.getResult<PassInstrumentationAnalysis>(L, AR);
+  for (auto &Pass : LoopPasses) {
+    Optional<PreservedAnalyses> PassPA = runSinglePass(L, Pass, AM, AR, U, PI);
 
+    // `PassPA` is `None` means that the before-pass callbacks in
+    // `PassInstrumentation` return false. The pass does not run in this case,
+    // so we can skip the following procedure.
+    if (!PassPA)
+      continue;
+
+    // If the loop was deleted, abort the run and return to the outer walk.
+    if (U.skipCurrentLoop()) {
+      PA.intersect(std::move(*PassPA));
+      break;
+    }
+
+    // Update the analysis manager as each pass runs and potentially
+    // invalidates analyses.
+    AM.invalidate(L, *PassPA);
+
+    // Finally, we intersect the final preserved analyses to compute the
+    // aggregate preserved set for this pass manager.
+    PA.intersect(std::move(*PassPA));
+
+    // FIXME: Historically, the pass managers all called the LLVM context's
+    // yield function here. We don't have a generic way to acquire the
+    // context and it isn't yet clear what the right pattern is for yielding
+    // in the new pass manager so it is currently omitted.
+    // ...getContext().yield();
+  }
   return PA;
 }
-}
+} // namespace llvm
 
 PreservedAnalyses FunctionToLoopPassAdaptor::run(Function &F,
                                                  FunctionAnalysisManager &AM) {
@@ -152,8 +226,10 @@ PreservedAnalyses FunctionToLoopPassAdaptor::run(Function &F,
   PI.pushBeforeNonSkippedPassCallback([&LAR, &LI](StringRef PassID, Any IR) {
     if (isSpecialPass(PassID, {"PassManager"}))
       return;
-    assert(any_isa<const Loop *>(IR));
-    const Loop *L = any_cast<const Loop *>(IR);
+    assert(any_isa<const Loop *>(IR) || any_isa<const LoopNest *>(IR));
+    const Loop *L = any_isa<const Loop *>(IR)
+                        ? any_cast<const Loop *>(IR)
+                        : &any_cast<const LoopNest *>(IR)->getOutermostLoop();
     assert(L && "Loop should be valid for printing");
 
     // Verify the loop structure and LCSSA form before visiting the loop.

diff  --git a/llvm/unittests/IR/PassBuilderCallbacksTest.cpp b/llvm/unittests/IR/PassBuilderCallbacksTest.cpp
index a4366e10bd68..edd46b8521d6 100644
--- a/llvm/unittests/IR/PassBuilderCallbacksTest.cpp
+++ b/llvm/unittests/IR/PassBuilderCallbacksTest.cpp
@@ -174,6 +174,22 @@ struct MockPassHandle<Loop>
   MockPassHandle() { setDefaults(); }
 };
 
+template <>
+struct MockPassHandle<LoopNest>
+    : MockPassHandleBase<MockPassHandle<LoopNest>, LoopNest,
+                         LoopAnalysisManager, LoopStandardAnalysisResults &,
+                         LPMUpdater &> {
+  MOCK_METHOD4(run,
+               PreservedAnalyses(LoopNest &, LoopAnalysisManager &,
+                                 LoopStandardAnalysisResults &, LPMUpdater &));
+  static void invalidateLoopNest(LoopNest &L, LoopAnalysisManager &,
+                                 LoopStandardAnalysisResults &,
+                                 LPMUpdater &Updater) {
+    Updater.markLoopAsDeleted(L.getOutermostLoop(), L.getName());
+  }
+  MockPassHandle() { setDefaults(); }
+};
+
 template <>
 struct MockPassHandle<Function>
     : MockPassHandleBase<MockPassHandle<Function>, Function> {
@@ -284,6 +300,8 @@ template <> std::string getName(const llvm::Any &WrappedIR) {
     return any_cast<const Function *>(WrappedIR)->getName().str();
   if (any_isa<const Loop *>(WrappedIR))
     return any_cast<const Loop *>(WrappedIR)->getName().str();
+  if (any_isa<const LoopNest *>(WrappedIR))
+    return any_cast<const LoopNest *>(WrappedIR)->getName().str();
   if (any_isa<const LazyCallGraph::SCC *>(WrappedIR))
     return any_cast<const LazyCallGraph::SCC *>(WrappedIR)->getName();
   return "<UNKNOWN>";
@@ -384,6 +402,11 @@ struct MockPassInstrumentationCallbacks {
   }
 };
 
+template <typename IRUnitT>
+using ExtraMockPassHandle =
+    std::conditional_t<std::is_same<IRUnitT, Loop>::value,
+                       MockPassHandle<LoopNest>, MockPassHandle<IRUnitT>>;
+
 template <typename PassManagerT> class PassBuilderCallbacksTest;
 
 /// This test fixture is shared between all the actual tests below and
@@ -416,6 +439,8 @@ class PassBuilderCallbacksTest<PassManager<
   ModuleAnalysisManager AM;
 
   MockPassHandle<IRUnitT> PassHandle;
+  ExtraMockPassHandle<IRUnitT> ExtraPassHandle;
+
   MockAnalysisHandle<IRUnitT> AnalysisHandle;
 
   static PreservedAnalyses getAnalysisResult(IRUnitT &U, AnalysisManagerT &AM,
@@ -475,6 +500,8 @@ class PassBuilderCallbacksTest<PassManager<
           /// Parse the name of our pass mock handle
           if (Name == "test-transform") {
             PM.addPass(PassHandle.getPass());
+            if (std::is_same<IRUnitT, Loop>::value)
+              PM.addPass(ExtraPassHandle.getPass());
             return true;
           }
           return false;
@@ -781,6 +808,7 @@ TEST_F(LoopCallbacksTest, Passes) {
   EXPECT_CALL(AnalysisHandle, run(HasName("loop"), _, _));
   EXPECT_CALL(PassHandle, run(HasName("loop"), _, _, _))
       .WillOnce(WithArgs<0, 1, 2>(Invoke(getAnalysisResult)));
+  EXPECT_CALL(ExtraPassHandle, run(HasName("loop"), _, _, _));
 
   StringRef PipelineText = "test-transform";
   ASSERT_THAT_ERROR(PB.parsePassPipeline(PM, PipelineText), Succeeded())
@@ -798,6 +826,7 @@ TEST_F(LoopCallbacksTest, InstrumentedPasses) {
   EXPECT_CALL(AnalysisHandle, run(HasName("loop"), _, _));
   EXPECT_CALL(PassHandle, run(HasName("loop"), _, _, _))
       .WillOnce(WithArgs<0, 1, 2>(Invoke(getAnalysisResult)));
+  EXPECT_CALL(ExtraPassHandle, run(HasName("loop"), _, _, _));
 
   // PassInstrumentation calls should happen in-sequence, in the same order
   // as passes/analyses are scheduled.
@@ -821,6 +850,19 @@ TEST_F(LoopCallbacksTest, InstrumentedPasses) {
               runAfterPass(HasNameRegex("MockPassHandle"), HasName("loop"), _))
       .InSequence(PISequence);
 
+  EXPECT_CALL(CallbacksHandle,
+              runBeforePass(HasNameRegex("MockPassHandle<.*LoopNest>"),
+                            HasName("loop")))
+      .InSequence(PISequence);
+  EXPECT_CALL(CallbacksHandle,
+              runBeforeNonSkippedPass(
+                  HasNameRegex("MockPassHandle<.*LoopNest>"), HasName("loop")))
+      .InSequence(PISequence);
+  EXPECT_CALL(CallbacksHandle,
+              runAfterPass(HasNameRegex("MockPassHandle<.*LoopNest>"),
+                           HasName("loop"), _))
+      .InSequence(PISequence);
+
   // Our mock pass does not invalidate IR.
   EXPECT_CALL(CallbacksHandle,
               runAfterPassInvalidated(HasNameRegex("MockPassHandle"), _))
@@ -887,6 +929,77 @@ TEST_F(LoopCallbacksTest, InstrumentedInvalidatingPasses) {
   PM.run(*M, AM);
 }
 
+TEST_F(LoopCallbacksTest, InstrumentedInvalidatingLoopNestPasses) {
+  CallbacksHandle.registerPassInstrumentation();
+  // Non-mock instrumentation not specifically mentioned below can be ignored.
+  CallbacksHandle.ignoreNonMockPassInstrumentation("<string>");
+  CallbacksHandle.ignoreNonMockPassInstrumentation("foo");
+  CallbacksHandle.ignoreNonMockPassInstrumentation("loop");
+
+  EXPECT_CALL(AnalysisHandle, run(HasName("loop"), _, _));
+  EXPECT_CALL(PassHandle, run(HasName("loop"), _, _, _))
+      .WillOnce(WithArgs<0, 1, 2>(Invoke(getAnalysisResult)));
+  EXPECT_CALL(ExtraPassHandle, run(HasName("loop"), _, _, _))
+      .WillOnce(DoAll(Invoke(ExtraPassHandle.invalidateLoopNest),
+                      Invoke([&](LoopNest &, LoopAnalysisManager &,
+                                 LoopStandardAnalysisResults &, LPMUpdater &) {
+                        return PreservedAnalyses::all();
+                      })));
+
+  // PassInstrumentation calls should happen in-sequence, in the same order
+  // as passes/analyses are scheduled.
+  ::testing::Sequence PISequence;
+  EXPECT_CALL(CallbacksHandle,
+              runBeforePass(HasNameRegex("MockPassHandle"), HasName("loop")))
+      .InSequence(PISequence);
+  EXPECT_CALL(
+      CallbacksHandle,
+      runBeforeNonSkippedPass(HasNameRegex("MockPassHandle"), HasName("loop")))
+      .InSequence(PISequence);
+  EXPECT_CALL(
+      CallbacksHandle,
+      runBeforeAnalysis(HasNameRegex("MockAnalysisHandle"), HasName("loop")))
+      .InSequence(PISequence);
+  EXPECT_CALL(
+      CallbacksHandle,
+      runAfterAnalysis(HasNameRegex("MockAnalysisHandle"), HasName("loop")))
+      .InSequence(PISequence);
+  EXPECT_CALL(CallbacksHandle,
+              runAfterPass(HasNameRegex("MockPassHandle"), HasName("loop"), _))
+      .InSequence(PISequence);
+
+  EXPECT_CALL(CallbacksHandle,
+              runBeforePass(HasNameRegex("MockPassHandle<.*LoopNest>"),
+                            HasName("loop")))
+      .InSequence(PISequence);
+  EXPECT_CALL(CallbacksHandle,
+              runBeforeNonSkippedPass(
+                  HasNameRegex("MockPassHandle<.*LoopNest>"), HasName("loop")))
+      .InSequence(PISequence);
+  EXPECT_CALL(
+      CallbacksHandle,
+      runAfterPassInvalidated(HasNameRegex("MockPassHandle<.*LoopNest>"), _))
+      .InSequence(PISequence);
+
+  EXPECT_CALL(CallbacksHandle,
+              runAfterPassInvalidated(HasNameRegex("^PassManager"), _))
+      .InSequence(PISequence);
+
+  // Our mock pass invalidates IR, thus normal runAfterPass is never called.
+  EXPECT_CALL(CallbacksHandle, runAfterPassInvalidated(
+                                   HasNameRegex("MockPassHandle<.*Loop>"), _))
+      .Times(0);
+  EXPECT_CALL(CallbacksHandle,
+              runAfterPass(HasNameRegex("MockPassHandle<.*LoopNest>"),
+                           HasName("loop"), _))
+      .Times(0);
+
+  StringRef PipelineText = "test-transform";
+  ASSERT_THAT_ERROR(PB.parsePassPipeline(PM, PipelineText), Succeeded())
+      << "Pipeline was: " << PipelineText;
+  PM.run(*M, AM);
+}
+
 TEST_F(LoopCallbacksTest, InstrumentedSkippedPasses) {
   CallbacksHandle.registerPassInstrumentation();
   // Non-mock instrumentation run here can safely be ignored.
@@ -895,28 +1008,51 @@ TEST_F(LoopCallbacksTest, InstrumentedSkippedPasses) {
   CallbacksHandle.ignoreNonMockPassInstrumentation("loop");
 
   // Skip the pass by returning false.
+  EXPECT_CALL(
+      CallbacksHandle,
+      runBeforePass(HasNameRegex("MockPassHandle<.*Loop>"), HasName("loop")))
+      .WillOnce(Return(false));
+
   EXPECT_CALL(CallbacksHandle,
-              runBeforePass(HasNameRegex("MockPassHandle"), HasName("loop")))
+              runBeforeSkippedPass(HasNameRegex("MockPassHandle<.*Loop>"),
+                                   HasName("loop")))
+      .Times(1);
+
+  EXPECT_CALL(CallbacksHandle,
+              runBeforePass(HasNameRegex("MockPassHandle<.*LoopNest>"),
+                            HasName("loop")))
       .WillOnce(Return(false));
 
-  EXPECT_CALL(
-      CallbacksHandle,
-      runBeforeSkippedPass(HasNameRegex("MockPassHandle"), HasName("loop")))
+  EXPECT_CALL(CallbacksHandle,
+              runBeforeSkippedPass(HasNameRegex("MockPassHandle<.*LoopNest>"),
+                                   HasName("loop")))
       .Times(1);
 
   EXPECT_CALL(AnalysisHandle, run(HasName("loop"), _, _)).Times(0);
   EXPECT_CALL(PassHandle, run(HasName("loop"), _, _, _)).Times(0);
+  EXPECT_CALL(ExtraPassHandle, run(HasName("loop"), _, _, _)).Times(0);
 
   // As the pass is skipped there is no afterPass, beforeAnalysis/afterAnalysis
   // as well.
-  EXPECT_CALL(CallbacksHandle,
-              runBeforeNonSkippedPass(HasNameRegex("MockPassHandle"), _))
+  EXPECT_CALL(CallbacksHandle, runBeforeNonSkippedPass(
+                                   HasNameRegex("MockPassHandle<.*Loop>"), _))
       .Times(0);
   EXPECT_CALL(CallbacksHandle,
-              runAfterPass(HasNameRegex("MockPassHandle"), _, _))
+              runAfterPass(HasNameRegex("MockPassHandle<.*Loop>"), _, _))
+      .Times(0);
+  EXPECT_CALL(CallbacksHandle, runAfterPassInvalidated(
+                                   HasNameRegex("MockPassHandle<.*Loop>"), _))
+      .Times(0);
+  EXPECT_CALL(
+      CallbacksHandle,
+      runBeforeNonSkippedPass(HasNameRegex("MockPassHandle<.*LoopNest>"), _))
       .Times(0);
   EXPECT_CALL(CallbacksHandle,
-              runAfterPassInvalidated(HasNameRegex("MockPassHandle"), _))
+              runAfterPass(HasNameRegex("MockPassHandle<.*LoopNest>"), _, _))
+      .Times(0);
+  EXPECT_CALL(
+      CallbacksHandle,
+      runAfterPassInvalidated(HasNameRegex("MockPassHandle<.*LoopNest>"), _))
       .Times(0);
   EXPECT_CALL(CallbacksHandle,
               runBeforeAnalysis(HasNameRegex("MockAnalysisHandle"), _))

diff  --git a/llvm/unittests/Transforms/Scalar/LoopPassManagerTest.cpp b/llvm/unittests/Transforms/Scalar/LoopPassManagerTest.cpp
index c5b3e29d2a78..fc41bfa00ead 100644
--- a/llvm/unittests/Transforms/Scalar/LoopPassManagerTest.cpp
+++ b/llvm/unittests/Transforms/Scalar/LoopPassManagerTest.cpp
@@ -193,6 +193,16 @@ struct MockLoopPassHandle
   MockLoopPassHandle() { setDefaults(); }
 };
 
+struct MockLoopNestPassHandle
+    : MockPassHandleBase<MockLoopNestPassHandle, LoopNest, LoopAnalysisManager,
+                         LoopStandardAnalysisResults &, LPMUpdater &> {
+  MOCK_METHOD4(run,
+               PreservedAnalyses(LoopNest &, LoopAnalysisManager &,
+                                 LoopStandardAnalysisResults &, LPMUpdater &));
+
+  MockLoopNestPassHandle() { setDefaults(); }
+};
+
 struct MockFunctionPassHandle
     : MockPassHandleBase<MockFunctionPassHandle, Function> {
   MOCK_METHOD2(run, PreservedAnalyses(Function &, FunctionAnalysisManager &));
@@ -242,6 +252,7 @@ class LoopPassManagerTest : public ::testing::Test {
 
   MockLoopAnalysisHandle MLAHandle;
   MockLoopPassHandle MLPHandle;
+  MockLoopNestPassHandle MLNPHandle;
   MockFunctionPassHandle MFPHandle;
   MockModulePassHandle MMPHandle;
 
@@ -1590,4 +1601,31 @@ TEST_F(LoopPassManagerTest, LoopDeletion) {
   MPM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM)));
   MPM.run(*M, MAM);
 }
+
+TEST_F(LoopPassManagerTest, HandleLoopNestPass) {
+  ::testing::InSequence MakeExpectationsSequenced;
+
+  EXPECT_CALL(MLPHandle, run(HasName("loop.0.0"), _, _, _)).Times(2);
+  EXPECT_CALL(MLPHandle, run(HasName("loop.0.1"), _, _, _)).Times(2);
+  EXPECT_CALL(MLPHandle, run(HasName("loop.0"), _, _, _));
+  EXPECT_CALL(MLNPHandle, run(HasName("loop.0"), _, _, _));
+  EXPECT_CALL(MLPHandle, run(HasName("loop.0"), _, _, _));
+  EXPECT_CALL(MLNPHandle, run(HasName("loop.0"), _, _, _));
+  EXPECT_CALL(MLPHandle, run(HasName("loop.g.0"), _, _, _));
+  EXPECT_CALL(MLNPHandle, run(HasName("loop.g.0"), _, _, _));
+  EXPECT_CALL(MLPHandle, run(HasName("loop.g.0"), _, _, _));
+  EXPECT_CALL(MLNPHandle, run(HasName("loop.g.0"), _, _, _));
+
+  LoopPassManager LPM(true);
+  LPM.addPass(MLPHandle.getPass());
+  LPM.addPass(MLNPHandle.getPass());
+  LPM.addPass(MLPHandle.getPass());
+  LPM.addPass(MLNPHandle.getPass());
+
+  ModulePassManager MPM(true);
+  MPM.addPass(createModuleToFunctionPassAdaptor(
+      createFunctionToLoopPassAdaptor(std::move(LPM))));
+  MPM.run(*M, MAM);
 }
+
+} // namespace


        


More information about the llvm-commits mailing list