[Mlir-commits] [mlir] [MLIR][Python] Add walk_of_type() binding and get_ops_of_type() utility (PR #186131)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Mar 12 07:53:50 PDT 2026


https://github.com/RattataKing updated https://github.com/llvm/llvm-project/pull/186131

>From 881d72a4e64b62d61075fe713b4262f7f98cc770 Mon Sep 17 00:00:00 2001
From: Amily Wu <amilywu2 at amd.com>
Date: Thu, 12 Mar 2026 14:18:03 +0000
Subject: [PATCH 1/4] Add walk_of_type() and test

---
 mlir/include/mlir/Bindings/Python/IRCore.h |  6 +++
 mlir/lib/Bindings/Python/IRCore.cpp        | 33 +++++++++++++-
 mlir/test/python/ir/operation.py           | 52 ++++++++++++++++++++++
 3 files changed, 89 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Bindings/Python/IRCore.h b/mlir/include/mlir/Bindings/Python/IRCore.h
index bd2d49acbf681..1569a870e3996 100644
--- a/mlir/include/mlir/Bindings/Python/IRCore.h
+++ b/mlir/include/mlir/Bindings/Python/IRCore.h
@@ -599,6 +599,12 @@ class MLIR_PYTHON_API_EXPORTED PyOperationBase {
   void walk(std::function<PyWalkResult(MlirOperation)> callback,
             PyWalkOrder walkOrder);
 
+  // Wrap the walk method with a type filter. Works same as op.walk([](OpClass
+  // op) { ... } );
+  void walkOfType(nanobind::object opClass,
+                  std::function<PyWalkResult(MlirOperation)> callback,
+                  PyWalkOrder walkOrder);
+
   /// Moves the operation before or after the other operation.
   void moveAfter(PyOperationBase &other);
   void moveBefore(PyOperationBase &other);
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index b8637c57a3f48..2c5634b239b77 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -1145,6 +1145,20 @@ void PyOperationBase::walk(std::function<PyWalkResult(MlirOperation)> callback,
   }
 }
 
+void PyOperationBase::walkOfType(
+    nb::object opClass, std::function<PyWalkResult(MlirOperation)> callback,
+    PyWalkOrder walkOrder) {
+
+  auto filtered = [&](MlirOperation mlirOp) -> PyWalkResult {
+    nb::object opview = nb::cast(mlirOp).attr("opview");
+    if (nb::isinstance(opview, opClass)) {
+      return callback(mlirOp);
+    };
+    return PyWalkResult::Advance;
+  };
+  walk(filtered, walkOrder);
+}
+
 nb::object PyOperationBase::getAsm(bool binary,
                                    std::optional<int64_t> largeElementsLimit,
                                    std::optional<int64_t> largeResourceLimit,
@@ -2707,7 +2721,7 @@ MlirLocation tracebackToLocation(MlirContext ctx) {
     if (!PyGlobals::get().getTracebackLoc().isUserTracebackFilename(fileName))
       continue;
 
-    // co_qualname and PyCode_Addr2Location added in py3.11
+      // co_qualname and PyCode_Addr2Location added in py3.11
 #if PY_VERSION_HEX < 0x030B00F0
     std::string name =
         nb::cast<std::string>(nb::borrow<nb::str>(code->co_name));
@@ -3868,7 +3882,22 @@ void populateIRCore(nb::module_ &m) {
 
              Args:
                callback: A callable that takes an Operation and returns a WalkResult.
-               walk_order: The order of traversal (PRE_ORDER or POST_ORDER).)");
+               walk_order: The order of traversal (PRE_ORDER or POST_ORDER).)")
+      .def("walk_of_type", &PyOperationBase::walkOfType, "op_class"_a,
+           "callback"_a, "walk_order"_a = PyWalkOrder::PostOrder,
+           // clang-format off
+     nb::sig("def walk_of_type(self, op_class: type[OpView], callback: Callable[[Operation], WalkResult], walk_order: WalkOrder) -> None"),
+           // clang-format on
+           R"(
+              Walks the operation tree, invoking the callback only on operations of the specified type.
+
+              Args:
+                op_class: The operation type to match.
+                callback: A callable that takes an Operation and returns a WalkResult.
+                walk_order: The traversal order (PRE_ORDER or POST_ORDER).
+
+              For example, op.walk_of_type(arith.AddIOp, callback) walks the operation tree
+              and invokes callback only on arith.AddIOp operations.)");
 
   nb::class_<PyOperation, PyOperationBase>(m, "Operation")
       .def_static(
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index 865dd226cbe2a..f3cd729ece9a8 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -1323,6 +1323,58 @@ def callback(op):
         module.operation.walk(callback)
     except RuntimeError:
         print("Exception raised")
+        
+# CHECK-LABEL: TEST: testOpWalkOfType
+ at run
+def testOpWalkOfType():
+    with Context(), Location.unknown():
+        module = Module.parse("""
+            module {
+                func.func @f() { return }
+                func.func @g() { return }
+                arith.constant dense<0> : tensor<i32>
+            }
+        """)
+
+    # Callback: only visits ops of the requested type.
+    # CHECK: only FuncOp visited: True
+    only_funcs = True
+    def check_type(op):
+        nonlocal only_funcs
+        if not isinstance(op.opview, func.FuncOp):
+            only_funcs = False
+        return WalkResult.ADVANCE
+    module.operation.walk_of_type(func.FuncOp, check_type)
+    print(f"only FuncOp visited: {only_funcs}")
+
+    # Callback: interrupt after first match.
+    # CHECK: interrupted after: 1
+    seen = []
+    def stop_after_first(op):
+        seen.append(op.opview)
+        return WalkResult.INTERRUPT
+    module.operation.walk_of_type(func.FuncOp, stop_after_first)
+    print(f"interrupted after: {len(seen)}")
+
+    # Callback: no match, callback never called.
+    # CHECK: never called: True
+    called = False
+    def should_not_run(op):
+        nonlocal called
+        called = True
+        return WalkResult.ADVANCE
+    module.operation.walk_of_type(scf.ForOp, should_not_run)
+    print(f"never called: {not called}")
+
+    # Callback: collect all matching ops.
+    # CHECK: collected func.FuncOp: ['"f"', '"g"']
+    collected = []
+    def collect(op):
+        collected.append(op.opview)
+        return WalkResult.ADVANCE
+    module.operation.walk_of_type(func.FuncOp, collect)
+    assert all(isinstance(r, func.FuncOp) for r in collected)
+    print(f"collected func.FuncOp: {[str(r.name) for r in collected]}")
 
 
 # CHECK-LABEL: TEST: testOpReplaceUsesWith

>From 51e64162fdcd1869d2d369ea36fc24db18ea0d05 Mon Sep 17 00:00:00 2001
From: Amily Wu <amilywu2 at amd.com>
Date: Thu, 12 Mar 2026 14:22:38 +0000
Subject: [PATCH 2/4] Add get_ops_of_type() and test

---
 mlir/python/mlir/ir.py           | 20 ++++++++++++++++++++
 mlir/test/python/ir/operation.py | 29 +++++++++++++++++++++++++++++
 2 files changed, 49 insertions(+)

diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py
index d86298a72c6f2..a495a4db77861 100644
--- a/mlir/python/mlir/ir.py
+++ b/mlir/python/mlir/ir.py
@@ -45,6 +45,26 @@ def get_parent_of_type(op: OpView | Operation, op_class: type[OpView]) -> OpView
     return None
 
 
+def get_ops_of_type(root: OpView | Operation | Module, op_class: type[OpView]) -> list[OpView]:
+    """Return all operations of the given type in the operation tree.
+
+
+    Args:
+      root: The operation or module to start traversing from.
+      op_class: The OpView subclass to visit for (e.g. func.FuncOp).
+
+    Returns:
+      A list of operations of the given type.
+    """
+    op = root.operation if isinstance(root, Module) else root
+    ops = []
+    def collect_ops(op: Operation):
+        ops.append(op.opview)
+        return WalkResult.ADVANCE
+    op.walk_of_type(op_class, collect_ops)
+    return ops
+
+
 @contextmanager
 def loc_tracebacks(*, max_depth: int | None = None) -> Generator[None]:
     """Enables automatic traceback-based locations for MLIR operations.
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index f3cd729ece9a8..f184ec2a3984c 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -1489,3 +1489,32 @@ def testGetParentOfType():
             assert False, "expected TypeError"
         except TypeError:
             pass
+
+
+# CHECK-LABEL: TEST: test_get_ops_of_type
+ at run
+def test_get_ops_of_type():
+    with Context(), Location.unknown():
+        module = Module.parse("""
+            module {
+                func.func @f() { return }
+                func.func @g() { return }
+            }
+        """)
+
+        # CHECK: get_ops_of_type func.func count: 2
+        results = get_ops_of_type(module, func.FuncOp)
+        print(f"get_ops_of_type func.func count: {len(results)}")
+        assert len(results) == 2
+        assert all(isinstance(r, func.FuncOp) for r in results)
+
+        # CHECK: get_ops_of_type scf.for count: 0
+        results = get_ops_of_type(module, scf.ForOp)
+        print(f"get_ops_of_type scf.for count: {len(results)}")
+        assert len(results) == 0
+
+        # Accepts OpView as root.
+        func_op = get_ops_of_type(module, func.FuncOp)[0]
+        results = get_ops_of_type(func_op, func.ReturnOp)
+        assert len(results) == 1
+        assert isinstance(results[0], func.ReturnOp)

>From 36e089fbc02e13935fe7d010ba7983f3b6424f94 Mon Sep 17 00:00:00 2001
From: Amily Wu <amilywu2 at amd.com>
Date: Thu, 12 Mar 2026 14:28:39 +0000
Subject: [PATCH 3/4] Fix comments

---
 mlir/include/mlir/Bindings/Python/IRCore.h | 3 +--
 mlir/lib/Bindings/Python/IRCore.cpp        | 2 +-
 mlir/python/mlir/ir.py                     | 1 -
 3 files changed, 2 insertions(+), 4 deletions(-)

diff --git a/mlir/include/mlir/Bindings/Python/IRCore.h b/mlir/include/mlir/Bindings/Python/IRCore.h
index 1569a870e3996..836dedcff0ade 100644
--- a/mlir/include/mlir/Bindings/Python/IRCore.h
+++ b/mlir/include/mlir/Bindings/Python/IRCore.h
@@ -599,8 +599,7 @@ class MLIR_PYTHON_API_EXPORTED PyOperationBase {
   void walk(std::function<PyWalkResult(MlirOperation)> callback,
             PyWalkOrder walkOrder);
 
-  // Wrap the walk method with a type filter. Works same as op.walk([](OpClass
-  // op) { ... } );
+  // Wrap the walk method with a type filter.
   void walkOfType(nanobind::object opClass,
                   std::function<PyWalkResult(MlirOperation)> callback,
                   PyWalkOrder walkOrder);
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 2c5634b239b77..ac3dc00c008a0 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -2721,7 +2721,7 @@ MlirLocation tracebackToLocation(MlirContext ctx) {
     if (!PyGlobals::get().getTracebackLoc().isUserTracebackFilename(fileName))
       continue;
 
-      // co_qualname and PyCode_Addr2Location added in py3.11
+    // co_qualname and PyCode_Addr2Location added in py3.11
 #if PY_VERSION_HEX < 0x030B00F0
     std::string name =
         nb::cast<std::string>(nb::borrow<nb::str>(code->co_name));
diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py
index a495a4db77861..023808b34fbb7 100644
--- a/mlir/python/mlir/ir.py
+++ b/mlir/python/mlir/ir.py
@@ -48,7 +48,6 @@ def get_parent_of_type(op: OpView | Operation, op_class: type[OpView]) -> OpView
 def get_ops_of_type(root: OpView | Operation | Module, op_class: type[OpView]) -> list[OpView]:
     """Return all operations of the given type in the operation tree.
 
-
     Args:
       root: The operation or module to start traversing from.
       op_class: The OpView subclass to visit for (e.g. func.FuncOp).

>From 4ec803dc8f4870fd57a695cb6d85ee0fbacdce95 Mon Sep 17 00:00:00 2001
From: Amily Wu <amilywu2 at amd.com>
Date: Thu, 12 Mar 2026 14:53:33 +0000
Subject: [PATCH 4/4] Fix format

---
 mlir/lib/Bindings/Python/IRCore.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index ac3dc00c008a0..2c5634b239b77 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -2721,7 +2721,7 @@ MlirLocation tracebackToLocation(MlirContext ctx) {
     if (!PyGlobals::get().getTracebackLoc().isUserTracebackFilename(fileName))
       continue;
 
-    // co_qualname and PyCode_Addr2Location added in py3.11
+      // co_qualname and PyCode_Addr2Location added in py3.11
 #if PY_VERSION_HEX < 0x030B00F0
     std::string name =
         nb::cast<std::string>(nb::borrow<nb::str>(code->co_name));



More information about the Mlir-commits mailing list