[Mlir-commits] [mlir] c8b837a - [MLIR][Python] Add the `--mlir-print-ir-tree-dir` to the C and Python API (#117339)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Nov 23 11:17:29 PST 2024


Author: Mehdi Amini
Date: 2024-11-23T20:17:25+01:00
New Revision: c8b837ad8ce4f36a3b2e47f1f1367dc0b41fca7b

URL: https://github.com/llvm/llvm-project/commit/c8b837ad8ce4f36a3b2e47f1f1367dc0b41fca7b
DIFF: https://github.com/llvm/llvm-project/commit/c8b837ad8ce4f36a3b2e47f1f1367dc0b41fca7b.diff

LOG: [MLIR][Python] Add the `--mlir-print-ir-tree-dir` to the C and Python API (#117339)

Added: 
    

Modified: 
    mlir/include/mlir-c/Pass.h
    mlir/lib/Bindings/Python/Pass.cpp
    mlir/lib/CAPI/IR/Pass.cpp
    mlir/test/python/pass_manager.py

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/Pass.h b/mlir/include/mlir-c/Pass.h
index 2218ec0f47d199..6019071cfdaa29 100644
--- a/mlir/include/mlir-c/Pass.h
+++ b/mlir/include/mlir-c/Pass.h
@@ -75,10 +75,13 @@ MLIR_CAPI_EXPORTED MlirLogicalResult
 mlirPassManagerRunOnOp(MlirPassManager passManager, MlirOperation op);
 
 /// Enable IR printing.
+/// The treePrintingPath argument is an optional path to a directory
+/// where the dumps will be produced. If it isn't provided then dumps
+/// are produced to stderr.
 MLIR_CAPI_EXPORTED void mlirPassManagerEnableIRPrinting(
     MlirPassManager passManager, bool printBeforeAll, bool printAfterAll,
     bool printModuleScope, bool printAfterOnlyOnChange,
-    bool printAfterOnlyOnFailure);
+    bool printAfterOnlyOnFailure, MlirStringRef treePrintingPath);
 
 /// Enable / disable verify-each.
 MLIR_CAPI_EXPORTED void

diff  --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index 1d0e5ce2115a0a..e8d28abe6d583a 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -76,14 +76,21 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
           "enable_ir_printing",
           [](PyPassManager &passManager, bool printBeforeAll,
              bool printAfterAll, bool printModuleScope, bool printAfterChange,
-             bool printAfterFailure) {
+             bool printAfterFailure,
+             std::optional<std::string> optionalTreePrintingPath) {
+            std::string treePrintingPath = "";
+            if (optionalTreePrintingPath.has_value())
+              treePrintingPath = optionalTreePrintingPath.value();
             mlirPassManagerEnableIRPrinting(
                 passManager.get(), printBeforeAll, printAfterAll,
-                printModuleScope, printAfterChange, printAfterFailure);
+                printModuleScope, printAfterChange, printAfterFailure,
+                mlirStringRefCreate(treePrintingPath.data(),
+                                    treePrintingPath.size()));
           },
           "print_before_all"_a = false, "print_after_all"_a = true,
           "print_module_scope"_a = false, "print_after_change"_a = false,
           "print_after_failure"_a = false,
+          "tree_printing_dir_path"_a = py::none(),
           "Enable IR printing, default as mlir-print-ir-after-all.")
       .def(
           "enable_verifier",

diff  --git a/mlir/lib/CAPI/IR/Pass.cpp b/mlir/lib/CAPI/IR/Pass.cpp
index a6c9fbd08d45a6..01151eafeb5268 100644
--- a/mlir/lib/CAPI/IR/Pass.cpp
+++ b/mlir/lib/CAPI/IR/Pass.cpp
@@ -48,17 +48,25 @@ void mlirPassManagerEnableIRPrinting(MlirPassManager passManager,
                                      bool printBeforeAll, bool printAfterAll,
                                      bool printModuleScope,
                                      bool printAfterOnlyOnChange,
-                                     bool printAfterOnlyOnFailure) {
+                                     bool printAfterOnlyOnFailure,
+                                     MlirStringRef treePrintingPath) {
   auto shouldPrintBeforePass = [printBeforeAll](Pass *, Operation *) {
     return printBeforeAll;
   };
   auto shouldPrintAfterPass = [printAfterAll](Pass *, Operation *) {
     return printAfterAll;
   };
-  return unwrap(passManager)
-      ->enableIRPrinting(shouldPrintBeforePass, shouldPrintAfterPass,
-                         printModuleScope, printAfterOnlyOnChange,
-                         printAfterOnlyOnFailure);
+  if (unwrap(treePrintingPath).empty())
+    return unwrap(passManager)
+        ->enableIRPrinting(shouldPrintBeforePass, shouldPrintAfterPass,
+                           printModuleScope, printAfterOnlyOnChange,
+                           printAfterOnlyOnFailure);
+
+  unwrap(passManager)
+      ->enableIRPrintingToFileTree(shouldPrintBeforePass, shouldPrintAfterPass,
+                                   printModuleScope, printAfterOnlyOnChange,
+                                   printAfterOnlyOnFailure,
+                                   unwrap(treePrintingPath));
 }
 
 void mlirPassManagerEnableVerifier(MlirPassManager passManager, bool enable) {

diff  --git a/mlir/test/python/pass_manager.py b/mlir/test/python/pass_manager.py
index 74967032562351..a794a3fc6fa006 100644
--- a/mlir/test/python/pass_manager.py
+++ b/mlir/test/python/pass_manager.py
@@ -1,6 +1,6 @@
 # RUN: %PYTHON %s 2>&1 | FileCheck %s
 
-import gc, sys
+import gc, os, sys, tempfile
 from mlir.ir import *
 from mlir.passmanager import *
 from mlir.dialects.func import FuncOp
@@ -340,3 +340,45 @@ def testPrintIrBeforeAndAfterAll():
         # CHECK:   }
         # CHECK: }
         pm.run(module)
+
+
+# CHECK-LABEL: TEST: testPrintIrTree
+ at run
+def testPrintIrTree():
+    with Context() as ctx:
+        module = ModuleOp.parse(
+            """
+          module {
+            func.func @main() {
+              %0 = arith.constant 10
+              return
+            }
+          }
+        """
+        )
+        pm = PassManager.parse("builtin.module(canonicalize)")
+        ctx.enable_multithreading(False)
+        pm.enable_ir_printing()
+        # CHECK-LABEL: // Tree printing begin
+        # CHECK: \-- builtin_module_no-symbol-name
+        # CHECK:     \-- 0_canonicalize.mlir
+        # CHECK-LABEL: // Tree printing end
+        pm.run(module)
+        log("// Tree printing begin")
+        with tempfile.TemporaryDirectory() as temp_dir:
+            pm.enable_ir_printing(tree_printing_dir_path=temp_dir)
+            pm.run(module)
+
+            def print_file_tree(directory, prefix=""):
+                entries = sorted(os.listdir(directory))
+                for i, entry in enumerate(entries):
+                    path = os.path.join(directory, entry)
+                    connector = "\-- " if i == len(entries) - 1 else "|-- "
+                    log(f"{prefix}{connector}{entry}")
+                    if os.path.isdir(path):
+                        print_file_tree(
+                            path, prefix + ("    " if i == len(entries) - 1 else "│   ")
+                        )
+
+            print_file_tree(temp_dir)
+        log("// Tree printing end")


        


More information about the Mlir-commits mailing list