[Mlir-commits] [mlir] 92469ca - [mlir] Refactor the implementation of pass crash reproducers

River Riddle llvmlistbot at llvm.org
Wed May 19 17:03:24 PDT 2021


Author: River Riddle
Date: 2021-05-19T16:59:53-07:00
New Revision: 92469ca027b2e794aa9162931665b379445ca711

URL: https://github.com/llvm/llvm-project/commit/92469ca027b2e794aa9162931665b379445ca711
DIFF: https://github.com/llvm/llvm-project/commit/92469ca027b2e794aa9162931665b379445ca711.diff

LOG: [mlir] Refactor the implementation of pass crash reproducers

The current implementation has several key limitations and weirdness, e.g local reproducers don't support dynamic pass pipelines, error messages don't include the passes that failed, etc. This revision refactors the implementation to support more use cases, and also be much cleaner.

The main change in this revision, aside from moving the implementation out of Pass.cpp and into its own file, is the addition of a crash recovery pass instrumentation. For local reproducers, this instrumentation handles setting up the recovery context before executing each pass. For global reproducers, the instrumentation is used to provide a more detailed error message, containing information about which passes are running and on which operations.

Example of new message:

```
error: Failures have been detected while processing an MLIR pass pipeline
note: Pipeline failed while executing [`TestCrashRecoveryPass` on 'module' operation: @foo]: reproducer generated at `crash-recovery.mlir.tmp`
```

Differential Revision: https://reviews.llvm.org/D101854

Added: 
    mlir/lib/Pass/PassCrashRecovery.cpp

Modified: 
    mlir/docs/PassManagement.md
    mlir/include/mlir/Pass/PassManager.h
    mlir/lib/Pass/CMakeLists.txt
    mlir/lib/Pass/Pass.cpp
    mlir/lib/Pass/PassDetail.h
    mlir/test/Pass/crash-recovery.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/docs/PassManagement.md b/mlir/docs/PassManagement.md
index ec890dab710c..fcbf6365c1b9 100644
--- a/mlir/docs/PassManagement.md
+++ b/mlir/docs/PassManagement.md
@@ -1213,11 +1213,14 @@ useful for situations where the crash is known to be within a specific pass, or
 when the original input relies on components (like dialects or passes) that may
 not always be available.
 
+Note: Local reproducer generation requires that multi-threading is
+disabled(`-mlir-disable-threading`)
+
 For example, if the failure in the previous example came from `canonicalize`,
 the following reproducer will be generated:
 
 ```mlir
-// configuration: -pass-pipeline='func(canonicalize)' -verify-each
+// configuration: -pass-pipeline='func(canonicalize)' -verify-each -mlir-disable-threading
 
 module {
   func @foo() {

diff  --git a/mlir/include/mlir/Pass/PassManager.h b/mlir/include/mlir/Pass/PassManager.h
index a012395a32f4..0765e3841cdc 100644
--- a/mlir/include/mlir/Pass/PassManager.h
+++ b/mlir/include/mlir/Pass/PassManager.h
@@ -37,6 +37,7 @@ class PassInstrumentor;
 namespace detail {
 struct OpPassManagerImpl;
 class OpToOpPassAdaptor;
+class PassCrashReproducerGenerator;
 struct PassExecutionState;
 } // end namespace detail
 
@@ -373,12 +374,11 @@ class PassManager : public OpPassManager {
   /// Dump the statistics of the passes within this pass manager.
   void dumpStatistics();
 
-  /// Run the pass manager with crash recover enabled.
+  /// Run the pass manager with crash recovery enabled.
   LogicalResult runWithCrashRecovery(Operation *op, AnalysisManager am);
-  /// Run the given passes with crash recover enabled.
-  LogicalResult
-  runWithCrashRecovery(MutableArrayRef<std::unique_ptr<Pass>> passes,
-                       Operation *op, AnalysisManager am);
+
+  /// Run the passes of the pass manager, and return the result.
+  LogicalResult runPasses(Operation *op, AnalysisManager am);
 
   /// Context this PassManager was initialized with.
   MLIRContext *context;
@@ -389,8 +389,9 @@ class PassManager : public OpPassManager {
   /// A manager for pass instrumentations.
   std::unique_ptr<PassInstrumentor> instrumentor;
 
-  /// An optional factory to use when generating a crash reproducer if valid.
-  ReproducerStreamFactory crashReproducerStreamFactory;
+  /// An optional crash reproducer generator, if this pass manager is setup to
+  /// generate reproducers.
+  std::unique_ptr<detail::PassCrashReproducerGenerator> crashReproGenerator;
 
   /// A hash key used to detect when reinitialization is necessary.
   llvm::hash_code initializationKey;
@@ -398,9 +399,6 @@ class PassManager : public OpPassManager {
   /// Flag that specifies if pass timing is enabled.
   bool passTiming : 1;
 
-  /// Flag that specifies if the generated crash reproducer should be local.
-  bool localReproducer : 1;
-
   /// A flag that indicates if the IR should be verified in between passes.
   bool verifyPasses : 1;
 };

diff  --git a/mlir/lib/Pass/CMakeLists.txt b/mlir/lib/Pass/CMakeLists.txt
index 3ec596e7a7c2..5ca9b4163bb8 100644
--- a/mlir/lib/Pass/CMakeLists.txt
+++ b/mlir/lib/Pass/CMakeLists.txt
@@ -1,6 +1,7 @@
 add_mlir_library(MLIRPass
   IRPrinting.cpp
   Pass.cpp
+  PassCrashRecovery.cpp
   PassManagerOptions.cpp
   PassRegistry.cpp
   PassStatistics.cpp

diff  --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp
index e02c71d58e93..28309e462ea6 100644
--- a/mlir/lib/Pass/Pass.cpp
+++ b/mlir/lib/Pass/Pass.cpp
@@ -102,10 +102,6 @@ struct OpPassManagerImpl {
   /// recursively through the pipeline graph.
   void coalesceAdjacentAdaptorPasses();
 
-  /// Split all of AdaptorPasses such that each adaptor only contains one leaf
-  /// pass.
-  void splitAdaptorPasses();
-
   /// Return the operation name of this pass manager as an identifier.
   Identifier getOpName(MLIRContext &context) {
     if (!identifier)
@@ -213,27 +209,6 @@ void OpPassManagerImpl::coalesceAdjacentAdaptorPasses() {
   llvm::erase_if(passes, std::logical_not<std::unique_ptr<Pass>>());
 }
 
-void OpPassManagerImpl::splitAdaptorPasses() {
-  std::vector<std::unique_ptr<Pass>> oldPasses;
-  std::swap(passes, oldPasses);
-
-  for (std::unique_ptr<Pass> &pass : oldPasses) {
-    // If this pass isn't an adaptor, move it directly to the new pass list.
-    auto *currentAdaptor = dyn_cast<OpToOpPassAdaptor>(pass.get());
-    if (!currentAdaptor) {
-      addPass(std::move(pass));
-      continue;
-    }
-
-    // Otherwise, split the adaptors of each manager within the adaptor.
-    for (OpPassManager &adaptorPM : currentAdaptor->getPassManagers()) {
-      adaptorPM.getImpl().splitAdaptorPasses();
-      for (std::unique_ptr<Pass> &nestedPass : adaptorPM.getImpl().passes)
-        nest(adaptorPM.getOpName()).addPass(std::move(nestedPass));
-    }
-  }
-}
-
 //===----------------------------------------------------------------------===//
 // OpPassManager
 //===----------------------------------------------------------------------===//
@@ -645,210 +620,6 @@ void OpToOpPassAdaptor::runOnOperationAsyncImpl(bool verifyPasses) {
     signalPassFailure();
 }
 
-//===----------------------------------------------------------------------===//
-// PassCrashReproducer
-//===----------------------------------------------------------------------===//
-
-namespace {
-/// This class contains all of the context for generating a recovery reproducer.
-/// Each recovery context is registered globally to allow for generating
-/// reproducers when a signal is raised, such as a segfault.
-struct RecoveryReproducerContext {
-  RecoveryReproducerContext(MutableArrayRef<std::unique_ptr<Pass>> passes,
-                            Operation *op,
-                            PassManager::ReproducerStreamFactory &crashStream,
-                            bool disableThreads, bool verifyPasses);
-  ~RecoveryReproducerContext();
-
-  /// Generate a reproducer with the current context.
-  LogicalResult generate(std::string &error);
-
-private:
-  /// This function is invoked in the event of a crash.
-  static void crashHandler(void *);
-
-  /// Register a signal handler to run in the event of a crash.
-  static void registerSignalHandler();
-
-  /// The textual description of the currently executing pipeline.
-  std::string pipeline;
-
-  /// The MLIR operation representing the IR before the crash.
-  Operation *preCrashOperation;
-
-  /// The factory for the reproducer output stream to use when generating the
-  /// reproducer.
-  PassManager::ReproducerStreamFactory &crashStreamFactory;
-
-  /// Various pass manager and context flags.
-  bool disableThreads;
-  bool verifyPasses;
-
-  /// The current set of active reproducer contexts. This is used in the event
-  /// of a crash. This is not thread_local as the pass manager may produce any
-  /// number of child threads. This uses a set to allow for multiple MLIR pass
-  /// managers to be running at the same time.
-  static llvm::ManagedStatic<llvm::sys::SmartMutex<true>> reproducerMutex;
-  static llvm::ManagedStatic<
-      llvm::SmallSetVector<RecoveryReproducerContext *, 1>>
-      reproducerSet;
-};
-
-/// Instance of ReproducerStream backed by file.
-struct FileReproducerStream : public PassManager::ReproducerStream {
-  FileReproducerStream(std::unique_ptr<llvm::ToolOutputFile> outputFile)
-      : outputFile(std::move(outputFile)) {}
-  ~FileReproducerStream() override;
-
-  /// Description of the reproducer stream.
-  StringRef description() override;
-
-  /// Stream on which to output reprooducer.
-  raw_ostream &os() override;
-
-private:
-  /// ToolOutputFile corresponding to opened `filename`.
-  std::unique_ptr<llvm::ToolOutputFile> outputFile = nullptr;
-};
-
-} // end anonymous namespace
-
-llvm::ManagedStatic<llvm::sys::SmartMutex<true>>
-    RecoveryReproducerContext::reproducerMutex;
-llvm::ManagedStatic<llvm::SmallSetVector<RecoveryReproducerContext *, 1>>
-    RecoveryReproducerContext::reproducerSet;
-
-RecoveryReproducerContext::RecoveryReproducerContext(
-    MutableArrayRef<std::unique_ptr<Pass>> passes, Operation *op,
-    PassManager::ReproducerStreamFactory &crashStreamFactory,
-    bool disableThreads, bool verifyPasses)
-    : preCrashOperation(op->clone()), crashStreamFactory(crashStreamFactory),
-      disableThreads(disableThreads), verifyPasses(verifyPasses) {
-  // Grab the textual pipeline being executed..
-  {
-    llvm::raw_string_ostream pipelineOS(pipeline);
-    ::printAsTextualPipeline(passes, pipelineOS);
-  }
-
-  // Make sure that the handler is registered, and update the current context.
-  llvm::sys::SmartScopedLock<true> producerLock(*reproducerMutex);
-  if (reproducerSet->empty())
-    llvm::CrashRecoveryContext::Enable();
-  registerSignalHandler();
-  reproducerSet->insert(this);
-}
-
-RecoveryReproducerContext::~RecoveryReproducerContext() {
-  // Erase the cloned preCrash IR that we cached.
-  preCrashOperation->erase();
-
-  llvm::sys::SmartScopedLock<true> producerLock(*reproducerMutex);
-  reproducerSet->remove(this);
-  if (reproducerSet->empty())
-    llvm::CrashRecoveryContext::Disable();
-}
-
-/// Description of the reproducer stream.
-StringRef FileReproducerStream::description() {
-  return outputFile->getFilename();
-}
-
-/// Stream on which to output reproducer.
-raw_ostream &FileReproducerStream::os() { return outputFile->os(); }
-
-FileReproducerStream::~FileReproducerStream() { outputFile->keep(); }
-
-LogicalResult RecoveryReproducerContext::generate(std::string &error) {
-  std::unique_ptr<PassManager::ReproducerStream> crashStream =
-      crashStreamFactory(error);
-  if (!crashStream)
-    return failure();
-
-  // Output the current pass manager configuration.
-  auto &os = crashStream->os();
-  os << "// configuration: -pass-pipeline='" << pipeline << "'";
-  if (disableThreads)
-    os << " -mlir-disable-threading";
-  if (verifyPasses)
-    os << " -verify-each";
-  os << '\n';
-
-  // Output the .mlir module.
-  preCrashOperation->print(os);
-
-  bool shouldPrintOnOp =
-      preCrashOperation->getContext()->shouldPrintOpOnDiagnostic();
-  preCrashOperation->getContext()->printOpOnDiagnostic(false);
-  preCrashOperation->emitError()
-      << "A failure has been detected while processing the MLIR module, a "
-         "reproducer has been generated in '"
-      << crashStream->description() << "'";
-  preCrashOperation->getContext()->printOpOnDiagnostic(shouldPrintOnOp);
-  return success();
-}
-
-void RecoveryReproducerContext::crashHandler(void *) {
-  // Walk the current stack of contexts and generate a reproducer for each one.
-  // We can't know for certain which one was the cause, so we need to generate
-  // a reproducer for all of them.
-  std::string ignored;
-  for (RecoveryReproducerContext *context : *reproducerSet)
-    (void)context->generate(ignored);
-}
-
-void RecoveryReproducerContext::registerSignalHandler() {
-  // Ensure that the handler is only registered once.
-  static bool registered =
-      (llvm::sys::AddSignalHandler(crashHandler, nullptr), false);
-  (void)registered;
-}
-
-/// Run the pass manager with crash recover enabled.
-LogicalResult PassManager::runWithCrashRecovery(Operation *op,
-                                                AnalysisManager am) {
-  // If this isn't a local producer, run all of the passes in recovery mode.
-  if (!localReproducer)
-    return runWithCrashRecovery(impl->passes, op, am);
-
-  // Split the passes within adaptors to ensure that each pass can be run in
-  // isolation.
-  impl->splitAdaptorPasses();
-
-  // If this is a local producer, run each of the passes individually.
-  MutableArrayRef<std::unique_ptr<Pass>> passes = impl->passes;
-  for (std::unique_ptr<Pass> &pass : passes)
-    if (failed(runWithCrashRecovery(pass, op, am)))
-      return failure();
-  return success();
-}
-
-/// Run the given passes with crash recover enabled.
-LogicalResult
-PassManager::runWithCrashRecovery(MutableArrayRef<std::unique_ptr<Pass>> passes,
-                                  Operation *op, AnalysisManager am) {
-  RecoveryReproducerContext context(passes, op, crashReproducerStreamFactory,
-                                    !getContext()->isMultithreadingEnabled(),
-                                    verifyPasses);
-
-  // Safely invoke the passes within a recovery context.
-  LogicalResult passManagerResult = failure();
-  llvm::CrashRecoveryContext recoveryContext;
-  recoveryContext.RunSafelyOnThread([&] {
-    for (std::unique_ptr<Pass> &pass : passes)
-      if (failed(OpToOpPassAdaptor::run(pass.get(), op, am, verifyPasses,
-                                        impl->initializationGeneration)))
-        return;
-    passManagerResult = success();
-  });
-  if (succeeded(passManagerResult))
-    return success();
-
-  std::string error;
-  if (failed(context.generate(error)))
-    return op->emitError("<MLIR-PassManager-Crash-Reproducer>: ") << error;
-  return failure();
-}
-
 //===----------------------------------------------------------------------===//
 // PassManager
 //===----------------------------------------------------------------------===//
@@ -857,7 +628,7 @@ PassManager::PassManager(MLIRContext *ctx, Nesting nesting,
                          StringRef operationName)
     : OpPassManager(Identifier::get(operationName, ctx), nesting), context(ctx),
       initializationKey(DenseMapInfo<llvm::hash_code>::getTombstoneKey()),
-      passTiming(false), localReproducer(false), verifyPasses(true) {}
+      passTiming(false), verifyPasses(true) {}
 
 PassManager::~PassManager() {}
 
@@ -898,10 +669,7 @@ LogicalResult PassManager::run(Operation *op) {
   // If reproducer generation is enabled, run the pass manager with crash
   // handling enabled.
   LogicalResult result =
-      crashReproducerStreamFactory
-          ? runWithCrashRecovery(op, am)
-          : OpToOpPassAdaptor::runPipeline(getPasses(), op, am, verifyPasses,
-                                           impl->initializationGeneration);
+      crashReproGenerator ? runWithCrashRecovery(op, am) : runPasses(op, am);
 
   // Notify the context that the run is done.
   context->exitMultiThreadedExecution();
@@ -912,40 +680,6 @@ LogicalResult PassManager::run(Operation *op) {
   return result;
 }
 
-/// Enable support for the pass manager to generate a reproducer on the event
-/// of a crash or a pass failure. `outputFile` is a .mlir filename used to write
-/// the generated reproducer. If `genLocalReproducer` is true, the pass manager
-/// will attempt to generate a local reproducer that contains the smallest
-/// pipeline.
-void PassManager::enableCrashReproducerGeneration(StringRef outputFile,
-                                                  bool genLocalReproducer) {
-  // Capture the filename by value in case outputFile is out of scope when
-  // invoked.
-  std::string filename = outputFile.str();
-  enableCrashReproducerGeneration(
-      [filename](std::string &error) -> std::unique_ptr<ReproducerStream> {
-        std::unique_ptr<llvm::ToolOutputFile> outputFile =
-            mlir::openOutputFile(filename, &error);
-        if (!outputFile) {
-          error = "Failed to create reproducer stream: " + error;
-          return nullptr;
-        }
-        return std::make_unique<FileReproducerStream>(std::move(outputFile));
-      },
-      genLocalReproducer);
-}
-
-/// Enable support for the pass manager to generate a reproducer on the event
-/// of a crash or a pass failure. `factory` is used to construct the streams
-/// to write the generated reproducer to. If `genLocalReproducer` is true, the
-/// pass manager will attempt to generate a local reproducer that contains the
-/// smallest pipeline.
-void PassManager::enableCrashReproducerGeneration(
-    ReproducerStreamFactory factory, bool genLocalReproducer) {
-  crashReproducerStreamFactory = factory;
-  localReproducer = genLocalReproducer;
-}
-
 /// Add the provided instrumentation to the pass manager.
 void PassManager::addInstrumentation(std::unique_ptr<PassInstrumentation> pi) {
   if (!instrumentor)
@@ -954,6 +688,11 @@ void PassManager::addInstrumentation(std::unique_ptr<PassInstrumentation> pi) {
   instrumentor->addInstrumentation(std::move(pi));
 }
 
+LogicalResult PassManager::runPasses(Operation *op, AnalysisManager am) {
+  return OpToOpPassAdaptor::runPipeline(getPasses(), op, am, verifyPasses,
+                                        impl->initializationGeneration);
+}
+
 //===----------------------------------------------------------------------===//
 // AnalysisManager
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Pass/PassCrashRecovery.cpp b/mlir/lib/Pass/PassCrashRecovery.cpp
new file mode 100644
index 000000000000..86180737d321
--- /dev/null
+++ b/mlir/lib/Pass/PassCrashRecovery.cpp
@@ -0,0 +1,441 @@
+//===- PassCrashRecovery.cpp - Pass Crash Recovery Implementation ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "PassDetail.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/Verifier.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/FileUtilities.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/ScopeExit.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/CrashRecoveryContext.h"
+#include "llvm/Support/Mutex.h"
+#include "llvm/Support/Parallel.h"
+#include "llvm/Support/Signals.h"
+#include "llvm/Support/Threading.h"
+#include "llvm/Support/ToolOutputFile.h"
+
+using namespace mlir;
+using namespace mlir::detail;
+
+//===----------------------------------------------------------------------===//
+// RecoveryReproducerContext
+//===----------------------------------------------------------------------===//
+
+namespace mlir {
+namespace detail {
+/// This class contains all of the context for generating a recovery reproducer.
+/// Each recovery context is registered globally to allow for generating
+/// reproducers when a signal is raised, such as a segfault.
+struct RecoveryReproducerContext {
+  RecoveryReproducerContext(std::string passPipelineStr, Operation *op,
+                            PassManager::ReproducerStreamFactory &streamFactory,
+                            bool verifyPasses);
+  ~RecoveryReproducerContext();
+
+  /// Generate a reproducer with the current context.
+  void generate(std::string &description);
+
+  /// Disable this reproducer context. This prevents the context from generating
+  /// a reproducer in the result of a crash.
+  void disable();
+
+  /// Enable a previously disabled reproducer context.
+  void enable();
+
+private:
+  /// This function is invoked in the event of a crash.
+  static void crashHandler(void *);
+
+  /// Register a signal handler to run in the event of a crash.
+  static void registerSignalHandler();
+
+  /// The textual description of the currently executing pipeline.
+  std::string pipeline;
+
+  /// The MLIR operation representing the IR before the crash.
+  Operation *preCrashOperation;
+
+  /// The factory for the reproducer output stream to use when generating the
+  /// reproducer.
+  PassManager::ReproducerStreamFactory &streamFactory;
+
+  /// Various pass manager and context flags.
+  bool disableThreads;
+  bool verifyPasses;
+
+  /// The current set of active reproducer contexts. This is used in the event
+  /// of a crash. This is not thread_local as the pass manager may produce any
+  /// number of child threads. This uses a set to allow for multiple MLIR pass
+  /// managers to be running at the same time.
+  static llvm::ManagedStatic<llvm::sys::SmartMutex<true>> reproducerMutex;
+  static llvm::ManagedStatic<
+      llvm::SmallSetVector<RecoveryReproducerContext *, 1>>
+      reproducerSet;
+};
+} // namespace detail
+} // namespace mlir
+
+llvm::ManagedStatic<llvm::sys::SmartMutex<true>>
+    RecoveryReproducerContext::reproducerMutex;
+llvm::ManagedStatic<llvm::SmallSetVector<RecoveryReproducerContext *, 1>>
+    RecoveryReproducerContext::reproducerSet;
+
+RecoveryReproducerContext::RecoveryReproducerContext(
+    std::string passPipelineStr, Operation *op,
+    PassManager::ReproducerStreamFactory &streamFactory, bool verifyPasses)
+    : pipeline(std::move(passPipelineStr)), preCrashOperation(op->clone()),
+      streamFactory(streamFactory),
+      disableThreads(!op->getContext()->isMultithreadingEnabled()),
+      verifyPasses(verifyPasses) {
+  enable();
+}
+
+RecoveryReproducerContext::~RecoveryReproducerContext() {
+  // Erase the cloned preCrash IR that we cached.
+  preCrashOperation->erase();
+  disable();
+}
+
+void RecoveryReproducerContext::generate(std::string &description) {
+  llvm::raw_string_ostream descOS(description);
+
+  // Try to create a new output stream for this crash reproducer.
+  std::string error;
+  std::unique_ptr<PassManager::ReproducerStream> stream = streamFactory(error);
+  if (!stream) {
+    descOS << "failed to create output stream: " << error;
+    return;
+  }
+  descOS << "reproducer generated at `" << stream->description() << "`";
+
+  // Output the current pass manager configuration to the crash stream.
+  auto &os = stream->os();
+  os << "// configuration: -pass-pipeline='" << pipeline << "'";
+  if (disableThreads)
+    os << " -mlir-disable-threading";
+  if (verifyPasses)
+    os << " -verify-each";
+  os << '\n';
+
+  // Output the .mlir module.
+  preCrashOperation->print(os);
+}
+
+void RecoveryReproducerContext::disable() {
+  llvm::sys::SmartScopedLock<true> lock(*reproducerMutex);
+  reproducerSet->remove(this);
+  if (reproducerSet->empty())
+    llvm::CrashRecoveryContext::Disable();
+}
+
+void RecoveryReproducerContext::enable() {
+  llvm::sys::SmartScopedLock<true> lock(*reproducerMutex);
+  if (reproducerSet->empty())
+    llvm::CrashRecoveryContext::Enable();
+  registerSignalHandler();
+  reproducerSet->insert(this);
+}
+
+void RecoveryReproducerContext::crashHandler(void *) {
+  // Walk the current stack of contexts and generate a reproducer for each one.
+  // We can't know for certain which one was the cause, so we need to generate
+  // a reproducer for all of them.
+  for (RecoveryReproducerContext *context : *reproducerSet) {
+    std::string description;
+    context->generate(description);
+
+    // Emit an error using information only available within the context.
+    context->preCrashOperation->getContext()->printOpOnDiagnostic(false);
+    context->preCrashOperation->emitError()
+        << "A failure has been detected while processing the MLIR module:"
+        << description;
+  }
+}
+
+void RecoveryReproducerContext::registerSignalHandler() {
+  // Ensure that the handler is only registered once.
+  static bool registered =
+      (llvm::sys::AddSignalHandler(crashHandler, nullptr), false);
+  (void)registered;
+}
+
+//===----------------------------------------------------------------------===//
+// PassCrashReproducerGenerator
+//===----------------------------------------------------------------------===//
+
+struct PassCrashReproducerGenerator::Impl {
+  Impl(PassManager::ReproducerStreamFactory &streamFactory,
+       bool localReproducer)
+      : streamFactory(streamFactory), localReproducer(localReproducer) {}
+
+  /// The factory to use when generating a crash reproducer.
+  PassManager::ReproducerStreamFactory streamFactory;
+
+  /// Flag indicating if reproducer generation should be localized to the
+  /// failing pass.
+  bool localReproducer;
+
+  /// A record of all of the currently active reproducer contexts.
+  SmallVector<std::unique_ptr<RecoveryReproducerContext>> activeContexts;
+
+  /// The set of all currently running passes. Note: This is not populated when
+  /// `localReproducer` is true, as each pass will get its own recovery context.
+  SetVector<std::pair<Pass *, Operation *>> runningPasses;
+
+  /// Various pass manager flags that get emitted when generating a reproducer.
+  bool pmFlagVerifyPasses;
+};
+
+PassCrashReproducerGenerator::PassCrashReproducerGenerator(
+    PassManager::ReproducerStreamFactory &streamFactory, bool localReproducer)
+    : impl(std::make_unique<Impl>(streamFactory, localReproducer)) {}
+PassCrashReproducerGenerator::~PassCrashReproducerGenerator() {}
+
+void PassCrashReproducerGenerator::initialize(
+    iterator_range<PassManager::pass_iterator> passes, Operation *op,
+    bool pmFlagVerifyPasses) {
+  assert((!impl->localReproducer ||
+          !op->getContext()->isMultithreadingEnabled()) &&
+         "expected multi-threading to be disabled when generating a local "
+         "reproducer");
+
+  llvm::CrashRecoveryContext::Enable();
+  impl->pmFlagVerifyPasses = pmFlagVerifyPasses;
+
+  // If we aren't generating a local reproducer, prepare a reproducer for the
+  // given top-level operation.
+  if (!impl->localReproducer)
+    prepareReproducerFor(passes, op);
+}
+
+static void
+formatPassOpReproducerMessage(Diagnostic &os,
+                              std::pair<Pass *, Operation *> passOpPair) {
+  os << "`" << passOpPair.first->getName() << "` on "
+     << "'" << passOpPair.second->getName() << "' operation";
+  if (SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(passOpPair.second))
+    os << ": @" << symbol.getName();
+}
+
+void PassCrashReproducerGenerator::finalize(Operation *rootOp,
+                                            LogicalResult executionResult) {
+  // If the pass manager execution succeeded, we don't generate any reproducers.
+  if (succeeded(executionResult))
+    return impl->activeContexts.clear();
+
+  MLIRContext *context = rootOp->getContext();
+  bool shouldPrintOnOp = context->shouldPrintOpOnDiagnostic();
+  context->printOpOnDiagnostic(false);
+  InFlightDiagnostic diag = rootOp->emitError()
+                            << "Failures have been detected while "
+                               "processing an MLIR pass pipeline";
+  context->printOpOnDiagnostic(shouldPrintOnOp);
+
+  // If we are generating a global reproducer, we include all of the running
+  // passes in the error message for the only active context.
+  if (!impl->localReproducer) {
+    assert(impl->activeContexts.size() == 1 && "expected one active context");
+
+    // Generate the reproducer.
+    std::string description;
+    impl->activeContexts.front()->generate(description);
+
+    // Emit an error to the user.
+    Diagnostic &note = diag.attachNote() << "Pipeline failed while executing [";
+    llvm::interleaveComma(impl->runningPasses, note,
+                          [&](const std::pair<Pass *, Operation *> &value) {
+                            formatPassOpReproducerMessage(note, value);
+                          });
+    note << "]: " << description;
+    return;
+  }
+
+  // If we were generating a local reproducer, we generate a reproducer for the
+  // most recently executing pass using the matching entry from  `runningPasses`
+  // to generate a localized diagnostic message.
+  assert(impl->activeContexts.size() == impl->runningPasses.size() &&
+         "expected running passes to match active contexts");
+
+  // Generate the reproducer.
+  RecoveryReproducerContext &reproducerContext = *impl->activeContexts.back();
+  std::string description;
+  reproducerContext.generate(description);
+
+  // Emit an error to the user.
+  Diagnostic &note = diag.attachNote() << "Pipeline failed while executing ";
+  formatPassOpReproducerMessage(note, impl->runningPasses.back());
+  note << ": " << description;
+
+  impl->activeContexts.clear();
+}
+
+void PassCrashReproducerGenerator::prepareReproducerFor(Pass *pass,
+                                                        Operation *op) {
+  // If not tracking local reproducers, we simply remember that this pass is
+  // running.
+  impl->runningPasses.insert(std::make_pair(pass, op));
+  if (!impl->localReproducer)
+    return;
+
+  // Disable the current pass recovery context, if there is one. This may happen
+  // in the case of dynamic pass pipelines.
+  if (!impl->activeContexts.empty())
+    impl->activeContexts.back()->disable();
+
+  // Collect all of the parent scopes of this operation.
+  SmallVector<OperationName> scopes;
+  while (Operation *parentOp = op->getParentOp()) {
+    scopes.push_back(op->getName());
+    op = parentOp;
+  }
+
+  // Emit a pass pipeline string for the current pass running on the current
+  // operation type.
+  std::string passStr;
+  llvm::raw_string_ostream passOS(passStr);
+  for (OperationName scope : llvm::reverse(scopes))
+    passOS << scope << "(";
+  pass->printAsTextualPipeline(passOS);
+  for (unsigned i = 0, e = scopes.size(); i < e; ++i)
+    passOS << ")";
+
+  impl->activeContexts.push_back(std::make_unique<RecoveryReproducerContext>(
+      passOS.str(), op, impl->streamFactory, impl->pmFlagVerifyPasses));
+}
+void PassCrashReproducerGenerator::prepareReproducerFor(
+    iterator_range<PassManager::pass_iterator> passes, Operation *op) {
+  std::string passStr;
+  llvm::raw_string_ostream passOS(passStr);
+  llvm::interleaveComma(
+      passes, passOS, [&](Pass &pass) { pass.printAsTextualPipeline(passOS); });
+
+  impl->activeContexts.push_back(std::make_unique<RecoveryReproducerContext>(
+      passOS.str(), op, impl->streamFactory, impl->pmFlagVerifyPasses));
+}
+
+void PassCrashReproducerGenerator::removeLastReproducerFor(Pass *pass,
+                                                           Operation *op) {
+  // We only pop the active context if we are tracking local reproducers.
+  impl->runningPasses.remove(std::make_pair(pass, op));
+  if (impl->localReproducer) {
+    impl->activeContexts.pop_back();
+
+    // Re-enable the previous pass recovery context, if there was one. This may
+    // happen in the case of dynamic pass pipelines.
+    if (!impl->activeContexts.empty())
+      impl->activeContexts.back()->enable();
+  }
+}
+
+//===----------------------------------------------------------------------===//
+// CrashReproducerInstrumentation
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct CrashReproducerInstrumentation : public PassInstrumentation {
+  CrashReproducerInstrumentation(PassCrashReproducerGenerator &generator)
+      : generator(generator) {}
+  ~CrashReproducerInstrumentation() override = default;
+
+  /// A callback to run before a pass is executed.
+  void runBeforePass(Pass *pass, Operation *op) override {
+    if (!isa<OpToOpPassAdaptor>(pass))
+      generator.prepareReproducerFor(pass, op);
+  }
+
+  /// A callback to run after a pass is successfully executed. This function
+  /// takes a pointer to the pass to be executed, as well as the current
+  /// operation being operated on.
+  void runAfterPass(Pass *pass, Operation *op) override {
+    if (!isa<OpToOpPassAdaptor>(pass))
+      generator.removeLastReproducerFor(pass, op);
+  }
+
+private:
+  /// The generator used to create crash reproducers.
+  PassCrashReproducerGenerator &generator;
+};
+} // end anonymous namespace
+
+//===----------------------------------------------------------------------===//
+// FileReproducerStream
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// This class represents a default instance of PassManager::ReproducerStream
+/// that is backed by a file.
+struct FileReproducerStream : public PassManager::ReproducerStream {
+  FileReproducerStream(std::unique_ptr<llvm::ToolOutputFile> outputFile)
+      : outputFile(std::move(outputFile)) {}
+  ~FileReproducerStream() override { outputFile->keep(); }
+
+  /// Returns a description of the reproducer stream.
+  StringRef description() override { return outputFile->getFilename(); }
+
+  /// Returns the stream on which to output the reproducer.
+  raw_ostream &os() override { return outputFile->os(); }
+
+private:
+  /// ToolOutputFile corresponding to opened `filename`.
+  std::unique_ptr<llvm::ToolOutputFile> outputFile = nullptr;
+};
+} // end anonymous namespace
+
+//===----------------------------------------------------------------------===//
+// PassManager
+//===----------------------------------------------------------------------===//
+
+LogicalResult PassManager::runWithCrashRecovery(Operation *op,
+                                                AnalysisManager am) {
+  crashReproGenerator->initialize(getPasses(), op, verifyPasses);
+
+  // Safely invoke the passes within a recovery context.
+  LogicalResult passManagerResult = failure();
+  llvm::CrashRecoveryContext recoveryContext;
+  recoveryContext.RunSafelyOnThread(
+      [&] { passManagerResult = runPasses(op, am); });
+  crashReproGenerator->finalize(op, passManagerResult);
+  return passManagerResult;
+}
+
+void PassManager::enableCrashReproducerGeneration(StringRef outputFile,
+                                                  bool genLocalReproducer) {
+  // Capture the filename by value in case outputFile is out of scope when
+  // invoked.
+  std::string filename = outputFile.str();
+  enableCrashReproducerGeneration(
+      [filename](std::string &error) -> std::unique_ptr<ReproducerStream> {
+        std::unique_ptr<llvm::ToolOutputFile> outputFile =
+            mlir::openOutputFile(filename, &error);
+        if (!outputFile) {
+          error = "Failed to create reproducer stream: " + error;
+          return nullptr;
+        }
+        return std::make_unique<FileReproducerStream>(std::move(outputFile));
+      },
+      genLocalReproducer);
+}
+
+void PassManager::enableCrashReproducerGeneration(
+    ReproducerStreamFactory factory, bool genLocalReproducer) {
+  assert(!crashReproGenerator &&
+         "crash reproducer has already been initialized");
+  if (genLocalReproducer && getContext()->isMultithreadingEnabled())
+    llvm::report_fatal_error(
+        "Local crash reproduction can't be setup on a "
+        "pass-manager without disabling multi-threading first.");
+
+  crashReproGenerator = std::make_unique<PassCrashReproducerGenerator>(
+      factory, genLocalReproducer);
+  addInstrumentation(
+      std::make_unique<CrashReproducerInstrumentation>(*crashReproGenerator));
+}

diff  --git a/mlir/lib/Pass/PassDetail.h b/mlir/lib/Pass/PassDetail.h
index 436ad65629cc..24782fc80f1b 100644
--- a/mlir/lib/Pass/PassDetail.h
+++ b/mlir/lib/Pass/PassDetail.h
@@ -82,6 +82,43 @@ class OpToOpPassAdaptor
   friend class mlir::PassManager;
 };
 
+//===----------------------------------------------------------------------===//
+// PassCrashReproducerGenerator
+//===----------------------------------------------------------------------===//
+
+class PassCrashReproducerGenerator {
+public:
+  PassCrashReproducerGenerator(
+      PassManager::ReproducerStreamFactory &streamFactory,
+      bool localReproducer);
+  ~PassCrashReproducerGenerator();
+
+  /// Initialize the generator in preparation for reproducer generation. The
+  /// generator should be reinitialized before each run of the pass manager.
+  void initialize(iterator_range<PassManager::pass_iterator> passes,
+                  Operation *op, bool pmFlagVerifyPasses);
+  /// Finalize the current run of the generator, generating any necessary
+  /// reproducers if the provided execution result is a failure.
+  void finalize(Operation *rootOp, LogicalResult executionResult);
+
+  /// Prepare a new reproducer for the given pass, operating on `op`.
+  void prepareReproducerFor(Pass *pass, Operation *op);
+
+  /// Prepare a new reproducer for the given passes, operating on `op`.
+  void prepareReproducerFor(iterator_range<PassManager::pass_iterator> passes,
+                            Operation *op);
+
+  /// Remove the last recorded reproducer anchored at the given pass and
+  /// operation.
+  void removeLastReproducerFor(Pass *pass, Operation *op);
+
+private:
+  struct Impl;
+
+  /// The internal implementation of the crash reproducer.
+  std::unique_ptr<Impl> impl;
+};
+
 } // end namespace detail
 } // end namespace mlir
 #endif // MLIR_PASS_PASSDETAIL_H_

diff  --git a/mlir/test/Pass/crash-recovery.mlir b/mlir/test/Pass/crash-recovery.mlir
index 0d654b63a4d0..58743e55969a 100644
--- a/mlir/test/Pass/crash-recovery.mlir
+++ b/mlir/test/Pass/crash-recovery.mlir
@@ -1,26 +1,33 @@
-// RUN: mlir-opt %s -pass-pipeline='func(test-function-pass, test-pass-crash)' -pass-pipeline-crash-reproducer=%t -verify-diagnostics
+// RUN: mlir-opt %s -pass-pipeline='module(test-module-pass, test-pass-crash)' -pass-pipeline-crash-reproducer=%t -verify-diagnostics
 // RUN: cat %t | FileCheck -check-prefix=REPRO %s
-// RUN: mlir-opt %s -pass-pipeline='func(test-function-pass, test-pass-crash)' -pass-pipeline-crash-reproducer=%t -verify-diagnostics -pass-pipeline-local-reproducer
+// RUN: mlir-opt %s -pass-pipeline='module(test-module-pass, test-pass-crash)' -pass-pipeline-crash-reproducer=%t -verify-diagnostics -pass-pipeline-local-reproducer -mlir-disable-threading
 // RUN: cat %t | FileCheck -check-prefix=REPRO_LOCAL %s
 
-// Check that we correctly handle verifiers passes with local reproducer, this use to crash.
-// RUN: mlir-opt %s -test-function-pass -test-function-pass  -test-module-pass -pass-pipeline-crash-reproducer=%t -pass-pipeline-local-reproducer
+// Check that we correctly handle verifiers passes with local reproducer, this used to crash.
+// RUN: mlir-opt %s -test-module-pass -test-module-pass  -test-module-pass -pass-pipeline-crash-reproducer=%t -pass-pipeline-local-reproducer -mlir-disable-threading
+// RUN: cat %t | FileCheck -check-prefix=REPRO_LOCAL %s
+
+// Check that local reproducers will also traverse dynamic pass pipelines.
+// RUN: mlir-opt %s -pass-pipeline='test-module-pass,test-dynamic-pipeline{op-name=inner_mod1 run-on-nested-operations=1 dynamic-pipeline=test-pass-crash}' -pass-pipeline-crash-reproducer=%t -verify-diagnostics -pass-pipeline-local-reproducer --mlir-disable-threading
+// RUN: cat %t | FileCheck -check-prefix=REPRO_LOCAL_DYNAMIC %s
 
-// expected-error at +1 {{A failure has been detected while processing the MLIR module}}
-module {
-  func @foo() {
-    return
-  }
+// expected-error at below {{Failures have been detected while processing an MLIR pass pipeline}}
+// expected-note at below {{Pipeline failed while executing}}
+module @inner_mod1 {
+  module @foo {}
 }
 
-// REPRO: configuration: -pass-pipeline='func(test-function-pass, test-pass-crash)'
+// REPRO: configuration: -pass-pipeline='module(test-module-pass, test-pass-crash)'
+
+// REPRO: module @inner_mod1
+// REPRO: module @foo {
+
+// REPRO_LOCAL: configuration: -pass-pipeline='module(test-pass-crash)'
 
-// REPRO: module
-// REPRO: func @foo() {
-// REPRO-NEXT: return
+// REPRO_LOCAL: module @inner_mod1
+// REPRO_LOCAL: module @foo {
 
-// REPRO_LOCAL: configuration: -pass-pipeline='func(test-pass-crash)'
+// REPRO_LOCAL_DYNAMIC: configuration: -pass-pipeline='module(test-pass-crash)'
 
-// REPRO_LOCAL: module
-// REPRO_LOCAL: func @foo() {
-// REPRO_LOCAL-NEXT: return
+// REPRO_LOCAL_DYNAMIC: module @inner_mod1
+// REPRO_LOCAL_DYNAMIC: module @foo {


        


More information about the Mlir-commits mailing list