[flang-commits] [flang] [mlir] [mlir] share argument attributes interface between calls and callables (PR #123176)

via flang-commits flang-commits at lists.llvm.org
Mon Jan 27 05:14:04 PST 2025


https://github.com/jeanPerier updated https://github.com/llvm/llvm-project/pull/123176

>From b652f02b7178b2c3d35759770c7b20aca1ede6c4 Mon Sep 17 00:00:00 2001
From: Jean Perier <jperier at nvidia.com>
Date: Wed, 15 Jan 2025 09:01:44 -0800
Subject: [PATCH 1/6] [mlir] share argument attributes interface between calls
 and callables

---
 mlir/docs/Interfaces.md                       |  16 +-
 .../mlir/Interfaces/CallImplementation.h      |  91 +++++++++
 .../include/mlir/Interfaces/CallInterfaces.td |  99 +++++-----
 .../mlir/Interfaces/FunctionImplementation.h  |  24 +--
 mlir/lib/Dialect/Async/IR/Async.cpp           |   2 +-
 mlir/lib/Dialect/EmitC/IR/EmitC.cpp           |   2 +-
 mlir/lib/Dialect/Func/IR/FuncOps.cpp          |   2 +-
 mlir/lib/Dialect/GPU/IR/GPUDialect.cpp        |   2 +-
 mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp    |   4 +-
 mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp        |   2 +-
 mlir/lib/Dialect/Shape/IR/Shape.cpp           |   2 +-
 mlir/lib/Interfaces/CMakeLists.txt            |  21 +--
 mlir/lib/Interfaces/CallImplementation.cpp    | 178 ++++++++++++++++++
 .../lib/Interfaces/FunctionImplementation.cpp | 148 +--------------
 14 files changed, 359 insertions(+), 234 deletions(-)
 create mode 100644 mlir/include/mlir/Interfaces/CallImplementation.h
 create mode 100644 mlir/lib/Interfaces/CallImplementation.cpp

diff --git a/mlir/docs/Interfaces.md b/mlir/docs/Interfaces.md
index 51747db546bb76..195f58dfbc23b3 100644
--- a/mlir/docs/Interfaces.md
+++ b/mlir/docs/Interfaces.md
@@ -753,7 +753,15 @@ interface section goes as follows:
     -   (`C++ class` -- `ODS class`(if applicable))
 
 ##### CallInterfaces
-
+*   `OpWithArgumentAttributesInterface`  - Used to represent operations that may
+    carry argument and result attributes. It is inherited by both
+    CallOpInterface and CallableOpInterface.
+    -   `ArrayAttr getArgAttrsAttr()`
+    -   `ArrayAttr getResAttrsAttr()`
+    -   `void setArgAttrsAttr(ArrayAttr)`
+    -   `void setResAttrsAttr(ArrayAttr)`
+    -   `Attribute removeArgAttrsAttr()`
+    -   `Attribute removeResAttrsAttr()`
 *   `CallOpInterface` - Used to represent operations like 'call'
     -   `CallInterfaceCallable getCallableForCallee()`
     -   `void setCalleeFromCallable(CallInterfaceCallable)`
@@ -761,12 +769,6 @@ interface section goes as follows:
     -   `Region * getCallableRegion()`
     -   `ArrayRef<Type> getArgumentTypes()`
     -   `ArrayRef<Type> getResultsTypes()`
-    -   `ArrayAttr getArgAttrsAttr()`
-    -   `ArrayAttr getResAttrsAttr()`
-    -   `void setArgAttrsAttr(ArrayAttr)`
-    -   `void setResAttrsAttr(ArrayAttr)`
-    -   `Attribute removeArgAttrsAttr()`
-    -   `Attribute removeResAttrsAttr()`
 
 ##### RegionKindInterfaces
 
diff --git a/mlir/include/mlir/Interfaces/CallImplementation.h b/mlir/include/mlir/Interfaces/CallImplementation.h
new file mode 100644
index 00000000000000..85e47f6b3dbbb9
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/CallImplementation.h
@@ -0,0 +1,91 @@
+//===- CallImplementation.h - Call and Callable Op utilities ----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file provides utility functions for implementing call-like and
+// callable-like operations, in particular, parsing, printing and verification
+// components common to these operations.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_CALLIMPLEMENTATION_H
+#define MLIR_INTERFACES_CALLIMPLEMENTATION_H
+
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/Interfaces/CallInterfaces.h"
+
+namespace mlir {
+
+class OpWithArgumentAttributesInterface;
+
+namespace call_interface_impl {
+
+/// Parse a function or call result list.
+///
+///   function-result-list ::= function-result-list-parens
+///                          | non-function-type
+///   function-result-list-parens ::= `(` `)`
+///                                 | `(` function-result-list-no-parens `)`
+///   function-result-list-no-parens ::= function-result (`,` function-result)*
+///   function-result ::= type attribute-dict?
+///
+ParseResult
+parseFunctionResultList(OpAsmParser &parser, SmallVectorImpl<Type> &resultTypes,
+                        SmallVectorImpl<DictionaryAttr> &resultAttrs);
+
+/// Parses a function signature using `parser`. This does not deal with function
+/// signatures containing SSA region arguments (to parse these signatures, use
+/// function_interface_impl::parseFunctionSignature). When
+/// `mustParseEmptyResult`, `-> ()` is expected when there is no result type.
+///
+///   no-ssa-function-signature ::= `(` no-ssa-function-arg-list `)`
+///                               -> function-result-list
+///   no-ssa-function-arg-list  ::= no-ssa-function-arg
+///                               (`,` no-ssa-function-arg)*
+///   no-ssa-function-arg       ::= type attribute-dict?
+ParseResult parseFunctionSignature(OpAsmParser &parser,
+                                   SmallVectorImpl<Type> &argTypes,
+                                   SmallVectorImpl<DictionaryAttr> &argAttrs,
+                                   SmallVectorImpl<Type> &resultTypes,
+                                   SmallVectorImpl<DictionaryAttr> &resultAttrs,
+                                   bool mustParseEmptyResult = true);
+
+/// Print a function signature for a call or callable operation. If a body
+/// region is provided, the SSA arguments are printed in the signature. When
+/// `printEmptyResult` is false, `-> function-result-list` is omitted when
+/// `resultTypes` is empty.
+///
+///   function-signature     ::= ssa-function-signature
+///                            | no-ssa-function-signature
+///   ssa-function-signature ::= `(` ssa-function-arg-list `)`
+///                            -> function-result-list
+///   ssa-function-arg-list  ::= ssa-function-arg (`,` ssa-function-arg)*
+///   ssa-function-arg       ::= `%`name `:` type attribute-dict?
+void printFunctionSignature(OpAsmPrinter &p,
+                            OpWithArgumentAttributesInterface op,
+                            TypeRange argTypes, bool isVariadic,
+                            TypeRange resultTypes, Region *body = nullptr,
+                            bool printEmptyResult = true);
+
+/// Adds argument and result attributes, provided as `argAttrs` and
+/// `resultAttrs` arguments, to the list of operation attributes in `result`.
+/// Internally, argument and result attributes are stored as dict attributes
+/// with special names given by getResultAttrName, getArgumentAttrName.
+void addArgAndResultAttrs(Builder &builder, OperationState &result,
+                          ArrayRef<DictionaryAttr> argAttrs,
+                          ArrayRef<DictionaryAttr> resultAttrs,
+                          StringAttr argAttrsName, StringAttr resAttrsName);
+void addArgAndResultAttrs(Builder &builder, OperationState &result,
+                          ArrayRef<OpAsmParser::Argument> args,
+                          ArrayRef<DictionaryAttr> resultAttrs,
+                          StringAttr argAttrsName, StringAttr resAttrsName);
+
+} // namespace call_interface_impl
+
+} // namespace mlir
+
+#endif // MLIR_INTERFACES_CALLIMPLEMENTATION_H
diff --git a/mlir/include/mlir/Interfaces/CallInterfaces.td b/mlir/include/mlir/Interfaces/CallInterfaces.td
index c6002da0d491ce..80912a9762187e 100644
--- a/mlir/include/mlir/Interfaces/CallInterfaces.td
+++ b/mlir/include/mlir/Interfaces/CallInterfaces.td
@@ -17,12 +17,65 @@
 
 include "mlir/IR/OpBase.td"
 
+
+/// Interface for operations with arguments attributes (both call-like
+/// and callable operations).
+def OpWithArgumentAttributesInterface : OpInterface<"OpWithArgumentAttributesInterface"> {
+  let description = [{
+    A call-like or callable operation that may define attributes for its arguments. 
+  }];
+  let cppNamespace = "::mlir";
+  let methods = [
+    InterfaceMethod<[{
+        Get the array of argument attribute dictionaries. The method should
+        return an array attribute containing only dictionary attributes equal in
+        number to the number of arguments. Alternatively, the method can
+        return null to indicate that the region has no argument attributes.
+      }],
+      "::mlir::ArrayAttr", "getArgAttrsAttr", (ins),
+      /*methodBody=*/[{}], /*defaultImplementation=*/[{ return nullptr; }]>,
+    InterfaceMethod<[{
+        Get the array of result attribute dictionaries. The method should return
+        an array attribute containing only dictionary attributes equal in number
+        to the number of results. Alternatively, the method can return
+        null to indicate that the region has no result attributes.
+      }],
+      "::mlir::ArrayAttr", "getResAttrsAttr", (ins),
+      /*methodBody=*/[{}], /*defaultImplementation=*/[{ return nullptr; }]>,
+    InterfaceMethod<[{
+      Set the array of argument attribute dictionaries.
+    }],
+    "void", "setArgAttrsAttr", (ins "::mlir::ArrayAttr":$attrs),
+      /*methodBody=*/[{}], /*defaultImplementation=*/[{ return; }]>,
+    InterfaceMethod<[{
+      Set the array of result attribute dictionaries.
+    }],
+    "void", "setResAttrsAttr", (ins "::mlir::ArrayAttr":$attrs),
+      /*methodBody=*/[{}], /*defaultImplementation=*/[{ return; }]>,
+    InterfaceMethod<[{
+      Remove the array of argument attribute dictionaries. This is the same as
+      setting all argument attributes to an empty dictionary. The method should
+      return the removed attribute.
+    }],
+    "::mlir::Attribute", "removeArgAttrsAttr", (ins),
+      /*methodBody=*/[{}], /*defaultImplementation=*/[{ return nullptr; }]>,
+    InterfaceMethod<[{
+      Remove the array of result attribute dictionaries. This is the same as
+      setting all result attributes to an empty dictionary. The method should
+      return the removed attribute.
+    }],
+    "::mlir::Attribute", "removeResAttrsAttr", (ins),
+      /*methodBody=*/[{}], /*defaultImplementation=*/[{ return nullptr; }]>,
+  ];
+}
+
 // `CallInterfaceCallable`: This is a type used to represent a single callable
 // region. A callable is either a symbol, or an SSA value, that is referenced by
 // a call-like operation. This represents the destination of the call.
 
 /// Interface for call-like operations.
-def CallOpInterface : OpInterface<"CallOpInterface"> {
+def CallOpInterface : OpInterface<"CallOpInterface",
+    [OpWithArgumentAttributesInterface]> {
   let description = [{
     A call-like operation is one that transfers control from one sub-routine to
     another. These operations may be traditional direct calls `call @foo`, or
@@ -85,7 +138,8 @@ def CallOpInterface : OpInterface<"CallOpInterface"> {
 }
 
 /// Interface for callable operations.
-def CallableOpInterface : OpInterface<"CallableOpInterface"> {
+def CallableOpInterface : OpInterface<"CallableOpInterface",
+    [OpWithArgumentAttributesInterface]> {
   let description = [{
     A callable operation is one who represents a potential sub-routine, and may
     be a target for a call-like operation (those providing the CallOpInterface
@@ -113,47 +167,6 @@ def CallableOpInterface : OpInterface<"CallableOpInterface"> {
       allow for this method may be called on function declarations).
     }],
     "::llvm::ArrayRef<::mlir::Type>", "getResultTypes">,
-
-    InterfaceMethod<[{
-        Get the array of argument attribute dictionaries. The method should
-        return an array attribute containing only dictionary attributes equal in
-        number to the number of region arguments. Alternatively, the method can
-        return null to indicate that the region has no argument attributes.
-      }],
-      "::mlir::ArrayAttr", "getArgAttrsAttr", (ins),
-      /*methodBody=*/[{}], /*defaultImplementation=*/[{ return nullptr; }]>,
-    InterfaceMethod<[{
-        Get the array of result attribute dictionaries. The method should return
-        an array attribute containing only dictionary attributes equal in number
-        to the number of region results. Alternatively, the method can return
-        null to indicate that the region has no result attributes.
-      }],
-      "::mlir::ArrayAttr", "getResAttrsAttr", (ins),
-      /*methodBody=*/[{}], /*defaultImplementation=*/[{ return nullptr; }]>,
-    InterfaceMethod<[{
-      Set the array of argument attribute dictionaries.
-    }],
-    "void", "setArgAttrsAttr", (ins "::mlir::ArrayAttr":$attrs),
-      /*methodBody=*/[{}], /*defaultImplementation=*/[{ return; }]>,
-    InterfaceMethod<[{
-      Set the array of result attribute dictionaries.
-    }],
-    "void", "setResAttrsAttr", (ins "::mlir::ArrayAttr":$attrs),
-      /*methodBody=*/[{}], /*defaultImplementation=*/[{ return; }]>,
-    InterfaceMethod<[{
-      Remove the array of argument attribute dictionaries. This is the same as
-      setting all argument attributes to an empty dictionary. The method should
-      return the removed attribute.
-    }],
-    "::mlir::Attribute", "removeArgAttrsAttr", (ins),
-      /*methodBody=*/[{}], /*defaultImplementation=*/[{ return nullptr; }]>,
-    InterfaceMethod<[{
-      Remove the array of result attribute dictionaries. This is the same as
-      setting all result attributes to an empty dictionary. The method should
-      return the removed attribute.
-    }],
-    "::mlir::Attribute", "removeResAttrsAttr", (ins),
-      /*methodBody=*/[{}], /*defaultImplementation=*/[{ return nullptr; }]>,
   ];
 }
 
diff --git a/mlir/include/mlir/Interfaces/FunctionImplementation.h b/mlir/include/mlir/Interfaces/FunctionImplementation.h
index a5e6963e4e666f..ae20533ef4b87c 100644
--- a/mlir/include/mlir/Interfaces/FunctionImplementation.h
+++ b/mlir/include/mlir/Interfaces/FunctionImplementation.h
@@ -16,6 +16,7 @@
 #define MLIR_IR_FUNCTIONIMPLEMENTATION_H_
 
 #include "mlir/IR/OpImplementation.h"
+#include "mlir/Interfaces/CallImplementation.h"
 #include "mlir/Interfaces/FunctionInterfaces.h"
 
 namespace mlir {
@@ -33,19 +34,6 @@ class VariadicFlag {
   bool variadic;
 };
 
-/// Adds argument and result attributes, provided as `argAttrs` and
-/// `resultAttrs` arguments, to the list of operation attributes in `result`.
-/// Internally, argument and result attributes are stored as dict attributes
-/// with special names given by getResultAttrName, getArgumentAttrName.
-void addArgAndResultAttrs(Builder &builder, OperationState &result,
-                          ArrayRef<DictionaryAttr> argAttrs,
-                          ArrayRef<DictionaryAttr> resultAttrs,
-                          StringAttr argAttrsName, StringAttr resAttrsName);
-void addArgAndResultAttrs(Builder &builder, OperationState &result,
-                          ArrayRef<OpAsmParser::Argument> args,
-                          ArrayRef<DictionaryAttr> resultAttrs,
-                          StringAttr argAttrsName, StringAttr resAttrsName);
-
 /// Callback type for `parseFunctionOp`, the callback should produce the
 /// type that will be associated with a function-like operation from lists of
 /// function arguments and results, VariadicFlag indicates whether the function
@@ -84,9 +72,13 @@ void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic,
 
 /// Prints the signature of the function-like operation `op`. Assumes `op` has
 /// is a FunctionOpInterface and has passed verification.
-void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op,
-                            ArrayRef<Type> argTypes, bool isVariadic,
-                            ArrayRef<Type> resultTypes);
+inline void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op,
+                                   ArrayRef<Type> argTypes, bool isVariadic,
+                                   ArrayRef<Type> resultTypes) {
+  call_interface_impl::printFunctionSignature(p, op, argTypes, isVariadic,
+                                              resultTypes, &op->getRegion(0),
+                                              /*printEmptyResult=*/false);
+}
 
 /// Prints the list of function prefixed with the "attributes" keyword. The
 /// attributes with names listed in "elided" as well as those used by the
diff --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp
index a3e3f80954efce..d3bb250bb8ab9d 100644
--- a/mlir/lib/Dialect/Async/IR/Async.cpp
+++ b/mlir/lib/Dialect/Async/IR/Async.cpp
@@ -308,7 +308,7 @@ void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
   if (argAttrs.empty())
     return;
   assert(type.getNumInputs() == argAttrs.size());
-  function_interface_impl::addArgAndResultAttrs(
+  call_interface_impl::addArgAndResultAttrs(
       builder, state, argAttrs, /*resultAttrs=*/std::nullopt,
       getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name));
 }
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index fdc21d6c6e24b9..6af17087ced968 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -529,7 +529,7 @@ void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
   if (argAttrs.empty())
     return;
   assert(type.getNumInputs() == argAttrs.size());
-  function_interface_impl::addArgAndResultAttrs(
+  call_interface_impl::addArgAndResultAttrs(
       builder, state, argAttrs, /*resultAttrs=*/std::nullopt,
       getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name));
 }
diff --git a/mlir/lib/Dialect/Func/IR/FuncOps.cpp b/mlir/lib/Dialect/Func/IR/FuncOps.cpp
index a490b4c3c4ab43..ba7b84f27d6a8a 100644
--- a/mlir/lib/Dialect/Func/IR/FuncOps.cpp
+++ b/mlir/lib/Dialect/Func/IR/FuncOps.cpp
@@ -190,7 +190,7 @@ void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
   if (argAttrs.empty())
     return;
   assert(type.getNumInputs() == argAttrs.size());
-  function_interface_impl::addArgAndResultAttrs(
+  call_interface_impl::addArgAndResultAttrs(
       builder, state, argAttrs, /*resultAttrs=*/std::nullopt,
       getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name));
 }
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 49209229259a73..8b85c0829acfec 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -1487,7 +1487,7 @@ ParseResult GPUFuncOp::parse(OpAsmParser &parser, OperationState &result) {
   result.addAttribute(getFunctionTypeAttrName(result.name),
                       TypeAttr::get(type));
 
-  function_interface_impl::addArgAndResultAttrs(
+  call_interface_impl::addArgAndResultAttrs(
       builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name),
       getResAttrsAttrName(result.name));
 
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index ef5f1b069b40a3..ef1e0222e05f06 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -2510,7 +2510,7 @@ void LLVMFuncOp::build(OpBuilder &builder, OperationState &result,
 
   assert(llvm::cast<LLVMFunctionType>(type).getNumParams() == argAttrs.size() &&
          "expected as many argument attribute lists as arguments");
-  function_interface_impl::addArgAndResultAttrs(
+  call_interface_impl::addArgAndResultAttrs(
       builder, result, argAttrs, /*resultAttrs=*/std::nullopt,
       getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
 }
@@ -2636,7 +2636,7 @@ ParseResult LLVMFuncOp::parse(OpAsmParser &parser, OperationState &result) {
 
   if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
     return failure();
-  function_interface_impl::addArgAndResultAttrs(
+  call_interface_impl::addArgAndResultAttrs(
       parser.getBuilder(), result, entryArgs, resultAttrs,
       getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
 
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 26559c1321db5e..870359ce55301c 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -940,7 +940,7 @@ ParseResult spirv::FuncOp::parse(OpAsmParser &parser, OperationState &result) {
 
   // Add the attributes to the function arguments.
   assert(resultAttrs.size() == resultTypes.size());
-  function_interface_impl::addArgAndResultAttrs(
+  call_interface_impl::addArgAndResultAttrs(
       builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name),
       getResAttrsAttrName(result.name));
 
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 65efc88e9c4033..2200af0f67a862 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -1297,7 +1297,7 @@ void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
   if (argAttrs.empty())
     return;
   assert(type.getNumInputs() == argAttrs.size());
-  function_interface_impl::addArgAndResultAttrs(
+  call_interface_impl::addArgAndResultAttrs(
       builder, state, argAttrs, /*resultAttrs=*/std::nullopt,
       getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name));
 }
diff --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt
index d3b7bf65ad3e73..76e2d921c2a9f5 100644
--- a/mlir/lib/Interfaces/CMakeLists.txt
+++ b/mlir/lib/Interfaces/CMakeLists.txt
@@ -1,4 +1,5 @@
 set(LLVM_OPTIONAL_SOURCES
+  CallImplementation.cpp
   CallInterfaces.cpp
   CastInterfaces.cpp
   ControlFlowInterfaces.cpp
@@ -24,8 +25,10 @@ set(LLVM_OPTIONAL_SOURCES
   )
 
 function(add_mlir_interface_library name)
+  cmake_parse_arguments(ARG "" "" "EXTRA_SOURCE" ${ARGN})
   add_mlir_library(MLIR${name}
     ${name}.cpp
+    ${ARG_EXTRA_SOURCE}
 
     ADDITIONAL_HEADER_DIRS
     ${MLIR_MAIN_INCLUDE_DIR}/mlir/Interfaces
@@ -39,28 +42,14 @@ function(add_mlir_interface_library name)
 endfunction(add_mlir_interface_library)
 
 
-add_mlir_interface_library(CallInterfaces)
+add_mlir_interface_library(CallInterfaces EXTRA_SOURCE CallImplementation.cpp)
 add_mlir_interface_library(CastInterfaces)
 add_mlir_interface_library(ControlFlowInterfaces)
 add_mlir_interface_library(CopyOpInterface)
 add_mlir_interface_library(DataLayoutInterfaces)
 add_mlir_interface_library(DerivedAttributeOpInterface)
 add_mlir_interface_library(DestinationStyleOpInterface)
-
-add_mlir_library(MLIRFunctionInterfaces
-  FunctionInterfaces.cpp
-  FunctionImplementation.cpp
-
-  ADDITIONAL_HEADER_DIRS
-  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Interfaces
-
-  DEPENDS
-  MLIRFunctionInterfacesIncGen
-
-  LINK_LIBS PUBLIC
-  MLIRIR
-)
-
+add_mlir_interface_library(FunctionInterfaces EXTRA_SOURCE FunctionImplementation.cpp)
 add_mlir_interface_library(InferIntRangeInterface)
 add_mlir_interface_library(InferTypeOpInterface)
 
diff --git a/mlir/lib/Interfaces/CallImplementation.cpp b/mlir/lib/Interfaces/CallImplementation.cpp
new file mode 100644
index 00000000000000..85eca609d8dc8d
--- /dev/null
+++ b/mlir/lib/Interfaces/CallImplementation.cpp
@@ -0,0 +1,178 @@
+//===- CallImplementation.pp - Call and Callable Op utilities -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Interfaces/CallImplementation.h"
+#include "mlir/IR/Builders.h"
+
+using namespace mlir;
+
+static ParseResult
+parseTypeAndAttrList(OpAsmParser &parser, SmallVectorImpl<Type> &types,
+                     SmallVectorImpl<DictionaryAttr> &attrs) {
+  // Parse individual function results.
+  return parser.parseCommaSeparatedList([&]() -> ParseResult {
+    types.emplace_back();
+    attrs.emplace_back();
+    NamedAttrList attrList;
+    if (parser.parseType(types.back()) ||
+        parser.parseOptionalAttrDict(attrList))
+      return failure();
+    attrs.back() = attrList.getDictionary(parser.getContext());
+    return success();
+  });
+}
+
+ParseResult call_interface_impl::parseFunctionResultList(
+    OpAsmParser &parser, SmallVectorImpl<Type> &resultTypes,
+    SmallVectorImpl<DictionaryAttr> &resultAttrs) {
+  if (failed(parser.parseOptionalLParen())) {
+    // We already know that there is no `(`, so parse a type.
+    // Because there is no `(`, it cannot be a function type.
+    Type ty;
+    if (parser.parseType(ty))
+      return failure();
+    resultTypes.push_back(ty);
+    resultAttrs.emplace_back();
+    return success();
+  }
+
+  // Special case for an empty set of parens.
+  if (succeeded(parser.parseOptionalRParen()))
+    return success();
+  if (parseTypeAndAttrList(parser, resultTypes, resultAttrs))
+    return failure();
+  return parser.parseRParen();
+}
+
+ParseResult call_interface_impl::parseFunctionSignature(
+    OpAsmParser &parser, SmallVectorImpl<Type> &argTypes,
+    SmallVectorImpl<DictionaryAttr> &argAttrs,
+    SmallVectorImpl<Type> &resultTypes,
+    SmallVectorImpl<DictionaryAttr> &resultAttrs, bool mustParseEmptyResult) {
+  // Parse arguments.
+  if (parser.parseLParen())
+    return failure();
+  if (failed(parser.parseOptionalRParen())) {
+    if (parseTypeAndAttrList(parser, argTypes, argAttrs))
+      return failure();
+    if (parser.parseRParen())
+      return failure();
+  }
+  // Parse results.
+  if (succeeded(parser.parseOptionalArrow()))
+    return call_interface_impl::parseFunctionResultList(parser, resultTypes,
+                                                        resultAttrs);
+  if (mustParseEmptyResult)
+    return failure();
+  return success();
+}
+
+/// Print a function result list. The provided `attrs` must either be null, or
+/// contain a set of DictionaryAttrs of the same arity as `types`.
+static void printFunctionResultList(OpAsmPrinter &p, TypeRange types,
+                                    ArrayAttr attrs) {
+  assert(!types.empty() && "Should not be called for empty result list.");
+  assert((!attrs || attrs.size() == types.size()) &&
+         "Invalid number of attributes.");
+
+  auto &os = p.getStream();
+  bool needsParens = types.size() > 1 || llvm::isa<FunctionType>(types[0]) ||
+                     (attrs && !llvm::cast<DictionaryAttr>(attrs[0]).empty());
+  if (needsParens)
+    os << '(';
+  llvm::interleaveComma(llvm::seq<size_t>(0, types.size()), os, [&](size_t i) {
+    p.printType(types[i]);
+    if (attrs)
+      p.printOptionalAttrDict(llvm::cast<DictionaryAttr>(attrs[i]).getValue());
+  });
+  if (needsParens)
+    os << ')';
+}
+
+void call_interface_impl::printFunctionSignature(
+    OpAsmPrinter &p, OpWithArgumentAttributesInterface op, TypeRange argTypes,
+    bool isVariadic, TypeRange resultTypes, Region *body,
+    bool printEmptyResult) {
+  bool isExternal = !body || body->empty();
+  ArrayAttr argAttrs = op.getArgAttrsAttr();
+  ArrayAttr resultAttrs = op.getResAttrsAttr();
+  if (!isExternal && !isVariadic && !argAttrs && !resultAttrs &&
+      printEmptyResult) {
+    p.printFunctionalType(argTypes, resultTypes);
+    return;
+  }
+
+  p << '(';
+  for (unsigned i = 0, e = argTypes.size(); i < e; ++i) {
+    if (i > 0)
+      p << ", ";
+
+    if (!isExternal) {
+      ArrayRef<NamedAttribute> attrs;
+      if (argAttrs)
+        attrs = llvm::cast<DictionaryAttr>(argAttrs[i]).getValue();
+      p.printRegionArgument(body->getArgument(i), attrs);
+    } else {
+      p.printType(argTypes[i]);
+      if (argAttrs)
+        p.printOptionalAttrDict(
+            llvm::cast<DictionaryAttr>(argAttrs[i]).getValue());
+    }
+  }
+
+  if (isVariadic) {
+    if (!argTypes.empty())
+      p << ", ";
+    p << "...";
+  }
+
+  p << ')';
+
+  if (!resultTypes.empty()) {
+    p << " -> ";
+    printFunctionResultList(p, resultTypes, resultAttrs);
+  } else if (printEmptyResult) {
+    p << " -> ()";
+  }
+}
+
+void call_interface_impl::addArgAndResultAttrs(
+    Builder &builder, OperationState &result, ArrayRef<DictionaryAttr> argAttrs,
+    ArrayRef<DictionaryAttr> resultAttrs, StringAttr argAttrsName,
+    StringAttr resAttrsName) {
+  auto nonEmptyAttrsFn = [](DictionaryAttr attrs) {
+    return attrs && !attrs.empty();
+  };
+  // Convert the specified array of dictionary attrs (which may have null
+  // entries) to an ArrayAttr of dictionaries.
+  auto getArrayAttr = [&](ArrayRef<DictionaryAttr> dictAttrs) {
+    SmallVector<Attribute> attrs;
+    for (auto &dict : dictAttrs)
+      attrs.push_back(dict ? dict : builder.getDictionaryAttr({}));
+    return builder.getArrayAttr(attrs);
+  };
+
+  // Add the attributes to the function arguments.
+  if (llvm::any_of(argAttrs, nonEmptyAttrsFn))
+    result.addAttribute(argAttrsName, getArrayAttr(argAttrs));
+
+  // Add the attributes to the function results.
+  if (llvm::any_of(resultAttrs, nonEmptyAttrsFn))
+    result.addAttribute(resAttrsName, getArrayAttr(resultAttrs));
+}
+
+void call_interface_impl::addArgAndResultAttrs(
+    Builder &builder, OperationState &result,
+    ArrayRef<OpAsmParser::Argument> args, ArrayRef<DictionaryAttr> resultAttrs,
+    StringAttr argAttrsName, StringAttr resAttrsName) {
+  SmallVector<DictionaryAttr> argAttrs;
+  for (const auto &arg : args)
+    argAttrs.push_back(arg.attrs);
+  addArgAndResultAttrs(builder, result, argAttrs, resultAttrs, argAttrsName,
+                       resAttrsName);
+}
diff --git a/mlir/lib/Interfaces/FunctionImplementation.cpp b/mlir/lib/Interfaces/FunctionImplementation.cpp
index 988feee665fea6..80174d1fefb559 100644
--- a/mlir/lib/Interfaces/FunctionImplementation.cpp
+++ b/mlir/lib/Interfaces/FunctionImplementation.cpp
@@ -70,49 +70,6 @@ parseFunctionArgumentList(OpAsmParser &parser, bool allowVariadic,
       });
 }
 
-/// Parse a function result list.
-///
-///   function-result-list ::= function-result-list-parens
-///                          | non-function-type
-///   function-result-list-parens ::= `(` `)`
-///                                 | `(` function-result-list-no-parens `)`
-///   function-result-list-no-parens ::= function-result (`,` function-result)*
-///   function-result ::= type attribute-dict?
-///
-static ParseResult
-parseFunctionResultList(OpAsmParser &parser, SmallVectorImpl<Type> &resultTypes,
-                        SmallVectorImpl<DictionaryAttr> &resultAttrs) {
-  if (failed(parser.parseOptionalLParen())) {
-    // We already know that there is no `(`, so parse a type.
-    // Because there is no `(`, it cannot be a function type.
-    Type ty;
-    if (parser.parseType(ty))
-      return failure();
-    resultTypes.push_back(ty);
-    resultAttrs.emplace_back();
-    return success();
-  }
-
-  // Special case for an empty set of parens.
-  if (succeeded(parser.parseOptionalRParen()))
-    return success();
-
-  // Parse individual function results.
-  if (parser.parseCommaSeparatedList([&]() -> ParseResult {
-        resultTypes.emplace_back();
-        resultAttrs.emplace_back();
-        NamedAttrList attrs;
-        if (parser.parseType(resultTypes.back()) ||
-            parser.parseOptionalAttrDict(attrs))
-          return failure();
-        resultAttrs.back() = attrs.getDictionary(parser.getContext());
-        return success();
-      }))
-    return failure();
-
-  return parser.parseRParen();
-}
-
 ParseResult function_interface_impl::parseFunctionSignature(
     OpAsmParser &parser, bool allowVariadic,
     SmallVectorImpl<OpAsmParser::Argument> &arguments, bool &isVariadic,
@@ -121,46 +78,11 @@ ParseResult function_interface_impl::parseFunctionSignature(
   if (parseFunctionArgumentList(parser, allowVariadic, arguments, isVariadic))
     return failure();
   if (succeeded(parser.parseOptionalArrow()))
-    return parseFunctionResultList(parser, resultTypes, resultAttrs);
+    return call_interface_impl::parseFunctionResultList(parser, resultTypes,
+                                                        resultAttrs);
   return success();
 }
 
-void function_interface_impl::addArgAndResultAttrs(
-    Builder &builder, OperationState &result, ArrayRef<DictionaryAttr> argAttrs,
-    ArrayRef<DictionaryAttr> resultAttrs, StringAttr argAttrsName,
-    StringAttr resAttrsName) {
-  auto nonEmptyAttrsFn = [](DictionaryAttr attrs) {
-    return attrs && !attrs.empty();
-  };
-  // Convert the specified array of dictionary attrs (which may have null
-  // entries) to an ArrayAttr of dictionaries.
-  auto getArrayAttr = [&](ArrayRef<DictionaryAttr> dictAttrs) {
-    SmallVector<Attribute> attrs;
-    for (auto &dict : dictAttrs)
-      attrs.push_back(dict ? dict : builder.getDictionaryAttr({}));
-    return builder.getArrayAttr(attrs);
-  };
-
-  // Add the attributes to the function arguments.
-  if (llvm::any_of(argAttrs, nonEmptyAttrsFn))
-    result.addAttribute(argAttrsName, getArrayAttr(argAttrs));
-
-  // Add the attributes to the function results.
-  if (llvm::any_of(resultAttrs, nonEmptyAttrsFn))
-    result.addAttribute(resAttrsName, getArrayAttr(resultAttrs));
-}
-
-void function_interface_impl::addArgAndResultAttrs(
-    Builder &builder, OperationState &result,
-    ArrayRef<OpAsmParser::Argument> args, ArrayRef<DictionaryAttr> resultAttrs,
-    StringAttr argAttrsName, StringAttr resAttrsName) {
-  SmallVector<DictionaryAttr> argAttrs;
-  for (const auto &arg : args)
-    argAttrs.push_back(arg.attrs);
-  addArgAndResultAttrs(builder, result, argAttrs, resultAttrs, argAttrsName,
-                       resAttrsName);
-}
-
 ParseResult function_interface_impl::parseFunctionOp(
     OpAsmParser &parser, OperationState &result, bool allowVariadic,
     StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder,
@@ -221,8 +143,8 @@ ParseResult function_interface_impl::parseFunctionOp(
 
   // Add the attributes to the function arguments.
   assert(resultAttrs.size() == resultTypes.size());
-  addArgAndResultAttrs(builder, result, entryArgs, resultAttrs, argAttrsName,
-                       resAttrsName);
+  call_interface_impl::addArgAndResultAttrs(
+      builder, result, entryArgs, resultAttrs, argAttrsName, resAttrsName);
 
   // Parse the optional function body. The printer will not print the body if
   // its empty, so disallow parsing of empty body in the parser.
@@ -241,68 +163,6 @@ ParseResult function_interface_impl::parseFunctionOp(
   return success();
 }
 
-/// Print a function result list. The provided `attrs` must either be null, or
-/// contain a set of DictionaryAttrs of the same arity as `types`.
-static void printFunctionResultList(OpAsmPrinter &p, ArrayRef<Type> types,
-                                    ArrayAttr attrs) {
-  assert(!types.empty() && "Should not be called for empty result list.");
-  assert((!attrs || attrs.size() == types.size()) &&
-         "Invalid number of attributes.");
-
-  auto &os = p.getStream();
-  bool needsParens = types.size() > 1 || llvm::isa<FunctionType>(types[0]) ||
-                     (attrs && !llvm::cast<DictionaryAttr>(attrs[0]).empty());
-  if (needsParens)
-    os << '(';
-  llvm::interleaveComma(llvm::seq<size_t>(0, types.size()), os, [&](size_t i) {
-    p.printType(types[i]);
-    if (attrs)
-      p.printOptionalAttrDict(llvm::cast<DictionaryAttr>(attrs[i]).getValue());
-  });
-  if (needsParens)
-    os << ')';
-}
-
-void function_interface_impl::printFunctionSignature(
-    OpAsmPrinter &p, FunctionOpInterface op, ArrayRef<Type> argTypes,
-    bool isVariadic, ArrayRef<Type> resultTypes) {
-  Region &body = op->getRegion(0);
-  bool isExternal = body.empty();
-
-  p << '(';
-  ArrayAttr argAttrs = op.getArgAttrsAttr();
-  for (unsigned i = 0, e = argTypes.size(); i < e; ++i) {
-    if (i > 0)
-      p << ", ";
-
-    if (!isExternal) {
-      ArrayRef<NamedAttribute> attrs;
-      if (argAttrs)
-        attrs = llvm::cast<DictionaryAttr>(argAttrs[i]).getValue();
-      p.printRegionArgument(body.getArgument(i), attrs);
-    } else {
-      p.printType(argTypes[i]);
-      if (argAttrs)
-        p.printOptionalAttrDict(
-            llvm::cast<DictionaryAttr>(argAttrs[i]).getValue());
-    }
-  }
-
-  if (isVariadic) {
-    if (!argTypes.empty())
-      p << ", ";
-    p << "...";
-  }
-
-  p << ')';
-
-  if (!resultTypes.empty()) {
-    p.getStream() << " -> ";
-    auto resultAttrs = op.getResAttrsAttr();
-    printFunctionResultList(p, resultTypes, resultAttrs);
-  }
-}
-
 void function_interface_impl::printFunctionAttributes(
     OpAsmPrinter &p, Operation *op, ArrayRef<StringRef> elided) {
   // Print out function attributes, if present.

>From 3328f681ad735cae0a5e2d06ab685bb4710fb994 Mon Sep 17 00:00:00 2001
From: Jean Perier <jperier at nvidia.com>
Date: Thu, 16 Jan 2025 08:59:05 -0800
Subject: [PATCH 2/6] rename and change inheritance level

---
 .../mlir/Interfaces/CallImplementation.h      |  5 +---
 .../include/mlir/Interfaces/CallInterfaces.td | 29 +++++++------------
 .../mlir/Interfaces/FunctionInterfaces.td     |  2 +-
 mlir/lib/Interfaces/CallImplementation.cpp    |  2 +-
 mlir/lib/Transforms/Utils/InliningUtils.cpp   | 24 ++++++++-------
 5 files changed, 28 insertions(+), 34 deletions(-)

diff --git a/mlir/include/mlir/Interfaces/CallImplementation.h b/mlir/include/mlir/Interfaces/CallImplementation.h
index 85e47f6b3dbbb9..2edc081bddf478 100644
--- a/mlir/include/mlir/Interfaces/CallImplementation.h
+++ b/mlir/include/mlir/Interfaces/CallImplementation.h
@@ -20,8 +20,6 @@
 
 namespace mlir {
 
-class OpWithArgumentAttributesInterface;
-
 namespace call_interface_impl {
 
 /// Parse a function or call result list.
@@ -65,8 +63,7 @@ ParseResult parseFunctionSignature(OpAsmParser &parser,
 ///                            -> function-result-list
 ///   ssa-function-arg-list  ::= ssa-function-arg (`,` ssa-function-arg)*
 ///   ssa-function-arg       ::= `%`name `:` type attribute-dict?
-void printFunctionSignature(OpAsmPrinter &p,
-                            OpWithArgumentAttributesInterface op,
+void printFunctionSignature(OpAsmPrinter &p, ArgumentAttributesOpInterface op,
                             TypeRange argTypes, bool isVariadic,
                             TypeRange resultTypes, Region *body = nullptr,
                             bool printEmptyResult = true);
diff --git a/mlir/include/mlir/Interfaces/CallInterfaces.td b/mlir/include/mlir/Interfaces/CallInterfaces.td
index 80912a9762187e..1f2398387c044e 100644
--- a/mlir/include/mlir/Interfaces/CallInterfaces.td
+++ b/mlir/include/mlir/Interfaces/CallInterfaces.td
@@ -20,9 +20,10 @@ include "mlir/IR/OpBase.td"
 
 /// Interface for operations with arguments attributes (both call-like
 /// and callable operations).
-def OpWithArgumentAttributesInterface : OpInterface<"OpWithArgumentAttributesInterface"> {
+def ArgumentAttributesOpInterface : OpInterface<"ArgumentAttributesOpInterface"> {
   let description = [{
-    A call-like or callable operation that may define attributes for its arguments. 
+    A call-like or callable operation that can hold attributes for its arguments
+    and results.
   }];
   let cppNamespace = "::mlir";
   let methods = [
@@ -32,40 +33,34 @@ def OpWithArgumentAttributesInterface : OpInterface<"OpWithArgumentAttributesInt
         number to the number of arguments. Alternatively, the method can
         return null to indicate that the region has no argument attributes.
       }],
-      "::mlir::ArrayAttr", "getArgAttrsAttr", (ins),
-      /*methodBody=*/[{}], /*defaultImplementation=*/[{ return nullptr; }]>,
+      "::mlir::ArrayAttr", "getArgAttrsAttr">,
     InterfaceMethod<[{
         Get the array of result attribute dictionaries. The method should return
         an array attribute containing only dictionary attributes equal in number
         to the number of results. Alternatively, the method can return
         null to indicate that the region has no result attributes.
       }],
-      "::mlir::ArrayAttr", "getResAttrsAttr", (ins),
-      /*methodBody=*/[{}], /*defaultImplementation=*/[{ return nullptr; }]>,
+      "::mlir::ArrayAttr", "getResAttrsAttr">,
     InterfaceMethod<[{
       Set the array of argument attribute dictionaries.
     }],
-    "void", "setArgAttrsAttr", (ins "::mlir::ArrayAttr":$attrs),
-      /*methodBody=*/[{}], /*defaultImplementation=*/[{ return; }]>,
+    "void", "setArgAttrsAttr", (ins "::mlir::ArrayAttr":$attrs)>,
     InterfaceMethod<[{
       Set the array of result attribute dictionaries.
     }],
-    "void", "setResAttrsAttr", (ins "::mlir::ArrayAttr":$attrs),
-      /*methodBody=*/[{}], /*defaultImplementation=*/[{ return; }]>,
+    "void", "setResAttrsAttr", (ins "::mlir::ArrayAttr":$attrs)>,
     InterfaceMethod<[{
       Remove the array of argument attribute dictionaries. This is the same as
       setting all argument attributes to an empty dictionary. The method should
       return the removed attribute.
     }],
-    "::mlir::Attribute", "removeArgAttrsAttr", (ins),
-      /*methodBody=*/[{}], /*defaultImplementation=*/[{ return nullptr; }]>,
+    "::mlir::Attribute", "removeArgAttrsAttr">,
     InterfaceMethod<[{
       Remove the array of result attribute dictionaries. This is the same as
       setting all result attributes to an empty dictionary. The method should
       return the removed attribute.
     }],
-    "::mlir::Attribute", "removeResAttrsAttr", (ins),
-      /*methodBody=*/[{}], /*defaultImplementation=*/[{ return nullptr; }]>,
+    "::mlir::Attribute", "removeResAttrsAttr">
   ];
 }
 
@@ -74,8 +69,7 @@ def OpWithArgumentAttributesInterface : OpInterface<"OpWithArgumentAttributesInt
 // a call-like operation. This represents the destination of the call.
 
 /// Interface for call-like operations.
-def CallOpInterface : OpInterface<"CallOpInterface",
-    [OpWithArgumentAttributesInterface]> {
+def CallOpInterface : OpInterface<"CallOpInterface"> {
   let description = [{
     A call-like operation is one that transfers control from one sub-routine to
     another. These operations may be traditional direct calls `call @foo`, or
@@ -138,8 +132,7 @@ def CallOpInterface : OpInterface<"CallOpInterface",
 }
 
 /// Interface for callable operations.
-def CallableOpInterface : OpInterface<"CallableOpInterface",
-    [OpWithArgumentAttributesInterface]> {
+def CallableOpInterface : OpInterface<"CallableOpInterface"> {
   let description = [{
     A callable operation is one who represents a potential sub-routine, and may
     be a target for a call-like operation (those providing the CallOpInterface
diff --git a/mlir/include/mlir/Interfaces/FunctionInterfaces.td b/mlir/include/mlir/Interfaces/FunctionInterfaces.td
index 697f951748c675..616785837e1452 100644
--- a/mlir/include/mlir/Interfaces/FunctionInterfaces.td
+++ b/mlir/include/mlir/Interfaces/FunctionInterfaces.td
@@ -22,7 +22,7 @@ include "mlir/Interfaces/CallInterfaces.td"
 //===----------------------------------------------------------------------===//
 
 def FunctionOpInterface : OpInterface<"FunctionOpInterface", [
-    Symbol, CallableOpInterface
+    Symbol, CallableOpInterface, ArgumentAttributesOpInterface
   ]> {
   let cppNamespace = "::mlir";
   let description = [{
diff --git a/mlir/lib/Interfaces/CallImplementation.cpp b/mlir/lib/Interfaces/CallImplementation.cpp
index 85eca609d8dc8d..974e779e32d30b 100644
--- a/mlir/lib/Interfaces/CallImplementation.cpp
+++ b/mlir/lib/Interfaces/CallImplementation.cpp
@@ -95,7 +95,7 @@ static void printFunctionResultList(OpAsmPrinter &p, TypeRange types,
 }
 
 void call_interface_impl::printFunctionSignature(
-    OpAsmPrinter &p, OpWithArgumentAttributesInterface op, TypeRange argTypes,
+    OpAsmPrinter &p, ArgumentAttributesOpInterface op, TypeRange argTypes,
     bool isVariadic, TypeRange resultTypes, Region *body,
     bool printEmptyResult) {
   bool isExternal = !body || body->empty();
diff --git a/mlir/lib/Transforms/Utils/InliningUtils.cpp b/mlir/lib/Transforms/Utils/InliningUtils.cpp
index 0cae63c58ca7be..57a7931b56085a 100644
--- a/mlir/lib/Transforms/Utils/InliningUtils.cpp
+++ b/mlir/lib/Transforms/Utils/InliningUtils.cpp
@@ -193,11 +193,13 @@ static void handleArgumentImpl(InlinerInterface &interface, OpBuilder &builder,
   SmallVector<DictionaryAttr> argAttrs(
       callable.getCallableRegion()->getNumArguments(),
       builder.getDictionaryAttr({}));
-  if (ArrayAttr arrayAttr = callable.getArgAttrsAttr()) {
-    assert(arrayAttr.size() == argAttrs.size());
-    for (auto [idx, attr] : llvm::enumerate(arrayAttr))
-      argAttrs[idx] = cast<DictionaryAttr>(attr);
-  }
+  if (auto argAttrsOpInterface =
+          dyn_cast<ArgumentAttributesOpInterface>(callable.getOperation()))
+    if (ArrayAttr arrayAttr = argAttrsOpInterface.getArgAttrsAttr()) {
+      assert(arrayAttr.size() == argAttrs.size());
+      for (auto [idx, attr] : llvm::enumerate(arrayAttr))
+        argAttrs[idx] = cast<DictionaryAttr>(attr);
+    }
 
   // Run the argument attribute handler for the given argument and attribute.
   for (auto [blockArg, argAttr] :
@@ -218,11 +220,13 @@ static void handleResultImpl(InlinerInterface &interface, OpBuilder &builder,
   // Unpack the result attributes if there are any.
   SmallVector<DictionaryAttr> resAttrs(results.size(),
                                        builder.getDictionaryAttr({}));
-  if (ArrayAttr arrayAttr = callable.getResAttrsAttr()) {
-    assert(arrayAttr.size() == resAttrs.size());
-    for (auto [idx, attr] : llvm::enumerate(arrayAttr))
-      resAttrs[idx] = cast<DictionaryAttr>(attr);
-  }
+  if (auto argAttrsOpInterface =
+          dyn_cast<ArgumentAttributesOpInterface>(callable.getOperation()))
+    if (ArrayAttr arrayAttr = argAttrsOpInterface.getResAttrsAttr()) {
+      assert(arrayAttr.size() == resAttrs.size());
+      for (auto [idx, attr] : llvm::enumerate(arrayAttr))
+        resAttrs[idx] = cast<DictionaryAttr>(attr);
+    }
 
   // Run the result attribute handler for the given result and attribute.
   SmallVector<DictionaryAttr> resultAttributes;

>From 43cd7041311ed77f25404c5bafd291f9c0567d0c Mon Sep 17 00:00:00 2001
From: Jean Perier <jperier at nvidia.com>
Date: Tue, 21 Jan 2025 01:58:44 -0800
Subject: [PATCH 3/6] PR comments: remove new interface and added files

Share methods via listconcat instead.
Make methods mandatory (no default impl).
Add arg_attrs and res_attrs to all concrete operations inheriting from
CallOpInterface and CallableOpInterface.
---
 .../flang/Optimizer/Dialect/CUF/CUFOps.td     |   4 +-
 .../include/flang/Optimizer/Dialect/FIROps.td |   4 +
 flang/lib/Lower/ConvertCall.cpp               |  12 +-
 flang/lib/Optimizer/CodeGen/TargetRewrite.cpp |   2 +
 .../Optimizer/Transforms/AbstractResult.cpp   |   3 +
 .../Transforms/PolymorphicOpConversion.cpp    |   5 +-
 mlir/docs/Interfaces.md                       |  15 +-
 .../include/mlir/Dialect/Async/IR/AsyncOps.td |   8 +-
 mlir/include/mlir/Dialect/EmitC/IR/EmitC.td   |   8 +-
 mlir/include/mlir/Dialect/Func/IR/FuncOps.td  |  31 ++-
 mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td   |   6 +-
 .../Dialect/SPIRV/IR/SPIRVControlFlowOps.td   |  13 +-
 .../mlir/Dialect/Transform/IR/TransformOps.td |   4 +-
 .../mlir/Interfaces/CallImplementation.h      |  88 ---------
 mlir/include/mlir/Interfaces/CallInterfaces.h |  61 ++++++
 .../include/mlir/Interfaces/CallInterfaces.td |  17 +-
 .../mlir/Interfaces/FunctionImplementation.h  |   8 +-
 .../mlir/Interfaces/FunctionInterfaces.td     |   2 +-
 mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp    |  15 +-
 mlir/lib/Interfaces/CMakeLists.txt            |  21 ++-
 mlir/lib/Interfaces/CallImplementation.cpp    | 178 ------------------
 mlir/lib/Interfaces/CallInterfaces.cpp        | 169 +++++++++++++++++
 mlir/lib/Transforms/Utils/InliningUtils.cpp   |  24 +--
 mlir/test/lib/Dialect/Test/TestOps.td         |  19 +-
 24 files changed, 387 insertions(+), 330 deletions(-)
 delete mode 100644 mlir/include/mlir/Interfaces/CallImplementation.h
 delete mode 100644 mlir/lib/Interfaces/CallImplementation.cpp

diff --git a/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td b/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td
index a270e69b394104..c1021da0cfb213 100644
--- a/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td
+++ b/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td
@@ -207,7 +207,9 @@ def cuf_KernelLaunchOp : cuf_Op<"kernel_launch", [CallOpInterface,
     I32:$block_z,
     Optional<I32>:$bytes,
     Optional<I32>:$stream,
-    Variadic<AnyType>:$args
+    Variadic<AnyType>:$args,
+    OptionalAttr<DictArrayAttr>:$arg_attrs,
+    OptionalAttr<DictArrayAttr>:$res_attrs
   );
 
   let assemblyFormat = [{
diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index 5f0f0b48e892b9..8dbc9df9f553de 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -2432,6 +2432,8 @@ def fir_CallOp : fir_Op<"call",
   let arguments = (ins
     OptionalAttr<SymbolRefAttr>:$callee,
     Variadic<AnyType>:$args,
+    OptionalAttr<DictArrayAttr>:$arg_attrs,
+    OptionalAttr<DictArrayAttr>:$res_attrs,
     OptionalAttr<fir_FortranProcedureFlagsAttr>:$procedure_attrs,
     DefaultValuedAttr<Arith_FastMathAttr,
                       "::mlir::arith::FastMathFlags::none">:$fastmath
@@ -2518,6 +2520,8 @@ def fir_DispatchOp : fir_Op<"dispatch", []> {
     fir_ClassType:$object,
     Variadic<AnyType>:$args,
     OptionalAttr<I32Attr>:$pass_arg_pos,
+    OptionalAttr<DictArrayAttr>:$arg_attrs,
+    OptionalAttr<DictArrayAttr>:$res_attrs,
     OptionalAttr<fir_FortranProcedureFlagsAttr>:$procedure_attrs
   );
 
diff --git a/flang/lib/Lower/ConvertCall.cpp b/flang/lib/Lower/ConvertCall.cpp
index 40cd106e630182..7ca2baf0193ce1 100644
--- a/flang/lib/Lower/ConvertCall.cpp
+++ b/flang/lib/Lower/ConvertCall.cpp
@@ -594,7 +594,8 @@ Fortran::lower::genCallOpAndResult(
 
     builder.create<cuf::KernelLaunchOp>(
         loc, funcType.getResults(), funcSymbolAttr, grid_x, grid_y, grid_z,
-        block_x, block_y, block_z, bytes, stream, operands);
+        block_x, block_y, block_z, bytes, stream, operands,
+        /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr);
     callNumResults = 0;
   } else if (caller.requireDispatchCall()) {
     // Procedure call requiring a dynamic dispatch. Call is created with
@@ -621,7 +622,8 @@ Fortran::lower::genCallOpAndResult(
       dispatch = builder.create<fir::DispatchOp>(
           loc, funcType.getResults(), builder.getStringAttr(procName),
           caller.getInputs()[*passArg], operands,
-          builder.getI32IntegerAttr(*passArg), procAttrs);
+          builder.getI32IntegerAttr(*passArg), /*arg_attrs=*/nullptr,
+          /*res_attrs=*/nullptr, procAttrs);
     } else {
       // NOPASS
       const Fortran::evaluate::Component *component =
@@ -636,7 +638,8 @@ Fortran::lower::genCallOpAndResult(
         passObject = builder.create<fir::LoadOp>(loc, passObject);
       dispatch = builder.create<fir::DispatchOp>(
           loc, funcType.getResults(), builder.getStringAttr(procName),
-          passObject, operands, nullptr, procAttrs);
+          passObject, operands, nullptr, /*arg_attrs=*/nullptr,
+          /*res_attrs=*/nullptr, procAttrs);
     }
     callNumResults = dispatch.getNumResults();
     if (callNumResults != 0)
@@ -644,7 +647,8 @@ Fortran::lower::genCallOpAndResult(
   } else {
     // Standard procedure call with fir.call.
     auto call = builder.create<fir::CallOp>(
-        loc, funcType.getResults(), funcSymbolAttr, operands, procAttrs);
+        loc, funcType.getResults(), funcSymbolAttr, operands,
+        /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr, procAttrs);
 
     callNumResults = call.getNumResults();
     if (callNumResults != 0)
diff --git a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
index b0b9499557e2b7..010cce3681d3ff 100644
--- a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
+++ b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
@@ -518,6 +518,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
     newOpers.insert(newOpers.end(), trailingOpers.begin(), trailingOpers.end());
 
     llvm::SmallVector<mlir::Value, 1> newCallResults;
+    // TODO propagate/update call argument and result attributes.
     if constexpr (std::is_same_v<std::decay_t<A>, mlir::gpu::LaunchFuncOp>) {
       auto newCall = rewriter->create<A>(
           loc, callOp.getKernel(), callOp.getGridSizeOperandValues(),
@@ -557,6 +558,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
           loc, newResTys, rewriter->getStringAttr(callOp.getMethod()),
           callOp.getOperands()[0], newOpers,
           rewriter->getI32IntegerAttr(*callOp.getPassArgPos() + passArgShift),
+          /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr,
           callOp.getProcedureAttrsAttr());
       if (wrap)
         newCallResults.push_back((*wrap)(dispatchOp.getOperation()));
diff --git a/flang/lib/Optimizer/Transforms/AbstractResult.cpp b/flang/lib/Optimizer/Transforms/AbstractResult.cpp
index b0327cc10e9de6..f8badfa639f949 100644
--- a/flang/lib/Optimizer/Transforms/AbstractResult.cpp
+++ b/flang/lib/Optimizer/Transforms/AbstractResult.cpp
@@ -147,6 +147,7 @@ class CallConversion : public mlir::OpRewritePattern<Op> {
       newResultTypes.emplace_back(getVoidPtrType(result.getContext()));
 
     Op newOp;
+    // TODO: propagate argument and result attributes (need to be shifted).
     // fir::CallOp specific handling.
     if constexpr (std::is_same_v<Op, fir::CallOp>) {
       if (op.getCallee()) {
@@ -189,9 +190,11 @@ class CallConversion : public mlir::OpRewritePattern<Op> {
       if (op.getPassArgPos())
         passArgPos =
             rewriter.getI32IntegerAttr(*op.getPassArgPos() + passArgShift);
+      // TODO: propagate argument and result attributes (need to be shifted).
       newOp = rewriter.create<fir::DispatchOp>(
           loc, newResultTypes, rewriter.getStringAttr(op.getMethod()),
           op.getOperands()[0], newOperands, passArgPos,
+          /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr,
           op.getProcedureAttrsAttr());
     }
 
diff --git a/flang/lib/Optimizer/Transforms/PolymorphicOpConversion.cpp b/flang/lib/Optimizer/Transforms/PolymorphicOpConversion.cpp
index 070889a284f481..0c78a878cdc536 100644
--- a/flang/lib/Optimizer/Transforms/PolymorphicOpConversion.cpp
+++ b/flang/lib/Optimizer/Transforms/PolymorphicOpConversion.cpp
@@ -205,8 +205,9 @@ struct DispatchOpConv : public OpConversionPattern<fir::DispatchOp> {
     // Make the call.
     llvm::SmallVector<mlir::Value> args{funcPtr};
     args.append(dispatch.getArgs().begin(), dispatch.getArgs().end());
-    rewriter.replaceOpWithNewOp<fir::CallOp>(dispatch, resTypes, nullptr, args,
-                                             dispatch.getProcedureAttrsAttr());
+    rewriter.replaceOpWithNewOp<fir::CallOp>(
+        dispatch, resTypes, nullptr, args, dispatch.getArgAttrsAttr(),
+        dispatch.getResAttrsAttr(), dispatch.getProcedureAttrsAttr());
     return mlir::success();
   }
 
diff --git a/mlir/docs/Interfaces.md b/mlir/docs/Interfaces.md
index 195f58dfbc23b3..b7e9e64d23d77d 100644
--- a/mlir/docs/Interfaces.md
+++ b/mlir/docs/Interfaces.md
@@ -753,22 +753,25 @@ interface section goes as follows:
     -   (`C++ class` -- `ODS class`(if applicable))
 
 ##### CallInterfaces
-*   `OpWithArgumentAttributesInterface`  - Used to represent operations that may
-    carry argument and result attributes. It is inherited by both
-    CallOpInterface and CallableOpInterface.
+*   `CallOpInterface` - Used to represent operations like 'call'
+    -   `CallInterfaceCallable getCallableForCallee()`
+    -   `void setCalleeFromCallable(CallInterfaceCallable)`
     -   `ArrayAttr getArgAttrsAttr()`
     -   `ArrayAttr getResAttrsAttr()`
     -   `void setArgAttrsAttr(ArrayAttr)`
     -   `void setResAttrsAttr(ArrayAttr)`
     -   `Attribute removeArgAttrsAttr()`
     -   `Attribute removeResAttrsAttr()`
-*   `CallOpInterface` - Used to represent operations like 'call'
-    -   `CallInterfaceCallable getCallableForCallee()`
-    -   `void setCalleeFromCallable(CallInterfaceCallable)`
 *   `CallableOpInterface` - Used to represent the target callee of call.
     -   `Region * getCallableRegion()`
     -   `ArrayRef<Type> getArgumentTypes()`
     -   `ArrayRef<Type> getResultsTypes()`
+    -   `ArrayAttr getArgAttrsAttr()`
+    -   `ArrayAttr getResAttrsAttr()`
+    -   `void setArgAttrsAttr(ArrayAttr)`
+    -   `void setResAttrsAttr(ArrayAttr)`
+    -   `Attribute removeArgAttrsAttr()`
+    -   `Attribute removeResAttrsAttr()`
 
 ##### RegionKindInterfaces
 
diff --git a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
index a08f5d6e714ef3..3d29d5bc7dbb68 100644
--- a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
+++ b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
@@ -208,7 +208,13 @@ def Async_CallOp : Async_Op<"call",
     ```
   }];
 
-  let arguments = (ins FlatSymbolRefAttr:$callee, Variadic<AnyType>:$operands);
+  let arguments = (ins
+    FlatSymbolRefAttr:$callee,
+    Variadic<AnyType>:$operands,
+    OptionalAttr<DictArrayAttr>:$arg_attrs,
+    OptionalAttr<DictArrayAttr>:$res_attrs
+  );
+
   let results = (outs Variadic<Async_AnyValueOrTokenType>);
 
   let builders = [
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index b16f5a8619fe7b..11c48dcf325c8b 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -533,7 +533,13 @@ def EmitC_CallOp : EmitC_Op<"call",
     %2 = emitc.call @my_add(%0, %1) : (f32, f32) -> f32
     ```
   }];
-  let arguments = (ins FlatSymbolRefAttr:$callee, Variadic<EmitCType>:$operands);
+  let arguments = (ins 
+    FlatSymbolRefAttr:$callee,
+    Variadic<EmitCType>:$operands,
+    OptionalAttr<DictArrayAttr>:$arg_attrs,
+    OptionalAttr<DictArrayAttr>:$res_attrs
+  );
+
   let results = (outs Variadic<EmitCType>);
 
   let builders = [
diff --git a/mlir/include/mlir/Dialect/Func/IR/FuncOps.td b/mlir/include/mlir/Dialect/Func/IR/FuncOps.td
index 4da0efcb13ddf5..cdaeb6461afb4e 100644
--- a/mlir/include/mlir/Dialect/Func/IR/FuncOps.td
+++ b/mlir/include/mlir/Dialect/Func/IR/FuncOps.td
@@ -49,8 +49,14 @@ def CallOp : Func_Op<"call",
     ```
   }];
 
-  let arguments = (ins FlatSymbolRefAttr:$callee, Variadic<AnyType>:$operands,
-                       UnitAttr:$no_inline);
+  let arguments = (ins
+    FlatSymbolRefAttr:$callee,
+    Variadic<AnyType>:$operands,
+    OptionalAttr<DictArrayAttr>:$arg_attrs,
+    OptionalAttr<DictArrayAttr>:$res_attrs, 
+    UnitAttr:$no_inline
+  );
+
   let results = (outs Variadic<AnyType>);
 
   let builders = [
@@ -73,6 +79,18 @@ def CallOp : Func_Op<"call",
       CArg<"ValueRange", "{}">:$operands), [{
       build($_builder, $_state, StringAttr::get($_builder.getContext(), callee),
             results, operands);
+    }]>,
+    OpBuilder<(ins "TypeRange":$results, "FlatSymbolRefAttr":$callee,
+      CArg<"ValueRange", "{}">:$operands), [{
+      build($_builder, $_state, callee, results, operands);
+    }]>,
+    OpBuilder<(ins "TypeRange":$results, "StringAttr":$callee,
+      CArg<"ValueRange", "{}">:$operands), [{
+      build($_builder, $_state, callee, results, operands);
+    }]>,
+    OpBuilder<(ins "TypeRange":$results, "StringRef":$callee,
+      CArg<"ValueRange", "{}">:$operands), [{
+      build($_builder, $_state, callee, results, operands);
     }]>];
 
   let extraClassDeclaration = [{
@@ -136,8 +154,13 @@ def CallIndirectOp : Func_Op<"call_indirect", [
     ```
   }];
 
-  let arguments = (ins FunctionType:$callee,
-                       Variadic<AnyType>:$callee_operands);
+  let arguments = (ins
+    FunctionType:$callee,
+    Variadic<AnyType>:$callee_operands,
+    OptionalAttr<DictArrayAttr>:$arg_attrs,
+    OptionalAttr<DictArrayAttr>:$res_attrs
+  );
+
   let results = (outs Variadic<AnyType>:$results);
 
   let builders = [
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index b2281536aa40b6..ee6e10efed4f16 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -633,6 +633,8 @@ def LLVM_InvokeOp : LLVM_Op<"invoke", [
                    OptionalAttr<TypeAttrOf<LLVM_FunctionType>>:$var_callee_type,
                    OptionalAttr<FlatSymbolRefAttr>:$callee,
                    Variadic<LLVM_Type>:$callee_operands,
+                   OptionalAttr<DictArrayAttr>:$arg_attrs,
+                   OptionalAttr<DictArrayAttr>:$res_attrs,
                    Variadic<LLVM_Type>:$normalDestOperands,
                    Variadic<LLVM_Type>:$unwindDestOperands,
                    OptionalAttr<DenseI32ArrayAttr>:$branch_weights,
@@ -755,7 +757,9 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
                   VariadicOfVariadic<LLVM_Type,
                                      "op_bundle_sizes">:$op_bundle_operands,
                   DenseI32ArrayAttr:$op_bundle_sizes,
-                  OptionalAttr<ArrayAttr>:$op_bundle_tags);
+                  OptionalAttr<ArrayAttr>:$op_bundle_tags,
+                  OptionalAttr<DictArrayAttr>:$arg_attrs,
+                  OptionalAttr<DictArrayAttr>:$res_attrs);
   // Append the aliasing related attributes defined in LLVM_MemAccessOpBase.
   let arguments = !con(args, aliasAttrs);
   let results = (outs Optional<LLVM_Type>:$result);
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td
index 991e753d1b3593..cc2f0e4962d8a8 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td
@@ -214,13 +214,24 @@ def SPIRV_FunctionCallOp : SPIRV_Op<"FunctionCall", [
 
   let arguments = (ins
     FlatSymbolRefAttr:$callee,
-    Variadic<SPIRV_Type>:$arguments
+    Variadic<SPIRV_Type>:$arguments,
+    OptionalAttr<DictArrayAttr>:$arg_attrs,
+    OptionalAttr<DictArrayAttr>:$res_attrs
   );
 
   let results = (outs
     Optional<SPIRV_Type>:$return_value
   );
 
+  let builders = [
+    OpBuilder<(ins "Type":$returnType, "FlatSymbolRefAttr":$callee,
+      "ValueRange":$arguments),
+    [{
+      build($_builder, $_state, returnType, callee, arguments,
+            /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr);
+    }]>
+  ];
+
   let autogenSerialization = 0;
 
   let assemblyFormat = [{
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index 77ed6b322451e1..e4eb67c8e14cec 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -886,7 +886,9 @@ def IncludeOp : TransformDialectOp<"include",
 
   let arguments = (ins SymbolRefAttr:$target,
                        FailurePropagationMode:$failure_propagation_mode,
-                       Variadic<Transform_AnyHandleOrParamType>:$operands);
+                       Variadic<Transform_AnyHandleOrParamType>:$operands,
+                       OptionalAttr<DictArrayAttr>:$arg_attrs,
+                       OptionalAttr<DictArrayAttr>:$res_attrs);
   let results = (outs Variadic<Transform_AnyHandleOrParamType>:$results);
 
   let assemblyFormat =
diff --git a/mlir/include/mlir/Interfaces/CallImplementation.h b/mlir/include/mlir/Interfaces/CallImplementation.h
deleted file mode 100644
index 2edc081bddf478..00000000000000
--- a/mlir/include/mlir/Interfaces/CallImplementation.h
+++ /dev/null
@@ -1,88 +0,0 @@
-//===- CallImplementation.h - Call and Callable Op utilities ----*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file provides utility functions for implementing call-like and
-// callable-like operations, in particular, parsing, printing and verification
-// components common to these operations.
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_INTERFACES_CALLIMPLEMENTATION_H
-#define MLIR_INTERFACES_CALLIMPLEMENTATION_H
-
-#include "mlir/IR/OpImplementation.h"
-#include "mlir/Interfaces/CallInterfaces.h"
-
-namespace mlir {
-
-namespace call_interface_impl {
-
-/// Parse a function or call result list.
-///
-///   function-result-list ::= function-result-list-parens
-///                          | non-function-type
-///   function-result-list-parens ::= `(` `)`
-///                                 | `(` function-result-list-no-parens `)`
-///   function-result-list-no-parens ::= function-result (`,` function-result)*
-///   function-result ::= type attribute-dict?
-///
-ParseResult
-parseFunctionResultList(OpAsmParser &parser, SmallVectorImpl<Type> &resultTypes,
-                        SmallVectorImpl<DictionaryAttr> &resultAttrs);
-
-/// Parses a function signature using `parser`. This does not deal with function
-/// signatures containing SSA region arguments (to parse these signatures, use
-/// function_interface_impl::parseFunctionSignature). When
-/// `mustParseEmptyResult`, `-> ()` is expected when there is no result type.
-///
-///   no-ssa-function-signature ::= `(` no-ssa-function-arg-list `)`
-///                               -> function-result-list
-///   no-ssa-function-arg-list  ::= no-ssa-function-arg
-///                               (`,` no-ssa-function-arg)*
-///   no-ssa-function-arg       ::= type attribute-dict?
-ParseResult parseFunctionSignature(OpAsmParser &parser,
-                                   SmallVectorImpl<Type> &argTypes,
-                                   SmallVectorImpl<DictionaryAttr> &argAttrs,
-                                   SmallVectorImpl<Type> &resultTypes,
-                                   SmallVectorImpl<DictionaryAttr> &resultAttrs,
-                                   bool mustParseEmptyResult = true);
-
-/// Print a function signature for a call or callable operation. If a body
-/// region is provided, the SSA arguments are printed in the signature. When
-/// `printEmptyResult` is false, `-> function-result-list` is omitted when
-/// `resultTypes` is empty.
-///
-///   function-signature     ::= ssa-function-signature
-///                            | no-ssa-function-signature
-///   ssa-function-signature ::= `(` ssa-function-arg-list `)`
-///                            -> function-result-list
-///   ssa-function-arg-list  ::= ssa-function-arg (`,` ssa-function-arg)*
-///   ssa-function-arg       ::= `%`name `:` type attribute-dict?
-void printFunctionSignature(OpAsmPrinter &p, ArgumentAttributesOpInterface op,
-                            TypeRange argTypes, bool isVariadic,
-                            TypeRange resultTypes, Region *body = nullptr,
-                            bool printEmptyResult = true);
-
-/// Adds argument and result attributes, provided as `argAttrs` and
-/// `resultAttrs` arguments, to the list of operation attributes in `result`.
-/// Internally, argument and result attributes are stored as dict attributes
-/// with special names given by getResultAttrName, getArgumentAttrName.
-void addArgAndResultAttrs(Builder &builder, OperationState &result,
-                          ArrayRef<DictionaryAttr> argAttrs,
-                          ArrayRef<DictionaryAttr> resultAttrs,
-                          StringAttr argAttrsName, StringAttr resAttrsName);
-void addArgAndResultAttrs(Builder &builder, OperationState &result,
-                          ArrayRef<OpAsmParser::Argument> args,
-                          ArrayRef<DictionaryAttr> resultAttrs,
-                          StringAttr argAttrsName, StringAttr resAttrsName);
-
-} // namespace call_interface_impl
-
-} // namespace mlir
-
-#endif // MLIR_INTERFACES_CALLIMPLEMENTATION_H
diff --git a/mlir/include/mlir/Interfaces/CallInterfaces.h b/mlir/include/mlir/Interfaces/CallInterfaces.h
index 0020c19333d103..2bf3a3ca5f8a89 100644
--- a/mlir/include/mlir/Interfaces/CallInterfaces.h
+++ b/mlir/include/mlir/Interfaces/CallInterfaces.h
@@ -14,6 +14,7 @@
 #ifndef MLIR_INTERFACES_CALLINTERFACES_H
 #define MLIR_INTERFACES_CALLINTERFACES_H
 
+#include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/SymbolTable.h"
 #include "llvm/ADT/PointerUnion.h"
 
@@ -35,6 +36,66 @@ namespace call_interface_impl {
 Operation *resolveCallable(CallOpInterface call,
                            SymbolTableCollection *symbolTable = nullptr);
 
+/// Parse a function or call result list.
+///
+///   function-result-list ::= function-result-list-parens
+///                          | non-function-type
+///   function-result-list-parens ::= `(` `)`
+///                                 | `(` function-result-list-no-parens `)`
+///   function-result-list-no-parens ::= function-result (`,` function-result)*
+///   function-result ::= type attribute-dict?
+///
+ParseResult
+parseFunctionResultList(OpAsmParser &parser, SmallVectorImpl<Type> &resultTypes,
+                        SmallVectorImpl<DictionaryAttr> &resultAttrs);
+
+/// Parses a function signature using `parser`. This does not deal with function
+/// signatures containing SSA region arguments (to parse these signatures, use
+/// function_interface_impl::parseFunctionSignature). When
+/// `mustParseEmptyResult`, `-> ()` is expected when there is no result type.
+///
+///   no-ssa-function-signature ::= `(` no-ssa-function-arg-list `)`
+///                               -> function-result-list
+///   no-ssa-function-arg-list  ::= no-ssa-function-arg
+///                               (`,` no-ssa-function-arg)*
+///   no-ssa-function-arg       ::= type attribute-dict?
+ParseResult parseFunctionSignature(OpAsmParser &parser,
+                                   SmallVectorImpl<Type> &argTypes,
+                                   SmallVectorImpl<DictionaryAttr> &argAttrs,
+                                   SmallVectorImpl<Type> &resultTypes,
+                                   SmallVectorImpl<DictionaryAttr> &resultAttrs,
+                                   bool mustParseEmptyResult = true);
+
+/// Print a function signature for a call or callable operation. If a body
+/// region is provided, the SSA arguments are printed in the signature. When
+/// `printEmptyResult` is false, `-> function-result-list` is omitted when
+/// `resultTypes` is empty.
+///
+///   function-signature     ::= ssa-function-signature
+///                            | no-ssa-function-signature
+///   ssa-function-signature ::= `(` ssa-function-arg-list `)`
+///                            -> function-result-list
+///   ssa-function-arg-list  ::= ssa-function-arg (`,` ssa-function-arg)*
+///   ssa-function-arg       ::= `%`name `:` type attribute-dict?
+void printFunctionSignature(OpAsmPrinter &p, TypeRange argTypes,
+                            ArrayAttr argAttrs, bool isVariadic,
+                            TypeRange resultTypes, ArrayAttr resultAttrs,
+                            Region *body = nullptr,
+                            bool printEmptyResult = true);
+
+/// Adds argument and result attributes, provided as `argAttrs` and
+/// `resultAttrs` arguments, to the list of operation attributes in `result`.
+/// Internally, argument and result attributes are stored as dict attributes
+/// with special names given by getResultAttrName, getArgumentAttrName.
+void addArgAndResultAttrs(Builder &builder, OperationState &result,
+                          ArrayRef<DictionaryAttr> argAttrs,
+                          ArrayRef<DictionaryAttr> resultAttrs,
+                          StringAttr argAttrsName, StringAttr resAttrsName);
+void addArgAndResultAttrs(Builder &builder, OperationState &result,
+                          ArrayRef<OpAsmParser::Argument> args,
+                          ArrayRef<DictionaryAttr> resultAttrs,
+                          StringAttr argAttrsName, StringAttr resAttrsName);
+
 } // namespace call_interface_impl
 
 } // namespace mlir
diff --git a/mlir/include/mlir/Interfaces/CallInterfaces.td b/mlir/include/mlir/Interfaces/CallInterfaces.td
index 1f2398387c044e..9955c6862b5328 100644
--- a/mlir/include/mlir/Interfaces/CallInterfaces.td
+++ b/mlir/include/mlir/Interfaces/CallInterfaces.td
@@ -20,13 +20,8 @@ include "mlir/IR/OpBase.td"
 
 /// Interface for operations with arguments attributes (both call-like
 /// and callable operations).
-def ArgumentAttributesOpInterface : OpInterface<"ArgumentAttributesOpInterface"> {
-  let description = [{
-    A call-like or callable operation that can hold attributes for its arguments
-    and results.
-  }];
-  let cppNamespace = "::mlir";
-  let methods = [
+def ArgumentAttributesMethods {
+  list<InterfaceMethod> methods = [
     InterfaceMethod<[{
         Get the array of argument attribute dictionaries. The method should
         return an array attribute containing only dictionary attributes equal in
@@ -78,7 +73,7 @@ def CallOpInterface : OpInterface<"CallOpInterface"> {
   }];
   let cppNamespace = "::mlir";
 
-  let methods = [
+  let methods = !listconcat(ArgumentAttributesMethods.methods, [
     InterfaceMethod<[{
         Returns the callee of this call-like operation. A `callee` is either a
         reference to a symbol, via SymbolRefAttr, or a reference to a defined
@@ -128,7 +123,7 @@ def CallOpInterface : OpInterface<"CallOpInterface"> {
         return ::mlir::call_interface_impl::resolveCallable($_op);
       }]
     >
-  ];
+  ]);
 }
 
 /// Interface for callable operations.
@@ -143,7 +138,7 @@ def CallableOpInterface : OpInterface<"CallableOpInterface"> {
   }];
   let cppNamespace = "::mlir";
 
-  let methods = [
+  let methods = !listconcat(ArgumentAttributesMethods.methods, [
     InterfaceMethod<[{
         Returns the region on the current operation that is callable. This may
         return null in the case of an external callable object, e.g. an external
@@ -160,7 +155,7 @@ def CallableOpInterface : OpInterface<"CallableOpInterface"> {
       allow for this method may be called on function declarations).
     }],
     "::llvm::ArrayRef<::mlir::Type>", "getResultTypes">,
-  ];
+  ]);
 }
 
 #endif // MLIR_INTERFACES_CALLINTERFACES
diff --git a/mlir/include/mlir/Interfaces/FunctionImplementation.h b/mlir/include/mlir/Interfaces/FunctionImplementation.h
index ae20533ef4b87c..7f7b30962c145f 100644
--- a/mlir/include/mlir/Interfaces/FunctionImplementation.h
+++ b/mlir/include/mlir/Interfaces/FunctionImplementation.h
@@ -16,7 +16,6 @@
 #define MLIR_IR_FUNCTIONIMPLEMENTATION_H_
 
 #include "mlir/IR/OpImplementation.h"
-#include "mlir/Interfaces/CallImplementation.h"
 #include "mlir/Interfaces/FunctionInterfaces.h"
 
 namespace mlir {
@@ -75,9 +74,10 @@ void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic,
 inline void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op,
                                    ArrayRef<Type> argTypes, bool isVariadic,
                                    ArrayRef<Type> resultTypes) {
-  call_interface_impl::printFunctionSignature(p, op, argTypes, isVariadic,
-                                              resultTypes, &op->getRegion(0),
-                                              /*printEmptyResult=*/false);
+  call_interface_impl::printFunctionSignature(
+      p, argTypes, op.getArgAttrsAttr(), isVariadic, resultTypes,
+      op.getResAttrsAttr(), &op->getRegion(0),
+      /*printEmptyResult=*/false);
 }
 
 /// Prints the list of function prefixed with the "attributes" keyword. The
diff --git a/mlir/include/mlir/Interfaces/FunctionInterfaces.td b/mlir/include/mlir/Interfaces/FunctionInterfaces.td
index 616785837e1452..697f951748c675 100644
--- a/mlir/include/mlir/Interfaces/FunctionInterfaces.td
+++ b/mlir/include/mlir/Interfaces/FunctionInterfaces.td
@@ -22,7 +22,7 @@ include "mlir/Interfaces/CallInterfaces.td"
 //===----------------------------------------------------------------------===//
 
 def FunctionOpInterface : OpInterface<"FunctionOpInterface", [
-    Symbol, CallableOpInterface, ArgumentAttributesOpInterface
+    Symbol, CallableOpInterface
   ]> {
   let cppNamespace = "::mlir";
   let description = [{
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index ef1e0222e05f06..0b0a3c533e4040 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -1033,6 +1033,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
         /*memory_effects=*/nullptr,
         /*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr,
         /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{},
+        /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr,
         /*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
         /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
 }
@@ -1060,6 +1061,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state,
         /*convergent=*/nullptr,
         /*no_unwind=*/nullptr, /*will_return=*/nullptr,
         /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{},
+        /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr,
         /*access_groups=*/nullptr,
         /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
 }
@@ -1073,6 +1075,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state,
         /*CConv=*/nullptr, /*TailCallKind=*/nullptr, /*memory_effects=*/nullptr,
         /*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr,
         /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{},
+        /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr,
         /*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
         /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
 }
@@ -1087,6 +1090,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
         /*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr,
         /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{},
         /*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
+        /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr,
         /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
 }
 
@@ -1527,7 +1531,8 @@ void InvokeOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
   auto calleeType = func.getFunctionType();
   build(builder, state, getCallOpResultTypes(calleeType),
         getCallOpVarCalleeType(calleeType), SymbolRefAttr::get(func), ops,
-        normalOps, unwindOps, nullptr, nullptr, {}, {}, normal, unwind);
+        /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr, normalOps, unwindOps,
+        nullptr, nullptr, {}, {}, normal, unwind);
 }
 
 void InvokeOp::build(OpBuilder &builder, OperationState &state, TypeRange tys,
@@ -1535,8 +1540,9 @@ void InvokeOp::build(OpBuilder &builder, OperationState &state, TypeRange tys,
                      ValueRange normalOps, Block *unwind,
                      ValueRange unwindOps) {
   build(builder, state, tys,
-        /*var_callee_type=*/nullptr, callee, ops, normalOps, unwindOps, nullptr,
-        nullptr, {}, {}, normal, unwind);
+        /*var_callee_type=*/nullptr, callee, ops, /*arg_attrs=*/nullptr,
+        /*res_attrs=*/nullptr, normalOps, unwindOps, nullptr, nullptr, {}, {},
+        normal, unwind);
 }
 
 void InvokeOp::build(OpBuilder &builder, OperationState &state,
@@ -1544,7 +1550,8 @@ void InvokeOp::build(OpBuilder &builder, OperationState &state,
                      ValueRange ops, Block *normal, ValueRange normalOps,
                      Block *unwind, ValueRange unwindOps) {
   build(builder, state, getCallOpResultTypes(calleeType),
-        getCallOpVarCalleeType(calleeType), callee, ops, normalOps, unwindOps,
+        getCallOpVarCalleeType(calleeType), callee, ops,
+        /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr, normalOps, unwindOps,
         nullptr, nullptr, {}, {}, normal, unwind);
 }
 
diff --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt
index 76e2d921c2a9f5..d3b7bf65ad3e73 100644
--- a/mlir/lib/Interfaces/CMakeLists.txt
+++ b/mlir/lib/Interfaces/CMakeLists.txt
@@ -1,5 +1,4 @@
 set(LLVM_OPTIONAL_SOURCES
-  CallImplementation.cpp
   CallInterfaces.cpp
   CastInterfaces.cpp
   ControlFlowInterfaces.cpp
@@ -25,10 +24,8 @@ set(LLVM_OPTIONAL_SOURCES
   )
 
 function(add_mlir_interface_library name)
-  cmake_parse_arguments(ARG "" "" "EXTRA_SOURCE" ${ARGN})
   add_mlir_library(MLIR${name}
     ${name}.cpp
-    ${ARG_EXTRA_SOURCE}
 
     ADDITIONAL_HEADER_DIRS
     ${MLIR_MAIN_INCLUDE_DIR}/mlir/Interfaces
@@ -42,14 +39,28 @@ function(add_mlir_interface_library name)
 endfunction(add_mlir_interface_library)
 
 
-add_mlir_interface_library(CallInterfaces EXTRA_SOURCE CallImplementation.cpp)
+add_mlir_interface_library(CallInterfaces)
 add_mlir_interface_library(CastInterfaces)
 add_mlir_interface_library(ControlFlowInterfaces)
 add_mlir_interface_library(CopyOpInterface)
 add_mlir_interface_library(DataLayoutInterfaces)
 add_mlir_interface_library(DerivedAttributeOpInterface)
 add_mlir_interface_library(DestinationStyleOpInterface)
-add_mlir_interface_library(FunctionInterfaces EXTRA_SOURCE FunctionImplementation.cpp)
+
+add_mlir_library(MLIRFunctionInterfaces
+  FunctionInterfaces.cpp
+  FunctionImplementation.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Interfaces
+
+  DEPENDS
+  MLIRFunctionInterfacesIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+)
+
 add_mlir_interface_library(InferIntRangeInterface)
 add_mlir_interface_library(InferTypeOpInterface)
 
diff --git a/mlir/lib/Interfaces/CallImplementation.cpp b/mlir/lib/Interfaces/CallImplementation.cpp
deleted file mode 100644
index 974e779e32d30b..00000000000000
--- a/mlir/lib/Interfaces/CallImplementation.cpp
+++ /dev/null
@@ -1,178 +0,0 @@
-//===- CallImplementation.pp - Call and Callable Op utilities -------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Interfaces/CallImplementation.h"
-#include "mlir/IR/Builders.h"
-
-using namespace mlir;
-
-static ParseResult
-parseTypeAndAttrList(OpAsmParser &parser, SmallVectorImpl<Type> &types,
-                     SmallVectorImpl<DictionaryAttr> &attrs) {
-  // Parse individual function results.
-  return parser.parseCommaSeparatedList([&]() -> ParseResult {
-    types.emplace_back();
-    attrs.emplace_back();
-    NamedAttrList attrList;
-    if (parser.parseType(types.back()) ||
-        parser.parseOptionalAttrDict(attrList))
-      return failure();
-    attrs.back() = attrList.getDictionary(parser.getContext());
-    return success();
-  });
-}
-
-ParseResult call_interface_impl::parseFunctionResultList(
-    OpAsmParser &parser, SmallVectorImpl<Type> &resultTypes,
-    SmallVectorImpl<DictionaryAttr> &resultAttrs) {
-  if (failed(parser.parseOptionalLParen())) {
-    // We already know that there is no `(`, so parse a type.
-    // Because there is no `(`, it cannot be a function type.
-    Type ty;
-    if (parser.parseType(ty))
-      return failure();
-    resultTypes.push_back(ty);
-    resultAttrs.emplace_back();
-    return success();
-  }
-
-  // Special case for an empty set of parens.
-  if (succeeded(parser.parseOptionalRParen()))
-    return success();
-  if (parseTypeAndAttrList(parser, resultTypes, resultAttrs))
-    return failure();
-  return parser.parseRParen();
-}
-
-ParseResult call_interface_impl::parseFunctionSignature(
-    OpAsmParser &parser, SmallVectorImpl<Type> &argTypes,
-    SmallVectorImpl<DictionaryAttr> &argAttrs,
-    SmallVectorImpl<Type> &resultTypes,
-    SmallVectorImpl<DictionaryAttr> &resultAttrs, bool mustParseEmptyResult) {
-  // Parse arguments.
-  if (parser.parseLParen())
-    return failure();
-  if (failed(parser.parseOptionalRParen())) {
-    if (parseTypeAndAttrList(parser, argTypes, argAttrs))
-      return failure();
-    if (parser.parseRParen())
-      return failure();
-  }
-  // Parse results.
-  if (succeeded(parser.parseOptionalArrow()))
-    return call_interface_impl::parseFunctionResultList(parser, resultTypes,
-                                                        resultAttrs);
-  if (mustParseEmptyResult)
-    return failure();
-  return success();
-}
-
-/// Print a function result list. The provided `attrs` must either be null, or
-/// contain a set of DictionaryAttrs of the same arity as `types`.
-static void printFunctionResultList(OpAsmPrinter &p, TypeRange types,
-                                    ArrayAttr attrs) {
-  assert(!types.empty() && "Should not be called for empty result list.");
-  assert((!attrs || attrs.size() == types.size()) &&
-         "Invalid number of attributes.");
-
-  auto &os = p.getStream();
-  bool needsParens = types.size() > 1 || llvm::isa<FunctionType>(types[0]) ||
-                     (attrs && !llvm::cast<DictionaryAttr>(attrs[0]).empty());
-  if (needsParens)
-    os << '(';
-  llvm::interleaveComma(llvm::seq<size_t>(0, types.size()), os, [&](size_t i) {
-    p.printType(types[i]);
-    if (attrs)
-      p.printOptionalAttrDict(llvm::cast<DictionaryAttr>(attrs[i]).getValue());
-  });
-  if (needsParens)
-    os << ')';
-}
-
-void call_interface_impl::printFunctionSignature(
-    OpAsmPrinter &p, ArgumentAttributesOpInterface op, TypeRange argTypes,
-    bool isVariadic, TypeRange resultTypes, Region *body,
-    bool printEmptyResult) {
-  bool isExternal = !body || body->empty();
-  ArrayAttr argAttrs = op.getArgAttrsAttr();
-  ArrayAttr resultAttrs = op.getResAttrsAttr();
-  if (!isExternal && !isVariadic && !argAttrs && !resultAttrs &&
-      printEmptyResult) {
-    p.printFunctionalType(argTypes, resultTypes);
-    return;
-  }
-
-  p << '(';
-  for (unsigned i = 0, e = argTypes.size(); i < e; ++i) {
-    if (i > 0)
-      p << ", ";
-
-    if (!isExternal) {
-      ArrayRef<NamedAttribute> attrs;
-      if (argAttrs)
-        attrs = llvm::cast<DictionaryAttr>(argAttrs[i]).getValue();
-      p.printRegionArgument(body->getArgument(i), attrs);
-    } else {
-      p.printType(argTypes[i]);
-      if (argAttrs)
-        p.printOptionalAttrDict(
-            llvm::cast<DictionaryAttr>(argAttrs[i]).getValue());
-    }
-  }
-
-  if (isVariadic) {
-    if (!argTypes.empty())
-      p << ", ";
-    p << "...";
-  }
-
-  p << ')';
-
-  if (!resultTypes.empty()) {
-    p << " -> ";
-    printFunctionResultList(p, resultTypes, resultAttrs);
-  } else if (printEmptyResult) {
-    p << " -> ()";
-  }
-}
-
-void call_interface_impl::addArgAndResultAttrs(
-    Builder &builder, OperationState &result, ArrayRef<DictionaryAttr> argAttrs,
-    ArrayRef<DictionaryAttr> resultAttrs, StringAttr argAttrsName,
-    StringAttr resAttrsName) {
-  auto nonEmptyAttrsFn = [](DictionaryAttr attrs) {
-    return attrs && !attrs.empty();
-  };
-  // Convert the specified array of dictionary attrs (which may have null
-  // entries) to an ArrayAttr of dictionaries.
-  auto getArrayAttr = [&](ArrayRef<DictionaryAttr> dictAttrs) {
-    SmallVector<Attribute> attrs;
-    for (auto &dict : dictAttrs)
-      attrs.push_back(dict ? dict : builder.getDictionaryAttr({}));
-    return builder.getArrayAttr(attrs);
-  };
-
-  // Add the attributes to the function arguments.
-  if (llvm::any_of(argAttrs, nonEmptyAttrsFn))
-    result.addAttribute(argAttrsName, getArrayAttr(argAttrs));
-
-  // Add the attributes to the function results.
-  if (llvm::any_of(resultAttrs, nonEmptyAttrsFn))
-    result.addAttribute(resAttrsName, getArrayAttr(resultAttrs));
-}
-
-void call_interface_impl::addArgAndResultAttrs(
-    Builder &builder, OperationState &result,
-    ArrayRef<OpAsmParser::Argument> args, ArrayRef<DictionaryAttr> resultAttrs,
-    StringAttr argAttrsName, StringAttr resAttrsName) {
-  SmallVector<DictionaryAttr> argAttrs;
-  for (const auto &arg : args)
-    argAttrs.push_back(arg.attrs);
-  addArgAndResultAttrs(builder, result, argAttrs, resultAttrs, argAttrsName,
-                       resAttrsName);
-}
diff --git a/mlir/lib/Interfaces/CallInterfaces.cpp b/mlir/lib/Interfaces/CallInterfaces.cpp
index da0ca0e24630f0..7e4790addeada1 100644
--- a/mlir/lib/Interfaces/CallInterfaces.cpp
+++ b/mlir/lib/Interfaces/CallInterfaces.cpp
@@ -7,9 +7,178 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Interfaces/CallInterfaces.h"
+#include "mlir/IR/Builders.h"
 
 using namespace mlir;
 
+//===----------------------------------------------------------------------===//
+// Argument and result attributes utilities
+//===----------------------------------------------------------------------===//
+
+static ParseResult
+parseTypeAndAttrList(OpAsmParser &parser, SmallVectorImpl<Type> &types,
+                     SmallVectorImpl<DictionaryAttr> &attrs) {
+  // Parse individual function results.
+  return parser.parseCommaSeparatedList([&]() -> ParseResult {
+    types.emplace_back();
+    attrs.emplace_back();
+    NamedAttrList attrList;
+    if (parser.parseType(types.back()) ||
+        parser.parseOptionalAttrDict(attrList))
+      return failure();
+    attrs.back() = attrList.getDictionary(parser.getContext());
+    return success();
+  });
+}
+
+ParseResult call_interface_impl::parseFunctionResultList(
+    OpAsmParser &parser, SmallVectorImpl<Type> &resultTypes,
+    SmallVectorImpl<DictionaryAttr> &resultAttrs) {
+  if (failed(parser.parseOptionalLParen())) {
+    // We already know that there is no `(`, so parse a type.
+    // Because there is no `(`, it cannot be a function type.
+    Type ty;
+    if (parser.parseType(ty))
+      return failure();
+    resultTypes.push_back(ty);
+    resultAttrs.emplace_back();
+    return success();
+  }
+
+  // Special case for an empty set of parens.
+  if (succeeded(parser.parseOptionalRParen()))
+    return success();
+  if (parseTypeAndAttrList(parser, resultTypes, resultAttrs))
+    return failure();
+  return parser.parseRParen();
+}
+
+ParseResult call_interface_impl::parseFunctionSignature(
+    OpAsmParser &parser, SmallVectorImpl<Type> &argTypes,
+    SmallVectorImpl<DictionaryAttr> &argAttrs,
+    SmallVectorImpl<Type> &resultTypes,
+    SmallVectorImpl<DictionaryAttr> &resultAttrs, bool mustParseEmptyResult) {
+  // Parse arguments.
+  if (parser.parseLParen())
+    return failure();
+  if (failed(parser.parseOptionalRParen())) {
+    if (parseTypeAndAttrList(parser, argTypes, argAttrs))
+      return failure();
+    if (parser.parseRParen())
+      return failure();
+  }
+  // Parse results.
+  if (succeeded(parser.parseOptionalArrow()))
+    return call_interface_impl::parseFunctionResultList(parser, resultTypes,
+                                                        resultAttrs);
+  if (mustParseEmptyResult)
+    return failure();
+  return success();
+}
+
+/// Print a function result list. The provided `attrs` must either be null, or
+/// contain a set of DictionaryAttrs of the same arity as `types`.
+static void printFunctionResultList(OpAsmPrinter &p, TypeRange types,
+                                    ArrayAttr attrs) {
+  assert(!types.empty() && "Should not be called for empty result list.");
+  assert((!attrs || attrs.size() == types.size()) &&
+         "Invalid number of attributes.");
+
+  auto &os = p.getStream();
+  bool needsParens = types.size() > 1 || llvm::isa<FunctionType>(types[0]) ||
+                     (attrs && !llvm::cast<DictionaryAttr>(attrs[0]).empty());
+  if (needsParens)
+    os << '(';
+  llvm::interleaveComma(llvm::seq<size_t>(0, types.size()), os, [&](size_t i) {
+    p.printType(types[i]);
+    if (attrs)
+      p.printOptionalAttrDict(llvm::cast<DictionaryAttr>(attrs[i]).getValue());
+  });
+  if (needsParens)
+    os << ')';
+}
+
+void call_interface_impl::printFunctionSignature(
+    OpAsmPrinter &p, TypeRange argTypes, ArrayAttr argAttrs, bool isVariadic,
+    TypeRange resultTypes, ArrayAttr resultAttrs, Region *body,
+    bool printEmptyResult) {
+  bool isExternal = !body || body->empty();
+  if (!isExternal && !isVariadic && !argAttrs && !resultAttrs &&
+      printEmptyResult) {
+    p.printFunctionalType(argTypes, resultTypes);
+    return;
+  }
+
+  p << '(';
+  for (unsigned i = 0, e = argTypes.size(); i < e; ++i) {
+    if (i > 0)
+      p << ", ";
+
+    if (!isExternal) {
+      ArrayRef<NamedAttribute> attrs;
+      if (argAttrs)
+        attrs = llvm::cast<DictionaryAttr>(argAttrs[i]).getValue();
+      p.printRegionArgument(body->getArgument(i), attrs);
+    } else {
+      p.printType(argTypes[i]);
+      if (argAttrs)
+        p.printOptionalAttrDict(
+            llvm::cast<DictionaryAttr>(argAttrs[i]).getValue());
+    }
+  }
+
+  if (isVariadic) {
+    if (!argTypes.empty())
+      p << ", ";
+    p << "...";
+  }
+
+  p << ')';
+
+  if (!resultTypes.empty()) {
+    p << " -> ";
+    printFunctionResultList(p, resultTypes, resultAttrs);
+  } else if (printEmptyResult) {
+    p << " -> ()";
+  }
+}
+
+void call_interface_impl::addArgAndResultAttrs(
+    Builder &builder, OperationState &result, ArrayRef<DictionaryAttr> argAttrs,
+    ArrayRef<DictionaryAttr> resultAttrs, StringAttr argAttrsName,
+    StringAttr resAttrsName) {
+  auto nonEmptyAttrsFn = [](DictionaryAttr attrs) {
+    return attrs && !attrs.empty();
+  };
+  // Convert the specified array of dictionary attrs (which may have null
+  // entries) to an ArrayAttr of dictionaries.
+  auto getArrayAttr = [&](ArrayRef<DictionaryAttr> dictAttrs) {
+    SmallVector<Attribute> attrs;
+    for (auto &dict : dictAttrs)
+      attrs.push_back(dict ? dict : builder.getDictionaryAttr({}));
+    return builder.getArrayAttr(attrs);
+  };
+
+  // Add the attributes to the function arguments.
+  if (llvm::any_of(argAttrs, nonEmptyAttrsFn))
+    result.addAttribute(argAttrsName, getArrayAttr(argAttrs));
+
+  // Add the attributes to the function results.
+  if (llvm::any_of(resultAttrs, nonEmptyAttrsFn))
+    result.addAttribute(resAttrsName, getArrayAttr(resultAttrs));
+}
+
+void call_interface_impl::addArgAndResultAttrs(
+    Builder &builder, OperationState &result,
+    ArrayRef<OpAsmParser::Argument> args, ArrayRef<DictionaryAttr> resultAttrs,
+    StringAttr argAttrsName, StringAttr resAttrsName) {
+  SmallVector<DictionaryAttr> argAttrs;
+  for (const auto &arg : args)
+    argAttrs.push_back(arg.attrs);
+  addArgAndResultAttrs(builder, result, argAttrs, resultAttrs, argAttrsName,
+                       resAttrsName);
+}
+
 //===----------------------------------------------------------------------===//
 // CallOpInterface
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Transforms/Utils/InliningUtils.cpp b/mlir/lib/Transforms/Utils/InliningUtils.cpp
index 57a7931b56085a..0cae63c58ca7be 100644
--- a/mlir/lib/Transforms/Utils/InliningUtils.cpp
+++ b/mlir/lib/Transforms/Utils/InliningUtils.cpp
@@ -193,13 +193,11 @@ static void handleArgumentImpl(InlinerInterface &interface, OpBuilder &builder,
   SmallVector<DictionaryAttr> argAttrs(
       callable.getCallableRegion()->getNumArguments(),
       builder.getDictionaryAttr({}));
-  if (auto argAttrsOpInterface =
-          dyn_cast<ArgumentAttributesOpInterface>(callable.getOperation()))
-    if (ArrayAttr arrayAttr = argAttrsOpInterface.getArgAttrsAttr()) {
-      assert(arrayAttr.size() == argAttrs.size());
-      for (auto [idx, attr] : llvm::enumerate(arrayAttr))
-        argAttrs[idx] = cast<DictionaryAttr>(attr);
-    }
+  if (ArrayAttr arrayAttr = callable.getArgAttrsAttr()) {
+    assert(arrayAttr.size() == argAttrs.size());
+    for (auto [idx, attr] : llvm::enumerate(arrayAttr))
+      argAttrs[idx] = cast<DictionaryAttr>(attr);
+  }
 
   // Run the argument attribute handler for the given argument and attribute.
   for (auto [blockArg, argAttr] :
@@ -220,13 +218,11 @@ static void handleResultImpl(InlinerInterface &interface, OpBuilder &builder,
   // Unpack the result attributes if there are any.
   SmallVector<DictionaryAttr> resAttrs(results.size(),
                                        builder.getDictionaryAttr({}));
-  if (auto argAttrsOpInterface =
-          dyn_cast<ArgumentAttributesOpInterface>(callable.getOperation()))
-    if (ArrayAttr arrayAttr = argAttrsOpInterface.getResAttrsAttr()) {
-      assert(arrayAttr.size() == resAttrs.size());
-      for (auto [idx, attr] : llvm::enumerate(arrayAttr))
-        resAttrs[idx] = cast<DictionaryAttr>(attr);
-    }
+  if (ArrayAttr arrayAttr = callable.getResAttrsAttr()) {
+    assert(arrayAttr.size() == resAttrs.size());
+    for (auto [idx, attr] : llvm::enumerate(arrayAttr))
+      resAttrs[idx] = cast<DictionaryAttr>(attr);
+  }
 
   // Run the result attribute handler for the given result and attribute.
   SmallVector<DictionaryAttr> resultAttributes;
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 0b1f22b3ee9323..dc812c584674a9 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -550,7 +550,12 @@ def TestCallOp : TEST_Op<"call", [DeclareOpInterfaceMethods<SymbolUserOpInterfac
 
 def ConversionCallOp : TEST_Op<"conversion_call_op",
     [CallOpInterface]> {
-  let arguments = (ins Variadic<AnyType>:$arg_operands, SymbolRefAttr:$callee);
+  let arguments = (ins
+    Variadic<AnyType>:$arg_operands,
+    SymbolRefAttr:$callee,
+    OptionalAttr<DictArrayAttr>:$arg_attrs,
+    OptionalAttr<DictArrayAttr>:$res_attrs
+  );
   let results = (outs Variadic<AnyType>);
 
   let extraClassDeclaration = [{
@@ -611,6 +616,10 @@ def ConversionFuncOp : TEST_Op<"conversion_func_op", [FunctionOpInterface]> {
 def FunctionalRegionOp : TEST_Op<"functional_region_op",
     [CallableOpInterface]> {
   let regions = (region AnyRegion:$body);
+  let arguments = (ins
+    OptionalAttr<DictArrayAttr>:$arg_attrs,
+    OptionalAttr<DictArrayAttr>:$res_attrs
+  );
   let results = (outs FunctionType);
 
   let extraClassDeclaration = [{
@@ -3287,7 +3296,9 @@ def TestCallAndStoreOp : TEST_Op<"call_and_store",
     SymbolRefAttr:$callee,
     Arg<AnyMemRef, "", [MemWrite]>:$address,
     Variadic<AnyType>:$callee_operands,
-    BoolAttr:$store_before_call
+    BoolAttr:$store_before_call,
+    OptionalAttr<DictArrayAttr>:$arg_attrs,
+    OptionalAttr<DictArrayAttr>:$res_attrs
   );
   let results = (outs
     Variadic<AnyType>:$results
@@ -3302,7 +3313,9 @@ def TestCallOnDeviceOp : TEST_Op<"call_on_device",
   let arguments = (ins
     SymbolRefAttr:$callee,
     Variadic<AnyType>:$forwarded_operands,
-    AnyType:$non_forwarded_device_operand
+    AnyType:$non_forwarded_device_operand,
+    OptionalAttr<DictArrayAttr>:$arg_attrs,
+    OptionalAttr<DictArrayAttr>:$res_attrs
   );
   let results = (outs
     Variadic<AnyType>:$results

>From 803451d255a569b56b36feea937aeccb93e1dec9 Mon Sep 17 00:00:00 2001
From: Jean Perier <jperier at nvidia.com>
Date: Tue, 21 Jan 2025 05:29:00 -0800
Subject: [PATCH 4/6] update toy dialect too

---
 mlir/examples/toy/Ch4/include/toy/Ops.td |  7 ++++++-
 mlir/examples/toy/Ch5/include/toy/Ops.td |  7 ++++++-
 mlir/examples/toy/Ch6/include/toy/Ops.td |  7 ++++++-
 mlir/examples/toy/Ch7/include/toy/Ops.td | 10 ++++++++--
 mlir/examples/toy/Ch7/mlir/Dialect.cpp   |  5 +++--
 mlir/examples/toy/Ch7/mlir/MLIRGen.cpp   |  3 +--
 6 files changed, 30 insertions(+), 9 deletions(-)

diff --git a/mlir/examples/toy/Ch4/include/toy/Ops.td b/mlir/examples/toy/Ch4/include/toy/Ops.td
index 075fd1a9cd4738..4441e48ca53c0b 100644
--- a/mlir/examples/toy/Ch4/include/toy/Ops.td
+++ b/mlir/examples/toy/Ch4/include/toy/Ops.td
@@ -215,7 +215,12 @@ def GenericCallOp : Toy_Op<"generic_call",
 
   // The generic call operation takes a symbol reference attribute as the
   // callee, and inputs for the call.
-  let arguments = (ins FlatSymbolRefAttr:$callee, Variadic<F64Tensor>:$inputs);
+  let arguments = (ins
+    FlatSymbolRefAttr:$callee,
+    Variadic<F64Tensor>:$inputs,
+    OptionalAttr<DictArrayAttr>:$arg_attrs,
+    OptionalAttr<DictArrayAttr>:$res_attrs
+  );
 
   // The generic call operation returns a single value of TensorType.
   let results = (outs F64Tensor);
diff --git a/mlir/examples/toy/Ch5/include/toy/Ops.td b/mlir/examples/toy/Ch5/include/toy/Ops.td
index ec6762ff406e87..5b7c966de6f088 100644
--- a/mlir/examples/toy/Ch5/include/toy/Ops.td
+++ b/mlir/examples/toy/Ch5/include/toy/Ops.td
@@ -214,7 +214,12 @@ def GenericCallOp : Toy_Op<"generic_call",
 
   // The generic call operation takes a symbol reference attribute as the
   // callee, and inputs for the call.
-  let arguments = (ins FlatSymbolRefAttr:$callee, Variadic<F64Tensor>:$inputs);
+  let arguments = (ins
+    FlatSymbolRefAttr:$callee,
+    Variadic<F64Tensor>:$inputs,
+    OptionalAttr<DictArrayAttr>:$arg_attrs,
+    OptionalAttr<DictArrayAttr>:$res_attrs
+  );
 
   // The generic call operation returns a single value of TensorType.
   let results = (outs F64Tensor);
diff --git a/mlir/examples/toy/Ch6/include/toy/Ops.td b/mlir/examples/toy/Ch6/include/toy/Ops.td
index a52bebc8b67b86..fdbc239a171dfa 100644
--- a/mlir/examples/toy/Ch6/include/toy/Ops.td
+++ b/mlir/examples/toy/Ch6/include/toy/Ops.td
@@ -214,7 +214,12 @@ def GenericCallOp : Toy_Op<"generic_call",
 
   // The generic call operation takes a symbol reference attribute as the
   // callee, and inputs for the call.
-  let arguments = (ins FlatSymbolRefAttr:$callee, Variadic<F64Tensor>:$inputs);
+  let arguments = (ins
+    FlatSymbolRefAttr:$callee,
+    Variadic<F64Tensor>:$inputs,
+    OptionalAttr<DictArrayAttr>:$arg_attrs,
+    OptionalAttr<DictArrayAttr>:$res_attrs
+  );
 
   // The generic call operation returns a single value of TensorType.
   let results = (outs F64Tensor);
diff --git a/mlir/examples/toy/Ch7/include/toy/Ops.td b/mlir/examples/toy/Ch7/include/toy/Ops.td
index cfd6859eb27bff..71ab7b0aeebb9f 100644
--- a/mlir/examples/toy/Ch7/include/toy/Ops.td
+++ b/mlir/examples/toy/Ch7/include/toy/Ops.td
@@ -237,7 +237,12 @@ def GenericCallOp : Toy_Op<"generic_call",
 
   // The generic call operation takes a symbol reference attribute as the
   // callee, and inputs for the call.
-  let arguments = (ins FlatSymbolRefAttr:$callee, Variadic<Toy_Type>:$inputs);
+  let arguments = (ins
+    FlatSymbolRefAttr:$callee,
+    Variadic<Toy_Type>:$inputs,
+    OptionalAttr<DictArrayAttr>:$arg_attrs,
+    OptionalAttr<DictArrayAttr>:$res_attrs
+  );
 
   // The generic call operation returns a single value of TensorType or
   // StructType.
@@ -250,7 +255,8 @@ def GenericCallOp : Toy_Op<"generic_call",
 
   // Add custom build methods for the generic call operation.
   let builders = [
-    OpBuilder<(ins "StringRef":$callee, "ArrayRef<Value>":$arguments)>
+    OpBuilder<(ins "Type":$result_type, "StringRef":$callee,
+              "ArrayRef<Value>":$arguments)>
   ];
 }
 
diff --git a/mlir/examples/toy/Ch7/mlir/Dialect.cpp b/mlir/examples/toy/Ch7/mlir/Dialect.cpp
index 7e030ffc5488c9..8dd5e3956c8f17 100644
--- a/mlir/examples/toy/Ch7/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch7/mlir/Dialect.cpp
@@ -350,9 +350,10 @@ void FuncOp::print(mlir::OpAsmPrinter &p) {
 //===----------------------------------------------------------------------===//
 
 void GenericCallOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
-                          StringRef callee, ArrayRef<mlir::Value> arguments) {
+                          mlir::Type resultType, StringRef callee,
+                          ArrayRef<mlir::Value> arguments) {
   // Generic call always returns an unranked Tensor initially.
-  state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
+  state.addTypes(resultType);
   state.addOperands(arguments);
   state.addAttribute("callee",
                      mlir::SymbolRefAttr::get(builder.getContext(), callee));
diff --git a/mlir/examples/toy/Ch7/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch7/mlir/MLIRGen.cpp
index 090e5ff9146041..e554e375209f1c 100644
--- a/mlir/examples/toy/Ch7/mlir/MLIRGen.cpp
+++ b/mlir/examples/toy/Ch7/mlir/MLIRGen.cpp
@@ -535,8 +535,7 @@ class MLIRGenImpl {
     }
     mlir::toy::FuncOp calledFunc = calledFuncIt->second;
     return builder.create<GenericCallOp>(
-        location, calledFunc.getFunctionType().getResult(0),
-        mlir::SymbolRefAttr::get(builder.getContext(), callee), operands);
+        location, calledFunc.getFunctionType().getResult(0), callee, operands);
   }
 
   /// Emit a print expression. It emits specific operations for two builtins:

>From b48b0dfd68c885af655c182a632681d6cd01f4fd Mon Sep 17 00:00:00 2001
From: jeanPerier <jperier at nvidia.com>
Date: Mon, 27 Jan 2025 13:47:19 +0100
Subject: [PATCH 5/6] Apply suggestions from code review

Co-authored-by: Tobias Gysi <tobias.gysi at nextsilicon.com>
---
 mlir/examples/toy/Ch7/mlir/Dialect.cpp         | 1 -
 mlir/include/mlir/Interfaces/CallInterfaces.td | 4 ++--
 mlir/lib/Interfaces/CallInterfaces.cpp         | 4 ++--
 3 files changed, 4 insertions(+), 5 deletions(-)

diff --git a/mlir/examples/toy/Ch7/mlir/Dialect.cpp b/mlir/examples/toy/Ch7/mlir/Dialect.cpp
index 8dd5e3956c8f17..76858a761dbc12 100644
--- a/mlir/examples/toy/Ch7/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch7/mlir/Dialect.cpp
@@ -352,7 +352,6 @@ void FuncOp::print(mlir::OpAsmPrinter &p) {
 void GenericCallOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
                           mlir::Type resultType, StringRef callee,
                           ArrayRef<mlir::Value> arguments) {
-  // Generic call always returns an unranked Tensor initially.
   state.addTypes(resultType);
   state.addOperands(arguments);
   state.addAttribute("callee",
diff --git a/mlir/include/mlir/Interfaces/CallInterfaces.td b/mlir/include/mlir/Interfaces/CallInterfaces.td
index 9955c6862b5328..697852afb9b7c3 100644
--- a/mlir/include/mlir/Interfaces/CallInterfaces.td
+++ b/mlir/include/mlir/Interfaces/CallInterfaces.td
@@ -26,14 +26,14 @@ def ArgumentAttributesMethods {
         Get the array of argument attribute dictionaries. The method should
         return an array attribute containing only dictionary attributes equal in
         number to the number of arguments. Alternatively, the method can
-        return null to indicate that the region has no argument attributes.
+        return null to indicate that there are no argument attributes.
       }],
       "::mlir::ArrayAttr", "getArgAttrsAttr">,
     InterfaceMethod<[{
         Get the array of result attribute dictionaries. The method should return
         an array attribute containing only dictionary attributes equal in number
         to the number of results. Alternatively, the method can return
-        null to indicate that the region has no result attributes.
+        null to indicate that there are no result attributes.
       }],
       "::mlir::ArrayAttr", "getResAttrsAttr">,
     InterfaceMethod<[{
diff --git a/mlir/lib/Interfaces/CallInterfaces.cpp b/mlir/lib/Interfaces/CallInterfaces.cpp
index 7e4790addeada1..e8ed4b339a0cba 100644
--- a/mlir/lib/Interfaces/CallInterfaces.cpp
+++ b/mlir/lib/Interfaces/CallInterfaces.cpp
@@ -159,11 +159,11 @@ void call_interface_impl::addArgAndResultAttrs(
     return builder.getArrayAttr(attrs);
   };
 
-  // Add the attributes to the function arguments.
+  // Add the attributes to the operation arguments.
   if (llvm::any_of(argAttrs, nonEmptyAttrsFn))
     result.addAttribute(argAttrsName, getArrayAttr(argAttrs));
 
-  // Add the attributes to the function results.
+  // Add the attributes to the operation results.
   if (llvm::any_of(resultAttrs, nonEmptyAttrsFn))
     result.addAttribute(resAttrsName, getArrayAttr(resultAttrs));
 }

>From 9c7ee7145c67e2424af8b58c2f25f2e34fc3586f Mon Sep 17 00:00:00 2001
From: Jean Perier <jperier at nvidia.com>
Date: Mon, 27 Jan 2025 05:13:22 -0800
Subject: [PATCH 6/6] rename parseFunctionSignature to
 parseFunctionSignatureWithArguments

---
 mlir/include/mlir/Interfaces/FunctionImplementation.h | 10 +++++-----
 mlir/lib/Dialect/GPU/IR/GPUDialect.cpp                |  2 +-
 mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp            |  2 +-
 mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp                |  2 +-
 mlir/lib/Interfaces/FunctionImplementation.cpp        |  6 +++---
 5 files changed, 11 insertions(+), 11 deletions(-)

diff --git a/mlir/include/mlir/Interfaces/FunctionImplementation.h b/mlir/include/mlir/Interfaces/FunctionImplementation.h
index 7f7b30962c145f..374c2c534f87d6 100644
--- a/mlir/include/mlir/Interfaces/FunctionImplementation.h
+++ b/mlir/include/mlir/Interfaces/FunctionImplementation.h
@@ -45,11 +45,11 @@ using FuncTypeBuilder = function_ref<Type(
 /// indicates whether functions with variadic arguments are supported. The
 /// trailing arguments are populated by this function with names, types,
 /// attributes and locations of the arguments and those of the results.
-ParseResult
-parseFunctionSignature(OpAsmParser &parser, bool allowVariadic,
-                       SmallVectorImpl<OpAsmParser::Argument> &arguments,
-                       bool &isVariadic, SmallVectorImpl<Type> &resultTypes,
-                       SmallVectorImpl<DictionaryAttr> &resultAttrs);
+ParseResult parseFunctionSignatureWithArguments(
+    OpAsmParser &parser, bool allowVariadic,
+    SmallVectorImpl<OpAsmParser::Argument> &arguments, bool &isVariadic,
+    SmallVectorImpl<Type> &resultTypes,
+    SmallVectorImpl<DictionaryAttr> &resultAttrs);
 
 /// Parser implementation for function-like operations.  Uses
 /// `funcTypeBuilder` to construct the custom function type given lists of
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 8b85c0829acfec..d9c119862710bd 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -1467,7 +1467,7 @@ ParseResult GPUFuncOp::parse(OpAsmParser &parser, OperationState &result) {
     return failure();
 
   auto signatureLocation = parser.getCurrentLocation();
-  if (failed(function_interface_impl::parseFunctionSignature(
+  if (failed(function_interface_impl::parseFunctionSignatureWithArguments(
           parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes,
           resultAttrs)))
     return failure();
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 0b0a3c533e4040..a6e996f3fb810d 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -2602,7 +2602,7 @@ ParseResult LLVMFuncOp::parse(OpAsmParser &parser, OperationState &result) {
   auto signatureLocation = parser.getCurrentLocation();
   if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
                              result.attributes) ||
-      function_interface_impl::parseFunctionSignature(
+      function_interface_impl::parseFunctionSignatureWithArguments(
           parser, /*allowVariadic=*/true, entryArgs, isVariadic, resultTypes,
           resultAttrs))
     return failure();
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 870359ce55301c..b613724421305d 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -917,7 +917,7 @@ ParseResult spirv::FuncOp::parse(OpAsmParser &parser, OperationState &result) {
 
   // Parse the function signature.
   bool isVariadic = false;
-  if (function_interface_impl::parseFunctionSignature(
+  if (function_interface_impl::parseFunctionSignatureWithArguments(
           parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes,
           resultAttrs))
     return failure();
diff --git a/mlir/lib/Interfaces/FunctionImplementation.cpp b/mlir/lib/Interfaces/FunctionImplementation.cpp
index 80174d1fefb559..90f32896e81813 100644
--- a/mlir/lib/Interfaces/FunctionImplementation.cpp
+++ b/mlir/lib/Interfaces/FunctionImplementation.cpp
@@ -70,7 +70,7 @@ parseFunctionArgumentList(OpAsmParser &parser, bool allowVariadic,
       });
 }
 
-ParseResult function_interface_impl::parseFunctionSignature(
+ParseResult function_interface_impl::parseFunctionSignatureWithArguments(
     OpAsmParser &parser, bool allowVariadic,
     SmallVectorImpl<OpAsmParser::Argument> &arguments, bool &isVariadic,
     SmallVectorImpl<Type> &resultTypes,
@@ -104,8 +104,8 @@ ParseResult function_interface_impl::parseFunctionOp(
   // Parse the function signature.
   SMLoc signatureLocation = parser.getCurrentLocation();
   bool isVariadic = false;
-  if (parseFunctionSignature(parser, allowVariadic, entryArgs, isVariadic,
-                             resultTypes, resultAttrs))
+  if (parseFunctionSignatureWithArguments(parser, allowVariadic, entryArgs,
+                                          isVariadic, resultTypes, resultAttrs))
     return failure();
 
   std::string errorMessage;



More information about the flang-commits mailing list