[Mlir-commits] [mlir] 0c40af6 - [mlir] First-party modeling of LLVM types
Alex Zinenko
llvmlistbot at llvm.org
Mon Aug 3 06:45:37 PDT 2020
Author: Alex Zinenko
Date: 2020-08-03T15:45:29+02:00
New Revision: 0c40af6b594f6eb2dcd43cdb2bc2f4584ec8ca15
URL: https://github.com/llvm/llvm-project/commit/0c40af6b594f6eb2dcd43cdb2bc2f4584ec8ca15
DIFF: https://github.com/llvm/llvm-project/commit/0c40af6b594f6eb2dcd43cdb2bc2f4584ec8ca15.diff
LOG: [mlir] First-party modeling of LLVM types
The current modeling of LLVM IR types in MLIR is based on the LLVMType class
that wraps a raw `llvm::Type *` and delegates uniquing, printing and parsing to
LLVM itself. This model makes thread-safe type manipulation hard and is being
progressively replaced with a cleaner MLIR model that replicates the type
system. Introduce a set of classes reflecting the LLVM IR type system in MLIR
instead of wrapping the existing types. These are currently introduced as
separate classes without affecting the dialect flow, and are exercised through
a test dialect. Once feature parity is reached, the old implementation will be
gradually substituted with the new one.
Depends On D84171
Reviewed By: rriddle
Differential Revision: https://reviews.llvm.org/D84339
Added:
mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
mlir/lib/Dialect/LLVMIR/IR/TypeDetail.h
mlir/test/Dialect/LLVMIR/types-invalid.mlir
mlir/test/Dialect/LLVMIR/types.mlir
mlir/test/lib/Dialect/LLVMIR/CMakeLists.txt
mlir/test/lib/Dialect/LLVMIR/LLVMTypeTestDialect.cpp
Modified:
mlir/include/mlir/IR/DialectImplementation.h
mlir/lib/Dialect/LLVMIR/CMakeLists.txt
mlir/lib/Parser/DialectSymbolParser.cpp
mlir/test/lib/Dialect/CMakeLists.txt
mlir/tools/mlir-opt/CMakeLists.txt
mlir/tools/mlir-opt/mlir-opt.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
new file mode 100644
index 000000000000..6764f9815c3f
--- /dev/null
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
@@ -0,0 +1,470 @@
+//===- LLVMDialect.h - MLIR LLVM dialect types ------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the types for the LLVM dialect in MLIR. These MLIR types
+// correspond to the LLVM IR type system.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_LLVMIR_LLVMTYPES_H_
+#define MLIR_DIALECT_LLVMIR_LLVMTYPES_H_
+
+#include "mlir/IR/Types.h"
+
+namespace llvm {
+class ElementCount;
+} // namespace llvm
+
+namespace mlir {
+
+class DialectAsmParser;
+class DialectAsmPrinter;
+
+namespace LLVM {
+namespace detail {
+struct LLVMFunctionTypeStorage;
+struct LLVMIntegerTypeStorage;
+struct LLVMPointerTypeStorage;
+struct LLVMStructTypeStorage;
+struct LLVMTypeAndSizeStorage;
+} // namespace detail
+
+//===----------------------------------------------------------------------===//
+// LLVMTypeNew.
+//===----------------------------------------------------------------------===//
+
+/// Base class for LLVM dialect types.
+///
+/// The LLVM dialect in MLIR fully reflects the LLVM IR type system, prodiving a
+/// sperate MLIR type for each LLVM IR type. All types are represted as separate
+/// subclasses and are compatible with the isa/cast infrastructure. For
+/// convenience, the base class provides most of the APIs available on
+/// llvm::Type in addition to MLIR-compatible APIs.
+///
+/// The LLVM dialect type system is closed: parametric types can only refer to
+/// other LLVM dialect types. This is consistent with LLVM IR and enables a more
+/// concise pretty-printing format.
+///
+/// Similarly to other MLIR types, LLVM dialect types are owned by the MLIR
+/// context, have an immutable identifier (for most types except identified
+/// structs, the entire type is the identifier) and are thread-safe.
+class LLVMTypeNew : public Type {
+public:
+ enum Kind {
+ // Keep non-parametric types contiguous in the enum.
+ VoidType = FIRST_LLVM_TYPE + 1,
+ HalfType,
+ BFloatType,
+ FloatType,
+ DoubleType,
+ FP128Type,
+ X86FP80Type,
+ PPCFP128Type,
+ X86MMXType,
+ LabelType,
+ TokenType,
+ MetadataType,
+ // End of non-parametric types.
+ FunctionType,
+ IntegerType,
+ PointerType,
+ FixedVectorType,
+ ScalableVectorType,
+ ArrayType,
+ StructType,
+ FIRST_NEW_LLVM_TYPE = VoidType,
+ LAST_NEW_LLVM_TYPE = StructType,
+ FIRST_TRIVIAL_TYPE = VoidType,
+ LAST_TRIVIAL_TYPE = MetadataType
+ };
+
+ /// Inherit base constructors.
+ using Type::Type;
+
+ /// Support for PointerLikeTypeTraits.
+ using Type::getAsOpaquePointer;
+ static LLVMTypeNew getFromOpaquePointer(const void *ptr) {
+ return LLVMTypeNew(static_cast<ImplType *>(const_cast<void *>(ptr)));
+ }
+
+ /// Support for isa/cast.
+ static bool kindof(unsigned kind) {
+ return FIRST_NEW_LLVM_TYPE <= kind && kind <= LAST_NEW_LLVM_TYPE;
+ }
+};
+
+//===----------------------------------------------------------------------===//
+// Trivial types.
+//===----------------------------------------------------------------------===//
+
+// Batch-define trivial types.
+#define DEFINE_TRIVIAL_LLVM_TYPE(ClassName, Kind) \
+ class ClassName \
+ : public Type::TypeBase<ClassName, LLVMTypeNew, TypeStorage> { \
+ public: \
+ using Base::Base; \
+ static bool kindof(unsigned kind) { return kind == Kind; } \
+ static ClassName get(MLIRContext *context) { \
+ return Base::get(context, Kind); \
+ } \
+ }
+
+DEFINE_TRIVIAL_LLVM_TYPE(LLVMVoidType, LLVMTypeNew::VoidType);
+DEFINE_TRIVIAL_LLVM_TYPE(LLVMHalfType, LLVMTypeNew::HalfType);
+DEFINE_TRIVIAL_LLVM_TYPE(LLVMBFloatType, LLVMTypeNew::BFloatType);
+DEFINE_TRIVIAL_LLVM_TYPE(LLVMFloatType, LLVMTypeNew::FloatType);
+DEFINE_TRIVIAL_LLVM_TYPE(LLVMDoubleType, LLVMTypeNew::DoubleType);
+DEFINE_TRIVIAL_LLVM_TYPE(LLVMFP128Type, LLVMTypeNew::FP128Type);
+DEFINE_TRIVIAL_LLVM_TYPE(LLVMX86FP80Type, LLVMTypeNew::X86FP80Type);
+DEFINE_TRIVIAL_LLVM_TYPE(LLVMPPCFP128Type, LLVMTypeNew::PPCFP128Type);
+DEFINE_TRIVIAL_LLVM_TYPE(LLVMX86MMXType, LLVMTypeNew::X86MMXType);
+DEFINE_TRIVIAL_LLVM_TYPE(LLVMTokenType, LLVMTypeNew::TokenType);
+DEFINE_TRIVIAL_LLVM_TYPE(LLVMLabelType, LLVMTypeNew::LabelType);
+DEFINE_TRIVIAL_LLVM_TYPE(LLVMMetadataType, LLVMTypeNew::MetadataType);
+
+#undef DEFINE_TRIVIAL_LLVM_TYPE
+
+//===----------------------------------------------------------------------===//
+// LLVMArrayType.
+//===----------------------------------------------------------------------===//
+
+/// LLVM dialect array type. It is an aggregate type representing consecutive
+/// elements in memory, parameterized by the number of elements and the element
+/// type.
+class LLVMArrayType : public Type::TypeBase<LLVMArrayType, LLVMTypeNew,
+ detail::LLVMTypeAndSizeStorage> {
+public:
+ /// Inherit base constructors.
+ using Base::Base;
+
+ /// Support for isa/cast.
+ static bool kindof(unsigned kind) { return kind == LLVMTypeNew::ArrayType; }
+
+ /// Gets or creates an instance of LLVM dialect array type containing
+ /// `numElements` of `elementType`, in the same context as `elementType`.
+ static LLVMArrayType get(LLVMTypeNew elementType, unsigned numElements);
+
+ /// Returns the element type of the array.
+ LLVMTypeNew getElementType();
+
+ /// Returns the number of elements in the array type.
+ unsigned getNumElements();
+};
+
+//===----------------------------------------------------------------------===//
+// LLVMFunctionType.
+//===----------------------------------------------------------------------===//
+
+/// LLVM dialect function type. It consists of a single return type (unlike MLIR
+/// which can have multiple), a list of parameter types and can optionally be
+/// variadic.
+class LLVMFunctionType
+ : public Type::TypeBase<LLVMFunctionType, LLVMTypeNew,
+ detail::LLVMFunctionTypeStorage> {
+public:
+ /// Inherit base constructors.
+ using Base::Base;
+
+ /// Support for isa/cast.
+ static bool kindof(unsigned kind) {
+ return kind == LLVMTypeNew::FunctionType;
+ }
+
+ /// Gets or creates an instance of LLVM dialect function in the same context
+ /// as the `result` type.
+ static LLVMFunctionType get(LLVMTypeNew result,
+ ArrayRef<LLVMTypeNew> arguments,
+ bool isVarArg = false);
+
+ /// Returns the result type of the function.
+ LLVMTypeNew getReturnType();
+
+ /// Returns the number of arguments to the function.
+ unsigned getNumParams();
+
+ /// Returns `i`-th argument of the function. Asserts on out-of-bounds.
+ LLVMTypeNew getParamType(unsigned i);
+
+ /// Returns whether the function is variadic.
+ bool isVarArg();
+
+ /// Returns a list of argument types of the function.
+ ArrayRef<LLVMTypeNew> getParams();
+ ArrayRef<LLVMTypeNew> params() { return getParams(); }
+};
+
+//===----------------------------------------------------------------------===//
+// LLVMIntegerType.
+//===----------------------------------------------------------------------===//
+
+/// LLVM dialect signless integer type parameterized by bitwidth.
+class LLVMIntegerType : public Type::TypeBase<LLVMIntegerType, LLVMTypeNew,
+ detail::LLVMIntegerTypeStorage> {
+public:
+ /// Inherit base constructor.
+ using Base::Base;
+
+ /// Support for isa/cast.
+ static bool kindof(unsigned kind) { return kind == LLVMTypeNew::IntegerType; }
+
+ /// Gets or creates an instance of the integer of the specified `bitwidth` in
+ /// the given context.
+ static LLVMIntegerType get(MLIRContext *ctx, unsigned bitwidth);
+
+ /// Returns the bitwidth of this integer type.
+ unsigned getBitWidth();
+};
+
+//===----------------------------------------------------------------------===//
+// LLVMPointerType.
+//===----------------------------------------------------------------------===//
+
+/// LLVM dialect pointer type. This type typically represents a reference to an
+/// object in memory. It is parameterized by the element type and the address
+/// space.
+class LLVMPointerType : public Type::TypeBase<LLVMPointerType, LLVMTypeNew,
+ detail::LLVMPointerTypeStorage> {
+public:
+ /// Inherit base constructors.
+ using Base::Base;
+
+ /// Support for isa/cast.
+ static bool kindof(unsigned kind) { return kind == LLVMTypeNew::PointerType; }
+
+ /// Gets or creates an instance of LLVM dialect pointer type pointing to an
+ /// object of `pointee` type in the given address space. The pointer type is
+ /// created in the same context as `pointee`.
+ static LLVMPointerType get(LLVMTypeNew pointee, unsigned addressSpace = 0);
+
+ /// Returns the pointed-to type.
+ LLVMTypeNew getElementType();
+
+ /// Returns the address space of the pointer.
+ unsigned getAddressSpace();
+};
+
+//===----------------------------------------------------------------------===//
+// LLVMStructType.
+//===----------------------------------------------------------------------===//
+
+/// LLVM dialect structure type representing a collection of
diff erent-typed
+/// elements manipulated together. Structured can optionally be packed, meaning
+/// that their elements immediately follow each other in memory without
+/// accounting for potential alignment.
+///
+/// Structure types can be identified (named) or literal. Literal structures
+/// are uniquely represented by the list of types they contain and packedness.
+/// Literal structure types are immutable after construction.
+///
+/// Identified structures are uniquely represented by their name, a string. They
+/// have a mutable component, consisting of the list of types they contain,
+/// the packedness and the opacity bits. Identified structs can be created
+/// without providing the lists of element types, making them suitable to
+/// represent recursive, i.e. self-referring, structures. Identified structs
+/// without body are considered opaque. For such structs, one can set the body.
+/// Identified structs can be created as intentionally-opaque, implying that the
+/// caller does not intend to ever set the body (e.g. forward-declarations of
+/// structs from another module) and wants to disallow further modification of
+/// the body. For intentionally-opaque structs or non-opaque structs with the
+/// body, one is not allowed to set another body (however, one can set exactly
+/// the same body).
+///
+/// Note that the packedness of the struct takes place in uniquing of literal
+/// structs, but does not in uniquing of identified structs.
+class LLVMStructType : public Type::TypeBase<LLVMStructType, LLVMTypeNew,
+ detail::LLVMStructTypeStorage> {
+public:
+ /// Inherit base construtors.
+ using Base::Base;
+
+ /// Support for isa/cast.
+ static bool kindof(unsigned kind) { return kind == LLVMTypeNew::StructType; }
+
+ /// Gets or creates an identified struct with the given name in the provided
+ /// context. Note that unlike llvm::StructType::create, this function will
+ /// _NOT_ rename a struct in case a struct with the same name already exists
+ /// in the context. Instead, it will just return the existing struct,
+ /// similarly to the rest of MLIR type ::get methods.
+ static LLVMStructType getIdentified(MLIRContext *context, StringRef name);
+
+ /// Gets or creates a literal struct with the given body in the provided
+ /// context.
+ static LLVMStructType getLiteral(MLIRContext *context,
+ ArrayRef<LLVMTypeNew> types,
+ bool isPacked = false);
+
+ /// Gets or creates an intentionally-opaque identified struct. Such a struct
+ /// cannot have its body set. To create an opaque struct with a mutable body,
+ /// use `getIdentified`. Note that unlike llvm::StructType::create, this
+ /// function will _NOT_ rename a struct in case a struct with the same name
+ /// already exists in the context. Instead, it will just return the existing
+ /// struct, similarly to the rest of MLIR type ::get methods.
+ static LLVMStructType getOpaque(StringRef name, MLIRContext *context);
+
+ /// Set the body of an identified struct. Returns failure if the body could
+ /// not be set, e.g. if the struct already has a body or if it was marked as
+ /// intentionally opaque. This might happen in a multi-threaded context when a
+ ///
diff erent thread modified the struct after it was created. Most callers
+ /// are likely to assert this always succeeds, but it is possible to implement
+ /// a local renaming scheme based on the result of this call.
+ LogicalResult setBody(ArrayRef<LLVMTypeNew> types, bool isPacked);
+
+ /// Checks if a struct is packed.
+ bool isPacked();
+
+ /// Checks if a struct is identified.
+ bool isIdentified();
+
+ /// Checks if a struct is opaque.
+ bool isOpaque();
+
+ /// Returns the name of an identified struct.
+ StringRef getName();
+
+ /// Returns the list of element types contained in a non-opaque struct.
+ ArrayRef<LLVMTypeNew> getBody();
+};
+
+//===----------------------------------------------------------------------===//
+// LLVMVectorType.
+//===----------------------------------------------------------------------===//
+
+/// LLVM dialect vector type, represents a sequence of elements that can be
+/// processed as one, typically in SIMD context. This is a base class for fixed
+/// and scalable vectors.
+class LLVMVectorType : public LLVMTypeNew {
+public:
+ /// Inherit base constructor.
+ using LLVMTypeNew::LLVMTypeNew;
+
+ /// Support for isa/cast.
+ static bool kindof(unsigned kind) {
+ return kind == LLVMTypeNew::FixedVectorType ||
+ kind == LLVMTypeNew::ScalableVectorType;
+ }
+
+ /// Returns the element type of the vector.
+ LLVMTypeNew getElementType();
+
+ /// Returns the number of elements in the vector.
+ llvm::ElementCount getElementCount();
+};
+
+//===----------------------------------------------------------------------===//
+// LLVMFixedVectorType.
+//===----------------------------------------------------------------------===//
+
+/// LLVM dialect fixed vector type, represents a sequence of elements of known
+/// length that can be processed as one.
+class LLVMFixedVectorType
+ : public Type::TypeBase<LLVMFixedVectorType, LLVMVectorType,
+ detail::LLVMTypeAndSizeStorage> {
+public:
+ /// Inherit base constructor.
+ using Base::Base;
+
+ /// Support for isa/cast.
+ static bool kindof(unsigned kind) {
+ return kind == LLVMTypeNew::FixedVectorType;
+ }
+
+ /// Gets or creates a fixed vector type containing `numElements` of
+ /// `elementType` in the same context as `elementType`.
+ static LLVMFixedVectorType get(LLVMTypeNew elementType, unsigned numElements);
+
+ /// Returns the number of elements in the fixed vector.
+ unsigned getNumElements();
+};
+
+//===----------------------------------------------------------------------===//
+// LLVMScalableVectorType.
+//===----------------------------------------------------------------------===//
+
+/// LLVM dialect scalable vector type, represents a sequence of elements of
+/// unknown length that is known to be divisible by some constant. These
+/// elements can be processed as one in SIMD context.
+class LLVMScalableVectorType
+ : public Type::TypeBase<LLVMScalableVectorType, LLVMVectorType,
+ detail::LLVMTypeAndSizeStorage> {
+public:
+ /// Inherit base constructor.
+ using Base::Base;
+
+ /// Support for isa/cast.
+ static bool kindof(unsigned kind) {
+ return kind == LLVMTypeNew::ScalableVectorType;
+ }
+
+ /// Gets or creates a scalable vector type containing a non-zero multiple of
+ /// `minNumElements` of `elementType` in the same context as `elementType`.
+ static LLVMScalableVectorType get(LLVMTypeNew elementType,
+ unsigned minNumElements);
+
+ /// Returns the scaling factor of the number of elements in the vector. The
+ /// vector contains at least the resulting number of elements, or any non-zero
+ /// multiple of this number.
+ unsigned getMinNumElements();
+};
+
+//===----------------------------------------------------------------------===//
+// Printing and parsing.
+//===----------------------------------------------------------------------===//
+
+namespace detail {
+/// Parses an LLVM dialect type.
+LLVMTypeNew parseType(DialectAsmParser &parser);
+
+/// Prints an LLVM Dialect type.
+void printType(LLVMTypeNew type, DialectAsmPrinter &printer);
+} // namespace detail
+
+} // namespace LLVM
+} // namespace mlir
+
+//===----------------------------------------------------------------------===//
+// Support for hashing and containers.
+//===----------------------------------------------------------------------===//
+
+namespace llvm {
+
+// LLVMTypeNew instances hash just like pointers.
+template <> struct DenseMapInfo<mlir::LLVM::LLVMTypeNew> {
+ static mlir::LLVM::LLVMTypeNew getEmptyKey() {
+ void *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
+ return mlir::LLVM::LLVMTypeNew(
+ static_cast<mlir::LLVM::LLVMTypeNew::ImplType *>(pointer));
+ }
+ static mlir::LLVM::LLVMTypeNew getTombstoneKey() {
+ void *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
+ return mlir::LLVM::LLVMTypeNew(
+ static_cast<mlir::LLVM::LLVMTypeNew::ImplType *>(pointer));
+ }
+ static unsigned getHashValue(mlir::LLVM::LLVMTypeNew val) {
+ return mlir::hash_value(val);
+ }
+ static bool isEqual(mlir::LLVM::LLVMTypeNew lhs,
+ mlir::LLVM::LLVMTypeNew rhs) {
+ return lhs == rhs;
+ }
+};
+
+// LLVMTypeNew behaves like a pointer similarly to mlir::Type.
+template <> struct PointerLikeTypeTraits<mlir::LLVM::LLVMTypeNew> {
+ static inline void *getAsVoidPointer(mlir::LLVM::LLVMTypeNew type) {
+ return const_cast<void *>(type.getAsOpaquePointer());
+ }
+ static inline mlir::LLVM::LLVMTypeNew getFromVoidPointer(void *ptr) {
+ return mlir::LLVM::LLVMTypeNew::getFromOpaquePointer(ptr);
+ }
+ static constexpr int NumLowBitsAvailable =
+ PointerLikeTypeTraits<mlir::Type>::NumLowBitsAvailable;
+};
+
+} // namespace llvm
+
+#endif // MLIR_DIALECT_LLVMIR_LLVMTYPES_H_
diff --git a/mlir/include/mlir/IR/DialectImplementation.h b/mlir/include/mlir/IR/DialectImplementation.h
index e2d7e2c409c4..c478b200b5d9 100644
--- a/mlir/include/mlir/IR/DialectImplementation.h
+++ b/mlir/include/mlir/IR/DialectImplementation.h
@@ -203,6 +203,9 @@ class DialectAsmParser {
/// Parse a `=` token if present.
virtual ParseResult parseOptionalEqual() = 0;
+ /// Parse a quoted string token if present.
+ virtual ParseResult parseOptionalString(StringRef *string) = 0;
+
/// Parse a given keyword.
ParseResult parseKeyword(StringRef keyword, const Twine &msg = "") {
auto loc = getCurrentLocation();
@@ -323,6 +326,9 @@ class DialectAsmParser {
return success();
}
+ /// Parse a type if present.
+ virtual OptionalParseResult parseOptionalType(Type &result) = 0;
+
/// Parse a 'x' separated dimension list. This populates the dimension list,
/// using -1 for the `?` dimensions if `allowDynamic` is set and errors out on
/// `?` otherwise.
diff --git a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
index e858a0a70c73..ff6560305cb8 100644
--- a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
+++ b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
@@ -2,6 +2,8 @@ add_subdirectory(Transforms)
add_mlir_dialect_library(MLIRLLVMIR
IR/LLVMDialect.cpp
+ IR/LLVMTypes.cpp
+ IR/LLVMTypeSyntax.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/LLVMIR
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
new file mode 100644
index 000000000000..d272297525c1
--- /dev/null
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
@@ -0,0 +1,477 @@
+//===- LLVMTypeSyntax.cpp - Parsing/printing for MLIR LLVM Dialect types --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/DialectImplementation.h"
+#include "llvm/ADT/SetVector.h"
+
+using namespace mlir;
+using namespace mlir::LLVM;
+
+//===----------------------------------------------------------------------===//
+// Printing.
+//===----------------------------------------------------------------------===//
+
+static void printTypeImpl(llvm::raw_ostream &os, LLVMTypeNew type,
+ llvm::SetVector<StringRef> &stack);
+
+/// Returns the keyword to use for the given type.
+static StringRef getTypeKeyword(LLVMTypeNew type) {
+ switch (type.getKind()) {
+ case LLVMTypeNew::VoidType:
+ return "void";
+ case LLVMTypeNew::HalfType:
+ return "half";
+ case LLVMTypeNew::BFloatType:
+ return "bfloat";
+ case LLVMTypeNew::FloatType:
+ return "float";
+ case LLVMTypeNew::DoubleType:
+ return "double";
+ case LLVMTypeNew::FP128Type:
+ return "fp128";
+ case LLVMTypeNew::X86FP80Type:
+ return "x86_fp80";
+ case LLVMTypeNew::PPCFP128Type:
+ return "ppc_fp128";
+ case LLVMTypeNew::X86MMXType:
+ return "x86_mmx";
+ case LLVMTypeNew::TokenType:
+ return "token";
+ case LLVMTypeNew::LabelType:
+ return "label";
+ case LLVMTypeNew::MetadataType:
+ return "metadata";
+ case LLVMTypeNew::FunctionType:
+ return "func";
+ case LLVMTypeNew::IntegerType:
+ return "i";
+ case LLVMTypeNew::PointerType:
+ return "ptr";
+ case LLVMTypeNew::FixedVectorType:
+ case LLVMTypeNew::ScalableVectorType:
+ return "vec";
+ case LLVMTypeNew::ArrayType:
+ return "array";
+ case LLVMTypeNew::StructType:
+ return "struct";
+ }
+ llvm_unreachable("unhandled type kind");
+}
+
+/// Prints the body of a structure type. Uses `stack` to avoid printing
+/// recursive structs indefinitely.
+static void printStructTypeBody(llvm::raw_ostream &os, LLVMStructType type,
+ llvm::SetVector<StringRef> &stack) {
+ if (type.isIdentified() && type.isOpaque()) {
+ os << "opaque";
+ return;
+ }
+
+ if (type.isPacked())
+ os << "packed ";
+
+ // Put the current type on stack to avoid infinite recursion.
+ os << '(';
+ if (type.isIdentified())
+ stack.insert(type.getName());
+ llvm::interleaveComma(type.getBody(), os, [&](LLVMTypeNew subtype) {
+ printTypeImpl(os, subtype, stack);
+ });
+ if (type.isIdentified())
+ stack.pop_back();
+ os << ')';
+}
+
+/// Prints a structure type. Uses `stack` to keep track of the identifiers of
+/// the structs being printed. Checks if the identifier of a struct is contained
+/// in `stack`, i.e. whether a self-reference to a recursive stack is being
+/// printed, and only prints the name to avoid infinite recursion.
+static void printStructType(llvm::raw_ostream &os, LLVMStructType type,
+ llvm::SetVector<StringRef> &stack) {
+ os << "<";
+ if (type.isIdentified()) {
+ os << '"' << type.getName() << '"';
+ // If we are printing a reference to one of the enclosing structs, just
+ // print the name and stop to avoid infinitely long output.
+ if (stack.count(type.getName())) {
+ os << '>';
+ return;
+ }
+ os << ", ";
+ }
+
+ printStructTypeBody(os, type, stack);
+ os << '>';
+}
+
+/// Prints a type containing a fixed number of elements.
+template <typename TypeTy>
+static void printArrayOrVectorType(llvm::raw_ostream &os, TypeTy type,
+ llvm::SetVector<StringRef> &stack) {
+ os << '<' << type.getNumElements() << " x ";
+ printTypeImpl(os, type.getElementType(), stack);
+ os << '>';
+}
+
+/// Prints a function type.
+static void printFunctionType(llvm::raw_ostream &os, LLVMFunctionType funcType,
+ llvm::SetVector<StringRef> &stack) {
+ os << '<';
+ printTypeImpl(os, funcType.getReturnType(), stack);
+ os << " (";
+ llvm::interleaveComma(funcType.getParams(), os,
+ [&os, &stack](LLVMTypeNew subtype) {
+ printTypeImpl(os, subtype, stack);
+ });
+ if (funcType.isVarArg()) {
+ if (funcType.getNumParams() != 0)
+ os << ", ";
+ os << "...";
+ }
+ os << ")>";
+}
+
+/// Prints the given LLVM dialect type recursively. This leverages closedness of
+/// the LLVM dialect type system to avoid printing the dialect prefix
+/// repeatedly. For recursive structures, only prints the name of the structure
+/// when printing a self-reference. Note that this does not apply to sibling
+/// references. For example,
+/// struct<"a", (ptr<struct<"a">>)>
+/// struct<"c", (ptr<struct<"b", (ptr<struct<"c">>)>>,
+/// ptr<struct<"b", (ptr<struct<"c">>)>>)>
+/// note that "b" is printed twice.
+static void printTypeImpl(llvm::raw_ostream &os, LLVMTypeNew type,
+ llvm::SetVector<StringRef> &stack) {
+ if (!type) {
+ os << "<<NULL-TYPE>>";
+ return;
+ }
+
+ unsigned kind = type.getKind();
+ os << getTypeKeyword(type);
+
+ // Trivial types only consist of their keyword.
+ if (LLVMTypeNew::FIRST_TRIVIAL_TYPE <= kind &&
+ kind <= LLVMTypeNew::LAST_TRIVIAL_TYPE)
+ return;
+
+ if (auto intType = type.dyn_cast<LLVMIntegerType>()) {
+ os << intType.getBitWidth();
+ return;
+ }
+
+ if (auto ptrType = type.dyn_cast<LLVMPointerType>()) {
+ os << '<';
+ printTypeImpl(os, ptrType.getElementType(), stack);
+ if (ptrType.getAddressSpace() != 0)
+ os << ", " << ptrType.getAddressSpace();
+ os << '>';
+ return;
+ }
+
+ if (auto arrayType = type.dyn_cast<LLVMArrayType>())
+ return printArrayOrVectorType(os, arrayType, stack);
+ if (auto vectorType = type.dyn_cast<LLVMFixedVectorType>())
+ return printArrayOrVectorType(os, vectorType, stack);
+
+ if (auto vectorType = type.dyn_cast<LLVMScalableVectorType>()) {
+ os << "<? x " << vectorType.getMinNumElements() << " x ";
+ printTypeImpl(os, vectorType.getElementType(), stack);
+ os << '>';
+ return;
+ }
+
+ if (auto structType = type.dyn_cast<LLVMStructType>())
+ return printStructType(os, structType, stack);
+
+ printFunctionType(os, type.cast<LLVMFunctionType>(), stack);
+}
+
+void mlir::LLVM::detail::printType(LLVMTypeNew type,
+ DialectAsmPrinter &printer) {
+ llvm::SetVector<StringRef> stack;
+ return printTypeImpl(printer.getStream(), type, stack);
+}
+
+//===----------------------------------------------------------------------===//
+// Parsing.
+//===----------------------------------------------------------------------===//
+
+static LLVMTypeNew parseTypeImpl(DialectAsmParser &parser,
+ llvm::SetVector<StringRef> &stack);
+
+/// Helper to be chained with other parsing functions.
+static ParseResult parseTypeImpl(DialectAsmParser &parser,
+ llvm::SetVector<StringRef> &stack,
+ LLVMTypeNew &result) {
+ result = parseTypeImpl(parser, stack);
+ return success(result != nullptr);
+}
+
+/// Parses an LLVM dialect function type.
+/// llvm-type :: = `func<` llvm-type `(` llvm-type-list `...`? `)>`
+static LLVMFunctionType parseFunctionType(DialectAsmParser &parser,
+ llvm::SetVector<StringRef> &stack) {
+ LLVMTypeNew returnType;
+ if (parser.parseLess() || parseTypeImpl(parser, stack, returnType) ||
+ parser.parseLParen())
+ return LLVMFunctionType();
+
+ // Function type without arguments.
+ if (succeeded(parser.parseOptionalRParen())) {
+ if (succeeded(parser.parseGreater()))
+ return LLVMFunctionType::get(returnType, {}, /*isVarArg=*/false);
+ return LLVMFunctionType();
+ }
+
+ // Parse arguments.
+ SmallVector<LLVMTypeNew, 8> argTypes;
+ do {
+ if (succeeded(parser.parseOptionalEllipsis())) {
+ if (parser.parseOptionalRParen() || parser.parseOptionalGreater())
+ return LLVMFunctionType();
+ return LLVMFunctionType::get(returnType, argTypes, /*isVarArg=*/true);
+ }
+
+ argTypes.push_back(parseTypeImpl(parser, stack));
+ if (!argTypes.back())
+ return LLVMFunctionType();
+ } while (succeeded(parser.parseOptionalComma()));
+
+ if (parser.parseOptionalRParen() || parser.parseOptionalGreater())
+ return LLVMFunctionType();
+ return LLVMFunctionType::get(returnType, argTypes, /*isVarArg=*/false);
+}
+
+/// Parses an LLVM dialect pointer type.
+/// llvm-type ::= `ptr<` llvm-type (`,` integer)? `>`
+static LLVMPointerType parsePointerType(DialectAsmParser &parser,
+ llvm::SetVector<StringRef> &stack) {
+ LLVMTypeNew elementType;
+ if (parser.parseLess() || parseTypeImpl(parser, stack, elementType))
+ return LLVMPointerType();
+
+ unsigned addressSpace = 0;
+ if (succeeded(parser.parseOptionalComma()) &&
+ failed(parser.parseInteger(addressSpace)))
+ return LLVMPointerType();
+ if (failed(parser.parseGreater()))
+ return LLVMPointerType();
+ return LLVMPointerType::get(elementType, addressSpace);
+}
+
+/// Parses an LLVM dialect vector type.
+/// llvm-type ::= `vec<` `? x`? integer `x` llvm-type `>`
+/// Supports both fixed and scalable vectors.
+static LLVMVectorType parseVectorType(DialectAsmParser &parser,
+ llvm::SetVector<StringRef> &stack) {
+ SmallVector<int64_t, 2> dims;
+ llvm::SMLoc dimPos;
+ LLVMTypeNew elementType;
+ if (parser.parseLess() || parser.getCurrentLocation(&dimPos) ||
+ parser.parseDimensionList(dims, /*allowDynamic=*/true) ||
+ parseTypeImpl(parser, stack, elementType) || parser.parseGreater())
+ return LLVMVectorType();
+
+ // We parsed a generic dimension list, but vectors only support two forms:
+ // - single non-dynamic entry in the list (fixed vector);
+ // - two elements, the first dynamic (indicated by -1) and the second
+ // non-dynamic (scalable vector).
+ if (dims.empty() || dims.size() > 2 ||
+ ((dims.size() == 2) ^ (dims[0] == -1)) ||
+ (dims.size() == 2 && dims[1] == -1)) {
+ parser.emitError(dimPos)
+ << "expected '? x <integer> x <type>' or '<integer> x <type>'";
+ return LLVMVectorType();
+ }
+
+ bool isScalable = dims.size() == 2;
+ return isScalable ? static_cast<LLVMVectorType>(
+ LLVMScalableVectorType::get(elementType, dims[1]))
+ : LLVMFixedVectorType::get(elementType, dims[0]);
+}
+
+/// Parses an LLVM dialect array type.
+/// llvm-type ::= `array<` integer `x` llvm-type `>`
+static LLVMArrayType parseArrayType(DialectAsmParser &parser,
+ llvm::SetVector<StringRef> &stack) {
+ SmallVector<int64_t, 1> dims;
+ llvm::SMLoc sizePos;
+ LLVMTypeNew elementType;
+ if (parser.parseLess() || parser.getCurrentLocation(&sizePos) ||
+ parser.parseDimensionList(dims, /*allowDynamic=*/false) ||
+ parseTypeImpl(parser, stack, elementType) || parser.parseGreater())
+ return LLVMArrayType();
+
+ if (dims.size() != 1) {
+ parser.emitError(sizePos) << "expected ? x <type>";
+ return LLVMArrayType();
+ }
+
+ return LLVMArrayType::get(elementType, dims[0]);
+}
+
+/// Attempts to set the body of an identified structure type. Reports a parsing
+/// error at `subtypesLoc` in case of failure, uses `stack` to make sure the
+/// types printed in the error message look like they did when parsed.
+static LLVMStructType trySetStructBody(LLVMStructType type,
+ ArrayRef<LLVMTypeNew> subtypes,
+ bool isPacked, DialectAsmParser &parser,
+ llvm::SMLoc subtypesLoc,
+ llvm::SetVector<StringRef> &stack) {
+ if (succeeded(type.setBody(subtypes, isPacked)))
+ return type;
+
+ std::string currentBody;
+ llvm::raw_string_ostream currentBodyStream(currentBody);
+ printStructTypeBody(currentBodyStream, type, stack);
+ auto diag = parser.emitError(subtypesLoc)
+ << "identified type already used with a
diff erent body";
+ diag.attachNote() << "existing body: " << currentBodyStream.str();
+ return LLVMStructType();
+}
+
+/// Parses an LLVM dialect structure type.
+/// llvm-type ::= `struct<` (string-literal `,`)? `packed`?
+/// `(` llvm-type-list `)` `>`
+/// | `struct<` string-literal `>`
+/// | `struct<` string-literal `, opaque>`
+static LLVMStructType parseStructType(DialectAsmParser &parser,
+ llvm::SetVector<StringRef> &stack) {
+ MLIRContext *ctx = parser.getBuilder().getContext();
+
+ if (failed(parser.parseLess()))
+ return LLVMStructType();
+
+ // If we are parsing a self-reference to a recursive struct, i.e. the parsing
+ // stack already contains a struct with the same identifier, bail out after
+ // the name.
+ StringRef name;
+ bool isIdentified = succeeded(parser.parseOptionalString(&name));
+ if (isIdentified) {
+ if (stack.count(name)) {
+ if (failed(parser.parseGreater()))
+ return LLVMStructType();
+ return LLVMStructType::getIdentified(ctx, name);
+ }
+ if (failed(parser.parseComma()))
+ return LLVMStructType();
+ }
+
+ // Handle intentionally opaque structs.
+ llvm::SMLoc kwLoc = parser.getCurrentLocation();
+ if (succeeded(parser.parseOptionalKeyword("opaque"))) {
+ if (!isIdentified)
+ return parser.emitError(kwLoc, "only identified structs can be opaque"),
+ LLVMStructType();
+ if (failed(parser.parseGreater()))
+ return LLVMStructType();
+ auto type = LLVMStructType::getOpaque(name, ctx);
+ if (!type.isOpaque()) {
+ parser.emitError(kwLoc, "redeclaring defined struct as opaque");
+ return LLVMStructType();
+ }
+ return type;
+ }
+
+ // Check for packedness.
+ bool isPacked = succeeded(parser.parseOptionalKeyword("packed"));
+ if (failed(parser.parseLParen()))
+ return LLVMStructType();
+
+ // Fast pass for structs with zero subtypes.
+ if (succeeded(parser.parseOptionalRParen())) {
+ if (failed(parser.parseGreater()))
+ return LLVMStructType();
+ if (!isIdentified)
+ return LLVMStructType::getLiteral(ctx, {}, isPacked);
+ auto type = LLVMStructType::getIdentified(ctx, name);
+ return trySetStructBody(type, {}, isPacked, parser, kwLoc, stack);
+ }
+
+ // Parse subtypes. For identified structs, put the identifier of the struct on
+ // the stack to support self-references in the recursive calls.
+ SmallVector<LLVMTypeNew, 4> subtypes;
+ llvm::SMLoc subtypesLoc = parser.getCurrentLocation();
+ do {
+ if (isIdentified)
+ stack.insert(name);
+ LLVMTypeNew type = parseTypeImpl(parser, stack);
+ if (!type)
+ return LLVMStructType();
+ subtypes.push_back(type);
+ if (isIdentified)
+ stack.pop_back();
+ } while (succeeded(parser.parseOptionalComma()));
+
+ if (parser.parseRParen() || parser.parseGreater())
+ return LLVMStructType();
+
+ // Construct the struct with body.
+ if (!isIdentified)
+ return LLVMStructType::getLiteral(ctx, subtypes, isPacked);
+ auto type = LLVMStructType::getIdentified(ctx, name);
+ return trySetStructBody(type, subtypes, isPacked, parser, subtypesLoc, stack);
+}
+
+/// Parses one of the LLVM dialect types.
+static LLVMTypeNew parseTypeImpl(DialectAsmParser &parser,
+ llvm::SetVector<StringRef> &stack) {
+ // Special case for integers (i[1-9][0-9]*) that are literals rather than
+ // keywords for the parser, so they are not caught by the main dispatch below.
+ // Try parsing it a built-in integer type instead.
+ Type maybeIntegerType;
+ MLIRContext *ctx = parser.getBuilder().getContext();
+ llvm::SMLoc keyLoc = parser.getCurrentLocation();
+ OptionalParseResult result = parser.parseOptionalType(maybeIntegerType);
+ if (result.hasValue()) {
+ if (failed(*result))
+ return LLVMTypeNew();
+
+ if (!maybeIntegerType.isSignlessInteger()) {
+ parser.emitError(keyLoc) << "unexpected type, expected i* or keyword";
+ return LLVMTypeNew();
+ }
+ return LLVMIntegerType::get(ctx, maybeIntegerType.getIntOrFloatBitWidth());
+ }
+
+ // Dispatch to concrete functions.
+ StringRef key;
+ if (failed(parser.parseKeyword(&key)))
+ return LLVMTypeNew();
+
+ return llvm::StringSwitch<function_ref<LLVMTypeNew()>>(key)
+ .Case("void", [&] { return LLVMVoidType::get(ctx); })
+ .Case("half", [&] { return LLVMHalfType::get(ctx); })
+ .Case("bfloat", [&] { return LLVMBFloatType::get(ctx); })
+ .Case("float", [&] { return LLVMFloatType::get(ctx); })
+ .Case("double", [&] { return LLVMDoubleType::get(ctx); })
+ .Case("fp128", [&] { return LLVMFP128Type::get(ctx); })
+ .Case("x86_fp80", [&] { return LLVMX86FP80Type::get(ctx); })
+ .Case("ppc_fp128", [&] { return LLVMPPCFP128Type::get(ctx); })
+ .Case("x86_mmx", [&] { return LLVMX86MMXType::get(ctx); })
+ .Case("token", [&] { return LLVMTokenType::get(ctx); })
+ .Case("label", [&] { return LLVMLabelType::get(ctx); })
+ .Case("metadata", [&] { return LLVMMetadataType::get(ctx); })
+ .Case("func", [&] { return parseFunctionType(parser, stack); })
+ .Case("ptr", [&] { return parsePointerType(parser, stack); })
+ .Case("vec", [&] { return parseVectorType(parser, stack); })
+ .Case("array", [&] { return parseArrayType(parser, stack); })
+ .Case("struct", [&] { return parseStructType(parser, stack); })
+ .Default([&] {
+ parser.emitError(keyLoc) << "unknown LLVM type: " << key;
+ return LLVMTypeNew();
+ })();
+}
+
+LLVMTypeNew mlir::LLVM::detail::parseType(DialectAsmParser &parser) {
+ llvm::SetVector<StringRef> stack;
+ return parseTypeImpl(parser, stack);
+}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
new file mode 100644
index 000000000000..3540091e90a3
--- /dev/null
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
@@ -0,0 +1,163 @@
+//===- LLVMTypes.cpp - MLIR LLVM Dialect types ----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the types for the LLVM dialect in MLIR. These MLIR types
+// correspond to the LLVM IR type system.
+//
+//===----------------------------------------------------------------------===//
+
+#include "TypeDetail.h"
+
+#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
+#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/TypeSupport.h"
+
+#include "llvm/Support/TypeSize.h"
+
+using namespace mlir;
+using namespace mlir::LLVM;
+
+//===----------------------------------------------------------------------===//
+// Array type.
+
+LLVMArrayType LLVMArrayType::get(LLVMTypeNew elementType,
+ unsigned numElements) {
+ assert(elementType && "expected non-null subtype");
+ return Base::get(elementType.getContext(), LLVMTypeNew::ArrayType,
+ elementType, numElements);
+}
+
+LLVMTypeNew LLVMArrayType::getElementType() { return getImpl()->elementType; }
+
+unsigned LLVMArrayType::getNumElements() { return getImpl()->numElements; }
+
+//===----------------------------------------------------------------------===//
+// Function type.
+
+LLVMFunctionType LLVMFunctionType::get(LLVMTypeNew result,
+ ArrayRef<LLVMTypeNew> arguments,
+ bool isVarArg) {
+ assert(result && "expected non-null result");
+ return Base::get(result.getContext(), LLVMTypeNew::FunctionType, result,
+ arguments, isVarArg);
+}
+
+LLVMTypeNew LLVMFunctionType::getReturnType() {
+ return getImpl()->getReturnType();
+}
+
+unsigned LLVMFunctionType::getNumParams() {
+ return getImpl()->getArgumentTypes().size();
+}
+
+LLVMTypeNew LLVMFunctionType::getParamType(unsigned i) {
+ return getImpl()->getArgumentTypes()[i];
+}
+
+bool LLVMFunctionType::isVarArg() { return getImpl()->isVariadic(); }
+
+ArrayRef<LLVMTypeNew> LLVMFunctionType::getParams() {
+ return getImpl()->getArgumentTypes();
+}
+
+//===----------------------------------------------------------------------===//
+// Integer type.
+
+LLVMIntegerType LLVMIntegerType::get(MLIRContext *ctx, unsigned bitwidth) {
+ return Base::get(ctx, LLVMTypeNew::IntegerType, bitwidth);
+}
+
+unsigned LLVMIntegerType::getBitWidth() { return getImpl()->bitwidth; }
+
+//===----------------------------------------------------------------------===//
+// Pointer type.
+
+LLVMPointerType LLVMPointerType::get(LLVMTypeNew pointee,
+ unsigned addressSpace) {
+ assert(pointee && "expected non-null subtype");
+ return Base::get(pointee.getContext(), LLVMTypeNew::PointerType, pointee,
+ addressSpace);
+}
+
+LLVMTypeNew LLVMPointerType::getElementType() { return getImpl()->pointeeType; }
+
+unsigned LLVMPointerType::getAddressSpace() { return getImpl()->addressSpace; }
+
+//===----------------------------------------------------------------------===//
+// Struct type.
+
+LLVMStructType LLVMStructType::getIdentified(MLIRContext *context,
+ StringRef name) {
+ return Base::get(context, LLVMTypeNew::StructType, name, /*opaque=*/false);
+}
+
+LLVMStructType LLVMStructType::getLiteral(MLIRContext *context,
+ ArrayRef<LLVMTypeNew> types,
+ bool isPacked) {
+ return Base::get(context, LLVMTypeNew::StructType, types, isPacked);
+}
+
+LLVMStructType LLVMStructType::getOpaque(StringRef name, MLIRContext *context) {
+ return Base::get(context, LLVMTypeNew::StructType, name, /*opaque=*/true);
+}
+
+LogicalResult LLVMStructType::setBody(ArrayRef<LLVMTypeNew> types,
+ bool isPacked) {
+ assert(isIdentified() && "can only set bodies of identified structs");
+ return Base::mutate(types, isPacked);
+}
+
+bool LLVMStructType::isPacked() { return getImpl()->isPacked(); }
+bool LLVMStructType::isIdentified() { return getImpl()->isIdentified(); }
+bool LLVMStructType::isOpaque() {
+ return getImpl()->isOpaque() || !getImpl()->isInitialized();
+}
+StringRef LLVMStructType::getName() { return getImpl()->getIdentifier(); }
+ArrayRef<LLVMTypeNew> LLVMStructType::getBody() {
+ return isIdentified() ? getImpl()->getIdentifiedStructBody()
+ : getImpl()->getTypeList();
+}
+
+//===----------------------------------------------------------------------===//
+// Vector types.
+
+LLVMTypeNew LLVMVectorType::getElementType() {
+ // Both derived classes share the implementation type.
+ return static_cast<detail::LLVMTypeAndSizeStorage *>(impl)->elementType;
+}
+
+llvm::ElementCount LLVMVectorType::getElementCount() {
+ // Both derived classes share the implementation type.
+ return llvm::ElementCount(
+ static_cast<detail::LLVMTypeAndSizeStorage *>(impl)->numElements,
+ this->isa<LLVMScalableVectorType>());
+}
+
+LLVMFixedVectorType LLVMFixedVectorType::get(LLVMTypeNew elementType,
+ unsigned numElements) {
+ assert(elementType && "expected non-null subtype");
+ return Base::get(elementType.getContext(), LLVMTypeNew::FixedVectorType,
+ elementType, numElements)
+ .cast<LLVMFixedVectorType>();
+}
+
+unsigned LLVMFixedVectorType::getNumElements() {
+ return getImpl()->numElements;
+}
+
+LLVMScalableVectorType LLVMScalableVectorType::get(LLVMTypeNew elementType,
+ unsigned minNumElements) {
+ assert(elementType && "expected non-null subtype");
+ return Base::get(elementType.getContext(), LLVMTypeNew::ScalableVectorType,
+ elementType, minNumElements)
+ .cast<LLVMScalableVectorType>();
+}
+
+unsigned LLVMScalableVectorType::getMinNumElements() {
+ return getImpl()->numElements;
+}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/TypeDetail.h b/mlir/lib/Dialect/LLVMIR/IR/TypeDetail.h
new file mode 100644
index 000000000000..2b72e43e5164
--- /dev/null
+++ b/mlir/lib/Dialect/LLVMIR/IR/TypeDetail.h
@@ -0,0 +1,458 @@
+//===- TypeDetail.h - Details of MLIR LLVM dialect types --------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains implementation details, such as storage structures, of
+// MLIR LLVM dialect types.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef DIALECT_LLVMIR_IR_TYPEDETAIL_H
+#define DIALECT_LLVMIR_IR_TYPEDETAIL_H
+
+#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
+#include "mlir/IR/TypeSupport.h"
+#include "mlir/IR/Types.h"
+
+#include "llvm/ADT/Bitfields.h"
+#include "llvm/ADT/PointerIntPair.h"
+
+namespace mlir {
+namespace LLVM {
+namespace detail {
+
+//===----------------------------------------------------------------------===//
+// LLVMStructTypeStorage.
+//===----------------------------------------------------------------------===//
+
+/// Type storage for LLVM structure types.
+///
+/// Structures are uniqued using:
+/// - a bit indicating whether a struct is literal or identified;
+/// - for identified structs, in addition to the bit:
+/// - a string identifier;
+/// - for literal structs, in addition to the bit:
+/// - a list of contained types;
+/// - a bit indicating whether the literal struct is packed.
+///
+/// Identified structures only have a mutable component consisting of:
+/// - a list of contained types;
+/// - a bit indicating whether the identified struct is packed;
+/// - a bit indicating whether the identified struct is intentionally opaque;
+/// - a bit indicating whether the identified struct has been initialized.
+/// Uninitialized structs are considered opaque by the user, and can be mutated.
+/// Initialized and still opaque structs cannot be mutated.
+///
+/// The struct storage consists of:
+/// - immutable part:
+/// - a pointer to the first element of the key (character for identified
+/// structs, type for literal structs);
+/// - the number of elements in the key packed together with bits indicating
+/// whether a type is literal or identified, and the packedness bit for
+/// literal structs only;
+/// - mutable part:
+/// - a pointer to the first contained type for identified structs only;
+/// - the number of contained types packed together with bits of the mutable
+/// component, for identified structs only.
+struct LLVMStructTypeStorage : public TypeStorage {
+public:
+ /// Construction/uniquing key class for LLVM dialect structure storage. Note
+ /// that this is a transient helper data structure that is NOT stored.
+ /// Therefore, it intentionally avoids bit manipulation and type erasure in
+ /// pointers to make manipulation more straightforward. Not all elements of
+ /// the key participate in uniquing, but all elements participate in
+ /// construction.
+ class Key {
+ public:
+ /// Constructs a key for an identified struct.
+ Key(StringRef name, bool opaque)
+ : name(name), identified(true), packed(false), opaque(opaque) {}
+ /// Constructs a key for a literal struct.
+ Key(ArrayRef<LLVMTypeNew> types, bool packed)
+ : types(types), identified(false), packed(packed), opaque(false) {}
+
+ /// Checks a specific property of the struct.
+ bool isIdentified() const { return identified; }
+ bool isPacked() const {
+ assert(!isIdentified() &&
+ "'packed' bit is not part of the key for identified stucts");
+ return packed;
+ }
+ bool isOpaque() const {
+ assert(isIdentified() &&
+ "'opaque' bit is meaningless on literal structs");
+ return opaque;
+ }
+
+ /// Returns the identifier of a key for identified structs.
+ StringRef getIdentifier() const {
+ assert(isIdentified() &&
+ "non-identified struct key canont have an identifier");
+ return name;
+ }
+
+ /// Returns the list of type contained in the key of a literal struct.
+ ArrayRef<LLVMTypeNew> getTypeList() const {
+ assert(!isIdentified() &&
+ "identified struct key cannot have a type list");
+ return types;
+ }
+
+ /// Returns the hash value of the key. This combines various flags into a
+ /// single value: the identified flag sets the first bit, and the packedness
+ /// flag sets the second bit. Opacity bit is only used for construction and
+ /// does not participate in uniquing.
+ llvm::hash_code hashValue() const {
+ constexpr static unsigned kIdentifiedHashFlag = 1;
+ constexpr static unsigned kPackedHashFlag = 2;
+
+ unsigned flags = 0;
+ if (isIdentified()) {
+ flags |= kIdentifiedHashFlag;
+ return llvm::hash_combine(flags, getIdentifier());
+ }
+ if (isPacked())
+ flags |= kPackedHashFlag;
+ return llvm::hash_combine(flags, getTypeList());
+ }
+
+ /// Compares two keys.
+ bool operator==(const Key &other) const {
+ if (isIdentified())
+ return other.isIdentified() &&
+ other.getIdentifier().equals(getIdentifier());
+
+ return !other.isIdentified() && other.isPacked() == isPacked() &&
+ other.getTypeList() == getTypeList();
+ }
+
+ /// Copies dynamically-sized components of the key into the given allocator.
+ Key copyIntoAllocator(TypeStorageAllocator &allocator) const {
+ if (isIdentified())
+ return Key(allocator.copyInto(name), opaque);
+ return Key(allocator.copyInto(types), packed);
+ }
+
+ private:
+ ArrayRef<LLVMTypeNew> types;
+ StringRef name;
+ bool identified;
+ bool packed;
+ bool opaque;
+ };
+ using KeyTy = Key;
+
+ /// Returns the string identifier of an identified struct.
+ StringRef getIdentifier() const {
+ assert(isIdentified() && "requested identifier on a non-identified struct");
+ return StringRef(static_cast<const char *>(keyPtr), keySize());
+ }
+
+ /// Returns the list of types (partially) identifying a literal struct.
+ ArrayRef<LLVMTypeNew> getTypeList() const {
+ // If this triggers, use getIdentifiedStructBody() instead.
+ assert(!isIdentified() && "requested typelist on an identified struct");
+ return ArrayRef<LLVMTypeNew>(static_cast<const LLVMTypeNew *>(keyPtr),
+ keySize());
+ }
+
+ /// Returns the list of types contained in an identified struct.
+ ArrayRef<LLVMTypeNew> getIdentifiedStructBody() const {
+ // If this triggers, use getTypeList() instead.
+ assert(isIdentified() &&
+ "requested struct body on a non-identified struct");
+ return ArrayRef<LLVMTypeNew>(identifiedBodyArray, identifiedBodySize());
+ }
+
+ /// Checks whether the struct is identified.
+ bool isIdentified() const {
+ return llvm::Bitfield::get<KeyFlagIdentified>(keySizeAndFlags);
+ }
+
+ /// Checks whether the struct is packed (both literal and identified structs).
+ bool isPacked() const {
+ return isIdentified() ? llvm::Bitfield::get<MutableFlagPacked>(
+ identifiedBodySizeAndFlags)
+ : llvm::Bitfield::get<KeyFlagPacked>(keySizeAndFlags);
+ }
+
+ /// Checks whether a struct is marked as intentionally opaque (an
+ /// uninitialized struct is also considered opaque by the user, call
+ /// isInitialized to check that).
+ bool isOpaque() const {
+ return llvm::Bitfield::get<MutableFlagOpaque>(identifiedBodySizeAndFlags);
+ }
+
+ /// Checks whether an identified struct has been explicitly initialized either
+ /// by setting its body or by marking it as intentionally opaque.
+ bool isInitialized() const {
+ return llvm::Bitfield::get<MutableFlagInitialized>(
+ identifiedBodySizeAndFlags);
+ }
+
+ /// Constructs the storage from the given key. This sets up the uniquing key
+ /// components and optionally the mutable component if they construction key
+ /// has the relevant information. In the latter case, the struct is considered
+ /// as initalized and can no longer be mutated.
+ LLVMStructTypeStorage(const KeyTy &key) {
+ if (!key.isIdentified()) {
+ ArrayRef<LLVMTypeNew> types = key.getTypeList();
+ keyPtr = static_cast<const void *>(types.data());
+ setKeySize(types.size());
+ llvm::Bitfield::set<KeyFlagPacked>(keySizeAndFlags, key.isPacked());
+ return;
+ }
+
+ StringRef name = key.getIdentifier();
+ keyPtr = static_cast<const void *>(name.data());
+ setKeySize(name.size());
+ llvm::Bitfield::set<KeyFlagIdentified>(keySizeAndFlags, true);
+
+ // If the struct is being constructed directly as opaque, mark it as
+ // initialized.
+ llvm::Bitfield::set<MutableFlagInitialized>(identifiedBodySizeAndFlags,
+ key.isOpaque());
+ llvm::Bitfield::set<MutableFlagOpaque>(identifiedBodySizeAndFlags,
+ key.isOpaque());
+ }
+
+ /// Hook into the type unquing infrastructure.
+ bool operator==(const KeyTy &other) const { return getKey() == other; };
+ static llvm::hash_code hashKey(const KeyTy &key) { return key.hashValue(); }
+ static LLVMStructTypeStorage *construct(TypeStorageAllocator &allocator,
+ const KeyTy &key) {
+ return new (allocator.allocate<LLVMStructTypeStorage>())
+ LLVMStructTypeStorage(key.copyIntoAllocator(allocator));
+ }
+
+ /// Sets the body of an identified struct. If the struct is already
+ /// initialized, succeeds only if the body is equal to the current body. Fails
+ /// if the struct is marked as intentionally opaque. The struct will be marked
+ /// as initialized as a result of this operation and can no longer be changed.
+ LogicalResult mutate(TypeStorageAllocator &allocator,
+ ArrayRef<LLVMTypeNew> body, bool packed) {
+ if (!isIdentified())
+ return failure();
+ if (isInitialized())
+ return success(!isOpaque() && body == getIdentifiedStructBody() &&
+ packed == isPacked());
+
+ llvm::Bitfield::set<MutableFlagInitialized>(identifiedBodySizeAndFlags,
+ true);
+ llvm::Bitfield::set<MutableFlagPacked>(identifiedBodySizeAndFlags, packed);
+
+ ArrayRef<LLVMTypeNew> typesInAllocator = allocator.copyInto(body);
+ identifiedBodyArray = typesInAllocator.data();
+ setIdentifiedBodySize(typesInAllocator.size());
+
+ return success();
+ }
+
+private:
+ /// Returns the number of elements in the key.
+ unsigned keySize() const {
+ return llvm::Bitfield::get<KeySize>(keySizeAndFlags);
+ }
+
+ /// Sets the number of elements in the key.
+ void setKeySize(unsigned value) {
+ llvm::Bitfield::set<KeySize>(keySizeAndFlags, value);
+ }
+
+ /// Returns the number of types contained in an identified struct.
+ unsigned identifiedBodySize() const {
+ return llvm::Bitfield::get<MutableSize>(identifiedBodySizeAndFlags);
+ }
+ /// Sets the number of types contained in an identified struct.
+ void setIdentifiedBodySize(unsigned value) {
+ llvm::Bitfield::set<MutableSize>(identifiedBodySizeAndFlags, value);
+ }
+
+ /// Returns the key for the current storage.
+ Key getKey() const {
+ if (isIdentified())
+ return Key(getIdentifier(), isOpaque());
+ return Key(getTypeList(), isPacked());
+ }
+
+ /// Bitfield elements for `keyAndSizeFlags`:
+ /// - bit 0: identified key flag;
+ /// - bit 1: packed key flag;
+ /// - bits 2..bitwidth(unsigned): size of the key.
+ using KeyFlagIdentified =
+ llvm::Bitfield::Element<bool, /*Offset=*/0, /*Size=*/1>;
+ using KeyFlagPacked = llvm::Bitfield::Element<bool, /*Offset=*/1, /*Size=*/1>;
+ using KeySize =
+ llvm::Bitfield::Element<unsigned, /*Offset=*/2,
+ std::numeric_limits<unsigned>::digits - 2>;
+
+ /// Bitfield elements for `identifiedBodySizeAndFlags`:
+ /// - bit 0: opaque flag;
+ /// - bit 1: packed mutable flag;
+ /// - bit 2: initialized flag;
+ /// - bits 3..bitwidth(unsigned): size of the identified body.
+ using MutableFlagOpaque =
+ llvm::Bitfield::Element<bool, /*Offset=*/0, /*Size=*/1>;
+ using MutableFlagPacked =
+ llvm::Bitfield::Element<bool, /*Offset=*/1, /*Size=*/1>;
+ using MutableFlagInitialized =
+ llvm::Bitfield::Element<bool, /*Offset=*/2, /*Size=*/1>;
+ using MutableSize =
+ llvm::Bitfield::Element<unsigned, /*Offset=*/3,
+ std::numeric_limits<unsigned>::digits - 3>;
+
+ /// Pointer to the first element of the uniquing key.
+ // Note: cannot use PointerUnion because bump-ptr allocator does not guarantee
+ // address alignment.
+ const void *keyPtr = nullptr;
+
+ /// Pointer to the first type contained in an identified struct.
+ const LLVMTypeNew *identifiedBodyArray = nullptr;
+
+ /// Size of the uniquing key combined with identified/literal and
+ /// packedness bits. Must only be used through the Key* bitfields.
+ unsigned keySizeAndFlags = 0;
+
+ /// Number of the types contained in an identified struct combined with
+ /// mutable flags. Must only be used through the Mutable* bitfields.
+ unsigned identifiedBodySizeAndFlags = 0;
+};
+
+//===----------------------------------------------------------------------===//
+// LLVMFunctionTypeStorage.
+//===----------------------------------------------------------------------===//
+
+/// Type storage for LLVM dialect function types. These are uniqued using the
+/// list of types they contain and the vararg bit.
+struct LLVMFunctionTypeStorage : public TypeStorage {
+ using KeyTy = std::tuple<LLVMTypeNew, ArrayRef<LLVMTypeNew>, bool>;
+
+ /// Construct a storage from the given components. The list is expected to be
+ /// allocated in the context.
+ LLVMFunctionTypeStorage(LLVMTypeNew result, ArrayRef<LLVMTypeNew> arguments,
+ bool variadic)
+ : argumentTypes(arguments) {
+ returnTypeAndVariadic.setPointerAndInt(result, variadic);
+ }
+
+ /// Hook into the type uniquing infrastructure.
+ static LLVMFunctionTypeStorage *construct(TypeStorageAllocator &allocator,
+ const KeyTy &key) {
+ return new (allocator.allocate<LLVMFunctionTypeStorage>())
+ LLVMFunctionTypeStorage(std::get<0>(key),
+ allocator.copyInto(std::get<1>(key)),
+ std::get<2>(key));
+ }
+
+ static unsigned hashKey(const KeyTy &key) {
+ // LLVM doesn't like hashing bools in tuples.
+ return llvm::hash_combine(std::get<0>(key), std::get<1>(key),
+ static_cast<int>(std::get<2>(key)));
+ }
+
+ bool operator==(const KeyTy &key) const {
+ return std::make_tuple(getReturnType(), getArgumentTypes(), isVariadic()) ==
+ key;
+ }
+
+ /// Returns the list of function argument types.
+ ArrayRef<LLVMTypeNew> getArgumentTypes() const { return argumentTypes; }
+
+ /// Checks whether the function type is variadic.
+ bool isVariadic() const { return returnTypeAndVariadic.getInt(); }
+
+ /// Returns the function result type.
+ LLVMTypeNew getReturnType() const {
+ return returnTypeAndVariadic.getPointer();
+ }
+
+private:
+ /// Function result type packed with the variadic bit.
+ llvm::PointerIntPair<LLVMTypeNew, 1, bool> returnTypeAndVariadic;
+ /// Argument types.
+ ArrayRef<LLVMTypeNew> argumentTypes;
+};
+
+//===----------------------------------------------------------------------===//
+// LLVMIntegerTypeStorage.
+//===----------------------------------------------------------------------===//
+
+/// Storage type for LLVM dialect integer types. These are uniqued by bitwidth.
+struct LLVMIntegerTypeStorage : public TypeStorage {
+ using KeyTy = unsigned;
+
+ LLVMIntegerTypeStorage(unsigned width) : bitwidth(width) {}
+
+ static LLVMIntegerTypeStorage *construct(TypeStorageAllocator &allocator,
+ const KeyTy &key) {
+ return new (allocator.allocate<LLVMIntegerTypeStorage>())
+ LLVMIntegerTypeStorage(key);
+ }
+
+ bool operator==(const KeyTy &key) const { return key == bitwidth; }
+
+ unsigned bitwidth;
+};
+
+//===----------------------------------------------------------------------===//
+// LLVMPointerTypeStorage.
+//===----------------------------------------------------------------------===//
+
+/// Storage type for LLVM dialect pointer types. These are uniqued by a pair of
+/// element type and address space.
+struct LLVMPointerTypeStorage : public TypeStorage {
+ using KeyTy = std::tuple<LLVMTypeNew, unsigned>;
+
+ LLVMPointerTypeStorage(const KeyTy &key)
+ : pointeeType(std::get<0>(key)), addressSpace(std::get<1>(key)) {}
+
+ static LLVMPointerTypeStorage *construct(TypeStorageAllocator &allocator,
+ const KeyTy &key) {
+ return new (allocator.allocate<LLVMPointerTypeStorage>())
+ LLVMPointerTypeStorage(key);
+ }
+
+ bool operator==(const KeyTy &key) const {
+ return std::make_tuple(pointeeType, addressSpace) == key;
+ }
+
+ LLVMTypeNew pointeeType;
+ unsigned addressSpace;
+};
+
+//===----------------------------------------------------------------------===//
+// LLVMTypeAndSizeStorage.
+//===----------------------------------------------------------------------===//
+
+/// Common storage used for LLVM dialect types that need an element type and a
+/// number: arrays, fixed and scalable vectors. The actual semantics of the
+/// type is defined by its kind.
+struct LLVMTypeAndSizeStorage : public TypeStorage {
+ using KeyTy = std::tuple<LLVMTypeNew, unsigned>;
+
+ LLVMTypeAndSizeStorage(const KeyTy &key)
+ : elementType(std::get<0>(key)), numElements(std::get<1>(key)) {}
+
+ static LLVMTypeAndSizeStorage *construct(TypeStorageAllocator &allocator,
+ const KeyTy &key) {
+ return new (allocator.allocate<LLVMTypeAndSizeStorage>())
+ LLVMTypeAndSizeStorage(key);
+ }
+
+ bool operator==(const KeyTy &key) const {
+ return std::make_tuple(elementType, numElements) == key;
+ }
+
+ LLVMTypeNew elementType;
+ unsigned numElements;
+};
+
+} // end namespace detail
+} // end namespace LLVM
+} // end namespace mlir
+
+#endif // DIALECT_LLVMIR_IR_TYPEDETAIL_H
diff --git a/mlir/lib/Parser/DialectSymbolParser.cpp b/mlir/lib/Parser/DialectSymbolParser.cpp
index 1a7e2c5448c1..3b522a876f25 100644
--- a/mlir/lib/Parser/DialectSymbolParser.cpp
+++ b/mlir/lib/Parser/DialectSymbolParser.cpp
@@ -237,6 +237,17 @@ class CustomDialectAsmParser : public DialectAsmParser {
return success(parser.consumeIf(Token::star));
}
+ /// Parses a quoted string token if present.
+ ParseResult parseOptionalString(StringRef *string) override {
+ if (!parser.getToken().is(Token::string))
+ return failure();
+
+ if (string)
+ *string = parser.getTokenSpelling().drop_front().drop_back();
+ parser.consumeToken();
+ return success();
+ }
+
/// Returns if the current token corresponds to a keyword.
bool isCurrentTokenAKeyword() const {
return parser.getToken().is(Token::bare_identifier) ||
@@ -297,6 +308,10 @@ class CustomDialectAsmParser : public DialectAsmParser {
return parser.parseDimensionListRanked(dimensions, allowDynamic);
}
+ OptionalParseResult parseOptionalType(Type &result) override {
+ return parser.parseOptionalType(result);
+ }
+
private:
/// The full symbol specification.
StringRef fullSpec;
diff --git a/mlir/test/Dialect/LLVMIR/types-invalid.mlir b/mlir/test/Dialect/LLVMIR/types-invalid.mlir
new file mode 100644
index 000000000000..bb281087412c
--- /dev/null
+++ b/mlir/test/Dialect/LLVMIR/types-invalid.mlir
@@ -0,0 +1,95 @@
+// RUN: mlir-opt --allow-unregistered-dialect -split-input-file -verify-diagnostics %s
+
+func @repeated_struct_name() {
+ "some.op"() : () -> !llvm2.struct<"a", (ptr<struct<"a">>)>
+ // expected-error @+2 {{identified type already used with a
diff erent body}}
+ // expected-note @+1 {{existing body: (ptr<struct<"a">>)}}
+ "some.op"() : () -> !llvm2.struct<"a", (i32)>
+}
+
+// -----
+
+func @repeated_struct_name_packed() {
+ "some.op"() : () -> !llvm2.struct<"a", packed (i32)>
+ // expected-error @+2 {{identified type already used with a
diff erent body}}
+ // expected-note @+1 {{existing body: packed (i32)}}
+ "some.op"() : () -> !llvm2.struct<"a", (i32)>
+}
+
+// -----
+
+func @repeated_struct_opaque() {
+ "some.op"() : () -> !llvm2.struct<"a", opaque>
+ // expected-error @+2 {{identified type already used with a
diff erent body}}
+ // expected-note @+1 {{existing body: opaque}}
+ "some.op"() : () -> !llvm2.struct<"a", ()>
+}
+
+// -----
+
+func @repeated_struct_opaque_non_empty() {
+ "some.op"() : () -> !llvm2.struct<"a", opaque>
+ // expected-error @+2 {{identified type already used with a
diff erent body}}
+ // expected-note @+1 {{existing body: opaque}}
+ "some.op"() : () -> !llvm2.struct<"a", (i32, i32)>
+}
+
+// -----
+
+func @repeated_struct_opaque_redefinition() {
+ "some.op"() : () -> !llvm2.struct<"a", ()>
+ // expected-error @+1 {{redeclaring defined struct as opaque}}
+ "some.op"() : () -> !llvm2.struct<"a", opaque>
+}
+
+// -----
+
+func @struct_literal_opaque() {
+ // expected-error @+1 {{only identified structs can be opaque}}
+ "some.op"() : () -> !llvm2.struct<opaque>
+}
+
+// -----
+
+func @unexpected_type() {
+ // expected-error @+1 {{unexpected type, expected i* or keyword}}
+ "some.op"() : () -> !llvm2.f32
+}
+
+// -----
+
+func @unexpected_type() {
+ // expected-error @+1 {{unknown LLVM type}}
+ "some.op"() : () -> !llvm2.ifoo
+}
+
+// -----
+
+func @explicitly_opaque_struct() {
+ "some.op"() : () -> !llvm2.struct<"a", opaque>
+ // expected-error @+2 {{identified type already used with a
diff erent body}}
+ // expected-note @+1 {{existing body: opaque}}
+ "some.op"() : () -> !llvm2.struct<"a", ()>
+}
+
+// -----
+
+func @dynamic_vector() {
+ // expected-error @+1 {{expected '? x <integer> x <type>' or '<integer> x <type>'}}
+ "some.op"() : () -> !llvm2.vec<? x float>
+}
+
+// -----
+
+func @dynamic_scalable_vector() {
+ // expected-error @+1 {{expected '? x <integer> x <type>' or '<integer> x <type>'}}
+ "some.op"() : () -> !llvm2.vec<? x ? x float>
+}
+
+// -----
+
+func @unscalable_vector() {
+ // expected-error @+1 {{expected '? x <integer> x <type>' or '<integer> x <type>'}}
+ "some.op"() : () -> !llvm2.vec<4 x 4 x i32>
+}
+
diff --git a/mlir/test/Dialect/LLVMIR/types.mlir b/mlir/test/Dialect/LLVMIR/types.mlir
new file mode 100644
index 000000000000..7ce606fe8c6a
--- /dev/null
+++ b/mlir/test/Dialect/LLVMIR/types.mlir
@@ -0,0 +1,184 @@
+// RUN: mlir-opt -allow-unregistered-dialect %s -split-input-file | mlir-opt -allow-unregistered-dialect | FileCheck %s
+
+// CHECK-LABEL: @primitive
+func @primitive() {
+ // CHECK: !llvm2.void
+ "some.op"() : () -> !llvm2.void
+ // CHECK: !llvm2.half
+ "some.op"() : () -> !llvm2.half
+ // CHECK: !llvm2.bfloat
+ "some.op"() : () -> !llvm2.bfloat
+ // CHECK: !llvm2.float
+ "some.op"() : () -> !llvm2.float
+ // CHECK: !llvm2.double
+ "some.op"() : () -> !llvm2.double
+ // CHECK: !llvm2.fp128
+ "some.op"() : () -> !llvm2.fp128
+ // CHECK: !llvm2.x86_fp80
+ "some.op"() : () -> !llvm2.x86_fp80
+ // CHECK: !llvm2.ppc_fp128
+ "some.op"() : () -> !llvm2.ppc_fp128
+ // CHECK: !llvm2.x86_mmx
+ "some.op"() : () -> !llvm2.x86_mmx
+ // CHECK: !llvm2.token
+ "some.op"() : () -> !llvm2.token
+ // CHECK: !llvm2.label
+ "some.op"() : () -> !llvm2.label
+ // CHECK: !llvm2.metadata
+ "some.op"() : () -> !llvm2.metadata
+ return
+}
+
+// CHECK-LABEL: @func
+func @func() {
+ // CHECK: !llvm2.func<void ()>
+ "some.op"() : () -> !llvm2.func<void ()>
+ // CHECK: !llvm2.func<void (i32)>
+ "some.op"() : () -> !llvm2.func<void (i32)>
+ // CHECK: !llvm2.func<i32 ()>
+ "some.op"() : () -> !llvm2.func<i32 ()>
+ // CHECK: !llvm2.func<i32 (half, bfloat, float, double)>
+ "some.op"() : () -> !llvm2.func<i32 (half, bfloat, float, double)>
+ // CHECK: !llvm2.func<i32 (i32, i32)>
+ "some.op"() : () -> !llvm2.func<i32 (i32, i32)>
+ // CHECK: !llvm2.func<void (...)>
+ "some.op"() : () -> !llvm2.func<void (...)>
+ // CHECK: !llvm2.func<void (i32, i32, ...)>
+ "some.op"() : () -> !llvm2.func<void (i32, i32, ...)>
+ return
+}
+
+// CHECK-LABEL: @integer
+func @integer() {
+ // CHECK: !llvm2.i1
+ "some.op"() : () -> !llvm2.i1
+ // CHECK: !llvm2.i8
+ "some.op"() : () -> !llvm2.i8
+ // CHECK: !llvm2.i16
+ "some.op"() : () -> !llvm2.i16
+ // CHECK: !llvm2.i32
+ "some.op"() : () -> !llvm2.i32
+ // CHECK: !llvm2.i64
+ "some.op"() : () -> !llvm2.i64
+ // CHECK: !llvm2.i57
+ "some.op"() : () -> !llvm2.i57
+ // CHECK: !llvm2.i129
+ "some.op"() : () -> !llvm2.i129
+ return
+}
+
+// CHECK-LABEL: @ptr
+func @ptr() {
+ // CHECK: !llvm2.ptr<i8>
+ "some.op"() : () -> !llvm2.ptr<i8>
+ // CHECK: !llvm2.ptr<float>
+ "some.op"() : () -> !llvm2.ptr<float>
+ // CHECK: !llvm2.ptr<ptr<i8>>
+ "some.op"() : () -> !llvm2.ptr<ptr<i8>>
+ // CHECK: !llvm2.ptr<ptr<ptr<ptr<ptr<i8>>>>>
+ "some.op"() : () -> !llvm2.ptr<ptr<ptr<ptr<ptr<i8>>>>>
+ // CHECK: !llvm2.ptr<i8>
+ "some.op"() : () -> !llvm2.ptr<i8, 0>
+ // CHECK: !llvm2.ptr<i8, 1>
+ "some.op"() : () -> !llvm2.ptr<i8, 1>
+ // CHECK: !llvm2.ptr<i8, 42>
+ "some.op"() : () -> !llvm2.ptr<i8, 42>
+ // CHECK: !llvm2.ptr<ptr<i8, 42>, 9>
+ "some.op"() : () -> !llvm2.ptr<ptr<i8, 42>, 9>
+ return
+}
+
+// CHECK-LABEL: @vec
+func @vec() {
+ // CHECK: !llvm2.vec<4 x i32>
+ "some.op"() : () -> !llvm2.vec<4 x i32>
+ // CHECK: !llvm2.vec<4 x float>
+ "some.op"() : () -> !llvm2.vec<4 x float>
+ // CHECK: !llvm2.vec<? x 4 x i32>
+ "some.op"() : () -> !llvm2.vec<? x 4 x i32>
+ // CHECK: !llvm2.vec<? x 8 x half>
+ "some.op"() : () -> !llvm2.vec<? x 8 x half>
+ // CHECK: !llvm2.vec<4 x ptr<i8>>
+ "some.op"() : () -> !llvm2.vec<4 x ptr<i8>>
+ return
+}
+
+// CHECK-LABEL: @array
+func @array() {
+ // CHECK: !llvm2.array<10 x i32>
+ "some.op"() : () -> !llvm2.array<10 x i32>
+ // CHECK: !llvm2.array<8 x float>
+ "some.op"() : () -> !llvm2.array<8 x float>
+ // CHECK: !llvm2.array<10 x ptr<i32, 4>>
+ "some.op"() : () -> !llvm2.array<10 x ptr<i32, 4>>
+ // CHECK: !llvm2.array<10 x array<4 x float>>
+ "some.op"() : () -> !llvm2.array<10 x array<4 x float>>
+ return
+}
+
+// CHECK-LABEL: @literal_struct
+func @literal_struct() {
+ // CHECK: !llvm2.struct<()>
+ "some.op"() : () -> !llvm2.struct<()>
+ // CHECK: !llvm2.struct<(i32)>
+ "some.op"() : () -> !llvm2.struct<(i32)>
+ // CHECK: !llvm2.struct<(float, i32)>
+ "some.op"() : () -> !llvm2.struct<(float, i32)>
+ // CHECK: !llvm2.struct<(struct<(i32)>)>
+ "some.op"() : () -> !llvm2.struct<(struct<(i32)>)>
+ // CHECK: !llvm2.struct<(i32, struct<(i32)>, float)>
+ "some.op"() : () -> !llvm2.struct<(i32, struct<(i32)>, float)>
+
+ // CHECK: !llvm2.struct<packed ()>
+ "some.op"() : () -> !llvm2.struct<packed ()>
+ // CHECK: !llvm2.struct<packed (i32)>
+ "some.op"() : () -> !llvm2.struct<packed (i32)>
+ // CHECK: !llvm2.struct<packed (float, i32)>
+ "some.op"() : () -> !llvm2.struct<packed (float, i32)>
+ // CHECK: !llvm2.struct<packed (float, i32)>
+ "some.op"() : () -> !llvm2.struct<packed (float, i32)>
+ // CHECK: !llvm2.struct<packed (struct<(i32)>)>
+ "some.op"() : () -> !llvm2.struct<packed (struct<(i32)>)>
+ // CHECK: !llvm2.struct<packed (i32, struct<(i32, i1)>, float)>
+ "some.op"() : () -> !llvm2.struct<packed (i32, struct<(i32, i1)>, float)>
+
+ // CHECK: !llvm2.struct<(struct<packed (i32)>)>
+ "some.op"() : () -> !llvm2.struct<(struct<packed (i32)>)>
+ // CHECK: !llvm2.struct<packed (struct<(i32)>)>
+ "some.op"() : () -> !llvm2.struct<packed (struct<(i32)>)>
+ return
+}
+
+// CHECK-LABEL: @identified_struct
+func @identified_struct() {
+ // CHECK: !llvm2.struct<"empty", ()>
+ "some.op"() : () -> !llvm2.struct<"empty", ()>
+ // CHECK: !llvm2.struct<"opaque", opaque>
+ "some.op"() : () -> !llvm2.struct<"opaque", opaque>
+ // CHECK: !llvm2.struct<"long", (i32, struct<(i32, i1)>, float, ptr<func<void ()>>)>
+ "some.op"() : () -> !llvm2.struct<"long", (i32, struct<(i32, i1)>, float, ptr<func<void ()>>)>
+ // CHECK: !llvm2.struct<"self-recursive", (ptr<struct<"self-recursive">>)>
+ "some.op"() : () -> !llvm2.struct<"self-recursive", (ptr<struct<"self-recursive">>)>
+ // CHECK: !llvm2.struct<"unpacked", (i32)>
+ "some.op"() : () -> !llvm2.struct<"unpacked", (i32)>
+ // CHECK: !llvm2.struct<"packed", packed (i32)>
+ "some.op"() : () -> !llvm2.struct<"packed", packed (i32)>
+ // CHECK: !llvm2.struct<"name with spaces and !^$@$#", packed (i32)>
+ "some.op"() : () -> !llvm2.struct<"name with spaces and !^$@$#", packed (i32)>
+
+ // CHECK: !llvm2.struct<"mutually-a", (ptr<struct<"mutually-b", (ptr<struct<"mutually-a">, 3>)>>)>
+ "some.op"() : () -> !llvm2.struct<"mutually-a", (ptr<struct<"mutually-b", (ptr<struct<"mutually-a">, 3>)>>)>
+ // CHECK: !llvm2.struct<"mutually-b", (ptr<struct<"mutually-a", (ptr<struct<"mutually-b">>)>, 3>)>
+ "some.op"() : () -> !llvm2.struct<"mutually-b", (ptr<struct<"mutually-a", (ptr<struct<"mutually-b">>)>, 3>)>
+ // CHECK: !llvm2.struct<"referring-another", (ptr<struct<"unpacked", (i32)>>)>
+ "some.op"() : () -> !llvm2.struct<"referring-another", (ptr<struct<"unpacked", (i32)>>)>
+
+ // CHECK: !llvm2.struct<"struct-of-arrays", (array<10 x i32>)>
+ "some.op"() : () -> !llvm2.struct<"struct-of-arrays", (array<10 x i32>)>
+ // CHECK: !llvm2.array<10 x struct<"array-of-structs", (i32)>>
+ "some.op"() : () -> !llvm2.array<10 x struct<"array-of-structs", (i32)>>
+ // CHECK: !llvm2.ptr<struct<"ptr-to-struct", (i8)>>
+ "some.op"() : () -> !llvm2.ptr<struct<"ptr-to-struct", (i8)>>
+ return
+}
+
diff --git a/mlir/test/lib/Dialect/CMakeLists.txt b/mlir/test/lib/Dialect/CMakeLists.txt
index 9008b86314be..36a18f79a8cb 100644
--- a/mlir/test/lib/Dialect/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/CMakeLists.txt
@@ -1,3 +1,4 @@
add_subdirectory(Affine)
+add_subdirectory(LLVMIR)
add_subdirectory(SPIRV)
add_subdirectory(Test)
diff --git a/mlir/test/lib/Dialect/LLVMIR/CMakeLists.txt b/mlir/test/lib/Dialect/LLVMIR/CMakeLists.txt
new file mode 100644
index 000000000000..2a42bc697485
--- /dev/null
+++ b/mlir/test/lib/Dialect/LLVMIR/CMakeLists.txt
@@ -0,0 +1,14 @@
+
+add_mlir_library(MLIRLLVMTypeTestDialect
+ LLVMTypeTestDialect.cpp
+
+ EXCLUDE_FROM_LIBMLIR
+
+ LINK_COMPONENTS
+ Core
+
+ LINK_LIBS PUBLIC
+ MLIRDialect
+ MLIRIR
+ MLIRLLVMIR
+ )
diff --git a/mlir/test/lib/Dialect/LLVMIR/LLVMTypeTestDialect.cpp b/mlir/test/lib/Dialect/LLVMIR/LLVMTypeTestDialect.cpp
new file mode 100644
index 000000000000..8ac1ef0a8c17
--- /dev/null
+++ b/mlir/test/lib/Dialect/LLVMIR/LLVMTypeTestDialect.cpp
@@ -0,0 +1,52 @@
+#ifndef DIALECT_LLVMIR_LLVMTYPETESTDIALECT_H_
+#define DIALECT_LLVMIR_LLVMTYPETESTDIALECT_H_
+
+#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
+#include "mlir/IR/Dialect.h"
+
+namespace mlir {
+namespace LLVM {
+namespace {
+class LLVMDialectNewTypes : public Dialect {
+public:
+ LLVMDialectNewTypes(MLIRContext *ctx) : Dialect(getDialectNamespace(), ctx) {
+ // clang-format off
+ addTypes<LLVMVoidType,
+ LLVMHalfType,
+ LLVMBFloatType,
+ LLVMFloatType,
+ LLVMDoubleType,
+ LLVMFP128Type,
+ LLVMX86FP80Type,
+ LLVMPPCFP128Type,
+ LLVMX86MMXType,
+ LLVMTokenType,
+ LLVMLabelType,
+ LLVMMetadataType,
+ LLVMFunctionType,
+ LLVMIntegerType,
+ LLVMPointerType,
+ LLVMFixedVectorType,
+ LLVMScalableVectorType,
+ LLVMArrayType,
+ LLVMStructType>();
+ // clang-format on
+ }
+ static StringRef getDialectNamespace() { return "llvm2"; }
+
+ Type parseType(DialectAsmParser &parser) const override {
+ return detail::parseType(parser);
+ }
+ void printType(Type type, DialectAsmPrinter &printer) const override {
+ detail::printType(type.cast<LLVMTypeNew>(), printer);
+ }
+};
+} // namespace
+} // namespace LLVM
+
+void registerLLVMTypeTestDialect() {
+ mlir::registerDialect<LLVM::LLVMDialectNewTypes>();
+}
+} // namespace mlir
+
+#endif // DIALECT_LLVMIR_LLVMTYPETESTDIALECT_H_
diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt
index 483dcfec0c0f..f52c5f41b22b 100644
--- a/mlir/tools/mlir-opt/CMakeLists.txt
+++ b/mlir/tools/mlir-opt/CMakeLists.txt
@@ -13,6 +13,7 @@ set(LLVM_LINK_COMPONENTS
if(MLIR_INCLUDE_TESTS)
set(test_libs
MLIRAffineTransformsTestPasses
+ MLIRLLVMTypeTestDialect
MLIRSPIRVTestPasses
MLIRTestDialect
MLIRTestIR
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 620c5871a420..05fba34092cb 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -31,6 +31,7 @@ namespace mlir {
// Defined in the test directory, no public header.
void registerConvertToTargetEnvPass();
void registerInliner();
+void registerLLVMTypeTestDialect();
void registerMemRefBoundCheck();
void registerPassManagerTestPass();
void registerPatternsTestPass();
@@ -39,10 +40,9 @@ void registerSideEffectTestPasses();
void registerSimpleParametricTilingPass();
void registerSymbolTestPasses();
void registerTestAffineDataCopyPass();
-void registerTestAllReduceLoweringPass();
void registerTestAffineLoopUnswitchingPass();
+void registerTestAllReduceLoweringPass();
void registerTestBufferPlacementPreparationPass();
-void registerTestLoopPermutationPass();
void registerTestCallGraphPass();
void registerTestConstantFold();
void registerTestConvertGPUKernelToCubinPass();
@@ -51,12 +51,14 @@ void registerTestDominancePass();
void registerTestExpandTanhPass();
void registerTestFunc();
void registerTestGpuMemoryPromotionPass();
+void registerTestGpuParallelLoopMappingPass();
void registerTestInterfaces();
void registerTestLinalgHoisting();
void registerTestLinalgTransforms();
void registerTestLivenessPass();
void registerTestLoopFusion();
void registerTestLoopMappingPass();
+void registerTestLoopPermutationPass();
void registerTestLoopUnrollingPass();
void registerTestMatchers();
void registerTestMemRefDependenceCheck();
@@ -65,7 +67,6 @@ void registerTestOpaqueLoc();
void registerTestPreparationPassWithAllowedMemrefResults();
void registerTestRecursiveTypesPass();
void registerTestReducer();
-void registerTestGpuParallelLoopMappingPass();
void registerTestSpirvEntryPointABIPass();
void registerTestSCFUtilsPass();
void registerTestVectorConversions();
@@ -104,6 +105,7 @@ static cl::opt<bool> allowUnregisteredDialects(
void registerTestPasses() {
registerConvertToTargetEnvPass();
registerInliner();
+ registerLLVMTypeTestDialect();
registerMemRefBoundCheck();
registerPassManagerTestPass();
registerPatternsTestPass();
More information about the Mlir-commits
mailing list