[Mlir-commits] [mlir] d898ff6 - [mlir, python] Fix case when `FuncOp.arg_attrs` is not set (#117188)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Dec 2 08:55:55 PST 2024


Author: Perry Gibson
Date: 2024-12-02T08:55:51-08:00
New Revision: d898ff650ae09e3ef942592aee2e87627f45d7c6

URL: https://github.com/llvm/llvm-project/commit/d898ff650ae09e3ef942592aee2e87627f45d7c6
DIFF: https://github.com/llvm/llvm-project/commit/d898ff650ae09e3ef942592aee2e87627f45d7c6.diff

LOG: [mlir,python] Fix case when `FuncOp.arg_attrs` is not set (#117188)

FuncOps can have `arg_attrs`, an array of dictionary attributes
associated with their arguments.

E.g., 

```mlir
func.func @main(%arg0: tensor<8xf32> {test.attr_name = "value"}, %arg1: tensor<8x16xf32>)
```

These are exposed via the MLIR Python bindings with
`my_funcop.arg_attrs`.

In this case, it would return `[{test.attr_name = "value"}, {}]`, i.e.,
`%arg1` has an empty `DictAttr`.

However, if I try and access this property from a FuncOp with an empty
`arg_attrs`, e.g.,

```mlir
func.func @main(%arg0: tensor<8xf32>, %arg1: tensor<8x16xf32>)
```

This raises the error:

```python
    return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME])
                     ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^
KeyError: 'attempt to access a non-existent attribute'
```

This PR fixes this by returning the expected `[{}, {}]`.

Added: 
    

Modified: 
    mlir/python/mlir/dialects/func.py
    mlir/test/python/dialects/func.py

Removed: 
    


################################################################################
diff  --git a/mlir/python/mlir/dialects/func.py b/mlir/python/mlir/dialects/func.py
index 24fdcbcd85b29f..1898fc1565cd49 100644
--- a/mlir/python/mlir/dialects/func.py
+++ b/mlir/python/mlir/dialects/func.py
@@ -105,6 +105,8 @@ def add_entry_block(self, arg_locs: Optional[Sequence[Location]] = None):
 
     @property
     def arg_attrs(self):
+        if ARGUMENT_ATTRIBUTE_NAME not in self.attributes:
+            return ArrayAttr.get([DictAttr.get({}) for _ in self.type.inputs])
         return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME])
 
     @arg_attrs.setter

diff  --git a/mlir/test/python/dialects/func.py b/mlir/test/python/dialects/func.py
index a2014c64d2fa53..6b3932ce64f137 100644
--- a/mlir/test/python/dialects/func.py
+++ b/mlir/test/python/dialects/func.py
@@ -104,3 +104,32 @@ def testFunctionCalls():
 # CHECK:   %1 = call @qux() : () -> f32
 # CHECK:   return
 # CHECK: }
+
+
+# CHECK-LABEL: TEST: testFunctionArgAttrs
+ at constructAndPrintInModule
+def testFunctionArgAttrs():
+    foo = func.FuncOp("foo", ([F32Type.get()], []))
+    foo.sym_visibility = StringAttr.get("private")
+    foo2 = func.FuncOp("foo2", ([F32Type.get(), F32Type.get()], []))
+    foo2.sym_visibility = StringAttr.get("private")
+
+    empty_attr = DictAttr.get({})
+    test_attr = DictAttr.get({"test.foo": StringAttr.get("bar")})
+    test_attr2 = DictAttr.get({"test.baz": StringAttr.get("qux")})
+
+    assert len(foo.arg_attrs) == 1
+    assert foo.arg_attrs[0] == empty_attr
+
+    foo.arg_attrs = [test_attr]
+    assert foo.arg_attrs[0]["test.foo"] == StringAttr.get("bar")
+
+    assert len(foo2.arg_attrs) == 2
+    assert foo2.arg_attrs == ArrayAttr.get([empty_attr, empty_attr])
+
+    foo2.arg_attrs = [empty_attr, test_attr2]
+    assert foo2.arg_attrs == ArrayAttr.get([empty_attr, test_attr2])
+
+
+# CHECK: func private @foo(f32 {test.foo = "bar"})
+# CHECK: func private @foo2(f32, f32  {test.baz = "qux"})


        


More information about the Mlir-commits mailing list