[Mlir-commits] [mlir] d959244 - [mlir][llvm] Move LLVMFunctionType to a TypeDef

Jeff Niu llvmlistbot at llvm.org
Fri Oct 21 15:13:19 PDT 2022


Author: Jeff Niu
Date: 2022-10-21T15:13:07-07:00
New Revision: d9592444cea11fcc9e8debe0f5eff331bdabbdc4

URL: https://github.com/llvm/llvm-project/commit/d9592444cea11fcc9e8debe0f5eff331bdabbdc4
DIFF: https://github.com/llvm/llvm-project/commit/d9592444cea11fcc9e8debe0f5eff331bdabbdc4.diff

LOG: [mlir][llvm] Move LLVMFunctionType to a TypeDef

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D136485

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
    mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
    mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
    mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
    mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
    mlir/lib/Dialect/LLVMIR/IR/TypeDetail.h

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
index da24d4df8129..237082b44479 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
@@ -72,69 +72,6 @@ DEFINE_TRIVIAL_LLVM_TYPE(LLVMMetadataType);
 
 #undef DEFINE_TRIVIAL_LLVM_TYPE
 
-//===----------------------------------------------------------------------===//
-// LLVMFunctionType.
-//===----------------------------------------------------------------------===//
-
-/// LLVM dialect function type. It consists of a single return type (unlike MLIR
-/// which can have multiple), a list of parameter types and can optionally be
-/// variadic.
-class LLVMFunctionType : public Type::TypeBase<LLVMFunctionType, Type,
-                                               detail::LLVMFunctionTypeStorage,
-                                               SubElementTypeInterface::Trait> {
-public:
-  /// Inherit base constructors.
-  using Base::Base;
-  using Base::getChecked;
-
-  /// Checks if the given type can be used an argument in a function type.
-  static bool isValidArgumentType(Type type);
-
-  /// Checks if the given type can be used as a result in a function type.
-  static bool isValidResultType(Type type);
-
-  /// Returns whether the function is variadic.
-  bool isVarArg() const;
-
-  /// Gets or creates an instance of LLVM dialect function in the same context
-  /// as the `result` type.
-  static LLVMFunctionType get(Type result, ArrayRef<Type> arguments,
-                              bool isVarArg = false);
-  static LLVMFunctionType
-  getChecked(function_ref<InFlightDiagnostic()> emitError, Type result,
-             ArrayRef<Type> arguments, bool isVarArg = false);
-
-  /// Returns a clone of this function type with the given argument
-  /// and result types.
-  LLVMFunctionType clone(TypeRange inputs, TypeRange results) const;
-
-  /// Returns the result type of the function.
-  Type getReturnType() const;
-
-  /// Returns the result type of the function as an ArrayRef, enabling better
-  /// integration with generic MLIR utilities.
-  ArrayRef<Type> getReturnTypes() const;
-
-  /// Returns the number of arguments to the function.
-  unsigned getNumParams();
-
-  /// Returns `i`-th argument of the function. Asserts on out-of-bounds.
-  Type getParamType(unsigned i);
-
-  /// Returns a list of argument types of the function.
-  ArrayRef<Type> getParams() const;
-  ArrayRef<Type> params() { return getParams(); }
-
-  /// Verifies that the type about to be constructed is well-formed.
-  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
-                              Type result, ArrayRef<Type> arguments, bool);
-
-  void walkImmediateSubElements(function_ref<void(Attribute)> walkAttrsFn,
-                                function_ref<void(Type)> walkTypesFn) const;
-  Type replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
-                                   ArrayRef<Type> replTypes) const;
-};
-
 //===----------------------------------------------------------------------===//
 // LLVMPointerType.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
index b347406f0521..6ddef17aa92a 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
@@ -58,4 +58,64 @@ def LLVMArrayType : LLVMType<"LLVMArray", "array", [
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// LLVMFunctionType
+//===----------------------------------------------------------------------===//
+
+def LLVMFunctionType : LLVMType<"LLVMFunction", "func", [
+    DeclareTypeInterfaceMethods<SubElementTypeInterface>]> {
+  let summary = "LLVM function type";
+  let description = [{
+    The `!llvm.func` is a function type. It consists of a single return type
+    (unlike MLIR which can have multiple), a list of parameter types and can
+    optionally be variadic.
+
+    Example:
+
+    ```mlir
+    !llvm.func<i32 (i32)>
+    ```
+  }];
+
+  let parameters = (ins "Type":$returnType, ArrayRefParameter<"Type">:$params,
+                        "bool":$varArg);
+  let assemblyFormat = [{
+    `<` custom<PrettyLLVMType>($returnType) ` ` `(`
+    custom<FunctionTypes>($params, $varArg) `>`
+  }];
+
+  let genVerifyDecl = 1;
+
+  let builders = [
+    TypeBuilderWithInferredContext<(ins
+      "Type":$result, "ArrayRef<Type>":$arguments,
+      CArg<"bool", "false">:$isVarArg)>
+  ];
+
+  let extraClassDeclaration = [{
+    /// Checks if the given type can be used an argument in a function type.
+    static bool isValidArgumentType(Type type);
+
+    /// Checks if the given type can be used as a result in a function type.
+    static bool isValidResultType(Type type);
+
+    /// Returns whether the function is variadic.
+    bool isVarArg() const { return getVarArg(); }
+
+    /// Returns a clone of this function type with the given argument
+    /// and result types.
+    LLVMFunctionType clone(TypeRange inputs, TypeRange results) const;
+
+    /// Returns the result type of the function as an ArrayRef, enabling better
+    /// integration with generic MLIR utilities.
+    ArrayRef<Type> getReturnTypes() const;
+
+    /// Returns the number of arguments to the function.
+    unsigned getNumParams() const { return getParams().size(); }
+
+    /// Returns `i`-th argument of the function. Asserts on out-of-bounds.
+    Type getParamType(unsigned i) { return getParams()[i]; }
+  }];
+}
+
 #endif // LLVMTYPES_TD

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 72f37f9a1780..71ba80d0ce7e 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -2573,7 +2573,6 @@ void LLVMDialect::initialize() {
            LLVMTokenType,
            LLVMLabelType,
            LLVMMetadataType,
-           LLVMFunctionType,
            LLVMPointerType,
            LLVMFixedVectorType,
            LLVMScalableVectorType,

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
index 932477488cb5..566ef63110b3 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
@@ -107,22 +107,6 @@ static void printVectorType(AsmPrinter &printer, TypeTy type) {
   printer << '>';
 }
 
-/// Prints a function type.
-static void printFunctionType(AsmPrinter &printer, LLVMFunctionType funcType) {
-  printer << '<';
-  dispatchPrint(printer, funcType.getReturnType());
-  printer << " (";
-  llvm::interleaveComma(
-      funcType.getParams(), printer.getStream(),
-      [&printer](Type subtype) { dispatchPrint(printer, subtype); });
-  if (funcType.isVarArg()) {
-    if (funcType.getNumParams() != 0)
-      printer << ", ";
-    printer << "...";
-  }
-  printer << ")>";
-}
-
 /// Prints the given LLVM dialect type recursively. This leverages closedness of
 /// the LLVM dialect type system to avoid printing the dialect prefix
 /// repeatedly. For recursive structures, only prints the name of the structure
@@ -171,7 +155,7 @@ void mlir::LLVM::detail::printType(Type type, AsmPrinter &printer) {
     return printStructType(printer, structType);
 
   if (auto funcType = type.dyn_cast<LLVMFunctionType>())
-    return printFunctionType(printer, funcType);
+    return funcType.print(printer);
 }
 
 //===----------------------------------------------------------------------===//
@@ -180,45 +164,6 @@ void mlir::LLVM::detail::printType(Type type, AsmPrinter &printer) {
 
 static ParseResult dispatchParse(AsmParser &parser, Type &type);
 
-/// Parses an LLVM dialect function type.
-///   llvm-type :: = `func<` llvm-type `(` llvm-type-list `...`? `)>`
-static LLVMFunctionType parseFunctionType(AsmParser &parser) {
-  SMLoc loc = parser.getCurrentLocation();
-  Type returnType;
-  if (parser.parseLess() || dispatchParse(parser, returnType) ||
-      parser.parseLParen())
-    return LLVMFunctionType();
-
-  // Function type without arguments.
-  if (succeeded(parser.parseOptionalRParen())) {
-    if (succeeded(parser.parseGreater()))
-      return parser.getChecked<LLVMFunctionType>(loc, returnType, llvm::None,
-                                                 /*isVarArg=*/false);
-    return LLVMFunctionType();
-  }
-
-  // Parse arguments.
-  SmallVector<Type, 8> argTypes;
-  do {
-    if (succeeded(parser.parseOptionalEllipsis())) {
-      if (parser.parseOptionalRParen() || parser.parseOptionalGreater())
-        return LLVMFunctionType();
-      return parser.getChecked<LLVMFunctionType>(loc, returnType, argTypes,
-                                                 /*isVarArg=*/true);
-    }
-
-    Type arg;
-    if (dispatchParse(parser, arg))
-      return LLVMFunctionType();
-    argTypes.push_back(arg);
-  } while (succeeded(parser.parseOptionalComma()));
-
-  if (parser.parseOptionalRParen() || parser.parseOptionalGreater())
-    return LLVMFunctionType();
-  return parser.getChecked<LLVMFunctionType>(loc, returnType, argTypes,
-                                             /*isVarArg=*/false);
-}
-
 /// Parses an LLVM dialect pointer type.
 ///   llvm-type ::= `ptr<` llvm-type (`,` integer)? `>`
 ///               | `ptr` (`<` integer `>`)?
@@ -445,7 +390,7 @@ static Type dispatchParse(AsmParser &parser, bool allowAny = true) {
       .Case("token", [&] { return LLVMTokenType::get(ctx); })
       .Case("label", [&] { return LLVMLabelType::get(ctx); })
       .Case("metadata", [&] { return LLVMMetadataType::get(ctx); })
-      .Case("func", [&] { return parseFunctionType(parser); })
+      .Case("func", [&] { return LLVMFunctionType::parse(parser); })
       .Case("ptr", [&] { return parsePointerType(parser); })
       .Case("vec", [&] { return parseVectorType(parser); })
       .Case("array", [&] { return LLVMArrayType::parse(parser); })

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
index 10170b02c382..f55d2ae45c45 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
@@ -27,6 +27,70 @@ using namespace mlir::LLVM;
 
 constexpr const static unsigned kBitsInByte = 8;
 
+//===----------------------------------------------------------------------===//
+// custom<FunctionTypes>
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseFunctionTypes(AsmParser &p,
+                                      FailureOr<SmallVector<Type>> &params,
+                                      FailureOr<bool> &isVarArg) {
+  params.emplace();
+  isVarArg = false;
+  // `(` `)`
+  if (succeeded(p.parseOptionalRParen()))
+    return success();
+
+  // `(` `...` `)`
+  if (succeeded(p.parseOptionalEllipsis())) {
+    isVarArg = true;
+    return p.parseRParen();
+  }
+
+  // type (`,` type)* (`,` `...`)?
+  FailureOr<Type> type;
+  if (parsePrettyLLVMType(p, type))
+    return failure();
+  params->push_back(*type);
+  while (succeeded(p.parseOptionalComma())) {
+    if (succeeded(p.parseOptionalEllipsis())) {
+      isVarArg = true;
+      return p.parseRParen();
+    }
+    if (parsePrettyLLVMType(p, type))
+      return failure();
+    params->push_back(*type);
+  }
+  return p.parseRParen();
+}
+
+static void printFunctionTypes(AsmPrinter &p, ArrayRef<Type> params,
+                               bool isVarArg) {
+  llvm::interleaveComma(params, p,
+                        [&](Type type) { printPrettyLLVMType(p, type); });
+  if (isVarArg) {
+    if (!params.empty())
+      p << ", ";
+    p << "...";
+  }
+  p << ')';
+}
+
+//===----------------------------------------------------------------------===//
+// ODS-Generated Definitions
+//===----------------------------------------------------------------------===//
+
+/// These are unused for now.
+/// TODO: Move over to these once more types have been migrated to TypeDef.
+LLVM_ATTRIBUTE_UNUSED static OptionalParseResult
+generatedTypeParser(AsmParser &parser, StringRef *mnemonic, Type &value);
+LLVM_ATTRIBUTE_UNUSED static LogicalResult
+generatedTypePrinter(Type def, AsmPrinter &printer);
+
+#include "mlir/Dialect/LLVMIR/LLVMTypeInterfaces.cpp.inc"
+
+#define GET_TYPEDEF_CLASSES
+#include "mlir/Dialect/LLVMIR/LLVMTypes.cpp.inc"
+
 //===----------------------------------------------------------------------===//
 // LLVMArrayType
 //===----------------------------------------------------------------------===//
@@ -130,25 +194,8 @@ LLVMFunctionType LLVMFunctionType::clone(TypeRange inputs,
   return get(results[0], llvm::to_vector(inputs), isVarArg());
 }
 
-Type LLVMFunctionType::getReturnType() const {
-  return getImpl()->getReturnType();
-}
 ArrayRef<Type> LLVMFunctionType::getReturnTypes() const {
-  return getImpl()->getReturnType();
-}
-
-unsigned LLVMFunctionType::getNumParams() {
-  return getImpl()->getArgumentTypes().size();
-}
-
-Type LLVMFunctionType::getParamType(unsigned i) {
-  return getImpl()->getArgumentTypes()[i];
-}
-
-bool LLVMFunctionType::isVarArg() const { return getImpl()->isVariadic(); }
-
-ArrayRef<Type> LLVMFunctionType::getParams() const {
-  return getImpl()->getArgumentTypes();
+  return static_cast<detail::LLVMFunctionTypeStorage *>(getImpl())->returnType;
 }
 
 LogicalResult
@@ -164,10 +211,14 @@ LLVMFunctionType::verify(function_ref<InFlightDiagnostic()> emitError,
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// SubElementTypeInterface
+
 void LLVMFunctionType::walkImmediateSubElements(
     function_ref<void(Attribute)> walkAttrsFn,
     function_ref<void(Type)> walkTypesFn) const {
-  for (Type type : llvm::concat<const Type>(getReturnTypes(), getParams()))
+  walkTypesFn(getReturnType());
+  for (Type type : getParams())
     walkTypesFn(type);
 }
 
@@ -1005,22 +1056,6 @@ llvm::TypeSize mlir::LLVM::getPrimitiveTypeSizeInBits(Type type) {
       });
 }
 
-//===----------------------------------------------------------------------===//
-// ODS-Generated Definitions
-//===----------------------------------------------------------------------===//
-
-/// These are unused for now.
-/// TODO: Move over to these once more types have been migrated to TypeDef.
-LLVM_ATTRIBUTE_UNUSED static OptionalParseResult
-generatedTypeParser(AsmParser &parser, StringRef *mnemonic, Type &value);
-LLVM_ATTRIBUTE_UNUSED static LogicalResult
-generatedTypePrinter(Type def, AsmPrinter &printer);
-
-#include "mlir/Dialect/LLVMIR/LLVMTypeInterfaces.cpp.inc"
-
-#define GET_TYPEDEF_CLASSES
-#include "mlir/Dialect/LLVMIR/LLVMTypes.cpp.inc"
-
 //===----------------------------------------------------------------------===//
 // LLVMDialect
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/TypeDetail.h b/mlir/lib/Dialect/LLVMIR/IR/TypeDetail.h
index 7ebfe8e69f8c..d13452f6e0af 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/TypeDetail.h
+++ b/mlir/lib/Dialect/LLVMIR/IR/TypeDetail.h
@@ -321,62 +321,6 @@ struct LLVMStructTypeStorage : public TypeStorage {
   unsigned identifiedBodySizeAndFlags = 0;
 };
 
-//===----------------------------------------------------------------------===//
-// LLVMFunctionTypeStorage.
-//===----------------------------------------------------------------------===//
-
-/// Type storage for LLVM dialect function types. These are uniqued using the
-/// list of types they contain and the vararg bit.
-struct LLVMFunctionTypeStorage : public TypeStorage {
-  using KeyTy = std::tuple<Type, ArrayRef<Type>, bool>;
-
-  /// Construct a storage from the given components. The list is expected to be
-  /// allocated in the context.
-  LLVMFunctionTypeStorage(Type result, ArrayRef<Type> arguments, bool variadic)
-      : resultType(result), isVariadicFlag(variadic),
-        numArguments(arguments.size()), argumentTypes(arguments.data()) {}
-
-  /// Hook into the type uniquing infrastructure.
-  static LLVMFunctionTypeStorage *construct(TypeStorageAllocator &allocator,
-                                            const KeyTy &key) {
-    return new (allocator.allocate<LLVMFunctionTypeStorage>())
-        LLVMFunctionTypeStorage(std::get<0>(key),
-                                allocator.copyInto(std::get<1>(key)),
-                                std::get<2>(key));
-  }
-
-  static unsigned hashKey(const KeyTy &key) {
-    // LLVM doesn't like hashing bools in tuples.
-    return llvm::hash_combine(std::get<0>(key), std::get<1>(key),
-                              static_cast<int>(std::get<2>(key)));
-  }
-
-  bool operator==(const KeyTy &key) const {
-    return std::make_tuple(getReturnType(), getArgumentTypes(), isVariadic()) ==
-           key;
-  }
-
-  /// Returns the list of function argument types.
-  ArrayRef<Type> getArgumentTypes() const {
-    return ArrayRef<Type>(argumentTypes, numArguments);
-  }
-
-  /// Checks whether the function type is variadic.
-  bool isVariadic() const { return isVariadicFlag; }
-
-  /// Returns the function result type.
-  const Type &getReturnType() const { return resultType; }
-
-private:
-  /// The result type of the function.
-  Type resultType;
-  /// Flag indicating if the function is variadic.
-  bool isVariadicFlag;
-  /// The argument types of the function.
-  unsigned numArguments;
-  const Type *argumentTypes;
-};
-
 //===----------------------------------------------------------------------===//
 // LLVMPointerTypeStorage.
 //===----------------------------------------------------------------------===//


        


More information about the Mlir-commits mailing list