[Mlir-commits] [mlir] 6d8dd3d - [MLIR][Python] Register Containers as Sequences for `match` compatibility (#174091)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Jan 3 09:56:28 PST 2026
Author: MaPePeR
Date: 2026-01-03T09:56:24-08:00
New Revision: 6d8dd3da4beb50566c1dc1c2db51e75cc827a491
URL: https://github.com/llvm/llvm-project/commit/6d8dd3da4beb50566c1dc1c2db51e75cc827a491
DIFF: https://github.com/llvm/llvm-project/commit/6d8dd3da4beb50566c1dc1c2db51e75cc827a491.diff
LOG: [MLIR][Python] Register Containers as Sequences for `match` compatibility (#174091)
This allows these containers to be used in `match` statements, which
allows extracting properties and asserting a shape at the same time.
It seems to be only possible, to match as _either_ a `Mapping` _or_ a
`Sequence`, so the `OpAttributeMap` is only a `Mapping`.
I couldn't find a way to make these C++ based types properly inherit
from `Sequence` or `Mapping`, so the Mixins are not provided (nanobind
only allows C++ parent classes, modifying `__base__` complains about
differing destructors).
`OpAttributeMap` was lacking the `get` method, so I simply copied it
from `collections.abc.Mapping`.
When writing the tests i ran into the error, that I wrote
`func.FuncOp(body=[Block(...)])` instead of
`func.FuncOp(body=Region(blocks=[Block(...)]))`. So maybe also turning
`Region` itself into a Sequence would be a good addition as well? Would
extend the Scope of this PR, though.
makslevental You suggested I make the PR, so i'm tagging you here as a
potential reviewer. I hope that is ok with you. :)
---------
Co-authored-by: Maksim Levental <maksim.levental at gmail.com>
Added:
Modified:
mlir/python/mlir/_mlir_libs/__init__.py
mlir/test/python/ir/blocks.py
mlir/test/python/ir/operation.py
Removed:
################################################################################
diff --git a/mlir/python/mlir/_mlir_libs/__init__.py b/mlir/python/mlir/_mlir_libs/__init__.py
index 63244212ba42c..ce7e6bf93012a 100644
--- a/mlir/python/mlir/_mlir_libs/__init__.py
+++ b/mlir/python/mlir/_mlir_libs/__init__.py
@@ -31,6 +31,7 @@ def get_include_dirs() -> Sequence[str]:
# 1. Attempting to load initializer modules, specific to the distribution.
# 2. Defining the concrete mlir.ir.Context that does site specific
# initialization.
+# 3. Registering container classes with their respective protocols.
#
# Aside from just being far more convenient to do this at the Python level,
# it is actually quite hard/impossible to have such __init__ hooks, given
@@ -233,5 +234,17 @@ def __str__(self):
ir.MLIRError = MLIRError
+ # Register containers as Sequences, so they can be used with `match`.
+
+ Sequence.register(ir.BlockArgumentList)
+ Sequence.register(ir.BlockList)
+ Sequence.register(ir.BlockSuccessors)
+ Sequence.register(ir.BlockPredecessors)
+ Sequence.register(ir.OperationList)
+ Sequence.register(ir.OpOperandList)
+ Sequence.register(ir.OpResultList)
+ Sequence.register(ir.OpSuccessors)
+ Sequence.register(ir.RegionSequence)
+
_site_initialize()
diff --git a/mlir/test/python/ir/blocks.py b/mlir/test/python/ir/blocks.py
index e876c00e0c52d..1e6ecea310f19 100644
--- a/mlir/test/python/ir/blocks.py
+++ b/mlir/test/python/ir/blocks.py
@@ -72,6 +72,26 @@ def testBlockCreation():
assert len(successor_block.successors) == 0
+ # Same checks but using structural pattern matching.
+ match entry_block:
+ case Block(
+ predecessors=[],
+ successors=[
+ Block(
+ predecessors=[matched_entry_block],
+ successors=[
+ Block(predecessors=[matched_middle_block], successors=[]),
+ ],
+ ),
+ ],
+ ) if (
+ entry_block == matched_entry_block
+ and middle_block == matched_middle_block
+ ):
+ assert True
+ case _:
+ assert False
+
# CHECK-LABEL: TEST: testBlockCreationArgLocs
@run
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index 9a61911185284..b9242a7cc2bd9 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -190,8 +190,8 @@ def testBlockArgumentList():
""",
ctx,
)
- func = module.body.operations[0]
- entry_block = func.regions[0].blocks[0]
+ func_op = module.body.operations[0]
+ entry_block = func_op.regions[0].blocks[0]
assert len(entry_block.arguments) == 3
# CHECK: Argument 0, type i32
# CHECK: Argument 1, type f64
@@ -207,6 +207,21 @@ def testBlockArgumentList():
for arg in entry_block.arguments:
print(f"Argument {arg.arg_number}, type {arg.type}")
+ # CHECK: Matched Arg 0, type i8
+ # CHECK: Matched Arg 1, type i16
+ # CHECK: Matched Arg 2, type i24
+ match func_op:
+ case func.FuncOp(body=Region(blocks=[Block(arguments=[a0, a1])])):
+ assert False
+ case func.FuncOp(body=Region(blocks=[Block(arguments=[a0, a1, a2, a3])])):
+ assert False
+ case func.FuncOp(body=Region(blocks=[Block(arguments=[a0, a1, a2])])):
+ print(f"Matched Arg 0, type {a0.type}")
+ print(f"Matched Arg 1, type {a1.type}")
+ print(f"Matched Arg 2, type {a2.type}")
+ case _:
+ assert False
+
# Check that slicing works for block argument lists.
# CHECK: Argument 1, type i16
# CHECK: Argument 2, type i24
@@ -250,8 +265,8 @@ def testOperationOperands():
return
}"""
)
- func = module.body.operations[0]
- entry_block = func.regions[0].blocks[0]
+ func_op = module.body.operations[0]
+ entry_block = func_op.regions[0].blocks[0]
consumer = entry_block.operations[1]
assert len(consumer.operands) == 2
# CHECK: Operand 0, type i32
@@ -259,6 +274,25 @@ def testOperationOperands():
for i, operand in enumerate(consumer.operands):
print(f"Operand {i}, type {operand.type}")
+ match module.body.operations:
+ case [
+ func.FuncOp(
+ body=Region(
+ blocks=[
+ Block(
+ operations=[
+ _,
+ OpView(operands=[o1, o2]) as matched_consumer,
+ *_,
+ ],
+ ),
+ ],
+ ),
+ ),
+ ]:
+ print(f"Matched Operand 0, type {o1.type}")
+ print(f"Matched Operand 1, type {o2.type}")
+
# CHECK-LABEL: TEST: testOperationOperandsSlice
@run
@@ -480,6 +514,26 @@ def testOperationResultList():
for res in call.results:
print(f"Result {res.result_number}, type {res.type}")
+ # CHECK: Matched Result r0, type i32
+ # CHECK: Matched Result r1, type f64
+ # CHECK: Matched Result r2, type index
+ match caller:
+ case func.FuncOp(
+ body=Region(
+ blocks=[
+ Block(
+ operations=[OpView(results=[r0, r1, r2]) as matched_call, *_],
+ ),
+ ],
+ ),
+ ):
+ assert matched_call == call
+ print(f"Matched Result r0, type {r0.type}")
+ print(f"Matched Result r1, type {r1.type}")
+ print(f"Matched Result r2, type {r2.type}")
+ case _:
+ assert False
+
# CHECK: Result type i32
# CHECK: Result type f64
# CHECK: Result type index
More information about the Mlir-commits
mailing list