[Mlir-commits] [mlir] [mlir][python] Add `walk` method to PyOperationBase (PR #87962)
Hideto Ueno
llvmlistbot at llvm.org
Sun Apr 7 23:12:15 PDT 2024
https://github.com/uenoku updated https://github.com/llvm/llvm-project/pull/87962
>From 77b9bbf120f5cead085fe642412a5ea3550009d2 Mon Sep 17 00:00:00 2001
From: Hideto Ueno <uenoku.tokotoko at gmail.com>
Date: Fri, 5 Apr 2024 03:58:44 -0700
Subject: [PATCH 1/2] [mlir][python] Add `walk` method to PyOperationBase
This commit adds `walk` method that uses a python object as
a callback
---
mlir/lib/Bindings/Python/IRCore.cpp | 18 +++++++++++++++-
mlir/lib/Bindings/Python/IRModule.h | 3 +++
mlir/test/python/ir/operation.py | 32 +++++++++++++++++++++++++++++
3 files changed, 52 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 734f2f7f3f94cf..848d918e16a7d1 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -1249,6 +1249,20 @@ void PyOperationBase::writeBytecode(const py::object &fileObject,
.str());
}
+void PyOperationBase::walk(py::object callback, bool usePreOrder) {
+ PyOperation &operation = getOperation();
+ operation.checkValid();
+ MlirOperationWalkCallback walkCallback =
+ [](MlirOperation op,
+ void *userData) {
+ py::object *fn = static_cast<py::object *>(userData);
+ (*fn)(op);
+ };
+ mlirOperationWalk(operation, walkCallback, &callback,
+ usePreOrder ? MlirWalkOrder::MlirWalkPreOrder
+ : MlirWalkOrder::MlirWalkPostOrder);
+}
+
py::object PyOperationBase::getAsm(bool binary,
std::optional<int64_t> largeElementsLimit,
bool enableDebugInfo, bool prettyDebugInfo,
@@ -3038,7 +3052,9 @@ void mlir::python::populateIRCore(py::module &m) {
return operation.createOpView();
},
"Detaches the operation from its parent block.")
- .def("erase", [](PyOperationBase &self) { self.getOperation().erase(); });
+ .def("erase", [](PyOperationBase &self) { self.getOperation().erase(); })
+ .def("walk", &PyOperationBase::walk, py::arg("callback"),
+ py::arg("use_pre_order") = py::bool_(false));
py::class_<PyOperation, PyOperationBase>(m, "Operation", py::module_local())
.def_static("create", &PyOperation::create, py::arg("name"),
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index 9acfdde25ae047..ed15dd4f87c2b4 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -579,6 +579,9 @@ class PyOperationBase {
void writeBytecode(const pybind11::object &fileObject,
std::optional<int64_t> bytecodeVersion);
+ // Implement the walk method.
+ void walk(pybind11::object callback, bool usePreOrder);
+
/// Moves the operation before or after the other operation.
void moveAfter(PyOperationBase &other);
void moveBefore(PyOperationBase &other);
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index 04f8a9936e31f7..92a4f1b1545c20 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -1015,3 +1015,35 @@ def testOperationParse():
print(
f"op_with_source_name: {o.get_asm(enable_debug_info=True, use_local_scope=True)}"
)
+
+# CHECK-LABEL: TEST: testOpWalk
+ at run
+def testOpWalk():
+ ctx = Context()
+ ctx.allow_unregistered_dialects = True
+ module = Module.parse(
+ r"""
+ builtin.module {
+ func.func @f() {
+ func.return
+ }
+ }
+ """,
+ ctx,
+ )
+ callback = lambda op: print(op.name)
+ # Test post-order walk (default).
+ # CHECK-NEXT: Post-order
+ # CHECK-NEXT: func.return
+ # CHECK-NEXT: func.func
+ # CHECK-NEXT: builtin.module
+ print("Post-order")
+ module.operation.walk(callback)
+
+ # Test pre-order walk.
+ # CHECK-NEXT: Pre-order
+ # CHECK-NEXT: builtin.module
+ # CHECK-NEXT: func.fun
+ # CHECK-NEXT: func.return
+ print("Pre-order")
+ module.operation.walk(callback, True)
>From 218279d09cabc7f6f9ad995d6686cba4bafd76cc Mon Sep 17 00:00:00 2001
From: Hideto Ueno <uenoku.tokotoko at gmail.com>
Date: Sun, 7 Apr 2024 23:11:49 -0700
Subject: [PATCH 2/2] format
---
mlir/lib/Bindings/Python/IRCore.cpp | 6 ++----
mlir/test/python/ir/operation.py | 3 ++-
2 files changed, 4 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 848d918e16a7d1..1ee9571065215b 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -1252,8 +1252,7 @@ void PyOperationBase::writeBytecode(const py::object &fileObject,
void PyOperationBase::walk(py::object callback, bool usePreOrder) {
PyOperation &operation = getOperation();
operation.checkValid();
- MlirOperationWalkCallback walkCallback =
- [](MlirOperation op,
+ MlirOperationWalkCallback walkCallback = [](MlirOperation op,
void *userData) {
py::object *fn = static_cast<py::object *>(userData);
(*fn)(op);
@@ -3003,8 +3002,7 @@ void mlir::python::populateIRCore(py::module &m) {
py::arg("binary") = false, kOperationPrintStateDocstring)
.def("print",
py::overload_cast<std::optional<int64_t>, bool, bool, bool, bool,
- bool, py::object, bool>(
- &PyOperationBase::print),
+ bool, py::object, bool>(&PyOperationBase::print),
// Careful: Lots of arguments must match up with print method.
py::arg("large_elements_limit") = py::none(),
py::arg("enable_debug_info") = false,
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index 92a4f1b1545c20..18790cd4f89c88 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -1016,6 +1016,7 @@ def testOperationParse():
f"op_with_source_name: {o.get_asm(enable_debug_info=True, use_local_scope=True)}"
)
+
# CHECK-LABEL: TEST: testOpWalk
@run
def testOpWalk():
@@ -1029,7 +1030,7 @@ def testOpWalk():
}
}
""",
- ctx,
+ ctx,
)
callback = lambda op: print(op.name)
# Test post-order walk (default).
More information about the Mlir-commits
mailing list