[Mlir-commits] [mlir] [MLIR][Python] Add walk_of_type() binding and get_ops_of_type() utility (PR #186131)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Mar 12 07:57:34 PDT 2026


github-actions[bot] wrote:

<!--LLVM CODE FORMAT COMMENT: {darker}-->


:warning: Python code formatter, darker found issues in your code. :warning:

<details>
<summary>
You can test this locally with the following command:
</summary>

``````````bash
darker --check --diff -r origin/main...HEAD mlir/python/mlir/ir.py mlir/test/python/ir/operation.py
``````````

:warning:
The reproduction instructions above might return results for more than one PR
in a stack if you are using a stacked PR workflow. You can limit the results by
changing `origin/main` to the base branch/commit you want to compare against.
:warning:

</details>

<details>
<summary>
View the diff from darker here.
</summary>

``````````diff
--- python/mlir/ir.py	2026-03-12 14:53:43.000000 +0000
+++ python/mlir/ir.py	2026-03-12 14:55:18.359666 +0000
@@ -43,11 +43,13 @@
             return parent.opview
         parent = parent.parent
     return None
 
 
-def get_ops_of_type(root: OpView | Operation | Module, op_class: type[OpView]) -> list[OpView]:
+def get_ops_of_type(
+    root: OpView | Operation | Module, op_class: type[OpView]
+) -> list[OpView]:
     """Return all operations of the given type in the operation tree.
 
     Args:
       root: The operation or module to start traversing from.
       op_class: The OpView subclass to visit for (e.g. func.FuncOp).
@@ -55,13 +57,15 @@
     Returns:
       A list of operations of the given type.
     """
     op = root.operation if isinstance(root, Module) else root
     ops = []
+
     def collect_ops(op: Operation):
         ops.append(op.opview)
         return WalkResult.ADVANCE
+
     op.walk_of_type(op_class, collect_ops)
     return ops
 
 
 @contextmanager
--- test/python/ir/operation.py	2026-03-12 14:53:33.000000 +0000
+++ test/python/ir/operation.py	2026-03-12 14:55:18.787522 +0000
@@ -1321,59 +1321,70 @@
 
     try:
         module.operation.walk(callback)
     except RuntimeError:
         print("Exception raised")
-        
+
+
 # CHECK-LABEL: TEST: testOpWalkOfType
 @run
 def testOpWalkOfType():
     with Context(), Location.unknown():
-        module = Module.parse("""
+        module = Module.parse(
+            """
             module {
                 func.func @f() { return }
                 func.func @g() { return }
                 arith.constant dense<0> : tensor<i32>
             }
-        """)
+        """
+        )
 
     # Callback: only visits ops of the requested type.
     # CHECK: only FuncOp visited: True
     only_funcs = True
+
     def check_type(op):
         nonlocal only_funcs
         if not isinstance(op.opview, func.FuncOp):
             only_funcs = False
         return WalkResult.ADVANCE
+
     module.operation.walk_of_type(func.FuncOp, check_type)
     print(f"only FuncOp visited: {only_funcs}")
 
     # Callback: interrupt after first match.
     # CHECK: interrupted after: 1
     seen = []
+
     def stop_after_first(op):
         seen.append(op.opview)
         return WalkResult.INTERRUPT
+
     module.operation.walk_of_type(func.FuncOp, stop_after_first)
     print(f"interrupted after: {len(seen)}")
 
     # Callback: no match, callback never called.
     # CHECK: never called: True
     called = False
+
     def should_not_run(op):
         nonlocal called
         called = True
         return WalkResult.ADVANCE
+
     module.operation.walk_of_type(scf.ForOp, should_not_run)
     print(f"never called: {not called}")
 
     # Callback: collect all matching ops.
     # CHECK: collected func.FuncOp: ['"f"', '"g"']
     collected = []
+
     def collect(op):
         collected.append(op.opview)
         return WalkResult.ADVANCE
+
     module.operation.walk_of_type(func.FuncOp, collect)
     assert all(isinstance(r, func.FuncOp) for r in collected)
     print(f"collected func.FuncOp: {[str(r.name) for r in collected]}")
 
 
@@ -1493,16 +1504,18 @@
 
 # CHECK-LABEL: TEST: test_get_ops_of_type
 @run
 def test_get_ops_of_type():
     with Context(), Location.unknown():
-        module = Module.parse("""
+        module = Module.parse(
+            """
             module {
                 func.func @f() { return }
                 func.func @g() { return }
             }
-        """)
+        """
+        )
 
         # CHECK: get_ops_of_type func.func count: 2
         results = get_ops_of_type(module, func.FuncOp)
         print(f"get_ops_of_type func.func count: {len(results)}")
         assert len(results) == 2

``````````

</details>


https://github.com/llvm/llvm-project/pull/186131


More information about the Mlir-commits mailing list