[Mlir-commits] [mlir] 29cd792 - [MLIR][Python] Add get_parent_of_type helper (#185512)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Mar 10 09:11:56 PDT 2026


Author: RattataKing
Date: 2026-03-10T12:11:51-04:00
New Revision: 29cd7921bcaa16a60afd7ae90b05737a837152ce

URL: https://github.com/llvm/llvm-project/commit/29cd7921bcaa16a60afd7ae90b05737a837152ce
DIFF: https://github.com/llvm/llvm-project/commit/29cd7921bcaa16a60afd7ae90b05737a837152ce.diff

LOG: [MLIR][Python] Add get_parent_of_type helper (#185512)

The `op.parent` only returns the immediate parent, in which case
downstream users have to traverse the operation by themselves to find a
specific type op.
This PR adds a python function `get_parent_of_type()` to mlir.ir to
provide an API to do so.

The function mirrors the implementation here:

https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/IR/Operation.h#L257-L273.
Instead of creating a new binding, reimplement it in python using
`isinstance()` is simpler.

Added: 
    

Modified: 
    mlir/python/mlir/ir.py
    mlir/test/python/ir/operation.py

Removed: 
    


################################################################################
diff  --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py
index f84792d4095f4..d86298a72c6f2 100644
--- a/mlir/python/mlir/ir.py
+++ b/mlir/python/mlir/ir.py
@@ -21,6 +21,30 @@
 )
 
 
+def get_parent_of_type(op: OpView | Operation, op_class: type[OpView]) -> OpView | None:
+    """Return the closest enclosing parent operation of the given type.
+
+    Walks the parent chain of *op* and returns the first ancestor that is an instance of *op_class*.
+    Returns ``None`` if no matching parent is found.
+
+    Args:
+      op: The starting operation.
+      op_class: The OpView subclass to search for (e.g. ``func.FuncOp``).
+
+    """
+    if not (isinstance(op_class, type) and issubclass(op_class, OpView)):
+        raise TypeError(f"op_class must be an OpView subclass, got {op_class!r}")
+    try:
+        op = op.parent
+    except ValueError:
+        return None  # No parent chain.
+    while op is not None:
+        if isinstance(op.opview, op_class):
+            return op.opview
+        op = op.parent
+    return None
+
+
 @contextmanager
 def loc_tracebacks(*, max_depth: int | None = None) -> Generator[None]:
     """Enables automatic traceback-based locations for MLIR operations.

diff  --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index 89f78ab1932a0..865dd226cbe2a 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -1391,3 +1391,49 @@ def index_switch(index):
                 assert len([i for i in switch_op.caseRegions]) == 3
                 assert len(switch_op.caseRegions[1:]) == 2
                 assert len([i for i in switch_op.caseRegions[1:]]) == 2
+
+
+# CHECK-LABEL: TEST: testGetParentOfType
+ at run
+def testGetParentOfType():
+    with Context() as ctx, Location.unknown():
+        ctx.allow_unregistered_dialects = True
+        idx = IndexType.get()
+        # Build: func.func -> scf.for -> custom.base_op
+        func_op: func.FuncOp = func.FuncOp("test_fn", ([], []))
+        with InsertionPoint(func_op.add_entry_block()):
+            lower_bound = arith.ConstantOp(idx, 0)
+            upper_bound = arith.ConstantOp(idx, 10)
+            step = arith.ConstantOp(idx, 1)
+            for_op: scf.ForOp = scf.ForOp(lower_bound, upper_bound, step)
+            with InsertionPoint(for_op.body):
+                base_op: Operation = Operation.create("custom.base_op")
+                scf.YieldOp([])
+            func.ReturnOp([])
+
+        # CHECK: get_parent_of_type detached->func.func: None
+        detached: Operation = Operation.create("custom.detached")
+        res = get_parent_of_type(detached, func.FuncOp)
+        print(f"get_parent_of_type detached->func.func: {res}")
+        assert res is None
+
+        # CHECK: get_parent_of_type base_op->func.func: func.func
+        res = get_parent_of_type(base_op, func.FuncOp)
+        print(f"get_parent_of_type base_op->func.func: {res.operation.name}")
+        assert isinstance(res, func.FuncOp)
+
+        # CHECK: get_parent_of_type func_op->func.func: None
+        res = get_parent_of_type(func_op, func.FuncOp)
+        print(f"get_parent_of_type func_op->func.func: {res}")
+        assert res is None
+
+        # CHECK: get_parent_of_type base_op->scf.if: None
+        res = get_parent_of_type(base_op, scf.IfOp)
+        print(f"get_parent_of_type base_op->scf.if: {res}")
+        assert res is None
+
+        try:
+            get_parent_of_type(base_op, int)
+            assert False, "expected TypeError"
+        except TypeError:
+            pass


        


More information about the Mlir-commits mailing list