[Mlir-commits] [mlir] [MLIR][Python] Add the `--mlir-print-ir-tree-dir` to the C and Python API (PR #117339)

Mehdi Amini llvmlistbot at llvm.org
Sat Nov 23 08:59:33 PST 2024


https://github.com/joker-eph updated https://github.com/llvm/llvm-project/pull/117339

>From dc0dad333c3d2dab7938771b76c79f0cb2966cae Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Fri, 22 Nov 2024 14:37:40 +0100
Subject: [PATCH] [MLIR][Python] Add the `--mlir-print-ir-tree-dir` to the C
 and Python API

---
 mlir/include/mlir-c/Pass.h        |  5 +++-
 mlir/lib/Bindings/Python/Pass.cpp | 11 ++++++--
 mlir/lib/CAPI/IR/Pass.cpp         | 18 +++++++++----
 mlir/test/python/pass_manager.py  | 44 ++++++++++++++++++++++++++++++-
 4 files changed, 69 insertions(+), 9 deletions(-)

diff --git a/mlir/include/mlir-c/Pass.h b/mlir/include/mlir-c/Pass.h
index 2218ec0f47d199..6019071cfdaa29 100644
--- a/mlir/include/mlir-c/Pass.h
+++ b/mlir/include/mlir-c/Pass.h
@@ -75,10 +75,13 @@ MLIR_CAPI_EXPORTED MlirLogicalResult
 mlirPassManagerRunOnOp(MlirPassManager passManager, MlirOperation op);
 
 /// Enable IR printing.
+/// The treePrintingPath argument is an optional path to a directory
+/// where the dumps will be produced. If it isn't provided then dumps
+/// are produced to stderr.
 MLIR_CAPI_EXPORTED void mlirPassManagerEnableIRPrinting(
     MlirPassManager passManager, bool printBeforeAll, bool printAfterAll,
     bool printModuleScope, bool printAfterOnlyOnChange,
-    bool printAfterOnlyOnFailure);
+    bool printAfterOnlyOnFailure, MlirStringRef treePrintingPath);
 
 /// Enable / disable verify-each.
 MLIR_CAPI_EXPORTED void
diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index 1d0e5ce2115a0a..e8d28abe6d583a 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -76,14 +76,21 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
           "enable_ir_printing",
           [](PyPassManager &passManager, bool printBeforeAll,
              bool printAfterAll, bool printModuleScope, bool printAfterChange,
-             bool printAfterFailure) {
+             bool printAfterFailure,
+             std::optional<std::string> optionalTreePrintingPath) {
+            std::string treePrintingPath = "";
+            if (optionalTreePrintingPath.has_value())
+              treePrintingPath = optionalTreePrintingPath.value();
             mlirPassManagerEnableIRPrinting(
                 passManager.get(), printBeforeAll, printAfterAll,
-                printModuleScope, printAfterChange, printAfterFailure);
+                printModuleScope, printAfterChange, printAfterFailure,
+                mlirStringRefCreate(treePrintingPath.data(),
+                                    treePrintingPath.size()));
           },
           "print_before_all"_a = false, "print_after_all"_a = true,
           "print_module_scope"_a = false, "print_after_change"_a = false,
           "print_after_failure"_a = false,
+          "tree_printing_dir_path"_a = py::none(),
           "Enable IR printing, default as mlir-print-ir-after-all.")
       .def(
           "enable_verifier",
diff --git a/mlir/lib/CAPI/IR/Pass.cpp b/mlir/lib/CAPI/IR/Pass.cpp
index a6c9fbd08d45a6..01151eafeb5268 100644
--- a/mlir/lib/CAPI/IR/Pass.cpp
+++ b/mlir/lib/CAPI/IR/Pass.cpp
@@ -48,17 +48,25 @@ void mlirPassManagerEnableIRPrinting(MlirPassManager passManager,
                                      bool printBeforeAll, bool printAfterAll,
                                      bool printModuleScope,
                                      bool printAfterOnlyOnChange,
-                                     bool printAfterOnlyOnFailure) {
+                                     bool printAfterOnlyOnFailure,
+                                     MlirStringRef treePrintingPath) {
   auto shouldPrintBeforePass = [printBeforeAll](Pass *, Operation *) {
     return printBeforeAll;
   };
   auto shouldPrintAfterPass = [printAfterAll](Pass *, Operation *) {
     return printAfterAll;
   };
-  return unwrap(passManager)
-      ->enableIRPrinting(shouldPrintBeforePass, shouldPrintAfterPass,
-                         printModuleScope, printAfterOnlyOnChange,
-                         printAfterOnlyOnFailure);
+  if (unwrap(treePrintingPath).empty())
+    return unwrap(passManager)
+        ->enableIRPrinting(shouldPrintBeforePass, shouldPrintAfterPass,
+                           printModuleScope, printAfterOnlyOnChange,
+                           printAfterOnlyOnFailure);
+
+  unwrap(passManager)
+      ->enableIRPrintingToFileTree(shouldPrintBeforePass, shouldPrintAfterPass,
+                                   printModuleScope, printAfterOnlyOnChange,
+                                   printAfterOnlyOnFailure,
+                                   unwrap(treePrintingPath));
 }
 
 void mlirPassManagerEnableVerifier(MlirPassManager passManager, bool enable) {
diff --git a/mlir/test/python/pass_manager.py b/mlir/test/python/pass_manager.py
index 74967032562351..f0698d8d3fb73d 100644
--- a/mlir/test/python/pass_manager.py
+++ b/mlir/test/python/pass_manager.py
@@ -1,6 +1,6 @@
 # RUN: %PYTHON %s 2>&1 | FileCheck %s
 
-import gc, sys
+import gc, os, sys, tempfile
 from mlir.ir import *
 from mlir.passmanager import *
 from mlir.dialects.func import FuncOp
@@ -340,3 +340,45 @@ def testPrintIrBeforeAndAfterAll():
         # CHECK:   }
         # CHECK: }
         pm.run(module)
+
+
+# CHECK-LABEL: TEST: testPrintIrTree
+ at run
+def testPrintIrTree():
+    with Context() as ctx:
+        module = ModuleOp.parse(
+            """
+          module {
+            func.func @main() {
+              %0 = arith.constant 10
+              return
+            }
+          }
+        """
+        )
+        pm = PassManager.parse("builtin.module(canonicalize)")
+        ctx.enable_multithreading(False)
+        pm.enable_ir_printing()
+        # CHECK-LABEL: // Tree printing begin
+        # CHECK: └── builtin_module_no-symbol-name
+        # CHECK:     └── 0_canonicalize.mlir
+        # CHECK-LABEL: // Tree printing end
+        pm.run(module)
+        log("// Tree printing begin")
+        with tempfile.TemporaryDirectory() as temp_dir:
+            pm.enable_ir_printing(tree_printing_dir_path=temp_dir)
+            pm.run(module)
+
+            def print_file_tree(directory, prefix=""):
+                entries = sorted(os.listdir(directory))
+                for i, entry in enumerate(entries):
+                    path = os.path.join(directory, entry)
+                    connector = "└── " if i == len(entries) - 1 else "├── "
+                    log(f"{prefix}{connector}{entry}")
+                    if os.path.isdir(path):
+                        print_file_tree(
+                            path, prefix + ("    " if i == len(entries) - 1 else "│   ")
+                        )
+
+            print_file_tree(temp_dir)
+        log("// Tree printing end")



More information about the Mlir-commits mailing list