[Mlir-commits] [mlir] [MLIR][Python] Add type filter to walk() binding and add get_ops_of_type() utility (PR #186131)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Mar 12 10:06:19 PDT 2026
https://github.com/RattataKing updated https://github.com/llvm/llvm-project/pull/186131
>From 881d72a4e64b62d61075fe713b4262f7f98cc770 Mon Sep 17 00:00:00 2001
From: Amily Wu <amilywu2 at amd.com>
Date: Thu, 12 Mar 2026 14:18:03 +0000
Subject: [PATCH 01/10] Add walk_of_type() and test
---
mlir/include/mlir/Bindings/Python/IRCore.h | 6 +++
mlir/lib/Bindings/Python/IRCore.cpp | 33 +++++++++++++-
mlir/test/python/ir/operation.py | 52 ++++++++++++++++++++++
3 files changed, 89 insertions(+), 2 deletions(-)
diff --git a/mlir/include/mlir/Bindings/Python/IRCore.h b/mlir/include/mlir/Bindings/Python/IRCore.h
index bd2d49acbf681..1569a870e3996 100644
--- a/mlir/include/mlir/Bindings/Python/IRCore.h
+++ b/mlir/include/mlir/Bindings/Python/IRCore.h
@@ -599,6 +599,12 @@ class MLIR_PYTHON_API_EXPORTED PyOperationBase {
void walk(std::function<PyWalkResult(MlirOperation)> callback,
PyWalkOrder walkOrder);
+ // Wrap the walk method with a type filter. Works same as op.walk([](OpClass
+ // op) { ... } );
+ void walkOfType(nanobind::object opClass,
+ std::function<PyWalkResult(MlirOperation)> callback,
+ PyWalkOrder walkOrder);
+
/// Moves the operation before or after the other operation.
void moveAfter(PyOperationBase &other);
void moveBefore(PyOperationBase &other);
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index b8637c57a3f48..2c5634b239b77 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -1145,6 +1145,20 @@ void PyOperationBase::walk(std::function<PyWalkResult(MlirOperation)> callback,
}
}
+void PyOperationBase::walkOfType(
+ nb::object opClass, std::function<PyWalkResult(MlirOperation)> callback,
+ PyWalkOrder walkOrder) {
+
+ auto filtered = [&](MlirOperation mlirOp) -> PyWalkResult {
+ nb::object opview = nb::cast(mlirOp).attr("opview");
+ if (nb::isinstance(opview, opClass)) {
+ return callback(mlirOp);
+ };
+ return PyWalkResult::Advance;
+ };
+ walk(filtered, walkOrder);
+}
+
nb::object PyOperationBase::getAsm(bool binary,
std::optional<int64_t> largeElementsLimit,
std::optional<int64_t> largeResourceLimit,
@@ -2707,7 +2721,7 @@ MlirLocation tracebackToLocation(MlirContext ctx) {
if (!PyGlobals::get().getTracebackLoc().isUserTracebackFilename(fileName))
continue;
- // co_qualname and PyCode_Addr2Location added in py3.11
+ // co_qualname and PyCode_Addr2Location added in py3.11
#if PY_VERSION_HEX < 0x030B00F0
std::string name =
nb::cast<std::string>(nb::borrow<nb::str>(code->co_name));
@@ -3868,7 +3882,22 @@ void populateIRCore(nb::module_ &m) {
Args:
callback: A callable that takes an Operation and returns a WalkResult.
- walk_order: The order of traversal (PRE_ORDER or POST_ORDER).)");
+ walk_order: The order of traversal (PRE_ORDER or POST_ORDER).)")
+ .def("walk_of_type", &PyOperationBase::walkOfType, "op_class"_a,
+ "callback"_a, "walk_order"_a = PyWalkOrder::PostOrder,
+ // clang-format off
+ nb::sig("def walk_of_type(self, op_class: type[OpView], callback: Callable[[Operation], WalkResult], walk_order: WalkOrder) -> None"),
+ // clang-format on
+ R"(
+ Walks the operation tree, invoking the callback only on operations of the specified type.
+
+ Args:
+ op_class: The operation type to match.
+ callback: A callable that takes an Operation and returns a WalkResult.
+ walk_order: The traversal order (PRE_ORDER or POST_ORDER).
+
+ For example, op.walk_of_type(arith.AddIOp, callback) walks the operation tree
+ and invokes callback only on arith.AddIOp operations.)");
nb::class_<PyOperation, PyOperationBase>(m, "Operation")
.def_static(
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index 865dd226cbe2a..f3cd729ece9a8 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -1323,6 +1323,58 @@ def callback(op):
module.operation.walk(callback)
except RuntimeError:
print("Exception raised")
+
+# CHECK-LABEL: TEST: testOpWalkOfType
+ at run
+def testOpWalkOfType():
+ with Context(), Location.unknown():
+ module = Module.parse("""
+ module {
+ func.func @f() { return }
+ func.func @g() { return }
+ arith.constant dense<0> : tensor<i32>
+ }
+ """)
+
+ # Callback: only visits ops of the requested type.
+ # CHECK: only FuncOp visited: True
+ only_funcs = True
+ def check_type(op):
+ nonlocal only_funcs
+ if not isinstance(op.opview, func.FuncOp):
+ only_funcs = False
+ return WalkResult.ADVANCE
+ module.operation.walk_of_type(func.FuncOp, check_type)
+ print(f"only FuncOp visited: {only_funcs}")
+
+ # Callback: interrupt after first match.
+ # CHECK: interrupted after: 1
+ seen = []
+ def stop_after_first(op):
+ seen.append(op.opview)
+ return WalkResult.INTERRUPT
+ module.operation.walk_of_type(func.FuncOp, stop_after_first)
+ print(f"interrupted after: {len(seen)}")
+
+ # Callback: no match, callback never called.
+ # CHECK: never called: True
+ called = False
+ def should_not_run(op):
+ nonlocal called
+ called = True
+ return WalkResult.ADVANCE
+ module.operation.walk_of_type(scf.ForOp, should_not_run)
+ print(f"never called: {not called}")
+
+ # Callback: collect all matching ops.
+ # CHECK: collected func.FuncOp: ['"f"', '"g"']
+ collected = []
+ def collect(op):
+ collected.append(op.opview)
+ return WalkResult.ADVANCE
+ module.operation.walk_of_type(func.FuncOp, collect)
+ assert all(isinstance(r, func.FuncOp) for r in collected)
+ print(f"collected func.FuncOp: {[str(r.name) for r in collected]}")
# CHECK-LABEL: TEST: testOpReplaceUsesWith
>From 51e64162fdcd1869d2d369ea36fc24db18ea0d05 Mon Sep 17 00:00:00 2001
From: Amily Wu <amilywu2 at amd.com>
Date: Thu, 12 Mar 2026 14:22:38 +0000
Subject: [PATCH 02/10] Add get_ops_of_type() and test
---
mlir/python/mlir/ir.py | 20 ++++++++++++++++++++
mlir/test/python/ir/operation.py | 29 +++++++++++++++++++++++++++++
2 files changed, 49 insertions(+)
diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py
index d86298a72c6f2..a495a4db77861 100644
--- a/mlir/python/mlir/ir.py
+++ b/mlir/python/mlir/ir.py
@@ -45,6 +45,26 @@ def get_parent_of_type(op: OpView | Operation, op_class: type[OpView]) -> OpView
return None
+def get_ops_of_type(root: OpView | Operation | Module, op_class: type[OpView]) -> list[OpView]:
+ """Return all operations of the given type in the operation tree.
+
+
+ Args:
+ root: The operation or module to start traversing from.
+ op_class: The OpView subclass to visit for (e.g. func.FuncOp).
+
+ Returns:
+ A list of operations of the given type.
+ """
+ op = root.operation if isinstance(root, Module) else root
+ ops = []
+ def collect_ops(op: Operation):
+ ops.append(op.opview)
+ return WalkResult.ADVANCE
+ op.walk_of_type(op_class, collect_ops)
+ return ops
+
+
@contextmanager
def loc_tracebacks(*, max_depth: int | None = None) -> Generator[None]:
"""Enables automatic traceback-based locations for MLIR operations.
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index f3cd729ece9a8..f184ec2a3984c 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -1489,3 +1489,32 @@ def testGetParentOfType():
assert False, "expected TypeError"
except TypeError:
pass
+
+
+# CHECK-LABEL: TEST: test_get_ops_of_type
+ at run
+def test_get_ops_of_type():
+ with Context(), Location.unknown():
+ module = Module.parse("""
+ module {
+ func.func @f() { return }
+ func.func @g() { return }
+ }
+ """)
+
+ # CHECK: get_ops_of_type func.func count: 2
+ results = get_ops_of_type(module, func.FuncOp)
+ print(f"get_ops_of_type func.func count: {len(results)}")
+ assert len(results) == 2
+ assert all(isinstance(r, func.FuncOp) for r in results)
+
+ # CHECK: get_ops_of_type scf.for count: 0
+ results = get_ops_of_type(module, scf.ForOp)
+ print(f"get_ops_of_type scf.for count: {len(results)}")
+ assert len(results) == 0
+
+ # Accepts OpView as root.
+ func_op = get_ops_of_type(module, func.FuncOp)[0]
+ results = get_ops_of_type(func_op, func.ReturnOp)
+ assert len(results) == 1
+ assert isinstance(results[0], func.ReturnOp)
>From 36e089fbc02e13935fe7d010ba7983f3b6424f94 Mon Sep 17 00:00:00 2001
From: Amily Wu <amilywu2 at amd.com>
Date: Thu, 12 Mar 2026 14:28:39 +0000
Subject: [PATCH 03/10] Fix comments
---
mlir/include/mlir/Bindings/Python/IRCore.h | 3 +--
mlir/lib/Bindings/Python/IRCore.cpp | 2 +-
mlir/python/mlir/ir.py | 1 -
3 files changed, 2 insertions(+), 4 deletions(-)
diff --git a/mlir/include/mlir/Bindings/Python/IRCore.h b/mlir/include/mlir/Bindings/Python/IRCore.h
index 1569a870e3996..836dedcff0ade 100644
--- a/mlir/include/mlir/Bindings/Python/IRCore.h
+++ b/mlir/include/mlir/Bindings/Python/IRCore.h
@@ -599,8 +599,7 @@ class MLIR_PYTHON_API_EXPORTED PyOperationBase {
void walk(std::function<PyWalkResult(MlirOperation)> callback,
PyWalkOrder walkOrder);
- // Wrap the walk method with a type filter. Works same as op.walk([](OpClass
- // op) { ... } );
+ // Wrap the walk method with a type filter.
void walkOfType(nanobind::object opClass,
std::function<PyWalkResult(MlirOperation)> callback,
PyWalkOrder walkOrder);
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 2c5634b239b77..ac3dc00c008a0 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -2721,7 +2721,7 @@ MlirLocation tracebackToLocation(MlirContext ctx) {
if (!PyGlobals::get().getTracebackLoc().isUserTracebackFilename(fileName))
continue;
- // co_qualname and PyCode_Addr2Location added in py3.11
+ // co_qualname and PyCode_Addr2Location added in py3.11
#if PY_VERSION_HEX < 0x030B00F0
std::string name =
nb::cast<std::string>(nb::borrow<nb::str>(code->co_name));
diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py
index a495a4db77861..023808b34fbb7 100644
--- a/mlir/python/mlir/ir.py
+++ b/mlir/python/mlir/ir.py
@@ -48,7 +48,6 @@ def get_parent_of_type(op: OpView | Operation, op_class: type[OpView]) -> OpView
def get_ops_of_type(root: OpView | Operation | Module, op_class: type[OpView]) -> list[OpView]:
"""Return all operations of the given type in the operation tree.
-
Args:
root: The operation or module to start traversing from.
op_class: The OpView subclass to visit for (e.g. func.FuncOp).
>From 4ec803dc8f4870fd57a695cb6d85ee0fbacdce95 Mon Sep 17 00:00:00 2001
From: Amily Wu <amilywu2 at amd.com>
Date: Thu, 12 Mar 2026 14:53:33 +0000
Subject: [PATCH 04/10] Fix format
---
mlir/lib/Bindings/Python/IRCore.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index ac3dc00c008a0..2c5634b239b77 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -2721,7 +2721,7 @@ MlirLocation tracebackToLocation(MlirContext ctx) {
if (!PyGlobals::get().getTracebackLoc().isUserTracebackFilename(fileName))
continue;
- // co_qualname and PyCode_Addr2Location added in py3.11
+ // co_qualname and PyCode_Addr2Location added in py3.11
#if PY_VERSION_HEX < 0x030B00F0
std::string name =
nb::cast<std::string>(nb::borrow<nb::str>(code->co_name));
>From ae9205f1fa163f6a4902dbb05fb683805dfd8670 Mon Sep 17 00:00:00 2001
From: Amily Wu <amilywu2 at amd.com>
Date: Thu, 12 Mar 2026 15:00:36 +0000
Subject: [PATCH 05/10] Fix format
---
mlir/python/mlir/ir.py | 6 +++++-
mlir/test/python/ir/operation.py | 11 ++++++++++-
2 files changed, 15 insertions(+), 2 deletions(-)
diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py
index 023808b34fbb7..b148c5fd28aed 100644
--- a/mlir/python/mlir/ir.py
+++ b/mlir/python/mlir/ir.py
@@ -45,7 +45,9 @@ def get_parent_of_type(op: OpView | Operation, op_class: type[OpView]) -> OpView
return None
-def get_ops_of_type(root: OpView | Operation | Module, op_class: type[OpView]) -> list[OpView]:
+def get_ops_of_type(
+ root: OpView | Operation | Module, op_class: type[OpView]
+) -> list[OpView]:
"""Return all operations of the given type in the operation tree.
Args:
@@ -57,9 +59,11 @@ def get_ops_of_type(root: OpView | Operation | Module, op_class: type[OpView]) -
"""
op = root.operation if isinstance(root, Module) else root
ops = []
+
def collect_ops(op: Operation):
ops.append(op.opview)
return WalkResult.ADVANCE
+
op.walk_of_type(op_class, collect_ops)
return ops
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index f184ec2a3984c..d579badc3e8dd 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -1323,7 +1323,8 @@ def callback(op):
module.operation.walk(callback)
except RuntimeError:
print("Exception raised")
-
+
+
# CHECK-LABEL: TEST: testOpWalkOfType
@run
def testOpWalkOfType():
@@ -1339,39 +1340,47 @@ def testOpWalkOfType():
# Callback: only visits ops of the requested type.
# CHECK: only FuncOp visited: True
only_funcs = True
+
def check_type(op):
nonlocal only_funcs
if not isinstance(op.opview, func.FuncOp):
only_funcs = False
return WalkResult.ADVANCE
+
module.operation.walk_of_type(func.FuncOp, check_type)
print(f"only FuncOp visited: {only_funcs}")
# Callback: interrupt after first match.
# CHECK: interrupted after: 1
seen = []
+
def stop_after_first(op):
seen.append(op.opview)
return WalkResult.INTERRUPT
+
module.operation.walk_of_type(func.FuncOp, stop_after_first)
print(f"interrupted after: {len(seen)}")
# Callback: no match, callback never called.
# CHECK: never called: True
called = False
+
def should_not_run(op):
nonlocal called
called = True
return WalkResult.ADVANCE
+
module.operation.walk_of_type(scf.ForOp, should_not_run)
print(f"never called: {not called}")
# Callback: collect all matching ops.
# CHECK: collected func.FuncOp: ['"f"', '"g"']
collected = []
+
def collect(op):
collected.append(op.opview)
return WalkResult.ADVANCE
+
module.operation.walk_of_type(func.FuncOp, collect)
assert all(isinstance(r, func.FuncOp) for r in collected)
print(f"collected func.FuncOp: {[str(r.name) for r in collected]}")
>From b9acc03375664b01d94e3ddd4ef3f2b3c876bc74 Mon Sep 17 00:00:00 2001
From: Amily Wu <amilywu2 at amd.com>
Date: Thu, 12 Mar 2026 15:16:49 +0000
Subject: [PATCH 06/10] Fix format
---
mlir/lib/Bindings/Python/IRCore.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 2c5634b239b77..ac3dc00c008a0 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -2721,7 +2721,7 @@ MlirLocation tracebackToLocation(MlirContext ctx) {
if (!PyGlobals::get().getTracebackLoc().isUserTracebackFilename(fileName))
continue;
- // co_qualname and PyCode_Addr2Location added in py3.11
+ // co_qualname and PyCode_Addr2Location added in py3.11
#if PY_VERSION_HEX < 0x030B00F0
std::string name =
nb::cast<std::string>(nb::borrow<nb::str>(code->co_name));
>From b7bb2cc125c4fc0453f41acf6cd0329b0bf3eb70 Mon Sep 17 00:00:00 2001
From: Amily Wu <amilywu2 at amd.com>
Date: Thu, 12 Mar 2026 15:29:45 +0000
Subject: [PATCH 07/10] Embed walkOfType into walk
---
mlir/include/mlir/Bindings/Python/IRCore.h | 5 --
mlir/lib/Bindings/Python/IRCore.cpp | 65 ++++++++++------------
mlir/python/mlir/ir.py | 2 +-
mlir/test/python/ir/operation.py | 65 +---------------------
4 files changed, 32 insertions(+), 105 deletions(-)
diff --git a/mlir/include/mlir/Bindings/Python/IRCore.h b/mlir/include/mlir/Bindings/Python/IRCore.h
index 836dedcff0ade..bd2d49acbf681 100644
--- a/mlir/include/mlir/Bindings/Python/IRCore.h
+++ b/mlir/include/mlir/Bindings/Python/IRCore.h
@@ -599,11 +599,6 @@ class MLIR_PYTHON_API_EXPORTED PyOperationBase {
void walk(std::function<PyWalkResult(MlirOperation)> callback,
PyWalkOrder walkOrder);
- // Wrap the walk method with a type filter.
- void walkOfType(nanobind::object opClass,
- std::function<PyWalkResult(MlirOperation)> callback,
- PyWalkOrder walkOrder);
-
/// Moves the operation before or after the other operation.
void moveAfter(PyOperationBase &other);
void moveBefore(PyOperationBase &other);
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index ac3dc00c008a0..21e6c41ae6e56 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -1145,20 +1145,6 @@ void PyOperationBase::walk(std::function<PyWalkResult(MlirOperation)> callback,
}
}
-void PyOperationBase::walkOfType(
- nb::object opClass, std::function<PyWalkResult(MlirOperation)> callback,
- PyWalkOrder walkOrder) {
-
- auto filtered = [&](MlirOperation mlirOp) -> PyWalkResult {
- nb::object opview = nb::cast(mlirOp).attr("opview");
- if (nb::isinstance(opview, opClass)) {
- return callback(mlirOp);
- };
- return PyWalkResult::Advance;
- };
- walk(filtered, walkOrder);
-}
-
nb::object PyOperationBase::getAsm(bool binary,
std::optional<int64_t> largeElementsLimit,
std::optional<int64_t> largeResourceLimit,
@@ -3872,32 +3858,39 @@ void populateIRCore(nb::module_ &m) {
Note:
After erasing, any Python references to the operation become invalid.)")
- .def("walk", &PyOperationBase::walk, "callback"_a,
- "walk_order"_a = PyWalkOrder::PostOrder,
- // clang-format off
- nb::sig("def walk(self, callback: Callable[[Operation], WalkResult], walk_order: WalkOrder) -> None"),
- // clang-format on
- R"(
+ .def(
+ "walk",
+ [](PyOperationBase &self,
+ std::function<PyWalkResult(MlirOperation)> callback,
+ std::optional<nb::object> opClass, PyWalkOrder walkOrder) {
+ if (opClass) {
+ self.walk(
+ [&](MlirOperation mlirOp) -> PyWalkResult {
+ nb::object opview = nb::cast(mlirOp).attr("opview");
+ if (nb::isinstance(opview, *opClass))
+ return callback(mlirOp);
+ return PyWalkResult::Advance;
+ },
+ walkOrder);
+ } else {
+ self.walk(callback, walkOrder);
+ }
+ },
+ "callback"_a, "op_class"_a = nb::none(),
+ "walk_order"_a = PyWalkOrder::PostOrder,
+ // clang-format off
+ nb::sig("def walk(self, callback: Callable[[Operation], WalkResult], op_class: type[OpView] | None = None, walk_order: WalkOrder = WalkOrder.POST_ORDER) -> None"),
+ // clang-format on
+ R"(
Walks the operation tree with a callback function.
+ If op_class is provided, the callback is only invoked on operations
+ of that type; all other operations are skipped silently.
+
Args:
callback: A callable that takes an Operation and returns a WalkResult.
- walk_order: The order of traversal (PRE_ORDER or POST_ORDER).)")
- .def("walk_of_type", &PyOperationBase::walkOfType, "op_class"_a,
- "callback"_a, "walk_order"_a = PyWalkOrder::PostOrder,
- // clang-format off
- nb::sig("def walk_of_type(self, op_class: type[OpView], callback: Callable[[Operation], WalkResult], walk_order: WalkOrder) -> None"),
- // clang-format on
- R"(
- Walks the operation tree, invoking the callback only on operations of the specified type.
-
- Args:
- op_class: The operation type to match.
- callback: A callable that takes an Operation and returns a WalkResult.
- walk_order: The traversal order (PRE_ORDER or POST_ORDER).
-
- For example, op.walk_of_type(arith.AddIOp, callback) walks the operation tree
- and invokes callback only on arith.AddIOp operations.)");
+ op_class: If provided, only operations of this type are passed to the callback.
+ walk_order: The order of traversal (PRE_ORDER or POST_ORDER).)");
nb::class_<PyOperation, PyOperationBase>(m, "Operation")
.def_static(
diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py
index b148c5fd28aed..7d8bd0ae7cfef 100644
--- a/mlir/python/mlir/ir.py
+++ b/mlir/python/mlir/ir.py
@@ -64,7 +64,7 @@ def collect_ops(op: Operation):
ops.append(op.opview)
return WalkResult.ADVANCE
- op.walk_of_type(op_class, collect_ops)
+ op.walk(collect_ops, op_class=op_class)
return ops
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index d579badc3e8dd..911a3bef3cc39 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -1284,7 +1284,7 @@ def callback(op):
# CHECK-NEXT: func.fun
# CHECK-NEXT: func.return
print("Pre-order")
- module.operation.walk(callback, WalkOrder.PRE_ORDER)
+ module.operation.walk(callback, walk_order=WalkOrder.PRE_ORDER)
# Test interrput.
# CHECK-NEXT: Interrupt post-order
@@ -1306,7 +1306,7 @@ def callback(op):
print(op.name)
return WalkResult.SKIP
- module.operation.walk(callback, WalkOrder.PRE_ORDER)
+ module.operation.walk(callback, walk_order=WalkOrder.PRE_ORDER)
# Test exception.
# CHECK: Exception
@@ -1325,67 +1325,6 @@ def callback(op):
print("Exception raised")
-# CHECK-LABEL: TEST: testOpWalkOfType
- at run
-def testOpWalkOfType():
- with Context(), Location.unknown():
- module = Module.parse("""
- module {
- func.func @f() { return }
- func.func @g() { return }
- arith.constant dense<0> : tensor<i32>
- }
- """)
-
- # Callback: only visits ops of the requested type.
- # CHECK: only FuncOp visited: True
- only_funcs = True
-
- def check_type(op):
- nonlocal only_funcs
- if not isinstance(op.opview, func.FuncOp):
- only_funcs = False
- return WalkResult.ADVANCE
-
- module.operation.walk_of_type(func.FuncOp, check_type)
- print(f"only FuncOp visited: {only_funcs}")
-
- # Callback: interrupt after first match.
- # CHECK: interrupted after: 1
- seen = []
-
- def stop_after_first(op):
- seen.append(op.opview)
- return WalkResult.INTERRUPT
-
- module.operation.walk_of_type(func.FuncOp, stop_after_first)
- print(f"interrupted after: {len(seen)}")
-
- # Callback: no match, callback never called.
- # CHECK: never called: True
- called = False
-
- def should_not_run(op):
- nonlocal called
- called = True
- return WalkResult.ADVANCE
-
- module.operation.walk_of_type(scf.ForOp, should_not_run)
- print(f"never called: {not called}")
-
- # Callback: collect all matching ops.
- # CHECK: collected func.FuncOp: ['"f"', '"g"']
- collected = []
-
- def collect(op):
- collected.append(op.opview)
- return WalkResult.ADVANCE
-
- module.operation.walk_of_type(func.FuncOp, collect)
- assert all(isinstance(r, func.FuncOp) for r in collected)
- print(f"collected func.FuncOp: {[str(r.name) for r in collected]}")
-
-
# CHECK-LABEL: TEST: testOpReplaceUsesWith
@run
def testOpReplaceUsesWith():
>From 4b18b54b8c3ab176775a8a38db8164b927df70f5 Mon Sep 17 00:00:00 2001
From: Amily Wu <amilywu2 at amd.com>
Date: Thu, 12 Mar 2026 15:33:47 +0000
Subject: [PATCH 08/10] Merge test
---
mlir/test/python/ir/operation.py | 56 ++++++++++++++++++++++++++++++++
1 file changed, 56 insertions(+)
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index 911a3bef3cc39..bc883643d53fc 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -1324,6 +1324,62 @@ def callback(op):
except RuntimeError:
print("Exception raised")
+ # Test op_class filter: only visits ops of the requested type.
+ module = Module.parse(
+ """
+ module {
+ func.func @f() { return }
+ func.func @g() { return }
+ arith.constant dense<0> : tensor<i32>
+ }
+ """,
+ ctx,
+ )
+
+ # CHECK-NEXT: only FuncOp visited: True
+ only_funcs = True
+
+ def check_type(op):
+ nonlocal only_funcs
+ if not isinstance(op.opview, func.FuncOp):
+ only_funcs = False
+ return WalkResult.ADVANCE
+
+ module.operation.walk(check_type, op_class=func.FuncOp)
+ print(f"only FuncOp visited: {only_funcs}")
+
+ # CHECK-NEXT: interrupted after: 1
+ seen = []
+
+ def stop_after_first(op):
+ seen.append(op.opview)
+ return WalkResult.INTERRUPT
+
+ module.operation.walk(stop_after_first, op_class=func.FuncOp)
+ print(f"interrupted after: {len(seen)}")
+
+ # CHECK-NEXT: never called: True
+ called = False
+
+ def should_not_run(op):
+ nonlocal called
+ called = True
+ return WalkResult.ADVANCE
+
+ module.operation.walk(should_not_run, op_class=scf.ForOp)
+ print(f"never called: {not called}")
+
+ # CHECK-NEXT: collected func.FuncOp: ['"f"', '"g"']
+ collected = []
+
+ def collect(op):
+ collected.append(op.opview)
+ return WalkResult.ADVANCE
+
+ module.operation.walk(collect, op_class=func.FuncOp)
+ assert all(isinstance(r, func.FuncOp) for r in collected)
+ print(f"collected func.FuncOp: {[str(r.name) for r in collected]}")
+
# CHECK-LABEL: TEST: testOpReplaceUsesWith
@run
>From a06ba1ce175f3fb4b7a725571a775689bc902c38 Mon Sep 17 00:00:00 2001
From: Amily Wu <amilywu2 at amd.com>
Date: Thu, 12 Mar 2026 15:46:28 +0000
Subject: [PATCH 09/10] Fix format
---
mlir/test/python/ir/operation.py | 6 ++++--
1 file changed, 4 insertions(+), 2 deletions(-)
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index bc883643d53fc..46850ab960be5 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -1499,12 +1499,14 @@ def testGetParentOfType():
@run
def test_get_ops_of_type():
with Context(), Location.unknown():
- module = Module.parse("""
+ module = Module.parse(
+ """
module {
func.func @f() { return }
func.func @g() { return }
}
- """)
+ """
+ )
# CHECK: get_ops_of_type func.func count: 2
results = get_ops_of_type(module, func.FuncOp)
>From b0059e85ab5b7997b928891cd6758ee698f8be4a Mon Sep 17 00:00:00 2001
From: Amily Wu <amilywu2 at amd.com>
Date: Thu, 12 Mar 2026 17:06:00 +0000
Subject: [PATCH 10/10] Update get opview
---
mlir/lib/Bindings/Python/IRCore.cpp | 4 +++-
1 file changed, 3 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 21e6c41ae6e56..6fca7c6fb18b6 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -3866,7 +3866,9 @@ void populateIRCore(nb::module_ &m) {
if (opClass) {
self.walk(
[&](MlirOperation mlirOp) -> PyWalkResult {
- nb::object opview = nb::cast(mlirOp).attr("opview");
+ nb::object opview = PyOperation::forOperation(
+ self.getOperation().getContext(), mlirOp)
+ ->createOpView();
if (nb::isinstance(opview, *opClass))
return callback(mlirOp);
return PyWalkResult::Advance;
More information about the Mlir-commits
mailing list