[Mlir-commits] [mlir] [MLIR][Python] Add walk_of_type() binding and get_ops_of_type() utility (PR #186131)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Mar 12 07:57:34 PDT 2026
github-actions[bot] wrote:
<!--LLVM CODE FORMAT COMMENT: {darker}-->
:warning: Python code formatter, darker found issues in your code. :warning:
<details>
<summary>
You can test this locally with the following command:
</summary>
``````````bash
darker --check --diff -r origin/main...HEAD mlir/python/mlir/ir.py mlir/test/python/ir/operation.py
``````````
:warning:
The reproduction instructions above might return results for more than one PR
in a stack if you are using a stacked PR workflow. You can limit the results by
changing `origin/main` to the base branch/commit you want to compare against.
:warning:
</details>
<details>
<summary>
View the diff from darker here.
</summary>
``````````diff
--- python/mlir/ir.py 2026-03-12 14:53:43.000000 +0000
+++ python/mlir/ir.py 2026-03-12 14:55:18.359666 +0000
@@ -43,11 +43,13 @@
return parent.opview
parent = parent.parent
return None
-def get_ops_of_type(root: OpView | Operation | Module, op_class: type[OpView]) -> list[OpView]:
+def get_ops_of_type(
+ root: OpView | Operation | Module, op_class: type[OpView]
+) -> list[OpView]:
"""Return all operations of the given type in the operation tree.
Args:
root: The operation or module to start traversing from.
op_class: The OpView subclass to visit for (e.g. func.FuncOp).
@@ -55,13 +57,15 @@
Returns:
A list of operations of the given type.
"""
op = root.operation if isinstance(root, Module) else root
ops = []
+
def collect_ops(op: Operation):
ops.append(op.opview)
return WalkResult.ADVANCE
+
op.walk_of_type(op_class, collect_ops)
return ops
@contextmanager
--- test/python/ir/operation.py 2026-03-12 14:53:33.000000 +0000
+++ test/python/ir/operation.py 2026-03-12 14:55:18.787522 +0000
@@ -1321,59 +1321,70 @@
try:
module.operation.walk(callback)
except RuntimeError:
print("Exception raised")
-
+
+
# CHECK-LABEL: TEST: testOpWalkOfType
@run
def testOpWalkOfType():
with Context(), Location.unknown():
- module = Module.parse("""
+ module = Module.parse(
+ """
module {
func.func @f() { return }
func.func @g() { return }
arith.constant dense<0> : tensor<i32>
}
- """)
+ """
+ )
# Callback: only visits ops of the requested type.
# CHECK: only FuncOp visited: True
only_funcs = True
+
def check_type(op):
nonlocal only_funcs
if not isinstance(op.opview, func.FuncOp):
only_funcs = False
return WalkResult.ADVANCE
+
module.operation.walk_of_type(func.FuncOp, check_type)
print(f"only FuncOp visited: {only_funcs}")
# Callback: interrupt after first match.
# CHECK: interrupted after: 1
seen = []
+
def stop_after_first(op):
seen.append(op.opview)
return WalkResult.INTERRUPT
+
module.operation.walk_of_type(func.FuncOp, stop_after_first)
print(f"interrupted after: {len(seen)}")
# Callback: no match, callback never called.
# CHECK: never called: True
called = False
+
def should_not_run(op):
nonlocal called
called = True
return WalkResult.ADVANCE
+
module.operation.walk_of_type(scf.ForOp, should_not_run)
print(f"never called: {not called}")
# Callback: collect all matching ops.
# CHECK: collected func.FuncOp: ['"f"', '"g"']
collected = []
+
def collect(op):
collected.append(op.opview)
return WalkResult.ADVANCE
+
module.operation.walk_of_type(func.FuncOp, collect)
assert all(isinstance(r, func.FuncOp) for r in collected)
print(f"collected func.FuncOp: {[str(r.name) for r in collected]}")
@@ -1493,16 +1504,18 @@
# CHECK-LABEL: TEST: test_get_ops_of_type
@run
def test_get_ops_of_type():
with Context(), Location.unknown():
- module = Module.parse("""
+ module = Module.parse(
+ """
module {
func.func @f() { return }
func.func @g() { return }
}
- """)
+ """
+ )
# CHECK: get_ops_of_type func.func count: 2
results = get_ops_of_type(module, func.FuncOp)
print(f"get_ops_of_type func.func count: {len(results)}")
assert len(results) == 2
``````````
</details>
https://github.com/llvm/llvm-project/pull/186131
More information about the Mlir-commits
mailing list