[Mlir-commits] [mlir] [MLIR][Python] Add get_parent_of_type helper (PR #185512)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Mar 10 08:02:21 PDT 2026


https://github.com/RattataKing updated https://github.com/llvm/llvm-project/pull/185512

>From e0cf1c904888840d7dc76a01951db723dcb6df98 Mon Sep 17 00:00:00 2001
From: Amily Wu <amilywu2 at amd.com>
Date: Mon, 9 Mar 2026 20:26:57 +0000
Subject: [PATCH 1/7] Add get_parent api and test

---
 mlir/python/mlir/ir.py           | 27 +++++++++++++++++++++
 mlir/test/python/ir/operation.py | 41 ++++++++++++++++++++++++++++++++
 2 files changed, 68 insertions(+)

diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py
index f84792d4095f4..72beb987a5790 100644
--- a/mlir/python/mlir/ir.py
+++ b/mlir/python/mlir/ir.py
@@ -21,6 +21,33 @@
 )
 
 
+def get_parent_of_type(
+    op: "OpView | Operation", op_class: "type[OpView]"
+) -> "OpView | None":
+    """Return the closest enclosing parent operation of the given type.
+
+    Walks the parent chain of *op* and returns the first ancestor that is an instance of *op_class*.
+    Returns ``None`` if no matching parent is found.
+
+    Args:
+      op: The starting operation.
+      op_class: The OpView subclass to search for (e.g. ``func.FuncOp``).
+
+    """
+    if not (isinstance(op_class, type) and issubclass(op_class, OpView)):
+        raise TypeError(f"op_class must be an OpView subclass, got {op_class!r}")
+    try:
+        parent = op.operation.parent
+    except ValueError:
+        return None  # No parent chain.
+    while parent is not None:
+        opview = parent.opview
+        if isinstance(opview, op_class):
+            return opview
+        parent = parent.parent
+    return None
+
+
 @contextmanager
 def loc_tracebacks(*, max_depth: int | None = None) -> Generator[None]:
     """Enables automatic traceback-based locations for MLIR operations.
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index 89f78ab1932a0..c647ef6b46a11 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -4,6 +4,7 @@
 import io
 from tempfile import NamedTemporaryFile
 from mlir.ir import *
+from mlir.ir import get_parent_of_type
 from mlir.dialects.builtin import ModuleOp
 from mlir.dialects import arith, func, scf, shape
 from mlir.dialects._ods_common import _cext
@@ -1391,3 +1392,43 @@ def index_switch(index):
                 assert len([i for i in switch_op.caseRegions]) == 3
                 assert len(switch_op.caseRegions[1:]) == 2
                 assert len([i for i in switch_op.caseRegions[1:]]) == 2
+
+
+# CHECK-LABEL: TEST: testGetParentOfType
+ at run
+def testGetParentOfType():
+    with Context() as ctx, Location.unknown():
+        ctx.allow_unregistered_dialects = True
+        idx = IndexType.get()
+        # Build: func.func -> scf.for -> custom.base_op
+        func_op = func.FuncOp("test_fn", ([], []))
+        with InsertionPoint(func_op.add_entry_block()):
+            lower_bound = arith.ConstantOp(idx, 0)
+            upper_bound = arith.ConstantOp(idx, 10)
+            step = arith.ConstantOp(idx, 1)
+            for_op = scf.ForOp(lower_bound, upper_bound, step)
+            with InsertionPoint(for_op.body):
+                base_op = Operation.create("custom.base_op")
+                scf.YieldOp([])
+            func.ReturnOp([])
+
+        # CHECK: get_parent_of_type base_op->func.func: func.func
+        res = get_parent_of_type(base_op, func.FuncOp)
+        print(f"get_parent_of_type base_op->func.func: {res.operation.name}")
+        assert isinstance(res, func.FuncOp)
+
+        # CHECK: get_parent_of_type func_op->func.func: None
+        res = get_parent_of_type(func_op, func.FuncOp)
+        print(f"get_parent_of_type func_op->func.func: {res}")
+        assert res is None
+
+        # CHECK: get_parent_of_type base_op->scf.if: None
+        res = get_parent_of_type(base_op, scf.IfOp)
+        print(f"get_parent_of_type base_op->scf.if: {res}")
+        assert res is None
+
+        try:
+            get_parent_of_type(base_op, int)
+            assert False, "expected TypeError"
+        except TypeError:
+            pass

>From 58f4fc20d7e28643bf4a0b3e3a6747d9ae584647 Mon Sep 17 00:00:00 2001
From: Amily Wu <amilywu2 at amd.com>
Date: Mon, 9 Mar 2026 20:29:45 +0000
Subject: [PATCH 2/7] Remove extra import

---
 mlir/test/python/ir/operation.py | 1 -
 1 file changed, 1 deletion(-)

diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index c647ef6b46a11..b4413bf2340a1 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -4,7 +4,6 @@
 import io
 from tempfile import NamedTemporaryFile
 from mlir.ir import *
-from mlir.ir import get_parent_of_type
 from mlir.dialects.builtin import ModuleOp
 from mlir.dialects import arith, func, scf, shape
 from mlir.dialects._ods_common import _cext

>From eec7c89437c7f78fb74e061dbc45e5716ca91695 Mon Sep 17 00:00:00 2001
From: Amily Wu <amilywu2 at amd.com>
Date: Mon, 9 Mar 2026 21:12:15 +0000
Subject: [PATCH 3/7] Fix type hints

---
 mlir/python/mlir/ir.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py
index 72beb987a5790..6816627004b02 100644
--- a/mlir/python/mlir/ir.py
+++ b/mlir/python/mlir/ir.py
@@ -22,8 +22,8 @@
 
 
 def get_parent_of_type(
-    op: "OpView | Operation", op_class: "type[OpView]"
-) -> "OpView | None":
+    op: OpView | Operation, op_class: type[OpView]
+) -> OpView | None:
     """Return the closest enclosing parent operation of the given type.
 
     Walks the parent chain of *op* and returns the first ancestor that is an instance of *op_class*.

>From 094bfcf92cfdd2a1010c3e5417525cec191ab38a Mon Sep 17 00:00:00 2001
From: Amily Wu <amilywu2 at amd.com>
Date: Mon, 9 Mar 2026 21:36:00 +0000
Subject: [PATCH 4/7] Add detach op test

---
 mlir/python/mlir/ir.py           | 11 +++++------
 mlir/test/python/ir/operation.py |  6 ++++++
 2 files changed, 11 insertions(+), 6 deletions(-)

diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py
index 6816627004b02..6eb10ae5db140 100644
--- a/mlir/python/mlir/ir.py
+++ b/mlir/python/mlir/ir.py
@@ -37,14 +37,13 @@ def get_parent_of_type(
     if not (isinstance(op_class, type) and issubclass(op_class, OpView)):
         raise TypeError(f"op_class must be an OpView subclass, got {op_class!r}")
     try:
-        parent = op.operation.parent
+        op = op.parent
     except ValueError:
         return None  # No parent chain.
-    while parent is not None:
-        opview = parent.opview
-        if isinstance(opview, op_class):
-            return opview
-        parent = parent.parent
+    while op is not None:
+        if isinstance(op.opview, op_class):
+            return op.opview
+        op = op.parent
     return None
 
 
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index b4413bf2340a1..cc231e89d4d55 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -1410,6 +1410,12 @@ def testGetParentOfType():
                 base_op = Operation.create("custom.base_op")
                 scf.YieldOp([])
             func.ReturnOp([])
+            
+        # CHECK: get_parent_of_type detached->func.func: None
+        detached = Operation.create("custom.detached")
+        res = get_parent_of_type(detached, func.FuncOp)
+        print(f"get_parent_of_type detached->func.func: {res}")
+        assert res is None
 
         # CHECK: get_parent_of_type base_op->func.func: func.func
         res = get_parent_of_type(base_op, func.FuncOp)

>From 015d4964293275693ce27d40372ec17cab71b785 Mon Sep 17 00:00:00 2001
From: Amily Wu <amilywu2 at amd.com>
Date: Mon, 9 Mar 2026 21:42:31 +0000
Subject: [PATCH 5/7] Fix format

---
 mlir/python/mlir/ir.py           | 4 +---
 mlir/test/python/ir/operation.py | 3 ++-
 2 files changed, 3 insertions(+), 4 deletions(-)

diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py
index 6eb10ae5db140..d86298a72c6f2 100644
--- a/mlir/python/mlir/ir.py
+++ b/mlir/python/mlir/ir.py
@@ -21,9 +21,7 @@
 )
 
 
-def get_parent_of_type(
-    op: OpView | Operation, op_class: type[OpView]
-) -> OpView | None:
+def get_parent_of_type(op: OpView | Operation, op_class: type[OpView]) -> OpView | None:
     """Return the closest enclosing parent operation of the given type.
 
     Walks the parent chain of *op* and returns the first ancestor that is an instance of *op_class*.
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index cc231e89d4d55..cedf6c67cead4 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -1410,7 +1410,8 @@ def testGetParentOfType():
                 base_op = Operation.create("custom.base_op")
                 scf.YieldOp([])
             func.ReturnOp([])
-            
+
+
         # CHECK: get_parent_of_type detached->func.func: None
         detached = Operation.create("custom.detached")
         res = get_parent_of_type(detached, func.FuncOp)

>From 35a976f9b1308679247e235c2c2033a776b014d3 Mon Sep 17 00:00:00 2001
From: Amily Wu <amilywu2 at amd.com>
Date: Mon, 9 Mar 2026 21:43:22 +0000
Subject: [PATCH 6/7] Fix format

---
 mlir/test/python/ir/operation.py | 1 -
 1 file changed, 1 deletion(-)

diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index cedf6c67cead4..511e7bcb9e8bc 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -1411,7 +1411,6 @@ def testGetParentOfType():
                 scf.YieldOp([])
             func.ReturnOp([])
 
-
         # CHECK: get_parent_of_type detached->func.func: None
         detached = Operation.create("custom.detached")
         res = get_parent_of_type(detached, func.FuncOp)

>From 2310d2df5821d152164c3dd173f42a438d64d23f Mon Sep 17 00:00:00 2001
From: Amily Wu <amilywu2 at amd.com>
Date: Tue, 10 Mar 2026 15:02:03 +0000
Subject: [PATCH 7/7] Add type hints

---
 mlir/test/python/ir/operation.py | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index 511e7bcb9e8bc..865dd226cbe2a 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -1400,19 +1400,19 @@ def testGetParentOfType():
         ctx.allow_unregistered_dialects = True
         idx = IndexType.get()
         # Build: func.func -> scf.for -> custom.base_op
-        func_op = func.FuncOp("test_fn", ([], []))
+        func_op: func.FuncOp = func.FuncOp("test_fn", ([], []))
         with InsertionPoint(func_op.add_entry_block()):
             lower_bound = arith.ConstantOp(idx, 0)
             upper_bound = arith.ConstantOp(idx, 10)
             step = arith.ConstantOp(idx, 1)
-            for_op = scf.ForOp(lower_bound, upper_bound, step)
+            for_op: scf.ForOp = scf.ForOp(lower_bound, upper_bound, step)
             with InsertionPoint(for_op.body):
-                base_op = Operation.create("custom.base_op")
+                base_op: Operation = Operation.create("custom.base_op")
                 scf.YieldOp([])
             func.ReturnOp([])
 
         # CHECK: get_parent_of_type detached->func.func: None
-        detached = Operation.create("custom.detached")
+        detached: Operation = Operation.create("custom.detached")
         res = get_parent_of_type(detached, func.FuncOp)
         print(f"get_parent_of_type detached->func.func: {res}")
         assert res is None



More information about the Mlir-commits mailing list