[Mlir-commits] [mlir] [mlir] share argument attributes interface between calls and callables (PR #123176)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jan 16 02:03:17 PST 2025


https://github.com/jeanPerier created https://github.com/llvm/llvm-project/pull/123176

First patch of this [RFC](https://discourse.llvm.org/t/mlir-rfc-adding-argument-and-result-attributes-to-llvm-call/84107) to add argument and result attributes on llvm.call (and llvm.invoke).

This patch extracts the core interface methods dealing with argument and result attributes from CallOpInterface in a new interface inherited by both CallOpInterface and CallableInterface.

This allows dialects to "opt-in" having parameter attributes on call arguments by adding `arg_attrs` and `res_attrs` `OptionalAttr<DictArrayAttr>` to the operation definition. They can then re-use the common "rich function signature" printing/parsing helpers if they want.

>From b652f02b7178b2c3d35759770c7b20aca1ede6c4 Mon Sep 17 00:00:00 2001
From: Jean Perier <jperier at nvidia.com>
Date: Wed, 15 Jan 2025 09:01:44 -0800
Subject: [PATCH] [mlir] share argument attributes interface between calls and
 callables

---
 mlir/docs/Interfaces.md                       |  16 +-
 .../mlir/Interfaces/CallImplementation.h      |  91 +++++++++
 .../include/mlir/Interfaces/CallInterfaces.td |  99 +++++-----
 .../mlir/Interfaces/FunctionImplementation.h  |  24 +--
 mlir/lib/Dialect/Async/IR/Async.cpp           |   2 +-
 mlir/lib/Dialect/EmitC/IR/EmitC.cpp           |   2 +-
 mlir/lib/Dialect/Func/IR/FuncOps.cpp          |   2 +-
 mlir/lib/Dialect/GPU/IR/GPUDialect.cpp        |   2 +-
 mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp    |   4 +-
 mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp        |   2 +-
 mlir/lib/Dialect/Shape/IR/Shape.cpp           |   2 +-
 mlir/lib/Interfaces/CMakeLists.txt            |  21 +--
 mlir/lib/Interfaces/CallImplementation.cpp    | 178 ++++++++++++++++++
 .../lib/Interfaces/FunctionImplementation.cpp | 148 +--------------
 14 files changed, 359 insertions(+), 234 deletions(-)
 create mode 100644 mlir/include/mlir/Interfaces/CallImplementation.h
 create mode 100644 mlir/lib/Interfaces/CallImplementation.cpp

diff --git a/mlir/docs/Interfaces.md b/mlir/docs/Interfaces.md
index 51747db546bb76..195f58dfbc23b3 100644
--- a/mlir/docs/Interfaces.md
+++ b/mlir/docs/Interfaces.md
@@ -753,7 +753,15 @@ interface section goes as follows:
     -   (`C++ class` -- `ODS class`(if applicable))
 
 ##### CallInterfaces
-
+*   `OpWithArgumentAttributesInterface`  - Used to represent operations that may
+    carry argument and result attributes. It is inherited by both
+    CallOpInterface and CallableOpInterface.
+    -   `ArrayAttr getArgAttrsAttr()`
+    -   `ArrayAttr getResAttrsAttr()`
+    -   `void setArgAttrsAttr(ArrayAttr)`
+    -   `void setResAttrsAttr(ArrayAttr)`
+    -   `Attribute removeArgAttrsAttr()`
+    -   `Attribute removeResAttrsAttr()`
 *   `CallOpInterface` - Used to represent operations like 'call'
     -   `CallInterfaceCallable getCallableForCallee()`
     -   `void setCalleeFromCallable(CallInterfaceCallable)`
@@ -761,12 +769,6 @@ interface section goes as follows:
     -   `Region * getCallableRegion()`
     -   `ArrayRef<Type> getArgumentTypes()`
     -   `ArrayRef<Type> getResultsTypes()`
-    -   `ArrayAttr getArgAttrsAttr()`
-    -   `ArrayAttr getResAttrsAttr()`
-    -   `void setArgAttrsAttr(ArrayAttr)`
-    -   `void setResAttrsAttr(ArrayAttr)`
-    -   `Attribute removeArgAttrsAttr()`
-    -   `Attribute removeResAttrsAttr()`
 
 ##### RegionKindInterfaces
 
diff --git a/mlir/include/mlir/Interfaces/CallImplementation.h b/mlir/include/mlir/Interfaces/CallImplementation.h
new file mode 100644
index 00000000000000..85e47f6b3dbbb9
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/CallImplementation.h
@@ -0,0 +1,91 @@
+//===- CallImplementation.h - Call and Callable Op utilities ----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file provides utility functions for implementing call-like and
+// callable-like operations, in particular, parsing, printing and verification
+// components common to these operations.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_CALLIMPLEMENTATION_H
+#define MLIR_INTERFACES_CALLIMPLEMENTATION_H
+
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/Interfaces/CallInterfaces.h"
+
+namespace mlir {
+
+class OpWithArgumentAttributesInterface;
+
+namespace call_interface_impl {
+
+/// Parse a function or call result list.
+///
+///   function-result-list ::= function-result-list-parens
+///                          | non-function-type
+///   function-result-list-parens ::= `(` `)`
+///                                 | `(` function-result-list-no-parens `)`
+///   function-result-list-no-parens ::= function-result (`,` function-result)*
+///   function-result ::= type attribute-dict?
+///
+ParseResult
+parseFunctionResultList(OpAsmParser &parser, SmallVectorImpl<Type> &resultTypes,
+                        SmallVectorImpl<DictionaryAttr> &resultAttrs);
+
+/// Parses a function signature using `parser`. This does not deal with function
+/// signatures containing SSA region arguments (to parse these signatures, use
+/// function_interface_impl::parseFunctionSignature). When
+/// `mustParseEmptyResult`, `-> ()` is expected when there is no result type.
+///
+///   no-ssa-function-signature ::= `(` no-ssa-function-arg-list `)`
+///                               -> function-result-list
+///   no-ssa-function-arg-list  ::= no-ssa-function-arg
+///                               (`,` no-ssa-function-arg)*
+///   no-ssa-function-arg       ::= type attribute-dict?
+ParseResult parseFunctionSignature(OpAsmParser &parser,
+                                   SmallVectorImpl<Type> &argTypes,
+                                   SmallVectorImpl<DictionaryAttr> &argAttrs,
+                                   SmallVectorImpl<Type> &resultTypes,
+                                   SmallVectorImpl<DictionaryAttr> &resultAttrs,
+                                   bool mustParseEmptyResult = true);
+
+/// Print a function signature for a call or callable operation. If a body
+/// region is provided, the SSA arguments are printed in the signature. When
+/// `printEmptyResult` is false, `-> function-result-list` is omitted when
+/// `resultTypes` is empty.
+///
+///   function-signature     ::= ssa-function-signature
+///                            | no-ssa-function-signature
+///   ssa-function-signature ::= `(` ssa-function-arg-list `)`
+///                            -> function-result-list
+///   ssa-function-arg-list  ::= ssa-function-arg (`,` ssa-function-arg)*
+///   ssa-function-arg       ::= `%`name `:` type attribute-dict?
+void printFunctionSignature(OpAsmPrinter &p,
+                            OpWithArgumentAttributesInterface op,
+                            TypeRange argTypes, bool isVariadic,
+                            TypeRange resultTypes, Region *body = nullptr,
+                            bool printEmptyResult = true);
+
+/// Adds argument and result attributes, provided as `argAttrs` and
+/// `resultAttrs` arguments, to the list of operation attributes in `result`.
+/// Internally, argument and result attributes are stored as dict attributes
+/// with special names given by getResultAttrName, getArgumentAttrName.
+void addArgAndResultAttrs(Builder &builder, OperationState &result,
+                          ArrayRef<DictionaryAttr> argAttrs,
+                          ArrayRef<DictionaryAttr> resultAttrs,
+                          StringAttr argAttrsName, StringAttr resAttrsName);
+void addArgAndResultAttrs(Builder &builder, OperationState &result,
+                          ArrayRef<OpAsmParser::Argument> args,
+                          ArrayRef<DictionaryAttr> resultAttrs,
+                          StringAttr argAttrsName, StringAttr resAttrsName);
+
+} // namespace call_interface_impl
+
+} // namespace mlir
+
+#endif // MLIR_INTERFACES_CALLIMPLEMENTATION_H
diff --git a/mlir/include/mlir/Interfaces/CallInterfaces.td b/mlir/include/mlir/Interfaces/CallInterfaces.td
index c6002da0d491ce..80912a9762187e 100644
--- a/mlir/include/mlir/Interfaces/CallInterfaces.td
+++ b/mlir/include/mlir/Interfaces/CallInterfaces.td
@@ -17,12 +17,65 @@
 
 include "mlir/IR/OpBase.td"
 
+
+/// Interface for operations with arguments attributes (both call-like
+/// and callable operations).
+def OpWithArgumentAttributesInterface : OpInterface<"OpWithArgumentAttributesInterface"> {
+  let description = [{
+    A call-like or callable operation that may define attributes for its arguments. 
+  }];
+  let cppNamespace = "::mlir";
+  let methods = [
+    InterfaceMethod<[{
+        Get the array of argument attribute dictionaries. The method should
+        return an array attribute containing only dictionary attributes equal in
+        number to the number of arguments. Alternatively, the method can
+        return null to indicate that the region has no argument attributes.
+      }],
+      "::mlir::ArrayAttr", "getArgAttrsAttr", (ins),
+      /*methodBody=*/[{}], /*defaultImplementation=*/[{ return nullptr; }]>,
+    InterfaceMethod<[{
+        Get the array of result attribute dictionaries. The method should return
+        an array attribute containing only dictionary attributes equal in number
+        to the number of results. Alternatively, the method can return
+        null to indicate that the region has no result attributes.
+      }],
+      "::mlir::ArrayAttr", "getResAttrsAttr", (ins),
+      /*methodBody=*/[{}], /*defaultImplementation=*/[{ return nullptr; }]>,
+    InterfaceMethod<[{
+      Set the array of argument attribute dictionaries.
+    }],
+    "void", "setArgAttrsAttr", (ins "::mlir::ArrayAttr":$attrs),
+      /*methodBody=*/[{}], /*defaultImplementation=*/[{ return; }]>,
+    InterfaceMethod<[{
+      Set the array of result attribute dictionaries.
+    }],
+    "void", "setResAttrsAttr", (ins "::mlir::ArrayAttr":$attrs),
+      /*methodBody=*/[{}], /*defaultImplementation=*/[{ return; }]>,
+    InterfaceMethod<[{
+      Remove the array of argument attribute dictionaries. This is the same as
+      setting all argument attributes to an empty dictionary. The method should
+      return the removed attribute.
+    }],
+    "::mlir::Attribute", "removeArgAttrsAttr", (ins),
+      /*methodBody=*/[{}], /*defaultImplementation=*/[{ return nullptr; }]>,
+    InterfaceMethod<[{
+      Remove the array of result attribute dictionaries. This is the same as
+      setting all result attributes to an empty dictionary. The method should
+      return the removed attribute.
+    }],
+    "::mlir::Attribute", "removeResAttrsAttr", (ins),
+      /*methodBody=*/[{}], /*defaultImplementation=*/[{ return nullptr; }]>,
+  ];
+}
+
 // `CallInterfaceCallable`: This is a type used to represent a single callable
 // region. A callable is either a symbol, or an SSA value, that is referenced by
 // a call-like operation. This represents the destination of the call.
 
 /// Interface for call-like operations.
-def CallOpInterface : OpInterface<"CallOpInterface"> {
+def CallOpInterface : OpInterface<"CallOpInterface",
+    [OpWithArgumentAttributesInterface]> {
   let description = [{
     A call-like operation is one that transfers control from one sub-routine to
     another. These operations may be traditional direct calls `call @foo`, or
@@ -85,7 +138,8 @@ def CallOpInterface : OpInterface<"CallOpInterface"> {
 }
 
 /// Interface for callable operations.
-def CallableOpInterface : OpInterface<"CallableOpInterface"> {
+def CallableOpInterface : OpInterface<"CallableOpInterface",
+    [OpWithArgumentAttributesInterface]> {
   let description = [{
     A callable operation is one who represents a potential sub-routine, and may
     be a target for a call-like operation (those providing the CallOpInterface
@@ -113,47 +167,6 @@ def CallableOpInterface : OpInterface<"CallableOpInterface"> {
       allow for this method may be called on function declarations).
     }],
     "::llvm::ArrayRef<::mlir::Type>", "getResultTypes">,
-
-    InterfaceMethod<[{
-        Get the array of argument attribute dictionaries. The method should
-        return an array attribute containing only dictionary attributes equal in
-        number to the number of region arguments. Alternatively, the method can
-        return null to indicate that the region has no argument attributes.
-      }],
-      "::mlir::ArrayAttr", "getArgAttrsAttr", (ins),
-      /*methodBody=*/[{}], /*defaultImplementation=*/[{ return nullptr; }]>,
-    InterfaceMethod<[{
-        Get the array of result attribute dictionaries. The method should return
-        an array attribute containing only dictionary attributes equal in number
-        to the number of region results. Alternatively, the method can return
-        null to indicate that the region has no result attributes.
-      }],
-      "::mlir::ArrayAttr", "getResAttrsAttr", (ins),
-      /*methodBody=*/[{}], /*defaultImplementation=*/[{ return nullptr; }]>,
-    InterfaceMethod<[{
-      Set the array of argument attribute dictionaries.
-    }],
-    "void", "setArgAttrsAttr", (ins "::mlir::ArrayAttr":$attrs),
-      /*methodBody=*/[{}], /*defaultImplementation=*/[{ return; }]>,
-    InterfaceMethod<[{
-      Set the array of result attribute dictionaries.
-    }],
-    "void", "setResAttrsAttr", (ins "::mlir::ArrayAttr":$attrs),
-      /*methodBody=*/[{}], /*defaultImplementation=*/[{ return; }]>,
-    InterfaceMethod<[{
-      Remove the array of argument attribute dictionaries. This is the same as
-      setting all argument attributes to an empty dictionary. The method should
-      return the removed attribute.
-    }],
-    "::mlir::Attribute", "removeArgAttrsAttr", (ins),
-      /*methodBody=*/[{}], /*defaultImplementation=*/[{ return nullptr; }]>,
-    InterfaceMethod<[{
-      Remove the array of result attribute dictionaries. This is the same as
-      setting all result attributes to an empty dictionary. The method should
-      return the removed attribute.
-    }],
-    "::mlir::Attribute", "removeResAttrsAttr", (ins),
-      /*methodBody=*/[{}], /*defaultImplementation=*/[{ return nullptr; }]>,
   ];
 }
 
diff --git a/mlir/include/mlir/Interfaces/FunctionImplementation.h b/mlir/include/mlir/Interfaces/FunctionImplementation.h
index a5e6963e4e666f..ae20533ef4b87c 100644
--- a/mlir/include/mlir/Interfaces/FunctionImplementation.h
+++ b/mlir/include/mlir/Interfaces/FunctionImplementation.h
@@ -16,6 +16,7 @@
 #define MLIR_IR_FUNCTIONIMPLEMENTATION_H_
 
 #include "mlir/IR/OpImplementation.h"
+#include "mlir/Interfaces/CallImplementation.h"
 #include "mlir/Interfaces/FunctionInterfaces.h"
 
 namespace mlir {
@@ -33,19 +34,6 @@ class VariadicFlag {
   bool variadic;
 };
 
-/// Adds argument and result attributes, provided as `argAttrs` and
-/// `resultAttrs` arguments, to the list of operation attributes in `result`.
-/// Internally, argument and result attributes are stored as dict attributes
-/// with special names given by getResultAttrName, getArgumentAttrName.
-void addArgAndResultAttrs(Builder &builder, OperationState &result,
-                          ArrayRef<DictionaryAttr> argAttrs,
-                          ArrayRef<DictionaryAttr> resultAttrs,
-                          StringAttr argAttrsName, StringAttr resAttrsName);
-void addArgAndResultAttrs(Builder &builder, OperationState &result,
-                          ArrayRef<OpAsmParser::Argument> args,
-                          ArrayRef<DictionaryAttr> resultAttrs,
-                          StringAttr argAttrsName, StringAttr resAttrsName);
-
 /// Callback type for `parseFunctionOp`, the callback should produce the
 /// type that will be associated with a function-like operation from lists of
 /// function arguments and results, VariadicFlag indicates whether the function
@@ -84,9 +72,13 @@ void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic,
 
 /// Prints the signature of the function-like operation `op`. Assumes `op` has
 /// is a FunctionOpInterface and has passed verification.
-void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op,
-                            ArrayRef<Type> argTypes, bool isVariadic,
-                            ArrayRef<Type> resultTypes);
+inline void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op,
+                                   ArrayRef<Type> argTypes, bool isVariadic,
+                                   ArrayRef<Type> resultTypes) {
+  call_interface_impl::printFunctionSignature(p, op, argTypes, isVariadic,
+                                              resultTypes, &op->getRegion(0),
+                                              /*printEmptyResult=*/false);
+}
 
 /// Prints the list of function prefixed with the "attributes" keyword. The
 /// attributes with names listed in "elided" as well as those used by the
diff --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp
index a3e3f80954efce..d3bb250bb8ab9d 100644
--- a/mlir/lib/Dialect/Async/IR/Async.cpp
+++ b/mlir/lib/Dialect/Async/IR/Async.cpp
@@ -308,7 +308,7 @@ void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
   if (argAttrs.empty())
     return;
   assert(type.getNumInputs() == argAttrs.size());
-  function_interface_impl::addArgAndResultAttrs(
+  call_interface_impl::addArgAndResultAttrs(
       builder, state, argAttrs, /*resultAttrs=*/std::nullopt,
       getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name));
 }
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index fdc21d6c6e24b9..6af17087ced968 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -529,7 +529,7 @@ void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
   if (argAttrs.empty())
     return;
   assert(type.getNumInputs() == argAttrs.size());
-  function_interface_impl::addArgAndResultAttrs(
+  call_interface_impl::addArgAndResultAttrs(
       builder, state, argAttrs, /*resultAttrs=*/std::nullopt,
       getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name));
 }
diff --git a/mlir/lib/Dialect/Func/IR/FuncOps.cpp b/mlir/lib/Dialect/Func/IR/FuncOps.cpp
index a490b4c3c4ab43..ba7b84f27d6a8a 100644
--- a/mlir/lib/Dialect/Func/IR/FuncOps.cpp
+++ b/mlir/lib/Dialect/Func/IR/FuncOps.cpp
@@ -190,7 +190,7 @@ void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
   if (argAttrs.empty())
     return;
   assert(type.getNumInputs() == argAttrs.size());
-  function_interface_impl::addArgAndResultAttrs(
+  call_interface_impl::addArgAndResultAttrs(
       builder, state, argAttrs, /*resultAttrs=*/std::nullopt,
       getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name));
 }
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 49209229259a73..8b85c0829acfec 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -1487,7 +1487,7 @@ ParseResult GPUFuncOp::parse(OpAsmParser &parser, OperationState &result) {
   result.addAttribute(getFunctionTypeAttrName(result.name),
                       TypeAttr::get(type));
 
-  function_interface_impl::addArgAndResultAttrs(
+  call_interface_impl::addArgAndResultAttrs(
       builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name),
       getResAttrsAttrName(result.name));
 
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index ef5f1b069b40a3..ef1e0222e05f06 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -2510,7 +2510,7 @@ void LLVMFuncOp::build(OpBuilder &builder, OperationState &result,
 
   assert(llvm::cast<LLVMFunctionType>(type).getNumParams() == argAttrs.size() &&
          "expected as many argument attribute lists as arguments");
-  function_interface_impl::addArgAndResultAttrs(
+  call_interface_impl::addArgAndResultAttrs(
       builder, result, argAttrs, /*resultAttrs=*/std::nullopt,
       getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
 }
@@ -2636,7 +2636,7 @@ ParseResult LLVMFuncOp::parse(OpAsmParser &parser, OperationState &result) {
 
   if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
     return failure();
-  function_interface_impl::addArgAndResultAttrs(
+  call_interface_impl::addArgAndResultAttrs(
       parser.getBuilder(), result, entryArgs, resultAttrs,
       getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
 
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 26559c1321db5e..870359ce55301c 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -940,7 +940,7 @@ ParseResult spirv::FuncOp::parse(OpAsmParser &parser, OperationState &result) {
 
   // Add the attributes to the function arguments.
   assert(resultAttrs.size() == resultTypes.size());
-  function_interface_impl::addArgAndResultAttrs(
+  call_interface_impl::addArgAndResultAttrs(
       builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name),
       getResAttrsAttrName(result.name));
 
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 65efc88e9c4033..2200af0f67a862 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -1297,7 +1297,7 @@ void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
   if (argAttrs.empty())
     return;
   assert(type.getNumInputs() == argAttrs.size());
-  function_interface_impl::addArgAndResultAttrs(
+  call_interface_impl::addArgAndResultAttrs(
       builder, state, argAttrs, /*resultAttrs=*/std::nullopt,
       getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name));
 }
diff --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt
index d3b7bf65ad3e73..76e2d921c2a9f5 100644
--- a/mlir/lib/Interfaces/CMakeLists.txt
+++ b/mlir/lib/Interfaces/CMakeLists.txt
@@ -1,4 +1,5 @@
 set(LLVM_OPTIONAL_SOURCES
+  CallImplementation.cpp
   CallInterfaces.cpp
   CastInterfaces.cpp
   ControlFlowInterfaces.cpp
@@ -24,8 +25,10 @@ set(LLVM_OPTIONAL_SOURCES
   )
 
 function(add_mlir_interface_library name)
+  cmake_parse_arguments(ARG "" "" "EXTRA_SOURCE" ${ARGN})
   add_mlir_library(MLIR${name}
     ${name}.cpp
+    ${ARG_EXTRA_SOURCE}
 
     ADDITIONAL_HEADER_DIRS
     ${MLIR_MAIN_INCLUDE_DIR}/mlir/Interfaces
@@ -39,28 +42,14 @@ function(add_mlir_interface_library name)
 endfunction(add_mlir_interface_library)
 
 
-add_mlir_interface_library(CallInterfaces)
+add_mlir_interface_library(CallInterfaces EXTRA_SOURCE CallImplementation.cpp)
 add_mlir_interface_library(CastInterfaces)
 add_mlir_interface_library(ControlFlowInterfaces)
 add_mlir_interface_library(CopyOpInterface)
 add_mlir_interface_library(DataLayoutInterfaces)
 add_mlir_interface_library(DerivedAttributeOpInterface)
 add_mlir_interface_library(DestinationStyleOpInterface)
-
-add_mlir_library(MLIRFunctionInterfaces
-  FunctionInterfaces.cpp
-  FunctionImplementation.cpp
-
-  ADDITIONAL_HEADER_DIRS
-  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Interfaces
-
-  DEPENDS
-  MLIRFunctionInterfacesIncGen
-
-  LINK_LIBS PUBLIC
-  MLIRIR
-)
-
+add_mlir_interface_library(FunctionInterfaces EXTRA_SOURCE FunctionImplementation.cpp)
 add_mlir_interface_library(InferIntRangeInterface)
 add_mlir_interface_library(InferTypeOpInterface)
 
diff --git a/mlir/lib/Interfaces/CallImplementation.cpp b/mlir/lib/Interfaces/CallImplementation.cpp
new file mode 100644
index 00000000000000..85eca609d8dc8d
--- /dev/null
+++ b/mlir/lib/Interfaces/CallImplementation.cpp
@@ -0,0 +1,178 @@
+//===- CallImplementation.pp - Call and Callable Op utilities -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Interfaces/CallImplementation.h"
+#include "mlir/IR/Builders.h"
+
+using namespace mlir;
+
+static ParseResult
+parseTypeAndAttrList(OpAsmParser &parser, SmallVectorImpl<Type> &types,
+                     SmallVectorImpl<DictionaryAttr> &attrs) {
+  // Parse individual function results.
+  return parser.parseCommaSeparatedList([&]() -> ParseResult {
+    types.emplace_back();
+    attrs.emplace_back();
+    NamedAttrList attrList;
+    if (parser.parseType(types.back()) ||
+        parser.parseOptionalAttrDict(attrList))
+      return failure();
+    attrs.back() = attrList.getDictionary(parser.getContext());
+    return success();
+  });
+}
+
+ParseResult call_interface_impl::parseFunctionResultList(
+    OpAsmParser &parser, SmallVectorImpl<Type> &resultTypes,
+    SmallVectorImpl<DictionaryAttr> &resultAttrs) {
+  if (failed(parser.parseOptionalLParen())) {
+    // We already know that there is no `(`, so parse a type.
+    // Because there is no `(`, it cannot be a function type.
+    Type ty;
+    if (parser.parseType(ty))
+      return failure();
+    resultTypes.push_back(ty);
+    resultAttrs.emplace_back();
+    return success();
+  }
+
+  // Special case for an empty set of parens.
+  if (succeeded(parser.parseOptionalRParen()))
+    return success();
+  if (parseTypeAndAttrList(parser, resultTypes, resultAttrs))
+    return failure();
+  return parser.parseRParen();
+}
+
+ParseResult call_interface_impl::parseFunctionSignature(
+    OpAsmParser &parser, SmallVectorImpl<Type> &argTypes,
+    SmallVectorImpl<DictionaryAttr> &argAttrs,
+    SmallVectorImpl<Type> &resultTypes,
+    SmallVectorImpl<DictionaryAttr> &resultAttrs, bool mustParseEmptyResult) {
+  // Parse arguments.
+  if (parser.parseLParen())
+    return failure();
+  if (failed(parser.parseOptionalRParen())) {
+    if (parseTypeAndAttrList(parser, argTypes, argAttrs))
+      return failure();
+    if (parser.parseRParen())
+      return failure();
+  }
+  // Parse results.
+  if (succeeded(parser.parseOptionalArrow()))
+    return call_interface_impl::parseFunctionResultList(parser, resultTypes,
+                                                        resultAttrs);
+  if (mustParseEmptyResult)
+    return failure();
+  return success();
+}
+
+/// Print a function result list. The provided `attrs` must either be null, or
+/// contain a set of DictionaryAttrs of the same arity as `types`.
+static void printFunctionResultList(OpAsmPrinter &p, TypeRange types,
+                                    ArrayAttr attrs) {
+  assert(!types.empty() && "Should not be called for empty result list.");
+  assert((!attrs || attrs.size() == types.size()) &&
+         "Invalid number of attributes.");
+
+  auto &os = p.getStream();
+  bool needsParens = types.size() > 1 || llvm::isa<FunctionType>(types[0]) ||
+                     (attrs && !llvm::cast<DictionaryAttr>(attrs[0]).empty());
+  if (needsParens)
+    os << '(';
+  llvm::interleaveComma(llvm::seq<size_t>(0, types.size()), os, [&](size_t i) {
+    p.printType(types[i]);
+    if (attrs)
+      p.printOptionalAttrDict(llvm::cast<DictionaryAttr>(attrs[i]).getValue());
+  });
+  if (needsParens)
+    os << ')';
+}
+
+void call_interface_impl::printFunctionSignature(
+    OpAsmPrinter &p, OpWithArgumentAttributesInterface op, TypeRange argTypes,
+    bool isVariadic, TypeRange resultTypes, Region *body,
+    bool printEmptyResult) {
+  bool isExternal = !body || body->empty();
+  ArrayAttr argAttrs = op.getArgAttrsAttr();
+  ArrayAttr resultAttrs = op.getResAttrsAttr();
+  if (!isExternal && !isVariadic && !argAttrs && !resultAttrs &&
+      printEmptyResult) {
+    p.printFunctionalType(argTypes, resultTypes);
+    return;
+  }
+
+  p << '(';
+  for (unsigned i = 0, e = argTypes.size(); i < e; ++i) {
+    if (i > 0)
+      p << ", ";
+
+    if (!isExternal) {
+      ArrayRef<NamedAttribute> attrs;
+      if (argAttrs)
+        attrs = llvm::cast<DictionaryAttr>(argAttrs[i]).getValue();
+      p.printRegionArgument(body->getArgument(i), attrs);
+    } else {
+      p.printType(argTypes[i]);
+      if (argAttrs)
+        p.printOptionalAttrDict(
+            llvm::cast<DictionaryAttr>(argAttrs[i]).getValue());
+    }
+  }
+
+  if (isVariadic) {
+    if (!argTypes.empty())
+      p << ", ";
+    p << "...";
+  }
+
+  p << ')';
+
+  if (!resultTypes.empty()) {
+    p << " -> ";
+    printFunctionResultList(p, resultTypes, resultAttrs);
+  } else if (printEmptyResult) {
+    p << " -> ()";
+  }
+}
+
+void call_interface_impl::addArgAndResultAttrs(
+    Builder &builder, OperationState &result, ArrayRef<DictionaryAttr> argAttrs,
+    ArrayRef<DictionaryAttr> resultAttrs, StringAttr argAttrsName,
+    StringAttr resAttrsName) {
+  auto nonEmptyAttrsFn = [](DictionaryAttr attrs) {
+    return attrs && !attrs.empty();
+  };
+  // Convert the specified array of dictionary attrs (which may have null
+  // entries) to an ArrayAttr of dictionaries.
+  auto getArrayAttr = [&](ArrayRef<DictionaryAttr> dictAttrs) {
+    SmallVector<Attribute> attrs;
+    for (auto &dict : dictAttrs)
+      attrs.push_back(dict ? dict : builder.getDictionaryAttr({}));
+    return builder.getArrayAttr(attrs);
+  };
+
+  // Add the attributes to the function arguments.
+  if (llvm::any_of(argAttrs, nonEmptyAttrsFn))
+    result.addAttribute(argAttrsName, getArrayAttr(argAttrs));
+
+  // Add the attributes to the function results.
+  if (llvm::any_of(resultAttrs, nonEmptyAttrsFn))
+    result.addAttribute(resAttrsName, getArrayAttr(resultAttrs));
+}
+
+void call_interface_impl::addArgAndResultAttrs(
+    Builder &builder, OperationState &result,
+    ArrayRef<OpAsmParser::Argument> args, ArrayRef<DictionaryAttr> resultAttrs,
+    StringAttr argAttrsName, StringAttr resAttrsName) {
+  SmallVector<DictionaryAttr> argAttrs;
+  for (const auto &arg : args)
+    argAttrs.push_back(arg.attrs);
+  addArgAndResultAttrs(builder, result, argAttrs, resultAttrs, argAttrsName,
+                       resAttrsName);
+}
diff --git a/mlir/lib/Interfaces/FunctionImplementation.cpp b/mlir/lib/Interfaces/FunctionImplementation.cpp
index 988feee665fea6..80174d1fefb559 100644
--- a/mlir/lib/Interfaces/FunctionImplementation.cpp
+++ b/mlir/lib/Interfaces/FunctionImplementation.cpp
@@ -70,49 +70,6 @@ parseFunctionArgumentList(OpAsmParser &parser, bool allowVariadic,
       });
 }
 
-/// Parse a function result list.
-///
-///   function-result-list ::= function-result-list-parens
-///                          | non-function-type
-///   function-result-list-parens ::= `(` `)`
-///                                 | `(` function-result-list-no-parens `)`
-///   function-result-list-no-parens ::= function-result (`,` function-result)*
-///   function-result ::= type attribute-dict?
-///
-static ParseResult
-parseFunctionResultList(OpAsmParser &parser, SmallVectorImpl<Type> &resultTypes,
-                        SmallVectorImpl<DictionaryAttr> &resultAttrs) {
-  if (failed(parser.parseOptionalLParen())) {
-    // We already know that there is no `(`, so parse a type.
-    // Because there is no `(`, it cannot be a function type.
-    Type ty;
-    if (parser.parseType(ty))
-      return failure();
-    resultTypes.push_back(ty);
-    resultAttrs.emplace_back();
-    return success();
-  }
-
-  // Special case for an empty set of parens.
-  if (succeeded(parser.parseOptionalRParen()))
-    return success();
-
-  // Parse individual function results.
-  if (parser.parseCommaSeparatedList([&]() -> ParseResult {
-        resultTypes.emplace_back();
-        resultAttrs.emplace_back();
-        NamedAttrList attrs;
-        if (parser.parseType(resultTypes.back()) ||
-            parser.parseOptionalAttrDict(attrs))
-          return failure();
-        resultAttrs.back() = attrs.getDictionary(parser.getContext());
-        return success();
-      }))
-    return failure();
-
-  return parser.parseRParen();
-}
-
 ParseResult function_interface_impl::parseFunctionSignature(
     OpAsmParser &parser, bool allowVariadic,
     SmallVectorImpl<OpAsmParser::Argument> &arguments, bool &isVariadic,
@@ -121,46 +78,11 @@ ParseResult function_interface_impl::parseFunctionSignature(
   if (parseFunctionArgumentList(parser, allowVariadic, arguments, isVariadic))
     return failure();
   if (succeeded(parser.parseOptionalArrow()))
-    return parseFunctionResultList(parser, resultTypes, resultAttrs);
+    return call_interface_impl::parseFunctionResultList(parser, resultTypes,
+                                                        resultAttrs);
   return success();
 }
 
-void function_interface_impl::addArgAndResultAttrs(
-    Builder &builder, OperationState &result, ArrayRef<DictionaryAttr> argAttrs,
-    ArrayRef<DictionaryAttr> resultAttrs, StringAttr argAttrsName,
-    StringAttr resAttrsName) {
-  auto nonEmptyAttrsFn = [](DictionaryAttr attrs) {
-    return attrs && !attrs.empty();
-  };
-  // Convert the specified array of dictionary attrs (which may have null
-  // entries) to an ArrayAttr of dictionaries.
-  auto getArrayAttr = [&](ArrayRef<DictionaryAttr> dictAttrs) {
-    SmallVector<Attribute> attrs;
-    for (auto &dict : dictAttrs)
-      attrs.push_back(dict ? dict : builder.getDictionaryAttr({}));
-    return builder.getArrayAttr(attrs);
-  };
-
-  // Add the attributes to the function arguments.
-  if (llvm::any_of(argAttrs, nonEmptyAttrsFn))
-    result.addAttribute(argAttrsName, getArrayAttr(argAttrs));
-
-  // Add the attributes to the function results.
-  if (llvm::any_of(resultAttrs, nonEmptyAttrsFn))
-    result.addAttribute(resAttrsName, getArrayAttr(resultAttrs));
-}
-
-void function_interface_impl::addArgAndResultAttrs(
-    Builder &builder, OperationState &result,
-    ArrayRef<OpAsmParser::Argument> args, ArrayRef<DictionaryAttr> resultAttrs,
-    StringAttr argAttrsName, StringAttr resAttrsName) {
-  SmallVector<DictionaryAttr> argAttrs;
-  for (const auto &arg : args)
-    argAttrs.push_back(arg.attrs);
-  addArgAndResultAttrs(builder, result, argAttrs, resultAttrs, argAttrsName,
-                       resAttrsName);
-}
-
 ParseResult function_interface_impl::parseFunctionOp(
     OpAsmParser &parser, OperationState &result, bool allowVariadic,
     StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder,
@@ -221,8 +143,8 @@ ParseResult function_interface_impl::parseFunctionOp(
 
   // Add the attributes to the function arguments.
   assert(resultAttrs.size() == resultTypes.size());
-  addArgAndResultAttrs(builder, result, entryArgs, resultAttrs, argAttrsName,
-                       resAttrsName);
+  call_interface_impl::addArgAndResultAttrs(
+      builder, result, entryArgs, resultAttrs, argAttrsName, resAttrsName);
 
   // Parse the optional function body. The printer will not print the body if
   // its empty, so disallow parsing of empty body in the parser.
@@ -241,68 +163,6 @@ ParseResult function_interface_impl::parseFunctionOp(
   return success();
 }
 
-/// Print a function result list. The provided `attrs` must either be null, or
-/// contain a set of DictionaryAttrs of the same arity as `types`.
-static void printFunctionResultList(OpAsmPrinter &p, ArrayRef<Type> types,
-                                    ArrayAttr attrs) {
-  assert(!types.empty() && "Should not be called for empty result list.");
-  assert((!attrs || attrs.size() == types.size()) &&
-         "Invalid number of attributes.");
-
-  auto &os = p.getStream();
-  bool needsParens = types.size() > 1 || llvm::isa<FunctionType>(types[0]) ||
-                     (attrs && !llvm::cast<DictionaryAttr>(attrs[0]).empty());
-  if (needsParens)
-    os << '(';
-  llvm::interleaveComma(llvm::seq<size_t>(0, types.size()), os, [&](size_t i) {
-    p.printType(types[i]);
-    if (attrs)
-      p.printOptionalAttrDict(llvm::cast<DictionaryAttr>(attrs[i]).getValue());
-  });
-  if (needsParens)
-    os << ')';
-}
-
-void function_interface_impl::printFunctionSignature(
-    OpAsmPrinter &p, FunctionOpInterface op, ArrayRef<Type> argTypes,
-    bool isVariadic, ArrayRef<Type> resultTypes) {
-  Region &body = op->getRegion(0);
-  bool isExternal = body.empty();
-
-  p << '(';
-  ArrayAttr argAttrs = op.getArgAttrsAttr();
-  for (unsigned i = 0, e = argTypes.size(); i < e; ++i) {
-    if (i > 0)
-      p << ", ";
-
-    if (!isExternal) {
-      ArrayRef<NamedAttribute> attrs;
-      if (argAttrs)
-        attrs = llvm::cast<DictionaryAttr>(argAttrs[i]).getValue();
-      p.printRegionArgument(body.getArgument(i), attrs);
-    } else {
-      p.printType(argTypes[i]);
-      if (argAttrs)
-        p.printOptionalAttrDict(
-            llvm::cast<DictionaryAttr>(argAttrs[i]).getValue());
-    }
-  }
-
-  if (isVariadic) {
-    if (!argTypes.empty())
-      p << ", ";
-    p << "...";
-  }
-
-  p << ')';
-
-  if (!resultTypes.empty()) {
-    p.getStream() << " -> ";
-    auto resultAttrs = op.getResAttrsAttr();
-    printFunctionResultList(p, resultTypes, resultAttrs);
-  }
-}
-
 void function_interface_impl::printFunctionAttributes(
     OpAsmPrinter &p, Operation *op, ArrayRef<StringRef> elided) {
   // Print out function attributes, if present.



More information about the Mlir-commits mailing list