[Mlir-commits] [mlir] [mlir][python] Add `walk` method to PyOperationBase (PR #87962)

Hideto Ueno llvmlistbot at llvm.org
Thu Apr 11 05:51:51 PDT 2024


https://github.com/uenoku updated https://github.com/llvm/llvm-project/pull/87962

>From 77b9bbf120f5cead085fe642412a5ea3550009d2 Mon Sep 17 00:00:00 2001
From: Hideto Ueno <uenoku.tokotoko at gmail.com>
Date: Fri, 5 Apr 2024 03:58:44 -0700
Subject: [PATCH 1/3] [mlir][python] Add `walk` method to PyOperationBase

This commit adds `walk` method that uses a python object as
a callback
---
 mlir/lib/Bindings/Python/IRCore.cpp | 18 +++++++++++++++-
 mlir/lib/Bindings/Python/IRModule.h |  3 +++
 mlir/test/python/ir/operation.py    | 32 +++++++++++++++++++++++++++++
 3 files changed, 52 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 734f2f7f3f94cf..848d918e16a7d1 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -1249,6 +1249,20 @@ void PyOperationBase::writeBytecode(const py::object &fileObject,
                               .str());
 }
 
+void PyOperationBase::walk(py::object callback, bool usePreOrder) {
+  PyOperation &operation = getOperation();
+  operation.checkValid();
+  MlirOperationWalkCallback walkCallback =
+   [](MlirOperation op,
+                                              void *userData) {
+    py::object *fn = static_cast<py::object *>(userData);
+    (*fn)(op);
+  };
+  mlirOperationWalk(operation, walkCallback, &callback,
+                    usePreOrder ? MlirWalkOrder::MlirWalkPreOrder
+                                : MlirWalkOrder::MlirWalkPostOrder);
+}
+
 py::object PyOperationBase::getAsm(bool binary,
                                    std::optional<int64_t> largeElementsLimit,
                                    bool enableDebugInfo, bool prettyDebugInfo,
@@ -3038,7 +3052,9 @@ void mlir::python::populateIRCore(py::module &m) {
             return operation.createOpView();
           },
           "Detaches the operation from its parent block.")
-      .def("erase", [](PyOperationBase &self) { self.getOperation().erase(); });
+      .def("erase", [](PyOperationBase &self) { self.getOperation().erase(); })
+      .def("walk", &PyOperationBase::walk, py::arg("callback"),
+           py::arg("use_pre_order") = py::bool_(false));
 
   py::class_<PyOperation, PyOperationBase>(m, "Operation", py::module_local())
       .def_static("create", &PyOperation::create, py::arg("name"),
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index 9acfdde25ae047..ed15dd4f87c2b4 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -579,6 +579,9 @@ class PyOperationBase {
   void writeBytecode(const pybind11::object &fileObject,
                      std::optional<int64_t> bytecodeVersion);
 
+  // Implement the walk method.
+  void walk(pybind11::object callback, bool usePreOrder);
+
   /// Moves the operation before or after the other operation.
   void moveAfter(PyOperationBase &other);
   void moveBefore(PyOperationBase &other);
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index 04f8a9936e31f7..92a4f1b1545c20 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -1015,3 +1015,35 @@ def testOperationParse():
         print(
             f"op_with_source_name: {o.get_asm(enable_debug_info=True, use_local_scope=True)}"
         )
+
+# CHECK-LABEL: TEST: testOpWalk
+ at run
+def testOpWalk():
+    ctx = Context()
+    ctx.allow_unregistered_dialects = True
+    module = Module.parse(
+        r"""
+    builtin.module {
+      func.func @f() {
+        func.return
+      }
+    }
+  """,
+    ctx,
+    )
+    callback = lambda op: print(op.name)
+    # Test post-order walk (default).
+    # CHECK-NEXT:  Post-order
+    # CHECK-NEXT:  func.return
+    # CHECK-NEXT:  func.func
+    # CHECK-NEXT:  builtin.module
+    print("Post-order")
+    module.operation.walk(callback)
+
+    # Test pre-order walk.
+    # CHECK-NEXT:  Pre-order
+    # CHECK-NEXT:  builtin.module
+    # CHECK-NEXT:  func.fun
+    # CHECK-NEXT:  func.return
+    print("Pre-order")
+    module.operation.walk(callback, True)

>From 218279d09cabc7f6f9ad995d6686cba4bafd76cc Mon Sep 17 00:00:00 2001
From: Hideto Ueno <uenoku.tokotoko at gmail.com>
Date: Sun, 7 Apr 2024 23:11:49 -0700
Subject: [PATCH 2/3] format

---
 mlir/lib/Bindings/Python/IRCore.cpp | 6 ++----
 mlir/test/python/ir/operation.py    | 3 ++-
 2 files changed, 4 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 848d918e16a7d1..1ee9571065215b 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -1252,8 +1252,7 @@ void PyOperationBase::writeBytecode(const py::object &fileObject,
 void PyOperationBase::walk(py::object callback, bool usePreOrder) {
   PyOperation &operation = getOperation();
   operation.checkValid();
-  MlirOperationWalkCallback walkCallback =
-   [](MlirOperation op,
+  MlirOperationWalkCallback walkCallback = [](MlirOperation op,
                                               void *userData) {
     py::object *fn = static_cast<py::object *>(userData);
     (*fn)(op);
@@ -3003,8 +3002,7 @@ void mlir::python::populateIRCore(py::module &m) {
            py::arg("binary") = false, kOperationPrintStateDocstring)
       .def("print",
            py::overload_cast<std::optional<int64_t>, bool, bool, bool, bool,
-                             bool, py::object, bool>(
-               &PyOperationBase::print),
+                             bool, py::object, bool>(&PyOperationBase::print),
            // Careful: Lots of arguments must match up with print method.
            py::arg("large_elements_limit") = py::none(),
            py::arg("enable_debug_info") = false,
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index 92a4f1b1545c20..18790cd4f89c88 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -1016,6 +1016,7 @@ def testOperationParse():
             f"op_with_source_name: {o.get_asm(enable_debug_info=True, use_local_scope=True)}"
         )
 
+
 # CHECK-LABEL: TEST: testOpWalk
 @run
 def testOpWalk():
@@ -1029,7 +1030,7 @@ def testOpWalk():
       }
     }
   """,
-    ctx,
+        ctx,
     )
     callback = lambda op: print(op.name)
     # Test post-order walk (default).

>From 96f76ac00aa45421b180df7786a40767522e0bdf Mon Sep 17 00:00:00 2001
From: Hideto Ueno <uenoku.tokotoko at gmail.com>
Date: Thu, 11 Apr 2024 05:49:45 -0700
Subject: [PATCH 3/3] Add CAPI for WalkResult. Add python enum definition of
 WalkResult/WalkOrder

---
 mlir/include/mlir-c/IR.h                      | 10 +++-
 .../mlir/Bindings/Python/PybindAdaptors.h     |  1 +
 mlir/lib/Bindings/Python/IRCore.cpp           | 26 ++++++---
 mlir/lib/Bindings/Python/IRModule.h           |  3 +-
 mlir/lib/CAPI/IR/IR.cpp                       | 21 ++++++-
 mlir/test/CAPI/ir.c                           | 58 +++++++++++++++----
 mlir/test/python/ir/operation.py              | 41 ++++++++++++-
 7 files changed, 136 insertions(+), 24 deletions(-)

diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 82da511f807a34..32abacf353133e 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -705,6 +705,13 @@ MLIR_CAPI_EXPORTED void mlirOperationMoveAfter(MlirOperation op,
 MLIR_CAPI_EXPORTED void mlirOperationMoveBefore(MlirOperation op,
                                                 MlirOperation other);
 
+/// Operation walk result.
+typedef enum MlirWalkResult {
+  MlirWalkResultAdvance,
+  MlirWalkResultInterrupt,
+  MlirWalkResultSkip
+} MlirWalkResult;
+
 /// Traversal order for operation walk.
 typedef enum MlirWalkOrder {
   MlirWalkPreOrder,
@@ -713,7 +720,8 @@ typedef enum MlirWalkOrder {
 
 /// Operation walker type. The handler is passed an (opaque) reference to an
 /// operation and a pointer to a `userData`.
-typedef void (*MlirOperationWalkCallback)(MlirOperation, void *userData);
+typedef MlirWalkResult (*MlirOperationWalkCallback)(MlirOperation,
+                                                    void *userData);
 
 /// Walks operation `op` in `walkOrder` and calls `callback` on that operation.
 /// `*userData` is passed to the callback as well and can be used to tunnel some
diff --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
index 52f6321251919e..d8f22c7aa17096 100644
--- a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
+++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
@@ -18,6 +18,7 @@
 #ifndef MLIR_BINDINGS_PYTHON_PYBINDADAPTORS_H
 #define MLIR_BINDINGS_PYTHON_PYBINDADAPTORS_H
 
+#include <pybind11/functional.h>
 #include <pybind11/pybind11.h>
 #include <pybind11/pytypes.h>
 #include <pybind11/stl.h>
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 1ee9571065215b..d875f4eba2b139 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -674,6 +674,7 @@ void PyMlirContext::clearOperationsInside(PyOperationBase &op) {
       data->rootOp.getOperation().getContext()->clearOperation(op);
     else
       data->rootSeen = true;
+    return MlirWalkResult::MlirWalkResultAdvance;
   };
   mlirOperationWalk(op.getOperation(), invalidatingCallback,
                     static_cast<void *>(&data), MlirWalkPreOrder);
@@ -1249,17 +1250,19 @@ void PyOperationBase::writeBytecode(const py::object &fileObject,
                               .str());
 }
 
-void PyOperationBase::walk(py::object callback, bool usePreOrder) {
+void PyOperationBase::walk(
+    std::function<MlirWalkResult(MlirOperation)> callback,
+    MlirWalkOrder walkOrder) {
   PyOperation &operation = getOperation();
   operation.checkValid();
   MlirOperationWalkCallback walkCallback = [](MlirOperation op,
                                               void *userData) {
-    py::object *fn = static_cast<py::object *>(userData);
-    (*fn)(op);
+    auto *fn =
+        static_cast<std::function<MlirWalkResult(MlirOperation)> *>(userData);
+    return (*fn)(op);
   };
-  mlirOperationWalk(operation, walkCallback, &callback,
-                    usePreOrder ? MlirWalkOrder::MlirWalkPreOrder
-                                : MlirWalkOrder::MlirWalkPostOrder);
+
+  mlirOperationWalk(operation, walkCallback, &callback, walkOrder);
 }
 
 py::object PyOperationBase::getAsm(bool binary,
@@ -2524,6 +2527,15 @@ void mlir::python::populateIRCore(py::module &m) {
       .value("NOTE", MlirDiagnosticNote)
       .value("REMARK", MlirDiagnosticRemark);
 
+  py::enum_<MlirWalkOrder>(m, "WalkOrder", py::module_local())
+      .value("PRE_ORDER", MlirWalkPreOrder)
+      .value("POST_ORDER", MlirWalkPostOrder);
+
+  py::enum_<MlirWalkResult>(m, "WalkResult", py::module_local())
+      .value("ADVANCE", MlirWalkResultAdvance)
+      .value("INTERRUPT", MlirWalkResultInterrupt)
+      .value("SKIP", MlirWalkResultSkip);
+
   //----------------------------------------------------------------------------
   // Mapping of Diagnostics.
   //----------------------------------------------------------------------------
@@ -3052,7 +3064,7 @@ void mlir::python::populateIRCore(py::module &m) {
           "Detaches the operation from its parent block.")
       .def("erase", [](PyOperationBase &self) { self.getOperation().erase(); })
       .def("walk", &PyOperationBase::walk, py::arg("callback"),
-           py::arg("use_pre_order") = py::bool_(false));
+           py::arg("walk_order") = MlirWalkPostOrder);
 
   py::class_<PyOperation, PyOperationBase>(m, "Operation", py::module_local())
       .def_static("create", &PyOperation::create, py::arg("name"),
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index ed15dd4f87c2b4..b038a0c54d29b9 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -580,7 +580,8 @@ class PyOperationBase {
                      std::optional<int64_t> bytecodeVersion);
 
   // Implement the walk method.
-  void walk(pybind11::object callback, bool usePreOrder);
+  void walk(std::function<MlirWalkResult(MlirOperation)> callback,
+            MlirWalkOrder walkOrder);
 
   /// Moves the operation before or after the other operation.
   void moveAfter(PyOperationBase &other);
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index cdb64f4ec4a40f..065a8580eae497 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -717,17 +717,34 @@ void mlirOperationMoveBefore(MlirOperation op, MlirOperation other) {
   return unwrap(op)->moveBefore(unwrap(other));
 }
 
+static mlir::WalkResult translateWalkResult(MlirWalkResult result) {
+  switch (result) {
+  case MlirWalkResultAdvance:
+    return mlir::WalkResult::advance();
+
+  case MlirWalkResultInterrupt:
+    return mlir::WalkResult::interrupt();
+
+  case MlirWalkResultSkip:
+    return mlir::WalkResult::skip();
+  }
+}
+
 void mlirOperationWalk(MlirOperation op, MlirOperationWalkCallback callback,
                        void *userData, MlirWalkOrder walkOrder) {
   switch (walkOrder) {
 
   case MlirWalkPreOrder:
     unwrap(op)->walk<mlir::WalkOrder::PreOrder>(
-        [callback, userData](Operation *op) { callback(wrap(op), userData); });
+        [callback, userData](Operation *op) {
+          return translateWalkResult(callback(wrap(op), userData));
+        });
     break;
   case MlirWalkPostOrder:
     unwrap(op)->walk<mlir::WalkOrder::PostOrder>(
-        [callback, userData](Operation *op) { callback(wrap(op), userData); });
+        [callback, userData](Operation *op) {
+          return translateWalkResult(callback(wrap(op), userData));
+        });
   }
 }
 
diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index 8e79338c57a22a..3d05b2a12dd8ef 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -2244,9 +2244,22 @@ typedef struct {
   const char *x;
 } callBackData;
 
-void walkCallBack(MlirOperation op, void *rootOpVoid) {
+MlirWalkResult walkCallBack(MlirOperation op, void *rootOpVoid) {
   fprintf(stderr, "%s: %s\n", ((callBackData *)(rootOpVoid))->x,
           mlirIdentifierStr(mlirOperationGetName(op)).data);
+  return MlirWalkResultAdvance;
+}
+
+MlirWalkResult walkCallBackTestWalkResult(MlirOperation op, void *rootOpVoid) {
+  fprintf(stderr, "%s: %s\n", ((callBackData *)(rootOpVoid))->x,
+          mlirIdentifierStr(mlirOperationGetName(op)).data);
+  if (strcmp(mlirIdentifierStr(mlirOperationGetName(op)).data, "func.func") ==
+      0)
+    return MlirWalkResultSkip;
+  if (strcmp(mlirIdentifierStr(mlirOperationGetName(op)).data, "arith.addi") ==
+      0)
+    return MlirWalkResultInterrupt;
+  return MlirWalkResultAdvance;
 }
 
 int testOperationWalk(MlirContext ctx) {
@@ -2259,6 +2272,9 @@ int testOperationWalk(MlirContext ctx) {
                              "    arith.addi %1, %1: i32\n"
                              "    return\n"
                              "  }\n"
+                             "  func.func @bar() {\n"
+                             "    return\n"
+                             "  }\n"
                              "}";
   MlirModule module =
       mlirModuleCreateParse(ctx, mlirStringRefCreateFromCString(moduleString));
@@ -2266,22 +2282,42 @@ int testOperationWalk(MlirContext ctx) {
   callBackData data;
   data.x = "i love you";
 
-  // CHECK: i love you: arith.constant
-  // CHECK: i love you: arith.addi
-  // CHECK: i love you: func.return
-  // CHECK: i love you: func.func
-  // CHECK: i love you: builtin.module
+  // CHECK-NEXT: i love you: arith.constant
+  // CHECK-NEXT: i love you: arith.addi
+  // CHECK-NEXT: i love you: func.return
+  // CHECK-NEXT: i love you: func.func
+  // CHECK-NEXT: i love you: func.return
+  // CHECK-NEXT: i love you: func.func
+  // CHECK-NEXT: i love you: builtin.module
   mlirOperationWalk(mlirModuleGetOperation(module), walkCallBack,
                     (void *)(&data), MlirWalkPostOrder);
 
   data.x = "i don't love you";
-  // CHECK: i don't love you: builtin.module
-  // CHECK: i don't love you: func.func
-  // CHECK: i don't love you: arith.constant
-  // CHECK: i don't love you: arith.addi
-  // CHECK: i don't love you: func.return
+  // CHECK-NEXT: i don't love you: builtin.module
+  // CHECK-NEXT: i don't love you: func.func
+  // CHECK-NEXT: i don't love you: arith.constant
+  // CHECK-NEXT: i don't love you: arith.addi
+  // CHECK-NEXT: i don't love you: func.return
+  // CHECK-NEXT: i don't love you: func.func
+  // CHECK-NEXT: i don't love you: func.return
   mlirOperationWalk(mlirModuleGetOperation(module), walkCallBack,
                     (void *)(&data), MlirWalkPreOrder);
+
+  data.x = "interrupt";
+  // Interrupted at `arith.addi`
+  // CHECK-NEXT: interrupt: arith.constant
+  // CHECK-NEXT: interrupt: arith.addi
+  mlirOperationWalk(mlirModuleGetOperation(module), walkCallBackTestWalkResult,
+                    (void *)(&data), MlirWalkPostOrder);
+
+  data.x = "skip";
+  // Skip at `func.func`
+  // CHECK-NEXT: skip: builtin.module
+  // CHECK-NEXT: skip: func.func
+  // CHECK-NEXT: skip: func.func
+  mlirOperationWalk(mlirModuleGetOperation(module), walkCallBackTestWalkResult,
+                    (void *)(&data), MlirWalkPreOrder);
+
   mlirModuleDestroy(module);
   return 0;
 }
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index 18790cd4f89c88..20483ae8f2092a 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -1032,7 +1032,10 @@ def testOpWalk():
   """,
         ctx,
     )
-    callback = lambda op: print(op.name)
+
+    def callback(op):
+        print(op.name)
+        return WalkResult.ADVANCE
     # Test post-order walk (default).
     # CHECK-NEXT:  Post-order
     # CHECK-NEXT:  func.return
@@ -1047,4 +1050,38 @@ def testOpWalk():
     # CHECK-NEXT:  func.fun
     # CHECK-NEXT:  func.return
     print("Pre-order")
-    module.operation.walk(callback, True)
+    module.operation.walk(callback, WalkOrder.PRE_ORDER)
+
+    # Test interrput.
+    # CHECK-NEXT:  Interrupt post-order
+    # CHECK-NEXT:  func.return
+    print("Interrupt post-order")
+    def callback(op):
+        print(op.name)
+        return WalkResult.INTERRUPT
+    module.operation.walk(callback)
+
+    # Test skipk.
+    # CHECK-NEXT:  Skip pre-order
+    # CHECK-NEXT:  builtin.module
+    print("Skip pre-order")
+    def callback(op):
+        print(op.name)
+        return WalkResult.SKIP
+    module.operation.walk(callback, WalkOrder.PRE_ORDER)
+
+
+    # Test exception
+    print("Exception")
+    def callback(op):
+        print(op.name)
+        raise ValueError
+        return WalkResult.ADVANCE
+    try:
+        module.operation.walk(callback)
+    except ValueError:
+        print("Exception raised")
+    # CHECK: Exception
+    # CHECK-NEXT: func.return
+    # CHECK-NEXT: Exception raised
+



More information about the Mlir-commits mailing list