[Mlir-commits] [mlir] [MLIR][Python] Register Containers as Sequences for `match` compatibility (PR #174091)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Dec 31 05:32:11 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: None (MaPePeR)
<details>
<summary>Changes</summary>
This allows these containers to be used in `match` statements, which allows extracting properties and asserting a shape at the same time.
It seems to be only possible, to match as _either_ a `Mapping` _or_ a `Sequence`, so the `OpAttributeMap` is only a `Mapping`.
I couldn't find a way to make these C++ based types properly inherit from `Sequence` or `Mapping`, so the Mixins are not provided (nanobind only allows C++ parent classes, modifying `__base__` complains about differing destructors).
`OpAttributeMap` was lacking the `get` method, so I simply copied it from `collections.abc.Mapping`.
When writing the tests i ran into the error, that I wrote `func.FuncOp(body=[Block(...)])` instead of `func.FuncOp(body=Region(blocks=[Block(...)]))`. So maybe also turning `Region` itself into a Sequence would be a good addition as well? Would extend the Scope of this PR, though.
@<!-- -->makslevental You suggested I make the PR, so i'm tagging you here as a potential reviewer. I hope that is ok with you. :)
---
Full diff: https://github.com/llvm/llvm-project/pull/174091.diff
4 Files Affected:
- (modified) mlir/include/mlir/Bindings/Python/Nanobind.h (+1)
- (modified) mlir/lib/Bindings/Python/IRCore.cpp (+16)
- (modified) mlir/test/python/ir/blocks.py (+17)
- (modified) mlir/test/python/ir/operation.py (+54-4)
``````````diff
diff --git a/mlir/include/mlir/Bindings/Python/Nanobind.h b/mlir/include/mlir/Bindings/Python/Nanobind.h
index 8dc8a0d063d70..62837e8b8652b 100644
--- a/mlir/include/mlir/Bindings/Python/Nanobind.h
+++ b/mlir/include/mlir/Bindings/Python/Nanobind.h
@@ -21,6 +21,7 @@
#pragma GCC diagnostic ignored "-Wc++98-compat-extra-semi"
#pragma GCC diagnostic ignored "-Wcovered-switch-default"
#endif
+#include <nanobind/eval.h>
#include <nanobind/nanobind.h>
#include <nanobind/ndarray.h>
#include <nanobind/stl/function.h>
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 168c57955af07..84154ede9f3e1 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -4951,6 +4951,22 @@ void mlir::python::populateIRCore(nb::module_ &m) {
PyRegionIterator::bind(m);
PyRegionList::bind(m);
+ // Register containers as Sequences, so they can be used with `match`.
+ nanobind::object scope = m.attr("__dict__");
+ nanobind::exec("from collections.abc import Sequence, Mapping\n"
+ "Sequence.register(BlockArgumentList)\n"
+ "Sequence.register(BlockList)\n"
+ "Sequence.register(BlockSuccessors)\n"
+ "Sequence.register(BlockPredecessors)\n"
+ "Sequence.register(OperationList)\n"
+ "Sequence.register(OpOperandList)\n"
+ "Sequence.register(OpResultList)\n"
+ "Sequence.register(OpSuccessors)\n"
+ "Sequence.register(RegionSequence)\n"
+ "OpAttributeMap.get = Mapping.get\n"
+ "Mapping.register(OpAttributeMap)\n",
+ scope);
+
// Debug bindings.
PyGlobalDebugFlag::bind(m);
diff --git a/mlir/test/python/ir/blocks.py b/mlir/test/python/ir/blocks.py
index e876c00e0c52d..6c7501b11292e 100644
--- a/mlir/test/python/ir/blocks.py
+++ b/mlir/test/python/ir/blocks.py
@@ -72,6 +72,23 @@ def testBlockCreation():
assert len(successor_block.successors) == 0
+ # Same checks but using structural pattern matching.
+ match entry_block:
+ case Block(
+ predecessors=[],
+ successors=[
+ Block(
+ predecessors=[matched_entry_block],
+ successors=[
+ Block(predecessors=[matched_middle_block], successors=[])
+ ]
+ )
+ ]
+ ) if entry_block == matched_entry_block and middle_block == matched_middle_block:
+ assert True
+ case _:
+ assert False
+
# CHECK-LABEL: TEST: testBlockCreationArgLocs
@run
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index 9a61911185284..a4f7ffa073994 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -190,8 +190,8 @@ def testBlockArgumentList():
""",
ctx,
)
- func = module.body.operations[0]
- entry_block = func.regions[0].blocks[0]
+ func_op = module.body.operations[0]
+ entry_block = func_op.regions[0].blocks[0]
assert len(entry_block.arguments) == 3
# CHECK: Argument 0, type i32
# CHECK: Argument 1, type f64
@@ -206,6 +206,21 @@ def testBlockArgumentList():
# CHECK: Argument 2, type i24
for arg in entry_block.arguments:
print(f"Argument {arg.arg_number}, type {arg.type}")
+
+ # CHECK: Matched Arg 0, type i8
+ # CHECK: Matched Arg 1, type i16
+ # CHECK: Matched Arg 2, type i24
+ match func_op:
+ case func.FuncOp(body=Region(blocks=[Block(arguments=[a0, a1])])):
+ assert False
+ case func.FuncOp(body=Region(blocks=[Block(arguments=[a0, a1, a2, a3])])):
+ assert False
+ case func.FuncOp(body=Region(blocks=[Block(arguments=[a0, a1, a2])])):
+ print(f"Matched Arg 0, type {a0.type}")
+ print(f"Matched Arg 1, type {a1.type}")
+ print(f"Matched Arg 2, type {a2.type}")
+ case _:
+ assert False
# Check that slicing works for block argument lists.
# CHECK: Argument 1, type i16
@@ -250,14 +265,21 @@ def testOperationOperands():
return
}"""
)
- func = module.body.operations[0]
- entry_block = func.regions[0].blocks[0]
+ func_op = module.body.operations[0]
+ entry_block = func_op.regions[0].blocks[0]
consumer = entry_block.operations[1]
assert len(consumer.operands) == 2
# CHECK: Operand 0, type i32
# CHECK: Operand 1, type i64
for i, operand in enumerate(consumer.operands):
print(f"Operand {i}, type {operand.type}")
+
+ match module.body.operations:
+ case [
+ func.FuncOp(body=Region(blocks=[Block(operations=[_, OpView(operands=[o1, o2]) as matched_consumer, *_])]))
+ ]:
+ print(f"Matched Operand 0, type {o1.type}")
+ print(f"Matched Operand 1, type {o2.type}")
# CHECK-LABEL: TEST: testOperationOperandsSlice
@@ -480,6 +502,18 @@ def testOperationResultList():
for res in call.results:
print(f"Result {res.result_number}, type {res.type}")
+ # CHECK: Matched Result r0, type i32
+ # CHECK: Matched Result r1, type f64
+ # CHECK: Matched Result r2, type index
+ match caller:
+ case func.FuncOp(body=Region(blocks=[Block(operations=[OpView(results=[r0, r1, r2]) as matched_call, *_])])):
+ assert matched_call == call
+ print(f"Matched Result r0, type {r0.type}")
+ print(f"Matched Result r1, type {r1.type}")
+ print(f"Matched Result r2, type {r2.type}")
+ case _:
+ assert False
+
# CHECK: Result type i32
# CHECK: Result type f64
# CHECK: Result type index
@@ -598,6 +632,22 @@ def testOperationAttributes():
# CHECK: Dict mapping {'dependent': 'text', 'other.attribute': 3.0, 'some.attribute': 1}
print("Dict mapping", d)
+ # Structural pattern matching test using Mapping
+
+ # CHECK: Matched Mapping Attribute 'some.attribute': 1
+ # CHECK: Matched Mapping Attribute 'other.attribute': 3.0
+ # CHECK: Matched Mapping Attribute 'dependent': text
+ match op:
+ case OpView(attributes={"does_not_exist": a0}):
+ assert False
+ case OpView(attributes={"some.attribute": Attribute() as some_attr, "other.attribute": Attribute() as other_attr, "dependent": Attribute() as dep_attr}):
+ print(f"Matched Mapping Attribute 'some.attribute': {some_attr.value}")
+ print(f"Matched Mapping Attribute 'other.attribute': {other_attr.value}")
+ print(f"Matched Mapping Attribute 'dependent': {dep_attr.value}")
+ case _:
+ print("Did not match!")
+ assert False
+
# Check that exceptions are raised as expected.
try:
op.attributes["does_not_exist"]
``````````
</details>
https://github.com/llvm/llvm-project/pull/174091
More information about the Mlir-commits
mailing list