[Mlir-commits] [mlir] [mlir][py] Overload print with state. (PR #72064)
Jacques Pienaar
llvmlistbot at llvm.org
Sun Nov 12 11:28:52 PST 2023
https://github.com/jpienaar created https://github.com/llvm/llvm-project/pull/72064
Enables reusing the AsmState when printing from Python. Also moves the fileObject and binary to the end (pybind11::object was resulting in the overload not working unless `state=` was specified).
>From 4e79e345c1d041eda7bc18003cd7a58e5cd97155 Mon Sep 17 00:00:00 2001
From: Jacques Pienaar <jpienaar at google.com>
Date: Wed, 8 Nov 2023 15:13:02 -0800
Subject: [PATCH] [mlir][py] Overload print with state.
Enables reusing the AsmState when printing from Python. Also moves the
fileObject and binary to the end (pybind11::object was resulting in the
overload not working unless `state=` was specified).
---
mlir/lib/Bindings/Python/IRCore.cpp | 47 +++++++++++++++++++++++------
mlir/lib/Bindings/Python/IRModule.h | 9 ++++--
mlir/test/python/ir/operation.py | 7 ++++-
3 files changed, 50 insertions(+), 13 deletions(-)
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 0f2ca666ccc050e..a4330b062532763 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -110,6 +110,15 @@ static const char kOperationPrintDocstring[] =
invalid, behavior is undefined.
)";
+static const char kOperationPrintStateDocstring[] =
+ R"(Prints the assembly form of the operation to a file like object.
+
+Args:
+ file: The file like object to write to. Defaults to sys.stdout.
+ binary: Whether to write bytes (True) or str (False). Defaults to False.
+ state: AsmState capturing the operation numbering and flags.
+)";
+
static const char kOperationGetAsmDocstring[] =
R"(Gets the assembly form of the operation with all options available.
@@ -1169,11 +1178,11 @@ void PyOperation::checkValid() const {
}
}
-void PyOperationBase::print(py::object fileObject, bool binary,
- std::optional<int64_t> largeElementsLimit,
+void PyOperationBase::print(std::optional<int64_t> largeElementsLimit,
bool enableDebugInfo, bool prettyDebugInfo,
bool printGenericOpForm, bool useLocalScope,
- bool assumeVerified) {
+ bool assumeVerified, py::object fileObject,
+ bool binary) {
PyOperation &operation = getOperation();
operation.checkValid();
if (fileObject.is_none())
@@ -1198,6 +1207,17 @@ void PyOperationBase::print(py::object fileObject, bool binary,
mlirOpPrintingFlagsDestroy(flags);
}
+void PyOperationBase::print(PyAsmState &state, pybind11::object fileObject,
+ bool binary) {
+ PyOperation &operation = getOperation();
+ operation.checkValid();
+ if (fileObject.is_none())
+ fileObject = py::module::import("sys").attr("stdout");
+ PyFileAccumulator accum(fileObject, binary);
+ mlirOperationPrintWithState(operation, state.get(), accum.getCallback(),
+ accum.getUserData());
+}
+
void PyOperationBase::writeBytecode(const py::object &fileObject,
std::optional<int64_t> bytecodeVersion) {
PyOperation &operation = getOperation();
@@ -1230,13 +1250,14 @@ py::object PyOperationBase::getAsm(bool binary,
} else {
fileObject = py::module::import("io").attr("StringIO")();
}
- print(fileObject, /*binary=*/binary,
- /*largeElementsLimit=*/largeElementsLimit,
+ print(/*largeElementsLimit=*/largeElementsLimit,
/*enableDebugInfo=*/enableDebugInfo,
/*prettyDebugInfo=*/prettyDebugInfo,
/*printGenericOpForm=*/printGenericOpForm,
/*useLocalScope=*/useLocalScope,
- /*assumeVerified=*/assumeVerified);
+ /*assumeVerified=*/assumeVerified,
+ /*fileObject=*/fileObject,
+ /*binary=*/binary);
return fileObject.attr("getvalue")();
}
@@ -2946,15 +2967,23 @@ void mlir::python::populateIRCore(py::module &m) {
/*assumeVerified=*/false);
},
"Returns the assembly form of the operation.")
- .def("print", &PyOperationBase::print,
+ .def("print",
+ py::overload_cast<PyAsmState &, pybind11::object, bool>(
+ &PyOperationBase::print),
+ py::arg("state"), py::arg("file") = py::none(),
+ py::arg("binary") = false, kOperationPrintStateDocstring)
+ .def("print",
+ py::overload_cast<std::optional<int64_t>, bool, bool, bool, bool,
+ bool, pybind11::object, bool>(
+ &PyOperationBase::print),
// Careful: Lots of arguments must match up with print method.
- py::arg("file") = py::none(), py::arg("binary") = false,
py::arg("large_elements_limit") = py::none(),
py::arg("enable_debug_info") = false,
py::arg("pretty_debug_info") = false,
py::arg("print_generic_op_form") = false,
py::arg("use_local_scope") = false,
- py::arg("assume_verified") = false, kOperationPrintDocstring)
+ py::arg("assume_verified") = false, py::arg("file") = py::none(),
+ py::arg("binary") = false, kOperationPrintDocstring)
.def("write_bytecode", &PyOperationBase::writeBytecode, py::arg("file"),
py::arg("desired_version") = py::none(),
kOperationPrintBytecodeDocstring)
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index af55693f18fbbf9..3f856681881829a 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -550,16 +550,19 @@ class PyModule : public BaseContextObject {
pybind11::handle handle;
};
+class PyAsmState;
+
/// Base class for PyOperation and PyOpView which exposes the primary, user
/// visible methods for manipulating it.
class PyOperationBase {
public:
virtual ~PyOperationBase() = default;
/// Implements the bound 'print' method and helps with others.
- void print(pybind11::object fileObject, bool binary,
- std::optional<int64_t> largeElementsLimit, bool enableDebugInfo,
+ void print(std::optional<int64_t> largeElementsLimit, bool enableDebugInfo,
bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope,
- bool assumeVerified);
+ bool assumeVerified, pybind11::object fileObject, bool binary);
+ void print(PyAsmState &state, pybind11::object fileObject, bool binary);
+
pybind11::object getAsm(bool binary,
std::optional<int64_t> largeElementsLimit,
bool enableDebugInfo, bool prettyDebugInfo,
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index 04239b048c1c641..04f8a9936e31f79 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -622,10 +622,15 @@ def testOperationPrint():
print(bytes_value.__class__)
print(bytes_value)
- # Test get_asm local_scope.
+ # Test print local_scope.
# CHECK: constant dense<[1, 2, 3, 4]> : tensor<4xi32> loc("nom")
module.operation.print(enable_debug_info=True, use_local_scope=True)
+ # Test printing using state.
+ state = AsmState(module.operation)
+ # CHECK: constant dense<[1, 2, 3, 4]> : tensor<4xi32>
+ module.operation.print(state)
+
# Test get_asm with options.
# CHECK: value = dense_resource<__elided__> : tensor<4xi32>
# CHECK: "func.return"(%arg0) : (i32) -> () -:4:7
More information about the Mlir-commits
mailing list