[Mlir-commits] [mlir] [MLIR][Python] Register `OpAttributeMap` as `Mapping` for `match` compatibility (PR #174292)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Jan 3 13:39:11 PST 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: None (MaPePeR)
<details>
<summary>Changes</summary>
This is a continuation of the idea from #<!-- -->174091 to add `match` support for MLIR containers. In this PR the `OpAttributeMap` container is registered as a `Mapping`, so be mapped as a "dictionary" in `match` statements.
For this to work the `get(key, default=None)` method had to be implemented. Those are pretty much copys of `dunderGetItemNamed` and `dunderGetItemIndexed` with an added argument and `nb::object` as return type, because they can now return other types than just `PyAttribute`. Was unsure if I should refactor this to make `dunderGetItem...` use the new `getWithDefault...` or if a separate method is preferred. Kept it as a copy for simplicitys sake for now.
Even though the `OpAttributeMap` supports indexing by `int` and `str`, Python does not allow to register it as a `Sequence` and a `Mapping` at the same time. If it is registered as a Sequence it only returns the attribute names as string, not as `NamedAttribute`. It is technically possible to also use integer keys for the `dict`-like match, but it doesn't provide any constraints on the number of attributes, etc., so probably not recommended.
<details><summary>Example</summary>
```python
from mlir.ir import Context, Module, OpAttributeMap
from collections.abc import Sequence
ctx = Context()
ctx.allow_unregistered_dialects = True
module = Module.parse(
r"""
"some.op"() { some.attribute = 1 : i8,
other.attribute = 3.0,
dependent = "text" } : () -> ()
""",
ctx,
)
op = module.body.operations[0]
def test(attr):
match attr:
case [*args]:
print("matched a Sequence", args)
case _:
print("Didn't match as Sequence")
match attr:
case {"some.attribute": a, "other.attribute": b, "dependent": c}:
print("Matched as Mapping individually", a, b, c)
case _:
print("Didn't match a Mapping")
match attr:
case {0: a, 1: b}:
print("Matched as Mapping with 2 int keys", a, b)
case _:
print("Didn't match as Mapping with 2 int keys")
print("Registered as Mapping only:")
test(op.attributes)
print("\nAfter additonally registering as Sequence:")
Sequence.register(OpAttributeMap)
test(op.attributes)
```
Output:
```
Registered as Mapping only:
Didn't match as Sequence
Matched as Mapping individually 1 : i8 3.000000e+00 : f64 "text"
Matched as Mapping with 2 int keys NamedAttribute(dependent="text") NamedAttribute(other.attribute=3.000000e+00 : f64)
After additonally registering as Sequence:
matched a Sequence ['dependent', 'other.attribute', 'some.attribute']
Didn't match a Mapping
Didn't match as Mapping with 2 int keys
```
</details>
@<!-- -->makslevental Would be great if you could take a look again ❤️
---
Full diff: https://github.com/llvm/llvm-project/pull/174292.diff
3 Files Affected:
- (modified) mlir/lib/Bindings/Python/IRCore.cpp (+33)
- (modified) mlir/python/mlir/_mlir_libs/__init__.py (+2-1)
- (modified) mlir/test/python/ir/operation.py (+25)
``````````diff
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 168c57955af07..f743c4989bc6c 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -2614,6 +2614,16 @@ class PyOpAttributeMap {
return PyAttribute(operation->getContext(), attr).maybeDownCast();
}
+ nb::object getWithDefaultNamed(const std::string &key,
+ nb::object defaultValue) {
+ MlirAttribute attr =
+ mlirOperationGetAttributeByName(operation->get(), toMlirStringRef(key));
+ if (mlirAttributeIsNull(attr)) {
+ return defaultValue;
+ }
+ return PyAttribute(operation->getContext(), attr).maybeDownCast();
+ }
+
PyNamedAttribute dunderGetItemIndexed(intptr_t index) {
if (index < 0) {
index += dunderLen();
@@ -2629,6 +2639,21 @@ class PyOpAttributeMap {
mlirIdentifierStr(namedAttr.name).length));
}
+ nb::object getWithDefaultIndexed(intptr_t key, nb::object defaultValue) {
+ if (key < 0) {
+ key += dunderLen();
+ }
+ if (key < 0 || key >= dunderLen()) {
+ return defaultValue;
+ }
+ MlirNamedAttribute namedAttr =
+ mlirOperationGetAttribute(operation->get(), key);
+ return nb::cast(PyNamedAttribute(
+ namedAttr.attribute,
+ std::string(mlirIdentifierStr(namedAttr.name).data,
+ mlirIdentifierStr(namedAttr.name).length)));
+ }
+
void dunderSetItem(const std::string &name, const PyAttribute &attr) {
mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name),
attr);
@@ -2675,6 +2700,14 @@ class PyOpAttributeMap {
nb::arg("attr"), "Sets an attribute with the given name.")
.def("__delitem__", &PyOpAttributeMap::dunderDelItem, nb::arg("name"),
"Deletes an attribute with the given name.")
+ .def("get", &PyOpAttributeMap::getWithDefaultNamed, nb::arg("key"),
+ nb::arg("default") = nb::none(),
+ "Gets an attribute by name or the default value, if it does not "
+ "exist.")
+ .def("get", &PyOpAttributeMap::getWithDefaultIndexed, nb::arg("key"),
+ nb::arg("default") = nb::none(),
+ "Gets a named attribute by index or the default value, if it does "
+ "not exist.")
.def(
"__iter__",
[](PyOpAttributeMap &self) {
diff --git a/mlir/python/mlir/_mlir_libs/__init__.py b/mlir/python/mlir/_mlir_libs/__init__.py
index ce7e6bf93012a..c0e8775149d41 100644
--- a/mlir/python/mlir/_mlir_libs/__init__.py
+++ b/mlir/python/mlir/_mlir_libs/__init__.py
@@ -2,7 +2,7 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-from typing import Any, Sequence
+from typing import Any, Mapping, Sequence
import os
@@ -245,6 +245,7 @@ def __str__(self):
Sequence.register(ir.OpResultList)
Sequence.register(ir.OpSuccessors)
Sequence.register(ir.RegionSequence)
+ Mapping.register(ir.OpAttributeMap)
_site_initialize()
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index b9242a7cc2bd9..382738fdffa9c 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -652,6 +652,31 @@ def testOperationAttributes():
# CHECK: Dict mapping {'dependent': 'text', 'other.attribute': 3.0, 'some.attribute': 1}
print("Dict mapping", d)
+ # Structural pattern matching test using Mapping
+
+ # CHECK: Matched Mapping Attribute 'some.attribute': 1
+ # CHECK: Matched Mapping Attribute 'other.attribute': 3.0
+ # CHECK: Matched Mapping Attribute 'dependent': text
+ match op:
+ case OpView(attributes={"does_not_exist": a0}):
+ assert False
+ case OpView(
+ attributes={
+ "some.attribute": IntegerAttr(value=some_attr_val),
+ "other.attribute": FloatAttr() as other_attr,
+ "dependent": StringAttr() as dep_attr,
+ **other_attributes,
+ }
+ ):
+ print(f"Matched Mapping Attribute 'some.attribute': {some_attr_val}")
+ print(f"Matched Mapping Attribute 'other.attribute': {other_attr.value}")
+ print(f"Matched Mapping Attribute 'dependent': {dep_attr.value}")
+ assert type(other_attributes) == dict
+ assert len(other_attributes) == 0
+ case _:
+ print("Did not match!")
+ assert False
+
# Check that exceptions are raised as expected.
try:
op.attributes["does_not_exist"]
``````````
</details>
https://github.com/llvm/llvm-project/pull/174292
More information about the Mlir-commits
mailing list