[Mlir-commits] [mlir] [mlir][py] Overload print with state. (PR #72064)

Jacques Pienaar llvmlistbot at llvm.org
Sun Nov 12 11:28:52 PST 2023


https://github.com/jpienaar created https://github.com/llvm/llvm-project/pull/72064

Enables reusing the AsmState when printing from Python. Also moves the fileObject and binary to the end (pybind11::object was resulting in the overload not working unless `state=` was specified).

>From 4e79e345c1d041eda7bc18003cd7a58e5cd97155 Mon Sep 17 00:00:00 2001
From: Jacques Pienaar <jpienaar at google.com>
Date: Wed, 8 Nov 2023 15:13:02 -0800
Subject: [PATCH] [mlir][py] Overload print with state.

Enables reusing the AsmState when printing from Python. Also moves the
fileObject and binary to the end (pybind11::object was resulting in the
overload not working unless `state=` was specified).
---
 mlir/lib/Bindings/Python/IRCore.cpp | 47 +++++++++++++++++++++++------
 mlir/lib/Bindings/Python/IRModule.h |  9 ++++--
 mlir/test/python/ir/operation.py    |  7 ++++-
 3 files changed, 50 insertions(+), 13 deletions(-)

diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 0f2ca666ccc050e..a4330b062532763 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -110,6 +110,15 @@ static const char kOperationPrintDocstring[] =
     invalid, behavior is undefined.
 )";
 
+static const char kOperationPrintStateDocstring[] =
+    R"(Prints the assembly form of the operation to a file like object.
+
+Args:
+  file: The file like object to write to. Defaults to sys.stdout.
+  binary: Whether to write bytes (True) or str (False). Defaults to False.
+  state: AsmState capturing the operation numbering and flags.
+)";
+
 static const char kOperationGetAsmDocstring[] =
     R"(Gets the assembly form of the operation with all options available.
 
@@ -1169,11 +1178,11 @@ void PyOperation::checkValid() const {
   }
 }
 
-void PyOperationBase::print(py::object fileObject, bool binary,
-                            std::optional<int64_t> largeElementsLimit,
+void PyOperationBase::print(std::optional<int64_t> largeElementsLimit,
                             bool enableDebugInfo, bool prettyDebugInfo,
                             bool printGenericOpForm, bool useLocalScope,
-                            bool assumeVerified) {
+                            bool assumeVerified, py::object fileObject,
+                            bool binary) {
   PyOperation &operation = getOperation();
   operation.checkValid();
   if (fileObject.is_none())
@@ -1198,6 +1207,17 @@ void PyOperationBase::print(py::object fileObject, bool binary,
   mlirOpPrintingFlagsDestroy(flags);
 }
 
+void PyOperationBase::print(PyAsmState &state, pybind11::object fileObject,
+                            bool binary) {
+  PyOperation &operation = getOperation();
+  operation.checkValid();
+  if (fileObject.is_none())
+    fileObject = py::module::import("sys").attr("stdout");
+  PyFileAccumulator accum(fileObject, binary);
+  mlirOperationPrintWithState(operation, state.get(), accum.getCallback(),
+                              accum.getUserData());
+}
+
 void PyOperationBase::writeBytecode(const py::object &fileObject,
                                     std::optional<int64_t> bytecodeVersion) {
   PyOperation &operation = getOperation();
@@ -1230,13 +1250,14 @@ py::object PyOperationBase::getAsm(bool binary,
   } else {
     fileObject = py::module::import("io").attr("StringIO")();
   }
-  print(fileObject, /*binary=*/binary,
-        /*largeElementsLimit=*/largeElementsLimit,
+  print(/*largeElementsLimit=*/largeElementsLimit,
         /*enableDebugInfo=*/enableDebugInfo,
         /*prettyDebugInfo=*/prettyDebugInfo,
         /*printGenericOpForm=*/printGenericOpForm,
         /*useLocalScope=*/useLocalScope,
-        /*assumeVerified=*/assumeVerified);
+        /*assumeVerified=*/assumeVerified,
+        /*fileObject=*/fileObject,
+        /*binary=*/binary);
 
   return fileObject.attr("getvalue")();
 }
@@ -2946,15 +2967,23 @@ void mlir::python::populateIRCore(py::module &m) {
                                /*assumeVerified=*/false);
           },
           "Returns the assembly form of the operation.")
-      .def("print", &PyOperationBase::print,
+      .def("print",
+           py::overload_cast<PyAsmState &, pybind11::object, bool>(
+               &PyOperationBase::print),
+           py::arg("state"), py::arg("file") = py::none(),
+           py::arg("binary") = false, kOperationPrintStateDocstring)
+      .def("print",
+           py::overload_cast<std::optional<int64_t>, bool, bool, bool, bool,
+                             bool, pybind11::object, bool>(
+               &PyOperationBase::print),
            // Careful: Lots of arguments must match up with print method.
-           py::arg("file") = py::none(), py::arg("binary") = false,
            py::arg("large_elements_limit") = py::none(),
            py::arg("enable_debug_info") = false,
            py::arg("pretty_debug_info") = false,
            py::arg("print_generic_op_form") = false,
            py::arg("use_local_scope") = false,
-           py::arg("assume_verified") = false, kOperationPrintDocstring)
+           py::arg("assume_verified") = false, py::arg("file") = py::none(),
+           py::arg("binary") = false, kOperationPrintDocstring)
       .def("write_bytecode", &PyOperationBase::writeBytecode, py::arg("file"),
            py::arg("desired_version") = py::none(),
            kOperationPrintBytecodeDocstring)
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index af55693f18fbbf9..3f856681881829a 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -550,16 +550,19 @@ class PyModule : public BaseContextObject {
   pybind11::handle handle;
 };
 
+class PyAsmState;
+
 /// Base class for PyOperation and PyOpView which exposes the primary, user
 /// visible methods for manipulating it.
 class PyOperationBase {
 public:
   virtual ~PyOperationBase() = default;
   /// Implements the bound 'print' method and helps with others.
-  void print(pybind11::object fileObject, bool binary,
-             std::optional<int64_t> largeElementsLimit, bool enableDebugInfo,
+  void print(std::optional<int64_t> largeElementsLimit, bool enableDebugInfo,
              bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope,
-             bool assumeVerified);
+             bool assumeVerified, pybind11::object fileObject, bool binary);
+  void print(PyAsmState &state, pybind11::object fileObject, bool binary);
+
   pybind11::object getAsm(bool binary,
                           std::optional<int64_t> largeElementsLimit,
                           bool enableDebugInfo, bool prettyDebugInfo,
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index 04239b048c1c641..04f8a9936e31f79 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -622,10 +622,15 @@ def testOperationPrint():
     print(bytes_value.__class__)
     print(bytes_value)
 
-    # Test get_asm local_scope.
+    # Test print local_scope.
     # CHECK: constant dense<[1, 2, 3, 4]> : tensor<4xi32> loc("nom")
     module.operation.print(enable_debug_info=True, use_local_scope=True)
 
+    # Test printing using state.
+    state = AsmState(module.operation)
+    # CHECK: constant dense<[1, 2, 3, 4]> : tensor<4xi32>
+    module.operation.print(state)
+
     # Test get_asm with options.
     # CHECK: value = dense_resource<__elided__> : tensor<4xi32>
     # CHECK: "func.return"(%arg0) : (i32) -> () -:4:7



More information about the Mlir-commits mailing list