[Mlir-commits] [mlir] 6d8dd3d - [MLIR][Python] Register Containers as Sequences for `match` compatibility (#174091)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Jan 3 09:56:28 PST 2026


Author: MaPePeR
Date: 2026-01-03T09:56:24-08:00
New Revision: 6d8dd3da4beb50566c1dc1c2db51e75cc827a491

URL: https://github.com/llvm/llvm-project/commit/6d8dd3da4beb50566c1dc1c2db51e75cc827a491
DIFF: https://github.com/llvm/llvm-project/commit/6d8dd3da4beb50566c1dc1c2db51e75cc827a491.diff

LOG: [MLIR][Python] Register Containers as Sequences for `match` compatibility (#174091)

This allows these containers to be used in `match` statements, which
allows extracting properties and asserting a shape at the same time.

It seems to be only possible, to match as _either_ a `Mapping` _or_ a
`Sequence`, so the `OpAttributeMap` is only a `Mapping`.

I couldn't find a way to make these C++ based types properly inherit
from `Sequence` or `Mapping`, so the Mixins are not provided (nanobind
only allows C++ parent classes, modifying `__base__` complains about
differing destructors).
`OpAttributeMap` was lacking the `get` method, so I simply copied it
from `collections.abc.Mapping`.

When writing the tests i ran into the error, that I wrote
`func.FuncOp(body=[Block(...)])` instead of
`func.FuncOp(body=Region(blocks=[Block(...)]))`. So maybe also turning
`Region` itself into a Sequence would be a good addition as well? Would
extend the Scope of this PR, though.

makslevental You suggested I make the PR, so i'm tagging you here as a
potential reviewer. I hope that is ok with you. :)

---------

Co-authored-by: Maksim Levental <maksim.levental at gmail.com>

Added: 
    

Modified: 
    mlir/python/mlir/_mlir_libs/__init__.py
    mlir/test/python/ir/blocks.py
    mlir/test/python/ir/operation.py

Removed: 
    


################################################################################
diff  --git a/mlir/python/mlir/_mlir_libs/__init__.py b/mlir/python/mlir/_mlir_libs/__init__.py
index 63244212ba42c..ce7e6bf93012a 100644
--- a/mlir/python/mlir/_mlir_libs/__init__.py
+++ b/mlir/python/mlir/_mlir_libs/__init__.py
@@ -31,6 +31,7 @@ def get_include_dirs() -> Sequence[str]:
 #   1. Attempting to load initializer modules, specific to the distribution.
 #   2. Defining the concrete mlir.ir.Context that does site specific
 #      initialization.
+#   3. Registering container classes with their respective protocols.
 #
 # Aside from just being far more convenient to do this at the Python level,
 # it is actually quite hard/impossible to have such __init__ hooks, given
@@ -233,5 +234,17 @@ def __str__(self):
 
     ir.MLIRError = MLIRError
 
+    # Register containers as Sequences, so they can be used with `match`.
+
+    Sequence.register(ir.BlockArgumentList)
+    Sequence.register(ir.BlockList)
+    Sequence.register(ir.BlockSuccessors)
+    Sequence.register(ir.BlockPredecessors)
+    Sequence.register(ir.OperationList)
+    Sequence.register(ir.OpOperandList)
+    Sequence.register(ir.OpResultList)
+    Sequence.register(ir.OpSuccessors)
+    Sequence.register(ir.RegionSequence)
+
 
 _site_initialize()

diff  --git a/mlir/test/python/ir/blocks.py b/mlir/test/python/ir/blocks.py
index e876c00e0c52d..1e6ecea310f19 100644
--- a/mlir/test/python/ir/blocks.py
+++ b/mlir/test/python/ir/blocks.py
@@ -72,6 +72,26 @@ def testBlockCreation():
 
         assert len(successor_block.successors) == 0
 
+        # Same checks but using structural pattern matching.
+        match entry_block:
+            case Block(
+                predecessors=[],
+                successors=[
+                    Block(
+                        predecessors=[matched_entry_block],
+                        successors=[
+                            Block(predecessors=[matched_middle_block], successors=[]),
+                        ],
+                    ),
+                ],
+            ) if (
+                entry_block == matched_entry_block
+                and middle_block == matched_middle_block
+            ):
+                assert True
+            case _:
+                assert False
+
 
 # CHECK-LABEL: TEST: testBlockCreationArgLocs
 @run

diff  --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index 9a61911185284..b9242a7cc2bd9 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -190,8 +190,8 @@ def testBlockArgumentList():
     """,
             ctx,
         )
-        func = module.body.operations[0]
-        entry_block = func.regions[0].blocks[0]
+        func_op = module.body.operations[0]
+        entry_block = func_op.regions[0].blocks[0]
         assert len(entry_block.arguments) == 3
         # CHECK: Argument 0, type i32
         # CHECK: Argument 1, type f64
@@ -207,6 +207,21 @@ def testBlockArgumentList():
         for arg in entry_block.arguments:
             print(f"Argument {arg.arg_number}, type {arg.type}")
 
+        # CHECK: Matched Arg 0, type i8
+        # CHECK: Matched Arg 1, type i16
+        # CHECK: Matched Arg 2, type i24
+        match func_op:
+            case func.FuncOp(body=Region(blocks=[Block(arguments=[a0, a1])])):
+                assert False
+            case func.FuncOp(body=Region(blocks=[Block(arguments=[a0, a1, a2, a3])])):
+                assert False
+            case func.FuncOp(body=Region(blocks=[Block(arguments=[a0, a1, a2])])):
+                print(f"Matched Arg 0, type {a0.type}")
+                print(f"Matched Arg 1, type {a1.type}")
+                print(f"Matched Arg 2, type {a2.type}")
+            case _:
+                assert False
+
         # Check that slicing works for block argument lists.
         # CHECK: Argument 1, type i16
         # CHECK: Argument 2, type i24
@@ -250,8 +265,8 @@ def testOperationOperands():
         return
       }"""
         )
-        func = module.body.operations[0]
-        entry_block = func.regions[0].blocks[0]
+        func_op = module.body.operations[0]
+        entry_block = func_op.regions[0].blocks[0]
         consumer = entry_block.operations[1]
         assert len(consumer.operands) == 2
         # CHECK: Operand 0, type i32
@@ -259,6 +274,25 @@ def testOperationOperands():
         for i, operand in enumerate(consumer.operands):
             print(f"Operand {i}, type {operand.type}")
 
+        match module.body.operations:
+            case [
+                func.FuncOp(
+                    body=Region(
+                        blocks=[
+                            Block(
+                                operations=[
+                                    _,
+                                    OpView(operands=[o1, o2]) as matched_consumer,
+                                    *_,
+                                ],
+                            ),
+                        ],
+                    ),
+                ),
+            ]:
+                print(f"Matched Operand 0, type {o1.type}")
+                print(f"Matched Operand 1, type {o2.type}")
+
 
 # CHECK-LABEL: TEST: testOperationOperandsSlice
 @run
@@ -480,6 +514,26 @@ def testOperationResultList():
     for res in call.results:
         print(f"Result {res.result_number}, type {res.type}")
 
+    # CHECK: Matched Result r0, type i32
+    # CHECK: Matched Result r1, type f64
+    # CHECK: Matched Result r2, type index
+    match caller:
+        case func.FuncOp(
+            body=Region(
+                blocks=[
+                    Block(
+                        operations=[OpView(results=[r0, r1, r2]) as matched_call, *_],
+                    ),
+                ],
+            ),
+        ):
+            assert matched_call == call
+            print(f"Matched Result r0, type {r0.type}")
+            print(f"Matched Result r1, type {r1.type}")
+            print(f"Matched Result r2, type {r2.type}")
+        case _:
+            assert False
+
     # CHECK: Result type i32
     # CHECK: Result type f64
     # CHECK: Result type index


        


More information about the Mlir-commits mailing list