[Mlir-commits] [mlir] [MLIR][Python] Add walk_of_type() binding and get_ops_of_type() utility (PR #186131)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Mar 12 07:50:04 PDT 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: RattataKing (RattataKing)
<details>
<summary>Changes</summary>
MLIR's C++ `Operation::walk` supports type-filtered traversal (e.g. `op->walk([](arith::AddIOp op) { ... })`), but the Python binding `op.walk()` requires users to manually implement type filtering inside the callback function.
This PR adds a new C API, `walkOfType()`, built on top of the existing `walk()` with type filtering, which provides Python API: `op.walk_of_type(op_class, callback)`.
This PR also adds a common use helper in mlir/ir that collects all ops of a given type into a list. Users can just call: `ops = ir.get_ops_of_type(root, op_class)`.
---
Full diff: https://github.com/llvm/llvm-project/pull/186131.diff
4 Files Affected:
- (modified) mlir/include/mlir/Bindings/Python/IRCore.h (+5)
- (modified) mlir/lib/Bindings/Python/IRCore.cpp (+30-1)
- (modified) mlir/python/mlir/ir.py (+19)
- (modified) mlir/test/python/ir/operation.py (+81)
``````````diff
diff --git a/mlir/include/mlir/Bindings/Python/IRCore.h b/mlir/include/mlir/Bindings/Python/IRCore.h
index bd2d49acbf681..836dedcff0ade 100644
--- a/mlir/include/mlir/Bindings/Python/IRCore.h
+++ b/mlir/include/mlir/Bindings/Python/IRCore.h
@@ -599,6 +599,11 @@ class MLIR_PYTHON_API_EXPORTED PyOperationBase {
void walk(std::function<PyWalkResult(MlirOperation)> callback,
PyWalkOrder walkOrder);
+ // Wrap the walk method with a type filter.
+ void walkOfType(nanobind::object opClass,
+ std::function<PyWalkResult(MlirOperation)> callback,
+ PyWalkOrder walkOrder);
+
/// Moves the operation before or after the other operation.
void moveAfter(PyOperationBase &other);
void moveBefore(PyOperationBase &other);
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index b8637c57a3f48..ac3dc00c008a0 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -1145,6 +1145,20 @@ void PyOperationBase::walk(std::function<PyWalkResult(MlirOperation)> callback,
}
}
+void PyOperationBase::walkOfType(
+ nb::object opClass, std::function<PyWalkResult(MlirOperation)> callback,
+ PyWalkOrder walkOrder) {
+
+ auto filtered = [&](MlirOperation mlirOp) -> PyWalkResult {
+ nb::object opview = nb::cast(mlirOp).attr("opview");
+ if (nb::isinstance(opview, opClass)) {
+ return callback(mlirOp);
+ };
+ return PyWalkResult::Advance;
+ };
+ walk(filtered, walkOrder);
+}
+
nb::object PyOperationBase::getAsm(bool binary,
std::optional<int64_t> largeElementsLimit,
std::optional<int64_t> largeResourceLimit,
@@ -3868,7 +3882,22 @@ void populateIRCore(nb::module_ &m) {
Args:
callback: A callable that takes an Operation and returns a WalkResult.
- walk_order: The order of traversal (PRE_ORDER or POST_ORDER).)");
+ walk_order: The order of traversal (PRE_ORDER or POST_ORDER).)")
+ .def("walk_of_type", &PyOperationBase::walkOfType, "op_class"_a,
+ "callback"_a, "walk_order"_a = PyWalkOrder::PostOrder,
+ // clang-format off
+ nb::sig("def walk_of_type(self, op_class: type[OpView], callback: Callable[[Operation], WalkResult], walk_order: WalkOrder) -> None"),
+ // clang-format on
+ R"(
+ Walks the operation tree, invoking the callback only on operations of the specified type.
+
+ Args:
+ op_class: The operation type to match.
+ callback: A callable that takes an Operation and returns a WalkResult.
+ walk_order: The traversal order (PRE_ORDER or POST_ORDER).
+
+ For example, op.walk_of_type(arith.AddIOp, callback) walks the operation tree
+ and invokes callback only on arith.AddIOp operations.)");
nb::class_<PyOperation, PyOperationBase>(m, "Operation")
.def_static(
diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py
index d86298a72c6f2..023808b34fbb7 100644
--- a/mlir/python/mlir/ir.py
+++ b/mlir/python/mlir/ir.py
@@ -45,6 +45,25 @@ def get_parent_of_type(op: OpView | Operation, op_class: type[OpView]) -> OpView
return None
+def get_ops_of_type(root: OpView | Operation | Module, op_class: type[OpView]) -> list[OpView]:
+ """Return all operations of the given type in the operation tree.
+
+ Args:
+ root: The operation or module to start traversing from.
+ op_class: The OpView subclass to visit for (e.g. func.FuncOp).
+
+ Returns:
+ A list of operations of the given type.
+ """
+ op = root.operation if isinstance(root, Module) else root
+ ops = []
+ def collect_ops(op: Operation):
+ ops.append(op.opview)
+ return WalkResult.ADVANCE
+ op.walk_of_type(op_class, collect_ops)
+ return ops
+
+
@contextmanager
def loc_tracebacks(*, max_depth: int | None = None) -> Generator[None]:
"""Enables automatic traceback-based locations for MLIR operations.
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index 865dd226cbe2a..f184ec2a3984c 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -1323,6 +1323,58 @@ def callback(op):
module.operation.walk(callback)
except RuntimeError:
print("Exception raised")
+
+# CHECK-LABEL: TEST: testOpWalkOfType
+ at run
+def testOpWalkOfType():
+ with Context(), Location.unknown():
+ module = Module.parse("""
+ module {
+ func.func @f() { return }
+ func.func @g() { return }
+ arith.constant dense<0> : tensor<i32>
+ }
+ """)
+
+ # Callback: only visits ops of the requested type.
+ # CHECK: only FuncOp visited: True
+ only_funcs = True
+ def check_type(op):
+ nonlocal only_funcs
+ if not isinstance(op.opview, func.FuncOp):
+ only_funcs = False
+ return WalkResult.ADVANCE
+ module.operation.walk_of_type(func.FuncOp, check_type)
+ print(f"only FuncOp visited: {only_funcs}")
+
+ # Callback: interrupt after first match.
+ # CHECK: interrupted after: 1
+ seen = []
+ def stop_after_first(op):
+ seen.append(op.opview)
+ return WalkResult.INTERRUPT
+ module.operation.walk_of_type(func.FuncOp, stop_after_first)
+ print(f"interrupted after: {len(seen)}")
+
+ # Callback: no match, callback never called.
+ # CHECK: never called: True
+ called = False
+ def should_not_run(op):
+ nonlocal called
+ called = True
+ return WalkResult.ADVANCE
+ module.operation.walk_of_type(scf.ForOp, should_not_run)
+ print(f"never called: {not called}")
+
+ # Callback: collect all matching ops.
+ # CHECK: collected func.FuncOp: ['"f"', '"g"']
+ collected = []
+ def collect(op):
+ collected.append(op.opview)
+ return WalkResult.ADVANCE
+ module.operation.walk_of_type(func.FuncOp, collect)
+ assert all(isinstance(r, func.FuncOp) for r in collected)
+ print(f"collected func.FuncOp: {[str(r.name) for r in collected]}")
# CHECK-LABEL: TEST: testOpReplaceUsesWith
@@ -1437,3 +1489,32 @@ def testGetParentOfType():
assert False, "expected TypeError"
except TypeError:
pass
+
+
+# CHECK-LABEL: TEST: test_get_ops_of_type
+ at run
+def test_get_ops_of_type():
+ with Context(), Location.unknown():
+ module = Module.parse("""
+ module {
+ func.func @f() { return }
+ func.func @g() { return }
+ }
+ """)
+
+ # CHECK: get_ops_of_type func.func count: 2
+ results = get_ops_of_type(module, func.FuncOp)
+ print(f"get_ops_of_type func.func count: {len(results)}")
+ assert len(results) == 2
+ assert all(isinstance(r, func.FuncOp) for r in results)
+
+ # CHECK: get_ops_of_type scf.for count: 0
+ results = get_ops_of_type(module, scf.ForOp)
+ print(f"get_ops_of_type scf.for count: {len(results)}")
+ assert len(results) == 0
+
+ # Accepts OpView as root.
+ func_op = get_ops_of_type(module, func.FuncOp)[0]
+ results = get_ops_of_type(func_op, func.ReturnOp)
+ assert len(results) == 1
+ assert isinstance(results[0], func.ReturnOp)
``````````
</details>
https://github.com/llvm/llvm-project/pull/186131
More information about the Mlir-commits
mailing list