[Mlir-commits] [mlir] [mlir][python] Add `walk` method to PyOperationBase (PR #87962)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Apr 7 23:05:56 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Hideto Ueno (uenoku)
<details>
<summary>Changes</summary>
This commit adds `walk` method t PyOperationBase that uses a python object as a callback, e.g. `op.walk(lambda op: print(op))`. The second optional argument is a boolean that specifies walk order.
---
Full diff: https://github.com/llvm/llvm-project/pull/87962.diff
3 Files Affected:
- (modified) mlir/lib/Bindings/Python/IRCore.cpp (+17-1)
- (modified) mlir/lib/Bindings/Python/IRModule.h (+3)
- (modified) mlir/test/python/ir/operation.py (+32)
``````````diff
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 734f2f7f3f94cf..848d918e16a7d1 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -1249,6 +1249,20 @@ void PyOperationBase::writeBytecode(const py::object &fileObject,
.str());
}
+void PyOperationBase::walk(py::object callback, bool usePreOrder) {
+ PyOperation &operation = getOperation();
+ operation.checkValid();
+ MlirOperationWalkCallback walkCallback =
+ [](MlirOperation op,
+ void *userData) {
+ py::object *fn = static_cast<py::object *>(userData);
+ (*fn)(op);
+ };
+ mlirOperationWalk(operation, walkCallback, &callback,
+ usePreOrder ? MlirWalkOrder::MlirWalkPreOrder
+ : MlirWalkOrder::MlirWalkPostOrder);
+}
+
py::object PyOperationBase::getAsm(bool binary,
std::optional<int64_t> largeElementsLimit,
bool enableDebugInfo, bool prettyDebugInfo,
@@ -3038,7 +3052,9 @@ void mlir::python::populateIRCore(py::module &m) {
return operation.createOpView();
},
"Detaches the operation from its parent block.")
- .def("erase", [](PyOperationBase &self) { self.getOperation().erase(); });
+ .def("erase", [](PyOperationBase &self) { self.getOperation().erase(); })
+ .def("walk", &PyOperationBase::walk, py::arg("callback"),
+ py::arg("use_pre_order") = py::bool_(false));
py::class_<PyOperation, PyOperationBase>(m, "Operation", py::module_local())
.def_static("create", &PyOperation::create, py::arg("name"),
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index 9acfdde25ae047..ed15dd4f87c2b4 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -579,6 +579,9 @@ class PyOperationBase {
void writeBytecode(const pybind11::object &fileObject,
std::optional<int64_t> bytecodeVersion);
+ // Implement the walk method.
+ void walk(pybind11::object callback, bool usePreOrder);
+
/// Moves the operation before or after the other operation.
void moveAfter(PyOperationBase &other);
void moveBefore(PyOperationBase &other);
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index 04f8a9936e31f7..92a4f1b1545c20 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -1015,3 +1015,35 @@ def testOperationParse():
print(
f"op_with_source_name: {o.get_asm(enable_debug_info=True, use_local_scope=True)}"
)
+
+# CHECK-LABEL: TEST: testOpWalk
+ at run
+def testOpWalk():
+ ctx = Context()
+ ctx.allow_unregistered_dialects = True
+ module = Module.parse(
+ r"""
+ builtin.module {
+ func.func @f() {
+ func.return
+ }
+ }
+ """,
+ ctx,
+ )
+ callback = lambda op: print(op.name)
+ # Test post-order walk (default).
+ # CHECK-NEXT: Post-order
+ # CHECK-NEXT: func.return
+ # CHECK-NEXT: func.func
+ # CHECK-NEXT: builtin.module
+ print("Post-order")
+ module.operation.walk(callback)
+
+ # Test pre-order walk.
+ # CHECK-NEXT: Pre-order
+ # CHECK-NEXT: builtin.module
+ # CHECK-NEXT: func.fun
+ # CHECK-NEXT: func.return
+ print("Pre-order")
+ module.operation.walk(callback, True)
``````````
</details>
https://github.com/llvm/llvm-project/pull/87962
More information about the Mlir-commits
mailing list