[Mlir-commits] [mlir] 56a6985 - [mlir][Pass][NFC] Merge OpToOpPassAdaptor and OpToOpPassAdaptorParallel

River Riddle llvmlistbot at llvm.org
Wed Apr 29 15:23:22 PDT 2020


Author: River Riddle
Date: 2020-04-29T15:23:10-07:00
New Revision: 56a698510faef5bf3ef224c229a049bb1e376a56

URL: https://github.com/llvm/llvm-project/commit/56a698510faef5bf3ef224c229a049bb1e376a56
DIFF: https://github.com/llvm/llvm-project/commit/56a698510faef5bf3ef224c229a049bb1e376a56.diff

LOG: [mlir][Pass][NFC] Merge OpToOpPassAdaptor and OpToOpPassAdaptorParallel

This moves the threading check to runOnOperation. This produces a much cleaner interface for the adaptor pass, and will allow for the ability to enable/disable threading in a much cleaner way in the future.

Differential Revision: https://reviews.llvm.org/D78313

Added: 
    

Modified: 
    mlir/lib/Pass/IRPrinting.cpp
    mlir/lib/Pass/Pass.cpp
    mlir/lib/Pass/PassDetail.h
    mlir/lib/Pass/PassStatistics.cpp
    mlir/lib/Pass/PassTiming.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Pass/IRPrinting.cpp b/mlir/lib/Pass/IRPrinting.cpp
index 679a9ec27ead..ba9ff989ea79 100644
--- a/mlir/lib/Pass/IRPrinting.cpp
+++ b/mlir/lib/Pass/IRPrinting.cpp
@@ -99,7 +99,7 @@ class IRPrinterInstrumentation : public PassInstrumentation {
 
 /// Returns true if the given pass is hidden from IR printing.
 static bool isHiddenPass(Pass *pass) {
-  return isAdaptorPass(pass) || isa<VerifierPass>(pass);
+  return isa<OpToOpPassAdaptor>(pass) || isa<VerifierPass>(pass);
 }
 
 static void printIR(Operation *op, bool printModuleScope, raw_ostream &out,
@@ -173,7 +173,7 @@ void IRPrinterInstrumentation::runAfterPass(Pass *pass, Operation *op) {
 }
 
 void IRPrinterInstrumentation::runAfterPassFailed(Pass *pass, Operation *op) {
-  if (isAdaptorPass(pass))
+  if (isa<OpToOpPassAdaptor>(pass))
     return;
   if (config->shouldPrintAfterOnlyOnChange())
     beforePassFingerPrints.erase(pass);

diff  --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp
index 53ccd4f005a4..b6bef48cb3ec 100644
--- a/mlir/lib/Pass/Pass.cpp
+++ b/mlir/lib/Pass/Pass.cpp
@@ -51,7 +51,7 @@ void Pass::copyOptionValuesFrom(const Pass *other) {
 /// an adaptor pass, print with the op_name(sub_pass,...) format.
 void Pass::printAsTextualPipeline(raw_ostream &os) {
   // Special case for adaptors to use the 'op_name(sub_passes)' format.
-  if (auto *adaptor = getAdaptorPassBase(this)) {
+  if (auto *adaptor = dyn_cast<OpToOpPassAdaptor>(this)) {
     llvm::interleaveComma(adaptor->getPassManagers(), os,
                           [&](OpPassManager &pm) {
                             os << pm.getOpName() << "(";
@@ -152,15 +152,15 @@ struct OpPassManagerImpl {
 void OpPassManagerImpl::coalesceAdjacentAdaptorPasses() {
   // Bail out early if there are no adaptor passes.
   if (llvm::none_of(passes, [](std::unique_ptr<Pass> &pass) {
-        return isAdaptorPass(pass.get());
+        return isa<OpToOpPassAdaptor>(pass.get());
       }))
     return;
 
   // Walk the pass list and merge adjacent adaptors.
-  OpToOpPassAdaptorBase *lastAdaptor = nullptr;
+  OpToOpPassAdaptor *lastAdaptor = nullptr;
   for (auto it = passes.begin(), e = passes.end(); it != e; ++it) {
     // Check to see if this pass is an adaptor.
-    if (auto *currentAdaptor = getAdaptorPassBase(it->get())) {
+    if (auto *currentAdaptor = dyn_cast<OpToOpPassAdaptor>(it->get())) {
       // If it is the first adaptor in a possible chain, remember it and
       // continue.
       if (!lastAdaptor) {
@@ -243,16 +243,7 @@ LogicalResult OpPassManager::run(Operation *op, AnalysisManager am) {
 /// pass manager.
 OpPassManager &OpPassManager::nest(const OperationName &nestedName) {
   OpPassManager nested(nestedName, impl->disableThreads, impl->verifyPasses);
-
-  /// Create an adaptor for this pass. If multi-threading is disabled, then
-  /// create a synchronous adaptor.
-  if (impl->disableThreads || !llvm::llvm_is_multithreaded()) {
-    auto *adaptor = new OpToOpPassAdaptor(std::move(nested));
-    addPass(std::unique_ptr<Pass>(adaptor));
-    return adaptor->getPassManagers().front();
-  }
-
-  auto *adaptor = new OpToOpPassAdaptorParallel(std::move(nested));
+  auto *adaptor = new OpToOpPassAdaptor(std::move(nested));
   addPass(std::unique_ptr<Pass>(adaptor));
   return adaptor->getPassManagers().front();
 }
@@ -330,12 +321,12 @@ static OpPassManager *findPassManagerFor(MutableArrayRef<OpPassManager> mgrs,
   return it == mgrs.end() ? nullptr : &*it;
 }
 
-OpToOpPassAdaptorBase::OpToOpPassAdaptorBase(OpPassManager &&mgr) {
+OpToOpPassAdaptor::OpToOpPassAdaptor(OpPassManager &&mgr) {
   mgrs.emplace_back(std::move(mgr));
 }
 
 /// Merge the current pass adaptor into given 'rhs'.
-void OpToOpPassAdaptorBase::mergeInto(OpToOpPassAdaptorBase &rhs) {
+void OpToOpPassAdaptor::mergeInto(OpToOpPassAdaptor &rhs) {
   for (auto &pm : mgrs) {
     // If an existing pass manager exists, then merge the given pass manager
     // into it.
@@ -357,7 +348,7 @@ void OpToOpPassAdaptorBase::mergeInto(OpToOpPassAdaptorBase &rhs) {
 }
 
 /// Returns the adaptor pass name.
-std::string OpToOpPassAdaptorBase::getName() {
+std::string OpToOpPassAdaptor::getAdaptorName() {
   std::string name = "Pipeline Collection : [";
   llvm::raw_string_ostream os(name);
   llvm::interleaveComma(getPassManagers(), os, [&](OpPassManager &pm) {
@@ -367,11 +358,16 @@ std::string OpToOpPassAdaptorBase::getName() {
   return os.str();
 }
 
-OpToOpPassAdaptor::OpToOpPassAdaptor(OpPassManager &&mgr)
-    : OpToOpPassAdaptorBase(std::move(mgr)) {}
-
 /// Run the held pipeline over all nested operations.
 void OpToOpPassAdaptor::runOnOperation() {
+  if (mgrs.front().getImpl().disableThreads || !llvm::llvm_is_multithreaded())
+    runOnOperationImpl();
+  else
+    runOnOperationAsyncImpl();
+}
+
+/// Run this pass adaptor synchronously.
+void OpToOpPassAdaptor::runOnOperationImpl() {
   auto am = getAnalysisManager();
   PassInstrumentation::PipelineParentInfo parentInfo = {llvm::get_threadid(),
                                                         this};
@@ -397,9 +393,6 @@ void OpToOpPassAdaptor::runOnOperation() {
   }
 }
 
-OpToOpPassAdaptorParallel::OpToOpPassAdaptorParallel(OpPassManager &&mgr)
-    : OpToOpPassAdaptorBase(std::move(mgr)) {}
-
 /// Utility functor that checks if the two ranges of pass managers have a size
 /// mismatch.
 static bool hasSizeMismatch(ArrayRef<OpPassManager> lhs,
@@ -409,8 +402,8 @@ static bool hasSizeMismatch(ArrayRef<OpPassManager> lhs,
                       [&](size_t i) { return lhs[i].size() != rhs[i].size(); });
 }
 
-// Run the held pipeline asynchronously across the functions within the module.
-void OpToOpPassAdaptorParallel::runOnOperation() {
+/// Run this pass adaptor synchronously.
+void OpToOpPassAdaptor::runOnOperationAsyncImpl() {
   AnalysisManager am = getAnalysisManager();
 
   // Create the async executors if they haven't been created, or if the main
@@ -491,16 +484,6 @@ void OpToOpPassAdaptorParallel::runOnOperation() {
     signalPassFailure();
 }
 
-/// Utility function to convert the given class to the base adaptor it is an
-/// adaptor pass, returns nullptr otherwise.
-OpToOpPassAdaptorBase *mlir::detail::getAdaptorPassBase(Pass *pass) {
-  if (auto *adaptor = dyn_cast<OpToOpPassAdaptor>(pass))
-    return adaptor;
-  if (auto *adaptor = dyn_cast<OpToOpPassAdaptorParallel>(pass))
-    return adaptor;
-  return nullptr;
-}
-
 //===----------------------------------------------------------------------===//
 // PassCrashReproducer
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Pass/PassDetail.h b/mlir/lib/Pass/PassDetail.h
index 59d9a7a0576f..2342a1a7af97 100644
--- a/mlir/lib/Pass/PassDetail.h
+++ b/mlir/lib/Pass/PassDetail.h
@@ -27,69 +27,44 @@ class VerifierPass : public PassWrapper<VerifierPass, OperationPass<>> {
 // OpToOpPassAdaptor
 //===----------------------------------------------------------------------===//
 
-/// A base class for Op-to-Op adaptor passes.
-class OpToOpPassAdaptorBase {
-public:
-  OpToOpPassAdaptorBase(OpPassManager &&mgr);
-  OpToOpPassAdaptorBase(const OpToOpPassAdaptorBase &rhs) = default;
-
-  /// Merge the current pass adaptor into given 'rhs'.
-  void mergeInto(OpToOpPassAdaptorBase &rhs);
-
-  /// Returns the pass managers held by this adaptor.
-  MutableArrayRef<OpPassManager> getPassManagers() { return mgrs; }
-
-  /// Returns the adaptor pass name.
-  std::string getName();
-
-protected:
-  // A set of adaptors to run.
-  SmallVector<OpPassManager, 1> mgrs;
-};
-
-/// An adaptor pass used to run operation passes over nested operations
-/// synchronously on a single thread.
+/// An adaptor pass used to run operation passes over nested operations.
 class OpToOpPassAdaptor
-    : public PassWrapper<OpToOpPassAdaptor, OperationPass<>>,
-      public OpToOpPassAdaptorBase {
+    : public PassWrapper<OpToOpPassAdaptor, OperationPass<>> {
 public:
   OpToOpPassAdaptor(OpPassManager &&mgr);
+  OpToOpPassAdaptor(const OpToOpPassAdaptor &rhs) = default;
 
   /// Run the held pipeline over all operations.
   void runOnOperation() override;
-};
 
-/// An adaptor pass used to run operation passes over nested operations
-/// asynchronously across multiple threads.
-class OpToOpPassAdaptorParallel
-    : public PassWrapper<OpToOpPassAdaptorParallel, OperationPass<>>,
-      public OpToOpPassAdaptorBase {
-public:
-  OpToOpPassAdaptorParallel(OpPassManager &&mgr);
+  /// Merge the current pass adaptor into given 'rhs'.
+  void mergeInto(OpToOpPassAdaptor &rhs);
 
-  /// Run the held pipeline over all operations.
-  void runOnOperation() override;
+  /// Returns the pass managers held by this adaptor.
+  MutableArrayRef<OpPassManager> getPassManagers() { return mgrs; }
 
   /// Return the async pass managers held by this parallel adaptor.
   MutableArrayRef<SmallVector<OpPassManager, 1>> getParallelPassManagers() {
     return asyncExecutors;
   }
 
+  /// Returns the adaptor pass name.
+  std::string getAdaptorName();
+
 private:
-  // A set of executors, cloned from the main executor, that run asynchronously
-  // on 
diff erent threads.
-  SmallVector<SmallVector<OpPassManager, 1>, 8> asyncExecutors;
-};
+  /// Run this pass adaptor synchronously.
+  void runOnOperationImpl();
+
+  /// Run this pass adaptor asynchronously.
+  void runOnOperationAsyncImpl();
 
-/// Utility function to convert the given class to the base adaptor it is an
-/// adaptor pass, returns nullptr otherwise.
-OpToOpPassAdaptorBase *getAdaptorPassBase(Pass *pass);
+  /// A set of adaptors to run.
+  SmallVector<OpPassManager, 1> mgrs;
 
-/// Utility function to return if a pass refers to an adaptor pass. Adaptor
-/// passes are those that internally execute a pipeline.
-inline bool isAdaptorPass(Pass *pass) {
-  return isa<OpToOpPassAdaptorParallel>(pass) || isa<OpToOpPassAdaptor>(pass);
-}
+  /// A set of executors, cloned from the main executor, that run asynchronously
+  /// on 
diff erent threads. This is used when threading is enabled.
+  SmallVector<SmallVector<OpPassManager, 1>, 8> asyncExecutors;
+};
 
 } // end namespace detail
 } // end namespace mlir

diff  --git a/mlir/lib/Pass/PassStatistics.cpp b/mlir/lib/Pass/PassStatistics.cpp
index 7ac54f7cf1af..6ef0d3bbea6a 100644
--- a/mlir/lib/Pass/PassStatistics.cpp
+++ b/mlir/lib/Pass/PassStatistics.cpp
@@ -60,7 +60,7 @@ static void printPassEntry(raw_ostream &os, unsigned indent, StringRef pass,
 static void printResultsAsList(raw_ostream &os, OpPassManager &pm) {
   llvm::StringMap<std::vector<Statistic>> mergedStats;
   std::function<void(Pass *)> addStats = [&](Pass *pass) {
-    auto *adaptor = getAdaptorPassBase(pass);
+    auto *adaptor = dyn_cast<OpToOpPassAdaptor>(pass);
 
     // If this is not an adaptor, add the stats to the list if there are any.
     if (!adaptor) {
@@ -105,13 +105,12 @@ static void printResultsAsList(raw_ostream &os, OpPassManager &pm) {
 static void printResultsAsPipeline(raw_ostream &os, OpPassManager &pm) {
   std::function<void(unsigned, Pass *)> printPass = [&](unsigned indent,
                                                         Pass *pass) {
-    // Handle the case of an adaptor pass.
-    if (auto *adaptor = getAdaptorPassBase(pass)) {
+    if (auto *adaptor = dyn_cast<OpToOpPassAdaptor>(pass)) {
       // If this adaptor has more than one internal pipeline, print an entry for
       // it.
       auto mgrs = adaptor->getPassManagers();
       if (mgrs.size() > 1) {
-        printPassEntry(os, indent, adaptor->getName());
+        printPassEntry(os, indent, adaptor->getAdaptorName());
         indent += 2;
       }
 
@@ -195,8 +194,8 @@ void OpPassManager::mergeStatisticsInto(OpPassManager &other) {
     Pass &pass = std::get<0>(passPair), &otherPass = std::get<1>(passPair);
 
     // If this is an adaptor, then recursively merge the pass managers.
-    if (auto *adaptorPass = getAdaptorPassBase(&pass)) {
-      auto *otherAdaptorPass = getAdaptorPassBase(&otherPass);
+    if (auto *adaptorPass = dyn_cast<OpToOpPassAdaptor>(&pass)) {
+      auto *otherAdaptorPass = cast<OpToOpPassAdaptor>(&otherPass);
       for (auto mgrs : llvm::zip(adaptorPass->getPassManagers(),
                                  otherAdaptorPass->getPassManagers()))
         std::get<0>(mgrs).mergeStatisticsInto(std::get<1>(mgrs));
@@ -217,18 +216,16 @@ void OpPassManager::mergeStatisticsInto(OpPassManager &other) {
 /// consumption(e.g. dumping).
 static void prepareStatistics(OpPassManager &pm) {
   for (Pass &pass : pm.getPasses()) {
-    OpToOpPassAdaptorBase *adaptor = getAdaptorPassBase(&pass);
+    OpToOpPassAdaptor *adaptor = dyn_cast<OpToOpPassAdaptor>(&pass);
     if (!adaptor)
       continue;
     MutableArrayRef<OpPassManager> nestedPms = adaptor->getPassManagers();
 
-    // If this is a parallel adaptor, merge the statistics from the async
-    // pass managers into the main nested pass managers.
-    if (auto *parallelAdaptor = dyn_cast<OpToOpPassAdaptorParallel>(&pass)) {
-      for (auto &asyncPM : parallelAdaptor->getParallelPassManagers()) {
-        for (unsigned i = 0, e = asyncPM.size(); i != e; ++i)
-          asyncPM[i].mergeStatisticsInto(nestedPms[i]);
-      }
+    // Merge the statistics from the async pass managers into the main nested
+    // pass managers.
+    for (auto &asyncPM : adaptor->getParallelPassManagers()) {
+      for (unsigned i = 0, e = asyncPM.size(); i != e; ++i)
+        asyncPM[i].mergeStatisticsInto(nestedPms[i]);
     }
 
     // Prepare the statistics of each of the nested passes.

diff  --git a/mlir/lib/Pass/PassTiming.cpp b/mlir/lib/Pass/PassTiming.cpp
index c8f0ad8afa50..71bf822a864b 100644
--- a/mlir/lib/Pass/PassTiming.cpp
+++ b/mlir/lib/Pass/PassTiming.cpp
@@ -277,17 +277,17 @@ void PassTiming::runAfterPipeline(const OperationName &name,
 
 /// Start a new timer for the given pass.
 void PassTiming::startPassTimer(Pass *pass) {
-  auto kind = isAdaptorPass(pass) ? TimerKind::PipelineCollection
-                                  : TimerKind::PassOrAnalysis;
+  auto kind = isa<OpToOpPassAdaptor>(pass) ? TimerKind::PipelineCollection
+                                           : TimerKind::PassOrAnalysis;
   Timer *timer = getTimer(pass, kind, [pass]() -> std::string {
-    if (auto *adaptor = getAdaptorPassBase(pass))
-      return adaptor->getName();
+    if (auto *adaptor = dyn_cast<OpToOpPassAdaptor>(pass))
+      return adaptor->getAdaptorName();
     return std::string(pass->getName());
   });
 
   // We don't actually want to time the adaptor passes, they gather their total
   // from their held passes.
-  if (!isAdaptorPass(pass))
+  if (!isa<OpToOpPassAdaptor>(pass))
     timer->start();
 }
 
@@ -302,9 +302,9 @@ void PassTiming::startAnalysisTimer(StringRef name, TypeID id) {
 void PassTiming::runAfterPass(Pass *pass, Operation *) {
   Timer *timer = popLastActiveTimer();
 
-  // If this is an OpToOpPassAdaptorParallel, then we need to merge in the
-  // timing data for the pipelines running on other threads.
-  if (isa<OpToOpPassAdaptorParallel>(pass)) {
+  // If this is a pass adaptor, then we need to merge in the timing data for the
+  // pipelines running on other threads.
+  if (isa<OpToOpPassAdaptor>(pass)) {
     auto toMerge = pipelinesToMerge.find({llvm::get_threadid(), pass});
     if (toMerge != pipelinesToMerge.end()) {
       for (auto &it : toMerge->second)
@@ -314,10 +314,7 @@ void PassTiming::runAfterPass(Pass *pass, Operation *) {
     return;
   }
 
-  // Adaptor passes aren't timed directly, so we don't need to stop their
-  // timers.
-  if (!isAdaptorPass(pass))
-    timer->stop();
+  timer->stop();
 }
 
 /// Stop a timer.


        


More information about the Mlir-commits mailing list