[Mlir-commits] [mlir] [mlir][pybind] export more options on enable_ir_printing() api (PR #65854)

Yuanqiang Liu llvmlistbot at llvm.org
Sun Apr 14 09:11:16 PDT 2024


https://github.com/qingyunqu updated https://github.com/llvm/llvm-project/pull/65854

>From 835ad09a5fab61af1504779030763dd3f9a306da Mon Sep 17 00:00:00 2001
From: LiuYuanqiang <liuyuanqiang.yqliu at bytedance.com>
Date: Sun, 10 Sep 2023 01:49:49 +0800
Subject: [PATCH] [mlir][pybind] export more options on enable_ir_printing()
 api

---
 mlir/include/mlir-c/IR.h             |  9 ++++++++
 mlir/include/mlir-c/Pass.h           |  8 ++++++-
 mlir/include/mlir/CAPI/IR.h          |  3 +++
 mlir/include/mlir/Pass/PassManager.h | 11 +++++++++-
 mlir/lib/Bindings/Python/Pass.cpp    | 33 ++++++++++++++++++++++++++--
 mlir/lib/CAPI/IR/IR.cpp              |  8 +++++++
 mlir/lib/CAPI/IR/Pass.cpp            | 18 +++++++++++++--
 mlir/lib/Pass/IRPrinting.cpp         |  5 ++++-
 8 files changed, 88 insertions(+), 7 deletions(-)

diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 82da511f807a34..422c3eb9cf58ad 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -56,6 +56,7 @@ DEFINE_C_API_STRUCT(MlirDialectRegistry, void);
 DEFINE_C_API_STRUCT(MlirOperation, void);
 DEFINE_C_API_STRUCT(MlirOpOperand, void);
 DEFINE_C_API_STRUCT(MlirOpPrintingFlags, void);
+DEFINE_C_API_STRUCT(MlirIRPrinterConfig, void);
 DEFINE_C_API_STRUCT(MlirBlock, void);
 DEFINE_C_API_STRUCT(MlirRegion, void);
 DEFINE_C_API_STRUCT(MlirSymbolTable, void);
@@ -450,6 +451,14 @@ mlirOpPrintingFlagsUseLocalScope(MlirOpPrintingFlags flags);
 MLIR_CAPI_EXPORTED void
 mlirOpPrintingFlagsAssumeVerified(MlirOpPrintingFlags flags);
 
+//===----------------------------------------------------------------------===//
+// IR Printing config API.
+//===----------------------------------------------------------------------===//
+
+MLIR_CAPI_EXPORTED MlirIRPrinterConfig mlirIRPrinterConfigCreate(void);
+
+MLIR_CAPI_EXPORTED void mlirIRPrinterConfigDestroy(MlirIRPrinterConfig config);
+
 //===----------------------------------------------------------------------===//
 // Bytecode printing flags API.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir-c/Pass.h b/mlir/include/mlir-c/Pass.h
index 35db138305d1e2..6c0fcb98409a51 100644
--- a/mlir/include/mlir-c/Pass.h
+++ b/mlir/include/mlir-c/Pass.h
@@ -75,8 +75,14 @@ MLIR_CAPI_EXPORTED MlirLogicalResult
 mlirPassManagerRunOnOp(MlirPassManager passManager, MlirOperation op);
 
 /// Enable mlir-print-ir-after-all.
+// MLIR_CAPI_EXPORTED void mlirPassManagerEnableIRPrinting(
+//     MlirPassManager passManager, bool printBeforePass, bool printAfterPass,
+//     bool printModuleScope, bool printAfterOnlyOnChange,
+//     bool printAfterOnlyOnFailure, MlirOpPrintingFlags flags);
+
 MLIR_CAPI_EXPORTED void
-mlirPassManagerEnableIRPrinting(MlirPassManager passManager);
+mlirPassManagerEnableIRPrinting(MlirPassManager passManager,
+                                MlirIRPrinterConfig config);
 
 /// Enable / disable verify-each.
 MLIR_CAPI_EXPORTED void
diff --git a/mlir/include/mlir/CAPI/IR.h b/mlir/include/mlir/CAPI/IR.h
index 1836cb0acb67e7..488c6fb80836c3 100644
--- a/mlir/include/mlir/CAPI/IR.h
+++ b/mlir/include/mlir/CAPI/IR.h
@@ -20,6 +20,7 @@
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/Operation.h"
+#include "mlir/Pass/PassManager.h"
 
 DEFINE_C_API_PTR_METHODS(MlirAsmState, mlir::AsmState)
 DEFINE_C_API_PTR_METHODS(MlirBytecodeWriterConfig, mlir::BytecodeWriterConfig)
@@ -30,6 +31,8 @@ DEFINE_C_API_PTR_METHODS(MlirOperation, mlir::Operation)
 DEFINE_C_API_PTR_METHODS(MlirBlock, mlir::Block)
 DEFINE_C_API_PTR_METHODS(MlirOpOperand, mlir::OpOperand)
 DEFINE_C_API_PTR_METHODS(MlirOpPrintingFlags, mlir::OpPrintingFlags)
+DEFINE_C_API_PTR_METHODS(MlirIRPrinterConfig,
+                         mlir::PassManager::IRPrinterConfig)
 DEFINE_C_API_PTR_METHODS(MlirRegion, mlir::Region)
 DEFINE_C_API_PTR_METHODS(MlirSymbolTable, mlir::SymbolTable)
 
diff --git a/mlir/include/mlir/Pass/PassManager.h b/mlir/include/mlir/Pass/PassManager.h
index 1b2e6a3bc82bb4..5e383bdf373ddd 100644
--- a/mlir/include/mlir/Pass/PassManager.h
+++ b/mlir/include/mlir/Pass/PassManager.h
@@ -307,7 +307,8 @@ class PassManager : public OpPassManager {
     ///   IR.
     explicit IRPrinterConfig(
         bool printModuleScope = false, bool printAfterOnlyOnChange = false,
-        bool printAfterOnlyOnFailure = false,
+        bool printAfterOnlyOnFailure = false, bool printBeforePass = false,
+        bool printAfterPass = false,
         OpPrintingFlags opPrintingFlags = OpPrintingFlags());
     virtual ~IRPrinterConfig();
 
@@ -338,6 +339,10 @@ class PassManager : public OpPassManager {
       return printAfterOnlyOnFailure;
     }
 
+    bool shouldPrintBeforePass() const { return printBeforePass; }
+
+    bool shouldPrintAfterPass() const { return printAfterPass; }
+
     /// Returns the printing flags to be used to print the IR.
     OpPrintingFlags getOpPrintingFlags() const { return opPrintingFlags; }
 
@@ -353,6 +358,10 @@ class PassManager : public OpPassManager {
     /// the pass failed.
     bool printAfterOnlyOnFailure;
 
+    bool printBeforePass;
+
+    bool printAfterPass;
+
     /// Flags to control printing behavior.
     OpPrintingFlags opPrintingFlags;
   };
diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index a68421b61641f6..943cc1cbc19e69 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -74,9 +74,38 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
            "Releases (leaks) the backing pass manager (testing)")
       .def(
           "enable_ir_printing",
-          [](PyPassManager &passManager) {
-            mlirPassManagerEnableIRPrinting(passManager.get());
+          [](PyPassManager &passManager, bool printBeforePass,
+             bool printAfterPass, bool printModuleScope,
+             bool printAfterOnlyOnChange, bool printAfterOnlyOnFailure,
+             std::optional<int64_t> largeElementsLimit, bool enableDebugInfo,
+             bool printGenericOpForm) {
+            MlirIRPrinterConfig config = mlirIRPrinterConfigCreate();
+            mlirPassManagerEnableIRPrinting(passManager.get(), config);
+            mlirIRPrinterConfigDestroy(config);
+            // MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
+            // if (largeElementsLimit)
+            //   mlirOpPrintingFlagsElideLargeElementsAttrs(flags,
+            //                                              *largeElementsLimit);
+            // if (enableDebugInfo)
+            //   mlirOpPrintingFlagsEnableDebugInfo(flags, /*enable=*/true,
+            //                                      /*prettyForm=*/false);
+            // if (printGenericOpForm)
+            //   mlirOpPrintingFlagsPrintGenericOpForm(flags);
+            // mlirPassManagerEnableIRPrinting(passManager.get(),
+            // printBeforePass,
+            //                                 printAfterPass, printModuleScope,
+            //                                 printAfterOnlyOnChange,
+            //                                 printAfterOnlyOnFailure, flags);
+            // mlirOpPrintingFlagsDestroy(flags);
           },
+          py::arg("print_before_pass") = true,
+          py::arg("print_after_pass") = true,
+          py::arg("print_module_scope") = true,
+          py::arg("print_after_only_on_change") = true,
+          py::arg("print_after_only_on_failure") = false,
+          py::arg("large_elements_limit") = py::none(),
+          py::arg("enable_debug_info") = false,
+          py::arg("print_generic_op_form") = false,
           "Enable mlir-print-ir-after-all.")
       .def(
           "enable_verifier",
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index cdb64f4ec4a40f..74f5211f5e9914 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -219,6 +219,14 @@ void mlirOpPrintingFlagsAssumeVerified(MlirOpPrintingFlags flags) {
   unwrap(flags)->assumeVerified();
 }
 
+MlirIRPrinterConfig mlirIRPrinterConfigCreate() {
+  return wrap(new PassManager::IRPrinterConfig());
+}
+
+void mlirIRPrinterConfigDestroy(MlirIRPrinterConfig config) {
+  delete unwrap(config);
+}
+
 //===----------------------------------------------------------------------===//
 // Bytecode printing flags API.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/CAPI/IR/Pass.cpp b/mlir/lib/CAPI/IR/Pass.cpp
index d242baae99c086..0bd77036ccbf0b 100644
--- a/mlir/lib/CAPI/IR/Pass.cpp
+++ b/mlir/lib/CAPI/IR/Pass.cpp
@@ -13,6 +13,7 @@
 #include "mlir/CAPI/Support.h"
 #include "mlir/CAPI/Utils.h"
 #include "mlir/Pass/PassManager.h"
+#include <functional>
 #include <optional>
 
 using namespace mlir;
@@ -44,8 +45,21 @@ MlirLogicalResult mlirPassManagerRunOnOp(MlirPassManager passManager,
   return wrap(unwrap(passManager)->run(unwrap(op)));
 }
 
-void mlirPassManagerEnableIRPrinting(MlirPassManager passManager) {
-  return unwrap(passManager)->enableIRPrinting();
+void mlirPassManagerEnableIRPrinting(MlirPassManager passManager,
+                                     MlirIRPrinterConfig config) {
+  std::function<bool(Pass *, Operation *)> shouldPrintBeforePass = nullptr;
+  std::function<bool(Pass *, Operation *)> shouldPrintAfterPass = nullptr;
+  if (unwrap(config)->shouldPrintBeforePass())
+    shouldPrintBeforePass = [](Pass *, Operation *) { return true; };
+  if (unwrap(config)->shouldPrintAfterPass())
+    shouldPrintAfterPass = [](Pass *, Operation *) { return true; };
+  return unwrap(passManager)
+      ->enableIRPrinting(shouldPrintBeforePass, shouldPrintAfterPass,
+                         unwrap(config)->shouldPrintAtModuleScope(),
+                         unwrap(config)->shouldPrintAfterOnlyOnChange(),
+                         unwrap(config)->shouldPrintAfterOnlyOnFailure(),
+                         /*out=*/llvm::errs(),
+                         unwrap(config)->getOpPrintingFlags());
 }
 
 void mlirPassManagerEnableVerifier(MlirPassManager passManager, bool enable) {
diff --git a/mlir/lib/Pass/IRPrinting.cpp b/mlir/lib/Pass/IRPrinting.cpp
index 72b94eeb0123fc..a9d847c7354b7e 100644
--- a/mlir/lib/Pass/IRPrinting.cpp
+++ b/mlir/lib/Pass/IRPrinting.cpp
@@ -133,10 +133,13 @@ void IRPrinterInstrumentation::runAfterPassFailed(Pass *pass, Operation *op) {
 PassManager::IRPrinterConfig::IRPrinterConfig(bool printModuleScope,
                                               bool printAfterOnlyOnChange,
                                               bool printAfterOnlyOnFailure,
+                                              bool printBeforePass,
+                                              bool printAfterPass,
                                               OpPrintingFlags opPrintingFlags)
     : printModuleScope(printModuleScope),
       printAfterOnlyOnChange(printAfterOnlyOnChange),
       printAfterOnlyOnFailure(printAfterOnlyOnFailure),
+      printBeforePass(printBeforePass), printAfterPass(printAfterPass),
       opPrintingFlags(opPrintingFlags) {}
 PassManager::IRPrinterConfig::~IRPrinterConfig() = default;
 
@@ -172,7 +175,7 @@ struct BasicIRPrinterConfig : public PassManager::IRPrinterConfig {
       bool printAfterOnlyOnFailure, OpPrintingFlags opPrintingFlags,
       raw_ostream &out)
       : IRPrinterConfig(printModuleScope, printAfterOnlyOnChange,
-                        printAfterOnlyOnFailure, opPrintingFlags),
+                        printAfterOnlyOnFailure, false, false, opPrintingFlags),
         shouldPrintBeforePass(std::move(shouldPrintBeforePass)),
         shouldPrintAfterPass(std::move(shouldPrintAfterPass)), out(out) {
     assert((this->shouldPrintBeforePass || this->shouldPrintAfterPass) &&



More information about the Mlir-commits mailing list