[Mlir-commits] [mlir] [mlir][python] Add `walk` method to PyOperationBase (PR #87962)

Hideto Ueno llvmlistbot at llvm.org
Sun Apr 7 23:12:15 PDT 2024


https://github.com/uenoku updated https://github.com/llvm/llvm-project/pull/87962

>From 77b9bbf120f5cead085fe642412a5ea3550009d2 Mon Sep 17 00:00:00 2001
From: Hideto Ueno <uenoku.tokotoko at gmail.com>
Date: Fri, 5 Apr 2024 03:58:44 -0700
Subject: [PATCH 1/2] [mlir][python] Add `walk` method to PyOperationBase

This commit adds `walk` method that uses a python object as
a callback
---
 mlir/lib/Bindings/Python/IRCore.cpp | 18 +++++++++++++++-
 mlir/lib/Bindings/Python/IRModule.h |  3 +++
 mlir/test/python/ir/operation.py    | 32 +++++++++++++++++++++++++++++
 3 files changed, 52 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 734f2f7f3f94cf..848d918e16a7d1 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -1249,6 +1249,20 @@ void PyOperationBase::writeBytecode(const py::object &fileObject,
                               .str());
 }
 
+void PyOperationBase::walk(py::object callback, bool usePreOrder) {
+  PyOperation &operation = getOperation();
+  operation.checkValid();
+  MlirOperationWalkCallback walkCallback =
+   [](MlirOperation op,
+                                              void *userData) {
+    py::object *fn = static_cast<py::object *>(userData);
+    (*fn)(op);
+  };
+  mlirOperationWalk(operation, walkCallback, &callback,
+                    usePreOrder ? MlirWalkOrder::MlirWalkPreOrder
+                                : MlirWalkOrder::MlirWalkPostOrder);
+}
+
 py::object PyOperationBase::getAsm(bool binary,
                                    std::optional<int64_t> largeElementsLimit,
                                    bool enableDebugInfo, bool prettyDebugInfo,
@@ -3038,7 +3052,9 @@ void mlir::python::populateIRCore(py::module &m) {
             return operation.createOpView();
           },
           "Detaches the operation from its parent block.")
-      .def("erase", [](PyOperationBase &self) { self.getOperation().erase(); });
+      .def("erase", [](PyOperationBase &self) { self.getOperation().erase(); })
+      .def("walk", &PyOperationBase::walk, py::arg("callback"),
+           py::arg("use_pre_order") = py::bool_(false));
 
   py::class_<PyOperation, PyOperationBase>(m, "Operation", py::module_local())
       .def_static("create", &PyOperation::create, py::arg("name"),
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index 9acfdde25ae047..ed15dd4f87c2b4 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -579,6 +579,9 @@ class PyOperationBase {
   void writeBytecode(const pybind11::object &fileObject,
                      std::optional<int64_t> bytecodeVersion);
 
+  // Implement the walk method.
+  void walk(pybind11::object callback, bool usePreOrder);
+
   /// Moves the operation before or after the other operation.
   void moveAfter(PyOperationBase &other);
   void moveBefore(PyOperationBase &other);
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index 04f8a9936e31f7..92a4f1b1545c20 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -1015,3 +1015,35 @@ def testOperationParse():
         print(
             f"op_with_source_name: {o.get_asm(enable_debug_info=True, use_local_scope=True)}"
         )
+
+# CHECK-LABEL: TEST: testOpWalk
+ at run
+def testOpWalk():
+    ctx = Context()
+    ctx.allow_unregistered_dialects = True
+    module = Module.parse(
+        r"""
+    builtin.module {
+      func.func @f() {
+        func.return
+      }
+    }
+  """,
+    ctx,
+    )
+    callback = lambda op: print(op.name)
+    # Test post-order walk (default).
+    # CHECK-NEXT:  Post-order
+    # CHECK-NEXT:  func.return
+    # CHECK-NEXT:  func.func
+    # CHECK-NEXT:  builtin.module
+    print("Post-order")
+    module.operation.walk(callback)
+
+    # Test pre-order walk.
+    # CHECK-NEXT:  Pre-order
+    # CHECK-NEXT:  builtin.module
+    # CHECK-NEXT:  func.fun
+    # CHECK-NEXT:  func.return
+    print("Pre-order")
+    module.operation.walk(callback, True)

>From 218279d09cabc7f6f9ad995d6686cba4bafd76cc Mon Sep 17 00:00:00 2001
From: Hideto Ueno <uenoku.tokotoko at gmail.com>
Date: Sun, 7 Apr 2024 23:11:49 -0700
Subject: [PATCH 2/2] format

---
 mlir/lib/Bindings/Python/IRCore.cpp | 6 ++----
 mlir/test/python/ir/operation.py    | 3 ++-
 2 files changed, 4 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 848d918e16a7d1..1ee9571065215b 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -1252,8 +1252,7 @@ void PyOperationBase::writeBytecode(const py::object &fileObject,
 void PyOperationBase::walk(py::object callback, bool usePreOrder) {
   PyOperation &operation = getOperation();
   operation.checkValid();
-  MlirOperationWalkCallback walkCallback =
-   [](MlirOperation op,
+  MlirOperationWalkCallback walkCallback = [](MlirOperation op,
                                               void *userData) {
     py::object *fn = static_cast<py::object *>(userData);
     (*fn)(op);
@@ -3003,8 +3002,7 @@ void mlir::python::populateIRCore(py::module &m) {
            py::arg("binary") = false, kOperationPrintStateDocstring)
       .def("print",
            py::overload_cast<std::optional<int64_t>, bool, bool, bool, bool,
-                             bool, py::object, bool>(
-               &PyOperationBase::print),
+                             bool, py::object, bool>(&PyOperationBase::print),
            // Careful: Lots of arguments must match up with print method.
            py::arg("large_elements_limit") = py::none(),
            py::arg("enable_debug_info") = false,
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index 92a4f1b1545c20..18790cd4f89c88 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -1016,6 +1016,7 @@ def testOperationParse():
             f"op_with_source_name: {o.get_asm(enable_debug_info=True, use_local_scope=True)}"
         )
 
+
 # CHECK-LABEL: TEST: testOpWalk
 @run
 def testOpWalk():
@@ -1029,7 +1030,7 @@ def testOpWalk():
       }
     }
   """,
-    ctx,
+        ctx,
     )
     callback = lambda op: print(op.name)
     # Test post-order walk (default).



More information about the Mlir-commits mailing list