[Mlir-commits] [mlir] 64ce90e - [mlir] Add a new `print-ir-after-failure` IR pass printing flag

River Riddle llvmlistbot at llvm.org
Wed May 19 17:03:22 PDT 2021


Author: River Riddle
Date: 2021-05-19T16:54:20-07:00
New Revision: 64ce90e1af5c38822b7730bb6f21ed3d99f2f364

URL: https://github.com/llvm/llvm-project/commit/64ce90e1af5c38822b7730bb6f21ed3d99f2f364
DIFF: https://github.com/llvm/llvm-project/commit/64ce90e1af5c38822b7730bb6f21ed3d99f2f364.diff

LOG: [mlir] Add a new `print-ir-after-failure` IR pass printing flag

This flag will print the IR after a pass only in the case where the pass failed. This can be useful to more easily view the invalid IR, without needing to print after every pass in the pipeline.

Differential Revision: https://reviews.llvm.org/D101853

Added: 
    

Modified: 
    mlir/docs/PassManagement.md
    mlir/include/mlir/Pass/PassManager.h
    mlir/lib/Pass/IRPrinting.cpp
    mlir/lib/Pass/PassManagerOptions.cpp
    mlir/test/Pass/ir-printing.mlir
    mlir/test/lib/Pass/TestPassManager.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/PassManagement.md b/mlir/docs/PassManagement.md
index ad9c7e566ad51..ec890dab710c6 100644
--- a/mlir/docs/PassManagement.md
+++ b/mlir/docs/PassManagement.md
@@ -1126,6 +1126,21 @@ func @simple_constant() -> (i32, i32) {
 }
 ```
 
+*   `print-ir-after-failure`
+    *   Only print IR after a pass failure.
+    *   This option should *not* be used with the other `print-ir-after` flags
+        above.
+
+```shell
+$ mlir-opt foo.mlir -pass-pipeline='func(cse,bad-pass)' -print-ir-failure
+
+*** IR Dump After BadPass Failed ***
+func @simple_constant() -> (i32, i32) {
+  %c1_i32 = constant 1 : i32
+  return %c1_i32, %c1_i32 : i32, i32
+}
+```
+
 *   `print-ir-module-scope`
     *   Always print the top-level module operation, regardless of pass type or
         operation nesting level.

diff  --git a/mlir/include/mlir/Pass/PassManager.h b/mlir/include/mlir/Pass/PassManager.h
index fff2daf4cd626..a012395a32f43 100644
--- a/mlir/include/mlir/Pass/PassManager.h
+++ b/mlir/include/mlir/Pass/PassManager.h
@@ -242,10 +242,15 @@ class PassManager : public OpPassManager {
     ///   pass, in the case of a non-failure, we should first check if any
     ///   potential mutations were made. This allows for reducing the number of
     ///   logs that don't contain meaningful changes.
+    /// * 'printAfterOnlyOnFailure' signals that when printing the IR after a
+    ///   pass, we only print in the case of a failure.
+    ///     - This option should *not* be used with the other `printAfter` flags
+    ///       above.
     /// * 'opPrintingFlags' sets up the printing flags to use when printing the
     ///   IR.
     explicit IRPrinterConfig(
         bool printModuleScope = false, bool printAfterOnlyOnChange = false,
+        bool printAfterOnlyOnFailure = false,
         OpPrintingFlags opPrintingFlags = OpPrintingFlags());
     virtual ~IRPrinterConfig();
 
@@ -270,6 +275,12 @@ class PassManager : public OpPassManager {
     /// "changed".
     bool shouldPrintAfterOnlyOnChange() const { return printAfterOnlyOnChange; }
 
+    /// Returns true if the IR should only printed after a pass if the pass
+    /// "failed".
+    bool shouldPrintAfterOnlyOnFailure() const {
+      return printAfterOnlyOnFailure;
+    }
+
     /// Returns the printing flags to be used to print the IR.
     OpPrintingFlags getOpPrintingFlags() const { return opPrintingFlags; }
 
@@ -281,6 +292,10 @@ class PassManager : public OpPassManager {
     /// a change is detected.
     bool printAfterOnlyOnChange;
 
+    /// A flag that indicates that the IR after a pass should only be printed if
+    /// the pass failed.
+    bool printAfterOnlyOnFailure;
+
     /// Flags to control printing behavior.
     OpPrintingFlags opPrintingFlags;
   };
@@ -299,16 +314,20 @@ class PassManager : public OpPassManager {
   /// * 'printAfterOnlyOnChange' signals that when printing the IR after a
   ///   pass, in the case of a non-failure, we should first check if any
   ///   potential mutations were made.
+  /// * 'printAfterOnlyOnFailure' signals that when printing the IR after a
+  ///   pass, we only print in the case of a failure.
+  ///     - This option should *not* be used with the other `printAfter` flags
+  ///       above.
+  /// * 'out' corresponds to the stream to output the printed IR to.
   /// * 'opPrintingFlags' sets up the printing flags to use when printing the
   ///   IR.
-  /// * 'out' corresponds to the stream to output the printed IR to.
   void enableIRPrinting(
       std::function<bool(Pass *, Operation *)> shouldPrintBeforePass =
           [](Pass *, Operation *) { return true; },
       std::function<bool(Pass *, Operation *)> shouldPrintAfterPass =
           [](Pass *, Operation *) { return true; },
       bool printModuleScope = true, bool printAfterOnlyOnChange = true,
-      raw_ostream &out = llvm::errs(),
+      bool printAfterOnlyOnFailure = false, raw_ostream &out = llvm::errs(),
       OpPrintingFlags opPrintingFlags = OpPrintingFlags());
 
   //===--------------------------------------------------------------------===//

diff  --git a/mlir/lib/Pass/IRPrinting.cpp b/mlir/lib/Pass/IRPrinting.cpp
index f8c58c78ae76a..b4074cc48d9e2 100644
--- a/mlir/lib/Pass/IRPrinting.cpp
+++ b/mlir/lib/Pass/IRPrinting.cpp
@@ -134,6 +134,11 @@ void IRPrinterInstrumentation::runBeforePass(Pass *pass, Operation *op) {
 void IRPrinterInstrumentation::runAfterPass(Pass *pass, Operation *op) {
   if (isa<OpToOpPassAdaptor>(pass))
     return;
+
+  // Check to see if we are only printing on failure.
+  if (config->shouldPrintAfterOnlyOnFailure())
+    return;
+
   // If the config asked to detect changes, compare the current fingerprint with
   // the previous.
   if (config->shouldPrintAfterOnlyOnChange()) {
@@ -177,9 +182,11 @@ void IRPrinterInstrumentation::runAfterPassFailed(Pass *pass, Operation *op) {
 /// Initialize the configuration.
 PassManager::IRPrinterConfig::IRPrinterConfig(bool printModuleScope,
                                               bool printAfterOnlyOnChange,
+                                              bool printAfterOnlyOnFailure,
                                               OpPrintingFlags opPrintingFlags)
     : printModuleScope(printModuleScope),
       printAfterOnlyOnChange(printAfterOnlyOnChange),
+      printAfterOnlyOnFailure(printAfterOnlyOnFailure),
       opPrintingFlags(opPrintingFlags) {}
 PassManager::IRPrinterConfig::~IRPrinterConfig() {}
 
@@ -212,9 +219,10 @@ struct BasicIRPrinterConfig : public PassManager::IRPrinterConfig {
       std::function<bool(Pass *, Operation *)> shouldPrintBeforePass,
       std::function<bool(Pass *, Operation *)> shouldPrintAfterPass,
       bool printModuleScope, bool printAfterOnlyOnChange,
-      OpPrintingFlags opPrintingFlags, raw_ostream &out)
+      bool printAfterOnlyOnFailure, OpPrintingFlags opPrintingFlags,
+      raw_ostream &out)
       : IRPrinterConfig(printModuleScope, printAfterOnlyOnChange,
-                        opPrintingFlags),
+                        printAfterOnlyOnFailure, opPrintingFlags),
         shouldPrintBeforePass(shouldPrintBeforePass),
         shouldPrintAfterPass(shouldPrintAfterPass), out(out) {
     assert((shouldPrintBeforePass || shouldPrintAfterPass) &&
@@ -257,9 +265,11 @@ void PassManager::enableIRPrinting(std::unique_ptr<IRPrinterConfig> config) {
 void PassManager::enableIRPrinting(
     std::function<bool(Pass *, Operation *)> shouldPrintBeforePass,
     std::function<bool(Pass *, Operation *)> shouldPrintAfterPass,
-    bool printModuleScope, bool printAfterOnlyOnChange, raw_ostream &out,
+    bool printModuleScope, bool printAfterOnlyOnChange,
+    bool printAfterOnlyOnFailure, raw_ostream &out,
     OpPrintingFlags opPrintingFlags) {
   enableIRPrinting(std::make_unique<BasicIRPrinterConfig>(
       std::move(shouldPrintBeforePass), std::move(shouldPrintAfterPass),
-      printModuleScope, printAfterOnlyOnChange, opPrintingFlags, out));
+      printModuleScope, printAfterOnlyOnChange, printAfterOnlyOnFailure,
+      opPrintingFlags, out));
 }

diff  --git a/mlir/lib/Pass/PassManagerOptions.cpp b/mlir/lib/Pass/PassManagerOptions.cpp
index 133f00a08d77b..f5780ab27b5ba 100644
--- a/mlir/lib/Pass/PassManagerOptions.cpp
+++ b/mlir/lib/Pass/PassManagerOptions.cpp
@@ -48,6 +48,11 @@ struct PassManagerOptions {
       llvm::cl::desc(
           "When printing the IR after a pass, only print if the IR changed"),
       llvm::cl::init(false)};
+  llvm::cl::opt<bool> printAfterFailure{
+      "print-ir-after-failure",
+      llvm::cl::desc(
+          "When printing the IR after a pass, only print if the pass failed"),
+      llvm::cl::init(false)};
   llvm::cl::opt<bool> printModuleScope{
       "print-ir-module-scope",
       llvm::cl::desc("When printing IR for print-ir-[before|after]{-all} "
@@ -96,8 +101,9 @@ void PassManagerOptions::addPrinterInstrumentation(PassManager &pm) {
   }
 
   // Handle print-after.
-  if (printAfterAll) {
-    // If we are printing after all, then just return true for the filter.
+  if (printAfterAll || printAfterFailure) {
+    // If we are printing after all or failure, then just return true for the
+    // filter.
     shouldPrintAfterPass = [](Pass *, Operation *) { return true; };
   } else if (printAfter.hasAnyOccurrences()) {
     // Otherwise if there are specific passes to print after, then check to see
@@ -114,7 +120,8 @@ void PassManagerOptions::addPrinterInstrumentation(PassManager &pm) {
 
   // Otherwise, add the IR printing instrumentation.
   pm.enableIRPrinting(shouldPrintBeforePass, shouldPrintAfterPass,
-                      printModuleScope, printAfterChange, llvm::errs());
+                      printModuleScope, printAfterChange, printAfterFailure,
+                      llvm::errs());
 }
 
 void mlir::registerPassManagerCLOptions() {

diff  --git a/mlir/test/Pass/ir-printing.mlir b/mlir/test/Pass/ir-printing.mlir
index 8bb86b36c1813..8616d6ff3f488 100644
--- a/mlir/test/Pass/ir-printing.mlir
+++ b/mlir/test/Pass/ir-printing.mlir
@@ -4,6 +4,7 @@
 // RUN: mlir-opt %s -mlir-disable-threading=true -pass-pipeline='func(cse,canonicalize)' -print-ir-after-all -o /dev/null 2>&1 | FileCheck -check-prefix=AFTER_ALL %s
 // RUN: mlir-opt %s -mlir-disable-threading=true -pass-pipeline='func(cse,canonicalize)' -print-ir-before=cse -print-ir-module-scope -o /dev/null 2>&1 | FileCheck -check-prefix=BEFORE_MODULE %s
 // RUN: mlir-opt %s -mlir-disable-threading=true -pass-pipeline='func(cse,cse)' -print-ir-after-all -print-ir-after-change -o /dev/null 2>&1 | FileCheck -check-prefix=AFTER_ALL_CHANGE %s
+// RUN: not mlir-opt %s -mlir-disable-threading=true -pass-pipeline='func(cse,test-pass-failure)' -print-ir-after-failure -o /dev/null 2>&1 | FileCheck -check-prefix=AFTER_FAILURE %s
 
 func @foo() {
   %0 = constant 0 : i32
@@ -60,3 +61,6 @@ func @bar() {
 // AFTER_ALL_CHANGE-NOT: *** IR Dump After{{.*}}CSE ***
 // We expect that only 'foo' changed during CSE, and the second run of CSE did
 // nothing.
+
+// AFTER_FAILURE-NOT: *** IR Dump After{{.*}}CSE
+// AFTER_FAILURE: *** IR Dump After{{.*}}TestFailurePass Failed ***

diff  --git a/mlir/test/lib/Pass/TestPassManager.cpp b/mlir/test/lib/Pass/TestPassManager.cpp
index 4cf028c05e558..937a5c2317c27 100644
--- a/mlir/test/lib/Pass/TestPassManager.cpp
+++ b/mlir/test/lib/Pass/TestPassManager.cpp
@@ -58,6 +58,12 @@ class TestCrashRecoveryPass
   void runOnOperation() final { abort(); }
 };
 
+/// A test pass that always fails to enable testing the failure recovery
+/// mechanisms of the pass manager.
+class TestFailurePass : public PassWrapper<TestFailurePass, OperationPass<>> {
+  void runOnOperation() final { signalPassFailure(); }
+};
+
 /// A test pass that contains a statistic.
 struct TestStatisticPass
     : public PassWrapper<TestStatisticPass, OperationPass<>> {
@@ -103,6 +109,8 @@ void registerPassManagerTestPass() {
 
   PassRegistration<TestCrashRecoveryPass>(
       "test-pass-crash", "Test a pass in the pass manager that always crashes");
+  PassRegistration<TestFailurePass>(
+      "test-pass-failure", "Test a pass in the pass manager that always fails");
 
   PassRegistration<TestStatisticPass> unusedStatP("test-stats-pass",
                                                   "Test pass statistics");


        


More information about the Mlir-commits mailing list