[Mlir-commits] [mlir] 53b946a - [mlir] Refactor the representation of function-like argument/result attributes.

River Riddle llvmlistbot at llvm.org
Fri May 7 19:40:20 PDT 2021


Author: River Riddle
Date: 2021-05-07T19:32:31-07:00
New Revision: 53b946aa636a31e9243b8c5bf1703a1f6eae798e

URL: https://github.com/llvm/llvm-project/commit/53b946aa636a31e9243b8c5bf1703a1f6eae798e
DIFF: https://github.com/llvm/llvm-project/commit/53b946aa636a31e9243b8c5bf1703a1f6eae798e.diff

LOG: [mlir] Refactor the representation of function-like argument/result attributes.

The current design uses a unique entry for each argument/result attribute, with the name of the entry being something like "arg0". This provides for a somewhat sparse design, but ends up being much more expensive (from a runtime perspective) in-practice. The design requires building a string every time we lookup the dictionary for a specific arg/result, and also requires N attribute lookups when collecting all of the arg/result attribute dictionaries.

This revision restructures the design to instead have an ArrayAttr that contains all of the attribute dictionaries for arguments and another for results. This design reduces the number of attribute name lookups to 1, and allows for O(1) lookup for individual element dictionaries. The major downside is that we can end up with larger memory usage, as the ArrayAttr contains an entry for each element even if that element has no attributes. If the memory usage becomes too problematic, we can experiment with a more sparse structure that still provides a lot of the wins in this revision.

This dropped the compilation time of a somewhat large TensorFlow model from ~650 seconds to ~400 seconds.

Differential Revision: https://reviews.llvm.org/D102035

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/GPU/GPUOps.td
    mlir/include/mlir/IR/BuiltinAttributes.td
    mlir/include/mlir/IR/FunctionImplementation.h
    mlir/include/mlir/IR/FunctionSupport.h
    mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
    mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
    mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
    mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
    mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
    mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
    mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
    mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
    mlir/lib/IR/BuiltinDialect.cpp
    mlir/lib/IR/FunctionImplementation.cpp
    mlir/lib/IR/FunctionSupport.cpp
    mlir/lib/Transforms/Utils/DialectConversion.cpp
    mlir/test/Dialect/LLVMIR/func.mlir
    mlir/test/IR/invalid-func-op.mlir
    mlir/test/IR/test-func-set-type.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/GPU/GPUOps.td b/mlir/include/mlir/Dialect/GPU/GPUOps.td
index 5bd6956e75c0..a29a22a9989b 100644
--- a/mlir/include/mlir/Dialect/GPU/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/GPUOps.td
@@ -213,16 +213,6 @@ def GPU_GPUFuncOp : GPU_Op<"func", [HasParent<"GPUModuleOp">,
           GPUDialect::getKernelFuncAttrName()) != nullptr;
     }
 
-    /// Change the type of this function in place. This is an extremely
-    /// dangerous operation and it is up to the caller to ensure that this is
-    /// legal for this function, and to restore invariants:
-    ///  - the entry block args must be updated to match the function params.
-    ///  - the argument/result attributes may need an update: if the new type
-    ///  has less parameters we drop the extra attributes, if there are more
-    ///  parameters they won't have any attributes.
-    // TODO: consider removing this function thanks to rewrite patterns.
-    void setType(FunctionType newType);
-
     /// Returns the number of buffers located in the workgroup memory.
     unsigned getNumWorkgroupAttributions() {
       return (*this)->getAttrOfType<IntegerAttr>(

diff  --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td
index c248ad5822f8..05dbc6bd7230 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.td
+++ b/mlir/include/mlir/IR/BuiltinAttributes.td
@@ -300,7 +300,7 @@ def Builtin_DictionaryAttr : Builtin_Attr<"Dictionary"> {
   }];
   let parameters = (ins ArrayRefParameter<"NamedAttribute", "">:$value);
   let builders = [
-    AttrBuilder<(ins "ArrayRef<NamedAttribute>":$value)>
+    AttrBuilder<(ins CArg<"ArrayRef<NamedAttribute>", "llvm::None">:$value)>
   ];
   let extraClassDeclaration = [{
     using ValueType = ArrayRef<NamedAttribute>;

diff  --git a/mlir/include/mlir/IR/FunctionImplementation.h b/mlir/include/mlir/IR/FunctionImplementation.h
index c19100c55219..cb7776fca355 100644
--- a/mlir/include/mlir/IR/FunctionImplementation.h
+++ b/mlir/include/mlir/IR/FunctionImplementation.h
@@ -20,7 +20,7 @@
 
 namespace mlir {
 
-namespace impl {
+namespace function_like_impl {
 
 /// A named class for passing around the variadic flag.
 class VariadicFlag {
@@ -37,6 +37,9 @@ class VariadicFlag {
 /// `resultAttrs` arguments, to the list of operation attributes in `result`.
 /// Internally, argument and result attributes are stored as dict attributes
 /// with special names given by getResultAttrName, getArgumentAttrName.
+void addArgAndResultAttrs(Builder &builder, OperationState &result,
+                          ArrayRef<DictionaryAttr> argAttrs,
+                          ArrayRef<DictionaryAttr> resultAttrs);
 void addArgAndResultAttrs(Builder &builder, OperationState &result,
                           ArrayRef<NamedAttrList> argAttrs,
                           ArrayRef<NamedAttrList> resultAttrs);
@@ -103,7 +106,7 @@ void printFunctionAttributes(OpAsmPrinter &p, Operation *op, unsigned numInputs,
                              unsigned numResults,
                              ArrayRef<StringRef> elided = {});
 
-} // namespace impl
+} // namespace function_like_impl
 
 } // namespace mlir
 

diff  --git a/mlir/include/mlir/IR/FunctionSupport.h b/mlir/include/mlir/IR/FunctionSupport.h
index a3be1a76143a..21d6e3724312 100644
--- a/mlir/include/mlir/IR/FunctionSupport.h
+++ b/mlir/include/mlir/IR/FunctionSupport.h
@@ -20,45 +20,41 @@
 
 namespace mlir {
 
-namespace impl {
+namespace function_like_impl {
 
 /// Return the name of the attribute used for function types.
 inline StringRef getTypeAttrName() { return "type"; }
 
-/// Return the name of the attribute used for function arguments.
-inline StringRef getArgAttrName(unsigned arg, SmallVectorImpl<char> &out) {
-  out.clear();
-  return ("arg" + Twine(arg)).toStringRef(out);
-}
-
-/// Returns true if the given name is a valid argument attribute name.
-inline bool isArgAttrName(StringRef name) {
-  APInt unused;
-  return name.startswith("arg") &&
-         !name.drop_front(3).getAsInteger(/*Radix=*/10, unused);
-}
+/// Return the name of the attribute used for function argument attributes.
+inline StringRef getArgDictAttrName() { return "arg_attrs"; }
 
-/// Return the name of the attribute used for function results.
-inline StringRef getResultAttrName(unsigned arg, SmallVectorImpl<char> &out) {
-  out.clear();
-  return ("result" + Twine(arg)).toStringRef(out);
-}
+/// Return the name of the attribute used for function argument attributes.
+inline StringRef getResultDictAttrName() { return "res_attrs"; }
 
 /// Returns the dictionary attribute corresponding to the argument at 'index'.
 /// If there are no argument attributes at 'index', a null attribute is
 /// returned.
-inline DictionaryAttr getArgAttrDict(Operation *op, unsigned index) {
-  SmallString<8> nameOut;
-  return op->getAttrOfType<DictionaryAttr>(getArgAttrName(index, nameOut));
-}
+DictionaryAttr getArgAttrDict(Operation *op, unsigned index);
 
 /// Returns the dictionary attribute corresponding to the result at 'index'.
 /// If there are no result attributes at 'index', a null attribute is
 /// returned.
-inline DictionaryAttr getResultAttrDict(Operation *op, unsigned index) {
-  SmallString<8> nameOut;
-  return op->getAttrOfType<DictionaryAttr>(getResultAttrName(index, nameOut));
-}
+DictionaryAttr getResultAttrDict(Operation *op, unsigned index);
+
+namespace detail {
+/// Update the given index into an argument or result attribute dictionary.
+void setArgResAttrDict(Operation *op, StringRef attrName,
+                       unsigned numTotalIndices, unsigned index,
+                       DictionaryAttr attrs);
+} // namespace detail
+
+/// Set all of the argument or result attribute dictionaries for a function. The
+/// size of `attrs` is expected to match the number of arguments/results of the
+/// given `op`.
+void setAllArgAttrDicts(Operation *op, ArrayRef<DictionaryAttr> attrs);
+void setAllArgAttrDicts(Operation *op, ArrayRef<Attribute> attrs);
+void setAllResultAttrDicts(Operation *op, ArrayRef<DictionaryAttr> attrs);
+void setAllResultAttrDicts(Operation *op, ArrayRef<Attribute> attrs);
 
 /// Return all of the attributes for the argument at 'index'.
 inline ArrayRef<NamedAttribute> getArgAttrs(Operation *op, unsigned index) {
@@ -87,7 +83,7 @@ void setFunctionType(Operation *op, FunctionType newType);
 /// Get a FunctionLike operation's body.
 Region &getFunctionBody(Operation *op);
 
-} // namespace impl
+} // namespace function_like_impl
 
 namespace OpTrait {
 
@@ -142,7 +138,7 @@ class FunctionLike : public OpTrait::TraitBase<ConcreteType, FunctionLike> {
   bool isExternal() { return empty(); }
 
   Region &getBody() {
-    return ::mlir::impl::getFunctionBody(this->getOperation());
+    return function_like_impl::getFunctionBody(this->getOperation());
   }
 
   /// Delete all blocks from this function.
@@ -194,7 +190,9 @@ class FunctionLike : public OpTrait::TraitBase<ConcreteType, FunctionLike> {
   //===--------------------------------------------------------------------===//
 
   /// Return the name of the attribute used for function types.
-  static StringRef getTypeAttrName() { return ::mlir::impl::getTypeAttrName(); }
+  static StringRef getTypeAttrName() {
+    return function_like_impl::getTypeAttrName();
+  }
 
   TypeAttr getTypeAttr() {
     return this->getOperation()->template getAttrOfType<TypeAttr>(
@@ -207,7 +205,7 @@ class FunctionLike : public OpTrait::TraitBase<ConcreteType, FunctionLike> {
   /// hide this one if the concrete class does not use FunctionType for the
   /// function type under the hood.
   FunctionType getType() {
-    return ::mlir::impl::getFunctionType(this->getOperation());
+    return function_like_impl::getFunctionType(this->getOperation());
   }
 
   /// Return the type of this function without the specified arguments and
@@ -277,8 +275,8 @@ class FunctionLike : public OpTrait::TraitBase<ConcreteType, FunctionLike> {
   void eraseArguments(ArrayRef<unsigned> argIndices) {
     unsigned originalNumArgs = getNumArguments();
     Type newType = getTypeWithoutArgsAndResults(argIndices, {});
-    ::mlir::impl::eraseFunctionArguments(this->getOperation(), argIndices,
-                                         originalNumArgs, newType);
+    function_like_impl::eraseFunctionArguments(this->getOperation(), argIndices,
+                                               originalNumArgs, newType);
   }
 
   /// Erase a single result at `resultIndex`.
@@ -289,8 +287,8 @@ class FunctionLike : public OpTrait::TraitBase<ConcreteType, FunctionLike> {
   void eraseResults(ArrayRef<unsigned> resultIndices) {
     unsigned originalNumResults = getNumResults();
     Type newType = getTypeWithoutArgsAndResults({}, resultIndices);
-    ::mlir::impl::eraseFunctionResults(this->getOperation(), resultIndices,
-                                       originalNumResults, newType);
+    function_like_impl::eraseFunctionResults(
+        this->getOperation(), resultIndices, originalNumResults, newType);
   }
 
   //===--------------------------------------------------------------------===//
@@ -306,14 +304,23 @@ class FunctionLike : public OpTrait::TraitBase<ConcreteType, FunctionLike> {
 
   /// Return all of the attributes for the argument at 'index'.
   ArrayRef<NamedAttribute> getArgAttrs(unsigned index) {
-    return ::mlir::impl::getArgAttrs(this->getOperation(), index);
+    return function_like_impl::getArgAttrs(this->getOperation(), index);
   }
 
-  /// Return all argument attributes of this function. If an argument does not
-  /// have any attributes, the corresponding entry in `result` is nullptr.
+  /// Return an ArrayAttr containing all argument attribute dictionaries of this
+  /// function, or nullptr if no arguments have attributes.
+  ArrayAttr getAllArgAttrs() {
+    return this->getOperation()->template getAttrOfType<ArrayAttr>(
+        function_like_impl::getArgDictAttrName());
+  }
+  /// Return all argument attributes of this function.
   void getAllArgAttrs(SmallVectorImpl<DictionaryAttr> &result) {
-    for (unsigned i = 0, e = getNumArguments(); i != e; ++i)
-      result.emplace_back(getArgAttrDict(i));
+    if (ArrayAttr argAttrs = getAllArgAttrs()) {
+      auto argAttrRange = argAttrs.template getAsRange<DictionaryAttr>();
+      result.append(argAttrRange.begin(), argAttrRange.end());
+    } else {
+      result.resize(getNumArguments());
+    }
   }
 
   /// Return the specified attribute, if present, for the argument at 'index',
@@ -342,7 +349,19 @@ class FunctionLike : public OpTrait::TraitBase<ConcreteType, FunctionLike> {
   /// Set the attributes held by the argument at 'index'. `attributes` may be
   /// null, in which case any existing argument attributes are removed.
   void setArgAttrs(unsigned index, DictionaryAttr attributes);
-  void setAllArgAttrs(ArrayRef<DictionaryAttr> attributes);
+  void setAllArgAttrs(ArrayRef<DictionaryAttr> attributes) {
+    assert(attributes.size() == getNumArguments());
+    function_like_impl::setAllArgAttrDicts(this->getOperation(), attributes);
+  }
+  void setAllArgAttrs(ArrayRef<Attribute> attributes) {
+    assert(attributes.size() == getNumArguments());
+    function_like_impl::setAllArgAttrDicts(this->getOperation(), attributes);
+  }
+  void setAllArgAttrs(ArrayAttr attributes) {
+    assert(attributes.size() == getNumArguments());
+    this->getOperation()->setAttr(function_like_impl::getArgDictAttrName(),
+                                  attributes);
+  }
 
   /// If the an attribute exists with the specified name, change it to the new
   /// value. Otherwise, add a new attribute with the specified name/value.
@@ -370,14 +389,23 @@ class FunctionLike : public OpTrait::TraitBase<ConcreteType, FunctionLike> {
 
   /// Return all of the attributes for the result at 'index'.
   ArrayRef<NamedAttribute> getResultAttrs(unsigned index) {
-    return ::mlir::impl::getResultAttrs(this->getOperation(), index);
+    return function_like_impl::getResultAttrs(this->getOperation(), index);
   }
 
-  /// Return all result attributes of this function. If a result does not have
-  /// any attributes, the corresponding entry in `result` is nullptr.
+  /// Return an ArrayAttr containing all result attribute dictionaries of this
+  /// function, or nullptr if no result have attributes.
+  ArrayAttr getAllResultAttrs() {
+    return this->getOperation()->template getAttrOfType<ArrayAttr>(
+        function_like_impl::getResultDictAttrName());
+  }
+  /// Return all result attributes of this function.
   void getAllResultAttrs(SmallVectorImpl<DictionaryAttr> &result) {
-    for (unsigned i = 0, e = getNumResults(); i != e; ++i)
-      result.emplace_back(getResultAttrDict(i));
+    if (ArrayAttr argAttrs = getAllResultAttrs()) {
+      auto argAttrRange = argAttrs.template getAsRange<DictionaryAttr>();
+      result.append(argAttrRange.begin(), argAttrRange.end());
+    } else {
+      result.resize(getNumResults());
+    }
   }
 
   /// Return the specified attribute, if present, for the result at 'index',
@@ -402,10 +430,23 @@ class FunctionLike : public OpTrait::TraitBase<ConcreteType, FunctionLike> {
 
   /// Set the attributes held by the result at 'index'.
   void setResultAttrs(unsigned index, ArrayRef<NamedAttribute> attributes);
+
   /// Set the attributes held by the result at 'index'. `attributes` may be
   /// null, in which case any existing argument attributes are removed.
   void setResultAttrs(unsigned index, DictionaryAttr attributes);
-  void setAllResultAttrs(ArrayRef<DictionaryAttr> attributes);
+  void setAllResultAttrs(ArrayRef<DictionaryAttr> attributes) {
+    assert(attributes.size() == getNumResults());
+    function_like_impl::setAllResultAttrDicts(this->getOperation(), attributes);
+  }
+  void setAllResultAttrs(ArrayRef<Attribute> attributes) {
+    assert(attributes.size() == getNumResults());
+    function_like_impl::setAllResultAttrDicts(this->getOperation(), attributes);
+  }
+  void setAllResultAttrs(ArrayAttr attributes) {
+    assert(attributes.size() == getNumResults());
+    this->getOperation()->setAttr(function_like_impl::getResultDictAttrName(),
+                                  attributes);
+  }
 
   /// If the an attribute exists with the specified name, change it to the new
   /// value. Otherwise, add a new attribute with the specified name/value.
@@ -422,25 +463,12 @@ class FunctionLike : public OpTrait::TraitBase<ConcreteType, FunctionLike> {
   Attribute removeResultAttr(unsigned index, Identifier name);
 
 protected:
-  /// Returns the attribute entry name for the set of argument attributes at
-  /// 'index'.
-  static StringRef getArgAttrName(unsigned index, SmallVectorImpl<char> &out) {
-    return ::mlir::impl::getArgAttrName(index, out);
-  }
-
   /// Returns the dictionary attribute corresponding to the argument at 'index'.
   /// If there are no argument attributes at 'index', a null attribute is
   /// returned.
   DictionaryAttr getArgAttrDict(unsigned index) {
     assert(index < getNumArguments() && "invalid argument number");
-    return ::mlir::impl::getArgAttrDict(this->getOperation(), index);
-  }
-
-  /// Returns the attribute entry name for the set of result attributes at
-  /// 'index'.
-  static StringRef getResultAttrName(unsigned index,
-                                     SmallVectorImpl<char> &out) {
-    return ::mlir::impl::getResultAttrName(index, out);
+    return function_like_impl::getArgAttrDict(this->getOperation(), index);
   }
 
   /// Returns the dictionary attribute corresponding to the result at 'index'.
@@ -448,7 +476,7 @@ class FunctionLike : public OpTrait::TraitBase<ConcreteType, FunctionLike> {
   /// returned.
   DictionaryAttr getResultAttrDict(unsigned index) {
     assert(index < getNumResults() && "invalid result number");
-    return ::mlir::impl::getResultAttrDict(this->getOperation(), index);
+    return function_like_impl::getResultAttrDict(this->getOperation(), index);
   }
 
   /// Hook for concrete classes to verify that the type attribute respects
@@ -475,9 +503,7 @@ LogicalResult FunctionLike<ConcreteType>::verifyBody() {
 
 template <typename ConcreteType>
 LogicalResult FunctionLike<ConcreteType>::verifyTrait(Operation *op) {
-  MLIRContext *ctx = op->getContext();
   auto funcOp = cast<ConcreteType>(op);
-
   if (!funcOp.isTypeAttrValid())
     return funcOp.emitOpError("requires a type attribute '")
            << getTypeAttrName() << '\'';
@@ -485,35 +511,69 @@ LogicalResult FunctionLike<ConcreteType>::verifyTrait(Operation *op) {
   if (failed(funcOp.verifyType()))
     return failure();
 
-  for (unsigned i = 0, e = funcOp.getNumArguments(); i != e; ++i) {
-    // Verify that all of the argument attributes are dialect attributes, i.e.
-    // that they contain a dialect prefix in their name.  Call the dialect, if
-    // registered, to verify the attributes themselves.
-    for (auto attr : funcOp.getArgAttrs(i)) {
-      if (!attr.first.strref().contains('.'))
-        return funcOp.emitOpError("arguments may only have dialect attributes");
-      auto dialectNamePair = attr.first.strref().split('.');
-      if (auto *dialect = ctx->getLoadedDialect(dialectNamePair.first)) {
-        if (failed(dialect->verifyRegionArgAttribute(op, /*regionIndex=*/0,
-                                                     /*argIndex=*/i, attr)))
-          return failure();
+  if (ArrayAttr allArgAttrs = funcOp.getAllArgAttrs()) {
+    unsigned numArgs = funcOp.getNumArguments();
+    if (allArgAttrs.size() != numArgs) {
+      return funcOp.emitOpError()
+             << "expects argument attribute array `"
+             << function_like_impl::getArgDictAttrName()
+             << "` to have the same number of elements as the number of "
+                "function arguments, got "
+             << allArgAttrs.size() << ", but expected " << numArgs;
+    }
+    for (unsigned i = 0; i != numArgs; ++i) {
+      DictionaryAttr argAttrs = allArgAttrs[i].dyn_cast<DictionaryAttr>();
+      if (!argAttrs) {
+        return funcOp.emitOpError() << "expects argument attribute dictionary "
+                                       "to be a DictionaryAttr, but got `"
+                                    << allArgAttrs[i] << "`";
+      }
+
+      // Verify that all of the argument attributes are dialect attributes, i.e.
+      // that they contain a dialect prefix in their name.  Call the dialect, if
+      // registered, to verify the attributes themselves.
+      for (auto attr : argAttrs) {
+        if (!attr.first.strref().contains('.'))
+          return funcOp.emitOpError(
+              "arguments may only have dialect attributes");
+        if (Dialect *dialect = attr.first.getDialect()) {
+          if (failed(dialect->verifyRegionArgAttribute(op, /*regionIndex=*/0,
+                                                       /*argIndex=*/i, attr)))
+            return failure();
+        }
       }
     }
   }
+  if (ArrayAttr allResultAttrs = funcOp.getAllResultAttrs()) {
+    unsigned numResults = funcOp.getNumResults();
+    if (allResultAttrs.size() != numResults) {
+      return funcOp.emitOpError()
+             << "expects result attribute array `"
+             << function_like_impl::getResultDictAttrName()
+             << "` to have the same number of elements as the number of "
+                "function results, got "
+             << allResultAttrs.size() << ", but expected " << numResults;
+    }
+    for (unsigned i = 0; i != numResults; ++i) {
+      DictionaryAttr resultAttrs = allResultAttrs[i].dyn_cast<DictionaryAttr>();
+      if (!resultAttrs) {
+        return funcOp.emitOpError() << "expects result attribute dictionary "
+                                       "to be a DictionaryAttr, but got `"
+                                    << allResultAttrs[i] << "`";
+      }
 
-  for (unsigned i = 0, e = funcOp.getNumResults(); i != e; ++i) {
-    // Verify that all of the result attributes are dialect attributes, i.e.
-    // that they contain a dialect prefix in their name.  Call the dialect, if
-    // registered, to verify the attributes themselves.
-    for (auto attr : funcOp.getResultAttrs(i)) {
-      if (!attr.first.strref().contains('.'))
-        return funcOp.emitOpError("results may only have dialect attributes");
-      auto dialectNamePair = attr.first.strref().split('.');
-      if (auto *dialect = ctx->getLoadedDialect(dialectNamePair.first)) {
-        if (failed(dialect->verifyRegionResultAttribute(op, /*regionIndex=*/0,
-                                                        /*resultIndex=*/i,
-                                                        attr)))
-          return failure();
+      // Verify that all of the result attributes are dialect attributes, i.e.
+      // that they contain a dialect prefix in their name.  Call the dialect, if
+      // registered, to verify the attributes themselves.
+      for (auto attr : resultAttrs) {
+        if (!attr.first.strref().contains('.'))
+          return funcOp.emitOpError("results may only have dialect attributes");
+        if (Dialect *dialect = attr.first.getDialect()) {
+          if (failed(dialect->verifyRegionResultAttribute(op, /*regionIndex=*/0,
+                                                          /*resultIndex=*/i,
+                                                          attr)))
+            return failure();
+        }
       }
     }
   }
@@ -551,7 +611,7 @@ Block *FunctionLike<ConcreteType>::addBlock() {
 
 template <typename ConcreteType>
 void FunctionLike<ConcreteType>::setType(FunctionType newType) {
-  ::mlir::impl::setFunctionType(this->getOperation(), newType);
+  function_like_impl::setFunctionType(this->getOperation(), newType);
 }
 
 //===----------------------------------------------------------------------===//
@@ -563,45 +623,19 @@ template <typename ConcreteType>
 void FunctionLike<ConcreteType>::setArgAttrs(
     unsigned index, ArrayRef<NamedAttribute> attributes) {
   assert(index < getNumArguments() && "invalid argument number");
-  SmallString<8> nameOut;
-  getArgAttrName(index, nameOut);
-
   Operation *op = this->getOperation();
-  if (attributes.empty())
-    return (void)op->removeAttr(nameOut);
-  op->setAttr(nameOut, DictionaryAttr::get(op->getContext(), attributes));
+  return function_like_impl::detail::setArgResAttrDict(
+      op, function_like_impl::getArgDictAttrName(), getNumArguments(), index,
+      DictionaryAttr::get(op->getContext(), attributes));
 }
 
 template <typename ConcreteType>
 void FunctionLike<ConcreteType>::setArgAttrs(unsigned index,
                                              DictionaryAttr attributes) {
-  assert(index < getNumArguments() && "invalid argument number");
-  SmallString<8> nameOut;
-  if (!attributes || attributes.empty())
-    this->getOperation()->removeAttr(getArgAttrName(index, nameOut));
-  else
-    return this->getOperation()->setAttr(getArgAttrName(index, nameOut),
-                                         attributes);
-}
-
-template <typename ConcreteType>
-void FunctionLike<ConcreteType>::setAllArgAttrs(
-    ArrayRef<DictionaryAttr> attributes) {
-  assert(attributes.size() == getNumArguments());
-  NamedAttrList attrs = this->getOperation()->getAttrs();
-
-  // Instead of calling setArgAttrs() multiple times, which rebuild the
-  // attribute dictionary every time, build a new list of attributes for the
-  // operation so that we rebuild the attribute dictionary in one shot.
-  SmallString<8> argAttrName;
-  for (unsigned i = 0, e = attributes.size(); i != e; ++i) {
-    StringRef attrName = getArgAttrName(i, argAttrName);
-    if (!attributes[i] || attributes[i].empty())
-      attrs.erase(attrName);
-    else
-      attrs.set(attrName, attributes[i]);
-  }
-  this->getOperation()->setAttrs(attrs);
+  Operation *op = this->getOperation();
+  return function_like_impl::detail::setArgResAttrDict(
+      op, function_like_impl::getArgDictAttrName(), getNumArguments(), index,
+      attributes ? attributes : DictionaryAttr::get(op->getContext()));
 }
 
 /// If the an attribute exists with the specified name, change it to the new
@@ -640,45 +674,20 @@ template <typename ConcreteType>
 void FunctionLike<ConcreteType>::setResultAttrs(
     unsigned index, ArrayRef<NamedAttribute> attributes) {
   assert(index < getNumResults() && "invalid result number");
-  SmallString<8> nameOut;
-  getResultAttrName(index, nameOut);
-
-  if (attributes.empty())
-    return (void)this->getOperation()->removeAttr(nameOut);
   Operation *op = this->getOperation();
-  op->setAttr(nameOut, DictionaryAttr::get(op->getContext(), attributes));
+  return function_like_impl::detail::setArgResAttrDict(
+      op, function_like_impl::getResultDictAttrName(), getNumResults(), index,
+      DictionaryAttr::get(op->getContext(), attributes));
 }
 
 template <typename ConcreteType>
 void FunctionLike<ConcreteType>::setResultAttrs(unsigned index,
                                                 DictionaryAttr attributes) {
   assert(index < getNumResults() && "invalid result number");
-  SmallString<8> nameOut;
-  if (!attributes || attributes.empty())
-    this->getOperation()->removeAttr(getResultAttrName(index, nameOut));
-  else
-    this->getOperation()->setAttr(getResultAttrName(index, nameOut),
-                                  attributes);
-}
-
-template <typename ConcreteType>
-void FunctionLike<ConcreteType>::setAllResultAttrs(
-    ArrayRef<DictionaryAttr> attributes) {
-  assert(attributes.size() == getNumResults());
-  NamedAttrList attrs = this->getOperation()->getAttrs();
-
-  // Instead of calling setResultAttrs() multiple times, which rebuild the
-  // attribute dictionary every time, build a new list of attributes for the
-  // operation so that we rebuild the attribute dictionary in one shot.
-  SmallString<8> resultAttrName;
-  for (unsigned i = 0, e = attributes.size(); i != e; ++i) {
-    StringRef attrName = getResultAttrName(i, resultAttrName);
-    if (!attributes[i] || attributes[i].empty())
-      attrs.erase(attrName);
-    else
-      attrs.set(attrName, attributes[i]);
-  }
-  this->getOperation()->setAttrs(attrs);
+  Operation *op = this->getOperation();
+  return function_like_impl::detail::setArgResAttrDict(
+      op, function_like_impl::getResultDictAttrName(), getNumResults(), index,
+      attributes ? attributes : DictionaryAttr::get(op->getContext()));
 }
 
 /// If the an attribute exists with the specified name, change it to the new

diff  --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index 08339530f2d6..67f699a13d04 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -58,7 +58,7 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp,
   SmallVector<NamedAttribute, 4> attributes;
   for (const auto &attr : gpuFuncOp->getAttrs()) {
     if (attr.first == SymbolTable::getSymbolAttrName() ||
-        attr.first == impl::getTypeAttrName() ||
+        attr.first == function_like_impl::getTypeAttrName() ||
         attr.first == gpu::GPUFuncOp::getNumWorkgroupAttributionsAttrName())
       continue;
     attributes.push_back(attr);

diff  --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
index 2066debb7d45..fa4bbff5bb32 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
@@ -195,7 +195,7 @@ lowerAsEntryFunction(gpu::GPUFuncOp funcOp, TypeConverter &typeConverter,
       rewriter.getFunctionType(signatureConverter.getConvertedTypes(),
                                llvm::None));
   for (const auto &namedAttr : funcOp->getAttrs()) {
-    if (namedAttr.first == impl::getTypeAttrName() ||
+    if (namedAttr.first == function_like_impl::getTypeAttrName() ||
         namedAttr.first == SymbolTable::getSymbolAttrName())
       continue;
     newFuncOp->setAttr(namedAttr.first, namedAttr.second);

diff  --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index 5f94804b6252..3949cd475141 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -1211,8 +1211,10 @@ static void filterFuncAttributes(ArrayRef<NamedAttribute> attrs,
                                  SmallVectorImpl<NamedAttribute> &result) {
   for (const auto &attr : attrs) {
     if (attr.first == SymbolTable::getSymbolAttrName() ||
-        attr.first == impl::getTypeAttrName() || attr.first == "std.varargs" ||
-        (filterArgAttrs && impl::isArgAttrName(attr.first.strref())))
+        attr.first == function_like_impl::getTypeAttrName() ||
+        attr.first == "std.varargs" ||
+        (filterArgAttrs &&
+         attr.first == function_like_impl::getArgDictAttrName()))
       continue;
     result.push_back(attr);
   }
@@ -1395,19 +1397,19 @@ struct FuncOpConversionBase : public ConvertOpToLLVMPattern<FuncOp> {
     SmallVector<NamedAttribute, 4> attributes;
     filterFuncAttributes(funcOp->getAttrs(), /*filterArgAttrs=*/true,
                          attributes);
-    for (unsigned i = 0, e = funcOp.getNumArguments(); i < e; ++i) {
-      auto attr = impl::getArgAttrDict(funcOp, i);
-      if (!attr)
-        continue;
-
-      auto mapping = result.getInputMapping(i);
-      assert(mapping.hasValue() && "unexpected deletion of function argument");
-
-      SmallString<8> name;
-      for (size_t j = 0; j < mapping->size; ++j) {
-        impl::getArgAttrName(mapping->inputNo + j, name);
-        attributes.push_back(rewriter.getNamedAttr(name, attr));
+    if (ArrayAttr argAttrDicts = funcOp.getAllArgAttrs()) {
+      SmallVector<Attribute, 4> newArgAttrs(
+          llvmType.cast<LLVM::LLVMFunctionType>().getNumParams());
+      for (unsigned i = 0, e = funcOp.getNumArguments(); i < e; ++i) {
+        auto mapping = result.getInputMapping(i);
+        assert(mapping.hasValue() &&
+               "unexpected deletion of function argument");
+        for (size_t j = 0; j < mapping->size; ++j)
+          newArgAttrs[mapping->inputNo + j] = argAttrDicts[i];
       }
+      attributes.push_back(
+          rewriter.getNamedAttr(function_like_impl::getArgDictAttrName(),
+                                rewriter.getArrayAttr(newArgAttrs)));
     }
 
     // Create an LLVM function, use external linkage by default until MLIR

diff  --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 1fa687f83f0d..1f081d896bfc 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -599,9 +599,9 @@ parseLaunchFuncOperands(OpAsmParser &parser,
     return success();
   SmallVector<NamedAttrList, 4> argAttrs;
   bool isVariadic = false;
-  return impl::parseFunctionArgumentList(parser, /*allowAttributes=*/false,
-                                         /*allowVariadic=*/false, argNames,
-                                         argTypes, argAttrs, isVariadic);
+  return function_like_impl::parseFunctionArgumentList(
+      parser, /*allowAttributes=*/false,
+      /*allowVariadic=*/false, argNames, argTypes, argAttrs, isVariadic);
 }
 
 static void printLaunchFuncOperands(OpAsmPrinter &printer, Operation *,
@@ -717,7 +717,7 @@ static ParseResult parseGPUFuncOp(OpAsmParser &parser, OperationState &result) {
     return failure();
 
   auto signatureLocation = parser.getCurrentLocation();
-  if (failed(impl::parseFunctionSignature(
+  if (failed(function_like_impl::parseFunctionSignature(
           parser, /*allowVariadic=*/false, entryArgs, argTypes, argAttrs,
           isVariadic, resultTypes, resultAttrs)))
     return failure();
@@ -756,7 +756,8 @@ static ParseResult parseGPUFuncOp(OpAsmParser &parser, OperationState &result) {
   // Parse attributes.
   if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
     return failure();
-  mlir::impl::addArgAndResultAttrs(builder, result, argAttrs, resultAttrs);
+  function_like_impl::addArgAndResultAttrs(builder, result, argAttrs,
+                                           resultAttrs);
 
   // Parse the region. If no argument names were provided, take all names
   // (including those of attributions) from the entry block.
@@ -781,33 +782,22 @@ static void printGPUFuncOp(OpAsmPrinter &p, GPUFuncOp op) {
   p.printSymbolName(op.getName());
 
   FunctionType type = op.getType();
-  impl::printFunctionSignature(p, op.getOperation(), type.getInputs(),
-                               /*isVariadic=*/false, type.getResults());
+  function_like_impl::printFunctionSignature(
+      p, op.getOperation(), type.getInputs(),
+      /*isVariadic=*/false, type.getResults());
 
   printAttributions(p, op.getWorkgroupKeyword(), op.getWorkgroupAttributions());
   printAttributions(p, op.getPrivateKeyword(), op.getPrivateAttributions());
   if (op.isKernel())
     p << ' ' << op.getKernelKeyword();
 
-  impl::printFunctionAttributes(p, op.getOperation(), type.getNumInputs(),
-                                type.getNumResults(),
-                                {op.getNumWorkgroupAttributionsAttrName(),
-                                 GPUDialect::getKernelFuncAttrName()});
+  function_like_impl::printFunctionAttributes(
+      p, op.getOperation(), type.getNumInputs(), type.getNumResults(),
+      {op.getNumWorkgroupAttributionsAttrName(),
+       GPUDialect::getKernelFuncAttrName()});
   p.printRegion(op.getBody(), /*printEntryBlockArgs=*/false);
 }
 
-void GPUFuncOp::setType(FunctionType newType) {
-  auto oldType = getType();
-  assert(newType.getNumResults() == oldType.getNumResults() &&
-         "unimplemented: changes to the number of results");
-
-  SmallVector<char, 16> nameBuf;
-  for (int i = newType.getNumInputs(), e = oldType.getNumInputs(); i < e; i++)
-    (*this)->removeAttr(getArgAttrName(i, nameBuf));
-
-  (*this)->setAttr(getTypeAttrName(), TypeAttr::get(newType));
-}
-
 /// Hook for FunctionLike verifier.
 LogicalResult GPUFuncOp::verifyType() {
   Type type = getTypeAttr().getValue();

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index e1ad37e75d86..12e6ccc28aa4 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -1732,21 +1732,19 @@ void LLVMFuncOp::build(OpBuilder &builder, OperationState &result,
   if (argAttrs.empty())
     return;
 
-  unsigned numInputs = type.cast<LLVMFunctionType>().getNumParams();
-  assert(numInputs == argAttrs.size() &&
+  assert(type.cast<LLVMFunctionType>().getNumParams() == argAttrs.size() &&
          "expected as many argument attribute lists as arguments");
-  SmallString<8> argAttrName;
-  for (unsigned i = 0; i < numInputs; ++i)
-    if (DictionaryAttr argDict = argAttrs[i])
-      result.addAttribute(getArgAttrName(i, argAttrName), argDict);
+  function_like_impl::addArgAndResultAttrs(builder, result, argAttrs,
+                                           /*resultAttrs=*/llvm::None);
 }
 
 // Builds an LLVM function type from the given lists of input and output types.
 // Returns a null type if any of the types provided are non-LLVM types, or if
 // there is more than one output type.
-static Type buildLLVMFunctionType(OpAsmParser &parser, llvm::SMLoc loc,
-                                  ArrayRef<Type> inputs, ArrayRef<Type> outputs,
-                                  impl::VariadicFlag variadicFlag) {
+static Type
+buildLLVMFunctionType(OpAsmParser &parser, llvm::SMLoc loc,
+                      ArrayRef<Type> inputs, ArrayRef<Type> outputs,
+                      function_like_impl::VariadicFlag variadicFlag) {
   Builder &b = parser.getBuilder();
   if (outputs.size() > 1) {
     parser.emitError(loc, "failed to construct function type: expected zero or "
@@ -1803,22 +1801,23 @@ static ParseResult parseLLVMFuncOp(OpAsmParser &parser,
   auto signatureLocation = parser.getCurrentLocation();
   if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
                              result.attributes) ||
-      impl::parseFunctionSignature(parser, /*allowVariadic=*/true, entryArgs,
-                                   argTypes, argAttrs, isVariadic, resultTypes,
-                                   resultAttrs))
+      function_like_impl::parseFunctionSignature(
+          parser, /*allowVariadic=*/true, entryArgs, argTypes, argAttrs,
+          isVariadic, resultTypes, resultAttrs))
     return failure();
 
   auto type =
       buildLLVMFunctionType(parser, signatureLocation, argTypes, resultTypes,
-                            impl::VariadicFlag(isVariadic));
+                            function_like_impl::VariadicFlag(isVariadic));
   if (!type)
     return failure();
-  result.addAttribute(impl::getTypeAttrName(), TypeAttr::get(type));
+  result.addAttribute(function_like_impl::getTypeAttrName(),
+                      TypeAttr::get(type));
 
   if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
     return failure();
-  impl::addArgAndResultAttrs(parser.getBuilder(), result, argAttrs,
-                             resultAttrs);
+  function_like_impl::addArgAndResultAttrs(parser.getBuilder(), result,
+                                           argAttrs, resultAttrs);
 
   auto *body = result.addRegion();
   OptionalParseResult parseResult = parser.parseOptionalRegion(
@@ -1846,9 +1845,10 @@ static void printLLVMFuncOp(OpAsmPrinter &p, LLVMFuncOp op) {
   if (!returnType.isa<LLVMVoidType>())
     resTypes.push_back(returnType);
 
-  impl::printFunctionSignature(p, op, argTypes, op.isVarArg(), resTypes);
-  impl::printFunctionAttributes(p, op, argTypes.size(), resTypes.size(),
-                                {getLinkageAttrName()});
+  function_like_impl::printFunctionSignature(p, op, argTypes, op.isVarArg(),
+                                             resTypes);
+  function_like_impl::printFunctionAttributes(
+      p, op, argTypes.size(), resTypes.size(), {getLinkageAttrName()});
 
   // Print the body if this is not an external function.
   Region &body = op.body();

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
index 4ca7da6fd22b..04b63535b6de 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
@@ -99,7 +99,7 @@ struct FunctionNonEntryBlockConversion : public ConversionPattern {
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
     rewriter.startRootUpdate(op);
-    Region &region = mlir::impl::getFunctionBody(op);
+    Region &region = function_like_impl::getFunctionBody(op);
     SmallVector<TypeConverter::SignatureConversion, 2> conversions;
 
     for (Block &block : llvm::drop_begin(region, 1)) {

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 31a92c351599..c74528c868c6 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -1783,13 +1783,14 @@ static ParseResult parseFuncOp(OpAsmParser &parser, OperationState &state) {
 
   // Parse the function signature.
   bool isVariadic = false;
-  if (impl::parseFunctionSignature(parser, /*allowVariadic=*/false, entryArgs,
-                                   argTypes, argAttrs, isVariadic, resultTypes,
-                                   resultAttrs))
+  if (function_like_impl::parseFunctionSignature(
+          parser, /*allowVariadic=*/false, entryArgs, argTypes, argAttrs,
+          isVariadic, resultTypes, resultAttrs))
     return failure();
 
   auto fnType = builder.getFunctionType(argTypes, resultTypes);
-  state.addAttribute(impl::getTypeAttrName(), TypeAttr::get(fnType));
+  state.addAttribute(function_like_impl::getTypeAttrName(),
+                     TypeAttr::get(fnType));
 
   // Parse the optional function control keyword.
   spirv::FunctionControl fnControl;
@@ -1803,7 +1804,8 @@ static ParseResult parseFuncOp(OpAsmParser &parser, OperationState &state) {
   // Add the attributes to the function arguments.
   assert(argAttrs.size() == argTypes.size());
   assert(resultAttrs.size() == resultTypes.size());
-  impl::addArgAndResultAttrs(builder, state, argAttrs, resultAttrs);
+  function_like_impl::addArgAndResultAttrs(builder, state, argAttrs,
+                                           resultAttrs);
 
   // Parse the optional function body.
   auto *body = state.addRegion();
@@ -1817,11 +1819,12 @@ static void print(spirv::FuncOp fnOp, OpAsmPrinter &printer) {
   printer << spirv::FuncOp::getOperationName() << " ";
   printer.printSymbolName(fnOp.sym_name());
   auto fnType = fnOp.getType();
-  impl::printFunctionSignature(printer, fnOp, fnType.getInputs(),
-                               /*isVariadic=*/false, fnType.getResults());
+  function_like_impl::printFunctionSignature(printer, fnOp, fnType.getInputs(),
+                                             /*isVariadic=*/false,
+                                             fnType.getResults());
   printer << " \"" << spirv::stringifyFunctionControl(fnOp.function_control())
           << "\"";
-  impl::printFunctionAttributes(
+  function_like_impl::printFunctionAttributes(
       printer, fnOp, fnType.getNumInputs(), fnType.getNumResults(),
       {spirv::attributeName<spirv::FunctionControl>()});
 

diff  --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 5a590216678b..6e807a72ac1c 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -582,7 +582,7 @@ FuncOpConversion::matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
 
   // Copy over all attributes other than the function name and type.
   for (const auto &namedAttr : funcOp->getAttrs()) {
-    if (namedAttr.first != impl::getTypeAttrName() &&
+    if (namedAttr.first != function_like_impl::getTypeAttrName() &&
         namedAttr.first != SymbolTable::getSymbolAttrName())
       newFuncOp->setAttr(namedAttr.first, namedAttr.second);
   }

diff  --git a/mlir/lib/IR/BuiltinDialect.cpp b/mlir/lib/IR/BuiltinDialect.cpp
index e1706f2c9315..728443e8b64f 100644
--- a/mlir/lib/IR/BuiltinDialect.cpp
+++ b/mlir/lib/IR/BuiltinDialect.cpp
@@ -106,27 +106,25 @@ void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
   if (argAttrs.empty())
     return;
   assert(type.getNumInputs() == argAttrs.size());
-  SmallString<8> argAttrName;
-  for (unsigned i = 0, e = type.getNumInputs(); i != e; ++i)
-    if (DictionaryAttr argDict = argAttrs[i])
-      state.addAttribute(getArgAttrName(i, argAttrName), argDict);
+  function_like_impl::addArgAndResultAttrs(builder, state, argAttrs,
+                                           /*resultAttrs=*/llvm::None);
 }
 
 static ParseResult parseFuncOp(OpAsmParser &parser, OperationState &result) {
   auto buildFuncType = [](Builder &builder, ArrayRef<Type> argTypes,
-                          ArrayRef<Type> results, impl::VariadicFlag,
-                          std::string &) {
+                          ArrayRef<Type> results,
+                          function_like_impl::VariadicFlag, std::string &) {
     return builder.getFunctionType(argTypes, results);
   };
 
-  return impl::parseFunctionLikeOp(parser, result, /*allowVariadic=*/false,
-                                   buildFuncType);
+  return function_like_impl::parseFunctionLikeOp(
+      parser, result, /*allowVariadic=*/false, buildFuncType);
 }
 
 static void print(FuncOp op, OpAsmPrinter &p) {
   FunctionType fnType = op.getType();
-  impl::printFunctionLikeOp(p, op, fnType.getInputs(), /*isVariadic=*/false,
-                            fnType.getResults());
+  function_like_impl::printFunctionLikeOp(
+      p, op, fnType.getInputs(), /*isVariadic=*/false, fnType.getResults());
 }
 
 static LogicalResult verify(FuncOp op) {
@@ -170,30 +168,39 @@ void FuncOp::cloneInto(FuncOp dest, BlockAndValueMapping &mapper) {
 /// to cloned sub-values with the corresponding value that is copied, and adds
 /// those mappings to the mapper.
 FuncOp FuncOp::clone(BlockAndValueMapping &mapper) {
-  FunctionType newType = getType();
+  // Create the new function.
+  FuncOp newFunc = cast<FuncOp>(getOperation()->cloneWithoutRegions());
 
   // If the function has a body, then the user might be deleting arguments to
   // the function by specifying them in the mapper. If so, we don't add the
   // argument to the input type vector.
-  bool isExternalFn = isExternal();
-  if (!isExternalFn) {
-    SmallVector<Type, 4> inputTypes;
-    inputTypes.reserve(newType.getNumInputs());
-    for (unsigned i = 0, e = getNumArguments(); i != e; ++i)
+  if (!isExternal()) {
+    FunctionType oldType = getType();
+
+    unsigned oldNumArgs = oldType.getNumInputs();
+    SmallVector<Type, 4> newInputs;
+    newInputs.reserve(oldNumArgs);
+    for (unsigned i = 0; i != oldNumArgs; ++i)
       if (!mapper.contains(getArgument(i)))
-        inputTypes.push_back(newType.getInput(i));
-    newType = FunctionType::get(getContext(), inputTypes, newType.getResults());
+        newInputs.push_back(oldType.getInput(i));
+
+    /// If any of the arguments were dropped, update the type and drop any
+    /// necessary argument attributes.
+    if (newInputs.size() != oldNumArgs) {
+      newFunc.setType(FunctionType::get(oldType.getContext(), newInputs,
+                                        oldType.getResults()));
+
+      if (ArrayAttr argAttrs = getAllArgAttrs()) {
+        SmallVector<Attribute> newArgAttrs;
+        newArgAttrs.reserve(newInputs.size());
+        for (unsigned i = 0; i != oldNumArgs; ++i)
+          if (!mapper.contains(getArgument(i)))
+            newArgAttrs.push_back(argAttrs[i]);
+        newFunc.setAllArgAttrs(newArgAttrs);
+      }
+    }
   }
 
-  // Create the new function.
-  FuncOp newFunc = cast<FuncOp>(getOperation()->cloneWithoutRegions());
-  newFunc.setType(newType);
-
-  /// Set the argument attributes for arguments that aren't being replaced.
-  for (unsigned i = 0, e = getNumArguments(), destI = 0; i != e; ++i)
-    if (isExternalFn || !mapper.contains(getArgument(i)))
-      newFunc.setArgAttrs(destI++, getArgAttrs(i));
-
   /// Clone the current function into the new one and return it.
   cloneInto(newFunc, mapper);
   return newFunc;

diff  --git a/mlir/lib/IR/FunctionImplementation.cpp b/mlir/lib/IR/FunctionImplementation.cpp
index 4bec1684b5ee..aadf5456126b 100644
--- a/mlir/lib/IR/FunctionImplementation.cpp
+++ b/mlir/lib/IR/FunctionImplementation.cpp
@@ -13,7 +13,7 @@
 
 using namespace mlir;
 
-ParseResult mlir::impl::parseFunctionArgumentList(
+ParseResult mlir::function_like_impl::parseFunctionArgumentList(
     OpAsmParser &parser, bool allowAttributes, bool allowVariadic,
     SmallVectorImpl<OpAsmParser::OperandType> &argNames,
     SmallVectorImpl<Type> &argTypes, SmallVectorImpl<NamedAttrList> &argAttrs,
@@ -125,7 +125,7 @@ parseFunctionResultList(OpAsmParser &parser, SmallVectorImpl<Type> &resultTypes,
 /// indicates whether functions with variadic arguments are supported. The
 /// trailing arguments are populated by this function with names, types and
 /// attributes of the arguments and those of the results.
-ParseResult mlir::impl::parseFunctionSignature(
+ParseResult mlir::function_like_impl::parseFunctionSignature(
     OpAsmParser &parser, bool allowVariadic,
     SmallVectorImpl<OpAsmParser::OperandType> &argNames,
     SmallVectorImpl<Type> &argTypes, SmallVectorImpl<NamedAttrList> &argAttrs,
@@ -140,29 +140,53 @@ ParseResult mlir::impl::parseFunctionSignature(
   return success();
 }
 
-void mlir::impl::addArgAndResultAttrs(Builder &builder, OperationState &result,
-                                      ArrayRef<NamedAttrList> argAttrs,
-                                      ArrayRef<NamedAttrList> resultAttrs) {
-  // Add the attributes to the function arguments.
-  SmallString<8> attrNameBuf;
-  for (unsigned i = 0, e = argAttrs.size(); i != e; ++i)
-    if (!argAttrs[i].empty())
-      result.addAttribute(getArgAttrName(i, attrNameBuf),
-                          builder.getDictionaryAttr(argAttrs[i]));
+/// Implementation of `addArgAndResultAttrs` that is attribute list type
+/// agnostic.
+template <typename AttrListT, typename AttrArrayBuildFnT>
+static void addArgAndResultAttrsImpl(Builder &builder, OperationState &result,
+                                     ArrayRef<AttrListT> argAttrs,
+                                     ArrayRef<AttrListT> resultAttrs,
+                                     AttrArrayBuildFnT &&buildAttrArrayFn) {
+  auto nonEmptyAttrsFn = [](const AttrListT &attrs) { return !attrs.empty(); };
 
+  // Add the attributes to the function arguments.
+  if (!argAttrs.empty() && llvm::any_of(argAttrs, nonEmptyAttrsFn)) {
+    ArrayAttr attrDicts = builder.getArrayAttr(buildAttrArrayFn(argAttrs));
+    result.addAttribute(function_like_impl::getArgDictAttrName(), attrDicts);
+  }
   // Add the attributes to the function results.
-  for (unsigned i = 0, e = resultAttrs.size(); i != e; ++i)
-    if (!resultAttrs[i].empty())
-      result.addAttribute(getResultAttrName(i, attrNameBuf),
-                          builder.getDictionaryAttr(resultAttrs[i]));
+  if (!resultAttrs.empty() && llvm::any_of(resultAttrs, nonEmptyAttrsFn)) {
+    ArrayAttr attrDicts = builder.getArrayAttr(buildAttrArrayFn(resultAttrs));
+    result.addAttribute(function_like_impl::getResultDictAttrName(), attrDicts);
+  }
+}
+
+void mlir::function_like_impl::addArgAndResultAttrs(
+    Builder &builder, OperationState &result, ArrayRef<DictionaryAttr> argAttrs,
+    ArrayRef<DictionaryAttr> resultAttrs) {
+  auto buildFn = [](ArrayRef<DictionaryAttr> attrs) {
+    return ArrayRef<Attribute>(attrs.data(), attrs.size());
+  };
+  addArgAndResultAttrsImpl(builder, result, argAttrs, resultAttrs, buildFn);
+}
+void mlir::function_like_impl::addArgAndResultAttrs(
+    Builder &builder, OperationState &result, ArrayRef<NamedAttrList> argAttrs,
+    ArrayRef<NamedAttrList> resultAttrs) {
+  MLIRContext *context = builder.getContext();
+  auto buildFn = [=](ArrayRef<NamedAttrList> attrs) {
+    return llvm::to_vector<8>(
+        llvm::map_range(attrs, [=](const NamedAttrList &attrList) -> Attribute {
+          return attrList.getDictionary(context);
+        }));
+  };
+  addArgAndResultAttrsImpl(builder, result, argAttrs, resultAttrs, buildFn);
 }
 
 /// Parser implementation for function-like operations.  Uses `funcTypeBuilder`
 /// to construct the custom function type given lists of input and output types.
-ParseResult
-mlir::impl::parseFunctionLikeOp(OpAsmParser &parser, OperationState &result,
-                                bool allowVariadic,
-                                mlir::impl::FuncTypeBuilder funcTypeBuilder) {
+ParseResult mlir::function_like_impl::parseFunctionLikeOp(
+    OpAsmParser &parser, OperationState &result, bool allowVariadic,
+    FuncTypeBuilder funcTypeBuilder) {
   SmallVector<OpAsmParser::OperandType, 4> entryArgs;
   SmallVector<NamedAttrList, 4> argAttrs;
   SmallVector<NamedAttrList, 4> resultAttrs;
@@ -187,13 +211,14 @@ mlir::impl::parseFunctionLikeOp(OpAsmParser &parser, OperationState &result,
     return failure();
 
   std::string errorMessage;
-  if (auto type = funcTypeBuilder(builder, argTypes, resultTypes,
-                                  impl::VariadicFlag(isVariadic), errorMessage))
-    result.addAttribute(getTypeAttrName(), TypeAttr::get(type));
-  else
+  Type type = funcTypeBuilder(builder, argTypes, resultTypes,
+                              VariadicFlag(isVariadic), errorMessage);
+  if (!type) {
     return parser.emitError(signatureLocation)
            << "failed to construct function type"
            << (errorMessage.empty() ? "" : ": ") << errorMessage;
+  }
+  result.addAttribute(getTypeAttrName(), TypeAttr::get(type));
 
   // If function attributes are present, parse them.
   NamedAttrList parsedAttributes;
@@ -236,35 +261,38 @@ mlir::impl::parseFunctionLikeOp(OpAsmParser &parser, OperationState &result,
   return success();
 }
 
-// Print a function result list.
+/// Print a function result list. The provided `attrs` must either be null, or
+/// contain a set of DictionaryAttrs of the same arity as `types`.
 static void printFunctionResultList(OpAsmPrinter &p, ArrayRef<Type> types,
-                                    ArrayRef<ArrayRef<NamedAttribute>> attrs) {
+                                    ArrayAttr attrs) {
   assert(!types.empty() && "Should not be called for empty result list.");
+  assert((!attrs || attrs.size() == types.size()) &&
+         "Invalid number of attributes.");
+
   auto &os = p.getStream();
-  bool needsParens =
-      types.size() > 1 || types[0].isa<FunctionType>() || !attrs[0].empty();
+  bool needsParens = types.size() > 1 || types[0].isa<FunctionType>() ||
+                     (attrs && !attrs[0].cast<DictionaryAttr>().empty());
   if (needsParens)
     os << '(';
-  llvm::interleaveComma(
-      llvm::zip(types, attrs), os,
-      [&](const std::tuple<Type, ArrayRef<NamedAttribute>> &t) {
-        p.printType(std::get<0>(t));
-        p.printOptionalAttrDict(std::get<1>(t));
-      });
+  llvm::interleaveComma(llvm::seq<size_t>(0, types.size()), os, [&](size_t i) {
+    p.printType(types[i]);
+    if (attrs)
+      p.printOptionalAttrDict(attrs[i].cast<DictionaryAttr>().getValue());
+  });
   if (needsParens)
     os << ')';
 }
 
 /// Print the signature of the function-like operation `op`.  Assumes `op` has
 /// the FunctionLike trait and passed the verification.
-void mlir::impl::printFunctionSignature(OpAsmPrinter &p, Operation *op,
-                                        ArrayRef<Type> argTypes,
-                                        bool isVariadic,
-                                        ArrayRef<Type> resultTypes) {
+void mlir::function_like_impl::printFunctionSignature(
+    OpAsmPrinter &p, Operation *op, ArrayRef<Type> argTypes, bool isVariadic,
+    ArrayRef<Type> resultTypes) {
   Region &body = op->getRegion(0);
   bool isExternal = body.empty();
 
   p << '(';
+  ArrayAttr argAttrs = op->getAttrOfType<ArrayAttr>(getArgDictAttrName());
   for (unsigned i = 0, e = argTypes.size(); i < e; ++i) {
     if (i > 0)
       p << ", ";
@@ -275,7 +303,8 @@ void mlir::impl::printFunctionSignature(OpAsmPrinter &p, Operation *op,
     }
 
     p.printType(argTypes[i]);
-    p.printOptionalAttrDict(::mlir::impl::getArgAttrs(op, i));
+    if (argAttrs)
+      p.printOptionalAttrDict(argAttrs[i].cast<DictionaryAttr>().getValue());
   }
 
   if (isVariadic) {
@@ -288,9 +317,7 @@ void mlir::impl::printFunctionSignature(OpAsmPrinter &p, Operation *op,
 
   if (!resultTypes.empty()) {
     p.getStream() << " -> ";
-    SmallVector<ArrayRef<NamedAttribute>, 4> resultAttrs;
-    for (int i = 0, e = resultTypes.size(); i < e; ++i)
-      resultAttrs.push_back(::mlir::impl::getResultAttrs(op, i));
+    auto resultAttrs = op->getAttrOfType<ArrayAttr>(getResultDictAttrName());
     printFunctionResultList(p, resultTypes, resultAttrs);
   }
 }
@@ -300,39 +327,25 @@ void mlir::impl::printFunctionSignature(OpAsmPrinter &p, Operation *op,
 /// function-like operation internally are not printed. Nothing is printed
 /// if all attributes are elided. Assumes `op` has the `FunctionLike` trait and
 /// passed the verification.
-void mlir::impl::printFunctionAttributes(OpAsmPrinter &p, Operation *op,
-                                         unsigned numInputs,
-                                         unsigned numResults,
-                                         ArrayRef<StringRef> elided) {
+void mlir::function_like_impl::printFunctionAttributes(
+    OpAsmPrinter &p, Operation *op, unsigned numInputs, unsigned numResults,
+    ArrayRef<StringRef> elided) {
   // Print out function attributes, if present.
   SmallVector<StringRef, 2> ignoredAttrs = {
-      ::mlir::SymbolTable::getSymbolAttrName(), getTypeAttrName()};
+      ::mlir::SymbolTable::getSymbolAttrName(), getTypeAttrName(),
+      getArgDictAttrName(), getResultDictAttrName()};
   ignoredAttrs.append(elided.begin(), elided.end());
 
-  SmallString<8> attrNameBuf;
-
-  // Ignore any argument attributes.
-  std::vector<SmallString<8>> argAttrStorage;
-  for (unsigned i = 0; i != numInputs; ++i)
-    if (op->getAttr(getArgAttrName(i, attrNameBuf)))
-      argAttrStorage.emplace_back(attrNameBuf);
-  ignoredAttrs.append(argAttrStorage.begin(), argAttrStorage.end());
-
-  // Ignore any result attributes.
-  std::vector<SmallString<8>> resultAttrStorage;
-  for (unsigned i = 0; i != numResults; ++i)
-    if (op->getAttr(getResultAttrName(i, attrNameBuf)))
-      resultAttrStorage.emplace_back(attrNameBuf);
-  ignoredAttrs.append(resultAttrStorage.begin(), resultAttrStorage.end());
-
   p.printOptionalAttrDictWithKeyword(op->getAttrs(), ignoredAttrs);
 }
 
 /// Printer implementation for function-like operations.  Accepts lists of
 /// argument and result types to use while printing.
-void mlir::impl::printFunctionLikeOp(OpAsmPrinter &p, Operation *op,
-                                     ArrayRef<Type> argTypes, bool isVariadic,
-                                     ArrayRef<Type> resultTypes) {
+void mlir::function_like_impl::printFunctionLikeOp(OpAsmPrinter &p,
+                                                   Operation *op,
+                                                   ArrayRef<Type> argTypes,
+                                                   bool isVariadic,
+                                                   ArrayRef<Type> resultTypes) {
   // Print the operation and the function name.
   auto funcName =
       op->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName())

diff  --git a/mlir/lib/IR/FunctionSupport.cpp b/mlir/lib/IR/FunctionSupport.cpp
index 347ea155d5e7..2538271373f2 100644
--- a/mlir/lib/IR/FunctionSupport.cpp
+++ b/mlir/lib/IR/FunctionSupport.cpp
@@ -31,103 +31,199 @@ inline void iterateIndicesExcept(unsigned totalIndices,
 // Function Arguments and Results.
 //===----------------------------------------------------------------------===//
 
-void mlir::impl::eraseFunctionArguments(Operation *op,
-                                        ArrayRef<unsigned> argIndices,
-                                        unsigned originalNumArgs,
-                                        Type newType) {
+static bool isEmptyAttrDict(Attribute attr) {
+  return attr.cast<DictionaryAttr>().empty();
+}
+
+DictionaryAttr mlir::function_like_impl::getArgAttrDict(Operation *op,
+                                                        unsigned index) {
+  ArrayAttr attrs = op->getAttrOfType<ArrayAttr>(getArgDictAttrName());
+  DictionaryAttr argAttrs =
+      attrs ? attrs[index].cast<DictionaryAttr>() : DictionaryAttr();
+  return (argAttrs && !argAttrs.empty()) ? argAttrs : DictionaryAttr();
+}
+
+DictionaryAttr mlir::function_like_impl::getResultAttrDict(Operation *op,
+                                                           unsigned index) {
+  ArrayAttr attrs = op->getAttrOfType<ArrayAttr>(getResultDictAttrName());
+  DictionaryAttr resAttrs =
+      attrs ? attrs[index].cast<DictionaryAttr>() : DictionaryAttr();
+  return (resAttrs && !resAttrs.empty()) ? resAttrs : DictionaryAttr();
+}
+
+void mlir::function_like_impl::detail::setArgResAttrDict(
+    Operation *op, StringRef attrName, unsigned numTotalIndices, unsigned index,
+    DictionaryAttr attrs) {
+  ArrayAttr allAttrs = op->getAttrOfType<ArrayAttr>(attrName);
+  if (!allAttrs) {
+    if (attrs.empty())
+      return;
+
+    // If this attribute is not empty, we need to create a new attribute array.
+    SmallVector<Attribute, 8> newAttrs(numTotalIndices,
+                                       DictionaryAttr::get(op->getContext()));
+    newAttrs[index] = attrs;
+    op->setAttr(attrName, ArrayAttr::get(op->getContext(), newAttrs));
+    return;
+  }
+  // Check to see if the attribute is 
diff erent from what we already have.
+  if (allAttrs[index] == attrs)
+    return;
+
+  // If it is, check to see if the attribute array would now contain only empty
+  // dictionaries.
+  ArrayRef<Attribute> rawAttrArray = allAttrs.getValue();
+  if (attrs.empty() &&
+      llvm::all_of(rawAttrArray.take_front(index), isEmptyAttrDict) &&
+      llvm::all_of(rawAttrArray.drop_front(index + 1), isEmptyAttrDict)) {
+    op->removeAttr(attrName);
+    return;
+  }
+
+  // Otherwise, create a new attribute array with the updated dictionary.
+  SmallVector<Attribute, 8> newAttrs(rawAttrArray.begin(), rawAttrArray.end());
+  newAttrs[index] = attrs;
+  op->setAttr(attrName, ArrayAttr::get(op->getContext(), newAttrs));
+}
+
+/// Set all of the argument or result attribute dictionaries for a function.
+static void setAllArgResAttrDicts(Operation *op, StringRef attrName,
+                                  ArrayRef<Attribute> attrs) {
+  if (llvm::all_of(attrs, isEmptyAttrDict))
+    op->removeAttr(attrName);
+  else
+    op->setAttr(attrName, ArrayAttr::get(op->getContext(), attrs));
+}
+
+void mlir::function_like_impl::setAllArgAttrDicts(
+    Operation *op, ArrayRef<DictionaryAttr> attrs) {
+  setAllArgAttrDicts(op, ArrayRef<Attribute>(attrs.data(), attrs.size()));
+}
+void mlir::function_like_impl::setAllArgAttrDicts(Operation *op,
+                                                  ArrayRef<Attribute> attrs) {
+  auto wrappedAttrs = llvm::map_range(attrs, [op](Attribute attr) -> Attribute {
+    return !attr ? DictionaryAttr::get(op->getContext()) : attr;
+  });
+  setAllArgResAttrDicts(op, getArgDictAttrName(),
+                        llvm::to_vector<8>(wrappedAttrs));
+}
+
+void mlir::function_like_impl::setAllResultAttrDicts(
+    Operation *op, ArrayRef<DictionaryAttr> attrs) {
+  setAllResultAttrDicts(op, ArrayRef<Attribute>(attrs.data(), attrs.size()));
+}
+void mlir::function_like_impl::setAllResultAttrDicts(
+    Operation *op, ArrayRef<Attribute> attrs) {
+  auto wrappedAttrs = llvm::map_range(attrs, [op](Attribute attr) -> Attribute {
+    return !attr ? DictionaryAttr::get(op->getContext()) : attr;
+  });
+  setAllArgResAttrDicts(op, getResultDictAttrName(),
+                        llvm::to_vector<8>(wrappedAttrs));
+}
+
+void mlir::function_like_impl::eraseFunctionArguments(
+    Operation *op, ArrayRef<unsigned> argIndices, unsigned originalNumArgs,
+    Type newType) {
   // There are 3 things that need to be updated:
   // - Function type.
   // - Arg attrs.
   // - Block arguments of entry block.
   Block &entry = op->getRegion(0).front();
-  SmallString<8> nameBuf;
-
-  // Collect arg attrs to set.
-  SmallVector<DictionaryAttr, 4> newArgAttrs;
-  iterateIndicesExcept(originalNumArgs, argIndices, [&](unsigned i) {
-    newArgAttrs.emplace_back(getArgAttrDict(op, i));
-  });
-
-  // Remove any arg attrs that are no longer needed.
-  for (unsigned i = newArgAttrs.size(), e = originalNumArgs; i < e; ++i)
-    op->removeAttr(getArgAttrName(i, nameBuf));
-
-  // Set the function type.
-  op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
 
-  // Set the new arg attrs, or remove them if empty.
-  for (unsigned i = 0, e = newArgAttrs.size(); i != e; ++i) {
-    auto nameAttr = getArgAttrName(i, nameBuf);
-    if (newArgAttrs[i] && !newArgAttrs[i].empty())
-      op->setAttr(nameAttr, newArgAttrs[i]);
-    else
-      op->removeAttr(nameAttr);
+  // Update the argument attributes of the function.
+  if (auto argAttrs = op->getAttrOfType<ArrayAttr>(getArgDictAttrName())) {
+    SmallVector<DictionaryAttr, 4> newArgAttrs;
+    newArgAttrs.reserve(argAttrs.size());
+    iterateIndicesExcept(originalNumArgs, argIndices, [&](unsigned i) {
+      newArgAttrs.emplace_back(argAttrs[i].cast<DictionaryAttr>());
+    });
+    setAllArgAttrDicts(op, newArgAttrs);
   }
 
-  // Update the entry block's arguments.
+  // Update the function type and any entry block arguments.
+  op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
   entry.eraseArguments(argIndices);
 }
 
-void mlir::impl::eraseFunctionResults(Operation *op,
-                                      ArrayRef<unsigned> resultIndices,
-                                      unsigned originalNumResults,
-                                      Type newType) {
+void mlir::function_like_impl::eraseFunctionResults(
+    Operation *op, ArrayRef<unsigned> resultIndices,
+    unsigned originalNumResults, Type newType) {
   // There are 2 things that need to be updated:
   // - Function type.
   // - Result attrs.
-  SmallString<8> nameBuf;
-
-  // Collect result attrs to set.
-  SmallVector<DictionaryAttr, 4> newResultAttrs;
-  iterateIndicesExcept(originalNumResults, resultIndices, [&](unsigned i) {
-    newResultAttrs.emplace_back(getResultAttrDict(op, i));
-  });
 
-  // Remove any result attrs that are no longer needed.
-  for (unsigned i = newResultAttrs.size(), e = originalNumResults; i < e; ++i)
-    op->removeAttr(getResultAttrName(i, nameBuf));
+  // Update the result attributes of the function.
+  if (auto resAttrs = op->getAttrOfType<ArrayAttr>(getResultDictAttrName())) {
+    SmallVector<DictionaryAttr, 4> newResultAttrs;
+    newResultAttrs.reserve(resAttrs.size());
+    iterateIndicesExcept(originalNumResults, resultIndices, [&](unsigned i) {
+      newResultAttrs.emplace_back(resAttrs[i].cast<DictionaryAttr>());
+    });
+    setAllResultAttrDicts(op, newResultAttrs);
+  }
 
-  // Set the function type.
+  // Update the function type.
   op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
-
-  // Set the new result attrs, or remove them if empty.
-  for (unsigned i = 0, e = newResultAttrs.size(); i != e; ++i) {
-    auto nameAttr = getResultAttrName(i, nameBuf);
-    if (newResultAttrs[i] && !newResultAttrs[i].empty())
-      op->setAttr(nameAttr, newResultAttrs[i]);
-    else
-      op->removeAttr(nameAttr);
-  }
 }
 
 //===----------------------------------------------------------------------===//
 // Function type signature.
 //===----------------------------------------------------------------------===//
 
-FunctionType mlir::impl::getFunctionType(Operation *op) {
+FunctionType mlir::function_like_impl::getFunctionType(Operation *op) {
   assert(op->hasTrait<OpTrait::FunctionLike>());
-  return op->getAttrOfType<TypeAttr>(mlir::impl::getTypeAttrName())
+  return op->getAttrOfType<TypeAttr>(getTypeAttrName())
       .getValue()
       .cast<FunctionType>();
 }
 
-void mlir::impl::setFunctionType(Operation *op, FunctionType newType) {
+void mlir::function_like_impl::setFunctionType(Operation *op,
+                                               FunctionType newType) {
   assert(op->hasTrait<OpTrait::FunctionLike>());
-  SmallVector<char, 16> nameBuf;
   FunctionType oldType = getFunctionType(op);
-
-  for (int i = newType.getNumInputs(), e = oldType.getNumInputs(); i < e; i++)
-    op->removeAttr(getArgAttrName(i, nameBuf));
-  for (int i = newType.getNumResults(), e = oldType.getNumResults(); i < e; i++)
-    op->removeAttr(getResultAttrName(i, nameBuf));
   op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
+
+  // Functor used to update the argument and result attributes of the function.
+  auto updateAttrFn = [&](StringRef attrName, unsigned oldCount,
+                          unsigned newCount, auto setAttrFn) {
+    if (oldCount == newCount)
+      return;
+    // The new type has no arguments/results, just drop the attribute.
+    if (newCount == 0) {
+      op->removeAttr(attrName);
+      return;
+    }
+    ArrayAttr attrs = op->getAttrOfType<ArrayAttr>(attrName);
+    if (!attrs)
+      return;
+
+    // The new type has less arguments/results, take the first N attributes.
+    if (newCount < oldCount)
+      return setAttrFn(op, attrs.getValue().take_front(newCount));
+
+    // Otherwise, the new type has more arguments/results. Initialize the new
+    // arguments/results with empty attributes.
+    SmallVector<Attribute> newAttrs(attrs.begin(), attrs.end());
+    newAttrs.resize(newCount);
+    setAttrFn(op, newAttrs);
+  };
+
+  // Update the argument and result attributes.
+  updateAttrFn(function_like_impl::getArgDictAttrName(), oldType.getNumInputs(),
+               newType.getNumInputs(), [&](Operation *op, auto &&attrs) {
+                 setAllArgAttrDicts(op, attrs);
+               });
+  updateAttrFn(
+      function_like_impl::getResultDictAttrName(), oldType.getNumResults(),
+      newType.getNumResults(),
+      [&](Operation *op, auto &&attrs) { setAllResultAttrDicts(op, attrs); });
 }
 
 //===----------------------------------------------------------------------===//
 // Function body.
 //===----------------------------------------------------------------------===//
 
-Region &mlir::impl::getFunctionBody(Operation *op) {
+Region &mlir::function_like_impl::getFunctionBody(Operation *op) {
   assert(op->hasTrait<OpTrait::FunctionLike>());
   return op->getRegion(0);
 }

diff  --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index c8bb22e10b86..00b006cca1b4 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -2628,15 +2628,15 @@ struct FunctionLikeSignatureConversion : public ConversionPattern {
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
-    FunctionType type = mlir::impl::getFunctionType(op);
+    FunctionType type = function_like_impl::getFunctionType(op);
 
     // Convert the original function types.
     TypeConverter::SignatureConversion result(type.getNumInputs());
     SmallVector<Type, 1> newResults;
     if (failed(typeConverter->convertSignatureArgs(type.getInputs(), result)) ||
         failed(typeConverter->convertTypes(type.getResults(), newResults)) ||
-        failed(rewriter.convertRegionTypes(&mlir::impl::getFunctionBody(op),
-                                           *typeConverter, &result)))
+        failed(rewriter.convertRegionTypes(
+            &function_like_impl::getFunctionBody(op), *typeConverter, &result)))
       return failure();
 
     // Update the function signature in-place.
@@ -2644,7 +2644,7 @@ struct FunctionLikeSignatureConversion : public ConversionPattern {
                                      result.getConvertedTypes(), newResults);
 
     rewriter.updateRootInPlace(
-        op, [&] { mlir::impl::setFunctionType(op, newType); });
+        op, [&] { function_like_impl::setFunctionType(op, newType); });
 
     return success();
   }

diff  --git a/mlir/test/Dialect/LLVMIR/func.mlir b/mlir/test/Dialect/LLVMIR/func.mlir
index ab32af24f3b5..e52acf67745c 100644
--- a/mlir/test/Dialect/LLVMIR/func.mlir
+++ b/mlir/test/Dialect/LLVMIR/func.mlir
@@ -35,7 +35,7 @@ module {
   // CHECK: attributes {xxx = {yyy = 42 : i64}}
   "llvm.func"() ({
   }) {sym_name = "qux", type = !llvm.func<void (ptr<i64>, i64)>,
-      arg0 = {llvm.noalias = true}, xxx = {yyy = 42}} : () -> ()
+      arg_attrs = [{llvm.noalias = true}, {}], xxx = {yyy = 42}} : () -> ()
 
   // CHECK: llvm.func @roundtrip1()
   llvm.func @roundtrip1()

diff  --git a/mlir/test/IR/invalid-func-op.mlir b/mlir/test/IR/invalid-func-op.mlir
index c2ceefeb3ee7..6b15a076f1d3 100644
--- a/mlir/test/IR/invalid-func-op.mlir
+++ b/mlir/test/IR/invalid-func-op.mlir
@@ -94,3 +94,22 @@ func private @invalid_symbol_name_attr() attributes { sym_name = "x" }
 // expected-error at +1 {{'type' is an inferred attribute and should not be specified in the explicit attribute dictionary}}
 func private @invalid_symbol_type_attr() attributes { type = "x" }
 
+// -----
+
+// expected-error at +1 {{argument attribute array `arg_attrs` to have the same number of elements as the number of function arguments}}
+func private @invalid_arg_attrs() attributes { arg_attrs = [{}] }
+
+// -----
+
+// expected-error at +1 {{expects argument attribute dictionary to be a DictionaryAttr, but got `10 : i64`}}
+func private @invalid_arg_attrs(i32) attributes { arg_attrs = [10] }
+
+// -----
+
+// expected-error at +1 {{result attribute array `res_attrs` to have the same number of elements as the number of function results}}
+func private @invalid_res_attrs() attributes { res_attrs = [{}] }
+
+// -----
+
+// expected-error at +1 {{expects result attribute dictionary to be a DictionaryAttr, but got `10 : i64`}}
+func private @invalid_res_attrs() -> i32 attributes { res_attrs = [10] }

diff  --git a/mlir/test/IR/test-func-set-type.mlir b/mlir/test/IR/test-func-set-type.mlir
index 05a1393a8436..42f56ae9ac98 100644
--- a/mlir/test/IR/test-func-set-type.mlir
+++ b/mlir/test/IR/test-func-set-type.mlir
@@ -9,7 +9,6 @@
 // Test case: The setType call needs to erase some arg attrs.
 
 // CHECK: func private @erase_arg(f32 {test.A})
-// CHECK-NOT: attributes{{.*arg[0-9]}}
 func private @t(f32)
 func private @erase_arg(%arg0: f32 {test.A}, %arg1: f32 {test.B})
 attributes {test.set_type_from = @t}
@@ -19,7 +18,6 @@ attributes {test.set_type_from = @t}
 // Test case: The setType call needs to erase some result attrs.
 
 // CHECK: func private @erase_result() -> (f32 {test.A})
-// CHECK-NOT: attributes{{.*result[0-9]}}
 func private @t() -> (f32)
 func private @erase_result() -> (f32 {test.A}, f32 {test.B})
 attributes {test.set_type_from = @t}


        


More information about the Mlir-commits mailing list