[Mlir-commits] [mlir] [MLIR][Python] Register Containers as Sequences for `match` compatibility (PR #174091)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jan 1 01:59:10 PST 2026
https://github.com/MaPePeR updated https://github.com/llvm/llvm-project/pull/174091
>From f773e8d4f18a1849333a071c187d426581edf667 Mon Sep 17 00:00:00 2001
From: MaPePeR <MaPePeR at users.noreply.github.com>
Date: Wed, 31 Dec 2025 00:19:54 +0000
Subject: [PATCH 1/4] [MLIR][Python] Register Containers as Sequences
This allows them to be used in `match` statements.
---
mlir/include/mlir/Bindings/Python/Nanobind.h | 1 +
mlir/lib/Bindings/Python/IRCore.cpp | 16 ++++++
mlir/test/python/ir/blocks.py | 17 ++++++
mlir/test/python/ir/operation.py | 58 ++++++++++++++++++--
4 files changed, 88 insertions(+), 4 deletions(-)
diff --git a/mlir/include/mlir/Bindings/Python/Nanobind.h b/mlir/include/mlir/Bindings/Python/Nanobind.h
index 8dc8a0d063d70..62837e8b8652b 100644
--- a/mlir/include/mlir/Bindings/Python/Nanobind.h
+++ b/mlir/include/mlir/Bindings/Python/Nanobind.h
@@ -21,6 +21,7 @@
#pragma GCC diagnostic ignored "-Wc++98-compat-extra-semi"
#pragma GCC diagnostic ignored "-Wcovered-switch-default"
#endif
+#include <nanobind/eval.h>
#include <nanobind/nanobind.h>
#include <nanobind/ndarray.h>
#include <nanobind/stl/function.h>
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 168c57955af07..84154ede9f3e1 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -4951,6 +4951,22 @@ void mlir::python::populateIRCore(nb::module_ &m) {
PyRegionIterator::bind(m);
PyRegionList::bind(m);
+ // Register containers as Sequences, so they can be used with `match`.
+ nanobind::object scope = m.attr("__dict__");
+ nanobind::exec("from collections.abc import Sequence, Mapping\n"
+ "Sequence.register(BlockArgumentList)\n"
+ "Sequence.register(BlockList)\n"
+ "Sequence.register(BlockSuccessors)\n"
+ "Sequence.register(BlockPredecessors)\n"
+ "Sequence.register(OperationList)\n"
+ "Sequence.register(OpOperandList)\n"
+ "Sequence.register(OpResultList)\n"
+ "Sequence.register(OpSuccessors)\n"
+ "Sequence.register(RegionSequence)\n"
+ "OpAttributeMap.get = Mapping.get\n"
+ "Mapping.register(OpAttributeMap)\n",
+ scope);
+
// Debug bindings.
PyGlobalDebugFlag::bind(m);
diff --git a/mlir/test/python/ir/blocks.py b/mlir/test/python/ir/blocks.py
index e876c00e0c52d..6c7501b11292e 100644
--- a/mlir/test/python/ir/blocks.py
+++ b/mlir/test/python/ir/blocks.py
@@ -72,6 +72,23 @@ def testBlockCreation():
assert len(successor_block.successors) == 0
+ # Same checks but using structural pattern matching.
+ match entry_block:
+ case Block(
+ predecessors=[],
+ successors=[
+ Block(
+ predecessors=[matched_entry_block],
+ successors=[
+ Block(predecessors=[matched_middle_block], successors=[])
+ ]
+ )
+ ]
+ ) if entry_block == matched_entry_block and middle_block == matched_middle_block:
+ assert True
+ case _:
+ assert False
+
# CHECK-LABEL: TEST: testBlockCreationArgLocs
@run
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index 9a61911185284..a4f7ffa073994 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -190,8 +190,8 @@ def testBlockArgumentList():
""",
ctx,
)
- func = module.body.operations[0]
- entry_block = func.regions[0].blocks[0]
+ func_op = module.body.operations[0]
+ entry_block = func_op.regions[0].blocks[0]
assert len(entry_block.arguments) == 3
# CHECK: Argument 0, type i32
# CHECK: Argument 1, type f64
@@ -206,6 +206,21 @@ def testBlockArgumentList():
# CHECK: Argument 2, type i24
for arg in entry_block.arguments:
print(f"Argument {arg.arg_number}, type {arg.type}")
+
+ # CHECK: Matched Arg 0, type i8
+ # CHECK: Matched Arg 1, type i16
+ # CHECK: Matched Arg 2, type i24
+ match func_op:
+ case func.FuncOp(body=Region(blocks=[Block(arguments=[a0, a1])])):
+ assert False
+ case func.FuncOp(body=Region(blocks=[Block(arguments=[a0, a1, a2, a3])])):
+ assert False
+ case func.FuncOp(body=Region(blocks=[Block(arguments=[a0, a1, a2])])):
+ print(f"Matched Arg 0, type {a0.type}")
+ print(f"Matched Arg 1, type {a1.type}")
+ print(f"Matched Arg 2, type {a2.type}")
+ case _:
+ assert False
# Check that slicing works for block argument lists.
# CHECK: Argument 1, type i16
@@ -250,14 +265,21 @@ def testOperationOperands():
return
}"""
)
- func = module.body.operations[0]
- entry_block = func.regions[0].blocks[0]
+ func_op = module.body.operations[0]
+ entry_block = func_op.regions[0].blocks[0]
consumer = entry_block.operations[1]
assert len(consumer.operands) == 2
# CHECK: Operand 0, type i32
# CHECK: Operand 1, type i64
for i, operand in enumerate(consumer.operands):
print(f"Operand {i}, type {operand.type}")
+
+ match module.body.operations:
+ case [
+ func.FuncOp(body=Region(blocks=[Block(operations=[_, OpView(operands=[o1, o2]) as matched_consumer, *_])]))
+ ]:
+ print(f"Matched Operand 0, type {o1.type}")
+ print(f"Matched Operand 1, type {o2.type}")
# CHECK-LABEL: TEST: testOperationOperandsSlice
@@ -480,6 +502,18 @@ def testOperationResultList():
for res in call.results:
print(f"Result {res.result_number}, type {res.type}")
+ # CHECK: Matched Result r0, type i32
+ # CHECK: Matched Result r1, type f64
+ # CHECK: Matched Result r2, type index
+ match caller:
+ case func.FuncOp(body=Region(blocks=[Block(operations=[OpView(results=[r0, r1, r2]) as matched_call, *_])])):
+ assert matched_call == call
+ print(f"Matched Result r0, type {r0.type}")
+ print(f"Matched Result r1, type {r1.type}")
+ print(f"Matched Result r2, type {r2.type}")
+ case _:
+ assert False
+
# CHECK: Result type i32
# CHECK: Result type f64
# CHECK: Result type index
@@ -598,6 +632,22 @@ def testOperationAttributes():
# CHECK: Dict mapping {'dependent': 'text', 'other.attribute': 3.0, 'some.attribute': 1}
print("Dict mapping", d)
+ # Structural pattern matching test using Mapping
+
+ # CHECK: Matched Mapping Attribute 'some.attribute': 1
+ # CHECK: Matched Mapping Attribute 'other.attribute': 3.0
+ # CHECK: Matched Mapping Attribute 'dependent': text
+ match op:
+ case OpView(attributes={"does_not_exist": a0}):
+ assert False
+ case OpView(attributes={"some.attribute": Attribute() as some_attr, "other.attribute": Attribute() as other_attr, "dependent": Attribute() as dep_attr}):
+ print(f"Matched Mapping Attribute 'some.attribute': {some_attr.value}")
+ print(f"Matched Mapping Attribute 'other.attribute': {other_attr.value}")
+ print(f"Matched Mapping Attribute 'dependent': {dep_attr.value}")
+ case _:
+ print("Did not match!")
+ assert False
+
# Check that exceptions are raised as expected.
try:
op.attributes["does_not_exist"]
>From 8926b5f68d4fc668d792afd47bad447e92114c52 Mon Sep 17 00:00:00 2001
From: MaPePeR <MaPePeR at users.noreply.github.com>
Date: Wed, 31 Dec 2025 14:26:54 +0000
Subject: [PATCH 2/4] python code style
---
mlir/test/python/ir/blocks.py | 13 ++++++++-----
mlir/test/python/ir/operation.py | 28 ++++++++++++++++++++++++----
2 files changed, 32 insertions(+), 9 deletions(-)
diff --git a/mlir/test/python/ir/blocks.py b/mlir/test/python/ir/blocks.py
index 6c7501b11292e..1e6ecea310f19 100644
--- a/mlir/test/python/ir/blocks.py
+++ b/mlir/test/python/ir/blocks.py
@@ -80,11 +80,14 @@ def testBlockCreation():
Block(
predecessors=[matched_entry_block],
successors=[
- Block(predecessors=[matched_middle_block], successors=[])
- ]
- )
- ]
- ) if entry_block == matched_entry_block and middle_block == matched_middle_block:
+ Block(predecessors=[matched_middle_block], successors=[]),
+ ],
+ ),
+ ],
+ ) if (
+ entry_block == matched_entry_block
+ and middle_block == matched_middle_block
+ ):
assert True
case _:
assert False
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index a4f7ffa073994..e6a6931589929 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -206,7 +206,7 @@ def testBlockArgumentList():
# CHECK: Argument 2, type i24
for arg in entry_block.arguments:
print(f"Argument {arg.arg_number}, type {arg.type}")
-
+
# CHECK: Matched Arg 0, type i8
# CHECK: Matched Arg 1, type i16
# CHECK: Matched Arg 2, type i24
@@ -273,10 +273,22 @@ def testOperationOperands():
# CHECK: Operand 1, type i64
for i, operand in enumerate(consumer.operands):
print(f"Operand {i}, type {operand.type}")
-
+
match module.body.operations:
case [
- func.FuncOp(body=Region(blocks=[Block(operations=[_, OpView(operands=[o1, o2]) as matched_consumer, *_])]))
+ func.FuncOp(
+ body=Region(
+ blocks=[
+ Block(
+ operations=[
+ _,
+ OpView(operands=[o1, o2]) as matched_consumer,
+ *_,
+ ],
+ ),
+ ],
+ ),
+ ),
]:
print(f"Matched Operand 0, type {o1.type}")
print(f"Matched Operand 1, type {o2.type}")
@@ -506,7 +518,15 @@ def testOperationResultList():
# CHECK: Matched Result r1, type f64
# CHECK: Matched Result r2, type index
match caller:
- case func.FuncOp(body=Region(blocks=[Block(operations=[OpView(results=[r0, r1, r2]) as matched_call, *_])])):
+ case func.FuncOp(
+ body=Region(
+ blocks=[
+ Block(
+ operations=[OpView(results=[r0, r1, r2]) as matched_call, *_],
+ ),
+ ],
+ ),
+ ):
assert matched_call == call
print(f"Matched Result r0, type {r0.type}")
print(f"Matched Result r1, type {r1.type}")
>From 7479390ccc9362c3adf90fb95c41dd26022595db Mon Sep 17 00:00:00 2001
From: MaPePeR <MaPePeR at users.noreply.github.com>
Date: Wed, 31 Dec 2025 19:10:29 +0000
Subject: [PATCH 3/4] don't register Mapping for OpAttributeMap
---
mlir/include/mlir/Bindings/Python/Nanobind.h | 1 -
mlir/lib/Bindings/Python/IRCore.cpp | 4 +---
mlir/test/python/ir/operation.py | 16 ----------------
3 files changed, 1 insertion(+), 20 deletions(-)
diff --git a/mlir/include/mlir/Bindings/Python/Nanobind.h b/mlir/include/mlir/Bindings/Python/Nanobind.h
index 62837e8b8652b..8dc8a0d063d70 100644
--- a/mlir/include/mlir/Bindings/Python/Nanobind.h
+++ b/mlir/include/mlir/Bindings/Python/Nanobind.h
@@ -21,7 +21,6 @@
#pragma GCC diagnostic ignored "-Wc++98-compat-extra-semi"
#pragma GCC diagnostic ignored "-Wcovered-switch-default"
#endif
-#include <nanobind/eval.h>
#include <nanobind/nanobind.h>
#include <nanobind/ndarray.h>
#include <nanobind/stl/function.h>
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 84154ede9f3e1..1c2cf4401f7fb 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -4962,9 +4962,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
"Sequence.register(OpOperandList)\n"
"Sequence.register(OpResultList)\n"
"Sequence.register(OpSuccessors)\n"
- "Sequence.register(RegionSequence)\n"
- "OpAttributeMap.get = Mapping.get\n"
- "Mapping.register(OpAttributeMap)\n",
+ "Sequence.register(RegionSequence)\n",
scope);
// Debug bindings.
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index e6a6931589929..b9242a7cc2bd9 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -652,22 +652,6 @@ def testOperationAttributes():
# CHECK: Dict mapping {'dependent': 'text', 'other.attribute': 3.0, 'some.attribute': 1}
print("Dict mapping", d)
- # Structural pattern matching test using Mapping
-
- # CHECK: Matched Mapping Attribute 'some.attribute': 1
- # CHECK: Matched Mapping Attribute 'other.attribute': 3.0
- # CHECK: Matched Mapping Attribute 'dependent': text
- match op:
- case OpView(attributes={"does_not_exist": a0}):
- assert False
- case OpView(attributes={"some.attribute": Attribute() as some_attr, "other.attribute": Attribute() as other_attr, "dependent": Attribute() as dep_attr}):
- print(f"Matched Mapping Attribute 'some.attribute': {some_attr.value}")
- print(f"Matched Mapping Attribute 'other.attribute': {other_attr.value}")
- print(f"Matched Mapping Attribute 'dependent': {dep_attr.value}")
- case _:
- print("Did not match!")
- assert False
-
# Check that exceptions are raised as expected.
try:
op.attributes["does_not_exist"]
>From a81e6c36d02cdbe8bdbfd683b62d115b3ef4af35 Mon Sep 17 00:00:00 2001
From: MaPePeR <MaPePeR at users.noreply.github.com>
Date: Wed, 31 Dec 2025 19:31:32 +0000
Subject: [PATCH 4/4] Register sequences in `mlir/_mlir_libs/__init__.py`
instead of C++
---
mlir/lib/Bindings/Python/IRCore.cpp | 14 --------------
mlir/python/mlir/_mlir_libs/__init__.py | 13 +++++++++++++
2 files changed, 13 insertions(+), 14 deletions(-)
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 1c2cf4401f7fb..168c57955af07 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -4951,20 +4951,6 @@ void mlir::python::populateIRCore(nb::module_ &m) {
PyRegionIterator::bind(m);
PyRegionList::bind(m);
- // Register containers as Sequences, so they can be used with `match`.
- nanobind::object scope = m.attr("__dict__");
- nanobind::exec("from collections.abc import Sequence, Mapping\n"
- "Sequence.register(BlockArgumentList)\n"
- "Sequence.register(BlockList)\n"
- "Sequence.register(BlockSuccessors)\n"
- "Sequence.register(BlockPredecessors)\n"
- "Sequence.register(OperationList)\n"
- "Sequence.register(OpOperandList)\n"
- "Sequence.register(OpResultList)\n"
- "Sequence.register(OpSuccessors)\n"
- "Sequence.register(RegionSequence)\n",
- scope);
-
// Debug bindings.
PyGlobalDebugFlag::bind(m);
diff --git a/mlir/python/mlir/_mlir_libs/__init__.py b/mlir/python/mlir/_mlir_libs/__init__.py
index 63244212ba42c..3cbf5603e22f7 100644
--- a/mlir/python/mlir/_mlir_libs/__init__.py
+++ b/mlir/python/mlir/_mlir_libs/__init__.py
@@ -31,6 +31,7 @@ def get_include_dirs() -> Sequence[str]:
# 1. Attempting to load initializer modules, specific to the distribution.
# 2. Defining the concrete mlir.ir.Context that does site specific
# initialization.
+# 3. Registering container classes.
#
# Aside from just being far more convenient to do this at the Python level,
# it is actually quite hard/impossible to have such __init__ hooks, given
@@ -233,5 +234,17 @@ def __str__(self):
ir.MLIRError = MLIRError
+ # Register containers as Sequences, so they can be used with `match`.
+
+ Sequence.register(ir.BlockArgumentList)
+ Sequence.register(ir.BlockList)
+ Sequence.register(ir.BlockSuccessors)
+ Sequence.register(ir.BlockPredecessors)
+ Sequence.register(ir.OperationList)
+ Sequence.register(ir.OpOperandList)
+ Sequence.register(ir.OpResultList)
+ Sequence.register(ir.OpSuccessors)
+ Sequence.register(ir.RegionSequence)
+
_site_initialize()
More information about the Mlir-commits
mailing list