[Mlir-commits] [mlir] d959244 - [mlir][llvm] Move LLVMFunctionType to a TypeDef
Jeff Niu
llvmlistbot at llvm.org
Fri Oct 21 15:13:19 PDT 2022
Author: Jeff Niu
Date: 2022-10-21T15:13:07-07:00
New Revision: d9592444cea11fcc9e8debe0f5eff331bdabbdc4
URL: https://github.com/llvm/llvm-project/commit/d9592444cea11fcc9e8debe0f5eff331bdabbdc4
DIFF: https://github.com/llvm/llvm-project/commit/d9592444cea11fcc9e8debe0f5eff331bdabbdc4.diff
LOG: [mlir][llvm] Move LLVMFunctionType to a TypeDef
Reviewed By: rriddle
Differential Revision: https://reviews.llvm.org/D136485
Added:
Modified:
mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
mlir/lib/Dialect/LLVMIR/IR/TypeDetail.h
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
index da24d4df8129..237082b44479 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
@@ -72,69 +72,6 @@ DEFINE_TRIVIAL_LLVM_TYPE(LLVMMetadataType);
#undef DEFINE_TRIVIAL_LLVM_TYPE
-//===----------------------------------------------------------------------===//
-// LLVMFunctionType.
-//===----------------------------------------------------------------------===//
-
-/// LLVM dialect function type. It consists of a single return type (unlike MLIR
-/// which can have multiple), a list of parameter types and can optionally be
-/// variadic.
-class LLVMFunctionType : public Type::TypeBase<LLVMFunctionType, Type,
- detail::LLVMFunctionTypeStorage,
- SubElementTypeInterface::Trait> {
-public:
- /// Inherit base constructors.
- using Base::Base;
- using Base::getChecked;
-
- /// Checks if the given type can be used an argument in a function type.
- static bool isValidArgumentType(Type type);
-
- /// Checks if the given type can be used as a result in a function type.
- static bool isValidResultType(Type type);
-
- /// Returns whether the function is variadic.
- bool isVarArg() const;
-
- /// Gets or creates an instance of LLVM dialect function in the same context
- /// as the `result` type.
- static LLVMFunctionType get(Type result, ArrayRef<Type> arguments,
- bool isVarArg = false);
- static LLVMFunctionType
- getChecked(function_ref<InFlightDiagnostic()> emitError, Type result,
- ArrayRef<Type> arguments, bool isVarArg = false);
-
- /// Returns a clone of this function type with the given argument
- /// and result types.
- LLVMFunctionType clone(TypeRange inputs, TypeRange results) const;
-
- /// Returns the result type of the function.
- Type getReturnType() const;
-
- /// Returns the result type of the function as an ArrayRef, enabling better
- /// integration with generic MLIR utilities.
- ArrayRef<Type> getReturnTypes() const;
-
- /// Returns the number of arguments to the function.
- unsigned getNumParams();
-
- /// Returns `i`-th argument of the function. Asserts on out-of-bounds.
- Type getParamType(unsigned i);
-
- /// Returns a list of argument types of the function.
- ArrayRef<Type> getParams() const;
- ArrayRef<Type> params() { return getParams(); }
-
- /// Verifies that the type about to be constructed is well-formed.
- static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
- Type result, ArrayRef<Type> arguments, bool);
-
- void walkImmediateSubElements(function_ref<void(Attribute)> walkAttrsFn,
- function_ref<void(Type)> walkTypesFn) const;
- Type replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
- ArrayRef<Type> replTypes) const;
-};
-
//===----------------------------------------------------------------------===//
// LLVMPointerType.
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
index b347406f0521..6ddef17aa92a 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
@@ -58,4 +58,64 @@ def LLVMArrayType : LLVMType<"LLVMArray", "array", [
}];
}
+//===----------------------------------------------------------------------===//
+// LLVMFunctionType
+//===----------------------------------------------------------------------===//
+
+def LLVMFunctionType : LLVMType<"LLVMFunction", "func", [
+ DeclareTypeInterfaceMethods<SubElementTypeInterface>]> {
+ let summary = "LLVM function type";
+ let description = [{
+ The `!llvm.func` is a function type. It consists of a single return type
+ (unlike MLIR which can have multiple), a list of parameter types and can
+ optionally be variadic.
+
+ Example:
+
+ ```mlir
+ !llvm.func<i32 (i32)>
+ ```
+ }];
+
+ let parameters = (ins "Type":$returnType, ArrayRefParameter<"Type">:$params,
+ "bool":$varArg);
+ let assemblyFormat = [{
+ `<` custom<PrettyLLVMType>($returnType) ` ` `(`
+ custom<FunctionTypes>($params, $varArg) `>`
+ }];
+
+ let genVerifyDecl = 1;
+
+ let builders = [
+ TypeBuilderWithInferredContext<(ins
+ "Type":$result, "ArrayRef<Type>":$arguments,
+ CArg<"bool", "false">:$isVarArg)>
+ ];
+
+ let extraClassDeclaration = [{
+ /// Checks if the given type can be used an argument in a function type.
+ static bool isValidArgumentType(Type type);
+
+ /// Checks if the given type can be used as a result in a function type.
+ static bool isValidResultType(Type type);
+
+ /// Returns whether the function is variadic.
+ bool isVarArg() const { return getVarArg(); }
+
+ /// Returns a clone of this function type with the given argument
+ /// and result types.
+ LLVMFunctionType clone(TypeRange inputs, TypeRange results) const;
+
+ /// Returns the result type of the function as an ArrayRef, enabling better
+ /// integration with generic MLIR utilities.
+ ArrayRef<Type> getReturnTypes() const;
+
+ /// Returns the number of arguments to the function.
+ unsigned getNumParams() const { return getParams().size(); }
+
+ /// Returns `i`-th argument of the function. Asserts on out-of-bounds.
+ Type getParamType(unsigned i) { return getParams()[i]; }
+ }];
+}
+
#endif // LLVMTYPES_TD
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 72f37f9a1780..71ba80d0ce7e 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -2573,7 +2573,6 @@ void LLVMDialect::initialize() {
LLVMTokenType,
LLVMLabelType,
LLVMMetadataType,
- LLVMFunctionType,
LLVMPointerType,
LLVMFixedVectorType,
LLVMScalableVectorType,
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
index 932477488cb5..566ef63110b3 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
@@ -107,22 +107,6 @@ static void printVectorType(AsmPrinter &printer, TypeTy type) {
printer << '>';
}
-/// Prints a function type.
-static void printFunctionType(AsmPrinter &printer, LLVMFunctionType funcType) {
- printer << '<';
- dispatchPrint(printer, funcType.getReturnType());
- printer << " (";
- llvm::interleaveComma(
- funcType.getParams(), printer.getStream(),
- [&printer](Type subtype) { dispatchPrint(printer, subtype); });
- if (funcType.isVarArg()) {
- if (funcType.getNumParams() != 0)
- printer << ", ";
- printer << "...";
- }
- printer << ")>";
-}
-
/// Prints the given LLVM dialect type recursively. This leverages closedness of
/// the LLVM dialect type system to avoid printing the dialect prefix
/// repeatedly. For recursive structures, only prints the name of the structure
@@ -171,7 +155,7 @@ void mlir::LLVM::detail::printType(Type type, AsmPrinter &printer) {
return printStructType(printer, structType);
if (auto funcType = type.dyn_cast<LLVMFunctionType>())
- return printFunctionType(printer, funcType);
+ return funcType.print(printer);
}
//===----------------------------------------------------------------------===//
@@ -180,45 +164,6 @@ void mlir::LLVM::detail::printType(Type type, AsmPrinter &printer) {
static ParseResult dispatchParse(AsmParser &parser, Type &type);
-/// Parses an LLVM dialect function type.
-/// llvm-type :: = `func<` llvm-type `(` llvm-type-list `...`? `)>`
-static LLVMFunctionType parseFunctionType(AsmParser &parser) {
- SMLoc loc = parser.getCurrentLocation();
- Type returnType;
- if (parser.parseLess() || dispatchParse(parser, returnType) ||
- parser.parseLParen())
- return LLVMFunctionType();
-
- // Function type without arguments.
- if (succeeded(parser.parseOptionalRParen())) {
- if (succeeded(parser.parseGreater()))
- return parser.getChecked<LLVMFunctionType>(loc, returnType, llvm::None,
- /*isVarArg=*/false);
- return LLVMFunctionType();
- }
-
- // Parse arguments.
- SmallVector<Type, 8> argTypes;
- do {
- if (succeeded(parser.parseOptionalEllipsis())) {
- if (parser.parseOptionalRParen() || parser.parseOptionalGreater())
- return LLVMFunctionType();
- return parser.getChecked<LLVMFunctionType>(loc, returnType, argTypes,
- /*isVarArg=*/true);
- }
-
- Type arg;
- if (dispatchParse(parser, arg))
- return LLVMFunctionType();
- argTypes.push_back(arg);
- } while (succeeded(parser.parseOptionalComma()));
-
- if (parser.parseOptionalRParen() || parser.parseOptionalGreater())
- return LLVMFunctionType();
- return parser.getChecked<LLVMFunctionType>(loc, returnType, argTypes,
- /*isVarArg=*/false);
-}
-
/// Parses an LLVM dialect pointer type.
/// llvm-type ::= `ptr<` llvm-type (`,` integer)? `>`
/// | `ptr` (`<` integer `>`)?
@@ -445,7 +390,7 @@ static Type dispatchParse(AsmParser &parser, bool allowAny = true) {
.Case("token", [&] { return LLVMTokenType::get(ctx); })
.Case("label", [&] { return LLVMLabelType::get(ctx); })
.Case("metadata", [&] { return LLVMMetadataType::get(ctx); })
- .Case("func", [&] { return parseFunctionType(parser); })
+ .Case("func", [&] { return LLVMFunctionType::parse(parser); })
.Case("ptr", [&] { return parsePointerType(parser); })
.Case("vec", [&] { return parseVectorType(parser); })
.Case("array", [&] { return LLVMArrayType::parse(parser); })
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
index 10170b02c382..f55d2ae45c45 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
@@ -27,6 +27,70 @@ using namespace mlir::LLVM;
constexpr const static unsigned kBitsInByte = 8;
+//===----------------------------------------------------------------------===//
+// custom<FunctionTypes>
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseFunctionTypes(AsmParser &p,
+ FailureOr<SmallVector<Type>> ¶ms,
+ FailureOr<bool> &isVarArg) {
+ params.emplace();
+ isVarArg = false;
+ // `(` `)`
+ if (succeeded(p.parseOptionalRParen()))
+ return success();
+
+ // `(` `...` `)`
+ if (succeeded(p.parseOptionalEllipsis())) {
+ isVarArg = true;
+ return p.parseRParen();
+ }
+
+ // type (`,` type)* (`,` `...`)?
+ FailureOr<Type> type;
+ if (parsePrettyLLVMType(p, type))
+ return failure();
+ params->push_back(*type);
+ while (succeeded(p.parseOptionalComma())) {
+ if (succeeded(p.parseOptionalEllipsis())) {
+ isVarArg = true;
+ return p.parseRParen();
+ }
+ if (parsePrettyLLVMType(p, type))
+ return failure();
+ params->push_back(*type);
+ }
+ return p.parseRParen();
+}
+
+static void printFunctionTypes(AsmPrinter &p, ArrayRef<Type> params,
+ bool isVarArg) {
+ llvm::interleaveComma(params, p,
+ [&](Type type) { printPrettyLLVMType(p, type); });
+ if (isVarArg) {
+ if (!params.empty())
+ p << ", ";
+ p << "...";
+ }
+ p << ')';
+}
+
+//===----------------------------------------------------------------------===//
+// ODS-Generated Definitions
+//===----------------------------------------------------------------------===//
+
+/// These are unused for now.
+/// TODO: Move over to these once more types have been migrated to TypeDef.
+LLVM_ATTRIBUTE_UNUSED static OptionalParseResult
+generatedTypeParser(AsmParser &parser, StringRef *mnemonic, Type &value);
+LLVM_ATTRIBUTE_UNUSED static LogicalResult
+generatedTypePrinter(Type def, AsmPrinter &printer);
+
+#include "mlir/Dialect/LLVMIR/LLVMTypeInterfaces.cpp.inc"
+
+#define GET_TYPEDEF_CLASSES
+#include "mlir/Dialect/LLVMIR/LLVMTypes.cpp.inc"
+
//===----------------------------------------------------------------------===//
// LLVMArrayType
//===----------------------------------------------------------------------===//
@@ -130,25 +194,8 @@ LLVMFunctionType LLVMFunctionType::clone(TypeRange inputs,
return get(results[0], llvm::to_vector(inputs), isVarArg());
}
-Type LLVMFunctionType::getReturnType() const {
- return getImpl()->getReturnType();
-}
ArrayRef<Type> LLVMFunctionType::getReturnTypes() const {
- return getImpl()->getReturnType();
-}
-
-unsigned LLVMFunctionType::getNumParams() {
- return getImpl()->getArgumentTypes().size();
-}
-
-Type LLVMFunctionType::getParamType(unsigned i) {
- return getImpl()->getArgumentTypes()[i];
-}
-
-bool LLVMFunctionType::isVarArg() const { return getImpl()->isVariadic(); }
-
-ArrayRef<Type> LLVMFunctionType::getParams() const {
- return getImpl()->getArgumentTypes();
+ return static_cast<detail::LLVMFunctionTypeStorage *>(getImpl())->returnType;
}
LogicalResult
@@ -164,10 +211,14 @@ LLVMFunctionType::verify(function_ref<InFlightDiagnostic()> emitError,
return success();
}
+//===----------------------------------------------------------------------===//
+// SubElementTypeInterface
+
void LLVMFunctionType::walkImmediateSubElements(
function_ref<void(Attribute)> walkAttrsFn,
function_ref<void(Type)> walkTypesFn) const {
- for (Type type : llvm::concat<const Type>(getReturnTypes(), getParams()))
+ walkTypesFn(getReturnType());
+ for (Type type : getParams())
walkTypesFn(type);
}
@@ -1005,22 +1056,6 @@ llvm::TypeSize mlir::LLVM::getPrimitiveTypeSizeInBits(Type type) {
});
}
-//===----------------------------------------------------------------------===//
-// ODS-Generated Definitions
-//===----------------------------------------------------------------------===//
-
-/// These are unused for now.
-/// TODO: Move over to these once more types have been migrated to TypeDef.
-LLVM_ATTRIBUTE_UNUSED static OptionalParseResult
-generatedTypeParser(AsmParser &parser, StringRef *mnemonic, Type &value);
-LLVM_ATTRIBUTE_UNUSED static LogicalResult
-generatedTypePrinter(Type def, AsmPrinter &printer);
-
-#include "mlir/Dialect/LLVMIR/LLVMTypeInterfaces.cpp.inc"
-
-#define GET_TYPEDEF_CLASSES
-#include "mlir/Dialect/LLVMIR/LLVMTypes.cpp.inc"
-
//===----------------------------------------------------------------------===//
// LLVMDialect
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/LLVMIR/IR/TypeDetail.h b/mlir/lib/Dialect/LLVMIR/IR/TypeDetail.h
index 7ebfe8e69f8c..d13452f6e0af 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/TypeDetail.h
+++ b/mlir/lib/Dialect/LLVMIR/IR/TypeDetail.h
@@ -321,62 +321,6 @@ struct LLVMStructTypeStorage : public TypeStorage {
unsigned identifiedBodySizeAndFlags = 0;
};
-//===----------------------------------------------------------------------===//
-// LLVMFunctionTypeStorage.
-//===----------------------------------------------------------------------===//
-
-/// Type storage for LLVM dialect function types. These are uniqued using the
-/// list of types they contain and the vararg bit.
-struct LLVMFunctionTypeStorage : public TypeStorage {
- using KeyTy = std::tuple<Type, ArrayRef<Type>, bool>;
-
- /// Construct a storage from the given components. The list is expected to be
- /// allocated in the context.
- LLVMFunctionTypeStorage(Type result, ArrayRef<Type> arguments, bool variadic)
- : resultType(result), isVariadicFlag(variadic),
- numArguments(arguments.size()), argumentTypes(arguments.data()) {}
-
- /// Hook into the type uniquing infrastructure.
- static LLVMFunctionTypeStorage *construct(TypeStorageAllocator &allocator,
- const KeyTy &key) {
- return new (allocator.allocate<LLVMFunctionTypeStorage>())
- LLVMFunctionTypeStorage(std::get<0>(key),
- allocator.copyInto(std::get<1>(key)),
- std::get<2>(key));
- }
-
- static unsigned hashKey(const KeyTy &key) {
- // LLVM doesn't like hashing bools in tuples.
- return llvm::hash_combine(std::get<0>(key), std::get<1>(key),
- static_cast<int>(std::get<2>(key)));
- }
-
- bool operator==(const KeyTy &key) const {
- return std::make_tuple(getReturnType(), getArgumentTypes(), isVariadic()) ==
- key;
- }
-
- /// Returns the list of function argument types.
- ArrayRef<Type> getArgumentTypes() const {
- return ArrayRef<Type>(argumentTypes, numArguments);
- }
-
- /// Checks whether the function type is variadic.
- bool isVariadic() const { return isVariadicFlag; }
-
- /// Returns the function result type.
- const Type &getReturnType() const { return resultType; }
-
-private:
- /// The result type of the function.
- Type resultType;
- /// Flag indicating if the function is variadic.
- bool isVariadicFlag;
- /// The argument types of the function.
- unsigned numArguments;
- const Type *argumentTypes;
-};
-
//===----------------------------------------------------------------------===//
// LLVMPointerTypeStorage.
//===----------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list