[Mlir-commits] [mlir] c8b837a - [MLIR][Python] Add the `--mlir-print-ir-tree-dir` to the C and Python API (#117339)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Nov 23 11:17:29 PST 2024
Author: Mehdi Amini
Date: 2024-11-23T20:17:25+01:00
New Revision: c8b837ad8ce4f36a3b2e47f1f1367dc0b41fca7b
URL: https://github.com/llvm/llvm-project/commit/c8b837ad8ce4f36a3b2e47f1f1367dc0b41fca7b
DIFF: https://github.com/llvm/llvm-project/commit/c8b837ad8ce4f36a3b2e47f1f1367dc0b41fca7b.diff
LOG: [MLIR][Python] Add the `--mlir-print-ir-tree-dir` to the C and Python API (#117339)
Added:
Modified:
mlir/include/mlir-c/Pass.h
mlir/lib/Bindings/Python/Pass.cpp
mlir/lib/CAPI/IR/Pass.cpp
mlir/test/python/pass_manager.py
Removed:
################################################################################
diff --git a/mlir/include/mlir-c/Pass.h b/mlir/include/mlir-c/Pass.h
index 2218ec0f47d199..6019071cfdaa29 100644
--- a/mlir/include/mlir-c/Pass.h
+++ b/mlir/include/mlir-c/Pass.h
@@ -75,10 +75,13 @@ MLIR_CAPI_EXPORTED MlirLogicalResult
mlirPassManagerRunOnOp(MlirPassManager passManager, MlirOperation op);
/// Enable IR printing.
+/// The treePrintingPath argument is an optional path to a directory
+/// where the dumps will be produced. If it isn't provided then dumps
+/// are produced to stderr.
MLIR_CAPI_EXPORTED void mlirPassManagerEnableIRPrinting(
MlirPassManager passManager, bool printBeforeAll, bool printAfterAll,
bool printModuleScope, bool printAfterOnlyOnChange,
- bool printAfterOnlyOnFailure);
+ bool printAfterOnlyOnFailure, MlirStringRef treePrintingPath);
/// Enable / disable verify-each.
MLIR_CAPI_EXPORTED void
diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index 1d0e5ce2115a0a..e8d28abe6d583a 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -76,14 +76,21 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
"enable_ir_printing",
[](PyPassManager &passManager, bool printBeforeAll,
bool printAfterAll, bool printModuleScope, bool printAfterChange,
- bool printAfterFailure) {
+ bool printAfterFailure,
+ std::optional<std::string> optionalTreePrintingPath) {
+ std::string treePrintingPath = "";
+ if (optionalTreePrintingPath.has_value())
+ treePrintingPath = optionalTreePrintingPath.value();
mlirPassManagerEnableIRPrinting(
passManager.get(), printBeforeAll, printAfterAll,
- printModuleScope, printAfterChange, printAfterFailure);
+ printModuleScope, printAfterChange, printAfterFailure,
+ mlirStringRefCreate(treePrintingPath.data(),
+ treePrintingPath.size()));
},
"print_before_all"_a = false, "print_after_all"_a = true,
"print_module_scope"_a = false, "print_after_change"_a = false,
"print_after_failure"_a = false,
+ "tree_printing_dir_path"_a = py::none(),
"Enable IR printing, default as mlir-print-ir-after-all.")
.def(
"enable_verifier",
diff --git a/mlir/lib/CAPI/IR/Pass.cpp b/mlir/lib/CAPI/IR/Pass.cpp
index a6c9fbd08d45a6..01151eafeb5268 100644
--- a/mlir/lib/CAPI/IR/Pass.cpp
+++ b/mlir/lib/CAPI/IR/Pass.cpp
@@ -48,17 +48,25 @@ void mlirPassManagerEnableIRPrinting(MlirPassManager passManager,
bool printBeforeAll, bool printAfterAll,
bool printModuleScope,
bool printAfterOnlyOnChange,
- bool printAfterOnlyOnFailure) {
+ bool printAfterOnlyOnFailure,
+ MlirStringRef treePrintingPath) {
auto shouldPrintBeforePass = [printBeforeAll](Pass *, Operation *) {
return printBeforeAll;
};
auto shouldPrintAfterPass = [printAfterAll](Pass *, Operation *) {
return printAfterAll;
};
- return unwrap(passManager)
- ->enableIRPrinting(shouldPrintBeforePass, shouldPrintAfterPass,
- printModuleScope, printAfterOnlyOnChange,
- printAfterOnlyOnFailure);
+ if (unwrap(treePrintingPath).empty())
+ return unwrap(passManager)
+ ->enableIRPrinting(shouldPrintBeforePass, shouldPrintAfterPass,
+ printModuleScope, printAfterOnlyOnChange,
+ printAfterOnlyOnFailure);
+
+ unwrap(passManager)
+ ->enableIRPrintingToFileTree(shouldPrintBeforePass, shouldPrintAfterPass,
+ printModuleScope, printAfterOnlyOnChange,
+ printAfterOnlyOnFailure,
+ unwrap(treePrintingPath));
}
void mlirPassManagerEnableVerifier(MlirPassManager passManager, bool enable) {
diff --git a/mlir/test/python/pass_manager.py b/mlir/test/python/pass_manager.py
index 74967032562351..a794a3fc6fa006 100644
--- a/mlir/test/python/pass_manager.py
+++ b/mlir/test/python/pass_manager.py
@@ -1,6 +1,6 @@
# RUN: %PYTHON %s 2>&1 | FileCheck %s
-import gc, sys
+import gc, os, sys, tempfile
from mlir.ir import *
from mlir.passmanager import *
from mlir.dialects.func import FuncOp
@@ -340,3 +340,45 @@ def testPrintIrBeforeAndAfterAll():
# CHECK: }
# CHECK: }
pm.run(module)
+
+
+# CHECK-LABEL: TEST: testPrintIrTree
+ at run
+def testPrintIrTree():
+ with Context() as ctx:
+ module = ModuleOp.parse(
+ """
+ module {
+ func.func @main() {
+ %0 = arith.constant 10
+ return
+ }
+ }
+ """
+ )
+ pm = PassManager.parse("builtin.module(canonicalize)")
+ ctx.enable_multithreading(False)
+ pm.enable_ir_printing()
+ # CHECK-LABEL: // Tree printing begin
+ # CHECK: \-- builtin_module_no-symbol-name
+ # CHECK: \-- 0_canonicalize.mlir
+ # CHECK-LABEL: // Tree printing end
+ pm.run(module)
+ log("// Tree printing begin")
+ with tempfile.TemporaryDirectory() as temp_dir:
+ pm.enable_ir_printing(tree_printing_dir_path=temp_dir)
+ pm.run(module)
+
+ def print_file_tree(directory, prefix=""):
+ entries = sorted(os.listdir(directory))
+ for i, entry in enumerate(entries):
+ path = os.path.join(directory, entry)
+ connector = "\-- " if i == len(entries) - 1 else "|-- "
+ log(f"{prefix}{connector}{entry}")
+ if os.path.isdir(path):
+ print_file_tree(
+ path, prefix + (" " if i == len(entries) - 1 else "│ ")
+ )
+
+ print_file_tree(temp_dir)
+ log("// Tree printing end")
More information about the Mlir-commits
mailing list