[Mlir-commits] [mlir] [MLIR][Python] Add optional arguments to PassManager IR printing (PR #89301)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Apr 21 16:48:04 PDT 2024


https://github.com/vraspar updated https://github.com/llvm/llvm-project/pull/89301

>From dd062a077de06018cd765526696f6e7200b66c10 Mon Sep 17 00:00:00 2001
From: Vrajang Parikh <vrajangp at d-matrix.ai>
Date: Mon, 8 Apr 2024 21:07:00 +0000
Subject: [PATCH] Add optional arguments to passmanger ir printing

---
 mlir/include/mlir-c/Pass.h                    | 10 ++++++++--
 mlir/lib/Bindings/Python/Pass.cpp             | 19 +++++++++++++++++--
 mlir/lib/CAPI/IR/Pass.cpp                     | 19 +++++++++++++++++--
 .../mlir/_mlir_libs/_mlir/passmanager.pyi     | 10 +++++++++-
 4 files changed, 51 insertions(+), 7 deletions(-)

diff --git a/mlir/include/mlir-c/Pass.h b/mlir/include/mlir-c/Pass.h
index 35db138305d1e2..3e3e2bcc656e60 100644
--- a/mlir/include/mlir-c/Pass.h
+++ b/mlir/include/mlir-c/Pass.h
@@ -74,9 +74,15 @@ mlirPassManagerGetAsOpPassManager(MlirPassManager passManager);
 MLIR_CAPI_EXPORTED MlirLogicalResult
 mlirPassManagerRunOnOp(MlirPassManager passManager, MlirOperation op);
 
-/// Enable mlir-print-ir-after-all.
+/// Configure IR printing options for the provided `passManager`.
+MLIR_CAPI_EXPORTED void mlirPassManagerEnableIRPrinting(
+    MlirPassManager passManager, bool shouldPrintBeforePass,
+    bool shouldPrintAfterPass, bool printModuleScope,
+    bool printAfterOnlyOnChange, bool printAfterOnlyOnFailure);
+
+// Enable timing of passes
 MLIR_CAPI_EXPORTED void
-mlirPassManagerEnableIRPrinting(MlirPassManager passManager);
+mlirPassManagerEnableTiming(MlirPassManager passManager);
 
 /// Enable / disable verify-each.
 MLIR_CAPI_EXPORTED void
diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index a68421b61641f6..f52503d341c303 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -74,10 +74,25 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
            "Releases (leaks) the backing pass manager (testing)")
       .def(
           "enable_ir_printing",
-          [](PyPassManager &passManager) {
-            mlirPassManagerEnableIRPrinting(passManager.get());
+          [](PyPassManager &passManager, bool print_before_pass,
+             bool print_after_pass, bool printModuleScope,
+             bool print_after_only_on_change,
+             bool print_after_only_on_failure) {
+            mlirPassManagerEnableIRPrinting(
+                passManager.get(), print_before_pass, print_after_pass,
+                printModuleScope, print_after_only_on_change,
+                print_after_only_on_failure);
           },
+          "print_before_pass"_a = true, "print_after_pass"_a = true,
+          "printModuleScope"_a = true, "print_after_only_on_change"_a = true,
+          "print_after_only_on_failure"_a = false,
           "Enable mlir-print-ir-after-all.")
+      .def(
+          "enable_timing",
+          [](PyPassManager &PyPassManager) {
+            mlirPassManagerEnableTiming(PyPassManager.get());
+          },
+          "Enable timing of passes")
       .def(
           "enable_verifier",
           [](PyPassManager &passManager, bool enable) {
diff --git a/mlir/lib/CAPI/IR/Pass.cpp b/mlir/lib/CAPI/IR/Pass.cpp
index d242baae99c086..6d63b6519bda06 100644
--- a/mlir/lib/CAPI/IR/Pass.cpp
+++ b/mlir/lib/CAPI/IR/Pass.cpp
@@ -44,8 +44,23 @@ MlirLogicalResult mlirPassManagerRunOnOp(MlirPassManager passManager,
   return wrap(unwrap(passManager)->run(unwrap(op)));
 }
 
-void mlirPassManagerEnableIRPrinting(MlirPassManager passManager) {
-  return unwrap(passManager)->enableIRPrinting();
+void mlirPassManagerEnableIRPrinting(MlirPassManager passManager,
+                                     bool shouldPrintBeforePass,
+                                     bool shouldPrintAfterPass,
+                                     bool printModuleScope,
+                                     bool printAfterOnlyOnChange,
+                                     bool printAfterOnlyOnFailure) {
+  auto shouldPrintBeforeFn = [shouldPrintBeforePass](Pass *, Operation *) {return shouldPrintBeforePass;};
+  auto shouldPrintAfterFn = [shouldPrintAfterPass](Pass *, Operation *) {return shouldPrintAfterPass;};
+
+  return unwrap(passManager)
+      ->enableIRPrinting(shouldPrintBeforeFn, shouldPrintAfterFn,
+                         printModuleScope, printAfterOnlyOnChange,
+                         printAfterOnlyOnFailure);
+}
+
+void mlirPassManagerEnableTiming(MlirPassManager passManager) {
+  return unwrap(passManager)->enableTiming();
 }
 
 void mlirPassManagerEnableVerifier(MlirPassManager passManager, bool enable) {
diff --git a/mlir/python/mlir/_mlir_libs/_mlir/passmanager.pyi b/mlir/python/mlir/_mlir_libs/_mlir/passmanager.pyi
index c072d5e0fb86f3..e08553cda33cab 100644
--- a/mlir/python/mlir/_mlir_libs/_mlir/passmanager.pyi
+++ b/mlir/python/mlir/_mlir_libs/_mlir/passmanager.pyi
@@ -16,7 +16,15 @@ class PassManager:
     def __init__(self, context: Optional[_ir.Context] = None) -> None: ...
     def _CAPICreate(self) -> object: ...
     def _testing_release(self) -> None: ...
-    def enable_ir_printing(self) -> None: ...
+    def enable_ir_printing(
+        self,
+        print_before_pass=True,
+        print_after_pass=True,
+        printModuleScope=True,
+        print_after_only_on_change=True,
+        print_after_only_on_failure=False,
+    ) -> None: ...
+    def enable_timing(self) -> None: ...
     def enable_verifier(self, enable: bool) -> None: ...
     @staticmethod
     def parse(pipeline: str, context: Optional[_ir.Context] = None) -> PassManager: ...



More information about the Mlir-commits mailing list