[Mlir-commits] [mlir] [mlir] share argument attributes interface between calls and callables (PR #123176)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jan 16 02:03:17 PST 2025
https://github.com/jeanPerier created https://github.com/llvm/llvm-project/pull/123176
First patch of this [RFC](https://discourse.llvm.org/t/mlir-rfc-adding-argument-and-result-attributes-to-llvm-call/84107) to add argument and result attributes on llvm.call (and llvm.invoke).
This patch extracts the core interface methods dealing with argument and result attributes from CallOpInterface in a new interface inherited by both CallOpInterface and CallableInterface.
This allows dialects to "opt-in" having parameter attributes on call arguments by adding `arg_attrs` and `res_attrs` `OptionalAttr<DictArrayAttr>` to the operation definition. They can then re-use the common "rich function signature" printing/parsing helpers if they want.
>From b652f02b7178b2c3d35759770c7b20aca1ede6c4 Mon Sep 17 00:00:00 2001
From: Jean Perier <jperier at nvidia.com>
Date: Wed, 15 Jan 2025 09:01:44 -0800
Subject: [PATCH] [mlir] share argument attributes interface between calls and
callables
---
mlir/docs/Interfaces.md | 16 +-
.../mlir/Interfaces/CallImplementation.h | 91 +++++++++
.../include/mlir/Interfaces/CallInterfaces.td | 99 +++++-----
.../mlir/Interfaces/FunctionImplementation.h | 24 +--
mlir/lib/Dialect/Async/IR/Async.cpp | 2 +-
mlir/lib/Dialect/EmitC/IR/EmitC.cpp | 2 +-
mlir/lib/Dialect/Func/IR/FuncOps.cpp | 2 +-
mlir/lib/Dialect/GPU/IR/GPUDialect.cpp | 2 +-
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 4 +-
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp | 2 +-
mlir/lib/Dialect/Shape/IR/Shape.cpp | 2 +-
mlir/lib/Interfaces/CMakeLists.txt | 21 +--
mlir/lib/Interfaces/CallImplementation.cpp | 178 ++++++++++++++++++
.../lib/Interfaces/FunctionImplementation.cpp | 148 +--------------
14 files changed, 359 insertions(+), 234 deletions(-)
create mode 100644 mlir/include/mlir/Interfaces/CallImplementation.h
create mode 100644 mlir/lib/Interfaces/CallImplementation.cpp
diff --git a/mlir/docs/Interfaces.md b/mlir/docs/Interfaces.md
index 51747db546bb76..195f58dfbc23b3 100644
--- a/mlir/docs/Interfaces.md
+++ b/mlir/docs/Interfaces.md
@@ -753,7 +753,15 @@ interface section goes as follows:
- (`C++ class` -- `ODS class`(if applicable))
##### CallInterfaces
-
+* `OpWithArgumentAttributesInterface` - Used to represent operations that may
+ carry argument and result attributes. It is inherited by both
+ CallOpInterface and CallableOpInterface.
+ - `ArrayAttr getArgAttrsAttr()`
+ - `ArrayAttr getResAttrsAttr()`
+ - `void setArgAttrsAttr(ArrayAttr)`
+ - `void setResAttrsAttr(ArrayAttr)`
+ - `Attribute removeArgAttrsAttr()`
+ - `Attribute removeResAttrsAttr()`
* `CallOpInterface` - Used to represent operations like 'call'
- `CallInterfaceCallable getCallableForCallee()`
- `void setCalleeFromCallable(CallInterfaceCallable)`
@@ -761,12 +769,6 @@ interface section goes as follows:
- `Region * getCallableRegion()`
- `ArrayRef<Type> getArgumentTypes()`
- `ArrayRef<Type> getResultsTypes()`
- - `ArrayAttr getArgAttrsAttr()`
- - `ArrayAttr getResAttrsAttr()`
- - `void setArgAttrsAttr(ArrayAttr)`
- - `void setResAttrsAttr(ArrayAttr)`
- - `Attribute removeArgAttrsAttr()`
- - `Attribute removeResAttrsAttr()`
##### RegionKindInterfaces
diff --git a/mlir/include/mlir/Interfaces/CallImplementation.h b/mlir/include/mlir/Interfaces/CallImplementation.h
new file mode 100644
index 00000000000000..85e47f6b3dbbb9
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/CallImplementation.h
@@ -0,0 +1,91 @@
+//===- CallImplementation.h - Call and Callable Op utilities ----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file provides utility functions for implementing call-like and
+// callable-like operations, in particular, parsing, printing and verification
+// components common to these operations.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_CALLIMPLEMENTATION_H
+#define MLIR_INTERFACES_CALLIMPLEMENTATION_H
+
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/Interfaces/CallInterfaces.h"
+
+namespace mlir {
+
+class OpWithArgumentAttributesInterface;
+
+namespace call_interface_impl {
+
+/// Parse a function or call result list.
+///
+/// function-result-list ::= function-result-list-parens
+/// | non-function-type
+/// function-result-list-parens ::= `(` `)`
+/// | `(` function-result-list-no-parens `)`
+/// function-result-list-no-parens ::= function-result (`,` function-result)*
+/// function-result ::= type attribute-dict?
+///
+ParseResult
+parseFunctionResultList(OpAsmParser &parser, SmallVectorImpl<Type> &resultTypes,
+ SmallVectorImpl<DictionaryAttr> &resultAttrs);
+
+/// Parses a function signature using `parser`. This does not deal with function
+/// signatures containing SSA region arguments (to parse these signatures, use
+/// function_interface_impl::parseFunctionSignature). When
+/// `mustParseEmptyResult`, `-> ()` is expected when there is no result type.
+///
+/// no-ssa-function-signature ::= `(` no-ssa-function-arg-list `)`
+/// -> function-result-list
+/// no-ssa-function-arg-list ::= no-ssa-function-arg
+/// (`,` no-ssa-function-arg)*
+/// no-ssa-function-arg ::= type attribute-dict?
+ParseResult parseFunctionSignature(OpAsmParser &parser,
+ SmallVectorImpl<Type> &argTypes,
+ SmallVectorImpl<DictionaryAttr> &argAttrs,
+ SmallVectorImpl<Type> &resultTypes,
+ SmallVectorImpl<DictionaryAttr> &resultAttrs,
+ bool mustParseEmptyResult = true);
+
+/// Print a function signature for a call or callable operation. If a body
+/// region is provided, the SSA arguments are printed in the signature. When
+/// `printEmptyResult` is false, `-> function-result-list` is omitted when
+/// `resultTypes` is empty.
+///
+/// function-signature ::= ssa-function-signature
+/// | no-ssa-function-signature
+/// ssa-function-signature ::= `(` ssa-function-arg-list `)`
+/// -> function-result-list
+/// ssa-function-arg-list ::= ssa-function-arg (`,` ssa-function-arg)*
+/// ssa-function-arg ::= `%`name `:` type attribute-dict?
+void printFunctionSignature(OpAsmPrinter &p,
+ OpWithArgumentAttributesInterface op,
+ TypeRange argTypes, bool isVariadic,
+ TypeRange resultTypes, Region *body = nullptr,
+ bool printEmptyResult = true);
+
+/// Adds argument and result attributes, provided as `argAttrs` and
+/// `resultAttrs` arguments, to the list of operation attributes in `result`.
+/// Internally, argument and result attributes are stored as dict attributes
+/// with special names given by getResultAttrName, getArgumentAttrName.
+void addArgAndResultAttrs(Builder &builder, OperationState &result,
+ ArrayRef<DictionaryAttr> argAttrs,
+ ArrayRef<DictionaryAttr> resultAttrs,
+ StringAttr argAttrsName, StringAttr resAttrsName);
+void addArgAndResultAttrs(Builder &builder, OperationState &result,
+ ArrayRef<OpAsmParser::Argument> args,
+ ArrayRef<DictionaryAttr> resultAttrs,
+ StringAttr argAttrsName, StringAttr resAttrsName);
+
+} // namespace call_interface_impl
+
+} // namespace mlir
+
+#endif // MLIR_INTERFACES_CALLIMPLEMENTATION_H
diff --git a/mlir/include/mlir/Interfaces/CallInterfaces.td b/mlir/include/mlir/Interfaces/CallInterfaces.td
index c6002da0d491ce..80912a9762187e 100644
--- a/mlir/include/mlir/Interfaces/CallInterfaces.td
+++ b/mlir/include/mlir/Interfaces/CallInterfaces.td
@@ -17,12 +17,65 @@
include "mlir/IR/OpBase.td"
+
+/// Interface for operations with arguments attributes (both call-like
+/// and callable operations).
+def OpWithArgumentAttributesInterface : OpInterface<"OpWithArgumentAttributesInterface"> {
+ let description = [{
+ A call-like or callable operation that may define attributes for its arguments.
+ }];
+ let cppNamespace = "::mlir";
+ let methods = [
+ InterfaceMethod<[{
+ Get the array of argument attribute dictionaries. The method should
+ return an array attribute containing only dictionary attributes equal in
+ number to the number of arguments. Alternatively, the method can
+ return null to indicate that the region has no argument attributes.
+ }],
+ "::mlir::ArrayAttr", "getArgAttrsAttr", (ins),
+ /*methodBody=*/[{}], /*defaultImplementation=*/[{ return nullptr; }]>,
+ InterfaceMethod<[{
+ Get the array of result attribute dictionaries. The method should return
+ an array attribute containing only dictionary attributes equal in number
+ to the number of results. Alternatively, the method can return
+ null to indicate that the region has no result attributes.
+ }],
+ "::mlir::ArrayAttr", "getResAttrsAttr", (ins),
+ /*methodBody=*/[{}], /*defaultImplementation=*/[{ return nullptr; }]>,
+ InterfaceMethod<[{
+ Set the array of argument attribute dictionaries.
+ }],
+ "void", "setArgAttrsAttr", (ins "::mlir::ArrayAttr":$attrs),
+ /*methodBody=*/[{}], /*defaultImplementation=*/[{ return; }]>,
+ InterfaceMethod<[{
+ Set the array of result attribute dictionaries.
+ }],
+ "void", "setResAttrsAttr", (ins "::mlir::ArrayAttr":$attrs),
+ /*methodBody=*/[{}], /*defaultImplementation=*/[{ return; }]>,
+ InterfaceMethod<[{
+ Remove the array of argument attribute dictionaries. This is the same as
+ setting all argument attributes to an empty dictionary. The method should
+ return the removed attribute.
+ }],
+ "::mlir::Attribute", "removeArgAttrsAttr", (ins),
+ /*methodBody=*/[{}], /*defaultImplementation=*/[{ return nullptr; }]>,
+ InterfaceMethod<[{
+ Remove the array of result attribute dictionaries. This is the same as
+ setting all result attributes to an empty dictionary. The method should
+ return the removed attribute.
+ }],
+ "::mlir::Attribute", "removeResAttrsAttr", (ins),
+ /*methodBody=*/[{}], /*defaultImplementation=*/[{ return nullptr; }]>,
+ ];
+}
+
// `CallInterfaceCallable`: This is a type used to represent a single callable
// region. A callable is either a symbol, or an SSA value, that is referenced by
// a call-like operation. This represents the destination of the call.
/// Interface for call-like operations.
-def CallOpInterface : OpInterface<"CallOpInterface"> {
+def CallOpInterface : OpInterface<"CallOpInterface",
+ [OpWithArgumentAttributesInterface]> {
let description = [{
A call-like operation is one that transfers control from one sub-routine to
another. These operations may be traditional direct calls `call @foo`, or
@@ -85,7 +138,8 @@ def CallOpInterface : OpInterface<"CallOpInterface"> {
}
/// Interface for callable operations.
-def CallableOpInterface : OpInterface<"CallableOpInterface"> {
+def CallableOpInterface : OpInterface<"CallableOpInterface",
+ [OpWithArgumentAttributesInterface]> {
let description = [{
A callable operation is one who represents a potential sub-routine, and may
be a target for a call-like operation (those providing the CallOpInterface
@@ -113,47 +167,6 @@ def CallableOpInterface : OpInterface<"CallableOpInterface"> {
allow for this method may be called on function declarations).
}],
"::llvm::ArrayRef<::mlir::Type>", "getResultTypes">,
-
- InterfaceMethod<[{
- Get the array of argument attribute dictionaries. The method should
- return an array attribute containing only dictionary attributes equal in
- number to the number of region arguments. Alternatively, the method can
- return null to indicate that the region has no argument attributes.
- }],
- "::mlir::ArrayAttr", "getArgAttrsAttr", (ins),
- /*methodBody=*/[{}], /*defaultImplementation=*/[{ return nullptr; }]>,
- InterfaceMethod<[{
- Get the array of result attribute dictionaries. The method should return
- an array attribute containing only dictionary attributes equal in number
- to the number of region results. Alternatively, the method can return
- null to indicate that the region has no result attributes.
- }],
- "::mlir::ArrayAttr", "getResAttrsAttr", (ins),
- /*methodBody=*/[{}], /*defaultImplementation=*/[{ return nullptr; }]>,
- InterfaceMethod<[{
- Set the array of argument attribute dictionaries.
- }],
- "void", "setArgAttrsAttr", (ins "::mlir::ArrayAttr":$attrs),
- /*methodBody=*/[{}], /*defaultImplementation=*/[{ return; }]>,
- InterfaceMethod<[{
- Set the array of result attribute dictionaries.
- }],
- "void", "setResAttrsAttr", (ins "::mlir::ArrayAttr":$attrs),
- /*methodBody=*/[{}], /*defaultImplementation=*/[{ return; }]>,
- InterfaceMethod<[{
- Remove the array of argument attribute dictionaries. This is the same as
- setting all argument attributes to an empty dictionary. The method should
- return the removed attribute.
- }],
- "::mlir::Attribute", "removeArgAttrsAttr", (ins),
- /*methodBody=*/[{}], /*defaultImplementation=*/[{ return nullptr; }]>,
- InterfaceMethod<[{
- Remove the array of result attribute dictionaries. This is the same as
- setting all result attributes to an empty dictionary. The method should
- return the removed attribute.
- }],
- "::mlir::Attribute", "removeResAttrsAttr", (ins),
- /*methodBody=*/[{}], /*defaultImplementation=*/[{ return nullptr; }]>,
];
}
diff --git a/mlir/include/mlir/Interfaces/FunctionImplementation.h b/mlir/include/mlir/Interfaces/FunctionImplementation.h
index a5e6963e4e666f..ae20533ef4b87c 100644
--- a/mlir/include/mlir/Interfaces/FunctionImplementation.h
+++ b/mlir/include/mlir/Interfaces/FunctionImplementation.h
@@ -16,6 +16,7 @@
#define MLIR_IR_FUNCTIONIMPLEMENTATION_H_
#include "mlir/IR/OpImplementation.h"
+#include "mlir/Interfaces/CallImplementation.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
namespace mlir {
@@ -33,19 +34,6 @@ class VariadicFlag {
bool variadic;
};
-/// Adds argument and result attributes, provided as `argAttrs` and
-/// `resultAttrs` arguments, to the list of operation attributes in `result`.
-/// Internally, argument and result attributes are stored as dict attributes
-/// with special names given by getResultAttrName, getArgumentAttrName.
-void addArgAndResultAttrs(Builder &builder, OperationState &result,
- ArrayRef<DictionaryAttr> argAttrs,
- ArrayRef<DictionaryAttr> resultAttrs,
- StringAttr argAttrsName, StringAttr resAttrsName);
-void addArgAndResultAttrs(Builder &builder, OperationState &result,
- ArrayRef<OpAsmParser::Argument> args,
- ArrayRef<DictionaryAttr> resultAttrs,
- StringAttr argAttrsName, StringAttr resAttrsName);
-
/// Callback type for `parseFunctionOp`, the callback should produce the
/// type that will be associated with a function-like operation from lists of
/// function arguments and results, VariadicFlag indicates whether the function
@@ -84,9 +72,13 @@ void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic,
/// Prints the signature of the function-like operation `op`. Assumes `op` has
/// is a FunctionOpInterface and has passed verification.
-void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op,
- ArrayRef<Type> argTypes, bool isVariadic,
- ArrayRef<Type> resultTypes);
+inline void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op,
+ ArrayRef<Type> argTypes, bool isVariadic,
+ ArrayRef<Type> resultTypes) {
+ call_interface_impl::printFunctionSignature(p, op, argTypes, isVariadic,
+ resultTypes, &op->getRegion(0),
+ /*printEmptyResult=*/false);
+}
/// Prints the list of function prefixed with the "attributes" keyword. The
/// attributes with names listed in "elided" as well as those used by the
diff --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp
index a3e3f80954efce..d3bb250bb8ab9d 100644
--- a/mlir/lib/Dialect/Async/IR/Async.cpp
+++ b/mlir/lib/Dialect/Async/IR/Async.cpp
@@ -308,7 +308,7 @@ void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
if (argAttrs.empty())
return;
assert(type.getNumInputs() == argAttrs.size());
- function_interface_impl::addArgAndResultAttrs(
+ call_interface_impl::addArgAndResultAttrs(
builder, state, argAttrs, /*resultAttrs=*/std::nullopt,
getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name));
}
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index fdc21d6c6e24b9..6af17087ced968 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -529,7 +529,7 @@ void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
if (argAttrs.empty())
return;
assert(type.getNumInputs() == argAttrs.size());
- function_interface_impl::addArgAndResultAttrs(
+ call_interface_impl::addArgAndResultAttrs(
builder, state, argAttrs, /*resultAttrs=*/std::nullopt,
getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name));
}
diff --git a/mlir/lib/Dialect/Func/IR/FuncOps.cpp b/mlir/lib/Dialect/Func/IR/FuncOps.cpp
index a490b4c3c4ab43..ba7b84f27d6a8a 100644
--- a/mlir/lib/Dialect/Func/IR/FuncOps.cpp
+++ b/mlir/lib/Dialect/Func/IR/FuncOps.cpp
@@ -190,7 +190,7 @@ void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
if (argAttrs.empty())
return;
assert(type.getNumInputs() == argAttrs.size());
- function_interface_impl::addArgAndResultAttrs(
+ call_interface_impl::addArgAndResultAttrs(
builder, state, argAttrs, /*resultAttrs=*/std::nullopt,
getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name));
}
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 49209229259a73..8b85c0829acfec 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -1487,7 +1487,7 @@ ParseResult GPUFuncOp::parse(OpAsmParser &parser, OperationState &result) {
result.addAttribute(getFunctionTypeAttrName(result.name),
TypeAttr::get(type));
- function_interface_impl::addArgAndResultAttrs(
+ call_interface_impl::addArgAndResultAttrs(
builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name),
getResAttrsAttrName(result.name));
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index ef5f1b069b40a3..ef1e0222e05f06 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -2510,7 +2510,7 @@ void LLVMFuncOp::build(OpBuilder &builder, OperationState &result,
assert(llvm::cast<LLVMFunctionType>(type).getNumParams() == argAttrs.size() &&
"expected as many argument attribute lists as arguments");
- function_interface_impl::addArgAndResultAttrs(
+ call_interface_impl::addArgAndResultAttrs(
builder, result, argAttrs, /*resultAttrs=*/std::nullopt,
getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
}
@@ -2636,7 +2636,7 @@ ParseResult LLVMFuncOp::parse(OpAsmParser &parser, OperationState &result) {
if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
return failure();
- function_interface_impl::addArgAndResultAttrs(
+ call_interface_impl::addArgAndResultAttrs(
parser.getBuilder(), result, entryArgs, resultAttrs,
getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 26559c1321db5e..870359ce55301c 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -940,7 +940,7 @@ ParseResult spirv::FuncOp::parse(OpAsmParser &parser, OperationState &result) {
// Add the attributes to the function arguments.
assert(resultAttrs.size() == resultTypes.size());
- function_interface_impl::addArgAndResultAttrs(
+ call_interface_impl::addArgAndResultAttrs(
builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name),
getResAttrsAttrName(result.name));
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 65efc88e9c4033..2200af0f67a862 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -1297,7 +1297,7 @@ void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
if (argAttrs.empty())
return;
assert(type.getNumInputs() == argAttrs.size());
- function_interface_impl::addArgAndResultAttrs(
+ call_interface_impl::addArgAndResultAttrs(
builder, state, argAttrs, /*resultAttrs=*/std::nullopt,
getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name));
}
diff --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt
index d3b7bf65ad3e73..76e2d921c2a9f5 100644
--- a/mlir/lib/Interfaces/CMakeLists.txt
+++ b/mlir/lib/Interfaces/CMakeLists.txt
@@ -1,4 +1,5 @@
set(LLVM_OPTIONAL_SOURCES
+ CallImplementation.cpp
CallInterfaces.cpp
CastInterfaces.cpp
ControlFlowInterfaces.cpp
@@ -24,8 +25,10 @@ set(LLVM_OPTIONAL_SOURCES
)
function(add_mlir_interface_library name)
+ cmake_parse_arguments(ARG "" "" "EXTRA_SOURCE" ${ARGN})
add_mlir_library(MLIR${name}
${name}.cpp
+ ${ARG_EXTRA_SOURCE}
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Interfaces
@@ -39,28 +42,14 @@ function(add_mlir_interface_library name)
endfunction(add_mlir_interface_library)
-add_mlir_interface_library(CallInterfaces)
+add_mlir_interface_library(CallInterfaces EXTRA_SOURCE CallImplementation.cpp)
add_mlir_interface_library(CastInterfaces)
add_mlir_interface_library(ControlFlowInterfaces)
add_mlir_interface_library(CopyOpInterface)
add_mlir_interface_library(DataLayoutInterfaces)
add_mlir_interface_library(DerivedAttributeOpInterface)
add_mlir_interface_library(DestinationStyleOpInterface)
-
-add_mlir_library(MLIRFunctionInterfaces
- FunctionInterfaces.cpp
- FunctionImplementation.cpp
-
- ADDITIONAL_HEADER_DIRS
- ${MLIR_MAIN_INCLUDE_DIR}/mlir/Interfaces
-
- DEPENDS
- MLIRFunctionInterfacesIncGen
-
- LINK_LIBS PUBLIC
- MLIRIR
-)
-
+add_mlir_interface_library(FunctionInterfaces EXTRA_SOURCE FunctionImplementation.cpp)
add_mlir_interface_library(InferIntRangeInterface)
add_mlir_interface_library(InferTypeOpInterface)
diff --git a/mlir/lib/Interfaces/CallImplementation.cpp b/mlir/lib/Interfaces/CallImplementation.cpp
new file mode 100644
index 00000000000000..85eca609d8dc8d
--- /dev/null
+++ b/mlir/lib/Interfaces/CallImplementation.cpp
@@ -0,0 +1,178 @@
+//===- CallImplementation.pp - Call and Callable Op utilities -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Interfaces/CallImplementation.h"
+#include "mlir/IR/Builders.h"
+
+using namespace mlir;
+
+static ParseResult
+parseTypeAndAttrList(OpAsmParser &parser, SmallVectorImpl<Type> &types,
+ SmallVectorImpl<DictionaryAttr> &attrs) {
+ // Parse individual function results.
+ return parser.parseCommaSeparatedList([&]() -> ParseResult {
+ types.emplace_back();
+ attrs.emplace_back();
+ NamedAttrList attrList;
+ if (parser.parseType(types.back()) ||
+ parser.parseOptionalAttrDict(attrList))
+ return failure();
+ attrs.back() = attrList.getDictionary(parser.getContext());
+ return success();
+ });
+}
+
+ParseResult call_interface_impl::parseFunctionResultList(
+ OpAsmParser &parser, SmallVectorImpl<Type> &resultTypes,
+ SmallVectorImpl<DictionaryAttr> &resultAttrs) {
+ if (failed(parser.parseOptionalLParen())) {
+ // We already know that there is no `(`, so parse a type.
+ // Because there is no `(`, it cannot be a function type.
+ Type ty;
+ if (parser.parseType(ty))
+ return failure();
+ resultTypes.push_back(ty);
+ resultAttrs.emplace_back();
+ return success();
+ }
+
+ // Special case for an empty set of parens.
+ if (succeeded(parser.parseOptionalRParen()))
+ return success();
+ if (parseTypeAndAttrList(parser, resultTypes, resultAttrs))
+ return failure();
+ return parser.parseRParen();
+}
+
+ParseResult call_interface_impl::parseFunctionSignature(
+ OpAsmParser &parser, SmallVectorImpl<Type> &argTypes,
+ SmallVectorImpl<DictionaryAttr> &argAttrs,
+ SmallVectorImpl<Type> &resultTypes,
+ SmallVectorImpl<DictionaryAttr> &resultAttrs, bool mustParseEmptyResult) {
+ // Parse arguments.
+ if (parser.parseLParen())
+ return failure();
+ if (failed(parser.parseOptionalRParen())) {
+ if (parseTypeAndAttrList(parser, argTypes, argAttrs))
+ return failure();
+ if (parser.parseRParen())
+ return failure();
+ }
+ // Parse results.
+ if (succeeded(parser.parseOptionalArrow()))
+ return call_interface_impl::parseFunctionResultList(parser, resultTypes,
+ resultAttrs);
+ if (mustParseEmptyResult)
+ return failure();
+ return success();
+}
+
+/// Print a function result list. The provided `attrs` must either be null, or
+/// contain a set of DictionaryAttrs of the same arity as `types`.
+static void printFunctionResultList(OpAsmPrinter &p, TypeRange types,
+ ArrayAttr attrs) {
+ assert(!types.empty() && "Should not be called for empty result list.");
+ assert((!attrs || attrs.size() == types.size()) &&
+ "Invalid number of attributes.");
+
+ auto &os = p.getStream();
+ bool needsParens = types.size() > 1 || llvm::isa<FunctionType>(types[0]) ||
+ (attrs && !llvm::cast<DictionaryAttr>(attrs[0]).empty());
+ if (needsParens)
+ os << '(';
+ llvm::interleaveComma(llvm::seq<size_t>(0, types.size()), os, [&](size_t i) {
+ p.printType(types[i]);
+ if (attrs)
+ p.printOptionalAttrDict(llvm::cast<DictionaryAttr>(attrs[i]).getValue());
+ });
+ if (needsParens)
+ os << ')';
+}
+
+void call_interface_impl::printFunctionSignature(
+ OpAsmPrinter &p, OpWithArgumentAttributesInterface op, TypeRange argTypes,
+ bool isVariadic, TypeRange resultTypes, Region *body,
+ bool printEmptyResult) {
+ bool isExternal = !body || body->empty();
+ ArrayAttr argAttrs = op.getArgAttrsAttr();
+ ArrayAttr resultAttrs = op.getResAttrsAttr();
+ if (!isExternal && !isVariadic && !argAttrs && !resultAttrs &&
+ printEmptyResult) {
+ p.printFunctionalType(argTypes, resultTypes);
+ return;
+ }
+
+ p << '(';
+ for (unsigned i = 0, e = argTypes.size(); i < e; ++i) {
+ if (i > 0)
+ p << ", ";
+
+ if (!isExternal) {
+ ArrayRef<NamedAttribute> attrs;
+ if (argAttrs)
+ attrs = llvm::cast<DictionaryAttr>(argAttrs[i]).getValue();
+ p.printRegionArgument(body->getArgument(i), attrs);
+ } else {
+ p.printType(argTypes[i]);
+ if (argAttrs)
+ p.printOptionalAttrDict(
+ llvm::cast<DictionaryAttr>(argAttrs[i]).getValue());
+ }
+ }
+
+ if (isVariadic) {
+ if (!argTypes.empty())
+ p << ", ";
+ p << "...";
+ }
+
+ p << ')';
+
+ if (!resultTypes.empty()) {
+ p << " -> ";
+ printFunctionResultList(p, resultTypes, resultAttrs);
+ } else if (printEmptyResult) {
+ p << " -> ()";
+ }
+}
+
+void call_interface_impl::addArgAndResultAttrs(
+ Builder &builder, OperationState &result, ArrayRef<DictionaryAttr> argAttrs,
+ ArrayRef<DictionaryAttr> resultAttrs, StringAttr argAttrsName,
+ StringAttr resAttrsName) {
+ auto nonEmptyAttrsFn = [](DictionaryAttr attrs) {
+ return attrs && !attrs.empty();
+ };
+ // Convert the specified array of dictionary attrs (which may have null
+ // entries) to an ArrayAttr of dictionaries.
+ auto getArrayAttr = [&](ArrayRef<DictionaryAttr> dictAttrs) {
+ SmallVector<Attribute> attrs;
+ for (auto &dict : dictAttrs)
+ attrs.push_back(dict ? dict : builder.getDictionaryAttr({}));
+ return builder.getArrayAttr(attrs);
+ };
+
+ // Add the attributes to the function arguments.
+ if (llvm::any_of(argAttrs, nonEmptyAttrsFn))
+ result.addAttribute(argAttrsName, getArrayAttr(argAttrs));
+
+ // Add the attributes to the function results.
+ if (llvm::any_of(resultAttrs, nonEmptyAttrsFn))
+ result.addAttribute(resAttrsName, getArrayAttr(resultAttrs));
+}
+
+void call_interface_impl::addArgAndResultAttrs(
+ Builder &builder, OperationState &result,
+ ArrayRef<OpAsmParser::Argument> args, ArrayRef<DictionaryAttr> resultAttrs,
+ StringAttr argAttrsName, StringAttr resAttrsName) {
+ SmallVector<DictionaryAttr> argAttrs;
+ for (const auto &arg : args)
+ argAttrs.push_back(arg.attrs);
+ addArgAndResultAttrs(builder, result, argAttrs, resultAttrs, argAttrsName,
+ resAttrsName);
+}
diff --git a/mlir/lib/Interfaces/FunctionImplementation.cpp b/mlir/lib/Interfaces/FunctionImplementation.cpp
index 988feee665fea6..80174d1fefb559 100644
--- a/mlir/lib/Interfaces/FunctionImplementation.cpp
+++ b/mlir/lib/Interfaces/FunctionImplementation.cpp
@@ -70,49 +70,6 @@ parseFunctionArgumentList(OpAsmParser &parser, bool allowVariadic,
});
}
-/// Parse a function result list.
-///
-/// function-result-list ::= function-result-list-parens
-/// | non-function-type
-/// function-result-list-parens ::= `(` `)`
-/// | `(` function-result-list-no-parens `)`
-/// function-result-list-no-parens ::= function-result (`,` function-result)*
-/// function-result ::= type attribute-dict?
-///
-static ParseResult
-parseFunctionResultList(OpAsmParser &parser, SmallVectorImpl<Type> &resultTypes,
- SmallVectorImpl<DictionaryAttr> &resultAttrs) {
- if (failed(parser.parseOptionalLParen())) {
- // We already know that there is no `(`, so parse a type.
- // Because there is no `(`, it cannot be a function type.
- Type ty;
- if (parser.parseType(ty))
- return failure();
- resultTypes.push_back(ty);
- resultAttrs.emplace_back();
- return success();
- }
-
- // Special case for an empty set of parens.
- if (succeeded(parser.parseOptionalRParen()))
- return success();
-
- // Parse individual function results.
- if (parser.parseCommaSeparatedList([&]() -> ParseResult {
- resultTypes.emplace_back();
- resultAttrs.emplace_back();
- NamedAttrList attrs;
- if (parser.parseType(resultTypes.back()) ||
- parser.parseOptionalAttrDict(attrs))
- return failure();
- resultAttrs.back() = attrs.getDictionary(parser.getContext());
- return success();
- }))
- return failure();
-
- return parser.parseRParen();
-}
-
ParseResult function_interface_impl::parseFunctionSignature(
OpAsmParser &parser, bool allowVariadic,
SmallVectorImpl<OpAsmParser::Argument> &arguments, bool &isVariadic,
@@ -121,46 +78,11 @@ ParseResult function_interface_impl::parseFunctionSignature(
if (parseFunctionArgumentList(parser, allowVariadic, arguments, isVariadic))
return failure();
if (succeeded(parser.parseOptionalArrow()))
- return parseFunctionResultList(parser, resultTypes, resultAttrs);
+ return call_interface_impl::parseFunctionResultList(parser, resultTypes,
+ resultAttrs);
return success();
}
-void function_interface_impl::addArgAndResultAttrs(
- Builder &builder, OperationState &result, ArrayRef<DictionaryAttr> argAttrs,
- ArrayRef<DictionaryAttr> resultAttrs, StringAttr argAttrsName,
- StringAttr resAttrsName) {
- auto nonEmptyAttrsFn = [](DictionaryAttr attrs) {
- return attrs && !attrs.empty();
- };
- // Convert the specified array of dictionary attrs (which may have null
- // entries) to an ArrayAttr of dictionaries.
- auto getArrayAttr = [&](ArrayRef<DictionaryAttr> dictAttrs) {
- SmallVector<Attribute> attrs;
- for (auto &dict : dictAttrs)
- attrs.push_back(dict ? dict : builder.getDictionaryAttr({}));
- return builder.getArrayAttr(attrs);
- };
-
- // Add the attributes to the function arguments.
- if (llvm::any_of(argAttrs, nonEmptyAttrsFn))
- result.addAttribute(argAttrsName, getArrayAttr(argAttrs));
-
- // Add the attributes to the function results.
- if (llvm::any_of(resultAttrs, nonEmptyAttrsFn))
- result.addAttribute(resAttrsName, getArrayAttr(resultAttrs));
-}
-
-void function_interface_impl::addArgAndResultAttrs(
- Builder &builder, OperationState &result,
- ArrayRef<OpAsmParser::Argument> args, ArrayRef<DictionaryAttr> resultAttrs,
- StringAttr argAttrsName, StringAttr resAttrsName) {
- SmallVector<DictionaryAttr> argAttrs;
- for (const auto &arg : args)
- argAttrs.push_back(arg.attrs);
- addArgAndResultAttrs(builder, result, argAttrs, resultAttrs, argAttrsName,
- resAttrsName);
-}
-
ParseResult function_interface_impl::parseFunctionOp(
OpAsmParser &parser, OperationState &result, bool allowVariadic,
StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder,
@@ -221,8 +143,8 @@ ParseResult function_interface_impl::parseFunctionOp(
// Add the attributes to the function arguments.
assert(resultAttrs.size() == resultTypes.size());
- addArgAndResultAttrs(builder, result, entryArgs, resultAttrs, argAttrsName,
- resAttrsName);
+ call_interface_impl::addArgAndResultAttrs(
+ builder, result, entryArgs, resultAttrs, argAttrsName, resAttrsName);
// Parse the optional function body. The printer will not print the body if
// its empty, so disallow parsing of empty body in the parser.
@@ -241,68 +163,6 @@ ParseResult function_interface_impl::parseFunctionOp(
return success();
}
-/// Print a function result list. The provided `attrs` must either be null, or
-/// contain a set of DictionaryAttrs of the same arity as `types`.
-static void printFunctionResultList(OpAsmPrinter &p, ArrayRef<Type> types,
- ArrayAttr attrs) {
- assert(!types.empty() && "Should not be called for empty result list.");
- assert((!attrs || attrs.size() == types.size()) &&
- "Invalid number of attributes.");
-
- auto &os = p.getStream();
- bool needsParens = types.size() > 1 || llvm::isa<FunctionType>(types[0]) ||
- (attrs && !llvm::cast<DictionaryAttr>(attrs[0]).empty());
- if (needsParens)
- os << '(';
- llvm::interleaveComma(llvm::seq<size_t>(0, types.size()), os, [&](size_t i) {
- p.printType(types[i]);
- if (attrs)
- p.printOptionalAttrDict(llvm::cast<DictionaryAttr>(attrs[i]).getValue());
- });
- if (needsParens)
- os << ')';
-}
-
-void function_interface_impl::printFunctionSignature(
- OpAsmPrinter &p, FunctionOpInterface op, ArrayRef<Type> argTypes,
- bool isVariadic, ArrayRef<Type> resultTypes) {
- Region &body = op->getRegion(0);
- bool isExternal = body.empty();
-
- p << '(';
- ArrayAttr argAttrs = op.getArgAttrsAttr();
- for (unsigned i = 0, e = argTypes.size(); i < e; ++i) {
- if (i > 0)
- p << ", ";
-
- if (!isExternal) {
- ArrayRef<NamedAttribute> attrs;
- if (argAttrs)
- attrs = llvm::cast<DictionaryAttr>(argAttrs[i]).getValue();
- p.printRegionArgument(body.getArgument(i), attrs);
- } else {
- p.printType(argTypes[i]);
- if (argAttrs)
- p.printOptionalAttrDict(
- llvm::cast<DictionaryAttr>(argAttrs[i]).getValue());
- }
- }
-
- if (isVariadic) {
- if (!argTypes.empty())
- p << ", ";
- p << "...";
- }
-
- p << ')';
-
- if (!resultTypes.empty()) {
- p.getStream() << " -> ";
- auto resultAttrs = op.getResAttrsAttr();
- printFunctionResultList(p, resultTypes, resultAttrs);
- }
-}
-
void function_interface_impl::printFunctionAttributes(
OpAsmPrinter &p, Operation *op, ArrayRef<StringRef> elided) {
// Print out function attributes, if present.
More information about the Mlir-commits
mailing list