[Mlir-commits] [mlir] [MLIR][Python] Register Containers as Sequences for `match` compatibility (PR #174091)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Dec 31 05:31:22 PST 2025


https://github.com/MaPePeR created https://github.com/llvm/llvm-project/pull/174091

This allows these containers to be used in `match` statements, which allows extracting properties and asserting a shape at the same time.

It seems to be only possible, to match as _either_ a `Mapping` _or_ a `Sequence`, so the `OpAttributeMap` is only a `Mapping`.

I couldn't find a way to make these C++ based types properly inherit from `Sequence` or `Mapping`, so the Mixins are not provided (nanobind only allows C++ parent classes, modifying `__base__` complains about differing destructors).
`OpAttributeMap` was lacking the `get` method, so I simply copied it from `collections.abc.Mapping`.

When writing the tests i ran into the error, that I wrote `func.FuncOp(body=[Block(...)])` instead of `func.FuncOp(body=Region(blocks=[Block(...)]))`. So maybe also turning `Region` itself into a Sequence would be a good addition as well? Would extend the Scope of this PR, though.

@makslevental You suggested I make the PR, so i'm tagging you here as a potential reviewer. I hope that is ok with you. :)

>From f773e8d4f18a1849333a071c187d426581edf667 Mon Sep 17 00:00:00 2001
From: MaPePeR <MaPePeR at users.noreply.github.com>
Date: Wed, 31 Dec 2025 00:19:54 +0000
Subject: [PATCH] [MLIR][Python] Register Containers as Sequences

This allows them to be used in `match` statements.
---
 mlir/include/mlir/Bindings/Python/Nanobind.h |  1 +
 mlir/lib/Bindings/Python/IRCore.cpp          | 16 ++++++
 mlir/test/python/ir/blocks.py                | 17 ++++++
 mlir/test/python/ir/operation.py             | 58 ++++++++++++++++++--
 4 files changed, 88 insertions(+), 4 deletions(-)

diff --git a/mlir/include/mlir/Bindings/Python/Nanobind.h b/mlir/include/mlir/Bindings/Python/Nanobind.h
index 8dc8a0d063d70..62837e8b8652b 100644
--- a/mlir/include/mlir/Bindings/Python/Nanobind.h
+++ b/mlir/include/mlir/Bindings/Python/Nanobind.h
@@ -21,6 +21,7 @@
 #pragma GCC diagnostic ignored "-Wc++98-compat-extra-semi"
 #pragma GCC diagnostic ignored "-Wcovered-switch-default"
 #endif
+#include <nanobind/eval.h>
 #include <nanobind/nanobind.h>
 #include <nanobind/ndarray.h>
 #include <nanobind/stl/function.h>
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 168c57955af07..84154ede9f3e1 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -4951,6 +4951,22 @@ void mlir::python::populateIRCore(nb::module_ &m) {
   PyRegionIterator::bind(m);
   PyRegionList::bind(m);
 
+  // Register containers as Sequences, so they can be used with `match`.
+  nanobind::object scope = m.attr("__dict__");
+  nanobind::exec("from collections.abc import Sequence, Mapping\n"
+                 "Sequence.register(BlockArgumentList)\n"
+                 "Sequence.register(BlockList)\n"
+                 "Sequence.register(BlockSuccessors)\n"
+                 "Sequence.register(BlockPredecessors)\n"
+                 "Sequence.register(OperationList)\n"
+                 "Sequence.register(OpOperandList)\n"
+                 "Sequence.register(OpResultList)\n"
+                 "Sequence.register(OpSuccessors)\n"
+                 "Sequence.register(RegionSequence)\n"
+                 "OpAttributeMap.get = Mapping.get\n"
+                 "Mapping.register(OpAttributeMap)\n",
+                 scope);
+
   // Debug bindings.
   PyGlobalDebugFlag::bind(m);
 
diff --git a/mlir/test/python/ir/blocks.py b/mlir/test/python/ir/blocks.py
index e876c00e0c52d..6c7501b11292e 100644
--- a/mlir/test/python/ir/blocks.py
+++ b/mlir/test/python/ir/blocks.py
@@ -72,6 +72,23 @@ def testBlockCreation():
 
         assert len(successor_block.successors) == 0
 
+        # Same checks but using structural pattern matching.
+        match entry_block:
+            case Block(
+                predecessors=[],
+                successors=[
+                    Block(
+                        predecessors=[matched_entry_block],
+                        successors=[
+                            Block(predecessors=[matched_middle_block], successors=[])
+                        ]
+                    )
+                ]
+            ) if entry_block == matched_entry_block and middle_block == matched_middle_block:
+                assert True
+            case _:
+                assert False
+
 
 # CHECK-LABEL: TEST: testBlockCreationArgLocs
 @run
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index 9a61911185284..a4f7ffa073994 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -190,8 +190,8 @@ def testBlockArgumentList():
     """,
             ctx,
         )
-        func = module.body.operations[0]
-        entry_block = func.regions[0].blocks[0]
+        func_op = module.body.operations[0]
+        entry_block = func_op.regions[0].blocks[0]
         assert len(entry_block.arguments) == 3
         # CHECK: Argument 0, type i32
         # CHECK: Argument 1, type f64
@@ -206,6 +206,21 @@ def testBlockArgumentList():
         # CHECK: Argument 2, type i24
         for arg in entry_block.arguments:
             print(f"Argument {arg.arg_number}, type {arg.type}")
+        
+        # CHECK: Matched Arg 0, type i8
+        # CHECK: Matched Arg 1, type i16
+        # CHECK: Matched Arg 2, type i24
+        match func_op:
+            case func.FuncOp(body=Region(blocks=[Block(arguments=[a0, a1])])):
+                assert False
+            case func.FuncOp(body=Region(blocks=[Block(arguments=[a0, a1, a2, a3])])):
+                assert False
+            case func.FuncOp(body=Region(blocks=[Block(arguments=[a0, a1, a2])])):
+                print(f"Matched Arg 0, type {a0.type}")
+                print(f"Matched Arg 1, type {a1.type}")
+                print(f"Matched Arg 2, type {a2.type}")
+            case _:
+                assert False
 
         # Check that slicing works for block argument lists.
         # CHECK: Argument 1, type i16
@@ -250,14 +265,21 @@ def testOperationOperands():
         return
       }"""
         )
-        func = module.body.operations[0]
-        entry_block = func.regions[0].blocks[0]
+        func_op = module.body.operations[0]
+        entry_block = func_op.regions[0].blocks[0]
         consumer = entry_block.operations[1]
         assert len(consumer.operands) == 2
         # CHECK: Operand 0, type i32
         # CHECK: Operand 1, type i64
         for i, operand in enumerate(consumer.operands):
             print(f"Operand {i}, type {operand.type}")
+        
+        match module.body.operations:
+            case [
+                func.FuncOp(body=Region(blocks=[Block(operations=[_, OpView(operands=[o1, o2]) as matched_consumer, *_])]))
+            ]:
+                print(f"Matched Operand 0, type {o1.type}")
+                print(f"Matched Operand 1, type {o2.type}")
 
 
 # CHECK-LABEL: TEST: testOperationOperandsSlice
@@ -480,6 +502,18 @@ def testOperationResultList():
     for res in call.results:
         print(f"Result {res.result_number}, type {res.type}")
 
+    # CHECK: Matched Result r0, type i32
+    # CHECK: Matched Result r1, type f64
+    # CHECK: Matched Result r2, type index
+    match caller:
+        case func.FuncOp(body=Region(blocks=[Block(operations=[OpView(results=[r0, r1, r2]) as matched_call, *_])])):
+            assert matched_call == call
+            print(f"Matched Result r0, type {r0.type}")
+            print(f"Matched Result r1, type {r1.type}")
+            print(f"Matched Result r2, type {r2.type}")
+        case _:
+            assert False
+
     # CHECK: Result type i32
     # CHECK: Result type f64
     # CHECK: Result type index
@@ -598,6 +632,22 @@ def testOperationAttributes():
     # CHECK: Dict mapping {'dependent': 'text', 'other.attribute': 3.0, 'some.attribute': 1}
     print("Dict mapping", d)
 
+    # Structural pattern matching test using Mapping
+
+    # CHECK: Matched Mapping Attribute 'some.attribute': 1
+    # CHECK: Matched Mapping Attribute 'other.attribute': 3.0
+    # CHECK: Matched Mapping Attribute 'dependent': text
+    match op:
+        case OpView(attributes={"does_not_exist": a0}):
+            assert False
+        case OpView(attributes={"some.attribute": Attribute() as some_attr, "other.attribute": Attribute() as other_attr, "dependent": Attribute() as dep_attr}):
+            print(f"Matched Mapping Attribute 'some.attribute': {some_attr.value}")
+            print(f"Matched Mapping Attribute 'other.attribute': {other_attr.value}")
+            print(f"Matched Mapping Attribute 'dependent': {dep_attr.value}")
+        case _:
+            print("Did not match!")
+            assert False
+
     # Check that exceptions are raised as expected.
     try:
         op.attributes["does_not_exist"]



More information about the Mlir-commits mailing list