[Mlir-commits] [mlir] dd74e6b - [mlir] FunctionOpInterface: arg and result attrs dispatch to interface

Jeff Niu llvmlistbot at llvm.org
Thu Dec 8 11:32:53 PST 2022


Author: Jeff Niu
Date: 2022-12-08T11:32:38-08:00
New Revision: dd74e6b6f4fb7a4685086a4895c1934e043f875b

URL: https://github.com/llvm/llvm-project/commit/dd74e6b6f4fb7a4685086a4895c1934e043f875b
DIFF: https://github.com/llvm/llvm-project/commit/dd74e6b6f4fb7a4685086a4895c1934e043f875b.diff

LOG: [mlir] FunctionOpInterface: arg and result attrs dispatch to interface

This patch removes the `arg_attrs` and `res_attrs` named attributes as a
requirement for FunctionOpInterface and replaces them with interface
methods for the getters, setters, and removers of the relevent
attributes. This allows operations to use their own storage for the
argument and result attributes.

Depends on D139471

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D139472

Added: 
    

Modified: 
    mlir/examples/toy/Ch2/include/toy/Ops.td
    mlir/examples/toy/Ch2/mlir/Dialect.cpp
    mlir/examples/toy/Ch3/include/toy/Ops.td
    mlir/examples/toy/Ch3/mlir/Dialect.cpp
    mlir/examples/toy/Ch4/include/toy/Ops.td
    mlir/examples/toy/Ch4/mlir/Dialect.cpp
    mlir/examples/toy/Ch5/include/toy/Ops.td
    mlir/examples/toy/Ch5/mlir/Dialect.cpp
    mlir/examples/toy/Ch6/include/toy/Ops.td
    mlir/examples/toy/Ch6/mlir/Dialect.cpp
    mlir/examples/toy/Ch7/include/toy/Ops.td
    mlir/examples/toy/Ch7/mlir/Dialect.cpp
    mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
    mlir/include/mlir/Dialect/Func/IR/FuncOps.td
    mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
    mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
    mlir/include/mlir/Dialect/MLProgram/IR/MLProgramOps.td
    mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
    mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
    mlir/include/mlir/IR/FunctionImplementation.h
    mlir/include/mlir/IR/FunctionInterfaces.h
    mlir/include/mlir/IR/FunctionInterfaces.td
    mlir/include/mlir/IR/OpBase.td
    mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
    mlir/lib/Dialect/Async/IR/Async.cpp
    mlir/lib/Dialect/Func/IR/FuncOps.cpp
    mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
    mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
    mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp
    mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp
    mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
    mlir/lib/Dialect/Shape/IR/Shape.cpp
    mlir/lib/IR/FunctionImplementation.cpp
    mlir/lib/IR/FunctionInterfaces.cpp
    mlir/test/IR/invalid-func-op.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/examples/toy/Ch2/include/toy/Ops.td b/mlir/examples/toy/Ch2/include/toy/Ops.td
index 380536bc3f945..4e2fb9ec397e5 100644
--- a/mlir/examples/toy/Ch2/include/toy/Ops.td
+++ b/mlir/examples/toy/Ch2/include/toy/Ops.td
@@ -134,7 +134,9 @@ def FuncOp : Toy_Op<"func", [
 
   let arguments = (ins
     SymbolNameAttr:$sym_name,
-    TypeAttrOf<FunctionType>:$function_type
+    TypeAttrOf<FunctionType>:$function_type,
+    OptionalAttr<DictArrayAttr>:$arg_attrs,
+    OptionalAttr<DictArrayAttr>:$res_attrs
   );
   let regions = (region AnyRegion:$body);
 

diff  --git a/mlir/examples/toy/Ch2/mlir/Dialect.cpp b/mlir/examples/toy/Ch2/mlir/Dialect.cpp
index ac12c5c0ac512..201f9c7d91b83 100644
--- a/mlir/examples/toy/Ch2/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch2/mlir/Dialect.cpp
@@ -218,8 +218,9 @@ mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser,
 void FuncOp::print(mlir::OpAsmPrinter &p) {
   // Dispatch to the FunctionOpInterface provided utility method that prints the
   // function operation.
-  mlir::function_interface_impl::printFunctionOp(p, *this,
-                                                 /*isVariadic=*/false);
+  mlir::function_interface_impl::printFunctionOp(
+      p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
+      getArgAttrsAttrName(), getResAttrsAttrName());
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/examples/toy/Ch3/include/toy/Ops.td b/mlir/examples/toy/Ch3/include/toy/Ops.td
index e526fe5bdaf51..1a4e6a1a29ad0 100644
--- a/mlir/examples/toy/Ch3/include/toy/Ops.td
+++ b/mlir/examples/toy/Ch3/include/toy/Ops.td
@@ -133,7 +133,9 @@ def FuncOp : Toy_Op<"func", [
 
   let arguments = (ins
     SymbolNameAttr:$sym_name,
-    TypeAttrOf<FunctionType>:$function_type
+    TypeAttrOf<FunctionType>:$function_type,
+    OptionalAttr<DictArrayAttr>:$arg_attrs,
+    OptionalAttr<DictArrayAttr>:$res_attrs
   );
   let regions = (region AnyRegion:$body);
 

diff  --git a/mlir/examples/toy/Ch3/mlir/Dialect.cpp b/mlir/examples/toy/Ch3/mlir/Dialect.cpp
index 75cb57ebd7b17..4bd10551b73f3 100644
--- a/mlir/examples/toy/Ch3/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch3/mlir/Dialect.cpp
@@ -205,8 +205,9 @@ mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser,
 void FuncOp::print(mlir::OpAsmPrinter &p) {
   // Dispatch to the FunctionOpInterface provided utility method that prints the
   // function operation.
-  mlir::function_interface_impl::printFunctionOp(p, *this,
-                                                 /*isVariadic=*/false);
+  mlir::function_interface_impl::printFunctionOp(
+      p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
+      getArgAttrsAttrName(), getResAttrsAttrName());
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/examples/toy/Ch4/include/toy/Ops.td b/mlir/examples/toy/Ch4/include/toy/Ops.td
index 4956b0ed2f793..cbece4767d150 100644
--- a/mlir/examples/toy/Ch4/include/toy/Ops.td
+++ b/mlir/examples/toy/Ch4/include/toy/Ops.td
@@ -163,7 +163,9 @@ def FuncOp : Toy_Op<"func", [
 
   let arguments = (ins
     SymbolNameAttr:$sym_name,
-    TypeAttrOf<FunctionType>:$function_type
+    TypeAttrOf<FunctionType>:$function_type,
+    OptionalAttr<DictArrayAttr>:$arg_attrs,
+    OptionalAttr<DictArrayAttr>:$res_attrs
   );
   let regions = (region AnyRegion:$body);
 

diff  --git a/mlir/examples/toy/Ch4/mlir/Dialect.cpp b/mlir/examples/toy/Ch4/mlir/Dialect.cpp
index 2d5a369630a0d..3a02ea36bc815 100644
--- a/mlir/examples/toy/Ch4/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch4/mlir/Dialect.cpp
@@ -294,8 +294,9 @@ mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser,
 void FuncOp::print(mlir::OpAsmPrinter &p) {
   // Dispatch to the FunctionOpInterface provided utility method that prints the
   // function operation.
-  mlir::function_interface_impl::printFunctionOp(p, *this,
-                                                 /*isVariadic=*/false);
+  mlir::function_interface_impl::printFunctionOp(
+      p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
+      getArgAttrsAttrName(), getResAttrsAttrName());
 }
 
 /// Returns the region on the function operation that is callable.

diff  --git a/mlir/examples/toy/Ch5/include/toy/Ops.td b/mlir/examples/toy/Ch5/include/toy/Ops.td
index f4e7b08732ed3..70e482dd76ebe 100644
--- a/mlir/examples/toy/Ch5/include/toy/Ops.td
+++ b/mlir/examples/toy/Ch5/include/toy/Ops.td
@@ -163,7 +163,9 @@ def FuncOp : Toy_Op<"func", [
 
   let arguments = (ins
     SymbolNameAttr:$sym_name,
-    TypeAttrOf<FunctionType>:$function_type
+    TypeAttrOf<FunctionType>:$function_type,
+    OptionalAttr<DictArrayAttr>:$arg_attrs,
+    OptionalAttr<DictArrayAttr>:$res_attrs
   );
   let regions = (region AnyRegion:$body);
 

diff  --git a/mlir/examples/toy/Ch5/mlir/Dialect.cpp b/mlir/examples/toy/Ch5/mlir/Dialect.cpp
index 280bf3122fbd5..49ce3d95fa7e5 100644
--- a/mlir/examples/toy/Ch5/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch5/mlir/Dialect.cpp
@@ -294,8 +294,9 @@ mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser,
 void FuncOp::print(mlir::OpAsmPrinter &p) {
   // Dispatch to the FunctionOpInterface provided utility method that prints the
   // function operation.
-  mlir::function_interface_impl::printFunctionOp(p, *this,
-                                                 /*isVariadic=*/false);
+  mlir::function_interface_impl::printFunctionOp(
+      p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
+      getArgAttrsAttrName(), getResAttrsAttrName());
 }
 
 /// Returns the region on the function operation that is callable.

diff  --git a/mlir/examples/toy/Ch6/include/toy/Ops.td b/mlir/examples/toy/Ch6/include/toy/Ops.td
index ea9323ece6259..cf2bc3f504806 100644
--- a/mlir/examples/toy/Ch6/include/toy/Ops.td
+++ b/mlir/examples/toy/Ch6/include/toy/Ops.td
@@ -163,7 +163,9 @@ def FuncOp : Toy_Op<"func", [
 
   let arguments = (ins
     SymbolNameAttr:$sym_name,
-    TypeAttrOf<FunctionType>:$function_type
+    TypeAttrOf<FunctionType>:$function_type,
+    OptionalAttr<DictArrayAttr>:$arg_attrs,
+    OptionalAttr<DictArrayAttr>:$res_attrs
   );
   let regions = (region AnyRegion:$body);
 

diff  --git a/mlir/examples/toy/Ch6/mlir/Dialect.cpp b/mlir/examples/toy/Ch6/mlir/Dialect.cpp
index 280bf3122fbd5..49ce3d95fa7e5 100644
--- a/mlir/examples/toy/Ch6/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch6/mlir/Dialect.cpp
@@ -294,8 +294,9 @@ mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser,
 void FuncOp::print(mlir::OpAsmPrinter &p) {
   // Dispatch to the FunctionOpInterface provided utility method that prints the
   // function operation.
-  mlir::function_interface_impl::printFunctionOp(p, *this,
-                                                 /*isVariadic=*/false);
+  mlir::function_interface_impl::printFunctionOp(
+      p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
+      getArgAttrsAttrName(), getResAttrsAttrName());
 }
 
 /// Returns the region on the function operation that is callable.

diff  --git a/mlir/examples/toy/Ch7/include/toy/Ops.td b/mlir/examples/toy/Ch7/include/toy/Ops.td
index 45ecdd3bc78f9..08671a7347c19 100644
--- a/mlir/examples/toy/Ch7/include/toy/Ops.td
+++ b/mlir/examples/toy/Ch7/include/toy/Ops.td
@@ -186,7 +186,9 @@ def FuncOp : Toy_Op<"func", [
 
   let arguments = (ins
     SymbolNameAttr:$sym_name,
-    TypeAttrOf<FunctionType>:$function_type
+    TypeAttrOf<FunctionType>:$function_type,
+    OptionalAttr<DictArrayAttr>:$arg_attrs,
+    OptionalAttr<DictArrayAttr>:$res_attrs
   );
   let regions = (region AnyRegion:$body);
 

diff  --git a/mlir/examples/toy/Ch7/mlir/Dialect.cpp b/mlir/examples/toy/Ch7/mlir/Dialect.cpp
index b0d213027f4c9..cb65a95eca6a2 100644
--- a/mlir/examples/toy/Ch7/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch7/mlir/Dialect.cpp
@@ -321,8 +321,9 @@ mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser,
 void FuncOp::print(mlir::OpAsmPrinter &p) {
   // Dispatch to the FunctionOpInterface provided utility method that prints the
   // function operation.
-  mlir::function_interface_impl::printFunctionOp(p, *this,
-                                                 /*isVariadic=*/false);
+  mlir::function_interface_impl::printFunctionOp(
+      p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
+      getArgAttrsAttrName(), getResAttrsAttrName());
 }
 
 /// Returns the region on the function operation that is callable.

diff  --git a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
index 30895e577ad6c..14146cd10ac63 100644
--- a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
+++ b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
@@ -140,7 +140,9 @@ def Async_FuncOp : Async_Op<"func",
 
   let arguments = (ins SymbolNameAttr:$sym_name,
                        TypeAttrOf<FunctionType>:$function_type,
-                       OptionalAttr<StrAttr>:$sym_visibility);
+                       OptionalAttr<StrAttr>:$sym_visibility,
+                       OptionalAttr<DictArrayAttr>:$arg_attrs,
+                       OptionalAttr<DictArrayAttr>:$res_attrs);
 
   let regions = (region AnyRegion:$body);
 

diff  --git a/mlir/include/mlir/Dialect/Func/IR/FuncOps.td b/mlir/include/mlir/Dialect/Func/IR/FuncOps.td
index f1b7cfdb63022..4922689fbef60 100644
--- a/mlir/include/mlir/Dialect/Func/IR/FuncOps.td
+++ b/mlir/include/mlir/Dialect/Func/IR/FuncOps.td
@@ -251,7 +251,9 @@ def FuncOp : Func_Op<"func", [
 
   let arguments = (ins SymbolNameAttr:$sym_name,
                        TypeAttrOf<FunctionType>:$function_type,
-                       OptionalAttr<StrAttr>:$sym_visibility);
+                       OptionalAttr<StrAttr>:$sym_visibility,
+                       OptionalAttr<DictArrayAttr>:$arg_attrs,
+                       OptionalAttr<DictArrayAttr>:$res_attrs);
   let regions = (region AnyRegion:$body);
 
   let builders = [OpBuilder<(ins

diff  --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index 0642b1865b5fd..f9fff78861ea1 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -242,7 +242,9 @@ def GPU_GPUFuncOp : GPU_Op<"func", [
     attribution.
   }];
 
-  let arguments = (ins TypeAttrOf<FunctionType>:$function_type);
+  let arguments = (ins TypeAttrOf<FunctionType>:$function_type,
+                       OptionalAttr<DictArrayAttr>:$arg_attrs,
+                       OptionalAttr<DictArrayAttr>:$res_attrs);
   let regions = (region AnyRegion:$body);
 
   let skipDefaultBuilders = 1;

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index c36e390377b32..4d2b2f9f94f9d 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -1308,7 +1308,9 @@ def LLVM_LLVMFuncOp : LLVM_Op<"func", [
     DefaultValuedAttr<CConv, "CConv::C">:$CConv,
     OptionalAttr<FlatSymbolRefAttr>:$personality,
     OptionalAttr<StrAttr>:$garbageCollector,
-    OptionalAttr<ArrayAttr>:$passthrough
+    OptionalAttr<ArrayAttr>:$passthrough,
+    OptionalAttr<DictArrayAttr>:$arg_attrs,
+    OptionalAttr<DictArrayAttr>:$res_attrs
   );
 
   let regions = (region AnyRegion:$body);

diff  --git a/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramOps.td b/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramOps.td
index 422680a06deaa..db6c7733130cb 100644
--- a/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramOps.td
+++ b/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramOps.td
@@ -52,6 +52,8 @@ def MLProgram_FuncOp : MLProgram_Op<"func", [
 
   let arguments = (ins SymbolNameAttr:$sym_name,
                        TypeAttrOf<FunctionType>:$function_type,
+                       OptionalAttr<DictArrayAttr>:$arg_attrs,
+                       OptionalAttr<DictArrayAttr>:$res_attrs,
                        OptionalAttr<StrAttr>:$sym_visibility);
   let regions = (region AnyRegion:$body);
 
@@ -401,6 +403,8 @@ def MLProgram_SubgraphOp : MLProgram_Op<"subgraph", [
 
   let arguments = (ins SymbolNameAttr:$sym_name,
                        TypeAttrOf<FunctionType>:$function_type,
+                       OptionalAttr<DictArrayAttr>:$arg_attrs,
+                       OptionalAttr<DictArrayAttr>:$res_attrs,
                        OptionalAttr<StrAttr>:$sym_visibility);
   let regions = (region AnyRegion:$body);
 

diff  --git a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td
index 42a48cdb2ff94..6ecbed2ac1b67 100644
--- a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td
+++ b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td
@@ -652,7 +652,9 @@ def PDLInterp_FuncOp : PDLInterp_Op<"func", [
 
   let arguments = (ins
     SymbolNameAttr:$sym_name,
-    TypeAttrOf<FunctionType>:$function_type
+    TypeAttrOf<FunctionType>:$function_type,
+    OptionalAttr<DictArrayAttr>:$arg_attrs,
+    OptionalAttr<DictArrayAttr>:$res_attrs
   );
   let regions = (region MinSizedRegion<1>:$body);
 

diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
index 147705ecb873c..8339afc4f7a33 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
@@ -291,6 +291,8 @@ def SPIRV_FuncOp : SPIRV_Op<"func", [
 
   let arguments = (ins
     TypeAttrOf<FunctionType>:$function_type,
+    OptionalAttr<DictArrayAttr>:$arg_attrs,
+    OptionalAttr<DictArrayAttr>:$res_attrs,
     StrAttr:$sym_name,
     SPIRV_FunctionControlAttr:$function_control
   );

diff  --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index c3697f01bb776..97d1f0c228421 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -1107,6 +1107,8 @@ def Shape_FuncOp : Shape_Op<"func",
 
   let arguments = (ins SymbolNameAttr:$sym_name,
                        TypeAttrOf<FunctionType>:$function_type,
+                       OptionalAttr<DictArrayAttr>:$arg_attrs,
+                       OptionalAttr<DictArrayAttr>:$res_attrs,
                        OptionalAttr<StrAttr>:$sym_visibility);
   let regions = (region AnyRegion:$body);
 

diff  --git a/mlir/include/mlir/IR/FunctionImplementation.h b/mlir/include/mlir/IR/FunctionImplementation.h
index f4c0cc03050fe..eb7979095c33c 100644
--- a/mlir/include/mlir/IR/FunctionImplementation.h
+++ b/mlir/include/mlir/IR/FunctionImplementation.h
@@ -39,10 +39,12 @@ class VariadicFlag {
 /// with special names given by getResultAttrName, getArgumentAttrName.
 void addArgAndResultAttrs(Builder &builder, OperationState &result,
                           ArrayRef<DictionaryAttr> argAttrs,
-                          ArrayRef<DictionaryAttr> resultAttrs);
+                          ArrayRef<DictionaryAttr> resultAttrs,
+                          StringAttr argAttrsName, StringAttr resAttrsName);
 void addArgAndResultAttrs(Builder &builder, OperationState &result,
-                          ArrayRef<OpAsmParser::Argument> argAttrs,
-                          ArrayRef<DictionaryAttr> resultAttrs);
+                          ArrayRef<OpAsmParser::Argument> args,
+                          ArrayRef<DictionaryAttr> resultAttrs,
+                          StringAttr argAttrsName, StringAttr resAttrsName);
 
 /// Callback type for `parseFunctionOp`, the callback should produce the
 /// type that will be associated with a function-like operation from lists of
@@ -77,15 +79,17 @@ Type getFunctionType(Builder &builder, ArrayRef<OpAsmParser::Argument> argAttrs,
 /// type, report the error or delegate the reporting to the op's verifier.
 ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result,
                             bool allowVariadic, StringAttr typeAttrName,
-                            FuncTypeBuilder funcTypeBuilder);
+                            FuncTypeBuilder funcTypeBuilder,
+                            StringAttr argAttrsName, StringAttr resAttrsName);
 
 /// Printer implementation for function-like operations.
 void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic,
-                     StringRef typeAttrName);
+                     StringRef typeAttrName, StringAttr argAttrsName,
+                     StringAttr resAttrsName);
 
 /// Prints the signature of the function-like operation `op`. Assumes `op` has
 /// is a FunctionOpInterface and has passed verification.
-void printFunctionSignature(OpAsmPrinter &p, Operation *op,
+void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op,
                             ArrayRef<Type> argTypes, bool isVariadic,
                             ArrayRef<Type> resultTypes);
 

diff  --git a/mlir/include/mlir/IR/FunctionInterfaces.h b/mlir/include/mlir/IR/FunctionInterfaces.h
index bc2ec4751c582..3beb3db4e5662 100644
--- a/mlir/include/mlir/IR/FunctionInterfaces.h
+++ b/mlir/include/mlir/IR/FunctionInterfaces.h
@@ -26,48 +26,30 @@ class FunctionOpInterface;
 
 namespace function_interface_impl {
 
-/// Return the name of the attribute used for function argument attributes.
-inline StringRef getArgDictAttrName() { return "arg_attrs"; }
-
-/// Return the name of the attribute used for function argument attributes.
-inline StringRef getResultDictAttrName() { return "res_attrs"; }
-
 /// Returns the dictionary attribute corresponding to the argument at 'index'.
 /// If there are no argument attributes at 'index', a null attribute is
 /// returned.
-DictionaryAttr getArgAttrDict(Operation *op, unsigned index);
+DictionaryAttr getArgAttrDict(FunctionOpInterface op, unsigned index);
 
 /// Returns the dictionary attribute corresponding to the result at 'index'.
 /// If there are no result attributes at 'index', a null attribute is
 /// returned.
-DictionaryAttr getResultAttrDict(Operation *op, unsigned index);
+DictionaryAttr getResultAttrDict(FunctionOpInterface op, unsigned index);
 
-namespace detail {
-/// Update the given index into an argument or result attribute dictionary.
-void setArgResAttrDict(Operation *op, StringRef attrName,
-                       unsigned numTotalIndices, unsigned index,
-                       DictionaryAttr attrs);
-} // namespace detail
+/// Return all of the attributes for the argument at 'index'.
+ArrayRef<NamedAttribute> getArgAttrs(FunctionOpInterface op, unsigned index);
+
+/// Return all of the attributes for the result at 'index'.
+ArrayRef<NamedAttribute> getResultAttrs(FunctionOpInterface op, unsigned index);
 
 /// Set all of the argument or result attribute dictionaries for a function. The
 /// size of `attrs` is expected to match the number of arguments/results of the
 /// given `op`.
-void setAllArgAttrDicts(Operation *op, ArrayRef<DictionaryAttr> attrs);
-void setAllArgAttrDicts(Operation *op, ArrayRef<Attribute> attrs);
-void setAllResultAttrDicts(Operation *op, ArrayRef<DictionaryAttr> attrs);
-void setAllResultAttrDicts(Operation *op, ArrayRef<Attribute> attrs);
-
-/// Return all of the attributes for the argument at 'index'.
-inline ArrayRef<NamedAttribute> getArgAttrs(Operation *op, unsigned index) {
-  auto argDict = getArgAttrDict(op, index);
-  return argDict ? argDict.getValue() : std::nullopt;
-}
-
-/// Return all of the attributes for the result at 'index'.
-inline ArrayRef<NamedAttribute> getResultAttrs(Operation *op, unsigned index) {
-  auto resultDict = getResultAttrDict(op, index);
-  return resultDict ? resultDict.getValue() : std::nullopt;
-}
+void setAllArgAttrDicts(FunctionOpInterface op, ArrayRef<DictionaryAttr> attrs);
+void setAllArgAttrDicts(FunctionOpInterface op, ArrayRef<Attribute> attrs);
+void setAllResultAttrDicts(FunctionOpInterface op,
+                           ArrayRef<DictionaryAttr> attrs);
+void setAllResultAttrDicts(FunctionOpInterface op, ArrayRef<Attribute> attrs);
 
 /// Insert the specified arguments and update the function type attribute.
 void insertFunctionArguments(FunctionOpInterface op,
@@ -110,20 +92,10 @@ TypeRange filterTypesOut(TypeRange types, const BitVector &indices,
 //===----------------------------------------------------------------------===//
 
 /// Set the attributes held by the argument at 'index'.
-template <typename ConcreteType>
-void setArgAttrs(ConcreteType op, unsigned index,
-                 ArrayRef<NamedAttribute> attributes) {
-  assert(index < op.getNumArguments() && "invalid argument number");
-  return detail::setArgResAttrDict(
-      op, getArgDictAttrName(), op.getNumArguments(), index,
-      DictionaryAttr::get(op->getContext(), attributes));
-}
-template <typename ConcreteType>
-void setArgAttrs(ConcreteType op, unsigned index, DictionaryAttr attributes) {
-  return detail::setArgResAttrDict(
-      op, getArgDictAttrName(), op.getNumArguments(), index,
-      attributes ? attributes : DictionaryAttr::get(op->getContext()));
-}
+void setArgAttrs(FunctionOpInterface op, unsigned index,
+                 ArrayRef<NamedAttribute> attributes);
+void setArgAttrs(FunctionOpInterface op, unsigned index,
+                 DictionaryAttr attributes);
 
 /// If the an attribute exists with the specified name, change it to the new
 /// value. Otherwise, add a new attribute with the specified name/value.
@@ -157,23 +129,10 @@ Attribute removeArgAttr(ConcreteType op, unsigned index, StringAttr name) {
 //===----------------------------------------------------------------------===//
 
 /// Set the attributes held by the result at 'index'.
-template <typename ConcreteType>
-void setResultAttrs(ConcreteType op, unsigned index,
-                    ArrayRef<NamedAttribute> attributes) {
-  assert(index < op.getNumResults() && "invalid result number");
-  return detail::setArgResAttrDict(
-      op, getResultDictAttrName(), op.getNumResults(), index,
-      DictionaryAttr::get(op->getContext(), attributes));
-}
-
-template <typename ConcreteType>
-void setResultAttrs(ConcreteType op, unsigned index,
-                    DictionaryAttr attributes) {
-  assert(index < op.getNumResults() && "invalid result number");
-  return detail::setArgResAttrDict(
-      op, getResultDictAttrName(), op.getNumResults(), index,
-      attributes ? attributes : DictionaryAttr::get(op->getContext()));
-}
+void setResultAttrs(FunctionOpInterface op, unsigned index,
+                    ArrayRef<NamedAttribute> attributes);
+void setResultAttrs(FunctionOpInterface op, unsigned index,
+                    DictionaryAttr attributes);
 
 /// If the an attribute exists with the specified name, change it to the new
 /// value. Otherwise, add a new attribute with the specified name/value.
@@ -213,9 +172,8 @@ LogicalResult verifyTrait(ConcreteOp op) {
     unsigned numArgs = op.getNumArguments();
     if (allArgAttrs.size() != numArgs) {
       return op.emitOpError()
-             << "expects argument attribute array `" << getArgDictAttrName()
-             << "` to have the same number of elements as the number of "
-                "function arguments, got "
+             << "expects argument attribute array to have the same number of "
+                "elements as the number of function arguments, got "
              << allArgAttrs.size() << ", but expected " << numArgs;
     }
     for (unsigned i = 0; i != numArgs; ++i) {
@@ -245,9 +203,8 @@ LogicalResult verifyTrait(ConcreteOp op) {
     unsigned numResults = op.getNumResults();
     if (allResultAttrs.size() != numResults) {
       return op.emitOpError()
-             << "expects result attribute array `" << getResultDictAttrName()
-             << "` to have the same number of elements as the number of "
-                "function results, got "
+             << "expects result attribute array to have the same number of "
+                "elements as the number of function results, got "
              << allResultAttrs.size() << ", but expected " << numResults;
     }
     for (unsigned i = 0; i != numResults; ++i) {

diff  --git a/mlir/include/mlir/IR/FunctionInterfaces.td b/mlir/include/mlir/IR/FunctionInterfaces.td
index e86057aa7ec2f..0e8a3addfa7ee 100644
--- a/mlir/include/mlir/IR/FunctionInterfaces.td
+++ b/mlir/include/mlir/IR/FunctionInterfaces.td
@@ -59,6 +59,42 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface"> {
       result attributes.
     }],
     "void", "setFunctionTypeAttr", (ins "::mlir::TypeAttr":$type)>,
+
+    InterfaceMethod<[{
+      Get the array of argument attribute dictionaries. The method should return
+      an array attribute containing only dictionary attributes equal in number
+      to the number of function arguments. Alternatively, the method can return
+      null to indicate that the function has no argument attributes.
+    }],
+    "::mlir::ArrayAttr", "getArgAttrsAttr">,
+    InterfaceMethod<[{
+      Get the array of result attribute dictionaries. The method should return
+      an array attribute containing only dictionary attributes equal in number
+      to the number of function results. Alternatively, the method can return
+      null to indicate that the function has no result attributes.
+    }],
+    "::mlir::ArrayAttr", "getResAttrsAttr">,
+    InterfaceMethod<[{
+      Set the array of argument attribute dictionaries.
+    }],
+    "void", "setArgAttrsAttr", (ins "::mlir::ArrayAttr":$attrs)>,
+    InterfaceMethod<[{
+      Set the array of result attribute dictionaries.
+    }],
+    "void", "setResAttrsAttr", (ins "::mlir::ArrayAttr":$attrs)>,
+    InterfaceMethod<[{
+      Remove the array of argument attribute dictionaries. This is the same as
+      setting all argument attributes to an empty dictionary. The method should
+      return the removed attribute.
+    }],
+    "::mlir::Attribute", "removeArgAttrsAttr">,
+    InterfaceMethod<[{
+      Remove the array of result attribute dictionaries. This is the same as
+      setting all result attributes to an empty dictionary. The method should
+      return the removed attribute.
+    }],
+    "::mlir::Attribute", "removeResAttrsAttr">,
+
     InterfaceMethod<[{
       Returns the function argument types based exclusively on
       the type (to allow for this method may be called on function
@@ -250,20 +286,6 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface"> {
       function_interface_impl::setFunctionType(this->getOperation(), newType);
     }
 
-    // FIXME: These functions should be removed in favor of just forwarding to
-    // the derived operation, which should already have these defined
-    // (via ODS).
-
-    /// Returns the name of the attribute used for function argument attributes.
-    static StringRef getArgDictAttrName() {
-      return function_interface_impl::getArgDictAttrName();
-    }
-
-    /// Returns the name of the attribute used for function argument attributes.
-    static StringRef getResultDictAttrName() {
-      return function_interface_impl::getResultDictAttrName();
-    }
-
     //===------------------------------------------------------------------===//
     // Argument and Result Handling
     //===------------------------------------------------------------------===//
@@ -405,10 +427,8 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface"> {
 
     /// Return an ArrayAttr containing all argument attribute dictionaries of
     /// this function, or nullptr if no arguments have attributes.
-    ArrayAttr getAllArgAttrs() {
-      return this->getOperation()->template getAttrOfType<ArrayAttr>(
-          getArgDictAttrName());
-    }
+    ArrayAttr getAllArgAttrs() { return $_op.getArgAttrsAttr(); }
+
     /// Return all argument attributes of this function.
     void getAllArgAttrs(SmallVectorImpl<DictionaryAttr> &result) {
       if (ArrayAttr argAttrs = getAllArgAttrs()) {
@@ -460,7 +480,7 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface"> {
     }
     void setAllArgAttrs(ArrayAttr attributes) {
       assert(attributes.size() == $_op.getNumArguments());
-      this->getOperation()->setAttr(getArgDictAttrName(), attributes);
+      $_op.setArgAttrsAttr(attributes);
     }
 
     /// If the an attribute exists with the specified name, change it to the new
@@ -496,10 +516,8 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface"> {
 
     /// Return an ArrayAttr containing all result attribute dictionaries of this
     /// function, or nullptr if no result have attributes.
-    ArrayAttr getAllResultAttrs() {
-      return this->getOperation()->template getAttrOfType<ArrayAttr>(
-          getResultDictAttrName());
-    }
+    ArrayAttr getAllResultAttrs() { return $_op.getResAttrsAttr(); }
+
     /// Return all result attributes of this function.
     void getAllResultAttrs(SmallVectorImpl<DictionaryAttr> &result) {
       if (ArrayAttr argAttrs = getAllResultAttrs()) {
@@ -553,7 +571,7 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface"> {
     }
     void setAllResultAttrs(ArrayAttr attributes) {
       assert(attributes.size() == $_op.getNumResults());
-      this->getOperation()->setAttr(getResultDictAttrName(), attributes);
+      $_op.setResAttrsAttr(attributes);
     }
 
     /// If the an attribute exists with the specified name, change it to the new

diff  --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 2affd9ae3e031..400f67162aaf9 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -1524,6 +1524,8 @@ def TypeArrayAttr : TypedArrayAttrBase<TypeAttr, "type array attribute"> {
 }
 def IndexListArrayAttr :
   TypedArrayAttrBase<I64ArrayAttr, "Array of 64-bit integer array attributes">;
+def DictArrayAttr :
+  TypedArrayAttrBase<DictionaryAttr, "Array of dictionary attributes">;
 
 // Attributes containing symbol references.
 def SymbolRefAttr : Attr<CPred<"$_self.isa<::mlir::SymbolRefAttr>()">,

diff  --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index 9f522aaa49f92..0cd024e4ae89b 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -66,8 +66,8 @@ static void filterFuncAttributes(func::FuncOp func, bool filterArgAndResAttrs,
         attr.getName() == func.getFunctionTypeAttrName() ||
         attr.getName() == "func.varargs" ||
         (filterArgAndResAttrs &&
-         (attr.getName() == FunctionOpInterface::getArgDictAttrName() ||
-          attr.getName() == FunctionOpInterface::getResultDictAttrName())))
+         (attr.getName() == func.getArgAttrsAttrName() ||
+          attr.getName() == func.getResAttrsAttrName())))
       continue;
     result.push_back(attr);
   }
@@ -90,18 +90,19 @@ static auto wrapAsStructAttrs(OpBuilder &b, ArrayAttr attrs) {
 static void
 prependResAttrsToArgAttrs(OpBuilder &builder,
                           SmallVectorImpl<NamedAttribute> &attributes,
-                          size_t numArguments) {
+                          func::FuncOp func) {
+  size_t numArguments = func.getNumArguments();
   auto allAttrs = SmallVector<Attribute>(
       numArguments + 1, DictionaryAttr::get(builder.getContext()));
   NamedAttribute *argAttrs = nullptr;
   for (auto *it = attributes.begin(); it != attributes.end();) {
-    if (it->getName() == FunctionOpInterface::getArgDictAttrName()) {
+    if (it->getName() == func.getArgAttrsAttrName()) {
       auto arrayAttrs = it->getValue().cast<ArrayAttr>();
       assert(arrayAttrs.size() == numArguments &&
              "Number of arg attrs and args should match");
       std::copy(arrayAttrs.begin(), arrayAttrs.end(), allAttrs.begin() + 1);
       argAttrs = it;
-    } else if (it->getName() == FunctionOpInterface::getResultDictAttrName()) {
+    } else if (it->getName() == func.getResAttrsAttrName()) {
       auto arrayAttrs = it->getValue().cast<ArrayAttr>();
       assert(!arrayAttrs.empty() && "expected array to be non-empty");
       allAttrs[0] = (arrayAttrs.size() == 1)
@@ -113,9 +114,8 @@ prependResAttrsToArgAttrs(OpBuilder &builder,
     it++;
   }
 
-  auto newArgAttrs =
-      builder.getNamedAttr(FunctionOpInterface::getArgDictAttrName(),
-                           builder.getArrayAttr(allAttrs));
+  auto newArgAttrs = builder.getNamedAttr(func.getArgAttrsAttrName(),
+                                          builder.getArrayAttr(allAttrs));
   if (!argAttrs) {
     attributes.emplace_back(newArgAttrs);
     return;
@@ -141,7 +141,7 @@ static void wrapForExternalCallers(OpBuilder &rewriter, Location loc,
   auto [wrapperFuncType, resultIsNowArg] =
       typeConverter.convertFunctionTypeCWrapper(type);
   if (resultIsNowArg)
-    prependResAttrsToArgAttrs(rewriter, attributes, funcOp.getNumArguments());
+    prependResAttrsToArgAttrs(rewriter, attributes, funcOp);
   auto wrapperFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
       loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(),
       wrapperFuncType, LLVM::Linkage::External, /*dsoLocal*/ false,
@@ -205,7 +205,7 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc,
   filterFuncAttributes(funcOp, /*filterArgAndResAttrs=*/false, attributes);
 
   if (resultIsNowArg)
-    prependResAttrsToArgAttrs(builder, attributes, funcOp.getNumArguments());
+    prependResAttrsToArgAttrs(builder, attributes, funcOp);
   // Create the auxiliary function.
   auto wrapperFunc = builder.create<LLVM::LLVMFuncOp>(
       loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(),
@@ -309,8 +309,8 @@ struct FuncOpConversionBase : public ConvertOpToLLVMPattern<func::FuncOp> {
               ? resAttrDicts
               : rewriter.getArrayAttr(
                     {wrapAsStructAttrs(rewriter, resAttrDicts)});
-      attributes.push_back(rewriter.getNamedAttr(
-          FunctionOpInterface::getResultDictAttrName(), newResAttrDicts));
+      attributes.push_back(
+          rewriter.getNamedAttr(funcOp.getResAttrsAttrName(), newResAttrDicts));
     }
     if (ArrayAttr argAttrDicts = funcOp.getAllArgAttrs()) {
       SmallVector<Attribute, 4> newArgAttrs(
@@ -353,9 +353,8 @@ struct FuncOpConversionBase : public ConvertOpToLLVMPattern<func::FuncOp> {
           newArgAttrs[mapping->inputNo + j] =
               DictionaryAttr::get(rewriter.getContext(), convertedAttrs);
       }
-      attributes.push_back(
-          rewriter.getNamedAttr(FunctionOpInterface::getArgDictAttrName(),
-                                rewriter.getArrayAttr(newArgAttrs)));
+      attributes.push_back(rewriter.getNamedAttr(
+          funcOp.getArgAttrsAttrName(), rewriter.getArrayAttr(newArgAttrs)));
     }
     for (const auto &pair : llvm::enumerate(attributes)) {
       if (pair.value().getName() == "llvm.linkage") {

diff  --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp
index 064bf525db238..54acc373018c0 100644
--- a/mlir/lib/Dialect/Async/IR/Async.cpp
+++ b/mlir/lib/Dialect/Async/IR/Async.cpp
@@ -340,8 +340,9 @@ void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
   if (argAttrs.empty())
     return;
   assert(type.getNumInputs() == argAttrs.size());
-  function_interface_impl::addArgAndResultAttrs(builder, state, argAttrs,
-                                                /*resultAttrs=*/std::nullopt);
+  function_interface_impl::addArgAndResultAttrs(
+      builder, state, argAttrs, /*resultAttrs=*/std::nullopt,
+      getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name));
 }
 
 ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
@@ -352,12 +353,14 @@ ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
 
   return function_interface_impl::parseFunctionOp(
       parser, result, /*allowVariadic=*/false,
-      getFunctionTypeAttrName(result.name), buildFuncType);
+      getFunctionTypeAttrName(result.name), buildFuncType,
+      getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
 }
 
 void FuncOp::print(OpAsmPrinter &p) {
-  function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false,
-                                           getFunctionTypeAttrName());
+  function_interface_impl::printFunctionOp(
+      p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
+      getArgAttrsAttrName(), getResAttrsAttrName());
 }
 
 /// Check that the result type of async.func is not void and must be

diff  --git a/mlir/lib/Dialect/Func/IR/FuncOps.cpp b/mlir/lib/Dialect/Func/IR/FuncOps.cpp
index fc9bd115e2223..7bb3663cc43be 100644
--- a/mlir/lib/Dialect/Func/IR/FuncOps.cpp
+++ b/mlir/lib/Dialect/Func/IR/FuncOps.cpp
@@ -251,8 +251,9 @@ void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
   if (argAttrs.empty())
     return;
   assert(type.getNumInputs() == argAttrs.size());
-  function_interface_impl::addArgAndResultAttrs(builder, state, argAttrs,
-                                                /*resultAttrs=*/std::nullopt);
+  function_interface_impl::addArgAndResultAttrs(
+      builder, state, argAttrs, /*resultAttrs=*/std::nullopt,
+      getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name));
 }
 
 ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
@@ -263,12 +264,14 @@ ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
 
   return function_interface_impl::parseFunctionOp(
       parser, result, /*allowVariadic=*/false,
-      getFunctionTypeAttrName(result.name), buildFuncType);
+      getFunctionTypeAttrName(result.name), buildFuncType,
+      getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
 }
 
 void FuncOp::print(OpAsmPrinter &p) {
-  function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false,
-                                           getFunctionTypeAttrName());
+  function_interface_impl::printFunctionOp(
+      p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
+      getArgAttrsAttrName(), getResAttrsAttrName());
 }
 
 /// Clone the internal blocks from this function into dest and all attributes

diff  --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 80db6461ecc55..9ea1b11593694 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -934,8 +934,9 @@ ParseResult GPUFuncOp::parse(OpAsmParser &parser, OperationState &result) {
   result.addAttribute(getFunctionTypeAttrName(result.name),
                       TypeAttr::get(type));
 
-  function_interface_impl::addArgAndResultAttrs(builder, result, entryArgs,
-                                                resultAttrs);
+  function_interface_impl::addArgAndResultAttrs(
+      builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name),
+      getResAttrsAttrName(result.name));
 
   // Parse workgroup memory attributions.
   if (failed(parseAttributions(parser, GPUFuncOp::getWorkgroupKeyword(),
@@ -996,7 +997,8 @@ void GPUFuncOp::print(OpAsmPrinter &p) {
   function_interface_impl::printFunctionAttributes(
       p, *this,
       {getNumWorkgroupAttributionsAttrName(),
-       GPUDialect::getKernelFuncAttrName(), getFunctionTypeAttrName()});
+       GPUDialect::getKernelFuncAttrName(), getFunctionTypeAttrName(),
+       getArgAttrsAttrName(), getResAttrsAttrName()});
   p << ' ';
   p.printRegion(getBody(), /*printEntryBlockArgs=*/false);
 }

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 6b428a171d79e..cff547ecda07e 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -2006,8 +2006,9 @@ void LLVMFuncOp::build(OpBuilder &builder, OperationState &result,
 
   assert(type.cast<LLVMFunctionType>().getNumParams() == argAttrs.size() &&
          "expected as many argument attribute lists as arguments");
-  function_interface_impl::addArgAndResultAttrs(builder, result, argAttrs,
-                                                /*resultAttrs=*/std::nullopt);
+  function_interface_impl::addArgAndResultAttrs(
+      builder, result, argAttrs, /*resultAttrs=*/std::nullopt,
+      getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
 }
 
 // Builds an LLVM function type from the given lists of input and output types.
@@ -2095,8 +2096,9 @@ ParseResult LLVMFuncOp::parse(OpAsmParser &parser, OperationState &result) {
 
   if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
     return failure();
-  function_interface_impl::addArgAndResultAttrs(parser.getBuilder(), result,
-                                                entryArgs, resultAttrs);
+  function_interface_impl::addArgAndResultAttrs(
+      parser.getBuilder(), result, entryArgs, resultAttrs,
+      getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
 
   auto *body = result.addRegion();
   OptionalParseResult parseResult =
@@ -2131,7 +2133,8 @@ void LLVMFuncOp::print(OpAsmPrinter &p) {
                                                   isVarArg(), resTypes);
   function_interface_impl::printFunctionAttributes(
       p, *this,
-      {getFunctionTypeAttrName(), getLinkageAttrName(), getCConvAttrName()});
+      {getFunctionTypeAttrName(), getArgAttrsAttrName(), getResAttrsAttrName(),
+       getLinkageAttrName(), getCConvAttrName()});
 
   // Print the body if this is not an external function.
   Region &body = getBody();

diff  --git a/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp b/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp
index 27c613088df56..31ed5ad3b9878 100644
--- a/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp
+++ b/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp
@@ -153,12 +153,14 @@ ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
 
   return function_interface_impl::parseFunctionOp(
       parser, result, /*allowVariadic=*/false,
-      getFunctionTypeAttrName(result.name), buildFuncType);
+      getFunctionTypeAttrName(result.name), buildFuncType,
+      getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
 }
 
 void FuncOp::print(OpAsmPrinter &p) {
-  function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false,
-                                           getFunctionTypeAttrName());
+  function_interface_impl::printFunctionOp(
+      p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
+      getArgAttrsAttrName(), getResAttrsAttrName());
 }
 
 //===----------------------------------------------------------------------===//
@@ -316,12 +318,14 @@ ParseResult SubgraphOp::parse(OpAsmParser &parser, OperationState &result) {
 
   return function_interface_impl::parseFunctionOp(
       parser, result, /*allowVariadic=*/false,
-      getFunctionTypeAttrName(result.name), buildFuncType);
+      getFunctionTypeAttrName(result.name), buildFuncType,
+      getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
 }
 
 void SubgraphOp::print(OpAsmPrinter &p) {
-  function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false,
-                                           getFunctionTypeAttrName());
+  function_interface_impl::printFunctionOp(
+      p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
+      getArgAttrsAttrName(), getResAttrsAttrName());
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp b/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp
index 28fc4dbf97aea..2cc282d25a262 100644
--- a/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp
+++ b/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp
@@ -221,12 +221,14 @@ ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
 
   return function_interface_impl::parseFunctionOp(
       parser, result, /*allowVariadic=*/false,
-      getFunctionTypeAttrName(result.name), buildFuncType);
+      getFunctionTypeAttrName(result.name), buildFuncType,
+      getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
 }
 
 void FuncOp::print(OpAsmPrinter &p) {
-  function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false,
-                                           getFunctionTypeAttrName());
+  function_interface_impl::printFunctionOp(
+      p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
+      getArgAttrsAttrName(), getResAttrsAttrName());
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 3ce3913f2814b..3341b5e367561 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -2396,8 +2396,9 @@ ParseResult spirv::FuncOp::parse(OpAsmParser &parser, OperationState &result) {
 
   // Add the attributes to the function arguments.
   assert(resultAttrs.size() == resultTypes.size());
-  function_interface_impl::addArgAndResultAttrs(builder, result, entryArgs,
-                                                resultAttrs);
+  function_interface_impl::addArgAndResultAttrs(
+      builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name),
+      getResAttrsAttrName(result.name));
 
   // Parse the optional function body.
   auto *body = result.addRegion();
@@ -2419,7 +2420,8 @@ void spirv::FuncOp::print(OpAsmPrinter &printer) {
   function_interface_impl::printFunctionAttributes(
       printer, *this,
       {spirv::attributeName<spirv::FunctionControl>(),
-       getFunctionTypeAttrName(), getFunctionControlAttrName()});
+       getFunctionTypeAttrName(), getArgAttrsAttrName(), getResAttrsAttrName(),
+       getFunctionControlAttrName()});
 
   // Print the body if this is not an external function.
   Region &body = this->getBody();

diff  --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 30c5f56f13886..28ac98e5ab8a5 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -1300,8 +1300,9 @@ void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
   if (argAttrs.empty())
     return;
   assert(type.getNumInputs() == argAttrs.size());
-  function_interface_impl::addArgAndResultAttrs(builder, state, argAttrs,
-                                                /*resultAttrs=*/std::nullopt);
+  function_interface_impl::addArgAndResultAttrs(
+      builder, state, argAttrs, /*resultAttrs=*/std::nullopt,
+      getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name));
 }
 
 ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
@@ -1312,12 +1313,14 @@ ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
 
   return function_interface_impl::parseFunctionOp(
       parser, result, /*allowVariadic=*/false,
-      getFunctionTypeAttrName(result.name), buildFuncType);
+      getFunctionTypeAttrName(result.name), buildFuncType,
+      getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
 }
 
 void FuncOp::print(OpAsmPrinter &p) {
-  function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false,
-                                           getFunctionTypeAttrName());
+  function_interface_impl::printFunctionOp(
+      p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
+      getArgAttrsAttrName(), getResAttrsAttrName());
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/IR/FunctionImplementation.cpp b/mlir/lib/IR/FunctionImplementation.cpp
index af692befb0fde..5ca6777cea18f 100644
--- a/mlir/lib/IR/FunctionImplementation.cpp
+++ b/mlir/lib/IR/FunctionImplementation.cpp
@@ -113,7 +113,7 @@ parseFunctionResultList(OpAsmParser &parser, SmallVectorImpl<Type> &resultTypes,
   return parser.parseRParen();
 }
 
-ParseResult mlir::function_interface_impl::parseFunctionSignature(
+ParseResult function_interface_impl::parseFunctionSignature(
     OpAsmParser &parser, bool allowVariadic,
     SmallVectorImpl<OpAsmParser::Argument> &arguments, bool &isVariadic,
     SmallVectorImpl<Type> &resultTypes,
@@ -125,9 +125,10 @@ ParseResult mlir::function_interface_impl::parseFunctionSignature(
   return success();
 }
 
-void mlir::function_interface_impl::addArgAndResultAttrs(
+void function_interface_impl::addArgAndResultAttrs(
     Builder &builder, OperationState &result, ArrayRef<DictionaryAttr> argAttrs,
-    ArrayRef<DictionaryAttr> resultAttrs) {
+    ArrayRef<DictionaryAttr> resultAttrs, StringAttr argAttrsName,
+    StringAttr resAttrsName) {
   auto nonEmptyAttrsFn = [](DictionaryAttr attrs) {
     return attrs && !attrs.empty();
   };
@@ -142,28 +143,28 @@ void mlir::function_interface_impl::addArgAndResultAttrs(
 
   // Add the attributes to the function arguments.
   if (llvm::any_of(argAttrs, nonEmptyAttrsFn))
-    result.addAttribute(function_interface_impl::getArgDictAttrName(),
-                        getArrayAttr(argAttrs));
+    result.addAttribute(argAttrsName, getArrayAttr(argAttrs));
 
   // Add the attributes to the function results.
   if (llvm::any_of(resultAttrs, nonEmptyAttrsFn))
-    result.addAttribute(function_interface_impl::getResultDictAttrName(),
-                        getArrayAttr(resultAttrs));
+    result.addAttribute(resAttrsName, getArrayAttr(resultAttrs));
 }
 
-void mlir::function_interface_impl::addArgAndResultAttrs(
+void function_interface_impl::addArgAndResultAttrs(
     Builder &builder, OperationState &result,
-    ArrayRef<OpAsmParser::Argument> args,
-    ArrayRef<DictionaryAttr> resultAttrs) {
+    ArrayRef<OpAsmParser::Argument> args, ArrayRef<DictionaryAttr> resultAttrs,
+    StringAttr argAttrsName, StringAttr resAttrsName) {
   SmallVector<DictionaryAttr> argAttrs;
   for (const auto &arg : args)
     argAttrs.push_back(arg.attrs);
-  addArgAndResultAttrs(builder, result, argAttrs, resultAttrs);
+  addArgAndResultAttrs(builder, result, argAttrs, resultAttrs, argAttrsName,
+                       resAttrsName);
 }
 
-ParseResult mlir::function_interface_impl::parseFunctionOp(
+ParseResult function_interface_impl::parseFunctionOp(
     OpAsmParser &parser, OperationState &result, bool allowVariadic,
-    StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder) {
+    StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder,
+    StringAttr argAttrsName, StringAttr resAttrsName) {
   SmallVector<OpAsmParser::Argument> entryArgs;
   SmallVector<DictionaryAttr> resultAttrs;
   SmallVector<Type> resultTypes;
@@ -220,7 +221,8 @@ ParseResult mlir::function_interface_impl::parseFunctionOp(
 
   // Add the attributes to the function arguments.
   assert(resultAttrs.size() == resultTypes.size());
-  addArgAndResultAttrs(builder, result, entryArgs, resultAttrs);
+  addArgAndResultAttrs(builder, result, entryArgs, resultAttrs, argAttrsName,
+                       resAttrsName);
 
   // Parse the optional function body. The printer will not print the body if
   // its empty, so disallow parsing of empty body in the parser.
@@ -261,14 +263,14 @@ static void printFunctionResultList(OpAsmPrinter &p, ArrayRef<Type> types,
     os << ')';
 }
 
-void mlir::function_interface_impl::printFunctionSignature(
-    OpAsmPrinter &p, Operation *op, ArrayRef<Type> argTypes, bool isVariadic,
-    ArrayRef<Type> resultTypes) {
+void function_interface_impl::printFunctionSignature(
+    OpAsmPrinter &p, FunctionOpInterface op, ArrayRef<Type> argTypes,
+    bool isVariadic, ArrayRef<Type> resultTypes) {
   Region &body = op->getRegion(0);
   bool isExternal = body.empty();
 
   p << '(';
-  ArrayAttr argAttrs = op->getAttrOfType<ArrayAttr>(getArgDictAttrName());
+  ArrayAttr argAttrs = op.getArgAttrsAttr();
   for (unsigned i = 0, e = argTypes.size(); i < e; ++i) {
     if (i > 0)
       p << ", ";
@@ -295,26 +297,23 @@ void mlir::function_interface_impl::printFunctionSignature(
 
   if (!resultTypes.empty()) {
     p.getStream() << " -> ";
-    auto resultAttrs = op->getAttrOfType<ArrayAttr>(getResultDictAttrName());
+    auto resultAttrs = op.getResAttrsAttr();
     printFunctionResultList(p, resultTypes, resultAttrs);
   }
 }
 
-void mlir::function_interface_impl::printFunctionAttributes(
+void function_interface_impl::printFunctionAttributes(
     OpAsmPrinter &p, Operation *op, ArrayRef<StringRef> elided) {
   // Print out function attributes, if present.
-  SmallVector<StringRef, 2> ignoredAttrs = {SymbolTable::getSymbolAttrName(),
-                                            getArgDictAttrName(),
-                                            getResultDictAttrName()};
+  SmallVector<StringRef, 8> ignoredAttrs = {SymbolTable::getSymbolAttrName()};
   ignoredAttrs.append(elided.begin(), elided.end());
 
   p.printOptionalAttrDictWithKeyword(op->getAttrs(), ignoredAttrs);
 }
 
-void mlir::function_interface_impl::printFunctionOp(OpAsmPrinter &p,
-                                                    FunctionOpInterface op,
-                                                    bool isVariadic,
-                                                    StringRef typeAttrName) {
+void function_interface_impl::printFunctionOp(
+    OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic,
+    StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName) {
   // Print the operation and the function name.
   auto funcName =
       op->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName())
@@ -329,7 +328,8 @@ void mlir::function_interface_impl::printFunctionOp(OpAsmPrinter &p,
   ArrayRef<Type> argTypes = op.getArgumentTypes();
   ArrayRef<Type> resultTypes = op.getResultTypes();
   printFunctionSignature(p, op, argTypes, isVariadic, resultTypes);
-  printFunctionAttributes(p, op, {visibilityAttrName, typeAttrName});
+  printFunctionAttributes(
+      p, op, {visibilityAttrName, typeAttrName, argAttrsName, resAttrsName});
   // Print the body if this is not an external function.
   Region &body = op->getRegion(0);
   if (!body.empty()) {

diff  --git a/mlir/lib/IR/FunctionInterfaces.cpp b/mlir/lib/IR/FunctionInterfaces.cpp
index 9ba830366056c..347fb15f8fbe7 100644
--- a/mlir/lib/IR/FunctionInterfaces.cpp
+++ b/mlir/lib/IR/FunctionInterfaces.cpp
@@ -24,27 +24,104 @@ static bool isEmptyAttrDict(Attribute attr) {
   return attr.cast<DictionaryAttr>().empty();
 }
 
-DictionaryAttr mlir::function_interface_impl::getArgAttrDict(Operation *op,
-                                                             unsigned index) {
-  ArrayAttr attrs = op->getAttrOfType<ArrayAttr>(getArgDictAttrName());
+DictionaryAttr function_interface_impl::getArgAttrDict(FunctionOpInterface op,
+                                                       unsigned index) {
+  ArrayAttr attrs = op.getArgAttrsAttr();
   DictionaryAttr argAttrs =
       attrs ? attrs[index].cast<DictionaryAttr>() : DictionaryAttr();
   return argAttrs;
 }
 
 DictionaryAttr
-mlir::function_interface_impl::getResultAttrDict(Operation *op,
-                                                 unsigned index) {
-  ArrayAttr attrs = op->getAttrOfType<ArrayAttr>(getResultDictAttrName());
+function_interface_impl::getResultAttrDict(FunctionOpInterface op,
+                                           unsigned index) {
+  ArrayAttr attrs = op.getResAttrsAttr();
   DictionaryAttr resAttrs =
       attrs ? attrs[index].cast<DictionaryAttr>() : DictionaryAttr();
   return resAttrs;
 }
 
-void mlir::function_interface_impl::detail::setArgResAttrDict(
-    Operation *op, StringRef attrName, unsigned numTotalIndices, unsigned index,
-    DictionaryAttr attrs) {
-  ArrayAttr allAttrs = op->getAttrOfType<ArrayAttr>(attrName);
+ArrayRef<NamedAttribute>
+function_interface_impl::getArgAttrs(FunctionOpInterface op, unsigned index) {
+  auto argDict = getArgAttrDict(op, index);
+  return argDict ? argDict.getValue() : std::nullopt;
+}
+
+ArrayRef<NamedAttribute>
+function_interface_impl::getResultAttrs(FunctionOpInterface op,
+                                        unsigned index) {
+  auto resultDict = getResultAttrDict(op, index);
+  return resultDict ? resultDict.getValue() : std::nullopt;
+}
+
+/// Get either the argument or result attributes array.
+template <bool isArg>
+static ArrayAttr getArgResAttrs(FunctionOpInterface op) {
+  if constexpr (isArg)
+    return op.getArgAttrsAttr();
+  else
+    return op.getResAttrsAttr();
+}
+
+/// Set either the argument or result attributes array.
+template <bool isArg>
+static void setArgResAttrs(FunctionOpInterface op, ArrayAttr attrs) {
+  if constexpr (isArg)
+    op.setArgAttrsAttr(attrs);
+  else
+    op.setResAttrsAttr(attrs);
+}
+
+/// Erase either the argument or result attributes array.
+template <bool isArg>
+static void removeArgResAttrs(FunctionOpInterface op) {
+  if constexpr (isArg)
+    op.removeArgAttrsAttr();
+  else
+    op.removeResAttrsAttr();
+}
+
+/// Set all of the argument or result attribute dictionaries for a function.
+template <bool isArg>
+static void setAllArgResAttrDicts(FunctionOpInterface op,
+                                  ArrayRef<Attribute> attrs) {
+  if (llvm::all_of(attrs, isEmptyAttrDict))
+    removeArgResAttrs<isArg>(op);
+  else
+    setArgResAttrs<isArg>(op, ArrayAttr::get(op->getContext(), attrs));
+}
+
+void function_interface_impl::setAllArgAttrDicts(
+    FunctionOpInterface op, ArrayRef<DictionaryAttr> attrs) {
+  setAllArgAttrDicts(op, ArrayRef<Attribute>(attrs.data(), attrs.size()));
+}
+
+void function_interface_impl::setAllArgAttrDicts(FunctionOpInterface op,
+                                                 ArrayRef<Attribute> attrs) {
+  auto wrappedAttrs = llvm::map_range(attrs, [op](Attribute attr) -> Attribute {
+    return !attr ? DictionaryAttr::get(op->getContext()) : attr;
+  });
+  setAllArgResAttrDicts</*isArg=*/true>(op, llvm::to_vector<8>(wrappedAttrs));
+}
+
+void function_interface_impl::setAllResultAttrDicts(
+    FunctionOpInterface op, ArrayRef<DictionaryAttr> attrs) {
+  setAllResultAttrDicts(op, ArrayRef<Attribute>(attrs.data(), attrs.size()));
+}
+
+void function_interface_impl::setAllResultAttrDicts(FunctionOpInterface op,
+                                                    ArrayRef<Attribute> attrs) {
+  auto wrappedAttrs = llvm::map_range(attrs, [op](Attribute attr) -> Attribute {
+    return !attr ? DictionaryAttr::get(op->getContext()) : attr;
+  });
+  setAllArgResAttrDicts</*isArg=*/false>(op, llvm::to_vector<8>(wrappedAttrs));
+}
+
+/// Update the given index into an argument or result attribute dictionary.
+template <bool isArg>
+static void setArgResAttrDict(FunctionOpInterface op, unsigned numTotalIndices,
+                              unsigned index, DictionaryAttr attrs) {
+  ArrayAttr allAttrs = getArgResAttrs<isArg>(op);
   if (!allAttrs) {
     if (attrs.empty())
       return;
@@ -53,7 +130,7 @@ void mlir::function_interface_impl::detail::setArgResAttrDict(
     SmallVector<Attribute, 8> newAttrs(numTotalIndices,
                                        DictionaryAttr::get(op->getContext()));
     newAttrs[index] = attrs;
-    op->setAttr(attrName, ArrayAttr::get(op->getContext(), newAttrs));
+    setArgResAttrs<isArg>(op, ArrayAttr::get(op->getContext(), newAttrs));
     return;
   }
   // Check to see if the attribute is 
diff erent from what we already have.
@@ -65,53 +142,51 @@ void mlir::function_interface_impl::detail::setArgResAttrDict(
   ArrayRef<Attribute> rawAttrArray = allAttrs.getValue();
   if (attrs.empty() &&
       llvm::all_of(rawAttrArray.take_front(index), isEmptyAttrDict) &&
-      llvm::all_of(rawAttrArray.drop_front(index + 1), isEmptyAttrDict)) {
-    op->removeAttr(attrName);
-    return;
-  }
+      llvm::all_of(rawAttrArray.drop_front(index + 1), isEmptyAttrDict))
+    return removeArgResAttrs<isArg>(op);
 
   // Otherwise, create a new attribute array with the updated dictionary.
   SmallVector<Attribute, 8> newAttrs(rawAttrArray.begin(), rawAttrArray.end());
   newAttrs[index] = attrs;
-  op->setAttr(attrName, ArrayAttr::get(op->getContext(), newAttrs));
+  setArgResAttrs<isArg>(op, ArrayAttr::get(op->getContext(), newAttrs));
 }
 
-/// Set all of the argument or result attribute dictionaries for a function.
-static void setAllArgResAttrDicts(Operation *op, StringRef attrName,
-                                  ArrayRef<Attribute> attrs) {
-  if (llvm::all_of(attrs, isEmptyAttrDict))
-    op->removeAttr(attrName);
-  else
-    op->setAttr(attrName, ArrayAttr::get(op->getContext(), attrs));
+void function_interface_impl::setArgAttrs(FunctionOpInterface op,
+                                          unsigned index,
+                                          ArrayRef<NamedAttribute> attributes) {
+  assert(index < op.getNumArguments() && "invalid argument number");
+  return setArgResAttrDict</*isArg=*/true>(
+      op, op.getNumArguments(), index,
+      DictionaryAttr::get(op->getContext(), attributes));
 }
 
-void mlir::function_interface_impl::setAllArgAttrDicts(
-    Operation *op, ArrayRef<DictionaryAttr> attrs) {
-  setAllArgAttrDicts(op, ArrayRef<Attribute>(attrs.data(), attrs.size()));
-}
-void mlir::function_interface_impl::setAllArgAttrDicts(
-    Operation *op, ArrayRef<Attribute> attrs) {
-  auto wrappedAttrs = llvm::map_range(attrs, [op](Attribute attr) -> Attribute {
-    return !attr ? DictionaryAttr::get(op->getContext()) : attr;
-  });
-  setAllArgResAttrDicts(op, getArgDictAttrName(),
-                        llvm::to_vector<8>(wrappedAttrs));
+void function_interface_impl::setArgAttrs(FunctionOpInterface op,
+                                          unsigned index,
+                                          DictionaryAttr attributes) {
+  return setArgResAttrDict</*isArg=*/true>(
+      op, op.getNumArguments(), index,
+      attributes ? attributes : DictionaryAttr::get(op->getContext()));
 }
 
-void mlir::function_interface_impl::setAllResultAttrDicts(
-    Operation *op, ArrayRef<DictionaryAttr> attrs) {
-  setAllResultAttrDicts(op, ArrayRef<Attribute>(attrs.data(), attrs.size()));
+void function_interface_impl::setResultAttrs(
+    FunctionOpInterface op, unsigned index,
+    ArrayRef<NamedAttribute> attributes) {
+  assert(index < op.getNumResults() && "invalid result number");
+  return setArgResAttrDict</*isArg=*/false>(
+      op, op.getNumResults(), index,
+      DictionaryAttr::get(op->getContext(), attributes));
 }
-void mlir::function_interface_impl::setAllResultAttrDicts(
-    Operation *op, ArrayRef<Attribute> attrs) {
-  auto wrappedAttrs = llvm::map_range(attrs, [op](Attribute attr) -> Attribute {
-    return !attr ? DictionaryAttr::get(op->getContext()) : attr;
-  });
-  setAllArgResAttrDicts(op, getResultDictAttrName(),
-                        llvm::to_vector<8>(wrappedAttrs));
+
+void function_interface_impl::setResultAttrs(FunctionOpInterface op,
+                                             unsigned index,
+                                             DictionaryAttr attributes) {
+  assert(index < op.getNumResults() && "invalid result number");
+  return setArgResAttrDict</*isArg=*/false>(
+      op, op.getNumResults(), index,
+      attributes ? attributes : DictionaryAttr::get(op->getContext()));
 }
 
-void mlir::function_interface_impl::insertFunctionArguments(
+void function_interface_impl::insertFunctionArguments(
     FunctionOpInterface op, ArrayRef<unsigned> argIndices, TypeRange argTypes,
     ArrayRef<DictionaryAttr> argAttrs, ArrayRef<Location> argLocs,
     unsigned originalNumArgs, Type newType) {
@@ -128,7 +203,7 @@ void mlir::function_interface_impl::insertFunctionArguments(
   Block &entry = op->getRegion(0).front();
 
   // Update the argument attributes of the function.
-  auto oldArgAttrs = op->getAttrOfType<ArrayAttr>(getArgDictAttrName());
+  ArrayAttr oldArgAttrs = op.getArgAttrsAttr();
   if (oldArgAttrs || !argAttrs.empty()) {
     SmallVector<DictionaryAttr, 4> newArgAttrs;
     newArgAttrs.reserve(originalNumArgs + argIndices.size());
@@ -157,7 +232,7 @@ void mlir::function_interface_impl::insertFunctionArguments(
     entry.insertArgument(argIndices[i] + i, argTypes[i], argLocs[i]);
 }
 
-void mlir::function_interface_impl::insertFunctionResults(
+void function_interface_impl::insertFunctionResults(
     FunctionOpInterface op, ArrayRef<unsigned> resultIndices,
     TypeRange resultTypes, ArrayRef<DictionaryAttr> resultAttrs,
     unsigned originalNumResults, Type newType) {
@@ -171,7 +246,7 @@ void mlir::function_interface_impl::insertFunctionResults(
   // - Result attrs.
 
   // Update the result attributes of the function.
-  auto oldResultAttrs = op->getAttrOfType<ArrayAttr>(getResultDictAttrName());
+  ArrayAttr oldResultAttrs = op.getResAttrsAttr();
   if (oldResultAttrs || !resultAttrs.empty()) {
     SmallVector<DictionaryAttr, 4> newResultAttrs;
     newResultAttrs.reserve(originalNumResults + resultIndices.size());
@@ -199,7 +274,7 @@ void mlir::function_interface_impl::insertFunctionResults(
   op.setFunctionTypeAttr(TypeAttr::get(newType));
 }
 
-void mlir::function_interface_impl::eraseFunctionArguments(
+void function_interface_impl::eraseFunctionArguments(
     FunctionOpInterface op, const BitVector &argIndices, Type newType) {
   // There are 3 things that need to be updated:
   // - Function type.
@@ -208,7 +283,7 @@ void mlir::function_interface_impl::eraseFunctionArguments(
   Block &entry = op->getRegion(0).front();
 
   // Update the argument attributes of the function.
-  if (auto argAttrs = op->getAttrOfType<ArrayAttr>(getArgDictAttrName())) {
+  if (ArrayAttr argAttrs = op.getArgAttrsAttr()) {
     SmallVector<DictionaryAttr, 4> newArgAttrs;
     newArgAttrs.reserve(argAttrs.size());
     for (unsigned i = 0, e = argIndices.size(); i < e; ++i)
@@ -222,14 +297,14 @@ void mlir::function_interface_impl::eraseFunctionArguments(
   entry.eraseArguments(argIndices);
 }
 
-void mlir::function_interface_impl::eraseFunctionResults(
+void function_interface_impl::eraseFunctionResults(
     FunctionOpInterface op, const BitVector &resultIndices, Type newType) {
   // There are 2 things that need to be updated:
   // - Function type.
   // - Result attrs.
 
   // Update the result attributes of the function.
-  if (auto resAttrs = op->getAttrOfType<ArrayAttr>(getResultDictAttrName())) {
+  if (ArrayAttr resAttrs = op.getResAttrsAttr()) {
     SmallVector<DictionaryAttr, 4> newResultAttrs;
     newResultAttrs.reserve(resAttrs.size());
     for (unsigned i = 0, e = resultIndices.size(); i < e; ++i)
@@ -242,7 +317,7 @@ void mlir::function_interface_impl::eraseFunctionResults(
   op.setFunctionTypeAttr(TypeAttr::get(newType));
 }
 
-TypeRange mlir::function_interface_impl::insertTypesInto(
+TypeRange function_interface_impl::insertTypesInto(
     TypeRange oldTypes, ArrayRef<unsigned> indices, TypeRange newTypes,
     SmallVectorImpl<Type> &storage) {
   assert(indices.size() == newTypes.size() &&
@@ -261,7 +336,7 @@ TypeRange mlir::function_interface_impl::insertTypesInto(
   return storage;
 }
 
-TypeRange mlir::function_interface_impl::filterTypesOut(
+TypeRange function_interface_impl::filterTypesOut(
     TypeRange types, const BitVector &indices, SmallVectorImpl<Type> &storage) {
   if (indices.none())
     return types;
@@ -276,8 +351,8 @@ TypeRange mlir::function_interface_impl::filterTypesOut(
 // Function type signature.
 //===----------------------------------------------------------------------===//
 
-void mlir::function_interface_impl::setFunctionType(FunctionOpInterface op,
-                                                    Type newType) {
+void function_interface_impl::setFunctionType(FunctionOpInterface op,
+                                              Type newType) {
   unsigned oldNumArgs = op.getNumArguments();
   unsigned oldNumResults = op.getNumResults();
   op.setFunctionTypeAttr(TypeAttr::get(newType));
@@ -285,35 +360,31 @@ void mlir::function_interface_impl::setFunctionType(FunctionOpInterface op,
   unsigned newNumResults = op.getNumResults();
 
   // Functor used to update the argument and result attributes of the function.
-  auto updateAttrFn = [&](StringRef attrName, unsigned oldCount,
-                          unsigned newCount, auto setAttrFn) {
+  auto updateAttrFn = [&](auto isArg, unsigned oldCount, unsigned newCount) {
+    constexpr bool isArgVal = std::is_same_v<decltype(isArg), std::true_type>;
+
     if (oldCount == newCount)
       return;
     // The new type has no arguments/results, just drop the attribute.
-    if (newCount == 0) {
-      op->removeAttr(attrName);
-      return;
-    }
-    ArrayAttr attrs = op->getAttrOfType<ArrayAttr>(attrName);
+    if (newCount == 0)
+      return removeArgResAttrs<isArgVal>(op);
+    ArrayAttr attrs = getArgResAttrs<isArgVal>(op);
     if (!attrs)
       return;
 
     // The new type has less arguments/results, take the first N attributes.
     if (newCount < oldCount)
-      return setAttrFn(op, attrs.getValue().take_front(newCount));
+      return setAllArgResAttrDicts<isArgVal>(
+          op, attrs.getValue().take_front(newCount));
 
     // Otherwise, the new type has more arguments/results. Initialize the new
     // arguments/results with empty attributes.
     SmallVector<Attribute> newAttrs(attrs.begin(), attrs.end());
     newAttrs.resize(newCount);
-    setAttrFn(op, newAttrs);
+    setAllArgResAttrDicts<isArgVal>(op, newAttrs);
   };
 
   // Update the argument and result attributes.
-  updateAttrFn(
-      getArgDictAttrName(), oldNumArgs, newNumArgs,
-      [&](Operation *op, auto &&attrs) { setAllArgAttrDicts(op, attrs); });
-  updateAttrFn(
-      getResultDictAttrName(), oldNumResults, newNumResults,
-      [&](Operation *op, auto &&attrs) { setAllResultAttrDicts(op, attrs); });
+  updateAttrFn(std::true_type{}, oldNumArgs, newNumArgs);
+  updateAttrFn(std::false_type{}, oldNumResults, newNumResults);
 }

diff  --git a/mlir/test/IR/invalid-func-op.mlir b/mlir/test/IR/invalid-func-op.mlir
index a72abad61f400..d995689ebb8d0 100644
--- a/mlir/test/IR/invalid-func-op.mlir
+++ b/mlir/test/IR/invalid-func-op.mlir
@@ -96,20 +96,11 @@ func.func private @invalid_symbol_type_attr() attributes { function_type = "x" }
 
 // -----
 
-// expected-error at +1 {{argument attribute array `arg_attrs` to have the same number of elements as the number of function arguments}}
+// expected-error at +1 {{argument attribute array to have the same number of elements as the number of function arguments}}
 func.func private @invalid_arg_attrs() attributes { arg_attrs = [{}] }
 
 // -----
 
-// expected-error at +1 {{expects argument attribute dictionary to be a DictionaryAttr, but got `10 : i64`}}
-func.func private @invalid_arg_attrs(i32) attributes { arg_attrs = [10] }
 
-// -----
-
-// expected-error at +1 {{result attribute array `res_attrs` to have the same number of elements as the number of function results}}
+// expected-error at +1 {{result attribute array to have the same number of elements as the number of function results}}
 func.func private @invalid_res_attrs() attributes { res_attrs = [{}] }
-
-// -----
-
-// expected-error at +1 {{expects result attribute dictionary to be a DictionaryAttr, but got `10 : i64`}}
-func.func private @invalid_res_attrs() -> i32 attributes { res_attrs = [10] }


        


More information about the Mlir-commits mailing list