[Mlir-commits] [mlir] 4714883 - [mlir][python] Add `walk` method to PyOperationBase (#87962)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Apr 16 23:09:50 PDT 2024


Author: Hideto Ueno
Date: 2024-04-17T15:09:47+09:00
New Revision: 47148832d4e3bf4901430732f1af6673147accb2

URL: https://github.com/llvm/llvm-project/commit/47148832d4e3bf4901430732f1af6673147accb2
DIFF: https://github.com/llvm/llvm-project/commit/47148832d4e3bf4901430732f1af6673147accb2.diff

LOG: [mlir][python] Add `walk` method to PyOperationBase (#87962)

This commit adds `walk` method to PyOperationBase that uses a python
object as a callback, e.g. `op.walk(callback)`. Currently callback must
return a walk result explicitly.

We(SiFive) have implemented walk method with python in our internal
python tool for a while. However the overhead of python is expensive and
it didn't scale well for large MLIR files. Just replacing walk with this
version reduced the entire execution time of the tool by 30~40% and
there are a few configs that the tool takes several hours to finish so
this commit significantly improves tool performance.

Added: 
    

Modified: 
    mlir/include/mlir-c/IR.h
    mlir/include/mlir/Bindings/Python/PybindAdaptors.h
    mlir/lib/Bindings/Python/IRCore.cpp
    mlir/lib/Bindings/Python/IRModule.h
    mlir/lib/CAPI/IR/IR.cpp
    mlir/test/CAPI/ir.c
    mlir/test/python/ir/operation.py

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 82da511f807a34..32abacf353133e 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -705,6 +705,13 @@ MLIR_CAPI_EXPORTED void mlirOperationMoveAfter(MlirOperation op,
 MLIR_CAPI_EXPORTED void mlirOperationMoveBefore(MlirOperation op,
                                                 MlirOperation other);
 
+/// Operation walk result.
+typedef enum MlirWalkResult {
+  MlirWalkResultAdvance,
+  MlirWalkResultInterrupt,
+  MlirWalkResultSkip
+} MlirWalkResult;
+
 /// Traversal order for operation walk.
 typedef enum MlirWalkOrder {
   MlirWalkPreOrder,
@@ -713,7 +720,8 @@ typedef enum MlirWalkOrder {
 
 /// Operation walker type. The handler is passed an (opaque) reference to an
 /// operation and a pointer to a `userData`.
-typedef void (*MlirOperationWalkCallback)(MlirOperation, void *userData);
+typedef MlirWalkResult (*MlirOperationWalkCallback)(MlirOperation,
+                                                    void *userData);
 
 /// Walks operation `op` in `walkOrder` and calls `callback` on that operation.
 /// `*userData` is passed to the callback as well and can be used to tunnel some

diff  --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
index 52f6321251919e..d8f22c7aa17096 100644
--- a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
+++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
@@ -18,6 +18,7 @@
 #ifndef MLIR_BINDINGS_PYTHON_PYBINDADAPTORS_H
 #define MLIR_BINDINGS_PYTHON_PYBINDADAPTORS_H
 
+#include <pybind11/functional.h>
 #include <pybind11/pybind11.h>
 #include <pybind11/pytypes.h>
 #include <pybind11/stl.h>

diff  --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 734f2f7f3f94cf..d875f4eba2b139 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -674,6 +674,7 @@ void PyMlirContext::clearOperationsInside(PyOperationBase &op) {
       data->rootOp.getOperation().getContext()->clearOperation(op);
     else
       data->rootSeen = true;
+    return MlirWalkResult::MlirWalkResultAdvance;
   };
   mlirOperationWalk(op.getOperation(), invalidatingCallback,
                     static_cast<void *>(&data), MlirWalkPreOrder);
@@ -1249,6 +1250,21 @@ void PyOperationBase::writeBytecode(const py::object &fileObject,
                               .str());
 }
 
+void PyOperationBase::walk(
+    std::function<MlirWalkResult(MlirOperation)> callback,
+    MlirWalkOrder walkOrder) {
+  PyOperation &operation = getOperation();
+  operation.checkValid();
+  MlirOperationWalkCallback walkCallback = [](MlirOperation op,
+                                              void *userData) {
+    auto *fn =
+        static_cast<std::function<MlirWalkResult(MlirOperation)> *>(userData);
+    return (*fn)(op);
+  };
+
+  mlirOperationWalk(operation, walkCallback, &callback, walkOrder);
+}
+
 py::object PyOperationBase::getAsm(bool binary,
                                    std::optional<int64_t> largeElementsLimit,
                                    bool enableDebugInfo, bool prettyDebugInfo,
@@ -2511,6 +2527,15 @@ void mlir::python::populateIRCore(py::module &m) {
       .value("NOTE", MlirDiagnosticNote)
       .value("REMARK", MlirDiagnosticRemark);
 
+  py::enum_<MlirWalkOrder>(m, "WalkOrder", py::module_local())
+      .value("PRE_ORDER", MlirWalkPreOrder)
+      .value("POST_ORDER", MlirWalkPostOrder);
+
+  py::enum_<MlirWalkResult>(m, "WalkResult", py::module_local())
+      .value("ADVANCE", MlirWalkResultAdvance)
+      .value("INTERRUPT", MlirWalkResultInterrupt)
+      .value("SKIP", MlirWalkResultSkip);
+
   //----------------------------------------------------------------------------
   // Mapping of Diagnostics.
   //----------------------------------------------------------------------------
@@ -2989,8 +3014,7 @@ void mlir::python::populateIRCore(py::module &m) {
            py::arg("binary") = false, kOperationPrintStateDocstring)
       .def("print",
            py::overload_cast<std::optional<int64_t>, bool, bool, bool, bool,
-                             bool, py::object, bool>(
-               &PyOperationBase::print),
+                             bool, py::object, bool>(&PyOperationBase::print),
            // Careful: Lots of arguments must match up with print method.
            py::arg("large_elements_limit") = py::none(),
            py::arg("enable_debug_info") = false,
@@ -3038,7 +3062,9 @@ void mlir::python::populateIRCore(py::module &m) {
             return operation.createOpView();
           },
           "Detaches the operation from its parent block.")
-      .def("erase", [](PyOperationBase &self) { self.getOperation().erase(); });
+      .def("erase", [](PyOperationBase &self) { self.getOperation().erase(); })
+      .def("walk", &PyOperationBase::walk, py::arg("callback"),
+           py::arg("walk_order") = MlirWalkPostOrder);
 
   py::class_<PyOperation, PyOperationBase>(m, "Operation", py::module_local())
       .def_static("create", &PyOperation::create, py::arg("name"),

diff  --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index 9acfdde25ae047..b038a0c54d29b9 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -579,6 +579,10 @@ class PyOperationBase {
   void writeBytecode(const pybind11::object &fileObject,
                      std::optional<int64_t> bytecodeVersion);
 
+  // Implement the walk method.
+  void walk(std::function<MlirWalkResult(MlirOperation)> callback,
+            MlirWalkOrder walkOrder);
+
   /// Moves the operation before or after the other operation.
   void moveAfter(PyOperationBase &other);
   void moveBefore(PyOperationBase &other);

diff  --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index cdb64f4ec4a40f..a72cd247e73f60 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -717,17 +717,34 @@ void mlirOperationMoveBefore(MlirOperation op, MlirOperation other) {
   return unwrap(op)->moveBefore(unwrap(other));
 }
 
+static mlir::WalkResult unwrap(MlirWalkResult result) {
+  switch (result) {
+  case MlirWalkResultAdvance:
+    return mlir::WalkResult::advance();
+
+  case MlirWalkResultInterrupt:
+    return mlir::WalkResult::interrupt();
+
+  case MlirWalkResultSkip:
+    return mlir::WalkResult::skip();
+  }
+}
+
 void mlirOperationWalk(MlirOperation op, MlirOperationWalkCallback callback,
                        void *userData, MlirWalkOrder walkOrder) {
   switch (walkOrder) {
 
   case MlirWalkPreOrder:
     unwrap(op)->walk<mlir::WalkOrder::PreOrder>(
-        [callback, userData](Operation *op) { callback(wrap(op), userData); });
+        [callback, userData](Operation *op) {
+          return unwrap(callback(wrap(op), userData));
+        });
     break;
   case MlirWalkPostOrder:
     unwrap(op)->walk<mlir::WalkOrder::PostOrder>(
-        [callback, userData](Operation *op) { callback(wrap(op), userData); });
+        [callback, userData](Operation *op) {
+          return unwrap(callback(wrap(op), userData));
+        });
   }
 }
 

diff  --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index 8e79338c57a22a..3d05b2a12dd8ef 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -2244,9 +2244,22 @@ typedef struct {
   const char *x;
 } callBackData;
 
-void walkCallBack(MlirOperation op, void *rootOpVoid) {
+MlirWalkResult walkCallBack(MlirOperation op, void *rootOpVoid) {
   fprintf(stderr, "%s: %s\n", ((callBackData *)(rootOpVoid))->x,
           mlirIdentifierStr(mlirOperationGetName(op)).data);
+  return MlirWalkResultAdvance;
+}
+
+MlirWalkResult walkCallBackTestWalkResult(MlirOperation op, void *rootOpVoid) {
+  fprintf(stderr, "%s: %s\n", ((callBackData *)(rootOpVoid))->x,
+          mlirIdentifierStr(mlirOperationGetName(op)).data);
+  if (strcmp(mlirIdentifierStr(mlirOperationGetName(op)).data, "func.func") ==
+      0)
+    return MlirWalkResultSkip;
+  if (strcmp(mlirIdentifierStr(mlirOperationGetName(op)).data, "arith.addi") ==
+      0)
+    return MlirWalkResultInterrupt;
+  return MlirWalkResultAdvance;
 }
 
 int testOperationWalk(MlirContext ctx) {
@@ -2259,6 +2272,9 @@ int testOperationWalk(MlirContext ctx) {
                              "    arith.addi %1, %1: i32\n"
                              "    return\n"
                              "  }\n"
+                             "  func.func @bar() {\n"
+                             "    return\n"
+                             "  }\n"
                              "}";
   MlirModule module =
       mlirModuleCreateParse(ctx, mlirStringRefCreateFromCString(moduleString));
@@ -2266,22 +2282,42 @@ int testOperationWalk(MlirContext ctx) {
   callBackData data;
   data.x = "i love you";
 
-  // CHECK: i love you: arith.constant
-  // CHECK: i love you: arith.addi
-  // CHECK: i love you: func.return
-  // CHECK: i love you: func.func
-  // CHECK: i love you: builtin.module
+  // CHECK-NEXT: i love you: arith.constant
+  // CHECK-NEXT: i love you: arith.addi
+  // CHECK-NEXT: i love you: func.return
+  // CHECK-NEXT: i love you: func.func
+  // CHECK-NEXT: i love you: func.return
+  // CHECK-NEXT: i love you: func.func
+  // CHECK-NEXT: i love you: builtin.module
   mlirOperationWalk(mlirModuleGetOperation(module), walkCallBack,
                     (void *)(&data), MlirWalkPostOrder);
 
   data.x = "i don't love you";
-  // CHECK: i don't love you: builtin.module
-  // CHECK: i don't love you: func.func
-  // CHECK: i don't love you: arith.constant
-  // CHECK: i don't love you: arith.addi
-  // CHECK: i don't love you: func.return
+  // CHECK-NEXT: i don't love you: builtin.module
+  // CHECK-NEXT: i don't love you: func.func
+  // CHECK-NEXT: i don't love you: arith.constant
+  // CHECK-NEXT: i don't love you: arith.addi
+  // CHECK-NEXT: i don't love you: func.return
+  // CHECK-NEXT: i don't love you: func.func
+  // CHECK-NEXT: i don't love you: func.return
   mlirOperationWalk(mlirModuleGetOperation(module), walkCallBack,
                     (void *)(&data), MlirWalkPreOrder);
+
+  data.x = "interrupt";
+  // Interrupted at `arith.addi`
+  // CHECK-NEXT: interrupt: arith.constant
+  // CHECK-NEXT: interrupt: arith.addi
+  mlirOperationWalk(mlirModuleGetOperation(module), walkCallBackTestWalkResult,
+                    (void *)(&data), MlirWalkPostOrder);
+
+  data.x = "skip";
+  // Skip at `func.func`
+  // CHECK-NEXT: skip: builtin.module
+  // CHECK-NEXT: skip: func.func
+  // CHECK-NEXT: skip: func.func
+  mlirOperationWalk(mlirModuleGetOperation(module), walkCallBackTestWalkResult,
+                    (void *)(&data), MlirWalkPreOrder);
+
   mlirModuleDestroy(module);
   return 0;
 }

diff  --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index 04f8a9936e31f7..9666e63bda1e0e 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -1015,3 +1015,78 @@ def testOperationParse():
         print(
             f"op_with_source_name: {o.get_asm(enable_debug_info=True, use_local_scope=True)}"
         )
+
+
+# CHECK-LABEL: TEST: testOpWalk
+ at run
+def testOpWalk():
+    ctx = Context()
+    ctx.allow_unregistered_dialects = True
+    module = Module.parse(
+        r"""
+    builtin.module {
+      func.func @f() {
+        func.return
+      }
+    }
+  """,
+        ctx,
+    )
+
+    def callback(op):
+        print(op.name)
+        return WalkResult.ADVANCE
+
+    # Test post-order walk (default).
+    # CHECK-NEXT:  Post-order
+    # CHECK-NEXT:  func.return
+    # CHECK-NEXT:  func.func
+    # CHECK-NEXT:  builtin.module
+    print("Post-order")
+    module.operation.walk(callback)
+
+    # Test pre-order walk.
+    # CHECK-NEXT:  Pre-order
+    # CHECK-NEXT:  builtin.module
+    # CHECK-NEXT:  func.fun
+    # CHECK-NEXT:  func.return
+    print("Pre-order")
+    module.operation.walk(callback, WalkOrder.PRE_ORDER)
+
+    # Test interrput.
+    # CHECK-NEXT:  Interrupt post-order
+    # CHECK-NEXT:  func.return
+    print("Interrupt post-order")
+
+    def callback(op):
+        print(op.name)
+        return WalkResult.INTERRUPT
+
+    module.operation.walk(callback)
+
+    # Test skip.
+    # CHECK-NEXT:  Skip pre-order
+    # CHECK-NEXT:  builtin.module
+    print("Skip pre-order")
+
+    def callback(op):
+        print(op.name)
+        return WalkResult.SKIP
+
+    module.operation.walk(callback, WalkOrder.PRE_ORDER)
+
+    # Test exception.
+    # CHECK: Exception
+    # CHECK-NEXT: func.return
+    # CHECK-NEXT: Exception raised
+    print("Exception")
+
+    def callback(op):
+        print(op.name)
+        raise ValueError
+        return WalkResult.ADVANCE
+
+    try:
+        module.operation.walk(callback)
+    except ValueError:
+        print("Exception raised")


        


More information about the Mlir-commits mailing list