[Mlir-commits] [mlir] 524fde8 - [MLIR][Python] Register `OpAttributeMap` as `Mapping` for `match` compatibility (#174292)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Jan 10 09:09:08 PST 2026
Author: MaPePeR
Date: 2026-01-10T09:09:05-08:00
New Revision: 524fde8a4de43d4114ad58349d4c4c12609108a5
URL: https://github.com/llvm/llvm-project/commit/524fde8a4de43d4114ad58349d4c4c12609108a5
DIFF: https://github.com/llvm/llvm-project/commit/524fde8a4de43d4114ad58349d4c4c12609108a5.diff
LOG: [MLIR][Python] Register `OpAttributeMap` as `Mapping` for `match` compatibility (#174292)
This is a continuation of the idea from #174091 to add `match` support
for MLIR containers. In this PR the `OpAttributeMap` container is
registered as a `Mapping`, so be mapped as a "dictionary" in `match`
statements.
For this to work the `get(key, default=None)` method had to be
implemented. Those are pretty much copys of `dunderGetItemNamed` and
`dunderGetItemIndexed` with an added argument and `nb::object` as return
type, because they can now return other types than just `PyAttribute`.
Was unsure if I should refactor this to make `dunderGetItem...` use the
new `getWithDefault...` or if a separate method is preferred. Kept it as
a copy for simplicitys sake for now.
Even though the `OpAttributeMap` supports indexing by `int` and `str`,
Python does not allow to register it as a `Sequence` and a `Mapping` at
the same time. If it is registered as a Sequence it only returns the
attribute names as string, not as `NamedAttribute`. It is technically
possible to also use integer keys for the `dict`-like match, but it
doesn't provide any constraints on the number of attributes, etc., so
probably not recommended.
<details><summary>Example</summary>
```python
from mlir.ir import Context, Module, OpAttributeMap
from collections.abc import Sequence
ctx = Context()
ctx.allow_unregistered_dialects = True
module = Module.parse(
r"""
"some.op"() { some.attribute = 1 : i8,
other.attribute = 3.0,
dependent = "text" } : () -> ()
""",
ctx,
)
op = module.body.operations[0]
def test(attr):
match attr:
case [*args]:
print("matched a Sequence", args)
case _:
print("Didn't match as Sequence")
match attr:
case {"some.attribute": a, "other.attribute": b, "dependent": c}:
print("Matched as Mapping individually", a, b, c)
case _:
print("Didn't match a Mapping")
match attr:
case {0: a, 1: b}:
print("Matched as Mapping with 2 int keys", a, b)
case _:
print("Didn't match as Mapping with 2 int keys")
print("Registered as Mapping only:")
test(op.attributes)
print("\nAfter additonally registering as Sequence:")
Sequence.register(OpAttributeMap)
test(op.attributes)
```
Output:
```
Registered as Mapping only:
Didn't match as Sequence
Matched as Mapping individually 1 : i8 3.000000e+00 : f64 "text"
Matched as Mapping with 2 int keys NamedAttribute(dependent="text") NamedAttribute(other.attribute=3.000000e+00 : f64)
After additonally registering as Sequence:
matched a Sequence ['dependent', 'other.attribute', 'some.attribute']
Didn't match a Mapping
Didn't match as Mapping with 2 int keys
```
</details>
makslevental Would be great if you could take a look again ❤️
---------
Co-authored-by: Maksim Levental <maksim.levental at gmail.com>
Added:
Modified:
mlir/include/mlir/Bindings/Python/IRCore.h
mlir/lib/Bindings/Python/IRCore.cpp
mlir/python/mlir/_mlir_libs/__init__.py
mlir/test/python/ir/operation.py
Removed:
################################################################################
diff --git a/mlir/include/mlir/Bindings/Python/IRCore.h b/mlir/include/mlir/Bindings/Python/IRCore.h
index 330318683c15e5..fa1b88d1a80ddc 100644
--- a/mlir/include/mlir/Bindings/Python/IRCore.h
+++ b/mlir/include/mlir/Bindings/Python/IRCore.h
@@ -1801,6 +1801,9 @@ class MLIR_PYTHON_API_EXPORTED PyOpAttributeMap {
PyNamedAttribute dunderGetItemIndexed(intptr_t index);
+ nanobind::typed<nanobind::object, std::optional<PyAttribute>>
+ get(const std::string &key, nanobind::object defaultValue);
+
void dunderSetItem(const std::string &name, const PyAttribute &attr);
void dunderDelItem(const std::string &name);
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 19db41fae4fe2a..a544648a47f458 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -2388,6 +2388,15 @@ PyOpAttributeMap::dunderGetItemNamed(const std::string &name) {
return PyAttribute(operation->getContext(), attr).maybeDownCast();
}
+nb::typed<nb::object, std::optional<PyAttribute>>
+PyOpAttributeMap::get(const std::string &key, nb::object defaultValue) {
+ MlirAttribute attr =
+ mlirOperationGetAttributeByName(operation->get(), toMlirStringRef(key));
+ if (mlirAttributeIsNull(attr))
+ return defaultValue;
+ return PyAttribute(operation->getContext(), attr).maybeDownCast();
+}
+
PyNamedAttribute PyOpAttributeMap::dunderGetItemIndexed(intptr_t index) {
if (index < 0) {
index += dunderLen();
@@ -2450,6 +2459,10 @@ void PyOpAttributeMap::bind(nb::module_ &m) {
"Sets an attribute with the given name.")
.def("__delitem__", &PyOpAttributeMap::dunderDelItem, "name"_a,
"Deletes an attribute with the given name.")
+ .def("get", &PyOpAttributeMap::get, nb::arg("key"),
+ nb::arg("default") = nb::none(),
+ "Gets an attribute by name or the default value, if it does not "
+ "exist.")
.def(
"__iter__",
[](PyOpAttributeMap &self) {
diff --git a/mlir/python/mlir/_mlir_libs/__init__.py b/mlir/python/mlir/_mlir_libs/__init__.py
index ce7e6bf93012a0..c0e8775149d41e 100644
--- a/mlir/python/mlir/_mlir_libs/__init__.py
+++ b/mlir/python/mlir/_mlir_libs/__init__.py
@@ -2,7 +2,7 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-from typing import Any, Sequence
+from typing import Any, Mapping, Sequence
import os
@@ -245,6 +245,7 @@ def __str__(self):
Sequence.register(ir.OpResultList)
Sequence.register(ir.OpSuccessors)
Sequence.register(ir.RegionSequence)
+ Mapping.register(ir.OpAttributeMap)
_site_initialize()
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index b9242a7cc2bd91..89f78ab1932a03 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -652,6 +652,34 @@ def testOperationAttributes():
# CHECK: Dict mapping {'dependent': 'text', 'other.attribute': 3.0, 'some.attribute': 1}
print("Dict mapping", d)
+ # Structural pattern matching test using Mapping
+
+ # CHECK: Matched Mapping Attribute 'some.attribute': 1
+ # CHECK: Matched Mapping Attribute 'other.attribute': 3.0
+ # CHECK: Matched Mapping Attribute 'dependent': text
+ match op:
+ case OpView(attributes={"does_not_exist": a0}):
+ assert False
+ case OpView(
+ attributes={
+ "some.attribute": IntegerAttr(value=some_attr_val) as some_attr,
+ "other.attribute": FloatAttr() as other_attr,
+ "dependent": StringAttr() as dep_attr,
+ **other_attributes,
+ }
+ ):
+ print(f"Matched Mapping Attribute 'some.attribute': {some_attr_val}")
+ print(f"Matched Mapping Attribute 'other.attribute': {other_attr.value}")
+ print(f"Matched Mapping Attribute 'dependent': {dep_attr.value}")
+ assert type(other_attributes) == dict
+ assert len(other_attributes) == 0
+ assert some_attr == op.attributes.get("some.attribute")
+ assert other_attr == op.attributes.get("other.attribute", None)
+ assert dep_attr == op.attributes.get("dependent", "Default value")
+ case _:
+ print("Did not match!")
+ assert False
+
# Check that exceptions are raised as expected.
try:
op.attributes["does_not_exist"]
@@ -667,6 +695,41 @@ def testOperationAttributes():
else:
assert False, "expected IndexError on accessing an out-of-bounds attribute"
+ # Check that exceptions are raised when `get` is used with non-str arg.
+ try:
+ op.attributes.get(0)
+ except TypeError:
+ pass
+ else:
+ assert False, "expected TypeError using int as key for get()"
+
+ try:
+ op.attributes.get(0, None)
+ except TypeError:
+ pass
+ else:
+ assert False, "expected TypeError using int as key for get()"
+
+ try:
+ op.attributes.get([], None)
+ except TypeError:
+ pass
+ else:
+ assert False, "expected TypeError using list as key for get()"
+
+ try:
+ match op:
+ case OpView(attributes={0: a}):
+ assert False
+ except TypeError:
+ pass
+ else:
+ assert False, "expected TypeError matching OpAttributeMap with int-key "
+
+ # get() does not throw for non existent attributes.
+ assert op.attributes.get("does_not_exist") is None
+ assert op.attributes.get("does_not_exist", "default_value") == "default_value"
+
# CHECK-LABEL: TEST: testOperationPrint
@run
More information about the Mlir-commits
mailing list