[Mlir-commits] [mlir] 39b3b2e - [MLIR][Python] Add type filter to walk() binding and add get_ops_of_type() utility (#186131)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Mar 13 10:36:35 PDT 2026


Author: RattataKing
Date: 2026-03-13T13:36:30-04:00
New Revision: 39b3b2e0e8d1a0c2f32976a1e440d0a64137208d

URL: https://github.com/llvm/llvm-project/commit/39b3b2e0e8d1a0c2f32976a1e440d0a64137208d
DIFF: https://github.com/llvm/llvm-project/commit/39b3b2e0e8d1a0c2f32976a1e440d0a64137208d.diff

LOG: [MLIR][Python] Add type filter to walk() binding and add get_ops_of_type() utility (#186131)

MLIR's C++ `Operation::walk` supports type-filtered traversal (e.g.
`op->walk([](arith::AddIOp op) { ... })`), but the Python binding
`op.walk()` requires users to manually implement type filtering inside
the callback function.

This PR adds type filtering into the python binding `op.walk()`, if
users pass `op_class`, walk() will only apply callback to matching ops.

This PR also adds a common use helper in mlir/ir that collects all ops
of a given type into a list. Users can just call: `ops =
ir.get_ops_of_type(root, op_class)`.

Added: 
    

Modified: 
    mlir/lib/Bindings/Python/IRCore.cpp
    mlir/python/mlir/ir.py
    mlir/test/python/ir/operation.py

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 92ea44605b01a..7eb59d61b0d57 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -3927,17 +3927,40 @@ void populateIRCore(nb::module_ &m) {
 
             Note:
               After erasing, any Python references to the operation become invalid.)")
-      .def("walk", &PyOperationBase::walk, "callback"_a,
-           "walk_order"_a = PyWalkOrder::PostOrder,
-           // clang-format off
-          nb::sig("def walk(self, callback: Callable[[Operation], WalkResult], walk_order: WalkOrder = ...) -> None"),
-           // clang-format on
-           R"(
+      .def(
+          "walk",
+          [](PyOperationBase &self,
+             std::function<PyWalkResult(MlirOperation)> callback,
+             PyWalkOrder walkOrder, std::optional<nb::object> opClass) {
+            if (!opClass)
+              return self.walk(callback, walkOrder);
+            self.walk(
+                [&](MlirOperation mlirOp) -> PyWalkResult {
+                  nb::object opview =
+                      PyOperation::forOperation(
+                          self.getOperation().getContext(), mlirOp)
+                          ->createOpView();
+                  if (nb::isinstance(opview, *opClass))
+                    return callback(mlirOp);
+                  return PyWalkResult::Advance;
+                },
+                walkOrder);
+          },
+          "callback"_a, "walk_order"_a = PyWalkOrder::PostOrder,
+          "op_class"_a = nb::none(),
+          // clang-format off
+           nb::sig("def walk(self, callback: Callable[[Operation], WalkResult], walk_order: WalkOrder = ..., op_class: type[OpView] | None = None) -> None"),
+          // clang-format on
+          R"(
              Walks the operation tree with a callback function.
 
+             If op_class is provided, the callback is only invoked on operations
+             of that type; all other operations are skipped silently.
+
              Args:
                callback: A callable that takes an Operation and returns a WalkResult.
-               walk_order: The order of traversal (PRE_ORDER or POST_ORDER).)");
+               walk_order: The order of traversal (PRE_ORDER or POST_ORDER).
+               op_class: If provided, only operations of this type are passed to the callback.)");
 
   nb::class_<PyOperation, PyOperationBase>(m, "Operation")
       .def_static(

diff  --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py
index 99bd135b49636..210465daad0d8 100644
--- a/mlir/python/mlir/ir.py
+++ b/mlir/python/mlir/ir.py
@@ -45,6 +45,30 @@ def get_parent_of_type(op: OpView | Operation, op_class: type[OpView]) -> OpView
     return None
 
 
+def get_ops_of_type(
+    root: OpView | Operation | Module, op_class: type[OpView] | None = None
+) -> list[OpView]:
+    """Return all operations of the given type in the operation tree.
+
+    Args:
+      root: The operation or module to start traversing from.
+      op_class: The OpView subclass to filter by (e.g. func.FuncOp). If None,
+        collects all operations in the tree.
+
+    Returns:
+      A list of operations of the given type.
+    """
+    op = root.operation if isinstance(root, Module) else root
+    ops = []
+
+    def collect_ops(op: Operation):
+        ops.append(op.opview)
+        return WalkResult.ADVANCE
+
+    op.walk(collect_ops, op_class=op_class)
+    return ops
+
+
 @contextmanager
 def loc_tracebacks(*, max_depth: int | None = None) -> Generator[None, None, None]:
     """Enables automatic traceback-based locations for MLIR operations.

diff  --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index 865dd226cbe2a..f561a1bc624d8 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -1324,6 +1324,73 @@ def callback(op):
     except RuntimeError:
         print("Exception raised")
 
+    # Test op_class filter: only visits ops of the requested type.
+    module = Module.parse(
+        r"""
+    module {
+      func.func @f() {
+        func.return
+      }
+      func.func @g() {
+        func.return
+      }
+      arith.constant dense<0> : tensor<i32>
+    }
+  """,
+        ctx,
+    )
+
+    # CHECK-NEXT: only FuncOp visited: True
+    only_funcs = True
+
+    def check_type(op):
+        nonlocal only_funcs
+        if not isinstance(op.opview, func.FuncOp):
+            only_funcs = False
+        return WalkResult.ADVANCE
+
+    module.operation.walk(check_type, op_class=func.FuncOp)
+    print(f"only FuncOp visited: {only_funcs}")
+
+    # CHECK-NEXT: interrupted after: 1
+    seen = []
+
+    def stop_after_first(op):
+        seen.append(op.opview)
+        return WalkResult.INTERRUPT
+
+    module.operation.walk(stop_after_first, op_class=func.FuncOp)
+    print(f"interrupted after: {len(seen)}")
+
+    # CHECK-NEXT: never called: True
+    called = False
+
+    def should_not_run(op):
+        nonlocal called
+        called = True
+        return WalkResult.ADVANCE
+
+    module.operation.walk(should_not_run, op_class=scf.ForOp)
+    print(f"never called: {not called}")
+
+    # CHECK-NEXT: collected func.FuncOp: ['"f"', '"g"']
+    collected = []
+
+    def collect(op):
+        collected.append(op.opview)
+        return WalkResult.ADVANCE
+
+    module.operation.walk(collect, op_class=func.FuncOp)
+    assert all(isinstance(r, func.FuncOp) for r in collected)
+    print(f"collected func.FuncOp: {[str(r.name) for r in collected]}")
+
+    # Test op_class with walk_order: pre-order visits FuncOps in source order.
+    # CHECK-NEXT: pre-order FuncOp names: ['"f"', '"g"']
+    collected.clear()
+    module.operation.walk(collect, WalkOrder.PRE_ORDER, op_class=func.FuncOp)
+    assert all(isinstance(r, func.FuncOp) for r in collected)
+    print(f"pre-order FuncOp names: {[str(r.name) for r in collected]}")
+
 
 # CHECK-LABEL: TEST: testOpReplaceUsesWith
 @run
@@ -1437,3 +1504,48 @@ def testGetParentOfType():
             assert False, "expected TypeError"
         except TypeError:
             pass
+
+
+# CHECK-LABEL: TEST: test_get_ops_of_type
+ at run
+def test_get_ops_of_type():
+    with Context(), Location.unknown():
+        module = Module.parse(
+            r"""
+    module {
+      func.func @f() {
+        func.return
+      }
+      func.func @g() {
+        func.return
+      }
+    }
+  """
+        )
+
+        # CHECK: get_ops_of_type func.func count: 2
+        results = get_ops_of_type(module, func.FuncOp)
+        print(f"get_ops_of_type func.func count: {len(results)}")
+        assert len(results) == 2
+        assert all(isinstance(r, func.FuncOp) for r in results)
+
+        # CHECK: get_ops_of_type scf.for count: 0
+        results = get_ops_of_type(module, scf.ForOp)
+        print(f"get_ops_of_type scf.for count: {len(results)}")
+        assert len(results) == 0
+
+        # CHECK: get_ops_of_type func_op->func.ReturnOp count: 1
+        # Accepts OpView as root.
+        func_op = get_ops_of_type(module, func.FuncOp)[0]
+        results = get_ops_of_type(func_op, func.ReturnOp)
+        print(f"get_ops_of_type func_op->func.ReturnOp count: {len(results)}")
+        assert len(results) == 1
+        assert isinstance(results[0], func.ReturnOp)
+
+        # CHECK: get_ops_of_type no filter count: 5
+        # No op_class collects all ops.
+        results = get_ops_of_type(module)
+        print(f"get_ops_of_type no filter count: {len(results)}")
+        assert len(results) == 5
+        assert any(isinstance(r, func.FuncOp) for r in results)
+        assert any(isinstance(r, func.ReturnOp) for r in results)


        


More information about the Mlir-commits mailing list