[Mlir-commits] [mlir] c69c9e0 - [mlir] Remove LLVMType, LLVM dialect types now derive Type directly
Alex Zinenko
llvmlistbot at llvm.org
Tue Jan 5 08:37:03 PST 2021
Author: Alex Zinenko
Date: 2021-01-05T17:36:54+01:00
New Revision: c69c9e0f0fd2e5b72c4a1947822a4961b8630123
URL: https://github.com/llvm/llvm-project/commit/c69c9e0f0fd2e5b72c4a1947822a4961b8630123
DIFF: https://github.com/llvm/llvm-project/commit/c69c9e0f0fd2e5b72c4a1947822a4961b8630123.diff
LOG: [mlir] Remove LLVMType, LLVM dialect types now derive Type directly
BEGIN_PUBLIC
[mlir] Remove LLVMType, LLVM dialect types now derive Type directly
This class has become a simple `isa` hook with no proper functionality.
Removing will allow us to eventually make the LLVM dialect type infrastructure
open, i.e., support non-LLVM types inside container types, which itself will
make the type conversion more progressive.
Introduce a call `LLVM::isCompatibleType` to be used instead of
`isa<LLVMType>`. For now, this is strictly equivalent.
END_PUBLIC
Depends On D93681
Reviewed By: mehdi_amini
Differential Revision: https://reviews.llvm.org/D93713
Added:
Modified:
mlir/docs/Tutorials/Toy/Ch-6.md
mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
mlir/include/mlir/Target/LLVMIR/TypeTranslation.h
mlir/lib/Conversion/ArmSVEToLLVM/ArmSVEToLLVM.cpp
mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp
mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp
mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
mlir/lib/Dialect/LLVMIR/IR/TypeDetail.h
mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
mlir/lib/Target/LLVMIR/TypeTranslation.cpp
mlir/test/Dialect/LLVMIR/invalid.mlir
mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp
Removed:
################################################################################
diff --git a/mlir/docs/Tutorials/Toy/Ch-6.md b/mlir/docs/Tutorials/Toy/Ch-6.md
index 86805c10831a..c2211412e5c4 100644
--- a/mlir/docs/Tutorials/Toy/Ch-6.md
+++ b/mlir/docs/Tutorials/Toy/Ch-6.md
@@ -37,10 +37,11 @@ static FlatSymbolRefAttr getOrInsertPrintf(PatternRewriter &rewriter,
// Create a function declaration for printf, the signature is:
// * `i32 (i8*, ...)`
- auto llvmI32Ty = LLVM::LLVMType::getInt32Ty(llvmDialect);
- auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
- auto llvmFnType = LLVM::LLVMType::getFunctionTy(llvmI32Ty, llvmI8PtrTy,
- /*isVarArg=*/true);
+ auto llvmI32Ty = LLVM::LLVMIntegerType::get(context, 32);
+ auto llvmI8PtrTy =
+ LLVM::LLVMPointerType::get(LLVM::LLVMIntegerType::get(context, 8));
+ auto llvmFnType = LLVM::LLVMFunctionType::get(llvmI32Ty, llvmI8PtrTy,
+ /*isVarArg=*/true);
// Insert the printf function into the body of the parent module.
PatternRewriter::InsertionGuard insertGuard(rewriter);
diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
index d5c1e923fab9..686a8ca01d9e 100644
--- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
+++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
@@ -34,7 +34,6 @@ class UnrankedMemRefType;
namespace LLVM {
class LLVMDialect;
-class LLVMType;
class LLVMPointerType;
} // namespace LLVM
@@ -71,8 +70,8 @@ class LLVMTypeConverter : public TypeConverter {
/// Convert a function type. The arguments and results are converted one by
/// one and results are packed into a wrapped LLVM IR structure type. `result`
/// is populated with argument mapping.
- LLVM::LLVMType convertFunctionSignature(FunctionType funcTy, bool isVariadic,
- SignatureConversion &result);
+ Type convertFunctionSignature(FunctionType funcTy, bool isVariadic,
+ SignatureConversion &result);
/// Convert a non-empty list of types to be returned from a function into a
/// supported LLVM IR type. In particular, if more than one value is
@@ -118,14 +117,14 @@ class LLVMTypeConverter : public TypeConverter {
/// Converts the function type to a C-compatible format, in particular using
/// pointers to memref descriptors for arguments.
- LLVM::LLVMType convertFunctionTypeCWrapper(FunctionType type);
+ Type convertFunctionTypeCWrapper(FunctionType type);
/// Returns the data layout to use during and after conversion.
const llvm::DataLayout &getDataLayout() { return options.dataLayout; }
/// Gets the LLVM representation of the index type. The returned type is an
/// integer type with the size configured for this type converter.
- LLVM::LLVMType getIndexType();
+ Type getIndexType();
/// Gets the bitwidth of the index type when converted to LLVM.
unsigned getIndexTypeBitwidth() { return options.indexBitwidth; }
@@ -185,8 +184,8 @@ class LLVMTypeConverter : public TypeConverter {
/// - `!llvm.i64`, `!llvm.i64` (sizes),
/// - `!llvm.i64`, `!llvm.i64` (strides).
/// These types can be recomposed to a memref descriptor struct.
- SmallVector<LLVM::LLVMType, 5>
- getMemRefDescriptorFields(MemRefType type, bool unpackAggregates);
+ SmallVector<Type, 5> getMemRefDescriptorFields(MemRefType type,
+ bool unpackAggregates);
/// Convert an unranked memref type into a list of non-aggregate LLVM IR types
/// that will form the unranked memref descriptor. In particular, this list
@@ -197,7 +196,7 @@ class LLVMTypeConverter : public TypeConverter {
/// !llvm.i64 (rank)
/// !llvm<"i8*"> (type-erased pointer).
/// These types can be recomposed to a unranked memref descriptor struct.
- SmallVector<LLVM::LLVMType, 2> getUnrankedMemRefDescriptorFields();
+ SmallVector<Type, 2> getUnrankedMemRefDescriptorFields();
// Convert an unranked memref type to an LLVM type that captures the
// runtime rank and a pointer to the static ranked memref desc
@@ -417,31 +416,30 @@ class UnrankedMemRefDescriptor : public StructBuilder {
/// Builds IR extracting the allocated pointer from the descriptor.
static Value allocatedPtr(OpBuilder &builder, Location loc,
- Value memRefDescPtr, LLVM::LLVMType elemPtrPtrType);
+ Value memRefDescPtr, Type elemPtrPtrType);
/// Builds IR inserting the allocated pointer into the descriptor.
static void setAllocatedPtr(OpBuilder &builder, Location loc,
- Value memRefDescPtr,
- LLVM::LLVMType elemPtrPtrType,
+ Value memRefDescPtr, Type elemPtrPtrType,
Value allocatedPtr);
/// Builds IR extracting the aligned pointer from the descriptor.
static Value alignedPtr(OpBuilder &builder, Location loc,
LLVMTypeConverter &typeConverter, Value memRefDescPtr,
- LLVM::LLVMType elemPtrPtrType);
+ Type elemPtrPtrType);
/// Builds IR inserting the aligned pointer into the descriptor.
static void setAlignedPtr(OpBuilder &builder, Location loc,
LLVMTypeConverter &typeConverter,
- Value memRefDescPtr, LLVM::LLVMType elemPtrPtrType,
+ Value memRefDescPtr, Type elemPtrPtrType,
Value alignedPtr);
/// Builds IR extracting the offset from the descriptor.
static Value offset(OpBuilder &builder, Location loc,
LLVMTypeConverter &typeConverter, Value memRefDescPtr,
- LLVM::LLVMType elemPtrPtrType);
+ Type elemPtrPtrType);
/// Builds IR inserting the offset into the descriptor.
static void setOffset(OpBuilder &builder, Location loc,
LLVMTypeConverter &typeConverter, Value memRefDescPtr,
- LLVM::LLVMType elemPtrPtrType, Value offset);
+ Type elemPtrPtrType, Value offset);
/// Builds IR extracting the pointer to the first element of the size array.
static Value sizeBasePtr(OpBuilder &builder, Location loc,
@@ -490,17 +488,17 @@ class ConvertToLLVMPattern : public ConversionPattern {
/// Gets the MLIR type wrapping the LLVM integer type whose bit width is
/// defined by the used type converter.
- LLVM::LLVMType getIndexType() const;
+ Type getIndexType() const;
/// Gets the MLIR type wrapping the LLVM integer type whose bit width
/// corresponds to that of a LLVM pointer type.
- LLVM::LLVMType getIntPtrType(unsigned addressSpace = 0) const;
+ Type getIntPtrType(unsigned addressSpace = 0) const;
/// Gets the MLIR type wrapping the LLVM void type.
- LLVM::LLVMType getVoidType() const;
+ Type getVoidType() const;
/// Get the MLIR type wrapping the LLVM i8* type.
- LLVM::LLVMType getVoidPtrType() const;
+ Type getVoidPtrType() const;
/// Create an LLVM dialect operation defining the given index constant.
Value createIndexConstant(ConversionPatternRewriter &builder, Location loc,
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
index ef5efc9c8281..3b7cd5ea9184 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
@@ -49,8 +49,8 @@ def LLVM_Dialect : Dialect {
// LLVM dialect type.
def LLVM_Type : DialectType<LLVM_Dialect,
- CPred<"$_self.isa<::mlir::LLVM::LLVMType>()">,
- "LLVM dialect type">;
+ CPred<"::mlir::LLVM::isCompatibleType($_self)">,
+ "LLVM dialect-compatible type">;
// Type constraint accepting LLVM integer types.
def LLVM_AnyInteger : Type<
@@ -223,9 +223,9 @@ class ListIntSubst<string pattern, list<int> values> {
// or result in the operation.
def LLVM_IntrPatterns {
string operand =
- [{convertType(opInst.getOperand($0).getType().cast<LLVM::LLVMType>())}];
+ [{convertType(opInst.getOperand($0).getType())}];
string result =
- [{convertType(opInst.getResult($0).getType().cast<LLVM::LLVMType>())}];
+ [{convertType(opInst.getResult($0).getType())}];
string structResult =
[{convertType(opInst.getResult(0).getType().cast<LLVM::LLVMStructType>()
.getBody()[$0])}];
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 4968b33f47a4..428ca6783afd 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -49,9 +49,8 @@ def LLVM_VoidResultTypeOpBuilder :
OpBuilderDAG<(ins "Type":$resultType, "ValueRange":$operands,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
[{
- auto llvmType = resultType.dyn_cast<LLVMType>(); (void)llvmType;
- assert(llvmType && "result must be an LLVM type");
- assert(llvmType.isa<LLVMVoidType>() &&
+ assert(isCompatibleType(resultType) && "result must be an LLVM type");
+ assert(resultType.isa<LLVMVoidType>() &&
"for zero-result operands, only 'void' is accepted as result type");
build($_builder, $_state, operands, attributes);
}]>;
@@ -443,7 +442,7 @@ def LLVM_CallOp : LLVM_Op<"call"> {
OpBuilderDAG<(ins "LLVMFuncOp":$func, "ValueRange":$operands,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
[{
- LLVMType resultType = func.getType().getReturnType();
+ Type resultType = func.getType().getReturnType();
if (!resultType.isa<LLVM::LLVMVoidType>())
$_state.addTypes(resultType);
$_state.addAttribute("callee", $_builder.getSymbolRefAttr(func));
@@ -756,23 +755,21 @@ def LLVM_AddressOfOp : LLVM_Op<"mlir.addressof"> {
}];
let builders = [
- OpBuilderDAG<(ins "LLVMType":$resType, "StringRef":$name,
- CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
- [{
- $_state.addAttribute("global_name",$_builder.getSymbolRefAttr(name));
- $_state.addAttributes(attrs);
- $_state.addTypes(resType);}]>,
OpBuilderDAG<(ins "GlobalOp":$global,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
[{
build($_builder, $_state,
LLVM::LLVMPointerType::get(global.getType(), global.addr_space()),
- global.sym_name(), attrs);}]>,
+ global.sym_name());
+ $_state.addAttributes(attrs);
+ }]>,
OpBuilderDAG<(ins "LLVMFuncOp":$func,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
[{
build($_builder, $_state,
- LLVM::LLVMPointerType::get(func.getType()), func.getName(), attrs);}]>
+ LLVM::LLVMPointerType::get(func.getType()), func.getName());
+ $_state.addAttributes(attrs);
+ }]>
];
let extraClassDeclaration = [{
@@ -883,15 +880,15 @@ def LLVM_GlobalOp : LLVM_Op<"mlir.global",
let regions = (region AnyRegion:$initializer);
let builders = [
- OpBuilderDAG<(ins "LLVMType":$type, "bool":$isConstant, "Linkage":$linkage,
+ OpBuilderDAG<(ins "Type":$type, "bool":$isConstant, "Linkage":$linkage,
"StringRef":$name, "Attribute":$value, CArg<"unsigned", "0">:$addrSpace,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>
];
let extraClassDeclaration = [{
/// Return the LLVM type of the global.
- LLVMType getType() {
- return type().cast<LLVMType>();
+ Type getType() {
+ return type();
}
/// Return the initializer attribute if it exists, or a null attribute.
Attribute getValueOrNull() {
@@ -957,7 +954,7 @@ def LLVM_LLVMFuncOp : LLVM_Op<"func",
let skipDefaultBuilders = 1;
let builders = [
- OpBuilderDAG<(ins "StringRef":$name, "LLVMType":$type,
+ OpBuilderDAG<(ins "StringRef":$name, "Type":$type,
CArg<"Linkage", "Linkage::External">:$linkage,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs,
CArg<"ArrayRef<DictionaryAttr>", "{}">:$argAttrs)>
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
index 7c7731946ba8..6f78118a5b00 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
@@ -45,52 +45,13 @@ class LLVMFP128Type;
class LLVMX86FP80Type;
class LLVMIntegerType;
-//===----------------------------------------------------------------------===//
-// LLVMType.
-//===----------------------------------------------------------------------===//
-
-/// Base class for LLVM dialect types.
-///
-/// The LLVM dialect in MLIR fully reflects the LLVM IR type system, prodiving a
-/// separate MLIR type for each LLVM IR type. All types are represented as
-/// separate subclasses and are compatible with the isa/cast infrastructure.
-///
-/// The LLVM dialect type system is closed: parametric types can only refer to
-/// other LLVM dialect types. This is consistent with LLVM IR and enables a more
-/// concise pretty-printing format.
-///
-/// Similarly to other MLIR types, LLVM dialect types are owned by the MLIR
-/// context, have an immutable identifier (for most types except identified
-/// structs, the entire type is the identifier) and are thread-safe.
-///
-/// This class is a thin common base class for
diff erent types available in the
-/// LLVM dialect. It intentionally does not provide the API similar to
-/// llvm::Type to avoid confusion and highlight potentially expensive operations
-/// (e.g., type creation in MLIR takes a lock, so it's better to cache types).
-class LLVMType : public Type {
-public:
- /// Inherit base constructors.
- using Type::Type;
-
- /// Support for PointerLikeTypeTraits.
- using Type::getAsOpaquePointer;
- static LLVMType getFromOpaquePointer(const void *ptr) {
- return LLVMType(static_cast<ImplType *>(const_cast<void *>(ptr)));
- }
-
- /// Support for isa/cast.
- static bool classof(Type type);
-
- LLVMDialect &getDialect();
-};
-
//===----------------------------------------------------------------------===//
// Trivial types.
//===----------------------------------------------------------------------===//
// Batch-define trivial types.
#define DEFINE_TRIVIAL_LLVM_TYPE(ClassName) \
- class ClassName : public Type::TypeBase<ClassName, LLVMType, TypeStorage> { \
+ class ClassName : public Type::TypeBase<ClassName, Type, TypeStorage> { \
public: \
using Base::Base; \
}
@@ -117,30 +78,30 @@ DEFINE_TRIVIAL_LLVM_TYPE(LLVMMetadataType);
/// LLVM dialect array type. It is an aggregate type representing consecutive
/// elements in memory, parameterized by the number of elements and the element
/// type.
-class LLVMArrayType : public Type::TypeBase<LLVMArrayType, LLVMType,
+class LLVMArrayType : public Type::TypeBase<LLVMArrayType, Type,
detail::LLVMTypeAndSizeStorage> {
public:
/// Inherit base constructors.
using Base::Base;
/// Checks if the given type can be used inside an array type.
- static bool isValidElementType(LLVMType type);
+ static bool isValidElementType(Type type);
/// Gets or creates an instance of LLVM dialect array type containing
/// `numElements` of `elementType`, in the same context as `elementType`.
- static LLVMArrayType get(LLVMType elementType, unsigned numElements);
- static LLVMArrayType getChecked(Location loc, LLVMType elementType,
+ static LLVMArrayType get(Type elementType, unsigned numElements);
+ static LLVMArrayType getChecked(Location loc, Type elementType,
unsigned numElements);
/// Returns the element type of the array.
- LLVMType getElementType();
+ Type getElementType();
/// Returns the number of elements in the array type.
unsigned getNumElements();
/// Verifies that the type about to be constructed is well-formed.
static LogicalResult verifyConstructionInvariants(Location loc,
- LLVMType elementType,
+ Type elementType,
unsigned numElements);
};
@@ -152,46 +113,46 @@ class LLVMArrayType : public Type::TypeBase<LLVMArrayType, LLVMType,
/// which can have multiple), a list of parameter types and can optionally be
/// variadic.
class LLVMFunctionType
- : public Type::TypeBase<LLVMFunctionType, LLVMType,
+ : public Type::TypeBase<LLVMFunctionType, Type,
detail::LLVMFunctionTypeStorage> {
public:
/// Inherit base constructors.
using Base::Base;
/// Checks if the given type can be used an argument in a function type.
- static bool isValidArgumentType(LLVMType type);
+ static bool isValidArgumentType(Type type);
/// Checks if the given type can be used as a result in a function type.
- static bool isValidResultType(LLVMType type);
+ static bool isValidResultType(Type type);
/// Returns whether the function is variadic.
bool isVarArg();
/// Gets or creates an instance of LLVM dialect function in the same context
/// as the `result` type.
- static LLVMFunctionType get(LLVMType result, ArrayRef<LLVMType> arguments,
+ static LLVMFunctionType get(Type result, ArrayRef<Type> arguments,
bool isVarArg = false);
- static LLVMFunctionType getChecked(Location loc, LLVMType result,
- ArrayRef<LLVMType> arguments,
+ static LLVMFunctionType getChecked(Location loc, Type result,
+ ArrayRef<Type> arguments,
bool isVarArg = false);
/// Returns the result type of the function.
- LLVMType getReturnType();
+ Type getReturnType();
/// Returns the number of arguments to the function.
unsigned getNumParams();
/// Returns `i`-th argument of the function. Asserts on out-of-bounds.
- LLVMType getParamType(unsigned i);
+ Type getParamType(unsigned i);
/// Returns a list of argument types of the function.
- ArrayRef<LLVMType> getParams();
- ArrayRef<LLVMType> params() { return getParams(); }
+ ArrayRef<Type> getParams();
+ ArrayRef<Type> params() { return getParams(); }
/// Verifies that the type about to be constructed is well-formed.
- static LogicalResult
- verifyConstructionInvariants(Location loc, LLVMType result,
- ArrayRef<LLVMType> arguments, bool);
+ static LogicalResult verifyConstructionInvariants(Location loc, Type result,
+ ArrayRef<Type> arguments,
+ bool);
};
//===----------------------------------------------------------------------===//
@@ -199,7 +160,7 @@ class LLVMFunctionType
//===----------------------------------------------------------------------===//
/// LLVM dialect signless integer type parameterized by bitwidth.
-class LLVMIntegerType : public Type::TypeBase<LLVMIntegerType, LLVMType,
+class LLVMIntegerType : public Type::TypeBase<LLVMIntegerType, Type,
detail::LLVMIntegerTypeStorage> {
public:
/// Inherit base constructor.
@@ -225,31 +186,31 @@ class LLVMIntegerType : public Type::TypeBase<LLVMIntegerType, LLVMType,
/// LLVM dialect pointer type. This type typically represents a reference to an
/// object in memory. It is parameterized by the element type and the address
/// space.
-class LLVMPointerType : public Type::TypeBase<LLVMPointerType, LLVMType,
+class LLVMPointerType : public Type::TypeBase<LLVMPointerType, Type,
detail::LLVMPointerTypeStorage> {
public:
/// Inherit base constructors.
using Base::Base;
/// Checks if the given type can have a pointer type pointing to it.
- static bool isValidElementType(LLVMType type);
+ static bool isValidElementType(Type type);
/// Gets or creates an instance of LLVM dialect pointer type pointing to an
/// object of `pointee` type in the given address space. The pointer type is
/// created in the same context as `pointee`.
- static LLVMPointerType get(LLVMType pointee, unsigned addressSpace = 0);
- static LLVMPointerType getChecked(Location loc, LLVMType pointee,
+ static LLVMPointerType get(Type pointee, unsigned addressSpace = 0);
+ static LLVMPointerType getChecked(Location loc, Type pointee,
unsigned addressSpace = 0);
/// Returns the pointed-to type.
- LLVMType getElementType();
+ Type getElementType();
/// Returns the address space of the pointer.
unsigned getAddressSpace();
/// Verifies that the type about to be constructed is well-formed.
- static LogicalResult verifyConstructionInvariants(Location loc,
- LLVMType pointee, unsigned);
+ static LogicalResult verifyConstructionInvariants(Location loc, Type pointee,
+ unsigned);
};
//===----------------------------------------------------------------------===//
@@ -280,14 +241,14 @@ class LLVMPointerType : public Type::TypeBase<LLVMPointerType, LLVMType,
///
/// Note that the packedness of the struct takes place in uniquing of literal
/// structs, but does not in uniquing of identified structs.
-class LLVMStructType : public Type::TypeBase<LLVMStructType, LLVMType,
+class LLVMStructType : public Type::TypeBase<LLVMStructType, Type,
detail::LLVMStructTypeStorage> {
public:
/// Inherit base constructors.
using Base::Base;
/// Checks if the given type can be contained in a structure type.
- static bool isValidElementType(LLVMType type);
+ static bool isValidElementType(Type type);
/// Gets or creates an identified struct with the given name in the provided
/// context. Note that unlike llvm::StructType::create, this function will
@@ -302,16 +263,14 @@ class LLVMStructType : public Type::TypeBase<LLVMStructType, LLVMType,
/// the struct by appending a `.` followed by a number to the name. Renaming
/// happens even if the existing struct has the same body.
static LLVMStructType getNewIdentified(MLIRContext *context, StringRef name,
- ArrayRef<LLVMType> elements,
+ ArrayRef<Type> elements,
bool isPacked = false);
/// Gets or creates a literal struct with the given body in the provided
/// context.
- static LLVMStructType getLiteral(MLIRContext *context,
- ArrayRef<LLVMType> types,
+ static LLVMStructType getLiteral(MLIRContext *context, ArrayRef<Type> types,
bool isPacked = false);
- static LLVMStructType getLiteralChecked(Location loc,
- ArrayRef<LLVMType> types,
+ static LLVMStructType getLiteralChecked(Location loc, ArrayRef<Type> types,
bool isPacked = false);
/// Gets or creates an intentionally-opaque identified struct. Such a struct
@@ -329,7 +288,7 @@ class LLVMStructType : public Type::TypeBase<LLVMStructType, LLVMType,
///
diff erent thread modified the struct after it was created. Most callers
/// are likely to assert this always succeeds, but it is possible to implement
/// a local renaming scheme based on the result of this call.
- LogicalResult setBody(ArrayRef<LLVMType> types, bool isPacked);
+ LogicalResult setBody(ArrayRef<Type> types, bool isPacked);
/// Checks if a struct is packed.
bool isPacked();
@@ -347,12 +306,12 @@ class LLVMStructType : public Type::TypeBase<LLVMStructType, LLVMType,
StringRef getName();
/// Returns the list of element types contained in a non-opaque struct.
- ArrayRef<LLVMType> getBody();
+ ArrayRef<Type> getBody();
/// Verifies that the type about to be constructed is well-formed.
static LogicalResult verifyConstructionInvariants(Location, StringRef, bool);
- static LogicalResult
- verifyConstructionInvariants(Location loc, ArrayRef<LLVMType> types, bool);
+ static LogicalResult verifyConstructionInvariants(Location loc,
+ ArrayRef<Type> types, bool);
};
//===----------------------------------------------------------------------===//
@@ -362,26 +321,26 @@ class LLVMStructType : public Type::TypeBase<LLVMStructType, LLVMType,
/// LLVM dialect vector type, represents a sequence of elements that can be
/// processed as one, typically in SIMD context. This is a base class for fixed
/// and scalable vectors.
-class LLVMVectorType : public LLVMType {
+class LLVMVectorType : public Type {
public:
/// Inherit base constructor.
- using LLVMType::LLVMType;
+ using Type::Type;
/// Support type casting functionality.
static bool classof(Type type);
/// Checks if the given type can be used in a vector type.
- static bool isValidElementType(LLVMType type);
+ static bool isValidElementType(Type type);
/// Returns the element type of the vector.
- LLVMType getElementType();
+ Type getElementType();
/// Returns the number of elements in the vector.
llvm::ElementCount getElementCount();
/// Verifies that the type about to be constructed is well-formed.
static LogicalResult verifyConstructionInvariants(Location loc,
- LLVMType elementType,
+ Type elementType,
unsigned numElements);
};
@@ -401,8 +360,8 @@ class LLVMFixedVectorType
/// Gets or creates a fixed vector type containing `numElements` of
/// `elementType` in the same context as `elementType`.
- static LLVMFixedVectorType get(LLVMType elementType, unsigned numElements);
- static LLVMFixedVectorType getChecked(Location loc, LLVMType elementType,
+ static LLVMFixedVectorType get(Type elementType, unsigned numElements);
+ static LLVMFixedVectorType getChecked(Location loc, Type elementType,
unsigned numElements);
/// Returns the number of elements in the fixed vector.
@@ -426,9 +385,8 @@ class LLVMScalableVectorType
/// Gets or creates a scalable vector type containing a non-zero multiple of
/// `minNumElements` of `elementType` in the same context as `elementType`.
- static LLVMScalableVectorType get(LLVMType elementType,
- unsigned minNumElements);
- static LLVMScalableVectorType getChecked(Location loc, LLVMType elementType,
+ static LLVMScalableVectorType get(Type elementType, unsigned minNumElements);
+ static LLVMScalableVectorType getChecked(Location loc, Type elementType,
unsigned minNumElements);
/// Returns the scaling factor of the number of elements in the vector. The
@@ -443,10 +401,10 @@ class LLVMScalableVectorType
namespace detail {
/// Parses an LLVM dialect type.
-LLVMType parseType(DialectAsmParser &parser);
+Type parseType(DialectAsmParser &parser);
/// Prints an LLVM Dialect type.
-void printType(LLVMType type, DialectAsmPrinter &printer);
+void printType(Type type, DialectAsmPrinter &printer);
} // namespace detail
//===----------------------------------------------------------------------===//
@@ -454,7 +412,30 @@ void printType(LLVMType type, DialectAsmPrinter &printer);
//===----------------------------------------------------------------------===//
/// Returns `true` if the given type is compatible with the LLVM dialect.
-inline bool isCompatibleType(Type type) { return type.isa<LLVMType>(); }
+inline bool isCompatibleType(Type type) {
+ // clang-format off
+ return type.isa<
+ LLVMArrayType,
+ LLVMBFloatType,
+ LLVMDoubleType,
+ LLVMFP128Type,
+ LLVMFloatType,
+ LLVMFunctionType,
+ LLVMHalfType,
+ LLVMIntegerType,
+ LLVMLabelType,
+ LLVMMetadataType,
+ LLVMPPCFP128Type,
+ LLVMPointerType,
+ LLVMStructType,
+ LLVMTokenType,
+ LLVMVectorType,
+ LLVMVoidType,
+ LLVMX86FP80Type,
+ LLVMX86MMXType
+ >();
+ // clang-format on
+}
inline bool isCompatibleFloatingPointType(Type type) {
return type.isa<LLVMHalfType, LLVMBFloatType, LLVMFloatType, LLVMDoubleType,
@@ -470,46 +451,4 @@ llvm::TypeSize getPrimitiveTypeSizeInBits(Type type);
} // namespace LLVM
} // namespace mlir
-//===----------------------------------------------------------------------===//
-// Support for hashing and containers.
-//===----------------------------------------------------------------------===//
-
-namespace llvm {
-
-// LLVMType instances hash just like pointers.
-template <>
-struct DenseMapInfo<mlir::LLVM::LLVMType> {
- static mlir::LLVM::LLVMType getEmptyKey() {
- void *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
- return mlir::LLVM::LLVMType(
- static_cast<mlir::LLVM::LLVMType::ImplType *>(pointer));
- }
- static mlir::LLVM::LLVMType getTombstoneKey() {
- void *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
- return mlir::LLVM::LLVMType(
- static_cast<mlir::LLVM::LLVMType::ImplType *>(pointer));
- }
- static unsigned getHashValue(mlir::LLVM::LLVMType val) {
- return mlir::hash_value(val);
- }
- static bool isEqual(mlir::LLVM::LLVMType lhs, mlir::LLVM::LLVMType rhs) {
- return lhs == rhs;
- }
-};
-
-// LLVMType behaves like a pointer similarly to mlir::Type.
-template <>
-struct PointerLikeTypeTraits<mlir::LLVM::LLVMType> {
- static inline void *getAsVoidPointer(mlir::LLVM::LLVMType type) {
- return const_cast<void *>(type.getAsOpaquePointer());
- }
- static inline mlir::LLVM::LLVMType getFromVoidPointer(void *ptr) {
- return mlir::LLVM::LLVMType::getFromOpaquePointer(ptr);
- }
- static constexpr int NumLowBitsAvailable =
- PointerLikeTypeTraits<mlir::Type>::NumLowBitsAvailable;
-};
-
-} // namespace llvm
-
#endif // MLIR_DIALECT_LLVMIR_LLVMTYPES_H_
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index c6d2ded073e6..e073126450d6 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -149,11 +149,11 @@ def ROCDL_MubufLoadOp :
LLVM_Type:$slc)>{
string llvmBuilder = [{
$res = createIntrinsicCall(builder,
- llvm::Intrinsic::amdgcn_buffer_load, {$rsrc, $vindex, $offset, $glc,
+ llvm::Intrinsic::amdgcn_buffer_load, {$rsrc, $vindex, $offset, $glc,
$slc}, {$_resultType});
}];
let parser = [{ return parseROCDLMubufLoadOp(parser, result); }];
- let printer = [{
+ let printer = [{
Operation *op = this->getOperation();
p << op->getName() << " " << op->getOperands()
<< " : " << op->getResultTypes();
@@ -169,7 +169,7 @@ def ROCDL_MubufStoreOp :
LLVM_Type:$glc,
LLVM_Type:$slc)>{
string llvmBuilder = [{
- auto vdataType = convertType(op.vdata().getType().cast<LLVM::LLVMType>());
+ auto vdataType = convertType(op.vdata().getType());
createIntrinsicCall(builder,
llvm::Intrinsic::amdgcn_buffer_store, {$vdata, $rsrc, $vindex,
$offset, $glc, $slc}, {vdataType});
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
index 5259ed7fe182..7691adfeef14 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
@@ -104,7 +104,7 @@ class ModuleTranslation {
llvm::IRBuilder<> &builder);
/// Converts the type from MLIR LLVM dialect to LLVM.
- llvm::Type *convertType(LLVMType type);
+ llvm::Type *convertType(Type type);
static std::unique_ptr<llvm::Module>
prepareLLVMModule(Operation *m, llvm::LLVMContext &llvmContext,
diff --git a/mlir/include/mlir/Target/LLVMIR/TypeTranslation.h b/mlir/include/mlir/Target/LLVMIR/TypeTranslation.h
index 71924b0c61ca..6b8e29ad7d9c 100644
--- a/mlir/include/mlir/Target/LLVMIR/TypeTranslation.h
+++ b/mlir/include/mlir/Target/LLVMIR/TypeTranslation.h
@@ -24,12 +24,11 @@ class Type;
namespace mlir {
+class Type;
class MLIRContext;
namespace LLVM {
-class LLVMType;
-
namespace detail {
class TypeToLLVMIRTranslatorImpl;
class TypeFromLLVMIRTranslatorImpl;
@@ -47,11 +46,10 @@ class TypeToLLVMIRTranslator {
/// that this will perform type conversion and store its results for future
/// uses.
// TODO: this should be removed when MLIR has proper data layout.
- unsigned getPreferredAlignment(LLVM::LLVMType type,
- const llvm::DataLayout &layout);
+ unsigned getPreferredAlignment(Type type, const llvm::DataLayout &layout);
/// Translates the given MLIR LLVM dialect type to LLVM IR.
- llvm::Type *translateType(LLVM::LLVMType type);
+ llvm::Type *translateType(Type type);
private:
/// Private implementation.
@@ -67,7 +65,7 @@ class TypeFromLLVMIRTranslator {
~TypeFromLLVMIRTranslator();
/// Translates the given LLVM IR type to the MLIR LLVM dialect.
- LLVM::LLVMType translateType(llvm::Type *type);
+ Type translateType(llvm::Type *type);
private:
/// Private implementation.
diff --git a/mlir/lib/Conversion/ArmSVEToLLVM/ArmSVEToLLVM.cpp b/mlir/lib/Conversion/ArmSVEToLLVM/ArmSVEToLLVM.cpp
index 5742cd790e77..8a5790352263 100644
--- a/mlir/lib/Conversion/ArmSVEToLLVM/ArmSVEToLLVM.cpp
+++ b/mlir/lib/Conversion/ArmSVEToLLVM/ArmSVEToLLVM.cpp
@@ -36,15 +36,14 @@ using VectorScaleOpLowering =
OneToOneConvertToLLVMPattern<VectorScaleOp, LLVM::vector_scale>;
// Extract an LLVM IR type from the LLVM IR dialect type.
-static LLVM::LLVMType unwrap(Type type) {
+static Type unwrap(Type type) {
if (!type)
return nullptr;
auto *mlirContext = type.getContext();
- auto wrappedLLVMType = type.dyn_cast<LLVM::LLVMType>();
- if (!wrappedLLVMType)
+ if (!LLVM::isCompatibleType(type))
emitError(UnknownLoc::get(mlirContext),
"conversion resulted in a non-LLVM type");
- return wrappedLLVMType;
+ return type;
}
static Optional<Type>
diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
index 517c8d2c6f56..93854c3cc05c 100644
--- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
+++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
@@ -149,7 +149,7 @@ struct AsyncAPI {
}
// Auxiliary coroutine resume intrinsic wrapper.
- static LLVM::LLVMType resumeFunctionType(MLIRContext *ctx) {
+ static Type resumeFunctionType(MLIRContext *ctx) {
auto voidTy = LLVM::LLVMVoidType::get(ctx);
auto i8Ptr = opaquePointerType(ctx);
return LLVM::LLVMFunctionType::get(voidTy, {i8Ptr}, false);
@@ -203,13 +203,11 @@ static constexpr const char *kCoroEnd = "llvm.coro.end";
static constexpr const char *kCoroFree = "llvm.coro.free";
static constexpr const char *kCoroResume = "llvm.coro.resume";
-/// Adds an LLVM function declaration to a module.
static void addLLVMFuncDecl(ModuleOp module, ImplicitLocOpBuilder &builder,
- StringRef name, LLVM::LLVMType ret,
- ArrayRef<LLVM::LLVMType> params) {
+ StringRef name, Type ret, ArrayRef<Type> params) {
if (module.lookupSymbol(name))
return;
- LLVM::LLVMType type = LLVM::LLVMFunctionType::get(ret, params);
+ Type type = LLVM::LLVMFunctionType::get(ret, params);
builder.create<LLVM::LLVMFuncOp>(name, type);
}
@@ -386,8 +384,7 @@ static CoroMachinery setupCoroMachinery(FuncOp func) {
// http://nondot.org/sabre/LLVMNotes/SizeOf-OffsetOf-VariableSizedStructs.txt
auto sizeOf = [&](ValueType valueType) -> Value {
auto storedType = converter.convertType(valueType.getValueType());
- auto storagePtrType =
- LLVM::LLVMPointerType::get(storedType.cast<LLVM::LLVMType>());
+ auto storagePtrType = LLVM::LLVMPointerType::get(storedType);
// %Size = getelementptr %T* null, int 1
// %SizeI = ptrtoint %T* %Size to i32
@@ -949,8 +946,7 @@ class AwaitValueOpLowering : public AwaitOpLoweringBase<AwaitOp, ValueType> {
// Cast from i8* to the pointer pointer to LLVM type.
auto llvmValueType = getTypeConverter()->convertType(valueType);
auto castedStorage = rewriter.create<LLVM::BitcastOp>(
- loc, LLVM::LLVMPointerType::get(llvmValueType.cast<LLVM::LLVMType>()),
- storage.getResult(0));
+ loc, LLVM::LLVMPointerType::get(llvmValueType), storage.getResult(0));
// Load from the async value storage.
auto loaded = rewriter.create<LLVM::LoadOp>(loc, castedStorage.getResult());
@@ -1015,9 +1011,7 @@ class YieldOpLowering : public ConversionPattern {
// Cast storage pointer to the yielded value type.
auto castedStorage = rewriter.create<LLVM::BitcastOp>(
- loc,
- LLVM::LLVMPointerType::get(
- yieldValue.getType().cast<LLVM::LLVMType>()),
+ loc, LLVM::LLVMPointerType::get(yieldValue.getType()),
storage.getResult(0));
// Store the yielded value into the async value storage.
diff --git a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp
index 6859834de67f..cd16df12bfae 100644
--- a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp
+++ b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp
@@ -52,8 +52,8 @@ class GpuToLLVMConversionPass
class FunctionCallBuilder {
public:
- FunctionCallBuilder(StringRef functionName, LLVM::LLVMType returnType,
- ArrayRef<LLVM::LLVMType> argumentTypes)
+ FunctionCallBuilder(StringRef functionName, Type returnType,
+ ArrayRef<Type> argumentTypes)
: functionName(functionName),
functionType(LLVM::LLVMFunctionType::get(returnType, argumentTypes)) {}
LLVM::CallOp create(Location loc, OpBuilder &builder,
@@ -73,15 +73,14 @@ class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> {
protected:
MLIRContext *context = &this->getTypeConverter()->getContext();
- LLVM::LLVMType llvmVoidType = LLVM::LLVMVoidType::get(context);
- LLVM::LLVMType llvmPointerType =
+ Type llvmVoidType = LLVM::LLVMVoidType::get(context);
+ Type llvmPointerType =
LLVM::LLVMPointerType::get(LLVM::LLVMIntegerType::get(context, 8));
- LLVM::LLVMType llvmPointerPointerType =
- LLVM::LLVMPointerType::get(llvmPointerType);
- LLVM::LLVMType llvmInt8Type = LLVM::LLVMIntegerType::get(context, 8);
- LLVM::LLVMType llvmInt32Type = LLVM::LLVMIntegerType::get(context, 32);
- LLVM::LLVMType llvmInt64Type = LLVM::LLVMIntegerType::get(context, 64);
- LLVM::LLVMType llvmIntPtrType = LLVM::LLVMIntegerType::get(
+ Type llvmPointerPointerType = LLVM::LLVMPointerType::get(llvmPointerType);
+ Type llvmInt8Type = LLVM::LLVMIntegerType::get(context, 8);
+ Type llvmInt32Type = LLVM::LLVMIntegerType::get(context, 32);
+ Type llvmInt64Type = LLVM::LLVMIntegerType::get(context, 64);
+ Type llvmIntPtrType = LLVM::LLVMIntegerType::get(
context, this->getTypeConverter()->getPointerBitwidth(0));
FunctionCallBuilder moduleLoadCallBuilder = {
@@ -321,7 +320,7 @@ LLVM::CallOp FunctionCallBuilder::create(Location loc, OpBuilder &builder,
static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands,
ConversionPatternRewriter &rewriter) {
if (!llvm::all_of(operands, [](Value value) {
- return value.getType().isa<LLVM::LLVMType>();
+ return LLVM::isCompatibleType(value.getType());
}))
return rewriter.notifyMatchFailure(
op, "Cannot convert if operands aren't of LLVM type.");
@@ -511,10 +510,10 @@ Value ConvertLaunchFuncOpToGpuRuntimeCallPattern::generateParamsArray(
loc, launchOp.getOperands().take_back(numKernelOperands),
operands.take_back(numKernelOperands), builder);
auto numArguments = arguments.size();
- SmallVector<LLVM::LLVMType, 4> argumentTypes;
+ SmallVector<Type, 4> argumentTypes;
argumentTypes.reserve(numArguments);
for (auto argument : arguments)
- argumentTypes.push_back(argument.getType().cast<LLVM::LLVMType>());
+ argumentTypes.push_back(argument.getType());
auto structType = LLVM::LLVMStructType::getNewIdentified(context, StringRef(),
argumentTypes);
auto one = builder.create<LLVM::ConstantOp>(loc, llvmInt32Type,
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
index 3acc73415ef1..37c4469676c6 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
@@ -38,7 +38,7 @@ struct GPUFuncOpLowering : ConvertOpToLLVMPattern<gpu::GPUFuncOp> {
uint64_t numElements = type.getNumElements();
auto elementType = typeConverter->convertType(type.getElementType())
- .template cast<LLVM::LLVMType>();
+ .template cast<Type>();
auto arrayType = LLVM::LLVMArrayType::get(elementType, numElements);
std::string name = std::string(
llvm::formatv("__wg_{0}_{1}", gpuFuncOp.getName(), en.index()));
@@ -126,7 +126,7 @@ struct GPUFuncOpLowering : ConvertOpToLLVMPattern<gpu::GPUFuncOp> {
// memory space and does not support `alloca`s with addrspace(5).
auto ptrType = LLVM::LLVMPointerType::get(
typeConverter->convertType(type.getElementType())
- .template cast<LLVM::LLVMType>(),
+ .template cast<Type>(),
AllocaAddrSpace);
Value numElements = rewriter.create<LLVM::ConstantOp>(
gpuFuncOp.getLoc(), int64Ty,
diff --git a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
index 631eca5cd32d..5b98d3cee1fb 100644
--- a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
@@ -40,7 +40,6 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
using LLVM::LLVMFuncOp;
- using LLVM::LLVMType;
static_assert(
std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
@@ -54,9 +53,8 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
for (Value operand : operands)
castedOperands.push_back(maybeCast(operand, rewriter));
- LLVMType resultType =
- castedOperands.front().getType().cast<LLVM::LLVMType>();
- LLVMType funcType = getFunctionType(resultType, castedOperands);
+ Type resultType = castedOperands.front().getType();
+ Type funcType = getFunctionType(resultType, castedOperands);
StringRef funcName = getFunctionName(
funcType.cast<LLVM::LLVMFunctionType>().getReturnType());
if (funcName.empty())
@@ -80,7 +78,7 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
private:
Value maybeCast(Value operand, PatternRewriter &rewriter) const {
- LLVM::LLVMType type = operand.getType().cast<LLVM::LLVMType>();
+ Type type = operand.getType();
if (!type.isa<LLVM::LLVMHalfType>())
return operand;
@@ -89,17 +87,15 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
operand);
}
- LLVM::LLVMType getFunctionType(LLVM::LLVMType resultType,
- ArrayRef<Value> operands) const {
- using LLVM::LLVMType;
- SmallVector<LLVMType, 1> operandTypes;
+ Type getFunctionType(Type resultType, ArrayRef<Value> operands) const {
+ SmallVector<Type, 1> operandTypes;
for (Value operand : operands) {
- operandTypes.push_back(operand.getType().cast<LLVMType>());
+ operandTypes.push_back(operand.getType());
}
return LLVM::LLVMFunctionType::get(resultType, operandTypes);
}
- StringRef getFunctionName(LLVM::LLVMType type) const {
+ StringRef getFunctionName(Type type) const {
if (type.isa<LLVM::LLVMFloatType>())
return f32Func;
if (type.isa<LLVM::LLVMDoubleType>())
@@ -107,8 +103,7 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
return "";
}
- LLVM::LLVMFuncOp appendOrGetFuncOp(StringRef funcName,
- LLVM::LLVMType funcType,
+ LLVM::LLVMFuncOp appendOrGetFuncOp(StringRef funcName, Type funcType,
Operation *op) const {
using LLVM::LLVMFuncOp;
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index f747f519c66b..09877bfd19b3 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -56,7 +56,7 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
Location loc = op->getLoc();
gpu::ShuffleOpAdaptor adaptor(operands);
- auto valueTy = adaptor.value().getType().cast<LLVM::LLVMType>();
+ auto valueTy = adaptor.value().getType();
auto int32Type = LLVM::LLVMIntegerType::get(rewriter.getContext(), 32);
auto predTy = LLVM::LLVMIntegerType::get(rewriter.getContext(), 1);
auto resultTy = LLVM::LLVMStructType::getLiteral(rewriter.getContext(),
diff --git a/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp b/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
index 4b657d25f51e..6c88e5774239 100644
--- a/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
+++ b/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
@@ -65,7 +65,7 @@ class VulkanLaunchFuncToVulkanCallsPass
llvmInt64Type = LLVM::LLVMIntegerType::get(&getContext(), 64);
}
- LLVM::LLVMType getMemRefType(uint32_t rank, LLVM::LLVMType elemenType) {
+ Type getMemRefType(uint32_t rank, Type elemenType) {
// According to the MLIR doc memref argument is converted into a
// pointer-to-struct argument of type:
// template <typename Elem, size_t Rank>
@@ -89,10 +89,10 @@ class VulkanLaunchFuncToVulkanCallsPass
llvmArrayRankElementSizeType, llvmArrayRankElementSizeType});
}
- LLVM::LLVMType getVoidType() { return llvmVoidType; }
- LLVM::LLVMType getPointerType() { return llvmPointerType; }
- LLVM::LLVMType getInt32Type() { return llvmInt32Type; }
- LLVM::LLVMType getInt64Type() { return llvmInt64Type; }
+ Type getVoidType() { return llvmVoidType; }
+ Type getPointerType() { return llvmPointerType; }
+ Type getInt32Type() { return llvmInt32Type; }
+ Type getInt64Type() { return llvmInt64Type; }
/// Creates an LLVM global for the given `name`.
Value createEntryPointNameConstant(StringRef name, Location loc,
@@ -128,10 +128,10 @@ class VulkanLaunchFuncToVulkanCallsPass
/// Deduces a rank and element type from the given 'ptrToMemRefDescriptor`.
LogicalResult deduceMemRefRankAndType(Value ptrToMemRefDescriptor,
- uint32_t &rank, LLVM::LLVMType &type);
+ uint32_t &rank, Type &type);
/// Returns a string representation from the given `type`.
- StringRef stringifyType(LLVM::LLVMType type) {
+ StringRef stringifyType(Type type) {
if (type.isa<LLVM::LLVMFloatType>())
return "Float";
if (type.isa<LLVM::LLVMHalfType>())
@@ -152,11 +152,11 @@ class VulkanLaunchFuncToVulkanCallsPass
void runOnOperation() override;
private:
- LLVM::LLVMType llvmFloatType;
- LLVM::LLVMType llvmVoidType;
- LLVM::LLVMType llvmPointerType;
- LLVM::LLVMType llvmInt32Type;
- LLVM::LLVMType llvmInt64Type;
+ Type llvmFloatType;
+ Type llvmVoidType;
+ Type llvmPointerType;
+ Type llvmInt32Type;
+ Type llvmInt64Type;
// TODO: Use an associative array to support multiple vulkan launch calls.
std::pair<StringAttr, StringAttr> spirvAttributes;
@@ -230,7 +230,7 @@ void VulkanLaunchFuncToVulkanCallsPass::createBindMemRefCalls(
auto ptrToMemRefDescriptor = en.value();
uint32_t rank = 0;
- LLVM::LLVMType type;
+ Type type;
if (failed(deduceMemRefRankAndType(ptrToMemRefDescriptor, rank, type))) {
cInterfaceVulkanLaunchCallOp.emitError()
<< "invalid memref descriptor " << ptrToMemRefDescriptor.getType();
@@ -258,7 +258,7 @@ void VulkanLaunchFuncToVulkanCallsPass::createBindMemRefCalls(
}
LogicalResult VulkanLaunchFuncToVulkanCallsPass::deduceMemRefRankAndType(
- Value ptrToMemRefDescriptor, uint32_t &rank, LLVM::LLVMType &type) {
+ Value ptrToMemRefDescriptor, uint32_t &rank, Type &type) {
auto llvmPtrDescriptorTy =
ptrToMemRefDescriptor.getType().dyn_cast<LLVM::LLVMPointerType>();
if (!llvmPtrDescriptorTy)
@@ -324,12 +324,11 @@ void VulkanLaunchFuncToVulkanCallsPass::declareVulkanFunctions(Location loc) {
}
for (unsigned i = 1; i <= 3; i++) {
- SmallVector<LLVM::LLVMType, 5> types{
- LLVM::LLVMFloatType::get(&getContext()),
- LLVM::LLVMIntegerType::get(&getContext(), 32),
- LLVM::LLVMIntegerType::get(&getContext(), 16),
- LLVM::LLVMIntegerType::get(&getContext(), 8),
- LLVM::LLVMHalfType::get(&getContext())};
+ SmallVector<Type, 5> types{LLVM::LLVMFloatType::get(&getContext()),
+ LLVM::LLVMIntegerType::get(&getContext(), 32),
+ LLVM::LLVMIntegerType::get(&getContext(), 16),
+ LLVM::LLVMIntegerType::get(&getContext(), 8),
+ LLVM::LLVMHalfType::get(&getContext())};
for (auto type : types) {
std::string fnName = "bindMemRef" + std::to_string(i) + "D" +
std::string(stringifyType(type));
diff --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
index c86ae710aac6..95e7c34a1710 100644
--- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
+++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
@@ -67,11 +67,9 @@ using llvm_alloca = ValueBuilder<LLVM::AllocaOp>;
using llvm_return = OperationBuilder<LLVM::ReturnOp>;
template <typename T>
-static LLVMType getPtrToElementType(T containerType,
- LLVMTypeConverter &lowering) {
- return lowering.convertType(containerType.getElementType())
- .template cast<LLVMType>()
- .getPointerTo();
+static Type getPtrToElementType(T containerType, LLVMTypeConverter &lowering) {
+ return LLVMPointerType::get(
+ lowering.convertType(containerType.getElementType()));
}
/// Convert the given range descriptor type to the LLVMIR dialect.
@@ -84,8 +82,7 @@ static LLVMType getPtrToElementType(T containerType,
/// };
static Type convertRangeType(RangeType t, LLVMTypeConverter &converter) {
auto *context = t.getContext();
- auto int64Ty = converter.convertType(IntegerType::get(context, 64))
- .cast<LLVM::LLVMType>();
+ auto int64Ty = converter.convertType(IntegerType::get(context, 64));
return LLVMStructType::getLiteral(context, {int64Ty, int64Ty, int64Ty});
}
@@ -206,8 +203,7 @@ class SliceOpConversion : public ConvertOpToLLVMPattern<SliceOp> {
BaseViewConversionHelper baseDesc(adaptor.view());
auto memRefType = sliceOp.getBaseViewType();
- auto int64Ty = typeConverter->convertType(rewriter.getIntegerType(64))
- .cast<LLVM::LLVMType>();
+ auto int64Ty = typeConverter->convertType(rewriter.getIntegerType(64));
BaseViewConversionHelper desc(
typeConverter->convertType(sliceOp.getShapedType()));
diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp
index c293dacf8ec5..8c42c2a7dc92 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp
@@ -184,7 +184,7 @@ class GPULaunchLowering : public ConvertOpToLLVMPattern<gpu::LaunchFuncOp> {
kernelFunc = rewriter.create<LLVM::LLVMFuncOp>(
rewriter.getUnknownLoc(), newKernelFuncName,
LLVM::LLVMFunctionType::get(LLVM::LLVMVoidType::get(context),
- ArrayRef<LLVM::LLVMType>()));
+ ArrayRef<Type>()));
rewriter.setInsertionPoint(launchOp);
}
@@ -234,7 +234,7 @@ class GPULaunchLowering : public ConvertOpToLLVMPattern<gpu::LaunchFuncOp> {
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(module.getBody());
dstGlobal = rewriter.create<LLVM::GlobalOp>(
- loc, dstGlobalType.cast<LLVM::LLVMType>(),
+ loc, dstGlobalType,
/*isConstant=*/false, LLVM::Linkage::Linkonce, name, Attribute());
rewriter.setInsertionPoint(launchOp);
}
diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
index 76a4e0b2e07e..2ebb24b5aaeb 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
@@ -65,7 +65,7 @@ static unsigned getBitWidth(Type type) {
}
/// Returns the bit width of LLVMType integer or vector.
-static unsigned getLLVMTypeBitWidth(LLVM::LLVMType type) {
+static unsigned getLLVMTypeBitWidth(Type type) {
auto vectorType = type.dyn_cast<LLVM::LLVMVectorType>();
return (vectorType ? vectorType.getElementType() : type)
.cast<LLVM::LLVMIntegerType>()
@@ -115,16 +115,15 @@ static Value createFPConstant(Location loc, Type srcType, Type dstType,
/// - `BitFieldSExtract`
/// - `BitFieldUExtract`
/// Truncates or extends the value. If the bitwidth of the value is the same as
-/// `dstType` bitwidth, the value remains unchanged.
-static Value optionallyTruncateOrExtend(Location loc, Value value, Type dstType,
+/// `llvmType` bitwidth, the value remains unchanged.
+static Value optionallyTruncateOrExtend(Location loc, Value value,
+ Type llvmType,
PatternRewriter &rewriter) {
auto srcType = value.getType();
- auto llvmType = dstType.cast<LLVM::LLVMType>();
unsigned targetBitWidth = getLLVMTypeBitWidth(llvmType);
- unsigned valueBitWidth =
- srcType.isa<LLVM::LLVMType>()
- ? getLLVMTypeBitWidth(srcType.cast<LLVM::LLVMType>())
- : getBitWidth(srcType);
+ unsigned valueBitWidth = LLVM::isCompatibleType(srcType)
+ ? getLLVMTypeBitWidth(srcType)
+ : getBitWidth(srcType);
if (valueBitWidth < targetBitWidth)
return rewriter.create<LLVM::ZExtOp>(loc, llvmType, value);
@@ -193,7 +192,7 @@ convertStructTypeWithOffset(spirv::StructType type,
auto elementsVector = llvm::to_vector<8>(
llvm::map_range(type.getElementTypes(), [&](Type elementType) {
- return converter.convertType(elementType).cast<LLVM::LLVMType>();
+ return converter.convertType(elementType);
}));
return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector,
/*isPacked=*/false);
@@ -204,7 +203,7 @@ static Type convertStructTypePacked(spirv::StructType type,
LLVMTypeConverter &converter) {
auto elementsVector = llvm::to_vector<8>(
llvm::map_range(type.getElementTypes(), [&](Type elementType) {
- return converter.convertType(elementType).cast<LLVM::LLVMType>();
+ return converter.convertType(elementType);
}));
return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector,
/*isPacked=*/true);
@@ -255,8 +254,7 @@ static Optional<Type> convertArrayType(spirv::ArrayType type,
!(sizeInBytes.hasValue() && sizeInBytes.getValue() == stride))
return llvm::None;
- auto llvmElementType =
- converter.convertType(elementType).cast<LLVM::LLVMType>();
+ auto llvmElementType = converter.convertType(elementType);
unsigned numElements = type.getNumElements();
return LLVM::LLVMArrayType::get(llvmElementType, numElements);
}
@@ -265,8 +263,7 @@ static Optional<Type> convertArrayType(spirv::ArrayType type,
/// modelled at the moment.
static Type convertPointerType(spirv::PointerType type,
TypeConverter &converter) {
- auto pointeeType =
- converter.convertType(type.getPointeeType()).cast<LLVM::LLVMType>();
+ auto pointeeType = converter.convertType(type.getPointeeType());
return LLVM::LLVMPointerType::get(pointeeType);
}
@@ -277,8 +274,7 @@ static Optional<Type> convertRuntimeArrayType(spirv::RuntimeArrayType type,
TypeConverter &converter) {
if (type.getArrayStride() != 0)
return llvm::None;
- auto elementType =
- converter.convertType(type.getElementType()).cast<LLVM::LLVMType>();
+ auto elementType = converter.convertType(type.getElementType());
return LLVM::LLVMArrayType::get(elementType, 0);
}
@@ -336,8 +332,7 @@ class AddressOfPattern : public SPIRVToLLVMConversion<spirv::AddressOfOp> {
auto dstType = typeConverter.convertType(op.pointer().getType());
if (!dstType)
return failure();
- rewriter.replaceOpWithNewOp<LLVM::AddressOfOp>(
- op, dstType.cast<LLVM::LLVMType>(), op.variable());
+ rewriter.replaceOpWithNewOp<LLVM::AddressOfOp>(op, dstType, op.variable());
return success();
}
};
@@ -667,7 +662,7 @@ class ExecutionModePattern
// int32_t values[]; // optional values
// };
auto llvmI32Type = LLVM::LLVMIntegerType::get(context, 32);
- SmallVector<LLVM::LLVMType, 2> fields;
+ SmallVector<Type, 2> fields;
fields.push_back(llvmI32Type);
ArrayAttr values = op.values();
if (!values.empty()) {
@@ -757,8 +752,7 @@ class GlobalVariablePattern
? LLVM::Linkage::Private
: LLVM::Linkage::External;
rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
- op, dstType.cast<LLVM::LLVMType>(), isConstant, linkage, op.sym_name(),
- Attribute());
+ op, dstType, isConstant, linkage, op.sym_name(), Attribute());
return success();
}
};
diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index 233c2eadc77c..39680a28a33e 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -41,15 +41,15 @@ using namespace mlir;
#define PASS_NAME "convert-std-to-llvm"
// Extract an LLVM IR type from the LLVM IR dialect type.
-static LLVM::LLVMType unwrap(Type type) {
+static Type unwrap(Type type) {
if (!type)
return nullptr;
auto *mlirContext = type.getContext();
- auto wrappedLLVMType = type.dyn_cast<LLVM::LLVMType>();
- if (!wrappedLLVMType)
+ if (!LLVM::isCompatibleType(type))
emitError(UnknownLoc::get(mlirContext),
- "conversion resulted in a non-LLVM type");
- return wrappedLLVMType;
+ "conversion resulted in a non-LLVM type ")
+ << type;
+ return type;
}
/// Callback to convert function argument types. It converts a MemRef function
@@ -120,8 +120,11 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
[&](UnrankedMemRefType type) { return convertUnrankedMemRefType(type); });
addConversion([&](VectorType type) { return convertVectorType(type); });
- // LLVMType is legal, so add a pass-through conversion.
- addConversion([](LLVM::LLVMType type) { return type; });
+ // LLVM-compatible types are legal, so add a pass-through conversion.
+ addConversion([](Type type) {
+ return LLVM::isCompatibleType(type) ? llvm::Optional<Type>(type)
+ : llvm::None;
+ });
// Materialization for memrefs creates descriptor structs from individual
// values constituting them, when descriptors are used, i.e. more than one
@@ -170,7 +173,7 @@ MLIRContext &LLVMTypeConverter::getContext() {
return *getDialect()->getContext();
}
-LLVM::LLVMType LLVMTypeConverter::getIndexType() {
+Type LLVMTypeConverter::getIndexType() {
return LLVM::LLVMIntegerType::get(&getContext(), getIndexTypeBitwidth());
}
@@ -205,7 +208,7 @@ Type LLVMTypeConverter::convertFloatType(FloatType type) {
static constexpr unsigned kRealPosInComplexNumberStruct = 0;
static constexpr unsigned kImaginaryPosInComplexNumberStruct = 1;
Type LLVMTypeConverter::convertComplexType(ComplexType type) {
- auto elementType = convertType(type.getElementType()).cast<LLVM::LLVMType>();
+ auto elementType = convertType(type.getElementType());
return LLVM::LLVMStructType::getLiteral(&getContext(),
{elementType, elementType});
}
@@ -214,7 +217,7 @@ Type LLVMTypeConverter::convertComplexType(ComplexType type) {
// pointer-to-function types.
Type LLVMTypeConverter::convertFunctionType(FunctionType type) {
SignatureConversion conversion(type.getNumInputs());
- LLVM::LLVMType converted =
+ Type converted =
convertFunctionSignature(type, /*isVariadic=*/false, conversion);
return LLVM::LLVMPointerType::get(converted);
}
@@ -224,7 +227,7 @@ Type LLVMTypeConverter::convertFunctionType(FunctionType type) {
// argument and result types. If MLIR Function has zero results, the LLVM
// Function has one VoidType result. If MLIR Function has more than one result,
// they are into an LLVM StructType in their order of appearance.
-LLVM::LLVMType LLVMTypeConverter::convertFunctionSignature(
+Type LLVMTypeConverter::convertFunctionSignature(
FunctionType funcTy, bool isVariadic,
LLVMTypeConverter::SignatureConversion &result) {
// Select the argument converter depending on the calling convention.
@@ -240,7 +243,7 @@ LLVM::LLVMType LLVMTypeConverter::convertFunctionSignature(
result.addInputs(en.index(), converted);
}
- SmallVector<LLVM::LLVMType, 8> argTypes;
+ SmallVector<Type, 8> argTypes;
argTypes.reserve(llvm::size(result.getConvertedTypes()));
for (Type type : result.getConvertedTypes())
argTypes.push_back(unwrap(type));
@@ -248,10 +251,9 @@ LLVM::LLVMType LLVMTypeConverter::convertFunctionSignature(
// If function does not return anything, create the void result type,
// if it returns on element, convert it, otherwise pack the result types into
// a struct.
- LLVM::LLVMType resultType =
- funcTy.getNumResults() == 0
- ? LLVM::LLVMVoidType::get(&getContext())
- : unwrap(packFunctionResults(funcTy.getResults()));
+ Type resultType = funcTy.getNumResults() == 0
+ ? LLVM::LLVMVoidType::get(&getContext())
+ : unwrap(packFunctionResults(funcTy.getResults()));
if (!resultType)
return {};
return LLVM::LLVMFunctionType::get(resultType, argTypes, isVariadic);
@@ -259,23 +261,21 @@ LLVM::LLVMType LLVMTypeConverter::convertFunctionSignature(
/// Converts the function type to a C-compatible format, in particular using
/// pointers to memref descriptors for arguments.
-LLVM::LLVMType
-LLVMTypeConverter::convertFunctionTypeCWrapper(FunctionType type) {
- SmallVector<LLVM::LLVMType, 4> inputs;
+Type LLVMTypeConverter::convertFunctionTypeCWrapper(FunctionType type) {
+ SmallVector<Type, 4> inputs;
for (Type t : type.getInputs()) {
- auto converted = convertType(t).dyn_cast_or_null<LLVM::LLVMType>();
- if (!converted)
+ auto converted = convertType(t);
+ if (!converted || !LLVM::isCompatibleType(converted))
return {};
if (t.isa<MemRefType, UnrankedMemRefType>())
converted = LLVM::LLVMPointerType::get(converted);
inputs.push_back(converted);
}
- LLVM::LLVMType resultType =
- type.getNumResults() == 0
- ? LLVM::LLVMVoidType::get(&getContext())
- : unwrap(packFunctionResults(type.getResults()));
+ Type resultType = type.getNumResults() == 0
+ ? LLVM::LLVMVoidType::get(&getContext())
+ : unwrap(packFunctionResults(type.getResults()));
if (!resultType)
return {};
@@ -316,19 +316,19 @@ static constexpr unsigned kStridePosInMemRefDescriptor = 4;
/// Index sizes[Rank]; // omitted when rank == 0
/// Index strides[Rank]; // omitted when rank == 0
/// };
-SmallVector<LLVM::LLVMType, 5>
+SmallVector<Type, 5>
LLVMTypeConverter::getMemRefDescriptorFields(MemRefType type,
bool unpackAggregates) {
assert(isStrided(type) &&
"Non-strided layout maps must have been normalized away");
- LLVM::LLVMType elementType = unwrap(convertType(type.getElementType()));
+ Type elementType = unwrap(convertType(type.getElementType()));
if (!elementType)
return {};
auto ptrTy = LLVM::LLVMPointerType::get(elementType, type.getMemorySpace());
auto indexTy = getIndexType();
- SmallVector<LLVM::LLVMType, 5> results = {ptrTy, ptrTy, indexTy};
+ SmallVector<Type, 5> results = {ptrTy, ptrTy, indexTy};
auto rank = type.getRank();
if (rank == 0)
return results;
@@ -345,7 +345,7 @@ LLVMTypeConverter::getMemRefDescriptorFields(MemRefType type,
Type LLVMTypeConverter::convertMemRefType(MemRefType type) {
// When converting a MemRefType to a struct with descriptor fields, do not
// unpack the `sizes` and `strides` arrays.
- SmallVector<LLVM::LLVMType, 5> types =
+ SmallVector<Type, 5> types =
getMemRefDescriptorFields(type, /*unpackAggregates=*/false);
return LLVM::LLVMStructType::getLiteral(&getContext(), types);
}
@@ -360,8 +360,7 @@ static constexpr unsigned kPtrInUnrankedMemRefDescriptor = 1;
/// 2. void* ptr, pointer to the static ranked MemRef descriptor. This will be
/// stack allocated (alloca) copy of a MemRef descriptor that got casted to
/// be unranked.
-SmallVector<LLVM::LLVMType, 2>
-LLVMTypeConverter::getUnrankedMemRefDescriptorFields() {
+SmallVector<Type, 2> LLVMTypeConverter::getUnrankedMemRefDescriptorFields() {
return {getIndexType(), LLVM::LLVMPointerType::get(
LLVM::LLVMIntegerType::get(&getContext(), 8))};
}
@@ -395,7 +394,7 @@ Type LLVMTypeConverter::convertMemRefToBarePtr(BaseMemRefType type) {
if (ShapedType::isDynamicStrideOrOffset(offset))
return {};
- LLVM::LLVMType elementType = unwrap(convertType(type.getElementType()));
+ Type elementType = unwrap(convertType(type.getElementType()));
if (!elementType)
return {};
return LLVM::LLVMPointerType::get(elementType, type.getMemorySpace());
@@ -409,7 +408,7 @@ Type LLVMTypeConverter::convertVectorType(VectorType type) {
auto elementType = unwrap(convertType(type.getElementType()));
if (!elementType)
return {};
- LLVM::LLVMType vectorType =
+ Type vectorType =
LLVM::LLVMFixedVectorType::get(elementType, type.getShape().back());
auto shape = type.getShape();
for (int i = shape.size() - 2; i >= 0; --i)
@@ -454,10 +453,9 @@ ConvertToLLVMPattern::ConvertToLLVMPattern(StringRef rootOpName,
// StructBuilder implementation
//===----------------------------------------------------------------------===//
-StructBuilder::StructBuilder(Value v) : value(v) {
+StructBuilder::StructBuilder(Value v) : value(v), structType(v.getType()) {
assert(value != nullptr && "value cannot be null");
- structType = value.getType().dyn_cast<LLVM::LLVMType>();
- assert(structType && "expected llvm type");
+ assert(LLVM::isCompatibleType(structType) && "expected llvm type");
}
Value StructBuilder::extractPtr(OpBuilder &builder, Location loc,
@@ -479,7 +477,7 @@ void StructBuilder::setPtr(OpBuilder &builder, Location loc, unsigned pos,
ComplexStructBuilder ComplexStructBuilder::undef(OpBuilder &builder,
Location loc, Type type) {
- Value val = builder.create<LLVM::UndefOp>(loc, type.cast<LLVM::LLVMType>());
+ Value val = builder.create<LLVM::UndefOp>(loc, type);
return ComplexStructBuilder(val);
}
@@ -518,8 +516,7 @@ MemRefDescriptor::MemRefDescriptor(Value descriptor)
MemRefDescriptor MemRefDescriptor::undef(OpBuilder &builder, Location loc,
Type descriptorType) {
- Value descriptor =
- builder.create<LLVM::UndefOp>(loc, descriptorType.cast<LLVM::LLVMType>());
+ Value descriptor = builder.create<LLVM::UndefOp>(loc, descriptorType);
return MemRefDescriptor(descriptor);
}
@@ -620,9 +617,8 @@ Value MemRefDescriptor::size(OpBuilder &builder, Location loc, unsigned pos) {
Value MemRefDescriptor::size(OpBuilder &builder, Location loc, Value pos,
int64_t rank) {
- auto indexTy = indexType.cast<LLVM::LLVMType>();
- auto indexPtrTy = LLVM::LLVMPointerType::get(indexTy);
- auto arrayTy = LLVM::LLVMArrayType::get(indexTy, rank);
+ auto indexPtrTy = LLVM::LLVMPointerType::get(indexType);
+ auto arrayTy = LLVM::LLVMArrayType::get(indexType, rank);
auto arrayPtrTy = LLVM::LLVMPointerType::get(arrayTy);
// Copy size values to stack-allocated memory.
@@ -774,8 +770,7 @@ UnrankedMemRefDescriptor::UnrankedMemRefDescriptor(Value descriptor)
UnrankedMemRefDescriptor UnrankedMemRefDescriptor::undef(OpBuilder &builder,
Location loc,
Type descriptorType) {
- Value descriptor =
- builder.create<LLVM::UndefOp>(loc, descriptorType.cast<LLVM::LLVMType>());
+ Value descriptor = builder.create<LLVM::UndefOp>(loc, descriptorType);
return UnrankedMemRefDescriptor(descriptor);
}
Value UnrankedMemRefDescriptor::rank(OpBuilder &builder, Location loc) {
@@ -828,7 +823,7 @@ void UnrankedMemRefDescriptor::computeSizes(
return;
// Cache the index type.
- LLVM::LLVMType indexType = typeConverter.getIndexType();
+ Type indexType = typeConverter.getIndexType();
// Initialize shared constants.
Value one = createIndexAttrConstant(builder, loc, indexType, 1);
@@ -868,7 +863,7 @@ void UnrankedMemRefDescriptor::computeSizes(
Value UnrankedMemRefDescriptor::allocatedPtr(OpBuilder &builder, Location loc,
Value memRefDescPtr,
- LLVM::LLVMType elemPtrPtrType) {
+ Type elemPtrPtrType) {
Value elementPtrPtr =
builder.create<LLVM::BitcastOp>(loc, elemPtrPtrType, memRefDescPtr);
@@ -877,7 +872,7 @@ Value UnrankedMemRefDescriptor::allocatedPtr(OpBuilder &builder, Location loc,
void UnrankedMemRefDescriptor::setAllocatedPtr(OpBuilder &builder, Location loc,
Value memRefDescPtr,
- LLVM::LLVMType elemPtrPtrType,
+ Type elemPtrPtrType,
Value allocatedPtr) {
Value elementPtrPtr =
builder.create<LLVM::BitcastOp>(loc, elemPtrPtrType, memRefDescPtr);
@@ -887,7 +882,7 @@ void UnrankedMemRefDescriptor::setAllocatedPtr(OpBuilder &builder, Location loc,
Value UnrankedMemRefDescriptor::alignedPtr(OpBuilder &builder, Location loc,
LLVMTypeConverter &typeConverter,
Value memRefDescPtr,
- LLVM::LLVMType elemPtrPtrType) {
+ Type elemPtrPtrType) {
Value elementPtrPtr =
builder.create<LLVM::BitcastOp>(loc, elemPtrPtrType, memRefDescPtr);
@@ -901,7 +896,7 @@ Value UnrankedMemRefDescriptor::alignedPtr(OpBuilder &builder, Location loc,
void UnrankedMemRefDescriptor::setAlignedPtr(OpBuilder &builder, Location loc,
LLVMTypeConverter &typeConverter,
Value memRefDescPtr,
- LLVM::LLVMType elemPtrPtrType,
+ Type elemPtrPtrType,
Value alignedPtr) {
Value elementPtrPtr =
builder.create<LLVM::BitcastOp>(loc, elemPtrPtrType, memRefDescPtr);
@@ -916,7 +911,7 @@ void UnrankedMemRefDescriptor::setAlignedPtr(OpBuilder &builder, Location loc,
Value UnrankedMemRefDescriptor::offset(OpBuilder &builder, Location loc,
LLVMTypeConverter &typeConverter,
Value memRefDescPtr,
- LLVM::LLVMType elemPtrPtrType) {
+ Type elemPtrPtrType) {
Value elementPtrPtr =
builder.create<LLVM::BitcastOp>(loc, elemPtrPtrType, memRefDescPtr);
@@ -932,8 +927,7 @@ Value UnrankedMemRefDescriptor::offset(OpBuilder &builder, Location loc,
void UnrankedMemRefDescriptor::setOffset(OpBuilder &builder, Location loc,
LLVMTypeConverter &typeConverter,
Value memRefDescPtr,
- LLVM::LLVMType elemPtrPtrType,
- Value offset) {
+ Type elemPtrPtrType, Value offset) {
Value elementPtrPtr =
builder.create<LLVM::BitcastOp>(loc, elemPtrPtrType, memRefDescPtr);
@@ -949,16 +943,15 @@ void UnrankedMemRefDescriptor::setOffset(OpBuilder &builder, Location loc,
Value UnrankedMemRefDescriptor::sizeBasePtr(
OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter,
Value memRefDescPtr, LLVM::LLVMPointerType elemPtrPtrType) {
- LLVM::LLVMType elemPtrTy = elemPtrPtrType.getElementType();
- LLVM::LLVMType indexTy = typeConverter.getIndexType();
- LLVM::LLVMType structPtrTy =
+ Type elemPtrTy = elemPtrPtrType.getElementType();
+ Type indexTy = typeConverter.getIndexType();
+ Type structPtrTy =
LLVM::LLVMPointerType::get(LLVM::LLVMStructType::getLiteral(
indexTy.getContext(), {elemPtrTy, elemPtrTy, indexTy, indexTy}));
Value structPtr =
builder.create<LLVM::BitcastOp>(loc, structPtrTy, memRefDescPtr);
- LLVM::LLVMType int32_type =
- unwrap(typeConverter.convertType(builder.getI32Type()));
+ Type int32_type = unwrap(typeConverter.convertType(builder.getI32Type()));
Value zero =
createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 0);
Value three = builder.create<LLVM::ConstantOp>(loc, int32_type,
@@ -970,8 +963,7 @@ Value UnrankedMemRefDescriptor::sizeBasePtr(
Value UnrankedMemRefDescriptor::size(OpBuilder &builder, Location loc,
LLVMTypeConverter typeConverter,
Value sizeBasePtr, Value index) {
- LLVM::LLVMType indexPtrTy =
- LLVM::LLVMPointerType::get(typeConverter.getIndexType());
+ Type indexPtrTy = LLVM::LLVMPointerType::get(typeConverter.getIndexType());
Value sizeStoreGep = builder.create<LLVM::GEPOp>(loc, indexPtrTy, sizeBasePtr,
ValueRange({index}));
return builder.create<LLVM::LoadOp>(loc, sizeStoreGep);
@@ -981,8 +973,7 @@ void UnrankedMemRefDescriptor::setSize(OpBuilder &builder, Location loc,
LLVMTypeConverter typeConverter,
Value sizeBasePtr, Value index,
Value size) {
- LLVM::LLVMType indexPtrTy =
- LLVM::LLVMPointerType::get(typeConverter.getIndexType());
+ Type indexPtrTy = LLVM::LLVMPointerType::get(typeConverter.getIndexType());
Value sizeStoreGep = builder.create<LLVM::GEPOp>(loc, indexPtrTy, sizeBasePtr,
ValueRange({index}));
builder.create<LLVM::StoreOp>(loc, size, sizeStoreGep);
@@ -991,8 +982,7 @@ void UnrankedMemRefDescriptor::setSize(OpBuilder &builder, Location loc,
Value UnrankedMemRefDescriptor::strideBasePtr(OpBuilder &builder, Location loc,
LLVMTypeConverter &typeConverter,
Value sizeBasePtr, Value rank) {
- LLVM::LLVMType indexPtrTy =
- LLVM::LLVMPointerType::get(typeConverter.getIndexType());
+ Type indexPtrTy = LLVM::LLVMPointerType::get(typeConverter.getIndexType());
return builder.create<LLVM::GEPOp>(loc, indexPtrTy, sizeBasePtr,
ValueRange({rank}));
}
@@ -1001,8 +991,7 @@ Value UnrankedMemRefDescriptor::stride(OpBuilder &builder, Location loc,
LLVMTypeConverter typeConverter,
Value strideBasePtr, Value index,
Value stride) {
- LLVM::LLVMType indexPtrTy =
- LLVM::LLVMPointerType::get(typeConverter.getIndexType());
+ Type indexPtrTy = LLVM::LLVMPointerType::get(typeConverter.getIndexType());
Value strideStoreGep = builder.create<LLVM::GEPOp>(
loc, indexPtrTy, strideBasePtr, ValueRange({index}));
return builder.create<LLVM::LoadOp>(loc, strideStoreGep);
@@ -1012,8 +1001,7 @@ void UnrankedMemRefDescriptor::setStride(OpBuilder &builder, Location loc,
LLVMTypeConverter typeConverter,
Value strideBasePtr, Value index,
Value stride) {
- LLVM::LLVMType indexPtrTy =
- LLVM::LLVMPointerType::get(typeConverter.getIndexType());
+ Type indexPtrTy = LLVM::LLVMPointerType::get(typeConverter.getIndexType());
Value strideStoreGep = builder.create<LLVM::GEPOp>(
loc, indexPtrTy, strideBasePtr, ValueRange({index}));
builder.create<LLVM::StoreOp>(loc, stride, strideStoreGep);
@@ -1028,22 +1016,21 @@ LLVM::LLVMDialect &ConvertToLLVMPattern::getDialect() const {
return *getTypeConverter()->getDialect();
}
-LLVM::LLVMType ConvertToLLVMPattern::getIndexType() const {
+Type ConvertToLLVMPattern::getIndexType() const {
return getTypeConverter()->getIndexType();
}
-LLVM::LLVMType
-ConvertToLLVMPattern::getIntPtrType(unsigned addressSpace) const {
+Type ConvertToLLVMPattern::getIntPtrType(unsigned addressSpace) const {
return LLVM::LLVMIntegerType::get(
&getTypeConverter()->getContext(),
getTypeConverter()->getPointerBitwidth(addressSpace));
}
-LLVM::LLVMType ConvertToLLVMPattern::getVoidType() const {
+Type ConvertToLLVMPattern::getVoidType() const {
return LLVM::LLVMVoidType::get(&getTypeConverter()->getContext());
}
-LLVM::LLVMType ConvertToLLVMPattern::getVoidPtrType() const {
+Type ConvertToLLVMPattern::getVoidPtrType() const {
return LLVM::LLVMPointerType::get(
LLVM::LLVMIntegerType::get(&getTypeConverter()->getContext(), 8));
}
@@ -1084,7 +1071,7 @@ Value ConvertToLLVMPattern::getStridedElementPtr(
index ? rewriter.create<LLVM::AddOp>(loc, index, increment) : increment;
}
- LLVM::LLVMType elementPtrType = memRefDescriptor.getElementPtrType();
+ Type elementPtrType = memRefDescriptor.getElementPtrType();
return index ? rewriter.create<LLVM::GEPOp>(loc, elementPtrType, base, index)
: base;
}
@@ -1161,8 +1148,8 @@ Value ConvertToLLVMPattern::getSizeInBytes(
// %0 = getelementptr %elementType* null, %indexType 1
// %1 = ptrtoint %elementType* %0 to %indexType
// which is a common pattern of getting the size of a type in bytes.
- auto convertedPtrType = LLVM::LLVMPointerType::get(
- typeConverter->convertType(type).cast<LLVM::LLVMType>());
+ auto convertedPtrType =
+ LLVM::LLVMPointerType::get(typeConverter->convertType(type));
auto nullPtr = rewriter.create<LLVM::NullOp>(loc, convertedPtrType);
auto gep = rewriter.create<LLVM::GEPOp>(
loc, convertedPtrType,
@@ -1276,7 +1263,7 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc,
FuncOp funcOp, LLVM::LLVMFuncOp newFuncOp) {
OpBuilder::InsertionGuard guard(builder);
- LLVM::LLVMType wrapperType =
+ Type wrapperType =
typeConverter.convertFunctionTypeCWrapper(funcOp.getType());
// This conversion can only fail if it could not convert one of the argument
// types. But since it has been applies to a non-wrapper function before, it
@@ -1318,8 +1305,7 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc,
builder, loc, typeConverter, unrankedMemRefType,
wrapperArgsRange.take_front(numToDrop));
- auto ptrTy =
- LLVM::LLVMPointerType::get(packed.getType().cast<LLVM::LLVMType>());
+ auto ptrTy = LLVM::LLVMPointerType::get(packed.getType());
Value one = builder.create<LLVM::ConstantOp>(
loc, typeConverter.convertType(builder.getIndexType()),
builder.getIntegerAttr(builder.getIndexType(), 1));
@@ -1494,9 +1480,9 @@ struct BarePtrFuncOpConversion : public FuncOpConversionBase {
// 1-D LLVM vectors.
struct NDVectorTypeInfo {
// LLVM array struct which encodes n-D vectors.
- LLVM::LLVMType llvmArrayTy;
+ Type llvmArrayTy;
// LLVM vector type which encodes the inner 1-D vector type.
- LLVM::LLVMType llvmVectorTy;
+ Type llvmVectorTy;
// Multiplicity of llvmArrayTy to llvmVectorTy.
SmallVector<int64_t, 4> arraySizes;
};
@@ -1510,10 +1496,11 @@ static NDVectorTypeInfo extractNDVectorTypeInfo(VectorType vectorType,
LLVMTypeConverter &converter) {
assert(vectorType.getRank() > 1 && "expected >1D vector type");
NDVectorTypeInfo info;
- info.llvmArrayTy =
- converter.convertType(vectorType).dyn_cast<LLVM::LLVMType>();
- if (!info.llvmArrayTy)
+ info.llvmArrayTy = converter.convertType(vectorType);
+ if (!info.llvmArrayTy || !LLVM::isCompatibleType(info.llvmArrayTy)) {
+ info.llvmArrayTy = nullptr;
return info;
+ }
info.arraySizes.reserve(vectorType.getRank() - 1);
auto llvmTy = info.llvmArrayTy;
while (llvmTy.isa<LLVM::LLVMArrayType>()) {
@@ -1610,14 +1597,14 @@ LogicalResult LLVM::detail::oneToOneRewrite(
static LogicalResult handleMultidimensionalVectors(
Operation *op, ValueRange operands, LLVMTypeConverter &typeConverter,
- std::function<Value(LLVM::LLVMType, ValueRange)> createOperand,
+ std::function<Value(Type, ValueRange)> createOperand,
ConversionPatternRewriter &rewriter) {
auto vectorType = op->getResult(0).getType().dyn_cast<VectorType>();
if (!vectorType)
return failure();
auto vectorTypeInfo = extractNDVectorTypeInfo(vectorType, typeConverter);
auto llvmVectorTy = vectorTypeInfo.llvmVectorTy;
- auto llvmArrayTy = operands[0].getType().cast<LLVM::LLVMType>();
+ auto llvmArrayTy = operands[0].getType();
if (!llvmVectorTy || llvmArrayTy != vectorTypeInfo.llvmArrayTy)
return failure();
@@ -1645,14 +1632,14 @@ LogicalResult LLVM::detail::vectorOneToOneRewrite(
// Cannot convert ops if their operands are not of LLVM type.
if (!llvm::all_of(operands.getTypes(),
- [](Type t) { return t.isa<LLVM::LLVMType>(); }))
+ [](Type t) { return isCompatibleType(t); }))
return failure();
- auto llvmArrayTy = operands[0].getType().cast<LLVM::LLVMType>();
+ auto llvmArrayTy = operands[0].getType();
if (!llvmArrayTy.isa<LLVM::LLVMArrayType>())
return oneToOneRewrite(op, targetOp, operands, typeConverter, rewriter);
- auto callback = [op, targetOp, &rewriter](LLVM::LLVMType llvmVectorTy,
+ auto callback = [op, targetOp, &rewriter](Type llvmVectorTy,
ValueRange operands) {
OperationState state(op->getLoc(), targetOp);
state.addTypes(llvmVectorTy);
@@ -1896,16 +1883,18 @@ struct ConstantOpLowering : public ConvertOpToLLVMPattern<ConstantOp> {
ConversionPatternRewriter &rewriter) const override {
// If constant refers to a function, convert it to "addressof".
if (auto symbolRef = op.getValue().dyn_cast<FlatSymbolRefAttr>()) {
- auto type = typeConverter->convertType(op.getResult().getType())
- .dyn_cast_or_null<LLVM::LLVMType>();
- if (!type)
+ auto type = typeConverter->convertType(op.getResult().getType());
+ if (!type || !LLVM::isCompatibleType(type))
return rewriter.notifyMatchFailure(op, "failed to convert result type");
- NamedAttrList attrs(op->getAttrDictionary());
- attrs.erase("value");
- rewriter.replaceOpWithNewOp<LLVM::AddressOfOp>(
- op, type.cast<LLVM::LLVMType>(), symbolRef.getValue(),
- attrs.getAttrs());
+ auto newOp = rewriter.create<LLVM::AddressOfOp>(op.getLoc(), type,
+ symbolRef.getValue());
+ for (const NamedAttribute &attr : op->getAttrs()) {
+ if (attr.first.strref() == "value")
+ continue;
+ newOp.setAttr(attr.first, attr.second);
+ }
+ rewriter.replaceOp(op, newOp->getResults());
return success();
}
@@ -1947,11 +1936,11 @@ struct AllocLikeOpLowering : public ConvertToLLVMPattern {
Value createAllocCall(Location loc, StringRef name, Type ptrType,
ArrayRef<Value> params, ModuleOp module,
ConversionPatternRewriter &rewriter) const {
- SmallVector<LLVM::LLVMType, 2> paramTypes;
+ SmallVector<Type, 2> paramTypes;
auto allocFuncOp = module.lookupSymbol<LLVM::LLVMFuncOp>(name);
if (!allocFuncOp) {
for (Value param : params)
- paramTypes.push_back(param.getType().cast<LLVM::LLVMType>());
+ paramTypes.push_back(param.getType());
auto allocFuncType =
LLVM::LLVMFunctionType::get(getVoidPtrType(), paramTypes);
OpBuilder::InsertionGuard guard(rewriter);
@@ -2206,10 +2195,10 @@ static LogicalResult copyUnrankedDescriptors(OpBuilder &builder, Location loc,
// Get frequently used types.
MLIRContext *context = builder.getContext();
auto voidType = LLVM::LLVMVoidType::get(context);
- LLVM::LLVMType voidPtrType =
+ Type voidPtrType =
LLVM::LLVMPointerType::get(LLVM::LLVMIntegerType::get(context, 8));
auto i1Type = LLVM::LLVMIntegerType::get(context, 1);
- LLVM::LLVMType indexType = typeConverter.getIndexType();
+ Type indexType = typeConverter.getIndexType();
// Find the malloc and free, or declare them if necessary.
auto module = builder.getInsertionPoint()->getParentOfType<ModuleOp>();
@@ -2389,17 +2378,15 @@ struct DeallocOpLowering : public ConvertOpToLLVMPattern<DeallocOp> {
};
/// Returns the LLVM type of the global variable given the memref type `type`.
-static LLVM::LLVMType
-convertGlobalMemrefTypeToLLVM(MemRefType type,
- LLVMTypeConverter &typeConverter) {
+static Type convertGlobalMemrefTypeToLLVM(MemRefType type,
+ LLVMTypeConverter &typeConverter) {
// LLVM type for a global memref will be a multi-dimension array. For
// declarations or uninitialized global memrefs, we can potentially flatten
// this to a 1D array. However, for global_memref's with an initial value,
// we do not intend to flatten the ElementsAttribute when going from std ->
// LLVM dialect, so the LLVM type needs to me a multi-dimension array.
- LLVM::LLVMType elementType =
- unwrap(typeConverter.convertType(type.getElementType()));
- LLVM::LLVMType arrayTy = elementType;
+ Type elementType = unwrap(typeConverter.convertType(type.getElementType()));
+ Type arrayTy = elementType;
// Shape has the outermost dim at index 0, so need to walk it backwards
for (int64_t dim : llvm::reverse(type.getShape()))
arrayTy = LLVM::LLVMArrayType::get(arrayTy, dim);
@@ -2417,8 +2404,7 @@ struct GlobalMemrefOpLowering : public ConvertOpToLLVMPattern<GlobalMemrefOp> {
if (!isConvertibleAndHasIdentityMaps(type))
return failure();
- LLVM::LLVMType arrayTy =
- convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
+ Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
LLVM::Linkage linkage =
global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private;
@@ -2457,17 +2443,15 @@ struct GetGlobalMemrefOpLowering : public AllocLikeOpLowering {
MemRefType type = getGlobalOp.result().getType().cast<MemRefType>();
unsigned memSpace = type.getMemorySpace();
- LLVM::LLVMType arrayTy =
- convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
+ Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
auto addressOf = rewriter.create<LLVM::AddressOfOp>(
loc, LLVM::LLVMPointerType::get(arrayTy, memSpace), getGlobalOp.name());
// Get the address of the first element in the array by creating a GEP with
// the address of the GV as the base, and (rank + 1) number of 0 indices.
- LLVM::LLVMType elementType =
+ Type elementType =
unwrap(typeConverter->convertType(type.getElementType()));
- LLVM::LLVMType elementPtrType =
- LLVM::LLVMPointerType::get(elementType, memSpace);
+ Type elementPtrType = LLVM::LLVMPointerType::get(elementType, memSpace);
SmallVector<Value, 4> operands = {addressOf};
operands.insert(operands.end(), type.getRank() + 1,
@@ -2497,10 +2481,9 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern<RsqrtOp> {
matchAndRewrite(RsqrtOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
RsqrtOp::Adaptor transformed(operands);
- auto operandType =
- transformed.operand().getType().dyn_cast<LLVM::LLVMType>();
+ auto operandType = transformed.operand().getType();
- if (!operandType)
+ if (!operandType || !LLVM::isCompatibleType(operandType))
return failure();
auto loc = op.getLoc();
@@ -2528,7 +2511,7 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern<RsqrtOp> {
return handleMultidimensionalVectors(
op.getOperation(), operands, *getTypeConverter(),
- [&](LLVM::LLVMType llvmVectorTy, ValueRange operands) {
+ [&](Type llvmVectorTy, ValueRange operands) {
auto splatAttr = SplatElementsAttr::get(
mlir::VectorType::get(
{llvmVectorTy.cast<LLVM::LLVMFixedVectorType>()
@@ -2620,13 +2603,11 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<MemRefCastOp> {
// ptr = ExtractValueOp src, 1
auto ptr = memRefDesc.memRefDescPtr(rewriter, loc);
// castPtr = BitCastOp i8* to structTy*
- auto castPtr = rewriter
- .create<LLVM::BitcastOp>(
- loc,
- LLVM::LLVMPointerType::get(
- targetStructType.cast<LLVM::LLVMType>()),
- ptr)
- .getResult();
+ auto castPtr =
+ rewriter
+ .create<LLVM::BitcastOp>(
+ loc, LLVM::LLVMPointerType::get(targetStructType), ptr)
+ .getResult();
// struct = LoadOp castPtr
auto loadOp = rewriter.create<LLVM::LoadOp>(loc, castPtr);
rewriter.replaceOp(memRefCastOp, loadOp.getResult());
@@ -2659,9 +2640,8 @@ static void extractPointersAndOffset(Location loc,
unsigned memorySpace =
operandType.cast<UnrankedMemRefType>().getMemorySpace();
Type elementType = operandType.cast<UnrankedMemRefType>().getElementType();
- LLVM::LLVMType llvmElementType =
- unwrap(typeConverter.convertType(elementType));
- LLVM::LLVMType elementPtrPtrType = LLVM::LLVMPointerType::get(
+ Type llvmElementType = unwrap(typeConverter.convertType(elementType));
+ Type elementPtrPtrType = LLVM::LLVMPointerType::get(
LLVM::LLVMPointerType::get(llvmElementType, memorySpace));
// Extract pointer to the underlying ranked memref descriptor and cast it to
@@ -2809,8 +2789,7 @@ struct MemRefReshapeOpLowering
&allocatedPtr, &alignedPtr, &offset);
// Set pointers and offset.
- LLVM::LLVMType llvmElementType =
- unwrap(typeConverter->convertType(elementType));
+ Type llvmElementType = unwrap(typeConverter->convertType(elementType));
auto elementPtrPtrType = LLVM::LLVMPointerType::get(
LLVM::LLVMPointerType::get(llvmElementType, addressSpace));
UnrankedMemRefDescriptor::setAllocatedPtr(rewriter, loc, underlyingDescPtr,
@@ -2835,7 +2814,7 @@ struct MemRefReshapeOpLowering
rewriter.create<LLVM::SubOp>(loc, resultRank, oneIndex);
Block *initBlock = rewriter.getInsertionBlock();
- LLVM::LLVMType indexType = getTypeConverter()->getIndexType();
+ Type indexType = getTypeConverter()->getIndexType();
Block::iterator remainingOpsIt = std::next(rewriter.getInsertionPoint());
Block *condBlock = rewriter.createBlock(initBlock->getParent(), {},
@@ -2865,7 +2844,7 @@ struct MemRefReshapeOpLowering
rewriter.setInsertionPointToStart(bodyBlock);
// Copy size from shape to descriptor.
- LLVM::LLVMType llvmIndexPtrType = LLVM::LLVMPointerType::get(indexType);
+ Type llvmIndexPtrType = LLVM::LLVMPointerType::get(indexType);
Value sizeLoadGep = rewriter.create<LLVM::GEPOp>(
loc, llvmIndexPtrType, shapeOperandPtr, ValueRange{indexArg});
Value size = rewriter.create<LLVM::LoadOp>(loc, sizeLoadGep);
@@ -2957,9 +2936,8 @@ struct DimOpLowering : public ConvertOpToLLVMPattern<DimOp> {
Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc);
Value scalarMemRefDescPtr = rewriter.create<LLVM::BitcastOp>(
loc,
- LLVM::LLVMPointerType::get(
- typeConverter->convertType(scalarMemRefType).cast<LLVM::LLVMType>(),
- addressSpace),
+ LLVM::LLVMPointerType::get(typeConverter->convertType(scalarMemRefType),
+ addressSpace),
underlyingRankedDesc);
// Get pointer to offset field of memref<element_type> descriptor.
@@ -3435,8 +3413,7 @@ struct SubViewOpLowering : public ConvertOpToLLVMPattern<SubViewOp> {
auto sourceMemRefType = subViewOp.source().getType().cast<MemRefType>();
auto sourceElementTy =
- typeConverter->convertType(sourceMemRefType.getElementType())
- .dyn_cast_or_null<LLVM::LLVMType>();
+ typeConverter->convertType(sourceMemRefType.getElementType());
auto viewMemRefType = subViewOp.getType();
auto inferredType = SubViewOp::inferResultType(
@@ -3446,11 +3423,12 @@ struct SubViewOpLowering : public ConvertOpToLLVMPattern<SubViewOp> {
extractFromI64ArrayAttr(subViewOp.static_strides()))
.cast<MemRefType>();
auto targetElementTy =
- typeConverter->convertType(viewMemRefType.getElementType())
- .dyn_cast<LLVM::LLVMType>();
- auto targetDescTy = typeConverter->convertType(viewMemRefType)
- .dyn_cast_or_null<LLVM::LLVMType>();
- if (!sourceElementTy || !targetDescTy)
+ typeConverter->convertType(viewMemRefType.getElementType());
+ auto targetDescTy = typeConverter->convertType(viewMemRefType);
+ if (!sourceElementTy || !targetDescTy || !targetElementTy ||
+ !LLVM::isCompatibleType(sourceElementTy) ||
+ !LLVM::isCompatibleType(targetElementTy) ||
+ !LLVM::isCompatibleType(targetDescTy))
return failure();
// Extract the offset and strides from the type.
@@ -3461,7 +3439,7 @@ struct SubViewOpLowering : public ConvertOpToLLVMPattern<SubViewOp> {
return failure();
// Create the descriptor.
- if (!operands.front().getType().isa<LLVM::LLVMType>())
+ if (!LLVM::isCompatibleType(operands.front().getType()))
return failure();
MemRefDescriptor sourceMemRef(operands.front());
auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
@@ -3650,11 +3628,11 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<ViewOp> {
auto viewMemRefType = viewOp.getType();
auto targetElementTy =
- typeConverter->convertType(viewMemRefType.getElementType())
- .dyn_cast<LLVM::LLVMType>();
- auto targetDescTy =
- typeConverter->convertType(viewMemRefType).dyn_cast<LLVM::LLVMType>();
- if (!targetDescTy)
+ typeConverter->convertType(viewMemRefType.getElementType());
+ auto targetDescTy = typeConverter->convertType(viewMemRefType);
+ if (!targetDescTy || !targetElementTy ||
+ !LLVM::isCompatibleType(targetElementTy) ||
+ !LLVM::isCompatibleType(targetDescTy))
return viewOp.emitWarning("Target descriptor type not converted to LLVM"),
failure();
@@ -3849,9 +3827,7 @@ struct GenericAtomicRMWOpLowering
auto loc = atomicOp.getLoc();
GenericAtomicRMWOp::Adaptor adaptor(operands);
- LLVM::LLVMType valueType =
- typeConverter->convertType(atomicOp.getResult().getType())
- .cast<LLVM::LLVMType>();
+ Type valueType = typeConverter->convertType(atomicOp.getResult().getType());
// Split the block into initial, loop, and ending parts.
auto *initBlock = rewriter.getInsertionBlock();
@@ -4060,12 +4036,11 @@ Type LLVMTypeConverter::packFunctionResults(ArrayRef<Type> types) {
if (types.size() == 1)
return convertCallingConventionType(types.front());
- SmallVector<LLVM::LLVMType, 8> resultTypes;
+ SmallVector<Type, 8> resultTypes;
resultTypes.reserve(types.size());
for (auto t : types) {
- auto converted =
- convertCallingConventionType(t).dyn_cast_or_null<LLVM::LLVMType>();
- if (!converted)
+ auto converted = convertCallingConventionType(t);
+ if (!converted || !LLVM::isCompatibleType(converted))
return {};
resultTypes.push_back(converted);
}
@@ -4080,8 +4055,7 @@ Value LLVMTypeConverter::promoteOneMemRefDescriptor(Location loc, Value operand,
auto indexType = IndexType::get(context);
// Alloca with proper alignment. We do not expect optimizations of this
// alloca op and so we omit allocating at the entry block.
- auto ptrType =
- LLVM::LLVMPointerType::get(operand.getType().cast<LLVM::LLVMType>());
+ auto ptrType = LLVM::LLVMPointerType::get(operand.getType());
Value one = builder.create<LLVM::ConstantOp>(loc, int64Ty,
IntegerAttr::get(indexType, 1));
Value allocated =
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index e023b6da460b..535cdcb7dfd7 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -152,8 +152,7 @@ LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter,
// stop depending on translation.
llvm::LLVMContext llvmContext;
align = LLVM::TypeToLLVMIRTranslator(llvmContext)
- .getPreferredAlignment(elementTy.cast<LLVM::LLVMType>(),
- typeConverter.getDataLayout());
+ .getPreferredAlignment(elementTy, typeConverter.getDataLayout());
return success();
}
@@ -193,7 +192,7 @@ static LogicalResult getBasePtr(ConversionPatternRewriter &rewriter,
Value base;
if (failed(getBase(rewriter, loc, memref, memRefType, base)))
return failure();
- auto pType = LLVM::LLVMPointerType::get(type.template cast<LLVM::LLVMType>());
+ auto pType = LLVM::LLVMPointerType::get(type);
base = rewriter.create<LLVM::BitcastOp>(loc, pType, base);
ptr = rewriter.create<LLVM::GEPOp>(loc, pType, base);
return success();
@@ -1401,7 +1400,7 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
// Helper for printer method declaration (first hit) and lookup.
static Operation *getPrint(Operation *op, StringRef name,
- ArrayRef<LLVM::LLVMType> params) {
+ ArrayRef<Type> params) {
auto module = op->getParentOfType<ModuleOp>();
auto func = module.lookupSymbol<LLVM::LLVMFuncOp>(name);
if (func)
diff --git a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
index 3e3ddc6aaff6..d27f097a3baa 100644
--- a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
+++ b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
@@ -30,8 +30,8 @@ using namespace mlir::vector;
static LogicalResult replaceTransferOpWithMubuf(
ConversionPatternRewriter &rewriter, ArrayRef<Value> operands,
LLVMTypeConverter &typeConverter, Location loc, TransferReadOp xferOp,
- LLVM::LLVMType &vecTy, Value &dwordConfig, Value &vindex,
- Value &offsetSizeInBytes, Value &glc, Value &slc) {
+ Type &vecTy, Value &dwordConfig, Value &vindex, Value &offsetSizeInBytes,
+ Value &glc, Value &slc) {
rewriter.replaceOpWithNewOp<ROCDL::MubufLoadOp>(
xferOp, vecTy, dwordConfig, vindex, offsetSizeInBytes, glc, slc);
return success();
@@ -40,8 +40,8 @@ static LogicalResult replaceTransferOpWithMubuf(
static LogicalResult replaceTransferOpWithMubuf(
ConversionPatternRewriter &rewriter, ArrayRef<Value> operands,
LLVMTypeConverter &typeConverter, Location loc, TransferWriteOp xferOp,
- LLVM::LLVMType &vecTy, Value &dwordConfig, Value &vindex,
- Value &offsetSizeInBytes, Value &glc, Value &slc) {
+ Type &vecTy, Value &dwordConfig, Value &vindex, Value &offsetSizeInBytes,
+ Value &glc, Value &slc) {
auto adaptor = TransferWriteOpAdaptor(operands);
rewriter.replaceOpWithNewOp<ROCDL::MubufStoreOp>(xferOp, adaptor.vector(),
dwordConfig, vindex,
@@ -121,16 +121,16 @@ class VectorTransferConversion : public ConvertOpToLLVMPattern<ConcreteOp> {
Type i64Ty = rewriter.getIntegerType(64);
Value i64x2Ty = rewriter.create<LLVM::BitcastOp>(
loc,
- LLVM::LLVMFixedVectorType::get(
- toLLVMTy(i64Ty).template cast<LLVM::LLVMType>(), 2),
+ LLVM::LLVMFixedVectorType::get(toLLVMTy(i64Ty).template cast<Type>(),
+ 2),
constConfig);
Value dataPtrAsI64 = rewriter.create<LLVM::PtrToIntOp>(
- loc, toLLVMTy(i64Ty).template cast<LLVM::LLVMType>(), dataPtr);
+ loc, toLLVMTy(i64Ty).template cast<Type>(), dataPtr);
Value zero = this->createIndexConstant(rewriter, loc, 0);
Value dwordConfig = rewriter.create<LLVM::InsertElementOp>(
loc,
- LLVM::LLVMFixedVectorType::get(
- toLLVMTy(i64Ty).template cast<LLVM::LLVMType>(), 2),
+ LLVM::LLVMFixedVectorType::get(toLLVMTy(i64Ty).template cast<Type>(),
+ 2),
i64x2Ty, dataPtrAsI64, zero);
dwordConfig =
rewriter.create<LLVM::BitcastOp>(loc, toLLVMTy(i32Vecx4), dwordConfig);
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 765538ca7a53..0a9b61628384 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -101,14 +101,14 @@ static ParseResult parseCmpOp(OpAsmParser &parser, OperationState &result) {
// The result type is either i1 or a vector type <? x i1> if the inputs are
// vectors.
- LLVMType resultType = LLVMIntegerType::get(builder.getContext(), 1);
- auto argType = type.dyn_cast<LLVM::LLVMType>();
- if (!argType)
- return parser.emitError(trailingTypeLoc, "expected LLVM IR dialect type");
- if (auto vecArgType = argType.dyn_cast<LLVM::LLVMFixedVectorType>())
+ Type resultType = LLVMIntegerType::get(builder.getContext(), 1);
+ if (!isCompatibleType(type))
+ return parser.emitError(trailingTypeLoc,
+ "expected LLVM dialect-compatible type");
+ if (auto vecArgType = type.dyn_cast<LLVM::LLVMFixedVectorType>())
resultType =
LLVMFixedVectorType::get(resultType, vecArgType.getNumElements());
- assert(!argType.isa<LLVM::LLVMScalableVectorType>() &&
+ assert(!type.isa<LLVM::LLVMScalableVectorType>() &&
"unhandled scalable vector");
result.addTypes({resultType});
@@ -546,21 +546,21 @@ static ParseResult parseInvokeOp(OpAsmParser &parser, OperationState &result) {
return parser.emitError(trailingTypeLoc,
"expected function with 0 or 1 result");
- LLVM::LLVMType llvmResultType;
+ Type llvmResultType;
if (funcType.getNumResults() == 0) {
llvmResultType = LLVM::LLVMVoidType::get(builder.getContext());
} else {
- llvmResultType = funcType.getResult(0).dyn_cast<LLVM::LLVMType>();
- if (!llvmResultType)
+ llvmResultType = funcType.getResult(0);
+ if (!isCompatibleType(llvmResultType))
return parser.emitError(trailingTypeLoc,
"expected result to have LLVM type");
}
- SmallVector<LLVM::LLVMType, 8> argTypes;
+ SmallVector<Type, 8> argTypes;
argTypes.reserve(funcType.getNumInputs());
for (Type ty : funcType.getInputs()) {
- if (auto argType = ty.dyn_cast<LLVM::LLVMType>())
- argTypes.push_back(argType);
+ if (isCompatibleType(ty))
+ argTypes.push_back(ty);
else
return parser.emitError(trailingTypeLoc,
"expected LLVM types as inputs");
@@ -693,7 +693,7 @@ static LogicalResult verify(CallOp &op) {
// Type for the callee, we'll get it
diff erently depending if it is a direct
// or indirect call.
- LLVMType fnType;
+ Type fnType;
bool isIndirect = false;
@@ -704,14 +704,10 @@ static LogicalResult verify(CallOp &op) {
if (!op.getNumOperands())
return op.emitOpError(
"must have either a `callee` attribute or at least an operand");
- fnType = op.getOperand(0).getType().dyn_cast<LLVMType>();
- if (!fnType)
- return op.emitOpError("indirect call to a non-llvm type: ")
- << op.getOperand(0).getType();
- auto ptrType = fnType.dyn_cast<LLVMPointerType>();
+ auto ptrType = op.getOperand(0).getType().dyn_cast<LLVMPointerType>();
if (!ptrType)
return op.emitOpError("indirect call expects a pointer as callee: ")
- << fnType;
+ << ptrType;
fnType = ptrType.getElementType();
} else {
Operation *callee = SymbolTable::lookupNearestSymbolFrom(op, *calleeName);
@@ -825,21 +821,21 @@ static ParseResult parseCallOp(OpAsmParser &parser, OperationState &result) {
"expected function with 0 or 1 result");
Builder &builder = parser.getBuilder();
- LLVM::LLVMType llvmResultType;
+ Type llvmResultType;
if (funcType.getNumResults() == 0) {
llvmResultType = LLVM::LLVMVoidType::get(builder.getContext());
} else {
- llvmResultType = funcType.getResult(0).dyn_cast<LLVM::LLVMType>();
- if (!llvmResultType)
+ llvmResultType = funcType.getResult(0);
+ if (!isCompatibleType(llvmResultType))
return parser.emitError(trailingTypeLoc,
"expected result to have LLVM type");
}
- SmallVector<LLVM::LLVMType, 8> argTypes;
+ SmallVector<Type, 8> argTypes;
argTypes.reserve(funcType.getNumInputs());
for (int i = 0, e = funcType.getNumInputs(); i < e; ++i) {
- auto argType = funcType.getInput(i).dyn_cast<LLVM::LLVMType>();
- if (!argType)
+ auto argType = funcType.getInput(i);
+ if (!isCompatibleType(argType))
return parser.emitError(trailingTypeLoc,
"expected LLVM types as inputs");
argTypes.push_back(argType);
@@ -922,13 +918,13 @@ static void printExtractValueOp(OpAsmPrinter &p, ExtractValueOp &op) {
// `containerType`. Position is an integer array attribute where each value
// is a zero-based position of the element in the aggregate type. Return the
// resulting type wrapped in MLIR, or nullptr on error.
-static LLVM::LLVMType getInsertExtractValueElementType(OpAsmParser &parser,
- Type containerType,
- ArrayAttr positionAttr,
- llvm::SMLoc attributeLoc,
- llvm::SMLoc typeLoc) {
- auto llvmType = containerType.dyn_cast<LLVM::LLVMType>();
- if (!llvmType)
+static Type getInsertExtractValueElementType(OpAsmParser &parser,
+ Type containerType,
+ ArrayAttr positionAttr,
+ llvm::SMLoc attributeLoc,
+ llvm::SMLoc typeLoc) {
+ Type llvmType = containerType;
+ if (!isCompatibleType(containerType))
return parser.emitError(typeLoc, "expected LLVM IR Dialect type"), nullptr;
// Infer the element type from the structure type: iteratively step inside the
@@ -1162,7 +1158,7 @@ static LogicalResult verify(AddressOfOp op) {
/// the name of the attribute in ODS.
static StringRef getLinkageAttrName() { return "linkage"; }
-void GlobalOp::build(OpBuilder &builder, OperationState &result, LLVMType type,
+void GlobalOp::build(OpBuilder &builder, OperationState &result, Type type,
bool isConstant, Linkage linkage, StringRef name,
Attribute value, unsigned addrSpace,
ArrayRef<NamedAttribute> attrs) {
@@ -1212,14 +1208,13 @@ static void printGlobalOp(OpAsmPrinter &p, GlobalOp op) {
/// report the error, the user is expected to produce an appropriate message.
// TODO: make the size depend on data layout rather than on the conversion
// pass option, and pull that information here.
-static LogicalResult verifyCastWithIndex(LLVMType llvmType) {
+static LogicalResult verifyCastWithIndex(Type llvmType) {
return success(llvmType.isa<LLVMIntegerType>());
}
/// Checks if `llvmType` is dialect cast-compatible with built-in `type` and
/// reports errors to the location of `op`.
-static LogicalResult verifyCast(DialectCastOp op, LLVMType llvmType,
- Type type) {
+static LogicalResult verifyCast(DialectCastOp op, Type llvmType, Type type) {
// Index is compatible with any integer.
if (type.isIndex()) {
if (succeeded(verifyCastWithIndex(llvmType)))
@@ -1387,14 +1382,13 @@ static LogicalResult verifyCast(DialectCastOp op, LLVMType llvmType,
}
static LogicalResult verify(DialectCastOp op) {
- if (auto llvmType = op.getType().dyn_cast<LLVMType>())
- return verifyCast(op, llvmType, op.in().getType());
+ if (isCompatibleType(op.getType()))
+ return verifyCast(op, op.getType(), op.in().getType());
- auto llvmType = op.in().getType().dyn_cast<LLVMType>();
- if (!llvmType)
+ if (!isCompatibleType(op.in().getType()))
return op->emitOpError("expected one LLVM type and one built-in type");
- return verifyCast(op, llvmType, op.getType());
+ return verifyCast(op, op.in().getType(), op.getType());
}
// Parses one of the keywords provided in the list `keywords` and returns the
@@ -1597,7 +1591,7 @@ Block *LLVMFuncOp::addEntryBlock() {
}
void LLVMFuncOp::build(OpBuilder &builder, OperationState &result,
- StringRef name, LLVMType type, LLVM::Linkage linkage,
+ StringRef name, Type type, LLVM::Linkage linkage,
ArrayRef<NamedAttribute> attrs,
ArrayRef<DictionaryAttr> argAttrs) {
result.addRegion();
@@ -1633,23 +1627,23 @@ static Type buildLLVMFunctionType(OpAsmParser &parser, llvm::SMLoc loc,
}
// Convert inputs to LLVM types, exit early on error.
- SmallVector<LLVMType, 4> llvmInputs;
+ SmallVector<Type, 4> llvmInputs;
for (auto t : inputs) {
- auto llvmTy = t.dyn_cast<LLVMType>();
- if (!llvmTy) {
+ if (!isCompatibleType(t)) {
parser.emitError(loc, "failed to construct function type: expected LLVM "
"type for function arguments");
return {};
}
- llvmInputs.push_back(llvmTy);
+ llvmInputs.push_back(t);
}
// No output is denoted as "void" in LLVM type system.
- LLVMType llvmOutput = outputs.empty() ? LLVMVoidType::get(b.getContext())
- : outputs.front().dyn_cast<LLVMType>();
- if (!llvmOutput) {
+ Type llvmOutput =
+ outputs.empty() ? LLVMVoidType::get(b.getContext()) : outputs.front();
+ if (!isCompatibleType(llvmOutput)) {
parser.emitError(loc, "failed to construct function type: expected LLVM "
- "type for function results");
+ "type for function results")
+ << llvmOutput;
return {};
}
return LLVMFunctionType::get(llvmOutput, llvmInputs,
@@ -1720,7 +1714,7 @@ static void printLLVMFuncOp(OpAsmPrinter &p, LLVMFuncOp op) {
for (unsigned i = 0, e = fnType.getNumParams(); i < e; ++i)
argTypes.push_back(fnType.getParamType(i));
- LLVMType returnType = fnType.getReturnType();
+ Type returnType = fnType.getReturnType();
if (!returnType.isa<LLVMVoidType>())
resTypes.push_back(returnType);
@@ -1792,11 +1786,10 @@ static LogicalResult verify(LLVMFuncOp op) {
Block &entryBlock = op.front();
for (unsigned i = 0; i < numArguments; ++i) {
Type argType = entryBlock.getArgument(i).getType();
- auto argLLVMType = argType.dyn_cast<LLVMType>();
- if (!argLLVMType)
+ if (!isCompatibleType(argType))
return op.emitOpError("entry block argument #")
<< i << " is not of LLVM type";
- if (op.getType().getParamType(i) != argLLVMType)
+ if (op.getType().getParamType(i) != argType)
return op.emitOpError("the type of entry block argument #")
<< i << " does not match the function signature";
}
@@ -1889,7 +1882,7 @@ static void printAtomicRMWOp(OpAsmPrinter &p, AtomicRMWOp &op) {
// attribute-dict? `:` type
static ParseResult parseAtomicRMWOp(OpAsmParser &parser,
OperationState &result) {
- LLVMType type;
+ Type type;
OpAsmParser::OperandType ptr, val;
if (parseAtomicBinOp(parser, result, "bin_op") || parser.parseOperand(ptr) ||
parser.parseComma() || parser.parseOperand(val) ||
@@ -1907,11 +1900,11 @@ static ParseResult parseAtomicRMWOp(OpAsmParser &parser,
static LogicalResult verify(AtomicRMWOp op) {
auto ptrType = op.ptr().getType().cast<LLVM::LLVMPointerType>();
- auto valType = op.val().getType().cast<LLVM::LLVMType>();
+ auto valType = op.val().getType();
if (valType != ptrType.getElementType())
return op.emitOpError("expected LLVM IR element type for operand #0 to "
"match type for operand #1");
- auto resType = op.res().getType().cast<LLVM::LLVMType>();
+ auto resType = op.res().getType();
if (resType != valType)
return op.emitOpError(
"expected LLVM IR result type to match type for operand #1");
@@ -1954,7 +1947,7 @@ static void printAtomicCmpXchgOp(OpAsmPrinter &p, AtomicCmpXchgOp &op) {
static ParseResult parseAtomicCmpXchgOp(OpAsmParser &parser,
OperationState &result) {
auto &builder = parser.getBuilder();
- LLVMType type;
+ Type type;
OpAsmParser::OperandType ptr, cmp, val;
if (parser.parseOperand(ptr) || parser.parseComma() ||
parser.parseOperand(cmp) || parser.parseComma() ||
@@ -1981,8 +1974,8 @@ static LogicalResult verify(AtomicCmpXchgOp op) {
auto ptrType = op.ptr().getType().cast<LLVM::LLVMPointerType>();
if (!ptrType)
return op.emitOpError("expected LLVM IR pointer type for operand #0");
- auto cmpType = op.cmp().getType().cast<LLVM::LLVMType>();
- auto valType = op.val().getType().cast<LLVM::LLVMType>();
+ auto cmpType = op.cmp().getType();
+ auto valType = op.val().getType();
if (cmpType != ptrType.getElementType() || cmpType != valType)
return op.emitOpError("expected LLVM IR element type for operand #0 to "
"match type for all other operands");
@@ -2088,7 +2081,7 @@ Type LLVMDialect::parseType(DialectAsmParser &parser) const {
/// Print a type registered to this dialect.
void LLVMDialect::printType(Type type, DialectAsmPrinter &os) const {
- return detail::printType(type.cast<LLVMType>(), os);
+ return detail::printType(type, os);
}
LogicalResult LLVMDialect::verifyDataLayoutString(
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
index 574d0aa8c37f..3d72e254f338 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
@@ -19,11 +19,11 @@ using namespace mlir::LLVM;
// Printing.
//===----------------------------------------------------------------------===//
-static void printTypeImpl(llvm::raw_ostream &os, LLVMType type,
+static void printTypeImpl(llvm::raw_ostream &os, Type type,
llvm::SetVector<StringRef> &stack);
/// Returns the keyword to use for the given type.
-static StringRef getTypeKeyword(LLVMType type) {
+static StringRef getTypeKeyword(Type type) {
return TypeSwitch<Type, StringRef>(type)
.Case<LLVMVoidType>([&](Type) { return "void"; })
.Case<LLVMHalfType>([&](Type) { return "half"; })
@@ -64,7 +64,7 @@ static void printStructTypeBody(llvm::raw_ostream &os, LLVMStructType type,
os << '(';
if (type.isIdentified())
stack.insert(type.getName());
- llvm::interleaveComma(type.getBody(), os, [&](LLVMType subtype) {
+ llvm::interleaveComma(type.getBody(), os, [&](Type subtype) {
printTypeImpl(os, subtype, stack);
});
if (type.isIdentified())
@@ -109,9 +109,9 @@ static void printFunctionType(llvm::raw_ostream &os, LLVMFunctionType funcType,
os << '<';
printTypeImpl(os, funcType.getReturnType(), stack);
os << " (";
- llvm::interleaveComma(
- funcType.getParams(), os,
- [&os, &stack](LLVMType subtype) { printTypeImpl(os, subtype, stack); });
+ llvm::interleaveComma(funcType.getParams(), os, [&os, &stack](Type subtype) {
+ printTypeImpl(os, subtype, stack);
+ });
if (funcType.isVarArg()) {
if (funcType.getNumParams() != 0)
os << ", ";
@@ -129,7 +129,7 @@ static void printFunctionType(llvm::raw_ostream &os, LLVMFunctionType funcType,
/// struct<"c", (ptr<struct<"b", (ptr<struct<"c">>)>>,
/// ptr<struct<"b", (ptr<struct<"c">>)>>)>
/// note that "b" is printed twice.
-static void printTypeImpl(llvm::raw_ostream &os, LLVMType type,
+static void printTypeImpl(llvm::raw_ostream &os, Type type,
llvm::SetVector<StringRef> &stack) {
if (!type) {
os << "<<NULL-TYPE>>";
@@ -171,7 +171,7 @@ static void printTypeImpl(llvm::raw_ostream &os, LLVMType type,
return printFunctionType(os, funcType, stack);
}
-void mlir::LLVM::detail::printType(LLVMType type, DialectAsmPrinter &printer) {
+void mlir::LLVM::detail::printType(Type type, DialectAsmPrinter &printer) {
llvm::SetVector<StringRef> stack;
return printTypeImpl(printer.getStream(), type, stack);
}
@@ -180,13 +180,13 @@ void mlir::LLVM::detail::printType(LLVMType type, DialectAsmPrinter &printer) {
// Parsing.
//===----------------------------------------------------------------------===//
-static LLVMType parseTypeImpl(DialectAsmParser &parser,
- llvm::SetVector<StringRef> &stack);
+static Type parseTypeImpl(DialectAsmParser &parser,
+ llvm::SetVector<StringRef> &stack);
/// Helper to be chained with other parsing functions.
static ParseResult parseTypeImpl(DialectAsmParser &parser,
llvm::SetVector<StringRef> &stack,
- LLVMType &result) {
+ Type &result) {
result = parseTypeImpl(parser, stack);
return success(result != nullptr);
}
@@ -196,7 +196,7 @@ static ParseResult parseTypeImpl(DialectAsmParser &parser,
static LLVMFunctionType parseFunctionType(DialectAsmParser &parser,
llvm::SetVector<StringRef> &stack) {
Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
- LLVMType returnType;
+ Type returnType;
if (parser.parseLess() || parseTypeImpl(parser, stack, returnType) ||
parser.parseLParen())
return LLVMFunctionType();
@@ -210,7 +210,7 @@ static LLVMFunctionType parseFunctionType(DialectAsmParser &parser,
}
// Parse arguments.
- SmallVector<LLVMType, 8> argTypes;
+ SmallVector<Type, 8> argTypes;
do {
if (succeeded(parser.parseOptionalEllipsis())) {
if (parser.parseOptionalRParen() || parser.parseOptionalGreater())
@@ -235,7 +235,7 @@ static LLVMFunctionType parseFunctionType(DialectAsmParser &parser,
static LLVMPointerType parsePointerType(DialectAsmParser &parser,
llvm::SetVector<StringRef> &stack) {
Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
- LLVMType elementType;
+ Type elementType;
if (parser.parseLess() || parseTypeImpl(parser, stack, elementType))
return LLVMPointerType();
@@ -255,7 +255,7 @@ static LLVMVectorType parseVectorType(DialectAsmParser &parser,
llvm::SetVector<StringRef> &stack) {
SmallVector<int64_t, 2> dims;
llvm::SMLoc dimPos;
- LLVMType elementType;
+ Type elementType;
Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
if (parser.parseLess() || parser.getCurrentLocation(&dimPos) ||
parser.parseDimensionList(dims, /*allowDynamic=*/true) ||
@@ -286,7 +286,7 @@ static LLVMArrayType parseArrayType(DialectAsmParser &parser,
llvm::SetVector<StringRef> &stack) {
SmallVector<int64_t, 1> dims;
llvm::SMLoc sizePos;
- LLVMType elementType;
+ Type elementType;
Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
if (parser.parseLess() || parser.getCurrentLocation(&sizePos) ||
parser.parseDimensionList(dims, /*allowDynamic=*/false) ||
@@ -305,11 +305,11 @@ static LLVMArrayType parseArrayType(DialectAsmParser &parser,
/// error at `subtypesLoc` in case of failure, uses `stack` to make sure the
/// types printed in the error message look like they did when parsed.
static LLVMStructType trySetStructBody(LLVMStructType type,
- ArrayRef<LLVMType> subtypes,
- bool isPacked, DialectAsmParser &parser,
+ ArrayRef<Type> subtypes, bool isPacked,
+ DialectAsmParser &parser,
llvm::SMLoc subtypesLoc,
llvm::SetVector<StringRef> &stack) {
- for (LLVMType t : subtypes) {
+ for (Type t : subtypes) {
if (!LLVMStructType::isValidElementType(t)) {
parser.emitError(subtypesLoc)
<< "invalid LLVM structure element type: " << t;
@@ -389,12 +389,12 @@ static LLVMStructType parseStructType(DialectAsmParser &parser,
// Parse subtypes. For identified structs, put the identifier of the struct on
// the stack to support self-references in the recursive calls.
- SmallVector<LLVMType, 4> subtypes;
+ SmallVector<Type, 4> subtypes;
llvm::SMLoc subtypesLoc = parser.getCurrentLocation();
do {
if (isIdentified)
stack.insert(name);
- LLVMType type = parseTypeImpl(parser, stack);
+ Type type = parseTypeImpl(parser, stack);
if (!type)
return LLVMStructType();
subtypes.push_back(type);
@@ -413,8 +413,8 @@ static LLVMStructType parseStructType(DialectAsmParser &parser,
}
/// Parses one of the LLVM dialect types.
-static LLVMType parseTypeImpl(DialectAsmParser &parser,
- llvm::SetVector<StringRef> &stack) {
+static Type parseTypeImpl(DialectAsmParser &parser,
+ llvm::SetVector<StringRef> &stack) {
// Special case for integers (i[1-9][0-9]*) that are literals rather than
// keywords for the parser, so they are not caught by the main dispatch below.
// Try parsing it a built-in integer type instead.
@@ -425,11 +425,11 @@ static LLVMType parseTypeImpl(DialectAsmParser &parser,
OptionalParseResult result = parser.parseOptionalType(maybeIntegerType);
if (result.hasValue()) {
if (failed(*result))
- return LLVMType();
+ return Type();
if (!maybeIntegerType.isSignlessInteger()) {
parser.emitError(keyLoc) << "unexpected type, expected i* or keyword";
- return LLVMType();
+ return Type();
}
return LLVMIntegerType::getChecked(
loc, maybeIntegerType.getIntOrFloatBitWidth());
@@ -438,9 +438,9 @@ static LLVMType parseTypeImpl(DialectAsmParser &parser,
// Dispatch to concrete functions.
StringRef key;
if (failed(parser.parseKeyword(&key)))
- return LLVMType();
+ return Type();
- return StringSwitch<function_ref<LLVMType()>>(key)
+ return StringSwitch<function_ref<Type()>>(key)
.Case("void", [&] { return LLVMVoidType::get(ctx); })
.Case("half", [&] { return LLVMHalfType::get(ctx); })
.Case("bfloat", [&] { return LLVMBFloatType::get(ctx); })
@@ -460,11 +460,11 @@ static LLVMType parseTypeImpl(DialectAsmParser &parser,
.Case("struct", [&] { return parseStructType(parser, stack); })
.Default([&] {
parser.emitError(keyLoc) << "unknown LLVM type: " << key;
- return LLVMType();
+ return Type();
})();
}
-LLVMType mlir::LLVM::detail::parseType(DialectAsmParser &parser) {
+Type mlir::LLVM::detail::parseType(DialectAsmParser &parser) {
llvm::SetVector<StringRef> stack;
return parseTypeImpl(parser, stack);
}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
index 3d75245a1fb3..c982abf8ad72 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
@@ -24,44 +24,32 @@
using namespace mlir;
using namespace mlir::LLVM;
-//===----------------------------------------------------------------------===//
-// LLVMType.
-//===----------------------------------------------------------------------===//
-
-bool LLVMType::classof(Type type) {
- return llvm::isa<LLVMDialect>(type.getDialect());
-}
-
-LLVMDialect &LLVMType::getDialect() {
- return static_cast<LLVMDialect &>(Type::getDialect());
-}
-
//===----------------------------------------------------------------------===//
// Array type.
//===----------------------------------------------------------------------===//
-bool LLVMArrayType::isValidElementType(LLVMType type) {
+bool LLVMArrayType::isValidElementType(Type type) {
return !type.isa<LLVMVoidType, LLVMLabelType, LLVMMetadataType,
LLVMFunctionType, LLVMTokenType, LLVMScalableVectorType>();
}
-LLVMArrayType LLVMArrayType::get(LLVMType elementType, unsigned numElements) {
+LLVMArrayType LLVMArrayType::get(Type elementType, unsigned numElements) {
assert(elementType && "expected non-null subtype");
return Base::get(elementType.getContext(), elementType, numElements);
}
-LLVMArrayType LLVMArrayType::getChecked(Location loc, LLVMType elementType,
+LLVMArrayType LLVMArrayType::getChecked(Location loc, Type elementType,
unsigned numElements) {
assert(elementType && "expected non-null subtype");
return Base::getChecked(loc, elementType, numElements);
}
-LLVMType LLVMArrayType::getElementType() { return getImpl()->elementType; }
+Type LLVMArrayType::getElementType() { return getImpl()->elementType; }
unsigned LLVMArrayType::getNumElements() { return getImpl()->numElements; }
LogicalResult
-LLVMArrayType::verifyConstructionInvariants(Location loc, LLVMType elementType,
+LLVMArrayType::verifyConstructionInvariants(Location loc, Type elementType,
unsigned numElements) {
if (!isValidElementType(elementType))
return emitError(loc, "invalid array element type: ") << elementType;
@@ -72,52 +60,50 @@ LLVMArrayType::verifyConstructionInvariants(Location loc, LLVMType elementType,
// Function type.
//===----------------------------------------------------------------------===//
-bool LLVMFunctionType::isValidArgumentType(LLVMType type) {
+bool LLVMFunctionType::isValidArgumentType(Type type) {
return !type.isa<LLVMVoidType, LLVMFunctionType>();
}
-bool LLVMFunctionType::isValidResultType(LLVMType type) {
+bool LLVMFunctionType::isValidResultType(Type type) {
return !type.isa<LLVMFunctionType, LLVMMetadataType, LLVMLabelType>();
}
-LLVMFunctionType LLVMFunctionType::get(LLVMType result,
- ArrayRef<LLVMType> arguments,
+LLVMFunctionType LLVMFunctionType::get(Type result, ArrayRef<Type> arguments,
bool isVarArg) {
assert(result && "expected non-null result");
return Base::get(result.getContext(), result, arguments, isVarArg);
}
-LLVMFunctionType LLVMFunctionType::getChecked(Location loc, LLVMType result,
- ArrayRef<LLVMType> arguments,
+LLVMFunctionType LLVMFunctionType::getChecked(Location loc, Type result,
+ ArrayRef<Type> arguments,
bool isVarArg) {
assert(result && "expected non-null result");
return Base::getChecked(loc, result, arguments, isVarArg);
}
-LLVMType LLVMFunctionType::getReturnType() {
- return getImpl()->getReturnType();
-}
+Type LLVMFunctionType::getReturnType() { return getImpl()->getReturnType(); }
unsigned LLVMFunctionType::getNumParams() {
return getImpl()->getArgumentTypes().size();
}
-LLVMType LLVMFunctionType::getParamType(unsigned i) {
+Type LLVMFunctionType::getParamType(unsigned i) {
return getImpl()->getArgumentTypes()[i];
}
bool LLVMFunctionType::isVarArg() { return getImpl()->isVariadic(); }
-ArrayRef<LLVMType> LLVMFunctionType::getParams() {
+ArrayRef<Type> LLVMFunctionType::getParams() {
return getImpl()->getArgumentTypes();
}
-LogicalResult LLVMFunctionType::verifyConstructionInvariants(
- Location loc, LLVMType result, ArrayRef<LLVMType> arguments, bool) {
+LogicalResult
+LLVMFunctionType::verifyConstructionInvariants(Location loc, Type result,
+ ArrayRef<Type> arguments, bool) {
if (!isValidResultType(result))
return emitError(loc, "invalid function result type: ") << result;
- for (LLVMType arg : arguments)
+ for (Type arg : arguments)
if (!isValidArgumentType(arg))
return emitError(loc, "invalid function argument type: ") << arg;
@@ -150,27 +136,27 @@ LogicalResult LLVMIntegerType::verifyConstructionInvariants(Location loc,
// Pointer type.
//===----------------------------------------------------------------------===//
-bool LLVMPointerType::isValidElementType(LLVMType type) {
+bool LLVMPointerType::isValidElementType(Type type) {
return !type.isa<LLVMVoidType, LLVMTokenType, LLVMMetadataType,
LLVMLabelType>();
}
-LLVMPointerType LLVMPointerType::get(LLVMType pointee, unsigned addressSpace) {
+LLVMPointerType LLVMPointerType::get(Type pointee, unsigned addressSpace) {
assert(pointee && "expected non-null subtype");
return Base::get(pointee.getContext(), pointee, addressSpace);
}
-LLVMPointerType LLVMPointerType::getChecked(Location loc, LLVMType pointee,
+LLVMPointerType LLVMPointerType::getChecked(Location loc, Type pointee,
unsigned addressSpace) {
return Base::getChecked(loc, pointee, addressSpace);
}
-LLVMType LLVMPointerType::getElementType() { return getImpl()->pointeeType; }
+Type LLVMPointerType::getElementType() { return getImpl()->pointeeType; }
unsigned LLVMPointerType::getAddressSpace() { return getImpl()->addressSpace; }
LogicalResult LLVMPointerType::verifyConstructionInvariants(Location loc,
- LLVMType pointee,
+ Type pointee,
unsigned) {
if (!isValidElementType(pointee))
return emitError(loc, "invalid pointer element type: ") << pointee;
@@ -181,7 +167,7 @@ LogicalResult LLVMPointerType::verifyConstructionInvariants(Location loc,
// Struct type.
//===----------------------------------------------------------------------===//
-bool LLVMStructType::isValidElementType(LLVMType type) {
+bool LLVMStructType::isValidElementType(Type type) {
return !type.isa<LLVMVoidType, LLVMLabelType, LLVMMetadataType,
LLVMFunctionType, LLVMTokenType, LLVMScalableVectorType>();
}
@@ -198,7 +184,7 @@ LLVMStructType LLVMStructType::getIdentifiedChecked(Location loc,
LLVMStructType LLVMStructType::getNewIdentified(MLIRContext *context,
StringRef name,
- ArrayRef<LLVMType> elements,
+ ArrayRef<Type> elements,
bool isPacked) {
std::string stringName = name.str();
unsigned counter = 0;
@@ -214,13 +200,12 @@ LLVMStructType LLVMStructType::getNewIdentified(MLIRContext *context,
}
LLVMStructType LLVMStructType::getLiteral(MLIRContext *context,
- ArrayRef<LLVMType> types,
- bool isPacked) {
+ ArrayRef<Type> types, bool isPacked) {
return Base::get(context, types, isPacked);
}
LLVMStructType LLVMStructType::getLiteralChecked(Location loc,
- ArrayRef<LLVMType> types,
+ ArrayRef<Type> types,
bool isPacked) {
return Base::getChecked(loc, types, isPacked);
}
@@ -233,7 +218,7 @@ LLVMStructType LLVMStructType::getOpaqueChecked(Location loc, StringRef name) {
return Base::getChecked(loc, name, /*opaque=*/true);
}
-LogicalResult LLVMStructType::setBody(ArrayRef<LLVMType> types, bool isPacked) {
+LogicalResult LLVMStructType::setBody(ArrayRef<Type> types, bool isPacked) {
assert(isIdentified() && "can only set bodies of identified structs");
assert(llvm::all_of(types, LLVMStructType::isValidElementType) &&
"expected valid body types");
@@ -248,7 +233,7 @@ bool LLVMStructType::isOpaque() {
}
bool LLVMStructType::isInitialized() { return getImpl()->isInitialized(); }
StringRef LLVMStructType::getName() { return getImpl()->getIdentifier(); }
-ArrayRef<LLVMType> LLVMStructType::getBody() {
+ArrayRef<Type> LLVMStructType::getBody() {
return isIdentified() ? getImpl()->getIdentifiedStructBody()
: getImpl()->getTypeList();
}
@@ -258,10 +243,10 @@ LogicalResult LLVMStructType::verifyConstructionInvariants(Location, StringRef,
return success();
}
-LogicalResult
-LLVMStructType::verifyConstructionInvariants(Location loc,
- ArrayRef<LLVMType> types, bool) {
- for (LLVMType t : types)
+LogicalResult LLVMStructType::verifyConstructionInvariants(Location loc,
+ ArrayRef<Type> types,
+ bool) {
+ for (Type t : types)
if (!isValidElementType(t))
return emitError(loc, "invalid LLVM structure element type: ") << t;
@@ -272,7 +257,7 @@ LLVMStructType::verifyConstructionInvariants(Location loc,
// Vector types.
//===----------------------------------------------------------------------===//
-bool LLVMVectorType::isValidElementType(LLVMType type) {
+bool LLVMVectorType::isValidElementType(Type type) {
return type.isa<LLVMIntegerType, LLVMPointerType>() ||
mlir::LLVM::isCompatibleFloatingPointType(type);
}
@@ -282,7 +267,7 @@ bool LLVMVectorType::classof(Type type) {
return type.isa<LLVMFixedVectorType, LLVMScalableVectorType>();
}
-LLVMType LLVMVectorType::getElementType() {
+Type LLVMVectorType::getElementType() {
// Both derived classes share the implementation type.
return static_cast<detail::LLVMTypeAndSizeStorage *>(impl)->elementType;
}
@@ -296,7 +281,7 @@ llvm::ElementCount LLVMVectorType::getElementCount() {
/// Verifies that the type about to be constructed is well-formed.
LogicalResult
-LLVMVectorType::verifyConstructionInvariants(Location loc, LLVMType elementType,
+LLVMVectorType::verifyConstructionInvariants(Location loc, Type elementType,
unsigned numElements) {
if (numElements == 0)
return emitError(loc, "the number of vector elements must be positive");
@@ -307,14 +292,14 @@ LLVMVectorType::verifyConstructionInvariants(Location loc, LLVMType elementType,
return success();
}
-LLVMFixedVectorType LLVMFixedVectorType::get(LLVMType elementType,
+LLVMFixedVectorType LLVMFixedVectorType::get(Type elementType,
unsigned numElements) {
assert(elementType && "expected non-null subtype");
return Base::get(elementType.getContext(), elementType, numElements);
}
LLVMFixedVectorType LLVMFixedVectorType::getChecked(Location loc,
- LLVMType elementType,
+ Type elementType,
unsigned numElements) {
assert(elementType && "expected non-null subtype");
return Base::getChecked(loc, elementType, numElements);
@@ -324,14 +309,14 @@ unsigned LLVMFixedVectorType::getNumElements() {
return getImpl()->numElements;
}
-LLVMScalableVectorType LLVMScalableVectorType::get(LLVMType elementType,
+LLVMScalableVectorType LLVMScalableVectorType::get(Type elementType,
unsigned minNumElements) {
assert(elementType && "expected non-null subtype");
return Base::get(elementType.getContext(), elementType, minNumElements);
}
LLVMScalableVectorType
-LLVMScalableVectorType::getChecked(Location loc, LLVMType elementType,
+LLVMScalableVectorType::getChecked(Location loc, Type elementType,
unsigned minNumElements) {
assert(elementType && "expected non-null subtype");
return Base::getChecked(loc, elementType, minNumElements);
@@ -351,16 +336,16 @@ llvm::TypeSize mlir::LLVM::getPrimitiveTypeSizeInBits(Type type) {
return llvm::TypeSwitch<Type, llvm::TypeSize>(type)
.Case<LLVMHalfType, LLVMBFloatType>(
- [](LLVMType) { return llvm::TypeSize::Fixed(16); })
- .Case<LLVMFloatType>([](LLVMType) { return llvm::TypeSize::Fixed(32); })
+ [](Type) { return llvm::TypeSize::Fixed(16); })
+ .Case<LLVMFloatType>([](Type) { return llvm::TypeSize::Fixed(32); })
.Case<LLVMDoubleType, LLVMX86MMXType>(
- [](LLVMType) { return llvm::TypeSize::Fixed(64); })
+ [](Type) { return llvm::TypeSize::Fixed(64); })
.Case<LLVMIntegerType>([](LLVMIntegerType intTy) {
return llvm::TypeSize::Fixed(intTy.getBitWidth());
})
- .Case<LLVMX86FP80Type>([](LLVMType) { return llvm::TypeSize::Fixed(80); })
+ .Case<LLVMX86FP80Type>([](Type) { return llvm::TypeSize::Fixed(80); })
.Case<LLVMPPCFP128Type, LLVMFP128Type>(
- [](LLVMType) { return llvm::TypeSize::Fixed(128); })
+ [](Type) { return llvm::TypeSize::Fixed(128); })
.Case<LLVMVectorType>([](LLVMVectorType t) {
llvm::TypeSize elementSize =
getPrimitiveTypeSizeInBits(t.getElementType());
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index c2f689be493a..f8d6518b23aa 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -53,19 +53,18 @@ static ParseResult parseNVVMShflSyncBflyOp(OpAsmParser &parser,
parser.addTypeToList(resultType, result.types))
return failure();
- auto type = resultType.cast<LLVM::LLVMType>();
for (auto &attr : result.attributes) {
if (attr.first != "return_value_and_is_valid")
continue;
- auto structType = type.dyn_cast<LLVM::LLVMStructType>();
+ auto structType = resultType.dyn_cast<LLVM::LLVMStructType>();
if (structType && !structType.getBody().empty())
- type = structType.getBody()[0];
+ resultType = structType.getBody()[0];
break;
}
auto int32Ty =
LLVM::LLVMIntegerType::get(parser.getBuilder().getContext(), 32);
- return parser.resolveOperands(ops, {int32Ty, type, int32Ty, int32Ty},
+ return parser.resolveOperands(ops, {int32Ty, resultType, int32Ty, int32Ty},
parser.getNameLoc(), result.operands);
}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/TypeDetail.h b/mlir/lib/Dialect/LLVMIR/IR/TypeDetail.h
index 1d147beafb40..a02a77d5a4c0 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/TypeDetail.h
+++ b/mlir/lib/Dialect/LLVMIR/IR/TypeDetail.h
@@ -72,7 +72,7 @@ struct LLVMStructTypeStorage : public TypeStorage {
Key(StringRef name, bool opaque)
: name(name), identified(true), packed(false), opaque(opaque) {}
/// Constructs a key for a literal struct.
- Key(ArrayRef<LLVMType> types, bool packed)
+ Key(ArrayRef<Type> types, bool packed)
: types(types), identified(false), packed(packed), opaque(false) {}
/// Checks a specific property of the struct.
@@ -96,7 +96,7 @@ struct LLVMStructTypeStorage : public TypeStorage {
}
/// Returns the list of type contained in the key of a literal struct.
- ArrayRef<LLVMType> getTypeList() const {
+ ArrayRef<Type> getTypeList() const {
assert(!isIdentified() &&
"identified struct key cannot have a type list");
return types;
@@ -138,7 +138,7 @@ struct LLVMStructTypeStorage : public TypeStorage {
}
private:
- ArrayRef<LLVMType> types;
+ ArrayRef<Type> types;
StringRef name;
bool identified;
bool packed;
@@ -153,18 +153,18 @@ struct LLVMStructTypeStorage : public TypeStorage {
}
/// Returns the list of types (partially) identifying a literal struct.
- ArrayRef<LLVMType> getTypeList() const {
+ ArrayRef<Type> getTypeList() const {
// If this triggers, use getIdentifiedStructBody() instead.
assert(!isIdentified() && "requested typelist on an identified struct");
- return ArrayRef<LLVMType>(static_cast<const LLVMType *>(keyPtr), keySize());
+ return ArrayRef<Type>(static_cast<const Type *>(keyPtr), keySize());
}
/// Returns the list of types contained in an identified struct.
- ArrayRef<LLVMType> getIdentifiedStructBody() const {
+ ArrayRef<Type> getIdentifiedStructBody() const {
// If this triggers, use getTypeList() instead.
assert(isIdentified() &&
"requested struct body on a non-identified struct");
- return ArrayRef<LLVMType>(identifiedBodyArray, identifiedBodySize());
+ return ArrayRef<Type>(identifiedBodyArray, identifiedBodySize());
}
/// Checks whether the struct is identified.
@@ -199,7 +199,7 @@ struct LLVMStructTypeStorage : public TypeStorage {
/// as initialized and can no longer be mutated.
LLVMStructTypeStorage(const KeyTy &key) {
if (!key.isIdentified()) {
- ArrayRef<LLVMType> types = key.getTypeList();
+ ArrayRef<Type> types = key.getTypeList();
keyPtr = static_cast<const void *>(types.data());
setKeySize(types.size());
llvm::Bitfield::set<KeyFlagPacked>(keySizeAndFlags, key.isPacked());
@@ -232,7 +232,7 @@ struct LLVMStructTypeStorage : public TypeStorage {
/// initialized, succeeds only if the body is equal to the current body. Fails
/// if the struct is marked as intentionally opaque. The struct will be marked
/// as initialized as a result of this operation and can no longer be changed.
- LogicalResult mutate(TypeStorageAllocator &allocator, ArrayRef<LLVMType> body,
+ LogicalResult mutate(TypeStorageAllocator &allocator, ArrayRef<Type> body,
bool packed) {
if (!isIdentified())
return failure();
@@ -244,7 +244,7 @@ struct LLVMStructTypeStorage : public TypeStorage {
true);
llvm::Bitfield::set<MutableFlagPacked>(identifiedBodySizeAndFlags, packed);
- ArrayRef<LLVMType> typesInAllocator = allocator.copyInto(body);
+ ArrayRef<Type> typesInAllocator = allocator.copyInto(body);
identifiedBodyArray = typesInAllocator.data();
setIdentifiedBodySize(typesInAllocator.size());
@@ -310,7 +310,7 @@ struct LLVMStructTypeStorage : public TypeStorage {
const void *keyPtr = nullptr;
/// Pointer to the first type contained in an identified struct.
- const LLVMType *identifiedBodyArray = nullptr;
+ const Type *identifiedBodyArray = nullptr;
/// Size of the uniquing key combined with identified/literal and
/// packedness bits. Must only be used through the Key* bitfields.
@@ -328,12 +328,11 @@ struct LLVMStructTypeStorage : public TypeStorage {
/// Type storage for LLVM dialect function types. These are uniqued using the
/// list of types they contain and the vararg bit.
struct LLVMFunctionTypeStorage : public TypeStorage {
- using KeyTy = std::tuple<LLVMType, ArrayRef<LLVMType>, bool>;
+ using KeyTy = std::tuple<Type, ArrayRef<Type>, bool>;
/// Construct a storage from the given components. The list is expected to be
/// allocated in the context.
- LLVMFunctionTypeStorage(LLVMType result, ArrayRef<LLVMType> arguments,
- bool variadic)
+ LLVMFunctionTypeStorage(Type result, ArrayRef<Type> arguments, bool variadic)
: argumentTypes(arguments) {
returnTypeAndVariadic.setPointerAndInt(result, variadic);
}
@@ -359,19 +358,19 @@ struct LLVMFunctionTypeStorage : public TypeStorage {
}
/// Returns the list of function argument types.
- ArrayRef<LLVMType> getArgumentTypes() const { return argumentTypes; }
+ ArrayRef<Type> getArgumentTypes() const { return argumentTypes; }
/// Checks whether the function type is variadic.
bool isVariadic() const { return returnTypeAndVariadic.getInt(); }
/// Returns the function result type.
- LLVMType getReturnType() const { return returnTypeAndVariadic.getPointer(); }
+ Type getReturnType() const { return returnTypeAndVariadic.getPointer(); }
private:
/// Function result type packed with the variadic bit.
- llvm::PointerIntPair<LLVMType, 1, bool> returnTypeAndVariadic;
+ llvm::PointerIntPair<Type, 1, bool> returnTypeAndVariadic;
/// Argument types.
- ArrayRef<LLVMType> argumentTypes;
+ ArrayRef<Type> argumentTypes;
};
//===----------------------------------------------------------------------===//
@@ -402,7 +401,7 @@ struct LLVMIntegerTypeStorage : public TypeStorage {
/// Storage type for LLVM dialect pointer types. These are uniqued by a pair of
/// element type and address space.
struct LLVMPointerTypeStorage : public TypeStorage {
- using KeyTy = std::tuple<LLVMType, unsigned>;
+ using KeyTy = std::tuple<Type, unsigned>;
LLVMPointerTypeStorage(const KeyTy &key)
: pointeeType(std::get<0>(key)), addressSpace(std::get<1>(key)) {}
@@ -417,7 +416,7 @@ struct LLVMPointerTypeStorage : public TypeStorage {
return std::make_tuple(pointeeType, addressSpace) == key;
}
- LLVMType pointeeType;
+ Type pointeeType;
unsigned addressSpace;
};
@@ -429,7 +428,7 @@ struct LLVMPointerTypeStorage : public TypeStorage {
/// number: arrays, fixed and scalable vectors. The actual semantics of the
/// type is defined by its kind.
struct LLVMTypeAndSizeStorage : public TypeStorage {
- using KeyTy = std::tuple<LLVMType, unsigned>;
+ using KeyTy = std::tuple<Type, unsigned>;
LLVMTypeAndSizeStorage(const KeyTy &key)
: elementType(std::get<0>(key)), numElements(std::get<1>(key)) {}
@@ -444,7 +443,7 @@ struct LLVMTypeAndSizeStorage : public TypeStorage {
return std::make_tuple(elementType, numElements) == key;
}
- LLVMType elementType;
+ Type elementType;
unsigned numElements;
};
diff --git a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
index 9786751ef4b0..89e7decc8152 100644
--- a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
+++ b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
@@ -68,8 +68,8 @@ class Importer {
LogicalResult processBasicBlock(llvm::BasicBlock *bb, Block *block);
/// Imports `inst` and populates instMap[inst] with the imported Value.
LogicalResult processInstruction(llvm::Instruction *inst);
- /// Creates an LLVMType for `type`.
- LLVMType processType(llvm::Type *type);
+ /// Creates an LLVM-compatible MLIR type for `type`.
+ Type processType(llvm::Type *type);
/// `value` is an SSA-use. Return the remapped version of `value` or a
/// placeholder that will be remapped later if this is an instruction that
/// has not yet been visited.
@@ -87,7 +87,7 @@ class Importer {
SmallVectorImpl<Value> &blockArguments);
/// Returns the builtin type equivalent to be used in attributes for the given
/// LLVM IR dialect type.
- Type getStdTypeForAttr(LLVMType type);
+ Type getStdTypeForAttr(Type type);
/// Return `value` as an attribute to attach to a GlobalOp.
Attribute getConstantAsAttr(llvm::Constant *value);
/// Return `c` as an MLIR Value. This could either be a ConstantOp, or
@@ -150,8 +150,8 @@ Location Importer::processDebugLoc(const llvm::DebugLoc &loc,
context);
}
-LLVMType Importer::processType(llvm::Type *type) {
- if (LLVMType result = typeTranslator.translateType(type))
+Type Importer::processType(llvm::Type *type) {
+ if (Type result = typeTranslator.translateType(type))
return result;
// FIXME: Diagnostic should be able to natively handle types that have
@@ -168,7 +168,7 @@ LLVMType Importer::processType(llvm::Type *type) {
// equivalents. Array types are converted to ranked tensors; nested array types
// are converted to multi-dimensional tensors or vectors, depending on the
// innermost type being a scalar or a vector.
-Type Importer::getStdTypeForAttr(LLVMType type) {
+Type Importer::getStdTypeForAttr(Type type) {
if (!type)
return nullptr;
@@ -252,7 +252,7 @@ Attribute Importer::getConstantAsAttr(llvm::Constant *value) {
// Convert constant data to a dense elements attribute.
if (auto *cd = dyn_cast<llvm::ConstantDataSequential>(value)) {
- LLVMType type = processType(cd->getElementType());
+ Type type = processType(cd->getElementType());
if (!type)
return nullptr;
@@ -315,7 +315,7 @@ GlobalOp Importer::processGlobal(llvm::GlobalVariable *GV) {
Attribute valueAttr;
if (GV->hasInitializer())
valueAttr = getConstantAsAttr(GV->getInitializer());
- LLVMType type = processType(GV->getValueType());
+ Type type = processType(GV->getValueType());
if (!type)
return nullptr;
GlobalOp op = b.create<GlobalOp>(
@@ -338,7 +338,7 @@ Value Importer::processConstant(llvm::Constant *c) {
if (Attribute attr = getConstantAsAttr(c)) {
// These constants can be represented as attributes.
OpBuilder b(currentEntryBlock, currentEntryBlock->begin());
- LLVMType type = processType(c->getType());
+ Type type = processType(c->getType());
if (!type)
return nullptr;
if (auto symbolRef = attr.dyn_cast<FlatSymbolRefAttr>())
@@ -347,7 +347,7 @@ Value Importer::processConstant(llvm::Constant *c) {
return instMap[c] = bEntry.create<ConstantOp>(unknownLoc, type, attr);
}
if (auto *cn = dyn_cast<llvm::ConstantPointerNull>(c)) {
- LLVMType type = processType(cn->getType());
+ Type type = processType(cn->getType());
if (!type)
return nullptr;
return instMap[c] = bEntry.create<NullOp>(unknownLoc, type);
@@ -370,7 +370,7 @@ Value Importer::processConstant(llvm::Constant *c) {
return instMap[c] = instMap[i];
}
if (auto *ue = dyn_cast<llvm::UndefValue>(c)) {
- LLVMType type = processType(ue->getType());
+ Type type = processType(ue->getType());
if (!type)
return nullptr;
return instMap[c] = bEntry.create<UndefOp>(UnknownLoc::get(context), type);
@@ -388,7 +388,7 @@ Value Importer::processValue(llvm::Value *value) {
// this instruction yet, create an unknown op and remap it later.
if (isa<llvm::Instruction>(value)) {
OperationState state(UnknownLoc::get(context), "llvm.unknown");
- LLVMType type = processType(value->getType());
+ Type type = processType(value->getType());
if (!type)
return nullptr;
state.addTypes(type);
@@ -578,7 +578,7 @@ LogicalResult Importer::processInstruction(llvm::Instruction *inst) {
}
state.addOperands(ops);
if (!inst->getType()->isVoidTy()) {
- LLVMType type = processType(inst->getType());
+ Type type = processType(inst->getType());
if (!type)
return failure();
state.addTypes(type);
@@ -629,7 +629,7 @@ LogicalResult Importer::processInstruction(llvm::Instruction *inst) {
return success();
}
case llvm::Instruction::PHI: {
- LLVMType type = processType(inst->getType());
+ Type type = processType(inst->getType());
if (!type)
return failure();
v = b.getInsertionBlock()->addArgument(type);
@@ -648,7 +648,7 @@ LogicalResult Importer::processInstruction(llvm::Instruction *inst) {
SmallVector<Type, 2> tys;
if (!ci->getType()->isVoidTy()) {
- LLVMType type = processType(inst->getType());
+ Type type = processType(inst->getType());
if (!type)
return failure();
tys.push_back(type);
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index c28588d32ad6..da9c734fbd80 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -762,17 +762,17 @@ LogicalResult ModuleTranslation::convertOperation(Operation &opInst,
if (auto inlineAsmOp = dyn_cast<LLVM::InlineAsmOp>(opInst)) {
// TODO: refactor function type creation which usually occurs in std-LLVM
// conversion.
- SmallVector<LLVM::LLVMType, 8> operandTypes;
+ SmallVector<Type, 8> operandTypes;
operandTypes.reserve(inlineAsmOp.operands().size());
for (auto t : inlineAsmOp.operands().getTypes())
- operandTypes.push_back(t.cast<LLVM::LLVMType>());
+ operandTypes.push_back(t);
- LLVM::LLVMType resultType;
+ Type resultType;
if (inlineAsmOp.getNumResults() == 0) {
resultType = LLVM::LLVMVoidType::get(mlirModule->getContext());
} else {
assert(inlineAsmOp.getNumResults() == 1);
- resultType = inlineAsmOp.getResultTypes()[0].cast<LLVM::LLVMType>();
+ resultType = inlineAsmOp.getResultTypes()[0];
}
auto ft = LLVM::LLVMFunctionType::get(resultType, operandTypes);
llvm::InlineAsm *inlineAsmInst =
@@ -813,7 +813,7 @@ LogicalResult ModuleTranslation::convertOperation(Operation &opInst,
}
if (auto lpOp = dyn_cast<LLVM::LandingpadOp>(opInst)) {
- llvm::Type *ty = convertType(lpOp.getType().cast<LLVMType>());
+ llvm::Type *ty = convertType(lpOp.getType());
llvm::LandingPadInst *lpi =
builder.CreateLandingPad(ty, lpOp.getNumOperands());
@@ -872,8 +872,8 @@ LogicalResult ModuleTranslation::convertOperation(Operation &opInst,
blockMapping[switchOp.defaultDestination()],
switchOp.caseDestinations().size(), branchWeights);
- auto *ty = llvm::cast<llvm::IntegerType>(
- convertType(switchOp.value().getType().cast<LLVMType>()));
+ auto *ty =
+ llvm::cast<llvm::IntegerType>(convertType(switchOp.value().getType()));
for (auto i :
llvm::zip(switchOp.case_values()->cast<DenseIntElementsAttr>(),
switchOp.caseDestinations()))
@@ -927,8 +927,8 @@ LogicalResult ModuleTranslation::convertBlock(Block &bb, bool ignoreArguments) {
unsigned numPredecessors =
std::distance(predecessors.begin(), predecessors.end());
for (auto arg : bb.getArguments()) {
- auto wrappedType = arg.getType().dyn_cast<LLVM::LLVMType>();
- if (!wrappedType)
+ auto wrappedType = arg.getType();
+ if (!isCompatibleType(wrappedType))
return emitError(bb.front().getLoc(),
"block argument does not have an LLVM type");
llvm::Type *type = convertType(wrappedType);
@@ -1094,7 +1094,7 @@ LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) {
argIdx, LLVMDialect::getNoAliasAttrName())) {
// NB: Attribute already verified to be boolean, so check if we can indeed
// attach the attribute to this argument, based on its type.
- auto argTy = mlirArg.getType().dyn_cast<LLVM::LLVMType>();
+ auto argTy = mlirArg.getType();
if (!argTy.isa<LLVM::LLVMPointerType>())
return func.emitError(
"llvm.noalias attribute attached to LLVM non-pointer argument");
@@ -1106,7 +1106,7 @@ LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) {
argIdx, LLVMDialect::getAlignAttrName())) {
// NB: Attribute already verified to be int, so check if we can indeed
// attach the attribute to this argument, based on its type.
- auto argTy = mlirArg.getType().dyn_cast<LLVM::LLVMType>();
+ auto argTy = mlirArg.getType();
if (!argTy.isa<LLVM::LLVMPointerType>())
return func.emitError(
"llvm.align attribute attached to LLVM non-pointer argument");
@@ -1190,7 +1190,7 @@ LogicalResult ModuleTranslation::convertFunctions() {
return success();
}
-llvm::Type *ModuleTranslation::convertType(LLVMType type) {
+llvm::Type *ModuleTranslation::convertType(Type type) {
return typeTranslator.translateType(type);
}
diff --git a/mlir/lib/Target/LLVMIR/TypeTranslation.cpp b/mlir/lib/Target/LLVMIR/TypeTranslation.cpp
index 2a4325f0df97..ecde56cc78f6 100644
--- a/mlir/lib/Target/LLVMIR/TypeTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/TypeTranslation.cpp
@@ -27,14 +27,14 @@ class TypeToLLVMIRTranslatorImpl {
TypeToLLVMIRTranslatorImpl(llvm::LLVMContext &context) : context(context) {}
/// Translates a single type.
- llvm::Type *translateType(LLVM::LLVMType type) {
+ llvm::Type *translateType(Type type) {
// If the conversion is already known, just return it.
if (knownTranslations.count(type))
return knownTranslations.lookup(type);
// Dispatch to an appropriate function.
llvm::Type *translated =
- llvm::TypeSwitch<LLVM::LLVMType, llvm::Type *>(type)
+ llvm::TypeSwitch<Type, llvm::Type *>(type)
.Case([this](LLVM::LLVMVoidType) {
return llvm::Type::getVoidTy(context);
})
@@ -76,7 +76,7 @@ class TypeToLLVMIRTranslatorImpl {
LLVM::LLVMStructType, LLVM::LLVMFixedVectorType,
LLVM::LLVMScalableVectorType>(
[this](auto type) { return this->translate(type); })
- .Default([](LLVM::LLVMType t) -> llvm::Type * {
+ .Default([](Type t) -> llvm::Type * {
llvm_unreachable("unknown LLVM dialect type");
});
@@ -147,7 +147,7 @@ class TypeToLLVMIRTranslatorImpl {
}
/// Translates a list of types.
- void translateTypes(ArrayRef<LLVM::LLVMType> types,
+ void translateTypes(ArrayRef<Type> types,
SmallVectorImpl<llvm::Type *> &result) {
result.reserve(result.size() + types.size());
for (auto type : types)
@@ -161,7 +161,7 @@ class TypeToLLVMIRTranslatorImpl {
/// results to avoid repeated recursive calls and makes sure identified
/// structs with the same name (that is, equal) are resolved to an existing
/// type instead of creating a new type.
- llvm::DenseMap<LLVM::LLVMType, llvm::Type *> knownTranslations;
+ llvm::DenseMap<Type, llvm::Type *> knownTranslations;
};
} // end namespace detail
} // end namespace LLVM
@@ -172,12 +172,12 @@ LLVM::TypeToLLVMIRTranslator::TypeToLLVMIRTranslator(llvm::LLVMContext &context)
LLVM::TypeToLLVMIRTranslator::~TypeToLLVMIRTranslator() {}
-llvm::Type *LLVM::TypeToLLVMIRTranslator::translateType(LLVM::LLVMType type) {
+llvm::Type *LLVM::TypeToLLVMIRTranslator::translateType(Type type) {
return impl->translateType(type);
}
unsigned LLVM::TypeToLLVMIRTranslator::getPreferredAlignment(
- LLVM::LLVMType type, const llvm::DataLayout &layout) {
+ Type type, const llvm::DataLayout &layout) {
return layout.getPrefTypeAlignment(translateType(type));
}
@@ -191,12 +191,12 @@ class TypeFromLLVMIRTranslatorImpl {
TypeFromLLVMIRTranslatorImpl(MLIRContext &context) : context(context) {}
/// Translates the given type.
- LLVM::LLVMType translateType(llvm::Type *type) {
+ Type translateType(llvm::Type *type) {
if (knownTranslations.count(type))
return knownTranslations.lookup(type);
- LLVM::LLVMType translated =
- llvm::TypeSwitch<llvm::Type *, LLVM::LLVMType>(type)
+ Type translated =
+ llvm::TypeSwitch<llvm::Type *, Type>(type)
.Case<llvm::ArrayType, llvm::FunctionType, llvm::IntegerType,
llvm::PointerType, llvm::StructType, llvm::FixedVectorType,
llvm::ScalableVectorType>(
@@ -211,7 +211,7 @@ class TypeFromLLVMIRTranslatorImpl {
private:
/// Translates the given primitive, i.e. non-parametric in MLIR nomenclature,
/// type.
- LLVM::LLVMType translatePrimitiveType(llvm::Type *type) {
+ Type translatePrimitiveType(llvm::Type *type) {
if (type->isVoidTy())
return LLVM::LLVMVoidType::get(&context);
if (type->isHalfTy())
@@ -238,33 +238,33 @@ class TypeFromLLVMIRTranslatorImpl {
}
/// Translates the given array type.
- LLVM::LLVMType translate(llvm::ArrayType *type) {
+ Type translate(llvm::ArrayType *type) {
return LLVM::LLVMArrayType::get(translateType(type->getElementType()),
type->getNumElements());
}
/// Translates the given function type.
- LLVM::LLVMType translate(llvm::FunctionType *type) {
- SmallVector<LLVM::LLVMType, 8> paramTypes;
+ Type translate(llvm::FunctionType *type) {
+ SmallVector<Type, 8> paramTypes;
translateTypes(type->params(), paramTypes);
return LLVM::LLVMFunctionType::get(translateType(type->getReturnType()),
paramTypes, type->isVarArg());
}
/// Translates the given integer type.
- LLVM::LLVMType translate(llvm::IntegerType *type) {
+ Type translate(llvm::IntegerType *type) {
return LLVM::LLVMIntegerType::get(&context, type->getBitWidth());
}
/// Translates the given pointer type.
- LLVM::LLVMType translate(llvm::PointerType *type) {
+ Type translate(llvm::PointerType *type) {
return LLVM::LLVMPointerType::get(translateType(type->getElementType()),
type->getAddressSpace());
}
/// Translates the given structure type.
- LLVM::LLVMType translate(llvm::StructType *type) {
- SmallVector<LLVM::LLVMType, 8> subtypes;
+ Type translate(llvm::StructType *type) {
+ SmallVector<Type, 8> subtypes;
if (type->isLiteral()) {
translateTypes(type->subtypes(), subtypes);
return LLVM::LLVMStructType::getLiteral(&context, subtypes,
@@ -286,20 +286,20 @@ class TypeFromLLVMIRTranslatorImpl {
}
/// Translates the given fixed-vector type.
- LLVM::LLVMType translate(llvm::FixedVectorType *type) {
+ Type translate(llvm::FixedVectorType *type) {
return LLVM::LLVMFixedVectorType::get(translateType(type->getElementType()),
type->getNumElements());
}
/// Translates the given scalable-vector type.
- LLVM::LLVMType translate(llvm::ScalableVectorType *type) {
+ Type translate(llvm::ScalableVectorType *type) {
return LLVM::LLVMScalableVectorType::get(
translateType(type->getElementType()), type->getMinNumElements());
}
/// Translates a list of types.
void translateTypes(ArrayRef<llvm::Type *> types,
- SmallVectorImpl<LLVM::LLVMType> &result) {
+ SmallVectorImpl<Type> &result) {
result.reserve(result.size() + types.size());
for (llvm::Type *type : types)
result.push_back(translateType(type));
@@ -307,7 +307,7 @@ class TypeFromLLVMIRTranslatorImpl {
/// Map of known translations. Serves as a cache and as recursion stopper for
/// translating recursive structs.
- llvm::DenseMap<llvm::Type *, LLVM::LLVMType> knownTranslations;
+ llvm::DenseMap<llvm::Type *, Type> knownTranslations;
/// The context in which MLIR types are created.
MLIRContext &context;
@@ -321,6 +321,6 @@ LLVM::TypeFromLLVMIRTranslator::TypeFromLLVMIRTranslator(MLIRContext &context)
LLVM::TypeFromLLVMIRTranslator::~TypeFromLLVMIRTranslator() {}
-LLVM::LLVMType LLVM::TypeFromLLVMIRTranslator::translateType(llvm::Type *type) {
+Type LLVM::TypeFromLLVMIRTranslator::translateType(llvm::Type *type) {
return impl->translateType(type);
}
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index d02c252c0bf3..5e2f666c5b83 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -163,7 +163,7 @@ func @call_non_llvm() {
// -----
func @call_non_llvm_indirect(%arg0 : i32) {
- // expected-error at +1 {{'llvm.call' op operand #0 must be LLVM dialect type, but got 'i32'}}
+ // expected-error at +1 {{'llvm.call' op operand #0 must be LLVM dialect-compatible type, but got 'i32'}}
"llvm.call"(%arg0) : (i32) -> ()
}
diff --git a/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp b/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp
index f748b56b1bdf..428b61589134 100644
--- a/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp
+++ b/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp
@@ -135,7 +135,7 @@ static bool emitOneBuilder(const Record &record, raw_ostream &os) {
} else if (isResultName(op, name)) {
bs << formatv("valueMapping[op.{0}()]", name);
} else if (name == "_resultType") {
- bs << "convertType(op.getResult().getType().cast<LLVM::LLVMType>())";
+ bs << "convertType(op.getResult().getType())";
} else if (name == "_hasResult") {
bs << "opInst.getNumResults() == 1";
} else if (name == "_location") {
More information about the Mlir-commits
mailing list