[Mlir-commits] [mlir] [MLIR][Python] Register Containers as Sequences for `match` compatibility (PR #174091)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jan 1 01:59:10 PST 2026


https://github.com/MaPePeR updated https://github.com/llvm/llvm-project/pull/174091

>From f773e8d4f18a1849333a071c187d426581edf667 Mon Sep 17 00:00:00 2001
From: MaPePeR <MaPePeR at users.noreply.github.com>
Date: Wed, 31 Dec 2025 00:19:54 +0000
Subject: [PATCH 1/4] [MLIR][Python] Register Containers as Sequences

This allows them to be used in `match` statements.
---
 mlir/include/mlir/Bindings/Python/Nanobind.h |  1 +
 mlir/lib/Bindings/Python/IRCore.cpp          | 16 ++++++
 mlir/test/python/ir/blocks.py                | 17 ++++++
 mlir/test/python/ir/operation.py             | 58 ++++++++++++++++++--
 4 files changed, 88 insertions(+), 4 deletions(-)

diff --git a/mlir/include/mlir/Bindings/Python/Nanobind.h b/mlir/include/mlir/Bindings/Python/Nanobind.h
index 8dc8a0d063d70..62837e8b8652b 100644
--- a/mlir/include/mlir/Bindings/Python/Nanobind.h
+++ b/mlir/include/mlir/Bindings/Python/Nanobind.h
@@ -21,6 +21,7 @@
 #pragma GCC diagnostic ignored "-Wc++98-compat-extra-semi"
 #pragma GCC diagnostic ignored "-Wcovered-switch-default"
 #endif
+#include <nanobind/eval.h>
 #include <nanobind/nanobind.h>
 #include <nanobind/ndarray.h>
 #include <nanobind/stl/function.h>
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 168c57955af07..84154ede9f3e1 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -4951,6 +4951,22 @@ void mlir::python::populateIRCore(nb::module_ &m) {
   PyRegionIterator::bind(m);
   PyRegionList::bind(m);
 
+  // Register containers as Sequences, so they can be used with `match`.
+  nanobind::object scope = m.attr("__dict__");
+  nanobind::exec("from collections.abc import Sequence, Mapping\n"
+                 "Sequence.register(BlockArgumentList)\n"
+                 "Sequence.register(BlockList)\n"
+                 "Sequence.register(BlockSuccessors)\n"
+                 "Sequence.register(BlockPredecessors)\n"
+                 "Sequence.register(OperationList)\n"
+                 "Sequence.register(OpOperandList)\n"
+                 "Sequence.register(OpResultList)\n"
+                 "Sequence.register(OpSuccessors)\n"
+                 "Sequence.register(RegionSequence)\n"
+                 "OpAttributeMap.get = Mapping.get\n"
+                 "Mapping.register(OpAttributeMap)\n",
+                 scope);
+
   // Debug bindings.
   PyGlobalDebugFlag::bind(m);
 
diff --git a/mlir/test/python/ir/blocks.py b/mlir/test/python/ir/blocks.py
index e876c00e0c52d..6c7501b11292e 100644
--- a/mlir/test/python/ir/blocks.py
+++ b/mlir/test/python/ir/blocks.py
@@ -72,6 +72,23 @@ def testBlockCreation():
 
         assert len(successor_block.successors) == 0
 
+        # Same checks but using structural pattern matching.
+        match entry_block:
+            case Block(
+                predecessors=[],
+                successors=[
+                    Block(
+                        predecessors=[matched_entry_block],
+                        successors=[
+                            Block(predecessors=[matched_middle_block], successors=[])
+                        ]
+                    )
+                ]
+            ) if entry_block == matched_entry_block and middle_block == matched_middle_block:
+                assert True
+            case _:
+                assert False
+
 
 # CHECK-LABEL: TEST: testBlockCreationArgLocs
 @run
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index 9a61911185284..a4f7ffa073994 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -190,8 +190,8 @@ def testBlockArgumentList():
     """,
             ctx,
         )
-        func = module.body.operations[0]
-        entry_block = func.regions[0].blocks[0]
+        func_op = module.body.operations[0]
+        entry_block = func_op.regions[0].blocks[0]
         assert len(entry_block.arguments) == 3
         # CHECK: Argument 0, type i32
         # CHECK: Argument 1, type f64
@@ -206,6 +206,21 @@ def testBlockArgumentList():
         # CHECK: Argument 2, type i24
         for arg in entry_block.arguments:
             print(f"Argument {arg.arg_number}, type {arg.type}")
+        
+        # CHECK: Matched Arg 0, type i8
+        # CHECK: Matched Arg 1, type i16
+        # CHECK: Matched Arg 2, type i24
+        match func_op:
+            case func.FuncOp(body=Region(blocks=[Block(arguments=[a0, a1])])):
+                assert False
+            case func.FuncOp(body=Region(blocks=[Block(arguments=[a0, a1, a2, a3])])):
+                assert False
+            case func.FuncOp(body=Region(blocks=[Block(arguments=[a0, a1, a2])])):
+                print(f"Matched Arg 0, type {a0.type}")
+                print(f"Matched Arg 1, type {a1.type}")
+                print(f"Matched Arg 2, type {a2.type}")
+            case _:
+                assert False
 
         # Check that slicing works for block argument lists.
         # CHECK: Argument 1, type i16
@@ -250,14 +265,21 @@ def testOperationOperands():
         return
       }"""
         )
-        func = module.body.operations[0]
-        entry_block = func.regions[0].blocks[0]
+        func_op = module.body.operations[0]
+        entry_block = func_op.regions[0].blocks[0]
         consumer = entry_block.operations[1]
         assert len(consumer.operands) == 2
         # CHECK: Operand 0, type i32
         # CHECK: Operand 1, type i64
         for i, operand in enumerate(consumer.operands):
             print(f"Operand {i}, type {operand.type}")
+        
+        match module.body.operations:
+            case [
+                func.FuncOp(body=Region(blocks=[Block(operations=[_, OpView(operands=[o1, o2]) as matched_consumer, *_])]))
+            ]:
+                print(f"Matched Operand 0, type {o1.type}")
+                print(f"Matched Operand 1, type {o2.type}")
 
 
 # CHECK-LABEL: TEST: testOperationOperandsSlice
@@ -480,6 +502,18 @@ def testOperationResultList():
     for res in call.results:
         print(f"Result {res.result_number}, type {res.type}")
 
+    # CHECK: Matched Result r0, type i32
+    # CHECK: Matched Result r1, type f64
+    # CHECK: Matched Result r2, type index
+    match caller:
+        case func.FuncOp(body=Region(blocks=[Block(operations=[OpView(results=[r0, r1, r2]) as matched_call, *_])])):
+            assert matched_call == call
+            print(f"Matched Result r0, type {r0.type}")
+            print(f"Matched Result r1, type {r1.type}")
+            print(f"Matched Result r2, type {r2.type}")
+        case _:
+            assert False
+
     # CHECK: Result type i32
     # CHECK: Result type f64
     # CHECK: Result type index
@@ -598,6 +632,22 @@ def testOperationAttributes():
     # CHECK: Dict mapping {'dependent': 'text', 'other.attribute': 3.0, 'some.attribute': 1}
     print("Dict mapping", d)
 
+    # Structural pattern matching test using Mapping
+
+    # CHECK: Matched Mapping Attribute 'some.attribute': 1
+    # CHECK: Matched Mapping Attribute 'other.attribute': 3.0
+    # CHECK: Matched Mapping Attribute 'dependent': text
+    match op:
+        case OpView(attributes={"does_not_exist": a0}):
+            assert False
+        case OpView(attributes={"some.attribute": Attribute() as some_attr, "other.attribute": Attribute() as other_attr, "dependent": Attribute() as dep_attr}):
+            print(f"Matched Mapping Attribute 'some.attribute': {some_attr.value}")
+            print(f"Matched Mapping Attribute 'other.attribute': {other_attr.value}")
+            print(f"Matched Mapping Attribute 'dependent': {dep_attr.value}")
+        case _:
+            print("Did not match!")
+            assert False
+
     # Check that exceptions are raised as expected.
     try:
         op.attributes["does_not_exist"]

>From 8926b5f68d4fc668d792afd47bad447e92114c52 Mon Sep 17 00:00:00 2001
From: MaPePeR <MaPePeR at users.noreply.github.com>
Date: Wed, 31 Dec 2025 14:26:54 +0000
Subject: [PATCH 2/4] python code style

---
 mlir/test/python/ir/blocks.py    | 13 ++++++++-----
 mlir/test/python/ir/operation.py | 28 ++++++++++++++++++++++++----
 2 files changed, 32 insertions(+), 9 deletions(-)

diff --git a/mlir/test/python/ir/blocks.py b/mlir/test/python/ir/blocks.py
index 6c7501b11292e..1e6ecea310f19 100644
--- a/mlir/test/python/ir/blocks.py
+++ b/mlir/test/python/ir/blocks.py
@@ -80,11 +80,14 @@ def testBlockCreation():
                     Block(
                         predecessors=[matched_entry_block],
                         successors=[
-                            Block(predecessors=[matched_middle_block], successors=[])
-                        ]
-                    )
-                ]
-            ) if entry_block == matched_entry_block and middle_block == matched_middle_block:
+                            Block(predecessors=[matched_middle_block], successors=[]),
+                        ],
+                    ),
+                ],
+            ) if (
+                entry_block == matched_entry_block
+                and middle_block == matched_middle_block
+            ):
                 assert True
             case _:
                 assert False
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index a4f7ffa073994..e6a6931589929 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -206,7 +206,7 @@ def testBlockArgumentList():
         # CHECK: Argument 2, type i24
         for arg in entry_block.arguments:
             print(f"Argument {arg.arg_number}, type {arg.type}")
-        
+
         # CHECK: Matched Arg 0, type i8
         # CHECK: Matched Arg 1, type i16
         # CHECK: Matched Arg 2, type i24
@@ -273,10 +273,22 @@ def testOperationOperands():
         # CHECK: Operand 1, type i64
         for i, operand in enumerate(consumer.operands):
             print(f"Operand {i}, type {operand.type}")
-        
+
         match module.body.operations:
             case [
-                func.FuncOp(body=Region(blocks=[Block(operations=[_, OpView(operands=[o1, o2]) as matched_consumer, *_])]))
+                func.FuncOp(
+                    body=Region(
+                        blocks=[
+                            Block(
+                                operations=[
+                                    _,
+                                    OpView(operands=[o1, o2]) as matched_consumer,
+                                    *_,
+                                ],
+                            ),
+                        ],
+                    ),
+                ),
             ]:
                 print(f"Matched Operand 0, type {o1.type}")
                 print(f"Matched Operand 1, type {o2.type}")
@@ -506,7 +518,15 @@ def testOperationResultList():
     # CHECK: Matched Result r1, type f64
     # CHECK: Matched Result r2, type index
     match caller:
-        case func.FuncOp(body=Region(blocks=[Block(operations=[OpView(results=[r0, r1, r2]) as matched_call, *_])])):
+        case func.FuncOp(
+            body=Region(
+                blocks=[
+                    Block(
+                        operations=[OpView(results=[r0, r1, r2]) as matched_call, *_],
+                    ),
+                ],
+            ),
+        ):
             assert matched_call == call
             print(f"Matched Result r0, type {r0.type}")
             print(f"Matched Result r1, type {r1.type}")

>From 7479390ccc9362c3adf90fb95c41dd26022595db Mon Sep 17 00:00:00 2001
From: MaPePeR <MaPePeR at users.noreply.github.com>
Date: Wed, 31 Dec 2025 19:10:29 +0000
Subject: [PATCH 3/4] don't register Mapping for OpAttributeMap

---
 mlir/include/mlir/Bindings/Python/Nanobind.h |  1 -
 mlir/lib/Bindings/Python/IRCore.cpp          |  4 +---
 mlir/test/python/ir/operation.py             | 16 ----------------
 3 files changed, 1 insertion(+), 20 deletions(-)

diff --git a/mlir/include/mlir/Bindings/Python/Nanobind.h b/mlir/include/mlir/Bindings/Python/Nanobind.h
index 62837e8b8652b..8dc8a0d063d70 100644
--- a/mlir/include/mlir/Bindings/Python/Nanobind.h
+++ b/mlir/include/mlir/Bindings/Python/Nanobind.h
@@ -21,7 +21,6 @@
 #pragma GCC diagnostic ignored "-Wc++98-compat-extra-semi"
 #pragma GCC diagnostic ignored "-Wcovered-switch-default"
 #endif
-#include <nanobind/eval.h>
 #include <nanobind/nanobind.h>
 #include <nanobind/ndarray.h>
 #include <nanobind/stl/function.h>
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 84154ede9f3e1..1c2cf4401f7fb 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -4962,9 +4962,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
                  "Sequence.register(OpOperandList)\n"
                  "Sequence.register(OpResultList)\n"
                  "Sequence.register(OpSuccessors)\n"
-                 "Sequence.register(RegionSequence)\n"
-                 "OpAttributeMap.get = Mapping.get\n"
-                 "Mapping.register(OpAttributeMap)\n",
+                 "Sequence.register(RegionSequence)\n",
                  scope);
 
   // Debug bindings.
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index e6a6931589929..b9242a7cc2bd9 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -652,22 +652,6 @@ def testOperationAttributes():
     # CHECK: Dict mapping {'dependent': 'text', 'other.attribute': 3.0, 'some.attribute': 1}
     print("Dict mapping", d)
 
-    # Structural pattern matching test using Mapping
-
-    # CHECK: Matched Mapping Attribute 'some.attribute': 1
-    # CHECK: Matched Mapping Attribute 'other.attribute': 3.0
-    # CHECK: Matched Mapping Attribute 'dependent': text
-    match op:
-        case OpView(attributes={"does_not_exist": a0}):
-            assert False
-        case OpView(attributes={"some.attribute": Attribute() as some_attr, "other.attribute": Attribute() as other_attr, "dependent": Attribute() as dep_attr}):
-            print(f"Matched Mapping Attribute 'some.attribute': {some_attr.value}")
-            print(f"Matched Mapping Attribute 'other.attribute': {other_attr.value}")
-            print(f"Matched Mapping Attribute 'dependent': {dep_attr.value}")
-        case _:
-            print("Did not match!")
-            assert False
-
     # Check that exceptions are raised as expected.
     try:
         op.attributes["does_not_exist"]

>From a81e6c36d02cdbe8bdbfd683b62d115b3ef4af35 Mon Sep 17 00:00:00 2001
From: MaPePeR <MaPePeR at users.noreply.github.com>
Date: Wed, 31 Dec 2025 19:31:32 +0000
Subject: [PATCH 4/4] Register sequences in `mlir/_mlir_libs/__init__.py`
 instead of C++

---
 mlir/lib/Bindings/Python/IRCore.cpp     | 14 --------------
 mlir/python/mlir/_mlir_libs/__init__.py | 13 +++++++++++++
 2 files changed, 13 insertions(+), 14 deletions(-)

diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 1c2cf4401f7fb..168c57955af07 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -4951,20 +4951,6 @@ void mlir::python::populateIRCore(nb::module_ &m) {
   PyRegionIterator::bind(m);
   PyRegionList::bind(m);
 
-  // Register containers as Sequences, so they can be used with `match`.
-  nanobind::object scope = m.attr("__dict__");
-  nanobind::exec("from collections.abc import Sequence, Mapping\n"
-                 "Sequence.register(BlockArgumentList)\n"
-                 "Sequence.register(BlockList)\n"
-                 "Sequence.register(BlockSuccessors)\n"
-                 "Sequence.register(BlockPredecessors)\n"
-                 "Sequence.register(OperationList)\n"
-                 "Sequence.register(OpOperandList)\n"
-                 "Sequence.register(OpResultList)\n"
-                 "Sequence.register(OpSuccessors)\n"
-                 "Sequence.register(RegionSequence)\n",
-                 scope);
-
   // Debug bindings.
   PyGlobalDebugFlag::bind(m);
 
diff --git a/mlir/python/mlir/_mlir_libs/__init__.py b/mlir/python/mlir/_mlir_libs/__init__.py
index 63244212ba42c..3cbf5603e22f7 100644
--- a/mlir/python/mlir/_mlir_libs/__init__.py
+++ b/mlir/python/mlir/_mlir_libs/__init__.py
@@ -31,6 +31,7 @@ def get_include_dirs() -> Sequence[str]:
 #   1. Attempting to load initializer modules, specific to the distribution.
 #   2. Defining the concrete mlir.ir.Context that does site specific
 #      initialization.
+#   3. Registering container classes.
 #
 # Aside from just being far more convenient to do this at the Python level,
 # it is actually quite hard/impossible to have such __init__ hooks, given
@@ -233,5 +234,17 @@ def __str__(self):
 
     ir.MLIRError = MLIRError
 
+    # Register containers as Sequences, so they can be used with `match`.
+
+    Sequence.register(ir.BlockArgumentList)
+    Sequence.register(ir.BlockList)
+    Sequence.register(ir.BlockSuccessors)
+    Sequence.register(ir.BlockPredecessors)
+    Sequence.register(ir.OperationList)
+    Sequence.register(ir.OpOperandList)
+    Sequence.register(ir.OpResultList)
+    Sequence.register(ir.OpSuccessors)
+    Sequence.register(ir.RegionSequence)
+
 
 _site_initialize()



More information about the Mlir-commits mailing list