[Mlir-commits] [mlir] [mlir][pybind] export more options on enable_ir_printing() api (PR #65854)
Yuanqiang Liu
llvmlistbot at llvm.org
Sun Apr 14 09:11:16 PDT 2024
https://github.com/qingyunqu updated https://github.com/llvm/llvm-project/pull/65854
>From 835ad09a5fab61af1504779030763dd3f9a306da Mon Sep 17 00:00:00 2001
From: LiuYuanqiang <liuyuanqiang.yqliu at bytedance.com>
Date: Sun, 10 Sep 2023 01:49:49 +0800
Subject: [PATCH] [mlir][pybind] export more options on enable_ir_printing()
api
---
mlir/include/mlir-c/IR.h | 9 ++++++++
mlir/include/mlir-c/Pass.h | 8 ++++++-
mlir/include/mlir/CAPI/IR.h | 3 +++
mlir/include/mlir/Pass/PassManager.h | 11 +++++++++-
mlir/lib/Bindings/Python/Pass.cpp | 33 ++++++++++++++++++++++++++--
mlir/lib/CAPI/IR/IR.cpp | 8 +++++++
mlir/lib/CAPI/IR/Pass.cpp | 18 +++++++++++++--
mlir/lib/Pass/IRPrinting.cpp | 5 ++++-
8 files changed, 88 insertions(+), 7 deletions(-)
diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 82da511f807a34..422c3eb9cf58ad 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -56,6 +56,7 @@ DEFINE_C_API_STRUCT(MlirDialectRegistry, void);
DEFINE_C_API_STRUCT(MlirOperation, void);
DEFINE_C_API_STRUCT(MlirOpOperand, void);
DEFINE_C_API_STRUCT(MlirOpPrintingFlags, void);
+DEFINE_C_API_STRUCT(MlirIRPrinterConfig, void);
DEFINE_C_API_STRUCT(MlirBlock, void);
DEFINE_C_API_STRUCT(MlirRegion, void);
DEFINE_C_API_STRUCT(MlirSymbolTable, void);
@@ -450,6 +451,14 @@ mlirOpPrintingFlagsUseLocalScope(MlirOpPrintingFlags flags);
MLIR_CAPI_EXPORTED void
mlirOpPrintingFlagsAssumeVerified(MlirOpPrintingFlags flags);
+//===----------------------------------------------------------------------===//
+// IR Printing config API.
+//===----------------------------------------------------------------------===//
+
+MLIR_CAPI_EXPORTED MlirIRPrinterConfig mlirIRPrinterConfigCreate(void);
+
+MLIR_CAPI_EXPORTED void mlirIRPrinterConfigDestroy(MlirIRPrinterConfig config);
+
//===----------------------------------------------------------------------===//
// Bytecode printing flags API.
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir-c/Pass.h b/mlir/include/mlir-c/Pass.h
index 35db138305d1e2..6c0fcb98409a51 100644
--- a/mlir/include/mlir-c/Pass.h
+++ b/mlir/include/mlir-c/Pass.h
@@ -75,8 +75,14 @@ MLIR_CAPI_EXPORTED MlirLogicalResult
mlirPassManagerRunOnOp(MlirPassManager passManager, MlirOperation op);
/// Enable mlir-print-ir-after-all.
+// MLIR_CAPI_EXPORTED void mlirPassManagerEnableIRPrinting(
+// MlirPassManager passManager, bool printBeforePass, bool printAfterPass,
+// bool printModuleScope, bool printAfterOnlyOnChange,
+// bool printAfterOnlyOnFailure, MlirOpPrintingFlags flags);
+
MLIR_CAPI_EXPORTED void
-mlirPassManagerEnableIRPrinting(MlirPassManager passManager);
+mlirPassManagerEnableIRPrinting(MlirPassManager passManager,
+ MlirIRPrinterConfig config);
/// Enable / disable verify-each.
MLIR_CAPI_EXPORTED void
diff --git a/mlir/include/mlir/CAPI/IR.h b/mlir/include/mlir/CAPI/IR.h
index 1836cb0acb67e7..488c6fb80836c3 100644
--- a/mlir/include/mlir/CAPI/IR.h
+++ b/mlir/include/mlir/CAPI/IR.h
@@ -20,6 +20,7 @@
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
+#include "mlir/Pass/PassManager.h"
DEFINE_C_API_PTR_METHODS(MlirAsmState, mlir::AsmState)
DEFINE_C_API_PTR_METHODS(MlirBytecodeWriterConfig, mlir::BytecodeWriterConfig)
@@ -30,6 +31,8 @@ DEFINE_C_API_PTR_METHODS(MlirOperation, mlir::Operation)
DEFINE_C_API_PTR_METHODS(MlirBlock, mlir::Block)
DEFINE_C_API_PTR_METHODS(MlirOpOperand, mlir::OpOperand)
DEFINE_C_API_PTR_METHODS(MlirOpPrintingFlags, mlir::OpPrintingFlags)
+DEFINE_C_API_PTR_METHODS(MlirIRPrinterConfig,
+ mlir::PassManager::IRPrinterConfig)
DEFINE_C_API_PTR_METHODS(MlirRegion, mlir::Region)
DEFINE_C_API_PTR_METHODS(MlirSymbolTable, mlir::SymbolTable)
diff --git a/mlir/include/mlir/Pass/PassManager.h b/mlir/include/mlir/Pass/PassManager.h
index 1b2e6a3bc82bb4..5e383bdf373ddd 100644
--- a/mlir/include/mlir/Pass/PassManager.h
+++ b/mlir/include/mlir/Pass/PassManager.h
@@ -307,7 +307,8 @@ class PassManager : public OpPassManager {
/// IR.
explicit IRPrinterConfig(
bool printModuleScope = false, bool printAfterOnlyOnChange = false,
- bool printAfterOnlyOnFailure = false,
+ bool printAfterOnlyOnFailure = false, bool printBeforePass = false,
+ bool printAfterPass = false,
OpPrintingFlags opPrintingFlags = OpPrintingFlags());
virtual ~IRPrinterConfig();
@@ -338,6 +339,10 @@ class PassManager : public OpPassManager {
return printAfterOnlyOnFailure;
}
+ bool shouldPrintBeforePass() const { return printBeforePass; }
+
+ bool shouldPrintAfterPass() const { return printAfterPass; }
+
/// Returns the printing flags to be used to print the IR.
OpPrintingFlags getOpPrintingFlags() const { return opPrintingFlags; }
@@ -353,6 +358,10 @@ class PassManager : public OpPassManager {
/// the pass failed.
bool printAfterOnlyOnFailure;
+ bool printBeforePass;
+
+ bool printAfterPass;
+
/// Flags to control printing behavior.
OpPrintingFlags opPrintingFlags;
};
diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index a68421b61641f6..943cc1cbc19e69 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -74,9 +74,38 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
"Releases (leaks) the backing pass manager (testing)")
.def(
"enable_ir_printing",
- [](PyPassManager &passManager) {
- mlirPassManagerEnableIRPrinting(passManager.get());
+ [](PyPassManager &passManager, bool printBeforePass,
+ bool printAfterPass, bool printModuleScope,
+ bool printAfterOnlyOnChange, bool printAfterOnlyOnFailure,
+ std::optional<int64_t> largeElementsLimit, bool enableDebugInfo,
+ bool printGenericOpForm) {
+ MlirIRPrinterConfig config = mlirIRPrinterConfigCreate();
+ mlirPassManagerEnableIRPrinting(passManager.get(), config);
+ mlirIRPrinterConfigDestroy(config);
+ // MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
+ // if (largeElementsLimit)
+ // mlirOpPrintingFlagsElideLargeElementsAttrs(flags,
+ // *largeElementsLimit);
+ // if (enableDebugInfo)
+ // mlirOpPrintingFlagsEnableDebugInfo(flags, /*enable=*/true,
+ // /*prettyForm=*/false);
+ // if (printGenericOpForm)
+ // mlirOpPrintingFlagsPrintGenericOpForm(flags);
+ // mlirPassManagerEnableIRPrinting(passManager.get(),
+ // printBeforePass,
+ // printAfterPass, printModuleScope,
+ // printAfterOnlyOnChange,
+ // printAfterOnlyOnFailure, flags);
+ // mlirOpPrintingFlagsDestroy(flags);
},
+ py::arg("print_before_pass") = true,
+ py::arg("print_after_pass") = true,
+ py::arg("print_module_scope") = true,
+ py::arg("print_after_only_on_change") = true,
+ py::arg("print_after_only_on_failure") = false,
+ py::arg("large_elements_limit") = py::none(),
+ py::arg("enable_debug_info") = false,
+ py::arg("print_generic_op_form") = false,
"Enable mlir-print-ir-after-all.")
.def(
"enable_verifier",
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index cdb64f4ec4a40f..74f5211f5e9914 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -219,6 +219,14 @@ void mlirOpPrintingFlagsAssumeVerified(MlirOpPrintingFlags flags) {
unwrap(flags)->assumeVerified();
}
+MlirIRPrinterConfig mlirIRPrinterConfigCreate() {
+ return wrap(new PassManager::IRPrinterConfig());
+}
+
+void mlirIRPrinterConfigDestroy(MlirIRPrinterConfig config) {
+ delete unwrap(config);
+}
+
//===----------------------------------------------------------------------===//
// Bytecode printing flags API.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/CAPI/IR/Pass.cpp b/mlir/lib/CAPI/IR/Pass.cpp
index d242baae99c086..0bd77036ccbf0b 100644
--- a/mlir/lib/CAPI/IR/Pass.cpp
+++ b/mlir/lib/CAPI/IR/Pass.cpp
@@ -13,6 +13,7 @@
#include "mlir/CAPI/Support.h"
#include "mlir/CAPI/Utils.h"
#include "mlir/Pass/PassManager.h"
+#include <functional>
#include <optional>
using namespace mlir;
@@ -44,8 +45,21 @@ MlirLogicalResult mlirPassManagerRunOnOp(MlirPassManager passManager,
return wrap(unwrap(passManager)->run(unwrap(op)));
}
-void mlirPassManagerEnableIRPrinting(MlirPassManager passManager) {
- return unwrap(passManager)->enableIRPrinting();
+void mlirPassManagerEnableIRPrinting(MlirPassManager passManager,
+ MlirIRPrinterConfig config) {
+ std::function<bool(Pass *, Operation *)> shouldPrintBeforePass = nullptr;
+ std::function<bool(Pass *, Operation *)> shouldPrintAfterPass = nullptr;
+ if (unwrap(config)->shouldPrintBeforePass())
+ shouldPrintBeforePass = [](Pass *, Operation *) { return true; };
+ if (unwrap(config)->shouldPrintAfterPass())
+ shouldPrintAfterPass = [](Pass *, Operation *) { return true; };
+ return unwrap(passManager)
+ ->enableIRPrinting(shouldPrintBeforePass, shouldPrintAfterPass,
+ unwrap(config)->shouldPrintAtModuleScope(),
+ unwrap(config)->shouldPrintAfterOnlyOnChange(),
+ unwrap(config)->shouldPrintAfterOnlyOnFailure(),
+ /*out=*/llvm::errs(),
+ unwrap(config)->getOpPrintingFlags());
}
void mlirPassManagerEnableVerifier(MlirPassManager passManager, bool enable) {
diff --git a/mlir/lib/Pass/IRPrinting.cpp b/mlir/lib/Pass/IRPrinting.cpp
index 72b94eeb0123fc..a9d847c7354b7e 100644
--- a/mlir/lib/Pass/IRPrinting.cpp
+++ b/mlir/lib/Pass/IRPrinting.cpp
@@ -133,10 +133,13 @@ void IRPrinterInstrumentation::runAfterPassFailed(Pass *pass, Operation *op) {
PassManager::IRPrinterConfig::IRPrinterConfig(bool printModuleScope,
bool printAfterOnlyOnChange,
bool printAfterOnlyOnFailure,
+ bool printBeforePass,
+ bool printAfterPass,
OpPrintingFlags opPrintingFlags)
: printModuleScope(printModuleScope),
printAfterOnlyOnChange(printAfterOnlyOnChange),
printAfterOnlyOnFailure(printAfterOnlyOnFailure),
+ printBeforePass(printBeforePass), printAfterPass(printAfterPass),
opPrintingFlags(opPrintingFlags) {}
PassManager::IRPrinterConfig::~IRPrinterConfig() = default;
@@ -172,7 +175,7 @@ struct BasicIRPrinterConfig : public PassManager::IRPrinterConfig {
bool printAfterOnlyOnFailure, OpPrintingFlags opPrintingFlags,
raw_ostream &out)
: IRPrinterConfig(printModuleScope, printAfterOnlyOnChange,
- printAfterOnlyOnFailure, opPrintingFlags),
+ printAfterOnlyOnFailure, false, false, opPrintingFlags),
shouldPrintBeforePass(std::move(shouldPrintBeforePass)),
shouldPrintAfterPass(std::move(shouldPrintAfterPass)), out(out) {
assert((this->shouldPrintBeforePass || this->shouldPrintAfterPass) &&
More information about the Mlir-commits
mailing list