[Mlir-commits] [mlir] [MLIR][Python] enhance python ir printing with pringing flags (PR #117836)

Yuanqiang Liu llvmlistbot at llvm.org
Thu Dec 5 00:39:33 PST 2024


https://github.com/qingyunqu updated https://github.com/llvm/llvm-project/pull/117836

>From a7d8bd9a6dd19da559d245044ee3b62b7e452516 Mon Sep 17 00:00:00 2001
From: Yuanqiang Liu <liuyuanqiang.yqliu at bytedance.com>
Date: Wed, 27 Nov 2024 11:28:58 +0800
Subject: [PATCH 1/3] [MLIR][Python] enhance python ir printing with pringing
 flags

---
 mlir/include/mlir-c/Pass.h                      |  3 ++-
 mlir/lib/Bindings/Python/Pass.cpp               | 17 +++++++++++++++--
 mlir/lib/CAPI/IR/Pass.cpp                       |  6 ++++--
 .../mlir/_mlir_libs/_mlir/passmanager.pyi       |  4 ++++
 4 files changed, 25 insertions(+), 5 deletions(-)

diff --git a/mlir/include/mlir-c/Pass.h b/mlir/include/mlir-c/Pass.h
index 6019071cfdaa29..8fd8e9956a65a3 100644
--- a/mlir/include/mlir-c/Pass.h
+++ b/mlir/include/mlir-c/Pass.h
@@ -81,7 +81,8 @@ mlirPassManagerRunOnOp(MlirPassManager passManager, MlirOperation op);
 MLIR_CAPI_EXPORTED void mlirPassManagerEnableIRPrinting(
     MlirPassManager passManager, bool printBeforeAll, bool printAfterAll,
     bool printModuleScope, bool printAfterOnlyOnChange,
-    bool printAfterOnlyOnFailure, MlirStringRef treePrintingPath);
+    bool printAfterOnlyOnFailure, MlirOpPrintingFlags flags,
+    MlirStringRef treePrintingPath);
 
 /// Enable / disable verify-each.
 MLIR_CAPI_EXPORTED void
diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index e8d28abe6d583a..e991deaae2daa5 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -76,20 +76,33 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
           "enable_ir_printing",
           [](PyPassManager &passManager, bool printBeforeAll,
              bool printAfterAll, bool printModuleScope, bool printAfterChange,
-             bool printAfterFailure,
+             bool printAfterFailure, std::optional<int64_t> largeElementsLimit,
+             bool enableDebugInfo, bool printGenericOpForm,
              std::optional<std::string> optionalTreePrintingPath) {
+            MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
+            if (largeElementsLimit)
+              mlirOpPrintingFlagsElideLargeElementsAttrs(flags,
+                                                         *largeElementsLimit);
+            if (enableDebugInfo)
+              mlirOpPrintingFlagsEnableDebugInfo(flags, /*enable=*/true,
+                                                 /*prettyForm=*/false);
+            if (printGenericOpForm)
+              mlirOpPrintingFlagsPrintGenericOpForm(flags);
             std::string treePrintingPath = "";
             if (optionalTreePrintingPath.has_value())
               treePrintingPath = optionalTreePrintingPath.value();
             mlirPassManagerEnableIRPrinting(
                 passManager.get(), printBeforeAll, printAfterAll,
-                printModuleScope, printAfterChange, printAfterFailure,
+                printModuleScope, printAfterChange, printAfterFailure, flags,
                 mlirStringRefCreate(treePrintingPath.data(),
                                     treePrintingPath.size()));
+            mlirOpPrintingFlagsDestroy(flags);
           },
           "print_before_all"_a = false, "print_after_all"_a = true,
           "print_module_scope"_a = false, "print_after_change"_a = false,
           "print_after_failure"_a = false,
+          "large_elements_limit"_a = py::none(), "enable_debug_info"_a = false,
+          "print_generic_op_form"_a = false,
           "tree_printing_dir_path"_a = py::none(),
           "Enable IR printing, default as mlir-print-ir-after-all.")
       .def(
diff --git a/mlir/lib/CAPI/IR/Pass.cpp b/mlir/lib/CAPI/IR/Pass.cpp
index 01151eafeb5268..883b7e8bb832d2 100644
--- a/mlir/lib/CAPI/IR/Pass.cpp
+++ b/mlir/lib/CAPI/IR/Pass.cpp
@@ -49,6 +49,7 @@ void mlirPassManagerEnableIRPrinting(MlirPassManager passManager,
                                      bool printModuleScope,
                                      bool printAfterOnlyOnChange,
                                      bool printAfterOnlyOnFailure,
+                                     MlirOpPrintingFlags flags,
                                      MlirStringRef treePrintingPath) {
   auto shouldPrintBeforePass = [printBeforeAll](Pass *, Operation *) {
     return printBeforeAll;
@@ -60,13 +61,14 @@ void mlirPassManagerEnableIRPrinting(MlirPassManager passManager,
     return unwrap(passManager)
         ->enableIRPrinting(shouldPrintBeforePass, shouldPrintAfterPass,
                            printModuleScope, printAfterOnlyOnChange,
-                           printAfterOnlyOnFailure);
+                           printAfterOnlyOnFailure, /*out=*/llvm::errs(),
+                           *unwrap(flags));
 
   unwrap(passManager)
       ->enableIRPrintingToFileTree(shouldPrintBeforePass, shouldPrintAfterPass,
                                    printModuleScope, printAfterOnlyOnChange,
                                    printAfterOnlyOnFailure,
-                                   unwrap(treePrintingPath));
+                                   unwrap(treePrintingPath), *unwrap(flags));
 }
 
 void mlirPassManagerEnableVerifier(MlirPassManager passManager, bool enable) {
diff --git a/mlir/python/mlir/_mlir_libs/_mlir/passmanager.pyi b/mlir/python/mlir/_mlir_libs/_mlir/passmanager.pyi
index 229979ae33608c..0d2eaffe16d3ec 100644
--- a/mlir/python/mlir/_mlir_libs/_mlir/passmanager.pyi
+++ b/mlir/python/mlir/_mlir_libs/_mlir/passmanager.pyi
@@ -22,6 +22,10 @@ class PassManager:
         print_module_scope: bool = False,
         print_after_change: bool = False,
         print_after_failure: bool = False,
+        large_elements_limit: int | None = None,
+        enable_debug_info: bool = False,
+        print_generic_op_form: bool = False,
+        tree_printing_dir_path: str | None = None,
     ) -> None: ...
     def enable_verifier(self, enable: bool) -> None: ...
     @staticmethod

>From d452e81131a6db19af92b2f9c86ec07646473211 Mon Sep 17 00:00:00 2001
From: Yuanqiang Liu <liuyuanqiang.yqliu at bytedance.com>
Date: Wed, 4 Dec 2024 15:16:38 +0800
Subject: [PATCH 2/3] add test

---
 mlir/test/python/pass_manager.py | 27 +++++++++++++++++++++++++++
 1 file changed, 27 insertions(+)

diff --git a/mlir/test/python/pass_manager.py b/mlir/test/python/pass_manager.py
index a794a3fc6fa006..0e555d0dc48583 100644
--- a/mlir/test/python/pass_manager.py
+++ b/mlir/test/python/pass_manager.py
@@ -342,6 +342,33 @@ def testPrintIrBeforeAndAfterAll():
         pm.run(module)
 
 
+# CHECK-LABEL: TEST: testPrintIrLargeLimitElements
+ at run
+def testPrintIrLargeLimitElements():
+    with Context() as ctx:
+        module = ModuleOp.parse(
+            """
+          module {
+            func.func @main() -> tensor<3xi64> {
+              %0 = arith.constant dense<[1, 2, 3]> : tensor<3xi64>
+              return %0 : tensor<3xi64>
+            }
+          }
+        """
+        )
+        pm = PassManager.parse("builtin.module(canonicalize)")
+        ctx.enable_multithreading(False)
+        pm.enable_ir_printing(large_elements_limit=2)
+        # CHECK: // -----// IR Dump After Canonicalizer (canonicalize) //----- //
+        # CHECK: module {
+        # CHECK:   func.func @main() -> tensor<3xi64> {
+        # CHECK:     %[[CST:.*]] = arith.constant dense_resource<__elided__> : tensor<3xi64>
+        # CHECK:     return %[[CST]] : tensor<3xi64>
+        # CHECK:   }
+        # CHECK: }
+        pm.run(module)
+
+
 # CHECK-LABEL: TEST: testPrintIrTree
 @run
 def testPrintIrTree():

>From ea27fb4481be5a42fa4227359ff0682e4d814fae Mon Sep 17 00:00:00 2001
From: Yuanqiang Liu <liuyuanqiang.yqliu at bytedance.com>
Date: Thu, 5 Dec 2024 16:39:14 +0800
Subject: [PATCH 3/3] update test

---
 mlir/test/python/pass_manager.py | 6 ------
 1 file changed, 6 deletions(-)

diff --git a/mlir/test/python/pass_manager.py b/mlir/test/python/pass_manager.py
index 0e555d0dc48583..ecac57e3302f01 100644
--- a/mlir/test/python/pass_manager.py
+++ b/mlir/test/python/pass_manager.py
@@ -359,13 +359,7 @@ def testPrintIrLargeLimitElements():
         pm = PassManager.parse("builtin.module(canonicalize)")
         ctx.enable_multithreading(False)
         pm.enable_ir_printing(large_elements_limit=2)
-        # CHECK: // -----// IR Dump After Canonicalizer (canonicalize) //----- //
-        # CHECK: module {
-        # CHECK:   func.func @main() -> tensor<3xi64> {
         # CHECK:     %[[CST:.*]] = arith.constant dense_resource<__elided__> : tensor<3xi64>
-        # CHECK:     return %[[CST]] : tensor<3xi64>
-        # CHECK:   }
-        # CHECK: }
         pm.run(module)
 
 



More information about the Mlir-commits mailing list