[Mlir-commits] [mlir] cd7107a - Handle the verifier at run() time in the PassManager instead of build time

Mehdi Amini llvmlistbot at llvm.org
Tue Nov 3 03:17:32 PST 2020


Author: Mehdi Amini
Date: 2020-11-03T11:17:14Z
New Revision: cd7107a62b4fe887c7a8f3a3e226ec5e92b7e5f5

URL: https://github.com/llvm/llvm-project/commit/cd7107a62b4fe887c7a8f3a3e226ec5e92b7e5f5
DIFF: https://github.com/llvm/llvm-project/commit/cd7107a62b4fe887c7a8f3a3e226ec5e92b7e5f5.diff

LOG: Handle the verifier at run() time in the PassManager instead of build time

This simplifies a few parts of the pass manager, but in particular we don't add as many
verifierpass as there are passes in the pipeline, and we can now enable/disable the
verifier after the fact on an already built PassManager.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D90669

Added: 
    

Modified: 
    mlir/include/mlir/Pass/PassManager.h
    mlir/lib/Pass/Pass.cpp
    mlir/lib/Pass/PassDetail.h
    mlir/lib/Support/MlirOptMain.cpp
    mlir/test/Pass/pass-timing.mlir
    mlir/test/Pass/pipeline-stats.mlir
    mlir/test/lib/Transforms/TestDynamicPipeline.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Pass/PassManager.h b/mlir/include/mlir/Pass/PassManager.h
index 4008911e05af..17bba93ee9a6 100644
--- a/mlir/include/mlir/Pass/PassManager.h
+++ b/mlir/include/mlir/Pass/PassManager.h
@@ -48,8 +48,8 @@ struct PassExecutionState;
 /// other OpPassManagers or the top-level PassManager.
 class OpPassManager {
 public:
-  OpPassManager(Identifier name, bool verifyPasses);
-  OpPassManager(StringRef name, bool verifyPasses);
+  OpPassManager(Identifier name);
+  OpPassManager(StringRef name);
   OpPassManager(OpPassManager &&rhs);
   OpPassManager(const OpPassManager &rhs);
   ~OpPassManager();
@@ -149,8 +149,7 @@ enum class PassDisplayMode {
 /// The main pass manager and pipeline builder.
 class PassManager : public OpPassManager {
 public:
-  // If verifyPasses is true, the verifier is run after each pass.
-  PassManager(MLIRContext *ctx, bool verifyPasses = true);
+  PassManager(MLIRContext *ctx);
   ~PassManager();
 
   /// Run the passes within this manager on the provided module.
@@ -168,6 +167,9 @@ class PassManager : public OpPassManager {
   void enableCrashReproducerGeneration(StringRef outputFile,
                                        bool genLocalReproducer = false);
 
+  /// Runs the verifier after each individual pass.
+  void enableVerifier(bool enabled = true);
+
   //===--------------------------------------------------------------------===//
   // Instrumentations
   //===--------------------------------------------------------------------===//
@@ -330,6 +332,9 @@ class PassManager : public OpPassManager {
 
   /// Flag that specifies if the generated crash reproducer should be local.
   bool localReproducer : 1;
+
+  /// A flag that indicates if the IR should be verified in between passes.
+  bool verifyPasses : 1;
 };
 
 /// Register a set of useful command-line options that can be used to configure

diff  --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp
index d84b71ccf59a..fb56d85e896b 100644
--- a/mlir/lib/Pass/Pass.cpp
+++ b/mlir/lib/Pass/Pass.cpp
@@ -91,10 +91,9 @@ void VerifierPass::runOnOperation() {
 namespace mlir {
 namespace detail {
 struct OpPassManagerImpl {
-  OpPassManagerImpl(Identifier identifier, bool verifyPasses)
-      : name(identifier), identifier(identifier), verifyPasses(verifyPasses) {}
-  OpPassManagerImpl(StringRef name, bool verifyPasses)
-      : name(name), verifyPasses(verifyPasses) {}
+  OpPassManagerImpl(Identifier identifier)
+      : name(identifier), identifier(identifier) {}
+  OpPassManagerImpl(StringRef name) : name(name) {}
 
   /// Merge the passes of this pass manager into the one provided.
   void mergeInto(OpPassManagerImpl &rhs);
@@ -129,9 +128,6 @@ struct OpPassManagerImpl {
   /// operation that passes of this pass manager operate on.
   Optional<Identifier> identifier;
 
-  /// Flag that specifies if the IR should be verified after each pass has run.
-  bool verifyPasses : 1;
-
   /// The set of passes to run as part of this pass manager.
   std::vector<std::unique_ptr<Pass>> passes;
 };
@@ -146,14 +142,14 @@ void OpPassManagerImpl::mergeInto(OpPassManagerImpl &rhs) {
 }
 
 OpPassManager &OpPassManagerImpl::nest(Identifier nestedName) {
-  OpPassManager nested(nestedName, verifyPasses);
+  OpPassManager nested(nestedName);
   auto *adaptor = new OpToOpPassAdaptor(std::move(nested));
   addPass(std::unique_ptr<Pass>(adaptor));
   return adaptor->getPassManagers().front();
 }
 
 OpPassManager &OpPassManagerImpl::nest(StringRef nestedName) {
-  OpPassManager nested(nestedName, verifyPasses);
+  OpPassManager nested(nestedName);
   auto *adaptor = new OpToOpPassAdaptor(std::move(nested));
   addPass(std::unique_ptr<Pass>(adaptor));
   return adaptor->getPassManagers().front();
@@ -167,8 +163,6 @@ void OpPassManagerImpl::addPass(std::unique_ptr<Pass> pass) {
     return nest(*passOpName).addPass(std::move(pass));
 
   passes.emplace_back(std::move(pass));
-  if (verifyPasses)
-    passes.emplace_back(std::make_unique<VerifierPass>());
 }
 
 void OpPassManagerImpl::coalesceAdjacentAdaptorPasses() {
@@ -193,14 +187,6 @@ void OpPassManagerImpl::coalesceAdjacentAdaptorPasses() {
       // Otherwise, merge into the existing adaptor and delete the current one.
       currentAdaptor->mergeInto(*lastAdaptor);
       it->reset();
-
-      // If the verifier is enabled, then next pass is a verifier run so
-      // drop it. Verifier passes are inserted after every pass, so this one
-      // would be a duplicate.
-      if (verifyPasses) {
-        assert(std::next(it) != e && isa<VerifierPass>(*std::next(it)));
-        (++it)->reset();
-      }
     } else if (lastAdaptor && !isa<VerifierPass>(*it)) {
       // If this pass is not an adaptor and not a verifier pass, then coalesce
       // and forget any existing adaptor.
@@ -254,14 +240,14 @@ void OpPassManagerImpl::splitAdaptorPasses() {
 // OpPassManager
 //===----------------------------------------------------------------------===//
 
-OpPassManager::OpPassManager(Identifier name, bool verifyPasses)
-    : impl(new OpPassManagerImpl(name, verifyPasses)) {}
-OpPassManager::OpPassManager(StringRef name, bool verifyPasses)
-    : impl(new OpPassManagerImpl(name, verifyPasses)) {}
+OpPassManager::OpPassManager(Identifier name)
+    : impl(new OpPassManagerImpl(name)) {}
+OpPassManager::OpPassManager(StringRef name)
+    : impl(new OpPassManagerImpl(name)) {}
 OpPassManager::OpPassManager(OpPassManager &&rhs) : impl(std::move(rhs.impl)) {}
 OpPassManager::OpPassManager(const OpPassManager &rhs) { *this = rhs; }
 OpPassManager &OpPassManager::operator=(const OpPassManager &rhs) {
-  impl.reset(new OpPassManagerImpl(rhs.impl->name, rhs.impl->verifyPasses));
+  impl.reset(new OpPassManagerImpl(rhs.impl->name));
   for (auto &pass : rhs.impl->passes)
     impl->passes.emplace_back(pass->clone());
   return *this;
@@ -356,7 +342,7 @@ void OpPassManager::getDependentDialects(DialectRegistry &dialects) const {
 //===----------------------------------------------------------------------===//
 
 LogicalResult OpToOpPassAdaptor::run(Pass *pass, Operation *op,
-                                     AnalysisManager am) {
+                                     AnalysisManager am, bool verifyPasses) {
   if (!op->getName().getAbstractOperation())
     return op->emitOpError()
            << "trying to schedule a pass on an unregistered operation";
@@ -368,18 +354,18 @@ LogicalResult OpToOpPassAdaptor::run(Pass *pass, Operation *op,
   // Initialize the pass state with a callback for the pass to dynamically
   // execute a pipeline on the currently visited operation.
   auto dynamic_pipeline_callback =
-      [op, &am](OpPassManager &pipeline, Operation *root) {
-        if (!op->isAncestor(root)) {
-          root->emitOpError()
-              << "Trying to schedule a dynamic pipeline on an "
-                 "operation that isn't "
-                 "nested under the current operation the pass is processing";
-          return failure();
-        }
-        AnalysisManager nestedAm = am.nest(root);
-        return OpToOpPassAdaptor::runPipeline(pipeline.getPasses(), root,
-                                              nestedAm);
-      };
+      [op, &am, verifyPasses](OpPassManager &pipeline,
+                              Operation *root) -> LogicalResult {
+    if (!op->isAncestor(root))
+      return root->emitOpError()
+             << "Trying to schedule a dynamic pipeline on an "
+                "operation that isn't "
+                "nested under the current operation the pass is processing";
+
+    AnalysisManager nestedAm = am.nest(root);
+    return OpToOpPassAdaptor::runPipeline(pipeline.getPasses(), root, nestedAm,
+                                          verifyPasses);
+  };
   pass->passState.emplace(op, am, dynamic_pipeline_callback);
   // Instrument before the pass has run.
   PassInstrumentor *pi = am.getPassInstrumentor();
@@ -387,13 +373,20 @@ LogicalResult OpToOpPassAdaptor::run(Pass *pass, Operation *op,
     pi->runBeforePass(pass, op);
 
   // Invoke the virtual runOnOperation method.
-  pass->runOnOperation();
+  if (auto *adaptor = dyn_cast<OpToOpPassAdaptor>(pass))
+    adaptor->runOnOperation(verifyPasses);
+  else
+    pass->runOnOperation();
+  bool passFailed = pass->passState->irAndPassFailed.getInt();
 
   // Invalidate any non preserved analyses.
   am.invalidate(pass->passState->preservedAnalyses);
 
+  // Run the verifier if this pass didn't fail already.
+  if (!passFailed && verifyPasses)
+    passFailed = failed(verify(op));
+
   // Instrument after the pass has run.
-  bool passFailed = pass->passState->irAndPassFailed.getInt();
   if (pi) {
     if (passFailed)
       pi->runAfterPassFailed(pass, op);
@@ -408,7 +401,7 @@ LogicalResult OpToOpPassAdaptor::run(Pass *pass, Operation *op,
 /// Run the given operation and analysis manager on a provided op pass manager.
 LogicalResult OpToOpPassAdaptor::runPipeline(
     iterator_range<OpPassManager::pass_iterator> passes, Operation *op,
-    AnalysisManager am) {
+    AnalysisManager am, bool verifyPasses) {
   auto scope_exit = llvm::make_scope_exit([&] {
     // Clear out any computed operation analyses. These analyses won't be used
     // any more in this pipeline, and this helps reduce the current working set
@@ -419,7 +412,7 @@ LogicalResult OpToOpPassAdaptor::runPipeline(
 
   // Run the pipeline over the provided operation.
   for (Pass &pass : passes)
-    if (failed(run(&pass, op, am)))
+    if (failed(run(&pass, op, am, verifyPasses)))
       return failure();
 
   return success();
@@ -485,16 +478,21 @@ std::string OpToOpPassAdaptor::getAdaptorName() {
   return os.str();
 }
 
-/// Run the held pipeline over all nested operations.
 void OpToOpPassAdaptor::runOnOperation() {
+  llvm_unreachable(
+      "Unexpected call to Pass::runOnOperation() on OpToOpPassAdaptor");
+}
+
+/// Run the held pipeline over all nested operations.
+void OpToOpPassAdaptor::runOnOperation(bool verifyPasses) {
   if (getContext().isMultithreadingEnabled())
-    runOnOperationAsyncImpl();
+    runOnOperationAsyncImpl(verifyPasses);
   else
-    runOnOperationImpl();
+    runOnOperationImpl(verifyPasses);
 }
 
 /// Run this pass adaptor synchronously.
-void OpToOpPassAdaptor::runOnOperationImpl() {
+void OpToOpPassAdaptor::runOnOperationImpl(bool verifyPasses) {
   auto am = getAnalysisManager();
   PassInstrumentation::PipelineParentInfo parentInfo = {llvm::get_threadid(),
                                                         this};
@@ -511,7 +509,8 @@ void OpToOpPassAdaptor::runOnOperationImpl() {
         // Run the held pipeline over the current operation.
         if (instrumentor)
           instrumentor->runBeforePipeline(opName, parentInfo);
-        auto result = runPipeline(mgr->getPasses(), &op, am.nest(&op));
+        LogicalResult result =
+            runPipeline(mgr->getPasses(), &op, am.nest(&op), verifyPasses);
         if (instrumentor)
           instrumentor->runAfterPipeline(opName, parentInfo);
 
@@ -532,7 +531,7 @@ static bool hasSizeMismatch(ArrayRef<OpPassManager> lhs,
 }
 
 /// Run this pass adaptor synchronously.
-void OpToOpPassAdaptor::runOnOperationAsyncImpl() {
+void OpToOpPassAdaptor::runOnOperationAsyncImpl(bool verifyPasses) {
   AnalysisManager am = getAnalysisManager();
 
   // Create the async executors if they haven't been created, or if the main
@@ -594,7 +593,7 @@ void OpToOpPassAdaptor::runOnOperationAsyncImpl() {
           if (instrumentor)
             instrumentor->runBeforePipeline(opName, parentInfo);
           auto pipelineResult =
-              runPipeline(pm->getPasses(), it.first, it.second);
+              runPipeline(pm->getPasses(), it.first, it.second, verifyPasses);
           if (instrumentor)
             instrumentor->runAfterPipeline(opName, parentInfo);
 
@@ -741,15 +740,11 @@ LogicalResult PassManager::runWithCrashRecovery(ModuleOp module,
   // isolation.
   impl->splitAdaptorPasses();
 
-  // If this is a local producer, run each of the passes individually. If the
-  // verifier is enabled, each pass will have a verifier after. This is included
-  // in the recovery run.
-  unsigned stride = impl->verifyPasses ? 2 : 1;
+  // If this is a local producer, run each of the passes individually.
   MutableArrayRef<std::unique_ptr<Pass>> passes = impl->passes;
-  for (unsigned i = 0, e = passes.size(); i != e; i += stride) {
-    if (failed(runWithCrashRecovery(passes.slice(i, stride), module, am)))
+  for (std::unique_ptr<Pass> &pass : passes)
+    if (failed(runWithCrashRecovery(pass, module, am)))
       return failure();
-  }
   return success();
 }
 
@@ -759,7 +754,7 @@ PassManager::runWithCrashRecovery(MutableArrayRef<std::unique_ptr<Pass>> passes,
                                   ModuleOp module, AnalysisManager am) {
   RecoveryReproducerContext context(passes, module, *crashReproducerFileName,
                                     !getContext()->isMultithreadingEnabled(),
-                                    impl->verifyPasses);
+                                    verifyPasses);
 
   // Safely invoke the passes within a recovery context.
   llvm::CrashRecoveryContext::Enable();
@@ -767,7 +762,7 @@ PassManager::runWithCrashRecovery(MutableArrayRef<std::unique_ptr<Pass>> passes,
   llvm::CrashRecoveryContext recoveryContext;
   recoveryContext.RunSafelyOnThread([&] {
     for (std::unique_ptr<Pass> &pass : passes)
-      if (failed(OpToOpPassAdaptor::run(pass.get(), module, am)))
+      if (failed(OpToOpPassAdaptor::run(pass.get(), module, am, verifyPasses)))
         return;
     passManagerResult = success();
   });
@@ -788,13 +783,15 @@ PassManager::runWithCrashRecovery(MutableArrayRef<std::unique_ptr<Pass>> passes,
 // PassManager
 //===----------------------------------------------------------------------===//
 
-PassManager::PassManager(MLIRContext *ctx, bool verifyPasses)
-    : OpPassManager(Identifier::get(ModuleOp::getOperationName(), ctx),
-                    verifyPasses),
-      context(ctx), passTiming(false), localReproducer(false) {}
+PassManager::PassManager(MLIRContext *ctx)
+    : OpPassManager(Identifier::get(ModuleOp::getOperationName(), ctx)),
+      context(ctx), passTiming(false), localReproducer(false),
+      verifyPasses(true) {}
 
 PassManager::~PassManager() {}
 
+void PassManager::enableVerifier(bool enabled) { verifyPasses = enabled; }
+
 /// Run the passes within this manager on the provided module.
 LogicalResult PassManager::run(ModuleOp module) {
   // Before running, make sure to coalesce any adjacent pass adaptors in the
@@ -814,10 +811,10 @@ LogicalResult PassManager::run(ModuleOp module) {
 
   // If reproducer generation is enabled, run the pass manager with crash
   // handling enabled.
-  LogicalResult result =
-      crashReproducerFileName
-          ? runWithCrashRecovery(module, am)
-          : OpToOpPassAdaptor::runPipeline(getPasses(), module, am);
+  LogicalResult result = crashReproducerFileName
+                             ? runWithCrashRecovery(module, am)
+                             : OpToOpPassAdaptor::runPipeline(
+                                   getPasses(), module, am, verifyPasses);
 
   // Notify the context that the run is done.
   module.getContext()->exitMultiThreadedExecution();

diff  --git a/mlir/lib/Pass/PassDetail.h b/mlir/lib/Pass/PassDetail.h
index fcd03da7b9bc..6bcf021d20fe 100644
--- a/mlir/lib/Pass/PassDetail.h
+++ b/mlir/lib/Pass/PassDetail.h
@@ -35,6 +35,7 @@ class OpToOpPassAdaptor
   OpToOpPassAdaptor(const OpToOpPassAdaptor &rhs) = default;
 
   /// Run the held pipeline over all operations.
+  void runOnOperation(bool verifyPasses);
   void runOnOperation() override;
 
   /// Merge the current pass adaptor into given 'rhs'.
@@ -57,19 +58,20 @@ class OpToOpPassAdaptor
 
 private:
   /// Run this pass adaptor synchronously.
-  void runOnOperationImpl();
+  void runOnOperationImpl(bool verifyPasses);
 
   /// Run this pass adaptor asynchronously.
-  void runOnOperationAsyncImpl();
+  void runOnOperationAsyncImpl(bool verifyPasses);
 
   /// Run the given operation and analysis manager on a single pass.
-  static LogicalResult run(Pass *pass, Operation *op, AnalysisManager am);
+  static LogicalResult run(Pass *pass, Operation *op, AnalysisManager am,
+                           bool verifyPasses);
 
   /// Run the given operation and analysis manager on a provided op pass
   /// manager.
   static LogicalResult
   runPipeline(iterator_range<OpPassManager::pass_iterator> passes,
-              Operation *op, AnalysisManager am);
+              Operation *op, AnalysisManager am, bool verifyPasses);
 
   /// A set of adaptors to run.
   SmallVector<OpPassManager, 1> mgrs;

diff  --git a/mlir/lib/Support/MlirOptMain.cpp b/mlir/lib/Support/MlirOptMain.cpp
index 1c2a1fe18ca3..3752cf4f3a13 100644
--- a/mlir/lib/Support/MlirOptMain.cpp
+++ b/mlir/lib/Support/MlirOptMain.cpp
@@ -58,7 +58,8 @@ static LogicalResult performActions(raw_ostream &os, bool verifyDiagnostics,
     return failure();
 
   // Apply any pass manager command line options.
-  PassManager pm(context, verifyPasses);
+  PassManager pm(context);
+  pm.enableVerifier(verifyPasses);
   applyPassManagerCLOptions(pm);
 
   // Build the provided pipeline.

diff  --git a/mlir/test/Pass/pass-timing.mlir b/mlir/test/Pass/pass-timing.mlir
index 6cd8a29e6f48..63a2945eca4d 100644
--- a/mlir/test/Pass/pass-timing.mlir
+++ b/mlir/test/Pass/pass-timing.mlir
@@ -8,7 +8,6 @@
 // LIST: Total Execution Time:
 // LIST: Name
 // LIST-DAG: Canonicalizer
-// LIST-DAG: Verifier
 // LIST-DAG: CSE
 // LIST-DAG: DominanceInfo
 // LIST: Total
@@ -19,20 +18,15 @@
 // PIPELINE-NEXT: 'func' Pipeline
 // PIPELINE-NEXT:   CSE
 // PIPELINE-NEXT:     (A) DominanceInfo
-// PIPELINE-NEXT:   Verifier
 // PIPELINE-NEXT:   Canonicalizer
-// PIPELINE-NEXT:   Verifier
 // PIPELINE-NEXT:   CSE
 // PIPELINE-NEXT:     (A) DominanceInfo
-// PIPELINE-NEXT:   Verifier
-// PIPELINE-NEXT: Verifier
 // PIPELINE-NEXT: Total
 
 // MT_LIST: Pass execution timing report
 // MT_LIST: Total Execution Time:
 // MT_LIST: Name
 // MT_LIST-DAG: Canonicalizer
-// MT_LIST-DAG: Verifier
 // MT_LIST-DAG: CSE
 // MT_LIST-DAG: DominanceInfo
 // MT_LIST: Total
@@ -43,13 +37,9 @@
 // MT_PIPELINE-NEXT: 'func' Pipeline
 // MT_PIPELINE-NEXT:   CSE
 // MT_PIPELINE-NEXT:     (A) DominanceInfo
-// MT_PIPELINE-NEXT:   Verifier
 // MT_PIPELINE-NEXT:   Canonicalizer
-// MT_PIPELINE-NEXT:   Verifier
 // MT_PIPELINE-NEXT:   CSE
 // MT_PIPELINE-NEXT:     (A) DominanceInfo
-// MT_PIPELINE-NEXT:   Verifier
-// MT_PIPELINE-NEXT: Verifier
 // MT_PIPELINE-NEXT: Total
 
 // NESTED_MT_PIPELINE: Pass execution timing report

diff  --git a/mlir/test/Pass/pipeline-stats.mlir b/mlir/test/Pass/pipeline-stats.mlir
index e3ee144ad471..6aee2bfc8314 100644
--- a/mlir/test/Pass/pipeline-stats.mlir
+++ b/mlir/test/Pass/pipeline-stats.mlir
@@ -10,11 +10,8 @@
 // PIPELINE: 'func' Pipeline
 // PIPELINE-NEXT:   TestStatisticPass
 // PIPELINE-NEXT:     (S) {{0|4}} num-ops - Number of operations counted
-// PIPELINE-NEXT:   Verifier
 // PIPELINE-NEXT:   TestStatisticPass
 // PIPELINE-NEXT:     (S) {{0|4}} num-ops - Number of operations counted
-// PIPELINE-NEXT:   Verifier
-// PIPELINE-NEXT: Verifier
 
 func @foo() {
   return

diff  --git a/mlir/test/lib/Transforms/TestDynamicPipeline.cpp b/mlir/test/lib/Transforms/TestDynamicPipeline.cpp
index 92e8861b5ad0..b7d6f1227bc0 100644
--- a/mlir/test/lib/Transforms/TestDynamicPipeline.cpp
+++ b/mlir/test/lib/Transforms/TestDynamicPipeline.cpp
@@ -25,7 +25,7 @@ class TestDynamicPipelinePass
     : public PassWrapper<TestDynamicPipelinePass, OperationPass<>> {
 public:
   void getDependentDialects(DialectRegistry &registry) const override {
-    OpPassManager pm(ModuleOp::getOperationName(), false);
+    OpPassManager pm(ModuleOp::getOperationName());
     parsePassPipeline(pipeline, pm, llvm::errs());
     pm.getDependentDialects(registry);
   }
@@ -54,7 +54,7 @@ class TestDynamicPipelinePass
     }
     if (!pm) {
       pm = std::make_unique<OpPassManager>(
-          getOperation()->getName().getIdentifier(), false);
+          getOperation()->getName().getIdentifier());
       parsePassPipeline(pipeline, *pm, llvm::errs());
     }
 


        


More information about the Mlir-commits mailing list