[Mlir-commits] [mlir] [mlir][pybind] export more options on enable_ir_printing() api (PR #65854)

Yuanqiang Liu llvmlistbot at llvm.org
Sat Sep 9 10:51:01 PDT 2023


https://github.com/qingyunqu created https://github.com/llvm/llvm-project/pull/65854:

None

>From 45c6dd1358dce3b3892caa77555f445edd9bb2df Mon Sep 17 00:00:00 2001
From: LiuYuanqiang <liuyuanqiang.yqliu at bytedance.com>
Date: Sun, 10 Sep 2023 01:49:49 +0800
Subject: [PATCH] [mlir][pybind] export more options on enable_ir_printing()
 api

---
 mlir/include/mlir-c/Pass.h        |  6 ++++--
 mlir/lib/Bindings/Python/Pass.cpp | 29 +++++++++++++++++++++++++++--
 mlir/lib/CAPI/IR/Pass.cpp         | 20 ++++++++++++++++++--
 3 files changed, 49 insertions(+), 6 deletions(-)

diff --git a/mlir/include/mlir-c/Pass.h b/mlir/include/mlir-c/Pass.h
index 35db138305d1e22..6a59b1642d30367 100644
--- a/mlir/include/mlir-c/Pass.h
+++ b/mlir/include/mlir-c/Pass.h
@@ -75,8 +75,10 @@ MLIR_CAPI_EXPORTED MlirLogicalResult
 mlirPassManagerRunOnOp(MlirPassManager passManager, MlirOperation op);
 
 /// Enable mlir-print-ir-after-all.
-MLIR_CAPI_EXPORTED void
-mlirPassManagerEnableIRPrinting(MlirPassManager passManager);
+MLIR_CAPI_EXPORTED void mlirPassManagerEnableIRPrinting(
+    MlirPassManager passManager, bool printBeforePass, bool printAfterPass,
+    bool printModuleScope, bool printAfterOnlyOnChange,
+    bool printAfterOnlyOnFailure, MlirOpPrintingFlags flags);
 
 /// Enable / disable verify-each.
 MLIR_CAPI_EXPORTED void
diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index cdbfcfbc22957a6..e9a3780eb772e28 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -73,9 +73,34 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
            "Releases (leaks) the backing pass manager (testing)")
       .def(
           "enable_ir_printing",
-          [](PyPassManager &passManager) {
-            mlirPassManagerEnableIRPrinting(passManager.get());
+          [](PyPassManager &passManager, bool printBeforePass,
+             bool printAfterPass, bool printModuleScope,
+             bool printAfterOnlyOnChange, bool printAfterOnlyOnFailure,
+             std::optional<int64_t> largeElementsLimit, bool enableDebugInfo,
+             bool printGenericOpForm) {
+            MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
+            if (largeElementsLimit)
+              mlirOpPrintingFlagsElideLargeElementsAttrs(flags,
+                                                         *largeElementsLimit);
+            if (enableDebugInfo)
+              mlirOpPrintingFlagsEnableDebugInfo(flags, /*enable=*/true,
+                                                 /*prettyForm=*/false);
+            if (printGenericOpForm)
+              mlirOpPrintingFlagsPrintGenericOpForm(flags);
+            mlirPassManagerEnableIRPrinting(passManager.get(), printBeforePass,
+                                            printAfterPass, printModuleScope,
+                                            printAfterOnlyOnChange,
+                                            printAfterOnlyOnFailure, flags);
+            mlirOpPrintingFlagsDestroy(flags);
           },
+          py::arg("print_before_pass") = true,
+          py::arg("print_after_pass") = true,
+          py::arg("print_module_scope") = true,
+          py::arg("print_after_only_on_change") = true,
+          py::arg("print_after_only_on_failure") = false,
+          py::arg("large_elements_limit") = py::none(),
+          py::arg("enable_debug_info") = false,
+          py::arg("print_generic_op_form") = false,
           "Enable mlir-print-ir-after-all.")
       .def(
           "enable_verifier",
diff --git a/mlir/lib/CAPI/IR/Pass.cpp b/mlir/lib/CAPI/IR/Pass.cpp
index d242baae99c0862..d13a71bb19cf794 100644
--- a/mlir/lib/CAPI/IR/Pass.cpp
+++ b/mlir/lib/CAPI/IR/Pass.cpp
@@ -13,6 +13,7 @@
 #include "mlir/CAPI/Support.h"
 #include "mlir/CAPI/Utils.h"
 #include "mlir/Pass/PassManager.h"
+#include <functional>
 #include <optional>
 
 using namespace mlir;
@@ -44,8 +45,23 @@ MlirLogicalResult mlirPassManagerRunOnOp(MlirPassManager passManager,
   return wrap(unwrap(passManager)->run(unwrap(op)));
 }
 
-void mlirPassManagerEnableIRPrinting(MlirPassManager passManager) {
-  return unwrap(passManager)->enableIRPrinting();
+void mlirPassManagerEnableIRPrinting(MlirPassManager passManager,
+                                     bool printBeforePass, bool printAfterPass,
+                                     bool printModuleScope,
+                                     bool printAfterOnlyOnChange,
+                                     bool printAfterOnlyOnFailure,
+                                     MlirOpPrintingFlags flags) {
+  std::function<bool(Pass *, Operation *)> shouldPrintBeforePass = nullptr;
+  std::function<bool(Pass *, Operation *)> shouldPrintAfterPass = nullptr;
+  if (printBeforePass)
+    shouldPrintBeforePass = [](Pass *, Operation *) { return true; };
+  if (printAfterPass)
+    shouldPrintAfterPass = [](Pass *, Operation *) { return true; };
+  return unwrap(passManager)
+      ->enableIRPrinting(shouldPrintBeforePass, shouldPrintAfterPass,
+                         printModuleScope, printAfterOnlyOnChange,
+                         printAfterOnlyOnFailure, /*out=*/llvm::errs(),
+                         *unwrap(flags));
 }
 
 void mlirPassManagerEnableVerifier(MlirPassManager passManager, bool enable) {



More information about the Mlir-commits mailing list