[Mlir-commits] [mlir] [mlir][python] Add `walk` method to PyOperationBase (PR #87962)
Hideto Ueno
llvmlistbot at llvm.org
Thu Apr 11 05:51:51 PDT 2024
https://github.com/uenoku updated https://github.com/llvm/llvm-project/pull/87962
>From 77b9bbf120f5cead085fe642412a5ea3550009d2 Mon Sep 17 00:00:00 2001
From: Hideto Ueno <uenoku.tokotoko at gmail.com>
Date: Fri, 5 Apr 2024 03:58:44 -0700
Subject: [PATCH 1/3] [mlir][python] Add `walk` method to PyOperationBase
This commit adds `walk` method that uses a python object as
a callback
---
mlir/lib/Bindings/Python/IRCore.cpp | 18 +++++++++++++++-
mlir/lib/Bindings/Python/IRModule.h | 3 +++
mlir/test/python/ir/operation.py | 32 +++++++++++++++++++++++++++++
3 files changed, 52 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 734f2f7f3f94cf..848d918e16a7d1 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -1249,6 +1249,20 @@ void PyOperationBase::writeBytecode(const py::object &fileObject,
.str());
}
+void PyOperationBase::walk(py::object callback, bool usePreOrder) {
+ PyOperation &operation = getOperation();
+ operation.checkValid();
+ MlirOperationWalkCallback walkCallback =
+ [](MlirOperation op,
+ void *userData) {
+ py::object *fn = static_cast<py::object *>(userData);
+ (*fn)(op);
+ };
+ mlirOperationWalk(operation, walkCallback, &callback,
+ usePreOrder ? MlirWalkOrder::MlirWalkPreOrder
+ : MlirWalkOrder::MlirWalkPostOrder);
+}
+
py::object PyOperationBase::getAsm(bool binary,
std::optional<int64_t> largeElementsLimit,
bool enableDebugInfo, bool prettyDebugInfo,
@@ -3038,7 +3052,9 @@ void mlir::python::populateIRCore(py::module &m) {
return operation.createOpView();
},
"Detaches the operation from its parent block.")
- .def("erase", [](PyOperationBase &self) { self.getOperation().erase(); });
+ .def("erase", [](PyOperationBase &self) { self.getOperation().erase(); })
+ .def("walk", &PyOperationBase::walk, py::arg("callback"),
+ py::arg("use_pre_order") = py::bool_(false));
py::class_<PyOperation, PyOperationBase>(m, "Operation", py::module_local())
.def_static("create", &PyOperation::create, py::arg("name"),
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index 9acfdde25ae047..ed15dd4f87c2b4 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -579,6 +579,9 @@ class PyOperationBase {
void writeBytecode(const pybind11::object &fileObject,
std::optional<int64_t> bytecodeVersion);
+ // Implement the walk method.
+ void walk(pybind11::object callback, bool usePreOrder);
+
/// Moves the operation before or after the other operation.
void moveAfter(PyOperationBase &other);
void moveBefore(PyOperationBase &other);
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index 04f8a9936e31f7..92a4f1b1545c20 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -1015,3 +1015,35 @@ def testOperationParse():
print(
f"op_with_source_name: {o.get_asm(enable_debug_info=True, use_local_scope=True)}"
)
+
+# CHECK-LABEL: TEST: testOpWalk
+ at run
+def testOpWalk():
+ ctx = Context()
+ ctx.allow_unregistered_dialects = True
+ module = Module.parse(
+ r"""
+ builtin.module {
+ func.func @f() {
+ func.return
+ }
+ }
+ """,
+ ctx,
+ )
+ callback = lambda op: print(op.name)
+ # Test post-order walk (default).
+ # CHECK-NEXT: Post-order
+ # CHECK-NEXT: func.return
+ # CHECK-NEXT: func.func
+ # CHECK-NEXT: builtin.module
+ print("Post-order")
+ module.operation.walk(callback)
+
+ # Test pre-order walk.
+ # CHECK-NEXT: Pre-order
+ # CHECK-NEXT: builtin.module
+ # CHECK-NEXT: func.fun
+ # CHECK-NEXT: func.return
+ print("Pre-order")
+ module.operation.walk(callback, True)
>From 218279d09cabc7f6f9ad995d6686cba4bafd76cc Mon Sep 17 00:00:00 2001
From: Hideto Ueno <uenoku.tokotoko at gmail.com>
Date: Sun, 7 Apr 2024 23:11:49 -0700
Subject: [PATCH 2/3] format
---
mlir/lib/Bindings/Python/IRCore.cpp | 6 ++----
mlir/test/python/ir/operation.py | 3 ++-
2 files changed, 4 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 848d918e16a7d1..1ee9571065215b 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -1252,8 +1252,7 @@ void PyOperationBase::writeBytecode(const py::object &fileObject,
void PyOperationBase::walk(py::object callback, bool usePreOrder) {
PyOperation &operation = getOperation();
operation.checkValid();
- MlirOperationWalkCallback walkCallback =
- [](MlirOperation op,
+ MlirOperationWalkCallback walkCallback = [](MlirOperation op,
void *userData) {
py::object *fn = static_cast<py::object *>(userData);
(*fn)(op);
@@ -3003,8 +3002,7 @@ void mlir::python::populateIRCore(py::module &m) {
py::arg("binary") = false, kOperationPrintStateDocstring)
.def("print",
py::overload_cast<std::optional<int64_t>, bool, bool, bool, bool,
- bool, py::object, bool>(
- &PyOperationBase::print),
+ bool, py::object, bool>(&PyOperationBase::print),
// Careful: Lots of arguments must match up with print method.
py::arg("large_elements_limit") = py::none(),
py::arg("enable_debug_info") = false,
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index 92a4f1b1545c20..18790cd4f89c88 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -1016,6 +1016,7 @@ def testOperationParse():
f"op_with_source_name: {o.get_asm(enable_debug_info=True, use_local_scope=True)}"
)
+
# CHECK-LABEL: TEST: testOpWalk
@run
def testOpWalk():
@@ -1029,7 +1030,7 @@ def testOpWalk():
}
}
""",
- ctx,
+ ctx,
)
callback = lambda op: print(op.name)
# Test post-order walk (default).
>From 96f76ac00aa45421b180df7786a40767522e0bdf Mon Sep 17 00:00:00 2001
From: Hideto Ueno <uenoku.tokotoko at gmail.com>
Date: Thu, 11 Apr 2024 05:49:45 -0700
Subject: [PATCH 3/3] Add CAPI for WalkResult. Add python enum definition of
WalkResult/WalkOrder
---
mlir/include/mlir-c/IR.h | 10 +++-
.../mlir/Bindings/Python/PybindAdaptors.h | 1 +
mlir/lib/Bindings/Python/IRCore.cpp | 26 ++++++---
mlir/lib/Bindings/Python/IRModule.h | 3 +-
mlir/lib/CAPI/IR/IR.cpp | 21 ++++++-
mlir/test/CAPI/ir.c | 58 +++++++++++++++----
mlir/test/python/ir/operation.py | 41 ++++++++++++-
7 files changed, 136 insertions(+), 24 deletions(-)
diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 82da511f807a34..32abacf353133e 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -705,6 +705,13 @@ MLIR_CAPI_EXPORTED void mlirOperationMoveAfter(MlirOperation op,
MLIR_CAPI_EXPORTED void mlirOperationMoveBefore(MlirOperation op,
MlirOperation other);
+/// Operation walk result.
+typedef enum MlirWalkResult {
+ MlirWalkResultAdvance,
+ MlirWalkResultInterrupt,
+ MlirWalkResultSkip
+} MlirWalkResult;
+
/// Traversal order for operation walk.
typedef enum MlirWalkOrder {
MlirWalkPreOrder,
@@ -713,7 +720,8 @@ typedef enum MlirWalkOrder {
/// Operation walker type. The handler is passed an (opaque) reference to an
/// operation and a pointer to a `userData`.
-typedef void (*MlirOperationWalkCallback)(MlirOperation, void *userData);
+typedef MlirWalkResult (*MlirOperationWalkCallback)(MlirOperation,
+ void *userData);
/// Walks operation `op` in `walkOrder` and calls `callback` on that operation.
/// `*userData` is passed to the callback as well and can be used to tunnel some
diff --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
index 52f6321251919e..d8f22c7aa17096 100644
--- a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
+++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
@@ -18,6 +18,7 @@
#ifndef MLIR_BINDINGS_PYTHON_PYBINDADAPTORS_H
#define MLIR_BINDINGS_PYTHON_PYBINDADAPTORS_H
+#include <pybind11/functional.h>
#include <pybind11/pybind11.h>
#include <pybind11/pytypes.h>
#include <pybind11/stl.h>
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 1ee9571065215b..d875f4eba2b139 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -674,6 +674,7 @@ void PyMlirContext::clearOperationsInside(PyOperationBase &op) {
data->rootOp.getOperation().getContext()->clearOperation(op);
else
data->rootSeen = true;
+ return MlirWalkResult::MlirWalkResultAdvance;
};
mlirOperationWalk(op.getOperation(), invalidatingCallback,
static_cast<void *>(&data), MlirWalkPreOrder);
@@ -1249,17 +1250,19 @@ void PyOperationBase::writeBytecode(const py::object &fileObject,
.str());
}
-void PyOperationBase::walk(py::object callback, bool usePreOrder) {
+void PyOperationBase::walk(
+ std::function<MlirWalkResult(MlirOperation)> callback,
+ MlirWalkOrder walkOrder) {
PyOperation &operation = getOperation();
operation.checkValid();
MlirOperationWalkCallback walkCallback = [](MlirOperation op,
void *userData) {
- py::object *fn = static_cast<py::object *>(userData);
- (*fn)(op);
+ auto *fn =
+ static_cast<std::function<MlirWalkResult(MlirOperation)> *>(userData);
+ return (*fn)(op);
};
- mlirOperationWalk(operation, walkCallback, &callback,
- usePreOrder ? MlirWalkOrder::MlirWalkPreOrder
- : MlirWalkOrder::MlirWalkPostOrder);
+
+ mlirOperationWalk(operation, walkCallback, &callback, walkOrder);
}
py::object PyOperationBase::getAsm(bool binary,
@@ -2524,6 +2527,15 @@ void mlir::python::populateIRCore(py::module &m) {
.value("NOTE", MlirDiagnosticNote)
.value("REMARK", MlirDiagnosticRemark);
+ py::enum_<MlirWalkOrder>(m, "WalkOrder", py::module_local())
+ .value("PRE_ORDER", MlirWalkPreOrder)
+ .value("POST_ORDER", MlirWalkPostOrder);
+
+ py::enum_<MlirWalkResult>(m, "WalkResult", py::module_local())
+ .value("ADVANCE", MlirWalkResultAdvance)
+ .value("INTERRUPT", MlirWalkResultInterrupt)
+ .value("SKIP", MlirWalkResultSkip);
+
//----------------------------------------------------------------------------
// Mapping of Diagnostics.
//----------------------------------------------------------------------------
@@ -3052,7 +3064,7 @@ void mlir::python::populateIRCore(py::module &m) {
"Detaches the operation from its parent block.")
.def("erase", [](PyOperationBase &self) { self.getOperation().erase(); })
.def("walk", &PyOperationBase::walk, py::arg("callback"),
- py::arg("use_pre_order") = py::bool_(false));
+ py::arg("walk_order") = MlirWalkPostOrder);
py::class_<PyOperation, PyOperationBase>(m, "Operation", py::module_local())
.def_static("create", &PyOperation::create, py::arg("name"),
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index ed15dd4f87c2b4..b038a0c54d29b9 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -580,7 +580,8 @@ class PyOperationBase {
std::optional<int64_t> bytecodeVersion);
// Implement the walk method.
- void walk(pybind11::object callback, bool usePreOrder);
+ void walk(std::function<MlirWalkResult(MlirOperation)> callback,
+ MlirWalkOrder walkOrder);
/// Moves the operation before or after the other operation.
void moveAfter(PyOperationBase &other);
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index cdb64f4ec4a40f..065a8580eae497 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -717,17 +717,34 @@ void mlirOperationMoveBefore(MlirOperation op, MlirOperation other) {
return unwrap(op)->moveBefore(unwrap(other));
}
+static mlir::WalkResult translateWalkResult(MlirWalkResult result) {
+ switch (result) {
+ case MlirWalkResultAdvance:
+ return mlir::WalkResult::advance();
+
+ case MlirWalkResultInterrupt:
+ return mlir::WalkResult::interrupt();
+
+ case MlirWalkResultSkip:
+ return mlir::WalkResult::skip();
+ }
+}
+
void mlirOperationWalk(MlirOperation op, MlirOperationWalkCallback callback,
void *userData, MlirWalkOrder walkOrder) {
switch (walkOrder) {
case MlirWalkPreOrder:
unwrap(op)->walk<mlir::WalkOrder::PreOrder>(
- [callback, userData](Operation *op) { callback(wrap(op), userData); });
+ [callback, userData](Operation *op) {
+ return translateWalkResult(callback(wrap(op), userData));
+ });
break;
case MlirWalkPostOrder:
unwrap(op)->walk<mlir::WalkOrder::PostOrder>(
- [callback, userData](Operation *op) { callback(wrap(op), userData); });
+ [callback, userData](Operation *op) {
+ return translateWalkResult(callback(wrap(op), userData));
+ });
}
}
diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index 8e79338c57a22a..3d05b2a12dd8ef 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -2244,9 +2244,22 @@ typedef struct {
const char *x;
} callBackData;
-void walkCallBack(MlirOperation op, void *rootOpVoid) {
+MlirWalkResult walkCallBack(MlirOperation op, void *rootOpVoid) {
fprintf(stderr, "%s: %s\n", ((callBackData *)(rootOpVoid))->x,
mlirIdentifierStr(mlirOperationGetName(op)).data);
+ return MlirWalkResultAdvance;
+}
+
+MlirWalkResult walkCallBackTestWalkResult(MlirOperation op, void *rootOpVoid) {
+ fprintf(stderr, "%s: %s\n", ((callBackData *)(rootOpVoid))->x,
+ mlirIdentifierStr(mlirOperationGetName(op)).data);
+ if (strcmp(mlirIdentifierStr(mlirOperationGetName(op)).data, "func.func") ==
+ 0)
+ return MlirWalkResultSkip;
+ if (strcmp(mlirIdentifierStr(mlirOperationGetName(op)).data, "arith.addi") ==
+ 0)
+ return MlirWalkResultInterrupt;
+ return MlirWalkResultAdvance;
}
int testOperationWalk(MlirContext ctx) {
@@ -2259,6 +2272,9 @@ int testOperationWalk(MlirContext ctx) {
" arith.addi %1, %1: i32\n"
" return\n"
" }\n"
+ " func.func @bar() {\n"
+ " return\n"
+ " }\n"
"}";
MlirModule module =
mlirModuleCreateParse(ctx, mlirStringRefCreateFromCString(moduleString));
@@ -2266,22 +2282,42 @@ int testOperationWalk(MlirContext ctx) {
callBackData data;
data.x = "i love you";
- // CHECK: i love you: arith.constant
- // CHECK: i love you: arith.addi
- // CHECK: i love you: func.return
- // CHECK: i love you: func.func
- // CHECK: i love you: builtin.module
+ // CHECK-NEXT: i love you: arith.constant
+ // CHECK-NEXT: i love you: arith.addi
+ // CHECK-NEXT: i love you: func.return
+ // CHECK-NEXT: i love you: func.func
+ // CHECK-NEXT: i love you: func.return
+ // CHECK-NEXT: i love you: func.func
+ // CHECK-NEXT: i love you: builtin.module
mlirOperationWalk(mlirModuleGetOperation(module), walkCallBack,
(void *)(&data), MlirWalkPostOrder);
data.x = "i don't love you";
- // CHECK: i don't love you: builtin.module
- // CHECK: i don't love you: func.func
- // CHECK: i don't love you: arith.constant
- // CHECK: i don't love you: arith.addi
- // CHECK: i don't love you: func.return
+ // CHECK-NEXT: i don't love you: builtin.module
+ // CHECK-NEXT: i don't love you: func.func
+ // CHECK-NEXT: i don't love you: arith.constant
+ // CHECK-NEXT: i don't love you: arith.addi
+ // CHECK-NEXT: i don't love you: func.return
+ // CHECK-NEXT: i don't love you: func.func
+ // CHECK-NEXT: i don't love you: func.return
mlirOperationWalk(mlirModuleGetOperation(module), walkCallBack,
(void *)(&data), MlirWalkPreOrder);
+
+ data.x = "interrupt";
+ // Interrupted at `arith.addi`
+ // CHECK-NEXT: interrupt: arith.constant
+ // CHECK-NEXT: interrupt: arith.addi
+ mlirOperationWalk(mlirModuleGetOperation(module), walkCallBackTestWalkResult,
+ (void *)(&data), MlirWalkPostOrder);
+
+ data.x = "skip";
+ // Skip at `func.func`
+ // CHECK-NEXT: skip: builtin.module
+ // CHECK-NEXT: skip: func.func
+ // CHECK-NEXT: skip: func.func
+ mlirOperationWalk(mlirModuleGetOperation(module), walkCallBackTestWalkResult,
+ (void *)(&data), MlirWalkPreOrder);
+
mlirModuleDestroy(module);
return 0;
}
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index 18790cd4f89c88..20483ae8f2092a 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -1032,7 +1032,10 @@ def testOpWalk():
""",
ctx,
)
- callback = lambda op: print(op.name)
+
+ def callback(op):
+ print(op.name)
+ return WalkResult.ADVANCE
# Test post-order walk (default).
# CHECK-NEXT: Post-order
# CHECK-NEXT: func.return
@@ -1047,4 +1050,38 @@ def testOpWalk():
# CHECK-NEXT: func.fun
# CHECK-NEXT: func.return
print("Pre-order")
- module.operation.walk(callback, True)
+ module.operation.walk(callback, WalkOrder.PRE_ORDER)
+
+ # Test interrput.
+ # CHECK-NEXT: Interrupt post-order
+ # CHECK-NEXT: func.return
+ print("Interrupt post-order")
+ def callback(op):
+ print(op.name)
+ return WalkResult.INTERRUPT
+ module.operation.walk(callback)
+
+ # Test skipk.
+ # CHECK-NEXT: Skip pre-order
+ # CHECK-NEXT: builtin.module
+ print("Skip pre-order")
+ def callback(op):
+ print(op.name)
+ return WalkResult.SKIP
+ module.operation.walk(callback, WalkOrder.PRE_ORDER)
+
+
+ # Test exception
+ print("Exception")
+ def callback(op):
+ print(op.name)
+ raise ValueError
+ return WalkResult.ADVANCE
+ try:
+ module.operation.walk(callback)
+ except ValueError:
+ print("Exception raised")
+ # CHECK: Exception
+ # CHECK-NEXT: func.return
+ # CHECK-NEXT: Exception raised
+
More information about the Mlir-commits
mailing list