[Mlir-commits] [mlir] [mlir, python] Fix case when `FuncOp.arg_attrs` is not set (PR #117188)

Perry Gibson llvmlistbot at llvm.org
Tue Nov 26 05:54:26 PST 2024


https://github.com/Wheest updated https://github.com/llvm/llvm-project/pull/117188

>From bad0d96c1f3be74b9fb7f9da9c7cd05c58b62375 Mon Sep 17 00:00:00 2001
From: pez <perry at fractile.ai>
Date: Thu, 21 Nov 2024 16:40:04 +0000
Subject: [PATCH 1/4] Add check to see if `FuncOp.arg_attrs` is set

---
 mlir/python/mlir/dialects/func.py |  4 ++++
 mlir/test/python/dialects/func.py | 13 +++++++++++++
 2 files changed, 17 insertions(+)

diff --git a/mlir/python/mlir/dialects/func.py b/mlir/python/mlir/dialects/func.py
index 24fdcbcd85b29f..211027d88051a7 100644
--- a/mlir/python/mlir/dialects/func.py
+++ b/mlir/python/mlir/dialects/func.py
@@ -105,6 +105,10 @@ def add_entry_block(self, arg_locs: Optional[Sequence[Location]] = None):
 
     @property
     def arg_attrs(self):
+        if ARGUMENT_ATTRIBUTE_NAME not in self.attributes:
+            self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get(
+                [DictAttr.get({}) for _ in self.type.inputs]
+            )
         return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME])
 
     @arg_attrs.setter
diff --git a/mlir/test/python/dialects/func.py b/mlir/test/python/dialects/func.py
index a2014c64d2fa53..bcfaace853bc64 100644
--- a/mlir/test/python/dialects/func.py
+++ b/mlir/test/python/dialects/func.py
@@ -104,3 +104,16 @@ def testFunctionCalls():
 # CHECK:   %1 = call @qux() : () -> f32
 # CHECK:   return
 # CHECK: }
+
+
+# CHECK-LABEL: TEST: testFunctionArgAttrs
+ at constructAndPrintInModule
+def testFunctionArgAttrs():
+    foo = func.FuncOp("foo", ([("arg0", F32Type.get())], []))
+
+    assert len(foo.arg_attrs) == 1
+    assert foo.arg_attrs[0] = ir.DictAttr.get({})
+
+    foo.arg_attrs = [DictAttr.get({"test.foo": StringAttr.get("bar")})]
+
+    assert foo.arg_attrs[0]["test.foo"] == StringAttr.get("bar")

>From 049bacebd0c9dbe38e23568fe16113731093a81f Mon Sep 17 00:00:00 2001
From: pez <perry at fractile.ai>
Date: Thu, 21 Nov 2024 16:56:32 +0000
Subject: [PATCH 2/4] Linting fix

---
 mlir/test/python/dialects/func.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/test/python/dialects/func.py b/mlir/test/python/dialects/func.py
index bcfaace853bc64..cc2d616c6407ce 100644
--- a/mlir/test/python/dialects/func.py
+++ b/mlir/test/python/dialects/func.py
@@ -112,7 +112,7 @@ def testFunctionArgAttrs():
     foo = func.FuncOp("foo", ([("arg0", F32Type.get())], []))
 
     assert len(foo.arg_attrs) == 1
-    assert foo.arg_attrs[0] = ir.DictAttr.get({})
+    assert foo.arg_attrs[0] == DictAttr.get({})
 
     foo.arg_attrs = [DictAttr.get({"test.foo": StringAttr.get("bar")})]
 

>From 2ae4815fe814948ad012aba1e3a7c9dd12e05909 Mon Sep 17 00:00:00 2001
From: pez <perry at fractile.ai>
Date: Fri, 22 Nov 2024 14:33:45 +0000
Subject: [PATCH 3/4] Add additional test case

---
 mlir/test/python/dialects/func.py | 24 ++++++++++++++++++++----
 1 file changed, 20 insertions(+), 4 deletions(-)

diff --git a/mlir/test/python/dialects/func.py b/mlir/test/python/dialects/func.py
index cc2d616c6407ce..6b3932ce64f137 100644
--- a/mlir/test/python/dialects/func.py
+++ b/mlir/test/python/dialects/func.py
@@ -109,11 +109,27 @@ def testFunctionCalls():
 # CHECK-LABEL: TEST: testFunctionArgAttrs
 @constructAndPrintInModule
 def testFunctionArgAttrs():
-    foo = func.FuncOp("foo", ([("arg0", F32Type.get())], []))
+    foo = func.FuncOp("foo", ([F32Type.get()], []))
+    foo.sym_visibility = StringAttr.get("private")
+    foo2 = func.FuncOp("foo2", ([F32Type.get(), F32Type.get()], []))
+    foo2.sym_visibility = StringAttr.get("private")
 
-    assert len(foo.arg_attrs) == 1
-    assert foo.arg_attrs[0] == DictAttr.get({})
+    empty_attr = DictAttr.get({})
+    test_attr = DictAttr.get({"test.foo": StringAttr.get("bar")})
+    test_attr2 = DictAttr.get({"test.baz": StringAttr.get("qux")})
 
-    foo.arg_attrs = [DictAttr.get({"test.foo": StringAttr.get("bar")})]
+    assert len(foo.arg_attrs) == 1
+    assert foo.arg_attrs[0] == empty_attr
 
+    foo.arg_attrs = [test_attr]
     assert foo.arg_attrs[0]["test.foo"] == StringAttr.get("bar")
+
+    assert len(foo2.arg_attrs) == 2
+    assert foo2.arg_attrs == ArrayAttr.get([empty_attr, empty_attr])
+
+    foo2.arg_attrs = [empty_attr, test_attr2]
+    assert foo2.arg_attrs == ArrayAttr.get([empty_attr, test_attr2])
+
+
+# CHECK: func private @foo(f32 {test.foo = "bar"})
+# CHECK: func private @foo2(f32, f32  {test.baz = "qux"})

>From c6afea5b86ccb35b10afdf5c03859378749f73c8 Mon Sep 17 00:00:00 2001
From: Perry Gibson <perry at fractile.ai>
Date: Tue, 26 Nov 2024 13:53:54 +0000
Subject: [PATCH 4/4] Return empty dict list, do not mutate IR

---
 mlir/python/mlir/dialects/func.py | 4 +---
 1 file changed, 1 insertion(+), 3 deletions(-)

diff --git a/mlir/python/mlir/dialects/func.py b/mlir/python/mlir/dialects/func.py
index 211027d88051a7..1898fc1565cd49 100644
--- a/mlir/python/mlir/dialects/func.py
+++ b/mlir/python/mlir/dialects/func.py
@@ -106,9 +106,7 @@ def add_entry_block(self, arg_locs: Optional[Sequence[Location]] = None):
     @property
     def arg_attrs(self):
         if ARGUMENT_ATTRIBUTE_NAME not in self.attributes:
-            self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get(
-                [DictAttr.get({}) for _ in self.type.inputs]
-            )
+            return ArrayAttr.get([DictAttr.get({}) for _ in self.type.inputs])
         return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME])
 
     @arg_attrs.setter



More information about the Mlir-commits mailing list