[Mlir-commits] [mlir] f8eceb4 - [MLIR] [Python] align python ir printing with mlir-print-ir-after-all (#107522)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Sep 17 20:54:19 PDT 2024


Author: Bimo
Date: 2024-09-18T11:54:16+08:00
New Revision: f8eceb45d0bbca092164efffc92f2e9d66b304a5

URL: https://github.com/llvm/llvm-project/commit/f8eceb45d0bbca092164efffc92f2e9d66b304a5
DIFF: https://github.com/llvm/llvm-project/commit/f8eceb45d0bbca092164efffc92f2e9d66b304a5.diff

LOG: [MLIR] [Python] align python ir printing with mlir-print-ir-after-all (#107522)

When using the `enable_ir_printing` API from Python, it invokes IR
printing with default args, printing the IR before each pass and
printing IR after pass only if there have been changes. This PR attempts
to align the `enable_ir_printing` API with the documentation

Added: 
    

Modified: 
    mlir/include/mlir-c/Pass.h
    mlir/lib/Bindings/Python/Pass.cpp
    mlir/lib/CAPI/IR/Pass.cpp
    mlir/python/mlir/_mlir_libs/_mlir/passmanager.pyi
    mlir/test/python/pass_manager.py

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/Pass.h b/mlir/include/mlir-c/Pass.h
index 35db138305d1e2..2218ec0f47d199 100644
--- a/mlir/include/mlir-c/Pass.h
+++ b/mlir/include/mlir-c/Pass.h
@@ -74,9 +74,11 @@ mlirPassManagerGetAsOpPassManager(MlirPassManager passManager);
 MLIR_CAPI_EXPORTED MlirLogicalResult
 mlirPassManagerRunOnOp(MlirPassManager passManager, MlirOperation op);
 
-/// Enable mlir-print-ir-after-all.
-MLIR_CAPI_EXPORTED void
-mlirPassManagerEnableIRPrinting(MlirPassManager passManager);
+/// Enable IR printing.
+MLIR_CAPI_EXPORTED void mlirPassManagerEnableIRPrinting(
+    MlirPassManager passManager, bool printBeforeAll, bool printAfterAll,
+    bool printModuleScope, bool printAfterOnlyOnChange,
+    bool printAfterOnlyOnFailure);
 
 /// Enable / disable verify-each.
 MLIR_CAPI_EXPORTED void

diff  --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index a68421b61641f6..1d0e5ce2115a0a 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -74,10 +74,17 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
            "Releases (leaks) the backing pass manager (testing)")
       .def(
           "enable_ir_printing",
-          [](PyPassManager &passManager) {
-            mlirPassManagerEnableIRPrinting(passManager.get());
+          [](PyPassManager &passManager, bool printBeforeAll,
+             bool printAfterAll, bool printModuleScope, bool printAfterChange,
+             bool printAfterFailure) {
+            mlirPassManagerEnableIRPrinting(
+                passManager.get(), printBeforeAll, printAfterAll,
+                printModuleScope, printAfterChange, printAfterFailure);
           },
-          "Enable mlir-print-ir-after-all.")
+          "print_before_all"_a = false, "print_after_all"_a = true,
+          "print_module_scope"_a = false, "print_after_change"_a = false,
+          "print_after_failure"_a = false,
+          "Enable IR printing, default as mlir-print-ir-after-all.")
       .def(
           "enable_verifier",
           [](PyPassManager &passManager, bool enable) {

diff  --git a/mlir/lib/CAPI/IR/Pass.cpp b/mlir/lib/CAPI/IR/Pass.cpp
index d242baae99c086..a6c9fbd08d45a6 100644
--- a/mlir/lib/CAPI/IR/Pass.cpp
+++ b/mlir/lib/CAPI/IR/Pass.cpp
@@ -44,8 +44,21 @@ MlirLogicalResult mlirPassManagerRunOnOp(MlirPassManager passManager,
   return wrap(unwrap(passManager)->run(unwrap(op)));
 }
 
-void mlirPassManagerEnableIRPrinting(MlirPassManager passManager) {
-  return unwrap(passManager)->enableIRPrinting();
+void mlirPassManagerEnableIRPrinting(MlirPassManager passManager,
+                                     bool printBeforeAll, bool printAfterAll,
+                                     bool printModuleScope,
+                                     bool printAfterOnlyOnChange,
+                                     bool printAfterOnlyOnFailure) {
+  auto shouldPrintBeforePass = [printBeforeAll](Pass *, Operation *) {
+    return printBeforeAll;
+  };
+  auto shouldPrintAfterPass = [printAfterAll](Pass *, Operation *) {
+    return printAfterAll;
+  };
+  return unwrap(passManager)
+      ->enableIRPrinting(shouldPrintBeforePass, shouldPrintAfterPass,
+                         printModuleScope, printAfterOnlyOnChange,
+                         printAfterOnlyOnFailure);
 }
 
 void mlirPassManagerEnableVerifier(MlirPassManager passManager, bool enable) {

diff  --git a/mlir/python/mlir/_mlir_libs/_mlir/passmanager.pyi b/mlir/python/mlir/_mlir_libs/_mlir/passmanager.pyi
index c072d5e0fb86f3..5d115e8222d730 100644
--- a/mlir/python/mlir/_mlir_libs/_mlir/passmanager.pyi
+++ b/mlir/python/mlir/_mlir_libs/_mlir/passmanager.pyi
@@ -16,7 +16,14 @@ class PassManager:
     def __init__(self, context: Optional[_ir.Context] = None) -> None: ...
     def _CAPICreate(self) -> object: ...
     def _testing_release(self) -> None: ...
-    def enable_ir_printing(self) -> None: ...
+    def enable_ir_printing(
+        self,
+        print_before_all: bool = False,
+        print_after_all: bool = True,
+        print_module_scope: bool = False,
+        print_after_change: bool = False,
+        print_after_failure: bool = False,
+    ) -> None: ...
     def enable_verifier(self, enable: bool) -> None: ...
     @staticmethod
     def parse(pipeline: str, context: Optional[_ir.Context] = None) -> PassManager: ...

diff  --git a/mlir/test/python/pass_manager.py b/mlir/test/python/pass_manager.py
index 43af80b53166cc..74967032562351 100644
--- a/mlir/test/python/pass_manager.py
+++ b/mlir/test/python/pass_manager.py
@@ -300,14 +300,40 @@ def testPrintIrAfterAll():
         pm = PassManager.parse("builtin.module(canonicalize)")
         ctx.enable_multithreading(False)
         pm.enable_ir_printing()
-        # CHECK: // -----// IR Dump Before Canonicalizer (canonicalize) ('builtin.module' operation) //----- //
+        # CHECK: // -----// IR Dump After Canonicalizer (canonicalize) //----- //
+        # CHECK: module {
+        # CHECK:   func.func @main() {
+        # CHECK:     return
+        # CHECK:   }
+        # CHECK: }
+        pm.run(module)
+
+
+# CHECK-LABEL: TEST: testPrintIrBeforeAndAfterAll
+ at run
+def testPrintIrBeforeAndAfterAll():
+    with Context() as ctx:
+        module = ModuleOp.parse(
+            """
+          module {
+            func.func @main() {
+              %0 = arith.constant 10
+              return
+            }
+          }
+        """
+        )
+        pm = PassManager.parse("builtin.module(canonicalize)")
+        ctx.enable_multithreading(False)
+        pm.enable_ir_printing(print_before_all=True, print_after_all=True)
+        # CHECK: // -----// IR Dump Before Canonicalizer (canonicalize) //----- //
         # CHECK: module {
         # CHECK:   func.func @main() {
         # CHECK:     %[[C10:.*]] = arith.constant 10 : i64
         # CHECK:     return
         # CHECK:   }
         # CHECK: }
-        # CHECK: // -----// IR Dump After Canonicalizer (canonicalize) ('builtin.module' operation) //----- //
+        # CHECK: // -----// IR Dump After Canonicalizer (canonicalize) //----- //
         # CHECK: module {
         # CHECK:   func.func @main() {
         # CHECK:     return


        


More information about the Mlir-commits mailing list