[Mlir-commits] [mlir] [mlir, python] Fix case when `FuncOp.arg_attrs` is not set (PR #117188)
Perry Gibson
llvmlistbot at llvm.org
Thu Nov 21 08:57:37 PST 2024
https://github.com/Wheest updated https://github.com/llvm/llvm-project/pull/117188
>From bad0d96c1f3be74b9fb7f9da9c7cd05c58b62375 Mon Sep 17 00:00:00 2001
From: pez <perry at fractile.ai>
Date: Thu, 21 Nov 2024 16:40:04 +0000
Subject: [PATCH 1/2] Add check to see if `FuncOp.arg_attrs` is set
---
mlir/python/mlir/dialects/func.py | 4 ++++
mlir/test/python/dialects/func.py | 13 +++++++++++++
2 files changed, 17 insertions(+)
diff --git a/mlir/python/mlir/dialects/func.py b/mlir/python/mlir/dialects/func.py
index 24fdcbcd85b29f..211027d88051a7 100644
--- a/mlir/python/mlir/dialects/func.py
+++ b/mlir/python/mlir/dialects/func.py
@@ -105,6 +105,10 @@ def add_entry_block(self, arg_locs: Optional[Sequence[Location]] = None):
@property
def arg_attrs(self):
+ if ARGUMENT_ATTRIBUTE_NAME not in self.attributes:
+ self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get(
+ [DictAttr.get({}) for _ in self.type.inputs]
+ )
return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME])
@arg_attrs.setter
diff --git a/mlir/test/python/dialects/func.py b/mlir/test/python/dialects/func.py
index a2014c64d2fa53..bcfaace853bc64 100644
--- a/mlir/test/python/dialects/func.py
+++ b/mlir/test/python/dialects/func.py
@@ -104,3 +104,16 @@ def testFunctionCalls():
# CHECK: %1 = call @qux() : () -> f32
# CHECK: return
# CHECK: }
+
+
+# CHECK-LABEL: TEST: testFunctionArgAttrs
+ at constructAndPrintInModule
+def testFunctionArgAttrs():
+ foo = func.FuncOp("foo", ([("arg0", F32Type.get())], []))
+
+ assert len(foo.arg_attrs) == 1
+ assert foo.arg_attrs[0] = ir.DictAttr.get({})
+
+ foo.arg_attrs = [DictAttr.get({"test.foo": StringAttr.get("bar")})]
+
+ assert foo.arg_attrs[0]["test.foo"] == StringAttr.get("bar")
>From 049bacebd0c9dbe38e23568fe16113731093a81f Mon Sep 17 00:00:00 2001
From: pez <perry at fractile.ai>
Date: Thu, 21 Nov 2024 16:56:32 +0000
Subject: [PATCH 2/2] Linting fix
---
mlir/test/python/dialects/func.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/test/python/dialects/func.py b/mlir/test/python/dialects/func.py
index bcfaace853bc64..cc2d616c6407ce 100644
--- a/mlir/test/python/dialects/func.py
+++ b/mlir/test/python/dialects/func.py
@@ -112,7 +112,7 @@ def testFunctionArgAttrs():
foo = func.FuncOp("foo", ([("arg0", F32Type.get())], []))
assert len(foo.arg_attrs) == 1
- assert foo.arg_attrs[0] = ir.DictAttr.get({})
+ assert foo.arg_attrs[0] == DictAttr.get({})
foo.arg_attrs = [DictAttr.get({"test.foo": StringAttr.get("bar")})]
More information about the Mlir-commits
mailing list