[Mlir-commits] [mlir] [mlir, python] Fix case when `FuncOp.arg_attrs` is not set (PR #117188)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Nov 22 06:52:37 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Perry Gibson (Wheest)
<details>
<summary>Changes</summary>
FuncOps can have `arg_attrs`, an array of dictionary attributes associated with their arguments.
E.g.,
```mlir
func.func @<!-- -->main(%arg0: tensor<8xf32> {test.attr_name = "value"}, %arg1: tensor<8x16xf32>)
```
These are exposed via the MLIR Python bindings with `my_funcop.arg_attrs`.
In this case, it would return `[{test.attr_name = "value"}, {}]`, i.e., `%arg1` has an empty `DictAttr`.
However, if I try and access this property from a FuncOp with an empty `arg_attrs`, e.g.,
```mlir
func.func @<!-- -->main(%arg0: tensor<8xf32>, %arg1: tensor<8x16xf32>)
```
This raises the error:
```python
return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME])
~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^
KeyError: 'attempt to access a non-existent attribute'
```
This PR fixes this by returning the expected `[{}, {}]`.
---
Full diff: https://github.com/llvm/llvm-project/pull/117188.diff
2 Files Affected:
- (modified) mlir/python/mlir/dialects/func.py (+4)
- (modified) mlir/test/python/dialects/func.py (+29)
``````````diff
diff --git a/mlir/python/mlir/dialects/func.py b/mlir/python/mlir/dialects/func.py
index 24fdcbcd85b29f..211027d88051a7 100644
--- a/mlir/python/mlir/dialects/func.py
+++ b/mlir/python/mlir/dialects/func.py
@@ -105,6 +105,10 @@ def add_entry_block(self, arg_locs: Optional[Sequence[Location]] = None):
@property
def arg_attrs(self):
+ if ARGUMENT_ATTRIBUTE_NAME not in self.attributes:
+ self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get(
+ [DictAttr.get({}) for _ in self.type.inputs]
+ )
return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME])
@arg_attrs.setter
diff --git a/mlir/test/python/dialects/func.py b/mlir/test/python/dialects/func.py
index a2014c64d2fa53..6b3932ce64f137 100644
--- a/mlir/test/python/dialects/func.py
+++ b/mlir/test/python/dialects/func.py
@@ -104,3 +104,32 @@ def testFunctionCalls():
# CHECK: %1 = call @qux() : () -> f32
# CHECK: return
# CHECK: }
+
+
+# CHECK-LABEL: TEST: testFunctionArgAttrs
+ at constructAndPrintInModule
+def testFunctionArgAttrs():
+ foo = func.FuncOp("foo", ([F32Type.get()], []))
+ foo.sym_visibility = StringAttr.get("private")
+ foo2 = func.FuncOp("foo2", ([F32Type.get(), F32Type.get()], []))
+ foo2.sym_visibility = StringAttr.get("private")
+
+ empty_attr = DictAttr.get({})
+ test_attr = DictAttr.get({"test.foo": StringAttr.get("bar")})
+ test_attr2 = DictAttr.get({"test.baz": StringAttr.get("qux")})
+
+ assert len(foo.arg_attrs) == 1
+ assert foo.arg_attrs[0] == empty_attr
+
+ foo.arg_attrs = [test_attr]
+ assert foo.arg_attrs[0]["test.foo"] == StringAttr.get("bar")
+
+ assert len(foo2.arg_attrs) == 2
+ assert foo2.arg_attrs == ArrayAttr.get([empty_attr, empty_attr])
+
+ foo2.arg_attrs = [empty_attr, test_attr2]
+ assert foo2.arg_attrs == ArrayAttr.get([empty_attr, test_attr2])
+
+
+# CHECK: func private @foo(f32 {test.foo = "bar"})
+# CHECK: func private @foo2(f32, f32 {test.baz = "qux"})
``````````
</details>
https://github.com/llvm/llvm-project/pull/117188
More information about the Mlir-commits
mailing list