[Mlir-commits] [mlir] 39b3b2e - [MLIR][Python] Add type filter to walk() binding and add get_ops_of_type() utility (#186131)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Mar 13 10:36:35 PDT 2026
Author: RattataKing
Date: 2026-03-13T13:36:30-04:00
New Revision: 39b3b2e0e8d1a0c2f32976a1e440d0a64137208d
URL: https://github.com/llvm/llvm-project/commit/39b3b2e0e8d1a0c2f32976a1e440d0a64137208d
DIFF: https://github.com/llvm/llvm-project/commit/39b3b2e0e8d1a0c2f32976a1e440d0a64137208d.diff
LOG: [MLIR][Python] Add type filter to walk() binding and add get_ops_of_type() utility (#186131)
MLIR's C++ `Operation::walk` supports type-filtered traversal (e.g.
`op->walk([](arith::AddIOp op) { ... })`), but the Python binding
`op.walk()` requires users to manually implement type filtering inside
the callback function.
This PR adds type filtering into the python binding `op.walk()`, if
users pass `op_class`, walk() will only apply callback to matching ops.
This PR also adds a common use helper in mlir/ir that collects all ops
of a given type into a list. Users can just call: `ops =
ir.get_ops_of_type(root, op_class)`.
Added:
Modified:
mlir/lib/Bindings/Python/IRCore.cpp
mlir/python/mlir/ir.py
mlir/test/python/ir/operation.py
Removed:
################################################################################
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 92ea44605b01a..7eb59d61b0d57 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -3927,17 +3927,40 @@ void populateIRCore(nb::module_ &m) {
Note:
After erasing, any Python references to the operation become invalid.)")
- .def("walk", &PyOperationBase::walk, "callback"_a,
- "walk_order"_a = PyWalkOrder::PostOrder,
- // clang-format off
- nb::sig("def walk(self, callback: Callable[[Operation], WalkResult], walk_order: WalkOrder = ...) -> None"),
- // clang-format on
- R"(
+ .def(
+ "walk",
+ [](PyOperationBase &self,
+ std::function<PyWalkResult(MlirOperation)> callback,
+ PyWalkOrder walkOrder, std::optional<nb::object> opClass) {
+ if (!opClass)
+ return self.walk(callback, walkOrder);
+ self.walk(
+ [&](MlirOperation mlirOp) -> PyWalkResult {
+ nb::object opview =
+ PyOperation::forOperation(
+ self.getOperation().getContext(), mlirOp)
+ ->createOpView();
+ if (nb::isinstance(opview, *opClass))
+ return callback(mlirOp);
+ return PyWalkResult::Advance;
+ },
+ walkOrder);
+ },
+ "callback"_a, "walk_order"_a = PyWalkOrder::PostOrder,
+ "op_class"_a = nb::none(),
+ // clang-format off
+ nb::sig("def walk(self, callback: Callable[[Operation], WalkResult], walk_order: WalkOrder = ..., op_class: type[OpView] | None = None) -> None"),
+ // clang-format on
+ R"(
Walks the operation tree with a callback function.
+ If op_class is provided, the callback is only invoked on operations
+ of that type; all other operations are skipped silently.
+
Args:
callback: A callable that takes an Operation and returns a WalkResult.
- walk_order: The order of traversal (PRE_ORDER or POST_ORDER).)");
+ walk_order: The order of traversal (PRE_ORDER or POST_ORDER).
+ op_class: If provided, only operations of this type are passed to the callback.)");
nb::class_<PyOperation, PyOperationBase>(m, "Operation")
.def_static(
diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py
index 99bd135b49636..210465daad0d8 100644
--- a/mlir/python/mlir/ir.py
+++ b/mlir/python/mlir/ir.py
@@ -45,6 +45,30 @@ def get_parent_of_type(op: OpView | Operation, op_class: type[OpView]) -> OpView
return None
+def get_ops_of_type(
+ root: OpView | Operation | Module, op_class: type[OpView] | None = None
+) -> list[OpView]:
+ """Return all operations of the given type in the operation tree.
+
+ Args:
+ root: The operation or module to start traversing from.
+ op_class: The OpView subclass to filter by (e.g. func.FuncOp). If None,
+ collects all operations in the tree.
+
+ Returns:
+ A list of operations of the given type.
+ """
+ op = root.operation if isinstance(root, Module) else root
+ ops = []
+
+ def collect_ops(op: Operation):
+ ops.append(op.opview)
+ return WalkResult.ADVANCE
+
+ op.walk(collect_ops, op_class=op_class)
+ return ops
+
+
@contextmanager
def loc_tracebacks(*, max_depth: int | None = None) -> Generator[None, None, None]:
"""Enables automatic traceback-based locations for MLIR operations.
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index 865dd226cbe2a..f561a1bc624d8 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -1324,6 +1324,73 @@ def callback(op):
except RuntimeError:
print("Exception raised")
+ # Test op_class filter: only visits ops of the requested type.
+ module = Module.parse(
+ r"""
+ module {
+ func.func @f() {
+ func.return
+ }
+ func.func @g() {
+ func.return
+ }
+ arith.constant dense<0> : tensor<i32>
+ }
+ """,
+ ctx,
+ )
+
+ # CHECK-NEXT: only FuncOp visited: True
+ only_funcs = True
+
+ def check_type(op):
+ nonlocal only_funcs
+ if not isinstance(op.opview, func.FuncOp):
+ only_funcs = False
+ return WalkResult.ADVANCE
+
+ module.operation.walk(check_type, op_class=func.FuncOp)
+ print(f"only FuncOp visited: {only_funcs}")
+
+ # CHECK-NEXT: interrupted after: 1
+ seen = []
+
+ def stop_after_first(op):
+ seen.append(op.opview)
+ return WalkResult.INTERRUPT
+
+ module.operation.walk(stop_after_first, op_class=func.FuncOp)
+ print(f"interrupted after: {len(seen)}")
+
+ # CHECK-NEXT: never called: True
+ called = False
+
+ def should_not_run(op):
+ nonlocal called
+ called = True
+ return WalkResult.ADVANCE
+
+ module.operation.walk(should_not_run, op_class=scf.ForOp)
+ print(f"never called: {not called}")
+
+ # CHECK-NEXT: collected func.FuncOp: ['"f"', '"g"']
+ collected = []
+
+ def collect(op):
+ collected.append(op.opview)
+ return WalkResult.ADVANCE
+
+ module.operation.walk(collect, op_class=func.FuncOp)
+ assert all(isinstance(r, func.FuncOp) for r in collected)
+ print(f"collected func.FuncOp: {[str(r.name) for r in collected]}")
+
+ # Test op_class with walk_order: pre-order visits FuncOps in source order.
+ # CHECK-NEXT: pre-order FuncOp names: ['"f"', '"g"']
+ collected.clear()
+ module.operation.walk(collect, WalkOrder.PRE_ORDER, op_class=func.FuncOp)
+ assert all(isinstance(r, func.FuncOp) for r in collected)
+ print(f"pre-order FuncOp names: {[str(r.name) for r in collected]}")
+
# CHECK-LABEL: TEST: testOpReplaceUsesWith
@run
@@ -1437,3 +1504,48 @@ def testGetParentOfType():
assert False, "expected TypeError"
except TypeError:
pass
+
+
+# CHECK-LABEL: TEST: test_get_ops_of_type
+ at run
+def test_get_ops_of_type():
+ with Context(), Location.unknown():
+ module = Module.parse(
+ r"""
+ module {
+ func.func @f() {
+ func.return
+ }
+ func.func @g() {
+ func.return
+ }
+ }
+ """
+ )
+
+ # CHECK: get_ops_of_type func.func count: 2
+ results = get_ops_of_type(module, func.FuncOp)
+ print(f"get_ops_of_type func.func count: {len(results)}")
+ assert len(results) == 2
+ assert all(isinstance(r, func.FuncOp) for r in results)
+
+ # CHECK: get_ops_of_type scf.for count: 0
+ results = get_ops_of_type(module, scf.ForOp)
+ print(f"get_ops_of_type scf.for count: {len(results)}")
+ assert len(results) == 0
+
+ # CHECK: get_ops_of_type func_op->func.ReturnOp count: 1
+ # Accepts OpView as root.
+ func_op = get_ops_of_type(module, func.FuncOp)[0]
+ results = get_ops_of_type(func_op, func.ReturnOp)
+ print(f"get_ops_of_type func_op->func.ReturnOp count: {len(results)}")
+ assert len(results) == 1
+ assert isinstance(results[0], func.ReturnOp)
+
+ # CHECK: get_ops_of_type no filter count: 5
+ # No op_class collects all ops.
+ results = get_ops_of_type(module)
+ print(f"get_ops_of_type no filter count: {len(results)}")
+ assert len(results) == 5
+ assert any(isinstance(r, func.FuncOp) for r in results)
+ assert any(isinstance(r, func.ReturnOp) for r in results)
More information about the Mlir-commits
mailing list