[Mlir-commits] [mlir] [mlir][python] Add `walk` method to PyOperationBase (PR #87962)

Hideto Ueno llvmlistbot at llvm.org
Sun Apr 7 23:05:28 PDT 2024


https://github.com/uenoku created https://github.com/llvm/llvm-project/pull/87962

This commit adds `walk` method t PyOperationBase that uses a python object as a callback, e.g. `op.walk(lambda op: print(op))`. The second optional argument is a boolean that specifies walk order. 

>From 77b9bbf120f5cead085fe642412a5ea3550009d2 Mon Sep 17 00:00:00 2001
From: Hideto Ueno <uenoku.tokotoko at gmail.com>
Date: Fri, 5 Apr 2024 03:58:44 -0700
Subject: [PATCH] [mlir][python] Add `walk` method to PyOperationBase

This commit adds `walk` method that uses a python object as
a callback
---
 mlir/lib/Bindings/Python/IRCore.cpp | 18 +++++++++++++++-
 mlir/lib/Bindings/Python/IRModule.h |  3 +++
 mlir/test/python/ir/operation.py    | 32 +++++++++++++++++++++++++++++
 3 files changed, 52 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 734f2f7f3f94cf..848d918e16a7d1 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -1249,6 +1249,20 @@ void PyOperationBase::writeBytecode(const py::object &fileObject,
                               .str());
 }
 
+void PyOperationBase::walk(py::object callback, bool usePreOrder) {
+  PyOperation &operation = getOperation();
+  operation.checkValid();
+  MlirOperationWalkCallback walkCallback =
+   [](MlirOperation op,
+                                              void *userData) {
+    py::object *fn = static_cast<py::object *>(userData);
+    (*fn)(op);
+  };
+  mlirOperationWalk(operation, walkCallback, &callback,
+                    usePreOrder ? MlirWalkOrder::MlirWalkPreOrder
+                                : MlirWalkOrder::MlirWalkPostOrder);
+}
+
 py::object PyOperationBase::getAsm(bool binary,
                                    std::optional<int64_t> largeElementsLimit,
                                    bool enableDebugInfo, bool prettyDebugInfo,
@@ -3038,7 +3052,9 @@ void mlir::python::populateIRCore(py::module &m) {
             return operation.createOpView();
           },
           "Detaches the operation from its parent block.")
-      .def("erase", [](PyOperationBase &self) { self.getOperation().erase(); });
+      .def("erase", [](PyOperationBase &self) { self.getOperation().erase(); })
+      .def("walk", &PyOperationBase::walk, py::arg("callback"),
+           py::arg("use_pre_order") = py::bool_(false));
 
   py::class_<PyOperation, PyOperationBase>(m, "Operation", py::module_local())
       .def_static("create", &PyOperation::create, py::arg("name"),
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index 9acfdde25ae047..ed15dd4f87c2b4 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -579,6 +579,9 @@ class PyOperationBase {
   void writeBytecode(const pybind11::object &fileObject,
                      std::optional<int64_t> bytecodeVersion);
 
+  // Implement the walk method.
+  void walk(pybind11::object callback, bool usePreOrder);
+
   /// Moves the operation before or after the other operation.
   void moveAfter(PyOperationBase &other);
   void moveBefore(PyOperationBase &other);
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index 04f8a9936e31f7..92a4f1b1545c20 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -1015,3 +1015,35 @@ def testOperationParse():
         print(
             f"op_with_source_name: {o.get_asm(enable_debug_info=True, use_local_scope=True)}"
         )
+
+# CHECK-LABEL: TEST: testOpWalk
+ at run
+def testOpWalk():
+    ctx = Context()
+    ctx.allow_unregistered_dialects = True
+    module = Module.parse(
+        r"""
+    builtin.module {
+      func.func @f() {
+        func.return
+      }
+    }
+  """,
+    ctx,
+    )
+    callback = lambda op: print(op.name)
+    # Test post-order walk (default).
+    # CHECK-NEXT:  Post-order
+    # CHECK-NEXT:  func.return
+    # CHECK-NEXT:  func.func
+    # CHECK-NEXT:  builtin.module
+    print("Post-order")
+    module.operation.walk(callback)
+
+    # Test pre-order walk.
+    # CHECK-NEXT:  Pre-order
+    # CHECK-NEXT:  builtin.module
+    # CHECK-NEXT:  func.fun
+    # CHECK-NEXT:  func.return
+    print("Pre-order")
+    module.operation.walk(callback, True)



More information about the Mlir-commits mailing list