[Mlir-commits] [mlir] [MLIR][Python] Add type filter to walk() binding and add get_ops_of_type() utility (PR #186131)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Mar 12 18:15:54 PDT 2026


https://github.com/RattataKing updated https://github.com/llvm/llvm-project/pull/186131

>From 3f189d7730fae931c3d5014b67d4b5dbf158b1a3 Mon Sep 17 00:00:00 2001
From: Amily Wu <amilywu2 at amd.com>
Date: Thu, 12 Mar 2026 14:18:03 +0000
Subject: [PATCH 01/17] Add walk_of_type() and test

---
 mlir/include/mlir/Bindings/Python/IRCore.h |  6 +++
 mlir/lib/Bindings/Python/IRCore.cpp        | 33 +++++++++++++-
 mlir/test/python/ir/operation.py           | 52 ++++++++++++++++++++++
 3 files changed, 89 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Bindings/Python/IRCore.h b/mlir/include/mlir/Bindings/Python/IRCore.h
index 557e32e9a612d..9838b5d740cbd 100644
--- a/mlir/include/mlir/Bindings/Python/IRCore.h
+++ b/mlir/include/mlir/Bindings/Python/IRCore.h
@@ -600,6 +600,12 @@ class MLIR_PYTHON_API_EXPORTED PyOperationBase {
   void walk(std::function<PyWalkResult(MlirOperation)> callback,
             PyWalkOrder walkOrder);
 
+  // Wrap the walk method with a type filter. Works same as op.walk([](OpClass
+  // op) { ... } );
+  void walkOfType(nanobind::object opClass,
+                  std::function<PyWalkResult(MlirOperation)> callback,
+                  PyWalkOrder walkOrder);
+
   /// Moves the operation before or after the other operation.
   void moveAfter(PyOperationBase &other);
   void moveBefore(PyOperationBase &other);
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 92ea44605b01a..6df879daaf35e 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -1145,6 +1145,20 @@ void PyOperationBase::walk(std::function<PyWalkResult(MlirOperation)> callback,
   }
 }
 
+void PyOperationBase::walkOfType(
+    nb::object opClass, std::function<PyWalkResult(MlirOperation)> callback,
+    PyWalkOrder walkOrder) {
+
+  auto filtered = [&](MlirOperation mlirOp) -> PyWalkResult {
+    nb::object opview = nb::cast(mlirOp).attr("opview");
+    if (nb::isinstance(opview, opClass)) {
+      return callback(mlirOp);
+    };
+    return PyWalkResult::Advance;
+  };
+  walk(filtered, walkOrder);
+}
+
 nb::object PyOperationBase::getAsm(bool binary,
                                    std::optional<int64_t> largeElementsLimit,
                                    std::optional<int64_t> largeResourceLimit,
@@ -2707,7 +2721,7 @@ MlirLocation tracebackToLocation(MlirContext ctx) {
     if (!PyGlobals::get().getTracebackLoc().isUserTracebackFilename(fileName))
       continue;
 
-    // co_qualname and PyCode_Addr2Location added in py3.11
+      // co_qualname and PyCode_Addr2Location added in py3.11
 #if PY_VERSION_HEX < 0x030B00F0
     std::string name =
         nb::cast<std::string>(nb::borrow<nb::str>(code->co_name));
@@ -3937,7 +3951,22 @@ void populateIRCore(nb::module_ &m) {
 
              Args:
                callback: A callable that takes an Operation and returns a WalkResult.
-               walk_order: The order of traversal (PRE_ORDER or POST_ORDER).)");
+               walk_order: The order of traversal (PRE_ORDER or POST_ORDER).)")
+      .def("walk_of_type", &PyOperationBase::walkOfType, "op_class"_a,
+           "callback"_a, "walk_order"_a = PyWalkOrder::PostOrder,
+           // clang-format off
+     nb::sig("def walk_of_type(self, op_class: type[OpView], callback: Callable[[Operation], WalkResult], walk_order: WalkOrder) -> None"),
+           // clang-format on
+           R"(
+              Walks the operation tree, invoking the callback only on operations of the specified type.
+
+              Args:
+                op_class: The operation type to match.
+                callback: A callable that takes an Operation and returns a WalkResult.
+                walk_order: The traversal order (PRE_ORDER or POST_ORDER).
+
+              For example, op.walk_of_type(arith.AddIOp, callback) walks the operation tree
+              and invokes callback only on arith.AddIOp operations.)");
 
   nb::class_<PyOperation, PyOperationBase>(m, "Operation")
       .def_static(
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index 865dd226cbe2a..f3cd729ece9a8 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -1323,6 +1323,58 @@ def callback(op):
         module.operation.walk(callback)
     except RuntimeError:
         print("Exception raised")
+        
+# CHECK-LABEL: TEST: testOpWalkOfType
+ at run
+def testOpWalkOfType():
+    with Context(), Location.unknown():
+        module = Module.parse("""
+            module {
+                func.func @f() { return }
+                func.func @g() { return }
+                arith.constant dense<0> : tensor<i32>
+            }
+        """)
+
+    # Callback: only visits ops of the requested type.
+    # CHECK: only FuncOp visited: True
+    only_funcs = True
+    def check_type(op):
+        nonlocal only_funcs
+        if not isinstance(op.opview, func.FuncOp):
+            only_funcs = False
+        return WalkResult.ADVANCE
+    module.operation.walk_of_type(func.FuncOp, check_type)
+    print(f"only FuncOp visited: {only_funcs}")
+
+    # Callback: interrupt after first match.
+    # CHECK: interrupted after: 1
+    seen = []
+    def stop_after_first(op):
+        seen.append(op.opview)
+        return WalkResult.INTERRUPT
+    module.operation.walk_of_type(func.FuncOp, stop_after_first)
+    print(f"interrupted after: {len(seen)}")
+
+    # Callback: no match, callback never called.
+    # CHECK: never called: True
+    called = False
+    def should_not_run(op):
+        nonlocal called
+        called = True
+        return WalkResult.ADVANCE
+    module.operation.walk_of_type(scf.ForOp, should_not_run)
+    print(f"never called: {not called}")
+
+    # Callback: collect all matching ops.
+    # CHECK: collected func.FuncOp: ['"f"', '"g"']
+    collected = []
+    def collect(op):
+        collected.append(op.opview)
+        return WalkResult.ADVANCE
+    module.operation.walk_of_type(func.FuncOp, collect)
+    assert all(isinstance(r, func.FuncOp) for r in collected)
+    print(f"collected func.FuncOp: {[str(r.name) for r in collected]}")
 
 
 # CHECK-LABEL: TEST: testOpReplaceUsesWith

>From c31a6e3bb24cd87591816b120f80f010d280d02b Mon Sep 17 00:00:00 2001
From: Amily Wu <amilywu2 at amd.com>
Date: Thu, 12 Mar 2026 14:22:38 +0000
Subject: [PATCH 02/17] Add get_ops_of_type() and test

---
 mlir/python/mlir/ir.py           | 20 ++++++++++++++++++++
 mlir/test/python/ir/operation.py | 29 +++++++++++++++++++++++++++++
 2 files changed, 49 insertions(+)

diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py
index 99bd135b49636..2a3e801d1ad15 100644
--- a/mlir/python/mlir/ir.py
+++ b/mlir/python/mlir/ir.py
@@ -45,6 +45,26 @@ def get_parent_of_type(op: OpView | Operation, op_class: type[OpView]) -> OpView
     return None
 
 
+def get_ops_of_type(root: OpView | Operation | Module, op_class: type[OpView]) -> list[OpView]:
+    """Return all operations of the given type in the operation tree.
+
+
+    Args:
+      root: The operation or module to start traversing from.
+      op_class: The OpView subclass to visit for (e.g. func.FuncOp).
+
+    Returns:
+      A list of operations of the given type.
+    """
+    op = root.operation if isinstance(root, Module) else root
+    ops = []
+    def collect_ops(op: Operation):
+        ops.append(op.opview)
+        return WalkResult.ADVANCE
+    op.walk_of_type(op_class, collect_ops)
+    return ops
+
+
 @contextmanager
 def loc_tracebacks(*, max_depth: int | None = None) -> Generator[None, None, None]:
     """Enables automatic traceback-based locations for MLIR operations.
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index f3cd729ece9a8..f184ec2a3984c 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -1489,3 +1489,32 @@ def testGetParentOfType():
             assert False, "expected TypeError"
         except TypeError:
             pass
+
+
+# CHECK-LABEL: TEST: test_get_ops_of_type
+ at run
+def test_get_ops_of_type():
+    with Context(), Location.unknown():
+        module = Module.parse("""
+            module {
+                func.func @f() { return }
+                func.func @g() { return }
+            }
+        """)
+
+        # CHECK: get_ops_of_type func.func count: 2
+        results = get_ops_of_type(module, func.FuncOp)
+        print(f"get_ops_of_type func.func count: {len(results)}")
+        assert len(results) == 2
+        assert all(isinstance(r, func.FuncOp) for r in results)
+
+        # CHECK: get_ops_of_type scf.for count: 0
+        results = get_ops_of_type(module, scf.ForOp)
+        print(f"get_ops_of_type scf.for count: {len(results)}")
+        assert len(results) == 0
+
+        # Accepts OpView as root.
+        func_op = get_ops_of_type(module, func.FuncOp)[0]
+        results = get_ops_of_type(func_op, func.ReturnOp)
+        assert len(results) == 1
+        assert isinstance(results[0], func.ReturnOp)

>From ccee41467ad1d6efbae2b788f6466e27e2d3ffb6 Mon Sep 17 00:00:00 2001
From: Amily Wu <amilywu2 at amd.com>
Date: Thu, 12 Mar 2026 14:28:39 +0000
Subject: [PATCH 03/17] Fix comments

---
 mlir/include/mlir/Bindings/Python/IRCore.h | 3 +--
 mlir/lib/Bindings/Python/IRCore.cpp        | 2 +-
 mlir/python/mlir/ir.py                     | 1 -
 3 files changed, 2 insertions(+), 4 deletions(-)

diff --git a/mlir/include/mlir/Bindings/Python/IRCore.h b/mlir/include/mlir/Bindings/Python/IRCore.h
index 9838b5d740cbd..c8cf8fe3d99dc 100644
--- a/mlir/include/mlir/Bindings/Python/IRCore.h
+++ b/mlir/include/mlir/Bindings/Python/IRCore.h
@@ -600,8 +600,7 @@ class MLIR_PYTHON_API_EXPORTED PyOperationBase {
   void walk(std::function<PyWalkResult(MlirOperation)> callback,
             PyWalkOrder walkOrder);
 
-  // Wrap the walk method with a type filter. Works same as op.walk([](OpClass
-  // op) { ... } );
+  // Wrap the walk method with a type filter.
   void walkOfType(nanobind::object opClass,
                   std::function<PyWalkResult(MlirOperation)> callback,
                   PyWalkOrder walkOrder);
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 6df879daaf35e..f5b074b62328e 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -2721,7 +2721,7 @@ MlirLocation tracebackToLocation(MlirContext ctx) {
     if (!PyGlobals::get().getTracebackLoc().isUserTracebackFilename(fileName))
       continue;
 
-      // co_qualname and PyCode_Addr2Location added in py3.11
+    // co_qualname and PyCode_Addr2Location added in py3.11
 #if PY_VERSION_HEX < 0x030B00F0
     std::string name =
         nb::cast<std::string>(nb::borrow<nb::str>(code->co_name));
diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py
index 2a3e801d1ad15..fd4a0bf2aa56d 100644
--- a/mlir/python/mlir/ir.py
+++ b/mlir/python/mlir/ir.py
@@ -48,7 +48,6 @@ def get_parent_of_type(op: OpView | Operation, op_class: type[OpView]) -> OpView
 def get_ops_of_type(root: OpView | Operation | Module, op_class: type[OpView]) -> list[OpView]:
     """Return all operations of the given type in the operation tree.
 
-
     Args:
       root: The operation or module to start traversing from.
       op_class: The OpView subclass to visit for (e.g. func.FuncOp).

>From 846858f6af8a16af5f5fd2ac2dfa72671e13c029 Mon Sep 17 00:00:00 2001
From: Amily Wu <amilywu2 at amd.com>
Date: Thu, 12 Mar 2026 14:53:33 +0000
Subject: [PATCH 04/17] Fix format

---
 mlir/lib/Bindings/Python/IRCore.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index f5b074b62328e..6df879daaf35e 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -2721,7 +2721,7 @@ MlirLocation tracebackToLocation(MlirContext ctx) {
     if (!PyGlobals::get().getTracebackLoc().isUserTracebackFilename(fileName))
       continue;
 
-    // co_qualname and PyCode_Addr2Location added in py3.11
+      // co_qualname and PyCode_Addr2Location added in py3.11
 #if PY_VERSION_HEX < 0x030B00F0
     std::string name =
         nb::cast<std::string>(nb::borrow<nb::str>(code->co_name));

>From 7ffccd56dc0a92425dfbdf70d23737f670849eff Mon Sep 17 00:00:00 2001
From: Amily Wu <amilywu2 at amd.com>
Date: Thu, 12 Mar 2026 15:00:36 +0000
Subject: [PATCH 05/17] Fix format

---
 mlir/python/mlir/ir.py           |  6 +++++-
 mlir/test/python/ir/operation.py | 11 ++++++++++-
 2 files changed, 15 insertions(+), 2 deletions(-)

diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py
index fd4a0bf2aa56d..9dca45c561bc6 100644
--- a/mlir/python/mlir/ir.py
+++ b/mlir/python/mlir/ir.py
@@ -45,7 +45,9 @@ def get_parent_of_type(op: OpView | Operation, op_class: type[OpView]) -> OpView
     return None
 
 
-def get_ops_of_type(root: OpView | Operation | Module, op_class: type[OpView]) -> list[OpView]:
+def get_ops_of_type(
+    root: OpView | Operation | Module, op_class: type[OpView]
+) -> list[OpView]:
     """Return all operations of the given type in the operation tree.
 
     Args:
@@ -57,9 +59,11 @@ def get_ops_of_type(root: OpView | Operation | Module, op_class: type[OpView]) -
     """
     op = root.operation if isinstance(root, Module) else root
     ops = []
+
     def collect_ops(op: Operation):
         ops.append(op.opview)
         return WalkResult.ADVANCE
+
     op.walk_of_type(op_class, collect_ops)
     return ops
 
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index f184ec2a3984c..d579badc3e8dd 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -1323,7 +1323,8 @@ def callback(op):
         module.operation.walk(callback)
     except RuntimeError:
         print("Exception raised")
-        
+
+
 # CHECK-LABEL: TEST: testOpWalkOfType
 @run
 def testOpWalkOfType():
@@ -1339,39 +1340,47 @@ def testOpWalkOfType():
     # Callback: only visits ops of the requested type.
     # CHECK: only FuncOp visited: True
     only_funcs = True
+
     def check_type(op):
         nonlocal only_funcs
         if not isinstance(op.opview, func.FuncOp):
             only_funcs = False
         return WalkResult.ADVANCE
+
     module.operation.walk_of_type(func.FuncOp, check_type)
     print(f"only FuncOp visited: {only_funcs}")
 
     # Callback: interrupt after first match.
     # CHECK: interrupted after: 1
     seen = []
+
     def stop_after_first(op):
         seen.append(op.opview)
         return WalkResult.INTERRUPT
+
     module.operation.walk_of_type(func.FuncOp, stop_after_first)
     print(f"interrupted after: {len(seen)}")
 
     # Callback: no match, callback never called.
     # CHECK: never called: True
     called = False
+
     def should_not_run(op):
         nonlocal called
         called = True
         return WalkResult.ADVANCE
+
     module.operation.walk_of_type(scf.ForOp, should_not_run)
     print(f"never called: {not called}")
 
     # Callback: collect all matching ops.
     # CHECK: collected func.FuncOp: ['"f"', '"g"']
     collected = []
+
     def collect(op):
         collected.append(op.opview)
         return WalkResult.ADVANCE
+
     module.operation.walk_of_type(func.FuncOp, collect)
     assert all(isinstance(r, func.FuncOp) for r in collected)
     print(f"collected func.FuncOp: {[str(r.name) for r in collected]}")

>From 0529d7d5adfc3b04479dae0c31a062f093987a73 Mon Sep 17 00:00:00 2001
From: Amily Wu <amilywu2 at amd.com>
Date: Thu, 12 Mar 2026 15:16:49 +0000
Subject: [PATCH 06/17] Fix format

---
 mlir/lib/Bindings/Python/IRCore.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 6df879daaf35e..f5b074b62328e 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -2721,7 +2721,7 @@ MlirLocation tracebackToLocation(MlirContext ctx) {
     if (!PyGlobals::get().getTracebackLoc().isUserTracebackFilename(fileName))
       continue;
 
-      // co_qualname and PyCode_Addr2Location added in py3.11
+    // co_qualname and PyCode_Addr2Location added in py3.11
 #if PY_VERSION_HEX < 0x030B00F0
     std::string name =
         nb::cast<std::string>(nb::borrow<nb::str>(code->co_name));

>From e6c49f00da7041e99b6d96643171c219356d1db5 Mon Sep 17 00:00:00 2001
From: Amily Wu <amilywu2 at amd.com>
Date: Thu, 12 Mar 2026 15:29:45 +0000
Subject: [PATCH 07/17] Embed walkOfType into walk

---
 mlir/include/mlir/Bindings/Python/IRCore.h |  5 --
 mlir/lib/Bindings/Python/IRCore.cpp        | 65 ++++++++++------------
 mlir/python/mlir/ir.py                     |  2 +-
 mlir/test/python/ir/operation.py           | 65 +---------------------
 4 files changed, 32 insertions(+), 105 deletions(-)

diff --git a/mlir/include/mlir/Bindings/Python/IRCore.h b/mlir/include/mlir/Bindings/Python/IRCore.h
index c8cf8fe3d99dc..557e32e9a612d 100644
--- a/mlir/include/mlir/Bindings/Python/IRCore.h
+++ b/mlir/include/mlir/Bindings/Python/IRCore.h
@@ -600,11 +600,6 @@ class MLIR_PYTHON_API_EXPORTED PyOperationBase {
   void walk(std::function<PyWalkResult(MlirOperation)> callback,
             PyWalkOrder walkOrder);
 
-  // Wrap the walk method with a type filter.
-  void walkOfType(nanobind::object opClass,
-                  std::function<PyWalkResult(MlirOperation)> callback,
-                  PyWalkOrder walkOrder);
-
   /// Moves the operation before or after the other operation.
   void moveAfter(PyOperationBase &other);
   void moveBefore(PyOperationBase &other);
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index f5b074b62328e..455bea10e9437 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -1145,20 +1145,6 @@ void PyOperationBase::walk(std::function<PyWalkResult(MlirOperation)> callback,
   }
 }
 
-void PyOperationBase::walkOfType(
-    nb::object opClass, std::function<PyWalkResult(MlirOperation)> callback,
-    PyWalkOrder walkOrder) {
-
-  auto filtered = [&](MlirOperation mlirOp) -> PyWalkResult {
-    nb::object opview = nb::cast(mlirOp).attr("opview");
-    if (nb::isinstance(opview, opClass)) {
-      return callback(mlirOp);
-    };
-    return PyWalkResult::Advance;
-  };
-  walk(filtered, walkOrder);
-}
-
 nb::object PyOperationBase::getAsm(bool binary,
                                    std::optional<int64_t> largeElementsLimit,
                                    std::optional<int64_t> largeResourceLimit,
@@ -3941,32 +3927,39 @@ void populateIRCore(nb::module_ &m) {
 
             Note:
               After erasing, any Python references to the operation become invalid.)")
-      .def("walk", &PyOperationBase::walk, "callback"_a,
-           "walk_order"_a = PyWalkOrder::PostOrder,
-           // clang-format off
-          nb::sig("def walk(self, callback: Callable[[Operation], WalkResult], walk_order: WalkOrder = ...) -> None"),
-           // clang-format on
-           R"(
+      .def(
+          "walk",
+          [](PyOperationBase &self,
+             std::function<PyWalkResult(MlirOperation)> callback,
+             std::optional<nb::object> opClass, PyWalkOrder walkOrder) {
+            if (opClass) {
+              self.walk(
+                  [&](MlirOperation mlirOp) -> PyWalkResult {
+                    nb::object opview = nb::cast(mlirOp).attr("opview");
+                    if (nb::isinstance(opview, *opClass))
+                      return callback(mlirOp);
+                    return PyWalkResult::Advance;
+                  },
+                  walkOrder);
+            } else {
+              self.walk(callback, walkOrder);
+            }
+          },
+          "callback"_a, "op_class"_a = nb::none(),
+          "walk_order"_a = PyWalkOrder::PostOrder,
+          // clang-format off
+           nb::sig("def walk(self, callback: Callable[[Operation], WalkResult], op_class: type[OpView] | None = None, walk_order: WalkOrder = WalkOrder.POST_ORDER) -> None"),
+          // clang-format on
+          R"(
              Walks the operation tree with a callback function.
 
+             If op_class is provided, the callback is only invoked on operations
+             of that type; all other operations are skipped silently.
+
              Args:
                callback: A callable that takes an Operation and returns a WalkResult.
-               walk_order: The order of traversal (PRE_ORDER or POST_ORDER).)")
-      .def("walk_of_type", &PyOperationBase::walkOfType, "op_class"_a,
-           "callback"_a, "walk_order"_a = PyWalkOrder::PostOrder,
-           // clang-format off
-     nb::sig("def walk_of_type(self, op_class: type[OpView], callback: Callable[[Operation], WalkResult], walk_order: WalkOrder) -> None"),
-           // clang-format on
-           R"(
-              Walks the operation tree, invoking the callback only on operations of the specified type.
-
-              Args:
-                op_class: The operation type to match.
-                callback: A callable that takes an Operation and returns a WalkResult.
-                walk_order: The traversal order (PRE_ORDER or POST_ORDER).
-
-              For example, op.walk_of_type(arith.AddIOp, callback) walks the operation tree
-              and invokes callback only on arith.AddIOp operations.)");
+               op_class: If provided, only operations of this type are passed to the callback.
+               walk_order: The order of traversal (PRE_ORDER or POST_ORDER).)");
 
   nb::class_<PyOperation, PyOperationBase>(m, "Operation")
       .def_static(
diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py
index 9dca45c561bc6..eadd333d1a9f5 100644
--- a/mlir/python/mlir/ir.py
+++ b/mlir/python/mlir/ir.py
@@ -64,7 +64,7 @@ def collect_ops(op: Operation):
         ops.append(op.opview)
         return WalkResult.ADVANCE
 
-    op.walk_of_type(op_class, collect_ops)
+    op.walk(collect_ops, op_class=op_class)
     return ops
 
 
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index d579badc3e8dd..911a3bef3cc39 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -1284,7 +1284,7 @@ def callback(op):
     # CHECK-NEXT:  func.fun
     # CHECK-NEXT:  func.return
     print("Pre-order")
-    module.operation.walk(callback, WalkOrder.PRE_ORDER)
+    module.operation.walk(callback, walk_order=WalkOrder.PRE_ORDER)
 
     # Test interrput.
     # CHECK-NEXT:  Interrupt post-order
@@ -1306,7 +1306,7 @@ def callback(op):
         print(op.name)
         return WalkResult.SKIP
 
-    module.operation.walk(callback, WalkOrder.PRE_ORDER)
+    module.operation.walk(callback, walk_order=WalkOrder.PRE_ORDER)
 
     # Test exception.
     # CHECK: Exception
@@ -1325,67 +1325,6 @@ def callback(op):
         print("Exception raised")
 
 
-# CHECK-LABEL: TEST: testOpWalkOfType
- at run
-def testOpWalkOfType():
-    with Context(), Location.unknown():
-        module = Module.parse("""
-            module {
-                func.func @f() { return }
-                func.func @g() { return }
-                arith.constant dense<0> : tensor<i32>
-            }
-        """)
-
-    # Callback: only visits ops of the requested type.
-    # CHECK: only FuncOp visited: True
-    only_funcs = True
-
-    def check_type(op):
-        nonlocal only_funcs
-        if not isinstance(op.opview, func.FuncOp):
-            only_funcs = False
-        return WalkResult.ADVANCE
-
-    module.operation.walk_of_type(func.FuncOp, check_type)
-    print(f"only FuncOp visited: {only_funcs}")
-
-    # Callback: interrupt after first match.
-    # CHECK: interrupted after: 1
-    seen = []
-
-    def stop_after_first(op):
-        seen.append(op.opview)
-        return WalkResult.INTERRUPT
-
-    module.operation.walk_of_type(func.FuncOp, stop_after_first)
-    print(f"interrupted after: {len(seen)}")
-
-    # Callback: no match, callback never called.
-    # CHECK: never called: True
-    called = False
-
-    def should_not_run(op):
-        nonlocal called
-        called = True
-        return WalkResult.ADVANCE
-
-    module.operation.walk_of_type(scf.ForOp, should_not_run)
-    print(f"never called: {not called}")
-
-    # Callback: collect all matching ops.
-    # CHECK: collected func.FuncOp: ['"f"', '"g"']
-    collected = []
-
-    def collect(op):
-        collected.append(op.opview)
-        return WalkResult.ADVANCE
-
-    module.operation.walk_of_type(func.FuncOp, collect)
-    assert all(isinstance(r, func.FuncOp) for r in collected)
-    print(f"collected func.FuncOp: {[str(r.name) for r in collected]}")
-
-
 # CHECK-LABEL: TEST: testOpReplaceUsesWith
 @run
 def testOpReplaceUsesWith():

>From 0ab6cab0405ddd15a34adccb654c79aff1881089 Mon Sep 17 00:00:00 2001
From: Amily Wu <amilywu2 at amd.com>
Date: Thu, 12 Mar 2026 15:33:47 +0000
Subject: [PATCH 08/17] Merge test

---
 mlir/test/python/ir/operation.py | 56 ++++++++++++++++++++++++++++++++
 1 file changed, 56 insertions(+)

diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index 911a3bef3cc39..bc883643d53fc 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -1324,6 +1324,62 @@ def callback(op):
     except RuntimeError:
         print("Exception raised")
 
+    # Test op_class filter: only visits ops of the requested type.
+    module = Module.parse(
+        """
+        module {
+            func.func @f() { return }
+            func.func @g() { return }
+            arith.constant dense<0> : tensor<i32>
+        }
+    """,
+        ctx,
+    )
+
+    # CHECK-NEXT: only FuncOp visited: True
+    only_funcs = True
+
+    def check_type(op):
+        nonlocal only_funcs
+        if not isinstance(op.opview, func.FuncOp):
+            only_funcs = False
+        return WalkResult.ADVANCE
+
+    module.operation.walk(check_type, op_class=func.FuncOp)
+    print(f"only FuncOp visited: {only_funcs}")
+
+    # CHECK-NEXT: interrupted after: 1
+    seen = []
+
+    def stop_after_first(op):
+        seen.append(op.opview)
+        return WalkResult.INTERRUPT
+
+    module.operation.walk(stop_after_first, op_class=func.FuncOp)
+    print(f"interrupted after: {len(seen)}")
+
+    # CHECK-NEXT: never called: True
+    called = False
+
+    def should_not_run(op):
+        nonlocal called
+        called = True
+        return WalkResult.ADVANCE
+
+    module.operation.walk(should_not_run, op_class=scf.ForOp)
+    print(f"never called: {not called}")
+
+    # CHECK-NEXT: collected func.FuncOp: ['"f"', '"g"']
+    collected = []
+
+    def collect(op):
+        collected.append(op.opview)
+        return WalkResult.ADVANCE
+
+    module.operation.walk(collect, op_class=func.FuncOp)
+    assert all(isinstance(r, func.FuncOp) for r in collected)
+    print(f"collected func.FuncOp: {[str(r.name) for r in collected]}")
+
 
 # CHECK-LABEL: TEST: testOpReplaceUsesWith
 @run

>From 47f1fb5c764d4adaa6e395655c701c2714d97b5d Mon Sep 17 00:00:00 2001
From: Amily Wu <amilywu2 at amd.com>
Date: Thu, 12 Mar 2026 15:46:28 +0000
Subject: [PATCH 09/17] Fix format

---
 mlir/test/python/ir/operation.py | 6 ++++--
 1 file changed, 4 insertions(+), 2 deletions(-)

diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index bc883643d53fc..46850ab960be5 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -1499,12 +1499,14 @@ def testGetParentOfType():
 @run
 def test_get_ops_of_type():
     with Context(), Location.unknown():
-        module = Module.parse("""
+        module = Module.parse(
+            """
             module {
                 func.func @f() { return }
                 func.func @g() { return }
             }
-        """)
+        """
+        )
 
         # CHECK: get_ops_of_type func.func count: 2
         results = get_ops_of_type(module, func.FuncOp)

>From 82cae94a410a7fb48df3d94b322c91946dcff588 Mon Sep 17 00:00:00 2001
From: Amily Wu <amilywu2 at amd.com>
Date: Thu, 12 Mar 2026 17:06:00 +0000
Subject: [PATCH 10/17] Update get opview

---
 mlir/lib/Bindings/Python/IRCore.cpp | 4 +++-
 1 file changed, 3 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 455bea10e9437..bf024e379ecc1 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -3935,7 +3935,9 @@ void populateIRCore(nb::module_ &m) {
             if (opClass) {
               self.walk(
                   [&](MlirOperation mlirOp) -> PyWalkResult {
-                    nb::object opview = nb::cast(mlirOp).attr("opview");
+                     nb::object opview = PyOperation::forOperation(
+                         self.getOperation().getContext(), mlirOp)
+                         ->createOpView();
                     if (nb::isinstance(opview, *opClass))
                       return callback(mlirOp);
                     return PyWalkResult::Advance;

>From 601c98f9f66fce1cf1c99050055d3a0e9f83fd96 Mon Sep 17 00:00:00 2001
From: Amily Wu <amilywu2 at amd.com>
Date: Thu, 12 Mar 2026 17:26:14 +0000
Subject: [PATCH 11/17] Fix format

---
 mlir/lib/Bindings/Python/IRCore.cpp | 7 ++++---
 1 file changed, 4 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index bf024e379ecc1..361825834d75a 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -3935,9 +3935,10 @@ void populateIRCore(nb::module_ &m) {
             if (opClass) {
               self.walk(
                   [&](MlirOperation mlirOp) -> PyWalkResult {
-                     nb::object opview = PyOperation::forOperation(
-                         self.getOperation().getContext(), mlirOp)
-                         ->createOpView();
+                    nb::object opview =
+                        PyOperation::forOperation(
+                            self.getOperation().getContext(), mlirOp)
+                            ->createOpView();
                     if (nb::isinstance(opview, *opClass))
                       return callback(mlirOp);
                     return PyWalkResult::Advance;

>From b8d6d723fb39642ce6adb7d47e3a196394423e57 Mon Sep 17 00:00:00 2001
From: Amily Wu <amilywu2 at amd.com>
Date: Thu, 12 Mar 2026 17:56:28 +0000
Subject: [PATCH 12/17] Move new optional attr to end

---
 mlir/lib/Bindings/Python/IRCore.cpp | 42 ++++++++++++++---------------
 1 file changed, 20 insertions(+), 22 deletions(-)

diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 361825834d75a..0cbe2fcb6d6b1 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -3931,27 +3931,25 @@ void populateIRCore(nb::module_ &m) {
           "walk",
           [](PyOperationBase &self,
              std::function<PyWalkResult(MlirOperation)> callback,
-             std::optional<nb::object> opClass, PyWalkOrder walkOrder) {
-            if (opClass) {
-              self.walk(
-                  [&](MlirOperation mlirOp) -> PyWalkResult {
-                    nb::object opview =
-                        PyOperation::forOperation(
-                            self.getOperation().getContext(), mlirOp)
-                            ->createOpView();
-                    if (nb::isinstance(opview, *opClass))
-                      return callback(mlirOp);
-                    return PyWalkResult::Advance;
-                  },
-                  walkOrder);
-            } else {
-              self.walk(callback, walkOrder);
-            }
-          },
-          "callback"_a, "op_class"_a = nb::none(),
-          "walk_order"_a = PyWalkOrder::PostOrder,
+             PyWalkOrder walkOrder, std::optional<nb::object> opClass) {
+            if (!opClass)
+              return self.walk(callback, walkOrder);
+            self.walk(
+                [&](MlirOperation mlirOp) -> PyWalkResult {
+                  nb::object opview =
+                      PyOperation::forOperation(
+                          self.getOperation().getContext(), mlirOp)
+                          ->createOpView();
+                  if (nb::isinstance(opview, *opClass))
+                    return callback(mlirOp);
+                  return PyWalkResult::Advance;
+                },
+                walkOrder);
+          },
+          "callback"_a, "walk_order"_a = PyWalkOrder::PostOrder,
+          "op_class"_a = nb::none(),
           // clang-format off
-           nb::sig("def walk(self, callback: Callable[[Operation], WalkResult], op_class: type[OpView] | None = None, walk_order: WalkOrder = WalkOrder.POST_ORDER) -> None"),
+           nb::sig("def walk(self, callback: Callable[[Operation], WalkResult], walk_order: WalkOrder = WalkOrder.POST_ORDER, op_class: type[OpView] | None = None) -> None"),
           // clang-format on
           R"(
              Walks the operation tree with a callback function.
@@ -3961,8 +3959,8 @@ void populateIRCore(nb::module_ &m) {
 
              Args:
                callback: A callable that takes an Operation and returns a WalkResult.
-               op_class: If provided, only operations of this type are passed to the callback.
-               walk_order: The order of traversal (PRE_ORDER or POST_ORDER).)");
+               walk_order: The order of traversal (PRE_ORDER or POST_ORDER).
+               op_class: If provided, only operations of this type are passed to the callback.)");
 
   nb::class_<PyOperation, PyOperationBase>(m, "Operation")
       .def_static(

>From 6ae9de4592d10065f502f8c6d61c3fa3496aaf61 Mon Sep 17 00:00:00 2001
From: Amily Wu <amilywu2 at amd.com>
Date: Thu, 12 Mar 2026 18:03:49 +0000
Subject: [PATCH 13/17] Update walk_order stubgen

---
 mlir/lib/Bindings/Python/IRCore.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 0cbe2fcb6d6b1..7eb59d61b0d57 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -3949,7 +3949,7 @@ void populateIRCore(nb::module_ &m) {
           "callback"_a, "walk_order"_a = PyWalkOrder::PostOrder,
           "op_class"_a = nb::none(),
           // clang-format off
-           nb::sig("def walk(self, callback: Callable[[Operation], WalkResult], walk_order: WalkOrder = WalkOrder.POST_ORDER, op_class: type[OpView] | None = None) -> None"),
+           nb::sig("def walk(self, callback: Callable[[Operation], WalkResult], walk_order: WalkOrder = ..., op_class: type[OpView] | None = None) -> None"),
           // clang-format on
           R"(
              Walks the operation tree with a callback function.

>From 66f9cba4936dda3e3ad6f00b25153ce6e76d1db2 Mon Sep 17 00:00:00 2001
From: Amily Wu <amilywu2 at amd.com>
Date: Thu, 12 Mar 2026 18:13:34 +0000
Subject: [PATCH 14/17] Edit optional attr type hint and add test

---
 mlir/python/mlir/ir.py           |  5 +++--
 mlir/test/python/ir/operation.py | 10 ++++++++++
 2 files changed, 13 insertions(+), 2 deletions(-)

diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py
index eadd333d1a9f5..210465daad0d8 100644
--- a/mlir/python/mlir/ir.py
+++ b/mlir/python/mlir/ir.py
@@ -46,13 +46,14 @@ def get_parent_of_type(op: OpView | Operation, op_class: type[OpView]) -> OpView
 
 
 def get_ops_of_type(
-    root: OpView | Operation | Module, op_class: type[OpView]
+    root: OpView | Operation | Module, op_class: type[OpView] | None = None
 ) -> list[OpView]:
     """Return all operations of the given type in the operation tree.
 
     Args:
       root: The operation or module to start traversing from.
-      op_class: The OpView subclass to visit for (e.g. func.FuncOp).
+      op_class: The OpView subclass to filter by (e.g. func.FuncOp). If None,
+        collects all operations in the tree.
 
     Returns:
       A list of operations of the given type.
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index 46850ab960be5..e4836b23d7602 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -1519,8 +1519,18 @@ def test_get_ops_of_type():
         print(f"get_ops_of_type scf.for count: {len(results)}")
         assert len(results) == 0
 
+        # CHECK: get_ops_of_type func_op->func.ReturnOp count: 1
         # Accepts OpView as root.
         func_op = get_ops_of_type(module, func.FuncOp)[0]
         results = get_ops_of_type(func_op, func.ReturnOp)
+        print(f"get_ops_of_type func_op->func.ReturnOp count: {len(results)}")
         assert len(results) == 1
         assert isinstance(results[0], func.ReturnOp)
+
+        # CHECK: get_ops_of_type no filter count: 5
+        # No op_class collects all ops.
+        results = get_ops_of_type(module)
+        print(f"get_ops_of_type no filter count: {len(results)}")
+        assert len(results) == 5
+        assert any(isinstance(r, func.FuncOp) for r in results)
+        assert any(isinstance(r, func.ReturnOp) for r in results)

>From 3bd24d996484287c193f3523ea5815bd8893f02c Mon Sep 17 00:00:00 2001
From: Amily Wu <amilywu2 at amd.com>
Date: Thu, 12 Mar 2026 18:17:10 +0000
Subject: [PATCH 15/17] Revert test case walk_order and add new test

---
 mlir/test/python/ir/operation.py | 11 +++++++++--
 1 file changed, 9 insertions(+), 2 deletions(-)

diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index e4836b23d7602..4efb06c5f5588 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -1284,7 +1284,7 @@ def callback(op):
     # CHECK-NEXT:  func.fun
     # CHECK-NEXT:  func.return
     print("Pre-order")
-    module.operation.walk(callback, walk_order=WalkOrder.PRE_ORDER)
+    module.operation.walk(callback, WalkOrder.PRE_ORDER)
 
     # Test interrput.
     # CHECK-NEXT:  Interrupt post-order
@@ -1306,7 +1306,7 @@ def callback(op):
         print(op.name)
         return WalkResult.SKIP
 
-    module.operation.walk(callback, walk_order=WalkOrder.PRE_ORDER)
+    module.operation.walk(callback, WalkOrder.PRE_ORDER)
 
     # Test exception.
     # CHECK: Exception
@@ -1380,6 +1380,13 @@ def collect(op):
     assert all(isinstance(r, func.FuncOp) for r in collected)
     print(f"collected func.FuncOp: {[str(r.name) for r in collected]}")
 
+    # Test op_class with walk_order: pre-order visits FuncOps in source order.
+    # CHECK-NEXT: pre-order FuncOp names: ['"f"', '"g"']
+    collected.clear()
+    module.operation.walk(collect, WalkOrder.PRE_ORDER, op_class=func.FuncOp)
+    assert all(isinstance(r, func.FuncOp) for r in collected)
+    print(f"pre-order FuncOp names: {[str(r.name) for r in collected]}")
+
 
 # CHECK-LABEL: TEST: testOpReplaceUsesWith
 @run

>From dda2de52025f5e9c7d79a9988adbbbdacf17b84f Mon Sep 17 00:00:00 2001
From: Amily Wu <amilywu2 at amd.com>
Date: Thu, 12 Mar 2026 18:20:22 +0000
Subject: [PATCH 16/17] Update indentation of test module str

---
 mlir/test/python/ir/operation.py | 18 +++++++++++-------
 1 file changed, 11 insertions(+), 7 deletions(-)

diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index 4efb06c5f5588..9fd83f90241be 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -1326,13 +1326,17 @@ def callback(op):
 
     # Test op_class filter: only visits ops of the requested type.
     module = Module.parse(
-        """
-        module {
-            func.func @f() { return }
-            func.func @g() { return }
-            arith.constant dense<0> : tensor<i32>
-        }
-    """,
+        r"""
+    module {
+      func.func @f() {
+        func.return
+      }
+      func.func @g() {
+        func.return
+      }
+      arith.constant dense<0> : tensor<i32>
+    }
+  """,
         ctx,
     )
 

>From ec973705f1acdcbb9316d904979d6f0ffcc419f8 Mon Sep 17 00:00:00 2001
From: Amily Wu <amilywu2 at amd.com>
Date: Thu, 12 Mar 2026 18:24:54 +0000
Subject: [PATCH 17/17] Update indentation of test module str

---
 mlir/test/python/ir/operation.py | 16 ++++++++++------
 1 file changed, 10 insertions(+), 6 deletions(-)

diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index 9fd83f90241be..f561a1bc624d8 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -1511,12 +1511,16 @@ def testGetParentOfType():
 def test_get_ops_of_type():
     with Context(), Location.unknown():
         module = Module.parse(
-            """
-            module {
-                func.func @f() { return }
-                func.func @g() { return }
-            }
-        """
+            r"""
+    module {
+      func.func @f() {
+        func.return
+      }
+      func.func @g() {
+        func.return
+      }
+    }
+  """
         )
 
         # CHECK: get_ops_of_type func.func count: 2



More information about the Mlir-commits mailing list