[Mlir-commits] [mlir] 8de43b9 - [mlir] Remove instance methods from LLVMType
Alex Zinenko
llvmlistbot at llvm.org
Tue Dec 22 14:35:03 PST 2020
Author: Alex Zinenko
Date: 2020-12-22T23:34:54+01:00
New Revision: 8de43b926f0e960bbc5b6a53d1b613c46b7c774b
URL: https://github.com/llvm/llvm-project/commit/8de43b926f0e960bbc5b6a53d1b613c46b7c774b
DIFF: https://github.com/llvm/llvm-project/commit/8de43b926f0e960bbc5b6a53d1b613c46b7c774b.diff
LOG: [mlir] Remove instance methods from LLVMType
LLVMType contains multiple instance methods that were introduced initially for
compatibility with LLVM API. These methods boil down to `cast` followed by
type-specific call. Arguably, they are mostly used in an LLVM cast-follows-isa
anti-pattern. This doesn't connect nicely to the rest of the MLIR
infrastructure and actively prevents it from making the LLVM dialect type
system more open, e.g., reusing built-in types when appropriate. Remove such
instance methods and replaces their uses with apporpriate casts and methods on
derived classes. In some cases, the result may look slightly more verbose, but
most cases should actually use a stricter subtype of LLVMType anyway and avoid
the isa/cast.
Reviewed By: mehdi_amini
Differential Revision: https://reviews.llvm.org/D93680
Added:
Modified:
mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp
mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp
mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
mlir/lib/ExecutionEngine/JitRunner.cpp
mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
mlir/test/Dialect/LLVMIR/invalid.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
index 7c069c9cd556..63ff16a84ab8 100644
--- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
+++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
@@ -446,7 +446,8 @@ class UnrankedMemRefDescriptor : public StructBuilder {
/// Builds IR extracting the pointer to the first element of the size array.
static Value sizeBasePtr(OpBuilder &builder, Location loc,
LLVMTypeConverter &typeConverter,
- Value memRefDescPtr, LLVM::LLVMType elemPtrPtrType);
+ Value memRefDescPtr,
+ LLVM::LLVMPointerType elemPtrPtrType);
/// Builds IR extracting the size[index] from the descriptor.
static Value size(OpBuilder &builder, Location loc,
LLVMTypeConverter typeConverter, Value sizeBasePtr,
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index df022ef47b33..552fe15e6899 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -51,7 +51,7 @@ def LLVM_VoidResultTypeOpBuilder :
[{
auto llvmType = resultType.dyn_cast<LLVMType>(); (void)llvmType;
assert(llvmType && "result must be an LLVM type");
- assert(llvmType.isVoidTy() &&
+ assert(llvmType.isa<LLVMVoidType>() &&
"for zero-result operands, only 'void' is accepted as result type");
build($_builder, $_state, operands, attributes);
}]>;
@@ -288,7 +288,7 @@ def LLVM_LoadOp : LLVM_Op<"load">, MemoryOpWithAlignmentAndAttributes {
OpBuilderDAG<(ins "Value":$addr, CArg<"unsigned", "0">:$alignment,
CArg<"bool", "false">:$isVolatile, CArg<"bool", "false">:$isNonTemporal),
[{
- auto type = addr.getType().cast<LLVMType>().getPointerElementTy();
+ auto type = addr.getType().cast<LLVMPointerType>().getElementType();
build($_builder, $_state, type, addr, alignment, isVolatile, isNonTemporal);
}]>,
OpBuilderDAG<(ins "Type":$t, "Value":$addr,
@@ -443,8 +443,8 @@ def LLVM_CallOp : LLVM_Op<"call"> {
OpBuilderDAG<(ins "LLVMFuncOp":$func, "ValueRange":$operands,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
[{
- LLVMType resultType = func.getType().getFunctionResultType();
- if (!resultType.isVoidTy())
+ LLVMType resultType = func.getType().getReturnType();
+ if (!resultType.isa<LLVM::LLVMVoidType>())
$_state.addTypes(resultType);
$_state.addAttribute("callee", $_builder.getSymbolRefAttr(func));
$_state.addAttributes(attributes);
@@ -515,12 +515,10 @@ def LLVM_ShuffleVectorOp : LLVM_Op<"shufflevector", [NoSideEffect]> {
OpBuilderDAG<(ins "Value":$v1, "Value":$v2, "ArrayAttr":$mask,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>];
let verifier = [{
- auto wrappedVectorType1 = v1().getType().cast<LLVMType>();
- auto wrappedVectorType2 = v2().getType().cast<LLVMType>();
- if (!wrappedVectorType2.isVectorTy())
- return emitOpError("expected LLVM IR Dialect vector type for operand #2");
- if (wrappedVectorType1.getVectorElementType() !=
- wrappedVectorType2.getVectorElementType())
+ auto wrappedVectorType1 = v1().getType().cast<LLVMVectorType>();
+ auto wrappedVectorType2 = v2().getType().cast<LLVMVectorType>();
+ if (wrappedVectorType1.getElementType() !=
+ wrappedVectorType2.getElementType())
return emitOpError("expected matching LLVM IR Dialect element types");
return success();
}];
@@ -768,13 +766,13 @@ def LLVM_AddressOfOp : LLVM_Op<"mlir.addressof"> {
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
[{
build($_builder, $_state,
- global.getType().getPointerTo(global.addr_space()),
+ LLVM::LLVMPointerType::get(global.getType(), global.addr_space()),
global.sym_name(), attrs);}]>,
OpBuilderDAG<(ins "LLVMFuncOp":$func,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
[{
build($_builder, $_state,
- func.getType().getPointerTo(), func.getName(), attrs);}]>
+ LLVM::LLVMPointerType::get(func.getType()), func.getName(), attrs);}]>
];
let extraClassDeclaration = [{
@@ -970,12 +968,12 @@ def LLVM_LLVMFuncOp : LLVM_Op<"func",
// to match the signature of the function.
Block *addEntryBlock();
- LLVMType getType() {
+ LLVMFunctionType getType() {
return (*this)->getAttrOfType<TypeAttr>(getTypeAttrName())
- .getValue().cast<LLVMType>();
+ .getValue().cast<LLVMFunctionType>();
}
bool isVarArg() {
- return getType().isFunctionVarArg();
+ return getType().isVarArg();
}
// Hook for OpTrait::FunctionLike, returns the number of function arguments`.
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
index f92bdf9e3041..e1938c12c809 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
@@ -80,58 +80,6 @@ class LLVMType : public Type {
LLVMDialect &getDialect();
- /// Returns the size of a primitive type (including vectors) in bits, for
- /// example, the size of !llvm.i16 is 16 and the size of !llvm.vec<4 x i16>
- /// is 64. Returns 0 for non-primitive (aggregates such as struct) or types
- /// that don't have a size (such as void).
- llvm::TypeSize getPrimitiveSizeInBits();
-
- /// Floating-point type utilities.
- bool isBFloatTy() { return isa<LLVMBFloatType>(); }
- bool isHalfTy() { return isa<LLVMHalfType>(); }
- bool isFloatTy() { return isa<LLVMFloatType>(); }
- bool isDoubleTy() { return isa<LLVMDoubleType>(); }
- bool isFP128Ty() { return isa<LLVMFP128Type>(); }
- bool isX86_FP80Ty() { return isa<LLVMX86FP80Type>(); }
- bool isFloatingPointTy() {
- return isa<LLVMHalfType>() || isa<LLVMBFloatType>() ||
- isa<LLVMFloatType>() || isa<LLVMDoubleType>() ||
- isa<LLVMFP128Type>() || isa<LLVMX86FP80Type>();
- }
-
- /// Array type utilities.
- LLVMType getArrayElementType();
- unsigned getArrayNumElements();
- bool isArrayTy();
-
- /// Integer type utilities.
- bool isIntegerTy() { return isa<LLVMIntegerType>(); }
- bool isIntegerTy(unsigned bitwidth);
- unsigned getIntegerBitWidth();
-
- /// Vector type utilities.
- LLVMType getVectorElementType();
- unsigned getVectorNumElements();
- llvm::ElementCount getVectorElementCount();
- bool isVectorTy();
-
- /// Function type utilities.
- LLVMType getFunctionParamType(unsigned argIdx);
- unsigned getFunctionNumParams();
- LLVMType getFunctionResultType();
- bool isFunctionTy();
- bool isFunctionVarArg();
-
- /// Pointer type utilities.
- LLVMType getPointerTo(unsigned addrSpace = 0);
- LLVMType getPointerElementTy();
- bool isPointerTy();
-
- /// Struct type utilities.
- LLVMType getStructElementType(unsigned i);
- unsigned getStructNumElements();
- bool isStructTy();
-
/// Utilities used to generate floating point types.
static LLVMType getDoubleTy(MLIRContext *context);
static LLVMType getFloatTy(MLIRContext *context);
@@ -148,9 +96,7 @@ class LLVMType : public Type {
static LLVMType getInt8Ty(MLIRContext *context) {
return getIntNTy(context, /*numBits=*/8);
}
- static LLVMType getInt8PtrTy(MLIRContext *context) {
- return getInt8Ty(context).getPointerTo();
- }
+ static LLVMType getInt8PtrTy(MLIRContext *context);
static LLVMType getInt16Ty(MLIRContext *context) {
return getIntNTy(context, /*numBits=*/16);
}
@@ -184,7 +130,6 @@ class LLVMType : public Type {
/// Void type utilities.
static LLVMType getVoidTy(MLIRContext *context);
- bool isVoidTy();
// Creation and setting of LLVM's identified struct types
static LLVMType createStructTy(MLIRContext *context,
@@ -585,6 +530,24 @@ LLVMType parseType(DialectAsmParser &parser);
void printType(LLVMType type, DialectAsmPrinter &printer);
} // namespace detail
+//===----------------------------------------------------------------------===//
+// Utility functions.
+//===----------------------------------------------------------------------===//
+
+/// Returns `true` if the given type is compatible with the LLVM dialect.
+inline bool isCompatibleType(Type type) { return type.isa<LLVMType>(); }
+
+inline bool isCompatibleFloatingPointType(Type type) {
+ return type.isa<LLVMHalfType, LLVMBFloatType, LLVMFloatType, LLVMDoubleType,
+ LLVMFP128Type, LLVMX86FP80Type>();
+}
+
+/// Returns the size of the given primitive LLVM dialect-compatible type
+/// (including vectors) in bits, for example, the size of !llvm.i16 is 16 and
+/// the size of !llvm.vec<4 x i16> is 64. Returns 0 for non-primitive
+/// (aggregates such as struct) or types that don't have a size (such as void).
+llvm::TypeSize getPrimitiveTypeSizeInBits(Type type);
+
} // namespace LLVM
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 1f9b860eb52e..3c73cdf64eb7 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -109,10 +109,11 @@ def NVVM_ShflBflyOp :
let verifier = [{
if (!(*this)->getAttrOfType<UnitAttr>("return_value_and_is_valid"))
return success();
- auto type = getType().cast<LLVM::LLVMType>();
- if (!type.isStructTy() || type.getStructNumElements() != 2 ||
- !type.getStructElementType(1).isIntegerTy(
- /*Bitwidth=*/1))
+ auto type = getType().dyn_cast<LLVM::LLVMStructType>();
+ auto elementType = (type && type.getBody().size() == 2)
+ ? type.getBody()[1].dyn_cast<LLVM::LLVMIntegerType>()
+ : nullptr;
+ if (!elementType || elementType.getBitWidth() != 1)
return emitError("expected return type to be a two-element struct with "
"i1 as the second element");
return success();
diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
index 273754fe2480..65545d8ab2de 100644
--- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
+++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
@@ -79,7 +79,7 @@ struct AsyncAPI {
static FunctionType executeFunctionType(MLIRContext *ctx) {
auto hdl = LLVM::LLVMType::getInt8PtrTy(ctx);
- auto resume = resumeFunctionType(ctx).getPointerTo();
+ auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
return FunctionType::get(ctx, {hdl, resume}, {});
}
@@ -91,13 +91,13 @@ struct AsyncAPI {
static FunctionType awaitAndExecuteFunctionType(MLIRContext *ctx) {
auto hdl = LLVM::LLVMType::getInt8PtrTy(ctx);
- auto resume = resumeFunctionType(ctx).getPointerTo();
+ auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
return FunctionType::get(ctx, {TokenType::get(ctx), hdl, resume}, {});
}
static FunctionType awaitAllAndExecuteFunctionType(MLIRContext *ctx) {
auto hdl = LLVM::LLVMType::getInt8PtrTy(ctx);
- auto resume = resumeFunctionType(ctx).getPointerTo();
+ auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
return FunctionType::get(ctx, {GroupType::get(ctx), hdl, resume}, {});
}
@@ -507,7 +507,7 @@ outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
// A pointer to coroutine resume intrinsic wrapper.
auto resumeFnTy = AsyncAPI::resumeFunctionType(ctx);
auto resumePtr = builder.create<LLVM::AddressOfOp>(
- loc, resumeFnTy.getPointerTo(), kResume);
+ loc, LLVM::LLVMPointerType::get(resumeFnTy), kResume);
// Save the coroutine state: @llvm.coro.save
auto coroSave = builder.create<LLVM::CallOp>(
@@ -750,7 +750,7 @@ class AwaitOpLoweringBase : public ConversionPattern {
// A pointer to coroutine resume intrinsic wrapper.
auto resumeFnTy = AsyncAPI::resumeFunctionType(ctx);
auto resumePtr = builder.create<LLVM::AddressOfOp>(
- loc, resumeFnTy.getPointerTo(), kResume);
+ loc, LLVM::LLVMPointerType::get(resumeFnTy), kResume);
// Save the coroutine state: @llvm.coro.save
auto coroSave = builder.create<LLVM::CallOp>(
diff --git a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp
index 41a079c44eea..bbb2bf1e04ff 100644
--- a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp
+++ b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp
@@ -55,14 +55,14 @@ class FunctionCallBuilder {
FunctionCallBuilder(StringRef functionName, LLVM::LLVMType returnType,
ArrayRef<LLVM::LLVMType> argumentTypes)
: functionName(functionName),
- functionType(LLVM::LLVMType::getFunctionTy(returnType, argumentTypes,
- /*isVarArg=*/false)) {}
+ functionType(LLVM::LLVMFunctionType::get(returnType, argumentTypes,
+ /*isVarArg=*/false)) {}
LLVM::CallOp create(Location loc, OpBuilder &builder,
ArrayRef<Value> arguments) const;
private:
StringRef functionName;
- LLVM::LLVMType functionType;
+ LLVM::LLVMFunctionType functionType;
};
template <typename OpTy>
@@ -76,7 +76,8 @@ class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> {
LLVM::LLVMType llvmVoidType = LLVM::LLVMType::getVoidTy(context);
LLVM::LLVMType llvmPointerType = LLVM::LLVMType::getInt8PtrTy(context);
- LLVM::LLVMType llvmPointerPointerType = llvmPointerType.getPointerTo();
+ LLVM::LLVMType llvmPointerPointerType =
+ LLVM::LLVMPointerType::get(llvmPointerType);
LLVM::LLVMType llvmInt8Type = LLVM::LLVMType::getInt8Ty(context);
LLVM::LLVMType llvmInt32Type = LLVM::LLVMType::getInt32Ty(context);
LLVM::LLVMType llvmInt64Type = LLVM::LLVMType::getInt64Ty(context);
@@ -312,7 +313,7 @@ LLVM::CallOp FunctionCallBuilder::create(Location loc, OpBuilder &builder,
.create<LLVM::LLVMFuncOp>(loc, functionName, functionType);
}();
return builder.create<LLVM::CallOp>(
- loc, const_cast<LLVM::LLVMType &>(functionType).getFunctionResultType(),
+ loc, const_cast<LLVM::LLVMFunctionType &>(functionType).getReturnType(),
builder.getSymbolRefAttr(function), arguments);
}
@@ -518,7 +519,7 @@ Value ConvertLaunchFuncOpToGpuRuntimeCallPattern::generateParamsArray(
auto one = builder.create<LLVM::ConstantOp>(loc, llvmInt32Type,
builder.getI32IntegerAttr(1));
auto structPtr = builder.create<LLVM::AllocaOp>(
- loc, structType.getPointerTo(), one, /*alignment=*/0);
+ loc, LLVM::LLVMPointerType::get(structType), one, /*alignment=*/0);
auto arraySize = builder.create<LLVM::ConstantOp>(
loc, llvmInt32Type, builder.getI32IntegerAttr(numArguments));
auto arrayPtr = builder.create<LLVM::AllocaOp>(loc, llvmPointerPointerType,
@@ -529,7 +530,7 @@ Value ConvertLaunchFuncOpToGpuRuntimeCallPattern::generateParamsArray(
auto index = builder.create<LLVM::ConstantOp>(
loc, llvmInt32Type, builder.getI32IntegerAttr(en.index()));
auto fieldPtr = builder.create<LLVM::GEPOp>(
- loc, argumentTypes[en.index()].getPointerTo(), structPtr,
+ loc, LLVM::LLVMPointerType::get(argumentTypes[en.index()]), structPtr,
ArrayRef<Value>{zero, index.getResult()});
builder.create<LLVM::StoreOp>(loc, en.value(), fieldPtr);
auto elementPtr = builder.create<LLVM::GEPOp>(loc, llvmPointerPointerType,
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
index bf17200e594f..914b7ee50cf9 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
@@ -51,8 +51,8 @@ struct GPUFuncOpLowering : ConvertOpToLLVMPattern<gpu::GPUFuncOp> {
// Rewrite the original GPU function to an LLVM function.
auto funcType = typeConverter->convertType(gpuFuncOp.getType())
- .template cast<LLVM::LLVMType>()
- .getPointerElementTy();
+ .template cast<LLVM::LLVMPointerType>()
+ .getElementType();
// Remap proper input types.
TypeConverter::SignatureConversion signatureConversion(
@@ -94,10 +94,11 @@ struct GPUFuncOpLowering : ConvertOpToLLVMPattern<gpu::GPUFuncOp> {
for (auto en : llvm::enumerate(workgroupBuffers)) {
LLVM::GlobalOp global = en.value();
Value address = rewriter.create<LLVM::AddressOfOp>(loc, global);
- auto elementType = global.getType().getArrayElementType();
+ auto elementType =
+ global.getType().cast<LLVM::LLVMArrayType>().getElementType();
Value memory = rewriter.create<LLVM::GEPOp>(
- loc, elementType.getPointerTo(global.addr_space()), address,
- ArrayRef<Value>{zero, zero});
+ loc, LLVM::LLVMPointerType::get(elementType, global.addr_space()),
+ address, ArrayRef<Value>{zero, zero});
// Build a memref descriptor pointing to the buffer to plug with the
// existing memref infrastructure. This may use more registers than
@@ -123,9 +124,10 @@ struct GPUFuncOpLowering : ConvertOpToLLVMPattern<gpu::GPUFuncOp> {
// Explicitly drop memory space when lowering private memory
// attributions since NVVM models it as `alloca`s in the default
// memory space and does not support `alloca`s with addrspace(5).
- auto ptrType = typeConverter->convertType(type.getElementType())
- .template cast<LLVM::LLVMType>()
- .getPointerTo(AllocaAddrSpace);
+ auto ptrType = LLVM::LLVMPointerType::get(
+ typeConverter->convertType(type.getElementType())
+ .template cast<LLVM::LLVMType>(),
+ AllocaAddrSpace);
Value numElements = rewriter.create<LLVM::ConstantOp>(
gpuFuncOp.getLoc(), int64Ty,
rewriter.getI64IntegerAttr(type.getNumElements()));
diff --git a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
index 9d08aeee1906..b2887aa1d782 100644
--- a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
@@ -57,7 +57,8 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
LLVMType resultType =
castedOperands.front().getType().cast<LLVM::LLVMType>();
LLVMType funcType = getFunctionType(resultType, castedOperands);
- StringRef funcName = getFunctionName(funcType.getFunctionResultType());
+ StringRef funcName = getFunctionName(
+ funcType.cast<LLVM::LLVMFunctionType>().getReturnType());
if (funcName.empty())
return failure();
@@ -80,7 +81,7 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
private:
Value maybeCast(Value operand, PatternRewriter &rewriter) const {
LLVM::LLVMType type = operand.getType().cast<LLVM::LLVMType>();
- if (!type.isHalfTy())
+ if (!type.isa<LLVM::LLVMHalfType>())
return operand;
return rewriter.create<LLVM::FPExtOp>(
@@ -100,9 +101,9 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
}
StringRef getFunctionName(LLVM::LLVMType type) const {
- if (type.isFloatTy())
+ if (type.isa<LLVM::LLVMFloatType>())
return f32Func;
- if (type.isDoubleTy())
+ if (type.isa<LLVM::LLVMDoubleType>())
return f64Func;
return "";
}
diff --git a/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp b/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
index 355bced96ae7..c676cd256d66 100644
--- a/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
+++ b/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
@@ -75,7 +75,7 @@ class VulkanLaunchFuncToVulkanCallsPass
// int64_t sizes[Rank]; // omitted when rank == 0
// int64_t strides[Rank]; // omitted when rank == 0
// };
- auto llvmPtrToElementType = elemenType.getPointerTo();
+ auto llvmPtrToElementType = LLVM::LLVMPointerType::get(elemenType);
auto llvmArrayRankElementSizeType =
LLVM::LLVMType::getArrayTy(getInt64Type(), rank);
@@ -131,16 +131,18 @@ class VulkanLaunchFuncToVulkanCallsPass
/// Returns a string representation from the given `type`.
StringRef stringifyType(LLVM::LLVMType type) {
- if (type.isFloatTy())
+ if (type.isa<LLVM::LLVMFloatType>())
return "Float";
- if (type.isHalfTy())
+ if (type.isa<LLVM::LLVMHalfType>())
return "Half";
- if (type.isIntegerTy(32))
- return "Int32";
- if (type.isIntegerTy(16))
- return "Int16";
- if (type.isIntegerTy(8))
- return "Int8";
+ if (auto intType = type.dyn_cast<LLVM::LLVMIntegerType>()) {
+ if (intType.getBitWidth() == 32)
+ return "Int32";
+ if (intType.getBitWidth() == 16)
+ return "Int16";
+ if (intType.getBitWidth() == 8)
+ return "Int8";
+ }
llvm_unreachable("unsupported type");
}
@@ -238,11 +240,11 @@ void VulkanLaunchFuncToVulkanCallsPass::createBindMemRefCalls(
llvm::formatv("bindMemRef{0}D{1}", rank, stringifyType(type)).str();
// Special case for fp16 type. Since it is not a supported type in C we use
// int16_t and bitcast the descriptor.
- if (type.isHalfTy()) {
+ if (type.isa<LLVM::LLVMHalfType>()) {
auto memRefTy =
getMemRefType(rank, LLVM::LLVMType::getInt16Ty(&getContext()));
ptrToMemRefDescriptor = builder.create<LLVM::BitcastOp>(
- loc, memRefTy.getPointerTo(), ptrToMemRefDescriptor);
+ loc, LLVM::LLVMPointerType::get(memRefTy), ptrToMemRefDescriptor);
}
// Create call to `bindMemRef`.
builder.create<LLVM::CallOp>(
@@ -257,11 +259,12 @@ void VulkanLaunchFuncToVulkanCallsPass::createBindMemRefCalls(
LogicalResult VulkanLaunchFuncToVulkanCallsPass::deduceMemRefRankAndType(
Value ptrToMemRefDescriptor, uint32_t &rank, LLVM::LLVMType &type) {
auto llvmPtrDescriptorTy =
- ptrToMemRefDescriptor.getType().dyn_cast<LLVM::LLVMType>();
+ ptrToMemRefDescriptor.getType().dyn_cast<LLVM::LLVMPointerType>();
if (!llvmPtrDescriptorTy)
return failure();
- auto llvmDescriptorTy = llvmPtrDescriptorTy.getPointerElementTy();
+ auto llvmDescriptorTy =
+ llvmPtrDescriptorTy.getElementType().dyn_cast<LLVM::LLVMStructType>();
// template <typename Elem, size_t Rank>
// struct {
// Elem *allocated;
@@ -270,15 +273,19 @@ LogicalResult VulkanLaunchFuncToVulkanCallsPass::deduceMemRefRankAndType(
// int64_t sizes[Rank]; // omitted when rank == 0
// int64_t strides[Rank]; // omitted when rank == 0
// };
- if (!llvmDescriptorTy || !llvmDescriptorTy.isStructTy())
+ if (!llvmDescriptorTy)
return failure();
- type = llvmDescriptorTy.getStructElementType(0).getPointerElementTy();
- if (llvmDescriptorTy.getStructNumElements() == 3) {
+ type = llvmDescriptorTy.getBody()[0]
+ .cast<LLVM::LLVMPointerType>()
+ .getElementType();
+ if (llvmDescriptorTy.getBody().size() == 3) {
rank = 0;
return success();
}
- rank = llvmDescriptorTy.getStructElementType(3).getArrayNumElements();
+ rank = llvmDescriptorTy.getBody()[3]
+ .cast<LLVM::LLVMArrayType>()
+ .getNumElements();
return success();
}
@@ -326,13 +333,13 @@ void VulkanLaunchFuncToVulkanCallsPass::declareVulkanFunctions(Location loc) {
LLVM::LLVMType::getHalfTy(&getContext())}) {
std::string fnName = "bindMemRef" + std::to_string(i) + "D" +
std::string(stringifyType(type));
- if (type.isHalfTy())
+ if (type.isa<LLVM::LLVMHalfType>())
type = LLVM::LLVMType::getInt16Ty(&getContext());
if (!module.lookupSymbol(fnName)) {
auto fnType = LLVM::LLVMType::getFunctionTy(
getVoidType(),
{getPointerType(), getInt32Type(), getInt32Type(),
- getMemRefType(i, type).getPointerTo()},
+ LLVM::LLVMPointerType::get(getMemRefType(i, type))},
/*isVarArg=*/false);
builder.create<LLVM::LLVMFuncOp>(loc, fnName, fnType);
}
diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp
index cacb4787edd4..7da9c47f9219 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp
@@ -66,8 +66,10 @@ static unsigned getBitWidth(Type type) {
/// Returns the bit width of LLVMType integer or vector.
static unsigned getLLVMTypeBitWidth(LLVM::LLVMType type) {
- return type.isVectorTy() ? type.getVectorElementType().getIntegerBitWidth()
- : type.getIntegerBitWidth();
+ auto vectorType = type.dyn_cast<LLVM::LLVMVectorType>();
+ return (vectorType ? vectorType.getElementType() : type)
+ .cast<LLVM::LLVMIntegerType>()
+ .getBitWidth();
}
/// Creates `IntegerAttribute` with all bits set for given type
@@ -265,7 +267,7 @@ static Type convertPointerType(spirv::PointerType type,
TypeConverter &converter) {
auto pointeeType =
converter.convertType(type.getPointeeType()).cast<LLVM::LLVMType>();
- return pointeeType.getPointerTo();
+ return LLVM::LLVMPointerType::get(pointeeType);
}
/// Converts SPIR-V runtime array to LLVM array. Since LLVM allows indexing over
diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index 6fbcc220a86b..e37e7e2dc0c1 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -215,7 +215,7 @@ Type LLVMTypeConverter::convertFunctionType(FunctionType type) {
SignatureConversion conversion(type.getNumInputs());
LLVM::LLVMType converted =
convertFunctionSignature(type, /*isVariadic=*/false, conversion);
- return converted.getPointerTo();
+ return LLVM::LLVMPointerType::get(converted);
}
@@ -267,7 +267,7 @@ LLVMTypeConverter::convertFunctionTypeCWrapper(FunctionType type) {
if (!converted)
return {};
if (t.isa<MemRefType, UnrankedMemRefType>())
- converted = converted.getPointerTo();
+ converted = LLVM::LLVMPointerType::get(converted);
inputs.push_back(converted);
}
@@ -324,7 +324,7 @@ LLVMTypeConverter::getMemRefDescriptorFields(MemRefType type,
LLVM::LLVMType elementType = unwrap(convertType(type.getElementType()));
if (!elementType)
return {};
- auto ptrTy = elementType.getPointerTo(type.getMemorySpace());
+ auto ptrTy = LLVM::LLVMPointerType::get(elementType, type.getMemorySpace());
auto indexTy = getIndexType();
SmallVector<LLVM::LLVMType, 5> results = {ptrTy, ptrTy, indexTy};
@@ -396,7 +396,7 @@ Type LLVMTypeConverter::convertMemRefToBarePtr(BaseMemRefType type) {
LLVM::LLVMType elementType = unwrap(convertType(type.getElementType()));
if (!elementType)
return {};
- return elementType.getPointerTo(type.getMemorySpace());
+ return LLVM::LLVMPointerType::get(elementType, type.getMemorySpace());
}
// Convert an n-D vector type to an LLVM vector type via (n-1)-D array type when
@@ -460,7 +460,7 @@ StructBuilder::StructBuilder(Value v) : value(v) {
Value StructBuilder::extractPtr(OpBuilder &builder, Location loc,
unsigned pos) {
- Type type = structType.cast<LLVM::LLVMType>().getStructElementType(pos);
+ Type type = structType.cast<LLVM::LLVMStructType>().getBody()[pos];
return builder.create<LLVM::ExtractValueOp>(loc, type, value,
builder.getI64ArrayAttr(pos));
}
@@ -507,8 +507,9 @@ Value ComplexStructBuilder::imaginary(OpBuilder &builder, Location loc) {
MemRefDescriptor::MemRefDescriptor(Value descriptor)
: StructBuilder(descriptor) {
assert(value != nullptr && "value cannot be null");
- indexType = value.getType().cast<LLVM::LLVMType>().getStructElementType(
- kOffsetPosInMemRefDescriptor);
+ indexType = value.getType()
+ .cast<LLVM::LLVMStructType>()
+ .getBody()[kOffsetPosInMemRefDescriptor];
}
/// Builds IR creating an `undef` value of the descriptor type.
@@ -618,9 +619,9 @@ Value MemRefDescriptor::size(OpBuilder &builder, Location loc, unsigned pos) {
Value MemRefDescriptor::size(OpBuilder &builder, Location loc, Value pos,
int64_t rank) {
auto indexTy = indexType.cast<LLVM::LLVMType>();
- auto indexPtrTy = indexTy.getPointerTo();
+ auto indexPtrTy = LLVM::LLVMPointerType::get(indexTy);
auto arrayTy = LLVM::LLVMType::getArrayTy(indexTy, rank);
- auto arrayPtrTy = arrayTy.getPointerTo();
+ auto arrayPtrTy = LLVM::LLVMPointerType::get(arrayTy);
// Copy size values to stack-allocated memory.
auto zero = createIndexAttrConstant(builder, loc, indexType, 0);
@@ -675,8 +676,8 @@ void MemRefDescriptor::setConstantStride(OpBuilder &builder, Location loc,
LLVM::LLVMPointerType MemRefDescriptor::getElementPtrType() {
return value.getType()
- .cast<LLVM::LLVMType>()
- .getStructElementType(kAlignedPtrPosInMemRefDescriptor)
+ .cast<LLVM::LLVMStructType>()
+ .getBody()[kAlignedPtrPosInMemRefDescriptor]
.cast<LLVM::LLVMPointerType>();
}
@@ -922,7 +923,7 @@ Value UnrankedMemRefDescriptor::offset(OpBuilder &builder, Location loc,
Value offsetGep = builder.create<LLVM::GEPOp>(
loc, elemPtrPtrType, elementPtrPtr, ValueRange({two}));
offsetGep = builder.create<LLVM::BitcastOp>(
- loc, typeConverter.getIndexType().getPointerTo(), offsetGep);
+ loc, LLVM::LLVMPointerType::get(typeConverter.getIndexType()), offsetGep);
return builder.create<LLVM::LoadOp>(loc, offsetGep);
}
@@ -939,19 +940,17 @@ void UnrankedMemRefDescriptor::setOffset(OpBuilder &builder, Location loc,
Value offsetGep = builder.create<LLVM::GEPOp>(
loc, elemPtrPtrType, elementPtrPtr, ValueRange({two}));
offsetGep = builder.create<LLVM::BitcastOp>(
- loc, typeConverter.getIndexType().getPointerTo(), offsetGep);
+ loc, LLVM::LLVMPointerType::get(typeConverter.getIndexType()), offsetGep);
builder.create<LLVM::StoreOp>(loc, offset, offsetGep);
}
-Value UnrankedMemRefDescriptor::sizeBasePtr(OpBuilder &builder, Location loc,
- LLVMTypeConverter &typeConverter,
- Value memRefDescPtr,
- LLVM::LLVMType elemPtrPtrType) {
- LLVM::LLVMType elemPtrTy = elemPtrPtrType.getPointerElementTy();
+Value UnrankedMemRefDescriptor::sizeBasePtr(
+ OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter,
+ Value memRefDescPtr, LLVM::LLVMPointerType elemPtrPtrType) {
+ LLVM::LLVMType elemPtrTy = elemPtrPtrType.getElementType();
LLVM::LLVMType indexTy = typeConverter.getIndexType();
- LLVM::LLVMType structPtrTy =
- LLVM::LLVMType::getStructTy(elemPtrTy, elemPtrTy, indexTy, indexTy)
- .getPointerTo();
+ LLVM::LLVMType structPtrTy = LLVM::LLVMPointerType::get(
+ LLVM::LLVMType::getStructTy(elemPtrTy, elemPtrTy, indexTy, indexTy));
Value structPtr =
builder.create<LLVM::BitcastOp>(loc, structPtrTy, memRefDescPtr);
@@ -961,14 +960,15 @@ Value UnrankedMemRefDescriptor::sizeBasePtr(OpBuilder &builder, Location loc,
createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 0);
Value three = builder.create<LLVM::ConstantOp>(loc, int32_type,
builder.getI32IntegerAttr(3));
- return builder.create<LLVM::GEPOp>(loc, indexTy.getPointerTo(), structPtr,
- ValueRange({zero, three}));
+ return builder.create<LLVM::GEPOp>(loc, LLVM::LLVMPointerType::get(indexTy),
+ structPtr, ValueRange({zero, three}));
}
Value UnrankedMemRefDescriptor::size(OpBuilder &builder, Location loc,
LLVMTypeConverter typeConverter,
Value sizeBasePtr, Value index) {
- LLVM::LLVMType indexPtrTy = typeConverter.getIndexType().getPointerTo();
+ LLVM::LLVMType indexPtrTy =
+ LLVM::LLVMPointerType::get(typeConverter.getIndexType());
Value sizeStoreGep = builder.create<LLVM::GEPOp>(loc, indexPtrTy, sizeBasePtr,
ValueRange({index}));
return builder.create<LLVM::LoadOp>(loc, sizeStoreGep);
@@ -978,7 +978,8 @@ void UnrankedMemRefDescriptor::setSize(OpBuilder &builder, Location loc,
LLVMTypeConverter typeConverter,
Value sizeBasePtr, Value index,
Value size) {
- LLVM::LLVMType indexPtrTy = typeConverter.getIndexType().getPointerTo();
+ LLVM::LLVMType indexPtrTy =
+ LLVM::LLVMPointerType::get(typeConverter.getIndexType());
Value sizeStoreGep = builder.create<LLVM::GEPOp>(loc, indexPtrTy, sizeBasePtr,
ValueRange({index}));
builder.create<LLVM::StoreOp>(loc, size, sizeStoreGep);
@@ -987,7 +988,8 @@ void UnrankedMemRefDescriptor::setSize(OpBuilder &builder, Location loc,
Value UnrankedMemRefDescriptor::strideBasePtr(OpBuilder &builder, Location loc,
LLVMTypeConverter &typeConverter,
Value sizeBasePtr, Value rank) {
- LLVM::LLVMType indexPtrTy = typeConverter.getIndexType().getPointerTo();
+ LLVM::LLVMType indexPtrTy =
+ LLVM::LLVMPointerType::get(typeConverter.getIndexType());
return builder.create<LLVM::GEPOp>(loc, indexPtrTy, sizeBasePtr,
ValueRange({rank}));
}
@@ -996,7 +998,8 @@ Value UnrankedMemRefDescriptor::stride(OpBuilder &builder, Location loc,
LLVMTypeConverter typeConverter,
Value strideBasePtr, Value index,
Value stride) {
- LLVM::LLVMType indexPtrTy = typeConverter.getIndexType().getPointerTo();
+ LLVM::LLVMType indexPtrTy =
+ LLVM::LLVMPointerType::get(typeConverter.getIndexType());
Value strideStoreGep = builder.create<LLVM::GEPOp>(
loc, indexPtrTy, strideBasePtr, ValueRange({index}));
return builder.create<LLVM::LoadOp>(loc, strideStoreGep);
@@ -1006,7 +1009,8 @@ void UnrankedMemRefDescriptor::setStride(OpBuilder &builder, Location loc,
LLVMTypeConverter typeConverter,
Value strideBasePtr, Value index,
Value stride) {
- LLVM::LLVMType indexPtrTy = typeConverter.getIndexType().getPointerTo();
+ LLVM::LLVMType indexPtrTy =
+ LLVM::LLVMPointerType::get(typeConverter.getIndexType());
Value strideStoreGep = builder.create<LLVM::GEPOp>(
loc, indexPtrTy, strideBasePtr, ValueRange({index}));
builder.create<LLVM::StoreOp>(loc, stride, strideStoreGep);
@@ -1100,7 +1104,7 @@ bool ConvertToLLVMPattern::isSupportedMemRefType(MemRefType type) const {
Type ConvertToLLVMPattern::getElementPtrType(MemRefType type) const {
auto elementType = type.getElementType();
auto structElementType = unwrap(typeConverter->convertType(elementType));
- return structElementType.getPointerTo(type.getMemorySpace());
+ return LLVM::LLVMPointerType::get(structElementType, type.getMemorySpace());
}
void ConvertToLLVMPattern::getMemRefDescriptorSizes(
@@ -1158,8 +1162,8 @@ Value ConvertToLLVMPattern::getSizeInBytes(
// %0 = getelementptr %elementType* null, %indexType 1
// %1 = ptrtoint %elementType* %0 to %indexType
// which is a common pattern of getting the size of a type in bytes.
- auto convertedPtrType =
- typeConverter->convertType(type).cast<LLVM::LLVMType>().getPointerTo();
+ auto convertedPtrType = LLVM::LLVMPointerType::get(
+ typeConverter->convertType(type).cast<LLVM::LLVMType>());
auto nullPtr = rewriter.create<LLVM::NullOp>(loc, convertedPtrType);
auto gep = rewriter.create<LLVM::GEPOp>(
loc, convertedPtrType,
@@ -1315,7 +1319,8 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc,
builder, loc, typeConverter, unrankedMemRefType,
wrapperArgsRange.take_front(numToDrop));
- auto ptrTy = packed.getType().cast<LLVM::LLVMType>().getPointerTo();
+ auto ptrTy =
+ LLVM::LLVMPointerType::get(packed.getType().cast<LLVM::LLVMType>());
Value one = builder.create<LLVM::ConstantOp>(
loc, typeConverter.convertType(builder.getIndexType()),
builder.getIntegerAttr(builder.getIndexType(), 1));
@@ -1512,11 +1517,12 @@ static NDVectorTypeInfo extractNDVectorTypeInfo(VectorType vectorType,
return info;
info.arraySizes.reserve(vectorType.getRank() - 1);
auto llvmTy = info.llvmArrayTy;
- while (llvmTy.isArrayTy()) {
- info.arraySizes.push_back(llvmTy.getArrayNumElements());
- llvmTy = llvmTy.getArrayElementType();
+ while (llvmTy.isa<LLVM::LLVMArrayType>()) {
+ info.arraySizes.push_back(
+ llvmTy.cast<LLVM::LLVMArrayType>().getNumElements());
+ llvmTy = llvmTy.cast<LLVM::LLVMArrayType>().getElementType();
}
- if (!llvmTy.isVectorTy())
+ if (!llvmTy.isa<LLVM::LLVMVectorType>())
return info;
info.llvmVectorTy = llvmTy;
return info;
@@ -1644,7 +1650,7 @@ LogicalResult LLVM::detail::vectorOneToOneRewrite(
return failure();
auto llvmArrayTy = operands[0].getType().cast<LLVM::LLVMType>();
- if (!llvmArrayTy.isArrayTy())
+ if (!llvmArrayTy.isa<LLVM::LLVMArrayType>())
return oneToOneRewrite(op, targetOp, operands, typeConverter, rewriter);
auto callback = [op, targetOp, &rewriter](LLVM::LLVMType llvmVectorTy,
@@ -2457,13 +2463,14 @@ struct GetGlobalMemrefOpLowering : public AllocLikeOpLowering {
LLVM::LLVMType arrayTy =
convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
auto addressOf = rewriter.create<LLVM::AddressOfOp>(
- loc, arrayTy.getPointerTo(memSpace), getGlobalOp.name());
+ loc, LLVM::LLVMPointerType::get(arrayTy, memSpace), getGlobalOp.name());
// Get the address of the first element in the array by creating a GEP with
// the address of the GV as the base, and (rank + 1) number of 0 indices.
LLVM::LLVMType elementType =
unwrap(typeConverter->convertType(type.getElementType()));
- LLVM::LLVMType elementPtrType = elementType.getPointerTo(memSpace);
+ LLVM::LLVMType elementPtrType =
+ LLVM::LLVMPointerType::get(elementType, memSpace);
SmallVector<Value, 4> operands = {addressOf};
operands.insert(operands.end(), type.getRank() + 1,
@@ -2504,9 +2511,9 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern<RsqrtOp> {
auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
- if (!operandType.isArrayTy()) {
+ if (!operandType.isa<LLVM::LLVMArrayType>()) {
LLVM::ConstantOp one;
- if (operandType.isVectorTy()) {
+ if (operandType.isa<LLVM::LLVMVectorType>()) {
one = rewriter.create<LLVM::ConstantOp>(
loc, operandType,
SplatElementsAttr::get(resultType.cast<ShapedType>(), floatOne));
@@ -2526,8 +2533,10 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern<RsqrtOp> {
op.getOperation(), operands, *getTypeConverter(),
[&](LLVM::LLVMType llvmVectorTy, ValueRange operands) {
auto splatAttr = SplatElementsAttr::get(
- mlir::VectorType::get({llvmVectorTy.getVectorNumElements()},
- floatType),
+ mlir::VectorType::get(
+ {llvmVectorTy.cast<LLVM::LLVMFixedVectorType>()
+ .getNumElements()},
+ floatType),
floatOne);
auto one =
rewriter.create<LLVM::ConstantOp>(loc, llvmVectorTy, splatAttr);
@@ -2614,12 +2623,13 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<MemRefCastOp> {
// ptr = ExtractValueOp src, 1
auto ptr = memRefDesc.memRefDescPtr(rewriter, loc);
// castPtr = BitCastOp i8* to structTy*
- auto castPtr =
- rewriter
- .create<LLVM::BitcastOp>(
- loc, targetStructType.cast<LLVM::LLVMType>().getPointerTo(),
- ptr)
- .getResult();
+ auto castPtr = rewriter
+ .create<LLVM::BitcastOp>(
+ loc,
+ LLVM::LLVMPointerType::get(
+ targetStructType.cast<LLVM::LLVMType>()),
+ ptr)
+ .getResult();
// struct = LoadOp castPtr
auto loadOp = rewriter.create<LLVM::LoadOp>(loc, castPtr);
rewriter.replaceOp(memRefCastOp, loadOp.getResult());
@@ -2654,8 +2664,8 @@ static void extractPointersAndOffset(Location loc,
Type elementType = operandType.cast<UnrankedMemRefType>().getElementType();
LLVM::LLVMType llvmElementType =
unwrap(typeConverter.convertType(elementType));
- LLVM::LLVMType elementPtrPtrType =
- llvmElementType.getPointerTo(memorySpace).getPointerTo();
+ LLVM::LLVMType elementPtrPtrType = LLVM::LLVMPointerType::get(
+ LLVM::LLVMPointerType::get(llvmElementType, memorySpace));
// Extract pointer to the underlying ranked memref descriptor and cast it to
// ElemType**.
@@ -2700,8 +2710,8 @@ struct MemRefReinterpretCastOpLowering
MemRefType targetMemRefType =
castOp.getResult().getType().cast<MemRefType>();
auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType)
- .dyn_cast_or_null<LLVM::LLVMType>();
- if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy())
+ .dyn_cast_or_null<LLVM::LLVMStructType>();
+ if (!llvmTargetDescriptorTy)
return failure();
// Create descriptor.
@@ -2804,8 +2814,8 @@ struct MemRefReshapeOpLowering
// Set pointers and offset.
LLVM::LLVMType llvmElementType =
unwrap(typeConverter->convertType(elementType));
- LLVM::LLVMType elementPtrPtrType =
- llvmElementType.getPointerTo(addressSpace).getPointerTo();
+ auto elementPtrPtrType = LLVM::LLVMPointerType::get(
+ LLVM::LLVMPointerType::get(llvmElementType, addressSpace));
UnrankedMemRefDescriptor::setAllocatedPtr(rewriter, loc, underlyingDescPtr,
elementPtrPtrType, allocatedPtr);
UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, *getTypeConverter(),
@@ -2858,7 +2868,7 @@ struct MemRefReshapeOpLowering
rewriter.setInsertionPointToStart(bodyBlock);
// Copy size from shape to descriptor.
- LLVM::LLVMType llvmIndexPtrType = indexType.getPointerTo();
+ LLVM::LLVMType llvmIndexPtrType = LLVM::LLVMPointerType::get(indexType);
Value sizeLoadGep = rewriter.create<LLVM::GEPOp>(
loc, llvmIndexPtrType, shapeOperandPtr, ValueRange{indexArg});
Value size = rewriter.create<LLVM::LoadOp>(loc, sizeLoadGep);
@@ -2950,14 +2960,14 @@ struct DimOpLowering : public ConvertOpToLLVMPattern<DimOp> {
Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc);
Value scalarMemRefDescPtr = rewriter.create<LLVM::BitcastOp>(
loc,
- typeConverter->convertType(scalarMemRefType)
- .cast<LLVM::LLVMType>()
- .getPointerTo(addressSpace),
+ LLVM::LLVMPointerType::get(
+ typeConverter->convertType(scalarMemRefType).cast<LLVM::LLVMType>(),
+ addressSpace),
underlyingRankedDesc);
// Get pointer to offset field of memref<element_type> descriptor.
- Type indexPtrTy =
- getTypeConverter()->getIndexType().getPointerTo(addressSpace);
+ Type indexPtrTy = LLVM::LLVMPointerType::get(
+ getTypeConverter()->getIndexType(), addressSpace);
Value two = rewriter.create<LLVM::ConstantOp>(
loc, typeConverter->convertType(rewriter.getI32Type()),
rewriter.getI32IntegerAttr(2));
@@ -3120,10 +3130,10 @@ struct IndexCastOpLowering : public ConvertOpToLLVMPattern<IndexCastOp> {
auto targetType =
typeConverter->convertType(indexCastOp.getResult().getType())
- .cast<LLVM::LLVMType>();
- auto sourceType = transformed.in().getType().cast<LLVM::LLVMType>();
- unsigned targetBits = targetType.getIntegerBitWidth();
- unsigned sourceBits = sourceType.getIntegerBitWidth();
+ .cast<LLVM::LLVMIntegerType>();
+ auto sourceType = transformed.in().getType().cast<LLVM::LLVMIntegerType>();
+ unsigned targetBits = targetType.getBitWidth();
+ unsigned sourceBits = sourceType.getBitWidth();
if (targetBits == sourceBits)
rewriter.replaceOp(indexCastOp, transformed.in());
@@ -3462,14 +3472,18 @@ struct SubViewOpLowering : public ConvertOpToLLVMPattern<SubViewOp> {
// Copy the buffer pointer from the old descriptor to the new one.
Value extracted = sourceMemRef.allocatedPtr(rewriter, loc);
Value bitcastPtr = rewriter.create<LLVM::BitcastOp>(
- loc, targetElementTy.getPointerTo(viewMemRefType.getMemorySpace()),
+ loc,
+ LLVM::LLVMPointerType::get(targetElementTy,
+ viewMemRefType.getMemorySpace()),
extracted);
targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr);
// Copy the buffer pointer from the old descriptor to the new one.
extracted = sourceMemRef.alignedPtr(rewriter, loc);
bitcastPtr = rewriter.create<LLVM::BitcastOp>(
- loc, targetElementTy.getPointerTo(viewMemRefType.getMemorySpace()),
+ loc,
+ LLVM::LLVMPointerType::get(targetElementTy,
+ viewMemRefType.getMemorySpace()),
extracted);
targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr);
@@ -3662,7 +3676,9 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<ViewOp> {
Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc);
auto srcMemRefType = viewOp.source().getType().cast<MemRefType>();
Value bitcastPtr = rewriter.create<LLVM::BitcastOp>(
- loc, targetElementTy.getPointerTo(srcMemRefType.getMemorySpace()),
+ loc,
+ LLVM::LLVMPointerType::get(targetElementTy,
+ srcMemRefType.getMemorySpace()),
allocatedPtr);
targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr);
@@ -3671,7 +3687,9 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<ViewOp> {
alignedPtr = rewriter.create<LLVM::GEPOp>(loc, alignedPtr.getType(),
alignedPtr, adaptor.byte_shift());
bitcastPtr = rewriter.create<LLVM::BitcastOp>(
- loc, targetElementTy.getPointerTo(srcMemRefType.getMemorySpace()),
+ loc,
+ LLVM::LLVMPointerType::get(targetElementTy,
+ srcMemRefType.getMemorySpace()),
alignedPtr);
targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr);
@@ -4064,7 +4082,8 @@ Value LLVMTypeConverter::promoteOneMemRefDescriptor(Location loc, Value operand,
auto indexType = IndexType::get(context);
// Alloca with proper alignment. We do not expect optimizations of this
// alloca op and so we omit allocating at the entry block.
- auto ptrType = operand.getType().cast<LLVM::LLVMType>().getPointerTo();
+ auto ptrType =
+ LLVM::LLVMPointerType::get(operand.getType().cast<LLVM::LLVMType>());
Value one = builder.create<LLVM::ConstantOp>(loc, int64Ty,
IntegerAttr::get(indexType, 1));
Value allocated =
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index a982b90e0e93..bcc91e304e72 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -193,7 +193,7 @@ static LogicalResult getBasePtr(ConversionPatternRewriter &rewriter,
Value base;
if (failed(getBase(rewriter, loc, memref, memRefType, base)))
return failure();
- auto pType = type.template cast<LLVM::LLVMType>().getPointerTo();
+ auto pType = LLVM::LLVMPointerType::get(type.template cast<LLVM::LLVMType>());
base = rewriter.create<LLVM::BitcastOp>(loc, pType, base);
ptr = rewriter.create<LLVM::GEPOp>(loc, pType, base);
return success();
@@ -1100,14 +1100,14 @@ class VectorTypeCastOpConversion
return failure();
auto llvmSourceDescriptorTy =
- operands[0].getType().dyn_cast<LLVM::LLVMType>();
- if (!llvmSourceDescriptorTy || !llvmSourceDescriptorTy.isStructTy())
+ operands[0].getType().dyn_cast<LLVM::LLVMStructType>();
+ if (!llvmSourceDescriptorTy)
return failure();
MemRefDescriptor sourceMemRef(operands[0]);
auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType)
- .dyn_cast_or_null<LLVM::LLVMType>();
- if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy())
+ .dyn_cast_or_null<LLVM::LLVMStructType>();
+ if (!llvmTargetDescriptorTy)
return failure();
// Only contiguous source buffers supported atm.
@@ -1231,15 +1231,15 @@ class VectorTransferConversion : public ConvertOpToLLVMPattern<ConcreteOp> {
// TODO: support alignment when possible.
Value dataPtr = this->getStridedElementPtr(
loc, memRefType, adaptor.source(), adaptor.indices(), rewriter);
- auto vecTy =
- toLLVMTy(xferOp.getVectorType()).template cast<LLVM::LLVMType>();
+ auto vecTy = toLLVMTy(xferOp.getVectorType())
+ .template cast<LLVM::LLVMFixedVectorType>();
Value vectorDataPtr;
if (memRefType.getMemorySpace() == 0)
- vectorDataPtr =
- rewriter.create<LLVM::BitcastOp>(loc, vecTy.getPointerTo(), dataPtr);
+ vectorDataPtr = rewriter.create<LLVM::BitcastOp>(
+ loc, LLVM::LLVMPointerType::get(vecTy), dataPtr);
else
vectorDataPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
- loc, vecTy.getPointerTo(), dataPtr);
+ loc, LLVM::LLVMPointerType::get(vecTy), dataPtr);
if (!xferOp.isMaskedDim(0))
return replaceTransferOpWithLoadOrStore(rewriter,
@@ -1253,7 +1253,7 @@ class VectorTransferConversion : public ConvertOpToLLVMPattern<ConcreteOp> {
//
// TODO: when the leaf transfer rank is k > 1, we need the last `k`
// dimensions here.
- unsigned vecWidth = vecTy.getVectorNumElements();
+ unsigned vecWidth = vecTy.getNumElements();
unsigned lastIndex = llvm::size(xferOp.indices()) - 1;
Value off = xferOp.indices()[lastIndex];
Value dim = rewriter.create<DimOp>(loc, xferOp.source(), lastIndex);
diff --git a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
index 973b116ef498..1335f33e10aa 100644
--- a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
+++ b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
@@ -78,9 +78,9 @@ class VectorTransferConversion : public ConvertOpToLLVMPattern<ConcreteOp> {
auto toLLVMTy = [&](Type t) {
return this->getTypeConverter()->convertType(t);
};
- LLVM::LLVMType vecTy =
- toLLVMTy(xferOp.getVectorType()).template cast<LLVM::LLVMType>();
- unsigned vecWidth = vecTy.getVectorNumElements();
+ auto vecTy = toLLVMTy(xferOp.getVectorType())
+ .template cast<LLVM::LLVMFixedVectorType>();
+ unsigned vecWidth = vecTy.getNumElements();
Location loc = xferOp->getLoc();
// The backend result vector scalarization have trouble scalarize
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 7b1300da1783..2bdbb877ec84 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -105,9 +105,10 @@ static ParseResult parseCmpOp(OpAsmParser &parser, OperationState &result) {
auto argType = type.dyn_cast<LLVM::LLVMType>();
if (!argType)
return parser.emitError(trailingTypeLoc, "expected LLVM IR dialect type");
- if (argType.isVectorTy())
- resultType =
- LLVMType::getVectorTy(resultType, argType.getVectorNumElements());
+ if (auto vecArgType = argType.dyn_cast<LLVM::LLVMFixedVectorType>())
+ resultType = LLVMType::getVectorTy(resultType, vecArgType.getNumElements());
+ assert(!argType.isa<LLVM::LLVMScalableVectorType>() &&
+ "unhandled scalable vector");
result.addTypes({resultType});
return success();
@@ -118,7 +119,7 @@ static ParseResult parseCmpOp(OpAsmParser &parser, OperationState &result) {
//===----------------------------------------------------------------------===//
static void printAllocaOp(OpAsmPrinter &p, AllocaOp &op) {
- auto elemTy = op.getType().cast<LLVM::LLVMType>().getPointerElementTy();
+ auto elemTy = op.getType().cast<LLVM::LLVMPointerType>().getElementType();
auto funcTy = FunctionType::get(op.getContext(), {op.arraySize().getType()},
{op.getType()});
@@ -363,14 +364,11 @@ static void printLoadOp(OpAsmPrinter &p, LoadOp &op) {
// the resulting type wrapped in MLIR, or nullptr on error.
static Type getLoadStoreElementType(OpAsmParser &parser, Type type,
llvm::SMLoc trailingTypeLoc) {
- auto llvmTy = type.dyn_cast<LLVM::LLVMType>();
+ auto llvmTy = type.dyn_cast<LLVM::LLVMPointerType>();
if (!llvmTy)
- return parser.emitError(trailingTypeLoc, "expected LLVM IR dialect type"),
- nullptr;
- if (!llvmTy.isPointerTy())
return parser.emitError(trailingTypeLoc, "expected LLVM pointer type"),
nullptr;
- return llvmTy.getPointerElementTy();
+ return llvmTy.getElementType();
}
// <operation> ::= `llvm.load` `volatile` ssa-use attribute-dict? `:` type
@@ -569,7 +567,7 @@ static ParseResult parseInvokeOp(OpAsmParser &parser, OperationState &result) {
auto llvmFuncType = LLVM::LLVMType::getFunctionTy(llvmResultType, argTypes,
/*isVarArg=*/false);
- auto wrappedFuncType = llvmFuncType.getPointerTo();
+ auto wrappedFuncType = LLVM::LLVMPointerType::get(llvmFuncType);
auto funcArguments = llvm::makeArrayRef(operands).drop_front();
@@ -613,7 +611,7 @@ static LogicalResult verify(LandingpadOp op) {
for (unsigned idx = 0, ie = op.getNumOperands(); idx < ie; idx++) {
value = op.getOperand(idx);
- bool isFilter = value.getType().cast<LLVMType>().isArrayTy();
+ bool isFilter = value.getType().isa<LLVMArrayType>();
if (isFilter) {
// FIXME: Verify filter clauses when arrays are appropriately handled
} else {
@@ -646,7 +644,7 @@ static void printLandingpadOp(OpAsmPrinter &p, LandingpadOp &op) {
for (auto value : op.getOperands()) {
// Similar to llvm - if clause is an array type then it is filter
// clause else catch clause
- bool isArrayTy = value.getType().cast<LLVMType>().isArrayTy();
+ bool isArrayTy = value.getType().isa<LLVMArrayType>();
p << '(' << (isArrayTy ? "filter " : "catch ") << value << " : "
<< value.getType() << ") ";
}
@@ -728,37 +726,37 @@ static LogicalResult verify(CallOp &op) {
fnType = fn.getType();
}
- if (!fnType.isFunctionTy())
+
+ LLVMFunctionType funcType = fnType.dyn_cast<LLVMFunctionType>();
+ if (!funcType)
return op.emitOpError("callee does not have a functional type: ") << fnType;
// Verify that the operand and result types match the callee.
- if (!fnType.isFunctionVarArg() &&
- fnType.getFunctionNumParams() != (op.getNumOperands() - isIndirect))
+ if (!funcType.isVarArg() &&
+ funcType.getNumParams() != (op.getNumOperands() - isIndirect))
return op.emitOpError()
<< "incorrect number of operands ("
<< (op.getNumOperands() - isIndirect)
- << ") for callee (expecting: " << fnType.getFunctionNumParams()
- << ")";
+ << ") for callee (expecting: " << funcType.getNumParams() << ")";
- if (fnType.getFunctionNumParams() > (op.getNumOperands() - isIndirect))
+ if (funcType.getNumParams() > (op.getNumOperands() - isIndirect))
return op.emitOpError() << "incorrect number of operands ("
<< (op.getNumOperands() - isIndirect)
<< ") for varargs callee (expecting at least: "
- << fnType.getFunctionNumParams() << ")";
+ << funcType.getNumParams() << ")";
- for (unsigned i = 0, e = fnType.getFunctionNumParams(); i != e; ++i)
- if (op.getOperand(i + isIndirect).getType() !=
- fnType.getFunctionParamType(i))
+ for (unsigned i = 0, e = funcType.getNumParams(); i != e; ++i)
+ if (op.getOperand(i + isIndirect).getType() != funcType.getParamType(i))
return op.emitOpError() << "operand type mismatch for operand " << i
<< ": " << op.getOperand(i + isIndirect).getType()
- << " != " << fnType.getFunctionParamType(i);
+ << " != " << funcType.getParamType(i);
if (op.getNumResults() &&
- op.getResult(0).getType() != fnType.getFunctionResultType())
+ op.getResult(0).getType() != funcType.getReturnType())
return op.emitOpError()
<< "result type mismatch: " << op.getResult(0).getType()
- << " != " << fnType.getFunctionResultType();
+ << " != " << funcType.getReturnType();
return success();
}
@@ -848,7 +846,7 @@ static ParseResult parseCallOp(OpAsmParser &parser, OperationState &result) {
}
auto llvmFuncType = LLVM::LLVMType::getFunctionTy(llvmResultType, argTypes,
/*isVarArg=*/false);
- auto wrappedFuncType = llvmFuncType.getPointerTo();
+ auto wrappedFuncType = LLVM::LLVMPointerType::get(llvmFuncType);
auto funcArguments =
ArrayRef<OpAsmParser::OperandType>(operands).drop_front();
@@ -875,8 +873,8 @@ static ParseResult parseCallOp(OpAsmParser &parser, OperationState &result) {
void LLVM::ExtractElementOp::build(OpBuilder &b, OperationState &result,
Value vector, Value position,
ArrayRef<NamedAttribute> attrs) {
- auto wrappedVectorType = vector.getType().cast<LLVM::LLVMType>();
- auto llvmType = wrappedVectorType.getVectorElementType();
+ auto vectorType = vector.getType().cast<LLVM::LLVMVectorType>();
+ auto llvmType = vectorType.getElementType();
build(b, result, llvmType, vector, position);
result.addAttributes(attrs);
}
@@ -903,11 +901,11 @@ static ParseResult parseExtractElementOp(OpAsmParser &parser,
parser.resolveOperand(vector, type, result.operands) ||
parser.resolveOperand(position, positionType, result.operands))
return failure();
- auto wrappedVectorType = type.dyn_cast<LLVM::LLVMType>();
- if (!wrappedVectorType || !wrappedVectorType.isVectorTy())
+ auto vectorType = type.dyn_cast<LLVM::LLVMVectorType>();
+ if (!vectorType)
return parser.emitError(
loc, "expected LLVM IR dialect vector type for operand #1");
- result.addTypes(wrappedVectorType.getVectorElementType());
+ result.addTypes(vectorType.getElementType());
return success();
}
@@ -930,8 +928,8 @@ static LLVM::LLVMType getInsertExtractValueElementType(OpAsmParser &parser,
ArrayAttr positionAttr,
llvm::SMLoc attributeLoc,
llvm::SMLoc typeLoc) {
- auto wrappedContainerType = containerType.dyn_cast<LLVM::LLVMType>();
- if (!wrappedContainerType)
+ auto llvmType = containerType.dyn_cast<LLVM::LLVMType>();
+ if (!llvmType)
return parser.emitError(typeLoc, "expected LLVM IR Dialect type"), nullptr;
// Infer the element type from the structure type: iteratively step inside the
@@ -945,26 +943,24 @@ static LLVM::LLVMType getInsertExtractValueElementType(OpAsmParser &parser,
"expected an array of integer literals"),
nullptr;
int position = positionElementAttr.getInt();
- if (wrappedContainerType.isArrayTy()) {
- if (position < 0 || static_cast<unsigned>(position) >=
- wrappedContainerType.getArrayNumElements())
+ if (auto arrayType = llvmType.dyn_cast<LLVMArrayType>()) {
+ if (position < 0 ||
+ static_cast<unsigned>(position) >= arrayType.getNumElements())
return parser.emitError(attributeLoc, "position out of bounds"),
nullptr;
- wrappedContainerType = wrappedContainerType.getArrayElementType();
- } else if (wrappedContainerType.isStructTy()) {
- if (position < 0 || static_cast<unsigned>(position) >=
- wrappedContainerType.getStructNumElements())
+ llvmType = arrayType.getElementType();
+ } else if (auto structType = llvmType.dyn_cast<LLVMStructType>()) {
+ if (position < 0 ||
+ static_cast<unsigned>(position) >= structType.getBody().size())
return parser.emitError(attributeLoc, "position out of bounds"),
nullptr;
- wrappedContainerType =
- wrappedContainerType.getStructElementType(position);
+ llvmType = structType.getBody()[position];
} else {
- return parser.emitError(typeLoc,
- "expected wrapped LLVM IR structure/array type"),
+ return parser.emitError(typeLoc, "expected LLVM IR structure/array type"),
nullptr;
}
}
- return wrappedContainerType;
+ return llvmType;
}
// <operation> ::= `llvm.extractvalue` ssa-use
@@ -1021,11 +1017,11 @@ static ParseResult parseInsertElementOp(OpAsmParser &parser,
parser.parseColonType(vectorType))
return failure();
- auto wrappedVectorType = vectorType.dyn_cast<LLVM::LLVMType>();
- if (!wrappedVectorType || !wrappedVectorType.isVectorTy())
+ auto llvmVectorType = vectorType.dyn_cast<LLVM::LLVMVectorType>();
+ if (!llvmVectorType)
return parser.emitError(
loc, "expected LLVM IR dialect vector type for operand #1");
- auto valueType = wrappedVectorType.getVectorElementType();
+ Type valueType = llvmVectorType.getElementType();
if (!valueType)
return failure();
@@ -1145,12 +1141,14 @@ static LogicalResult verify(AddressOfOp op) {
return op.emitOpError(
"must reference a global defined by 'llvm.mlir.global' or 'llvm.func'");
- if (global && global.getType().getPointerTo(global.addr_space()) !=
- op.getResult().getType())
+ if (global &&
+ LLVM::LLVMPointerType::get(global.getType(), global.addr_space()) !=
+ op.getResult().getType())
return op.emitOpError(
"the type must be a pointer to the type of the referenced global");
- if (function && function.getType().getPointerTo() != op.getResult().getType())
+ if (function && LLVM::LLVMPointerType::get(function.getType()) !=
+ op.getResult().getType())
return op.emitOpError(
"the type must be a pointer to the type of the referenced function");
@@ -1276,11 +1274,11 @@ static LogicalResult verifyCast(DialectCastOp op, LLVMType llvmType,
if (vectorType.getRank() != 1)
return op->emitOpError("only 1-d vector is allowed");
- auto llvmVector = llvmType.dyn_cast<LLVMVectorType>();
- if (llvmVector.isa<LLVMScalableVectorType>())
+ auto llvmVector = llvmType.dyn_cast<LLVMFixedVectorType>();
+ if (!llvmVector)
return op->emitOpError("only fixed-sized vector is allowed");
- if (vectorType.getDimSize(0) != llvmVector.getVectorNumElements())
+ if (vectorType.getDimSize(0) != llvmVector.getNumElements())
return op->emitOpError(
"invalid cast between vectors with mismatching sizes");
@@ -1375,7 +1373,10 @@ static LogicalResult verifyCast(DialectCastOp op, LLVMType llvmType,
"be an index-compatible integer");
auto ptrType = structType.getBody()[1].dyn_cast<LLVMPointerType>();
- if (!ptrType || !ptrType.getPointerElementTy().isIntegerTy(8))
+ auto ptrElementType =
+ ptrType ? ptrType.getElementType().dyn_cast<LLVMIntegerType>()
+ : nullptr;
+ if (!ptrElementType || ptrElementType.getBitWidth() != 8)
return op->emitOpError("expected second element of a memref descriptor "
"to be an !llvm.ptr<i8>");
@@ -1503,9 +1504,11 @@ static LogicalResult verify(GlobalOp op) {
return op.emitOpError("must appear at the module level");
if (auto strAttr = op.getValueOrNull().dyn_cast_or_null<StringAttr>()) {
- auto type = op.getType();
- if (!type.isArrayTy() || !type.getArrayElementType().isIntegerTy(8) ||
- type.getArrayNumElements() != strAttr.getValue().size())
+ auto type = op.getType().dyn_cast<LLVMArrayType>();
+ LLVMIntegerType elementType =
+ type ? type.getElementType().dyn_cast<LLVMIntegerType>() : nullptr;
+ if (!elementType || elementType.getBitWidth() != 8 ||
+ type.getNumElements() != strAttr.getValue().size())
return op.emitOpError(
"requires an i8 array type of the length equal to that of the string "
"attribute");
@@ -1534,9 +1537,9 @@ static LogicalResult verify(GlobalOp op) {
void LLVM::ShuffleVectorOp::build(OpBuilder &b, OperationState &result,
Value v1, Value v2, ArrayAttr mask,
ArrayRef<NamedAttribute> attrs) {
- auto wrappedContainerType1 = v1.getType().cast<LLVM::LLVMType>();
- auto vType = LLVMType::getVectorTy(
- wrappedContainerType1.getVectorElementType(), mask.size());
+ auto containerType = v1.getType().cast<LLVM::LLVMVectorType>();
+ auto vType =
+ LLVMType::getVectorTy(containerType.getElementType(), mask.size());
build(b, result, vType, v1, v2, mask);
result.addAttributes(attrs);
}
@@ -1566,12 +1569,12 @@ static ParseResult parseShuffleVectorOp(OpAsmParser &parser,
parser.resolveOperand(v1, typeV1, result.operands) ||
parser.resolveOperand(v2, typeV2, result.operands))
return failure();
- auto wrappedContainerType1 = typeV1.dyn_cast<LLVM::LLVMType>();
- if (!wrappedContainerType1 || !wrappedContainerType1.isVectorTy())
+ auto containerType = typeV1.dyn_cast<LLVM::LLVMVectorType>();
+ if (!containerType)
return parser.emitError(
loc, "expected LLVM IR dialect vector type for operand #1");
- auto vType = LLVMType::getVectorTy(
- wrappedContainerType1.getVectorElementType(), maskAttr.size());
+ auto vType =
+ LLVMType::getVectorTy(containerType.getElementType(), maskAttr.size());
result.addTypes(vType);
return success();
}
@@ -1588,9 +1591,9 @@ Block *LLVMFuncOp::addEntryBlock() {
auto *entry = new Block;
push_back(entry);
- LLVMType type = getType();
- for (unsigned i = 0, e = type.getFunctionNumParams(); i < e; ++i)
- entry->addArgument(type.getFunctionParamType(i));
+ LLVMFunctionType type = getType();
+ for (unsigned i = 0, e = type.getNumParams(); i < e; ++i)
+ entry->addArgument(type.getParamType(i));
return entry;
}
@@ -1608,7 +1611,7 @@ void LLVMFuncOp::build(OpBuilder &builder, OperationState &result,
if (argAttrs.empty())
return;
- unsigned numInputs = type.getFunctionNumParams();
+ unsigned numInputs = type.cast<LLVMFunctionType>().getNumParams();
assert(numInputs == argAttrs.size() &&
"expected as many argument attribute lists as arguments");
SmallString<8> argAttrName;
@@ -1711,15 +1714,15 @@ static void printLLVMFuncOp(OpAsmPrinter &p, LLVMFuncOp op) {
p << stringifyLinkage(op.linkage()) << ' ';
p.printSymbolName(op.getName());
- LLVMType fnType = op.getType();
+ LLVMFunctionType fnType = op.getType();
SmallVector<Type, 8> argTypes;
SmallVector<Type, 1> resTypes;
- argTypes.reserve(fnType.getFunctionNumParams());
- for (unsigned i = 0, e = fnType.getFunctionNumParams(); i < e; ++i)
- argTypes.push_back(fnType.getFunctionParamType(i));
+ argTypes.reserve(fnType.getNumParams());
+ for (unsigned i = 0, e = fnType.getNumParams(); i < e; ++i)
+ argTypes.push_back(fnType.getParamType(i));
- LLVMType returnType = fnType.getFunctionResultType();
- if (!returnType.isVoidTy())
+ LLVMType returnType = fnType.getReturnType();
+ if (!returnType.isa<LLVMVoidType>())
resTypes.push_back(returnType);
impl::printFunctionSignature(p, op, argTypes, op.isVarArg(), resTypes);
@@ -1737,8 +1740,8 @@ static void printLLVMFuncOp(OpAsmPrinter &p, LLVMFuncOp op) {
// attribute is present. This can check for preconditions of the
// getNumArguments hook not failing.
LogicalResult LLVMFuncOp::verifyType() {
- auto llvmType = getTypeAttr().getValue().dyn_cast_or_null<LLVMType>();
- if (!llvmType || !llvmType.isFunctionTy())
+ auto llvmType = getTypeAttr().getValue().dyn_cast_or_null<LLVMFunctionType>();
+ if (!llvmType)
return emitOpError("requires '" + getTypeAttrName() +
"' attribute of wrapped LLVM function type");
@@ -1747,9 +1750,7 @@ LogicalResult LLVMFuncOp::verifyType() {
// Hook for OpTrait::FunctionLike, returns the number of function arguments.
// Depends on the type attribute being correct as checked by verifyType
-unsigned LLVMFuncOp::getNumFuncArguments() {
- return getType().getFunctionNumParams();
-}
+unsigned LLVMFuncOp::getNumFuncArguments() { return getType().getNumParams(); }
// Hook for OpTrait::FunctionLike, returns the number of function results.
// Depends on the type attribute being correct as checked by verifyType
@@ -1759,7 +1760,7 @@ unsigned LLVMFuncOp::getNumFuncResults() {
// If we modeled a void return as one result, then it would be possible to
// attach an MLIR result attribute to it, and it isn't clear what semantics we
// would assign to that.
- if (getType().getFunctionResultType().isVoidTy())
+ if (getType().getReturnType().isa<LLVMVoidType>())
return 0;
return 1;
}
@@ -1788,7 +1789,7 @@ static LogicalResult verify(LLVMFuncOp op) {
if (op.isVarArg())
return op.emitOpError("only external functions can be variadic");
- unsigned numArguments = op.getType().getFunctionNumParams();
+ unsigned numArguments = op.getType().getNumParams();
Block &entryBlock = op.front();
for (unsigned i = 0; i < numArguments; ++i) {
Type argType = entryBlock.getArgument(i).getType();
@@ -1796,7 +1797,7 @@ static LogicalResult verify(LLVMFuncOp op) {
if (!argLLVMType)
return op.emitOpError("entry block argument #")
<< i << " is not of LLVM type";
- if (op.getType().getFunctionParamType(i) != argLLVMType)
+ if (op.getType().getParamType(i) != argLLVMType)
return op.emitOpError("the type of entry block argument #")
<< i << " does not match the function signature";
}
@@ -1896,7 +1897,8 @@ static ParseResult parseAtomicRMWOp(OpAsmParser &parser,
parseAtomicOrdering(parser, result, "ordering") ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonType(type) ||
- parser.resolveOperand(ptr, type.getPointerTo(), result.operands) ||
+ parser.resolveOperand(ptr, LLVM::LLVMPointerType::get(type),
+ result.operands) ||
parser.resolveOperand(val, type, result.operands))
return failure();
@@ -1905,9 +1907,9 @@ static ParseResult parseAtomicRMWOp(OpAsmParser &parser,
}
static LogicalResult verify(AtomicRMWOp op) {
- auto ptrType = op.ptr().getType().cast<LLVM::LLVMType>();
+ auto ptrType = op.ptr().getType().cast<LLVM::LLVMPointerType>();
auto valType = op.val().getType().cast<LLVM::LLVMType>();
- if (valType != ptrType.getPointerElementTy())
+ if (valType != ptrType.getElementType())
return op.emitOpError("expected LLVM IR element type for operand #0 to "
"match type for operand #1");
auto resType = op.res().getType().cast<LLVM::LLVMType>();
@@ -1915,17 +1917,21 @@ static LogicalResult verify(AtomicRMWOp op) {
return op.emitOpError(
"expected LLVM IR result type to match type for operand #1");
if (op.bin_op() == AtomicBinOp::fadd || op.bin_op() == AtomicBinOp::fsub) {
- if (!valType.isFloatingPointTy())
+ if (!mlir::LLVM::isCompatibleFloatingPointType(valType))
return op.emitOpError("expected LLVM IR floating point type");
} else if (op.bin_op() == AtomicBinOp::xchg) {
- if (!valType.isIntegerTy(8) && !valType.isIntegerTy(16) &&
- !valType.isIntegerTy(32) && !valType.isIntegerTy(64) &&
- !valType.isBFloatTy() && !valType.isHalfTy() && !valType.isFloatTy() &&
- !valType.isDoubleTy())
+ auto intType = valType.dyn_cast<LLVMIntegerType>();
+ unsigned intBitWidth = intType ? intType.getBitWidth() : 0;
+ if (intBitWidth != 8 && intBitWidth != 16 && intBitWidth != 32 &&
+ intBitWidth != 64 && !valType.isa<LLVMBFloatType>() &&
+ !valType.isa<LLVMHalfType>() && !valType.isa<LLVMFloatType>() &&
+ !valType.isa<LLVMDoubleType>())
return op.emitOpError("unexpected LLVM IR type for 'xchg' bin_op");
} else {
- if (!valType.isIntegerTy(8) && !valType.isIntegerTy(16) &&
- !valType.isIntegerTy(32) && !valType.isIntegerTy(64))
+ auto intType = valType.dyn_cast<LLVMIntegerType>();
+ unsigned intBitWidth = intType ? intType.getBitWidth() : 0;
+ if (intBitWidth != 8 && intBitWidth != 16 && intBitWidth != 32 &&
+ intBitWidth != 64)
return op.emitOpError("expected LLVM IR integer type");
}
return success();
@@ -1958,7 +1964,8 @@ static ParseResult parseAtomicCmpXchgOp(OpAsmParser &parser,
parseAtomicOrdering(parser, result, "failure_ordering") ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonType(type) ||
- parser.resolveOperand(ptr, type.getPointerTo(), result.operands) ||
+ parser.resolveOperand(ptr, LLVM::LLVMPointerType::get(type),
+ result.operands) ||
parser.resolveOperand(cmp, type, result.operands) ||
parser.resolveOperand(val, type, result.operands))
return failure();
@@ -1971,18 +1978,20 @@ static ParseResult parseAtomicCmpXchgOp(OpAsmParser &parser,
}
static LogicalResult verify(AtomicCmpXchgOp op) {
- auto ptrType = op.ptr().getType().cast<LLVM::LLVMType>();
- if (!ptrType.isPointerTy())
+ auto ptrType = op.ptr().getType().cast<LLVM::LLVMPointerType>();
+ if (!ptrType)
return op.emitOpError("expected LLVM IR pointer type for operand #0");
auto cmpType = op.cmp().getType().cast<LLVM::LLVMType>();
auto valType = op.val().getType().cast<LLVM::LLVMType>();
- if (cmpType != ptrType.getPointerElementTy() || cmpType != valType)
+ if (cmpType != ptrType.getElementType() || cmpType != valType)
return op.emitOpError("expected LLVM IR element type for operand #0 to "
"match type for all other operands");
- if (!valType.isPointerTy() && !valType.isIntegerTy(8) &&
- !valType.isIntegerTy(16) && !valType.isIntegerTy(32) &&
- !valType.isIntegerTy(64) && !valType.isBFloatTy() &&
- !valType.isHalfTy() && !valType.isFloatTy() && !valType.isDoubleTy())
+ auto intType = valType.dyn_cast<LLVMIntegerType>();
+ unsigned intBitWidth = intType ? intType.getBitWidth() : 0;
+ if (!valType.isa<LLVMPointerType>() && intBitWidth != 8 &&
+ intBitWidth != 16 && intBitWidth != 32 && intBitWidth != 64 &&
+ !valType.isa<LLVMBFloatType>() && !valType.isa<LLVMHalfType>() &&
+ !valType.isa<LLVMFloatType>() && !valType.isa<LLVMDoubleType>())
return op.emitOpError("unexpected LLVM IR type");
if (op.success_ordering() < AtomicOrdering::monotonic ||
op.failure_ordering() < AtomicOrdering::monotonic)
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
index a89287b764e5..0616efb7ef3f 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
@@ -36,129 +36,6 @@ LLVMDialect &LLVMType::getDialect() {
return static_cast<LLVMDialect &>(Type::getDialect());
}
-//----------------------------------------------------------------------------//
-// Misc type utilities.
-
-llvm::TypeSize LLVMType::getPrimitiveSizeInBits() {
- return llvm::TypeSwitch<LLVMType, llvm::TypeSize>(*this)
- .Case<LLVMHalfType, LLVMBFloatType>(
- [](LLVMType) { return llvm::TypeSize::Fixed(16); })
- .Case<LLVMFloatType>([](LLVMType) { return llvm::TypeSize::Fixed(32); })
- .Case<LLVMDoubleType, LLVMX86MMXType>(
- [](LLVMType) { return llvm::TypeSize::Fixed(64); })
- .Case<LLVMIntegerType>([](LLVMIntegerType intTy) {
- return llvm::TypeSize::Fixed(intTy.getBitWidth());
- })
- .Case<LLVMX86FP80Type>([](LLVMType) { return llvm::TypeSize::Fixed(80); })
- .Case<LLVMPPCFP128Type, LLVMFP128Type>(
- [](LLVMType) { return llvm::TypeSize::Fixed(128); })
- .Case<LLVMVectorType>([](LLVMVectorType t) {
- llvm::TypeSize elementSize =
- t.getElementType().getPrimitiveSizeInBits();
- llvm::ElementCount elementCount = t.getElementCount();
- assert(!elementSize.isScalable() &&
- "vector type should have fixed-width elements");
- return llvm::TypeSize(elementSize.getFixedSize() *
- elementCount.getKnownMinValue(),
- elementCount.isScalable());
- })
- .Default([](LLVMType ty) {
- assert((ty.isa<LLVMVoidType, LLVMLabelType, LLVMMetadataType,
- LLVMTokenType, LLVMStructType, LLVMArrayType,
- LLVMPointerType, LLVMFunctionType>()) &&
- "unexpected missing support for primitive type");
- return llvm::TypeSize::Fixed(0);
- });
-}
-
-//----------------------------------------------------------------------------//
-// Integer type utilities.
-
-bool LLVMType::isIntegerTy(unsigned bitwidth) {
- if (auto intType = dyn_cast<LLVMIntegerType>())
- return intType.getBitWidth() == bitwidth;
- return false;
-}
-unsigned LLVMType::getIntegerBitWidth() {
- return cast<LLVMIntegerType>().getBitWidth();
-}
-
-LLVMType LLVMType::getArrayElementType() {
- return cast<LLVMArrayType>().getElementType();
-}
-
-//----------------------------------------------------------------------------//
-// Array type utilities.
-
-unsigned LLVMType::getArrayNumElements() {
- return cast<LLVMArrayType>().getNumElements();
-}
-
-bool LLVMType::isArrayTy() { return isa<LLVMArrayType>(); }
-
-//----------------------------------------------------------------------------//
-// Vector type utilities.
-
-LLVMType LLVMType::getVectorElementType() {
- return cast<LLVMVectorType>().getElementType();
-}
-
-unsigned LLVMType::getVectorNumElements() {
- return cast<LLVMFixedVectorType>().getNumElements();
-}
-llvm::ElementCount LLVMType::getVectorElementCount() {
- return cast<LLVMVectorType>().getElementCount();
-}
-
-bool LLVMType::isVectorTy() { return isa<LLVMVectorType>(); }
-
-//----------------------------------------------------------------------------//
-// Function type utilities.
-
-LLVMType LLVMType::getFunctionParamType(unsigned argIdx) {
- return cast<LLVMFunctionType>().getParamType(argIdx);
-}
-
-unsigned LLVMType::getFunctionNumParams() {
- return cast<LLVMFunctionType>().getNumParams();
-}
-
-LLVMType LLVMType::getFunctionResultType() {
- return cast<LLVMFunctionType>().getReturnType();
-}
-
-bool LLVMType::isFunctionTy() { return isa<LLVMFunctionType>(); }
-
-bool LLVMType::isFunctionVarArg() {
- return cast<LLVMFunctionType>().isVarArg();
-}
-
-//----------------------------------------------------------------------------//
-// Pointer type utilities.
-
-LLVMType LLVMType::getPointerTo(unsigned addrSpace) {
- return LLVMPointerType::get(*this, addrSpace);
-}
-
-LLVMType LLVMType::getPointerElementTy() {
- return cast<LLVMPointerType>().getElementType();
-}
-
-bool LLVMType::isPointerTy() { return isa<LLVMPointerType>(); }
-
-//----------------------------------------------------------------------------//
-// Struct type utilities.
-
-LLVMType LLVMType::getStructElementType(unsigned i) {
- return cast<LLVMStructType>().getBody()[i];
-}
-
-unsigned LLVMType::getStructNumElements() {
- return cast<LLVMStructType>().getBody().size();
-}
-
-bool LLVMType::isStructTy() { return isa<LLVMStructType>(); }
-
//----------------------------------------------------------------------------//
// Utilities used to generate floating point types.
@@ -193,6 +70,10 @@ LLVMType LLVMType::getIntNTy(MLIRContext *context, unsigned numBits) {
return LLVMIntegerType::get(context, numBits);
}
+LLVMType LLVMType::getInt8PtrTy(MLIRContext *context) {
+ return LLVMPointerType::get(LLVMIntegerType::get(context, 8));
+}
+
//----------------------------------------------------------------------------//
// Utilities used to generate other miscellaneous types.
@@ -221,8 +102,6 @@ LLVMType LLVMType::getVoidTy(MLIRContext *context) {
return LLVMVoidType::get(context);
}
-bool LLVMType::isVoidTy() { return isa<LLVMVoidType>(); }
-
//----------------------------------------------------------------------------//
// Creation and setting of LLVM's identified struct types
@@ -470,7 +349,7 @@ LLVMStructType::verifyConstructionInvariants(Location loc,
bool LLVMVectorType::isValidElementType(LLVMType type) {
return type.isa<LLVMIntegerType, LLVMPointerType>() ||
- type.isFloatingPointTy();
+ mlir::LLVM::isCompatibleFloatingPointType(type);
}
/// Support type casting functionality.
@@ -536,3 +415,42 @@ LLVMScalableVectorType::getChecked(Location loc, LLVMType elementType,
unsigned LLVMScalableVectorType::getMinNumElements() {
return getImpl()->numElements;
}
+
+//===----------------------------------------------------------------------===//
+// Utility functions.
+//===----------------------------------------------------------------------===//
+
+llvm::TypeSize mlir::LLVM::getPrimitiveTypeSizeInBits(Type type) {
+ assert(isCompatibleType(type) &&
+ "expected a type compatible with the LLVM dialect");
+
+ return llvm::TypeSwitch<Type, llvm::TypeSize>(type)
+ .Case<LLVMHalfType, LLVMBFloatType>(
+ [](LLVMType) { return llvm::TypeSize::Fixed(16); })
+ .Case<LLVMFloatType>([](LLVMType) { return llvm::TypeSize::Fixed(32); })
+ .Case<LLVMDoubleType, LLVMX86MMXType>(
+ [](LLVMType) { return llvm::TypeSize::Fixed(64); })
+ .Case<LLVMIntegerType>([](LLVMIntegerType intTy) {
+ return llvm::TypeSize::Fixed(intTy.getBitWidth());
+ })
+ .Case<LLVMX86FP80Type>([](LLVMType) { return llvm::TypeSize::Fixed(80); })
+ .Case<LLVMPPCFP128Type, LLVMFP128Type>(
+ [](LLVMType) { return llvm::TypeSize::Fixed(128); })
+ .Case<LLVMVectorType>([](LLVMVectorType t) {
+ llvm::TypeSize elementSize =
+ getPrimitiveTypeSizeInBits(t.getElementType());
+ llvm::ElementCount elementCount = t.getElementCount();
+ assert(!elementSize.isScalable() &&
+ "vector type should have fixed-width elements");
+ return llvm::TypeSize(elementSize.getFixedSize() *
+ elementCount.getKnownMinValue(),
+ elementCount.isScalable());
+ })
+ .Default([](Type ty) {
+ assert((ty.isa<LLVMVoidType, LLVMLabelType, LLVMMetadataType,
+ LLVMTokenType, LLVMStructType, LLVMArrayType,
+ LLVMPointerType, LLVMFunctionType>()) &&
+ "unexpected missing support for primitive type");
+ return llvm::TypeSize::Fixed(0);
+ });
+}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 707ff7c1b089..c202075fa206 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -57,8 +57,9 @@ static ParseResult parseNVVMShflSyncBflyOp(OpAsmParser &parser,
for (auto &attr : result.attributes) {
if (attr.first != "return_value_and_is_valid")
continue;
- if (type.isStructTy() && type.getStructNumElements() > 0)
- type = type.getStructElementType(0);
+ auto structType = type.dyn_cast<LLVM::LLVMStructType>();
+ if (structType && !structType.getBody().empty())
+ type = structType.getBody()[0];
break;
}
diff --git a/mlir/lib/ExecutionEngine/JitRunner.cpp b/mlir/lib/ExecutionEngine/JitRunner.cpp
index a323e68170c1..bfdae2b4588d 100644
--- a/mlir/lib/ExecutionEngine/JitRunner.cpp
+++ b/mlir/lib/ExecutionEngine/JitRunner.cpp
@@ -196,19 +196,30 @@ template <typename Type>
Error checkCompatibleReturnType(LLVM::LLVMFuncOp mainFunction);
template <>
Error checkCompatibleReturnType<int32_t>(LLVM::LLVMFuncOp mainFunction) {
- if (!mainFunction.getType().getFunctionResultType().isIntegerTy(32))
+ auto resultType = mainFunction.getType()
+ .cast<LLVM::LLVMFunctionType>()
+ .getReturnType()
+ .dyn_cast<LLVM::LLVMIntegerType>();
+ if (!resultType || resultType.getBitWidth() != 32)
return make_string_error("only single llvm.i32 function result supported");
return Error::success();
}
template <>
Error checkCompatibleReturnType<int64_t>(LLVM::LLVMFuncOp mainFunction) {
- if (!mainFunction.getType().getFunctionResultType().isIntegerTy(64))
+ auto resultType = mainFunction.getType()
+ .cast<LLVM::LLVMFunctionType>()
+ .getReturnType()
+ .dyn_cast<LLVM::LLVMIntegerType>();
+ if (!resultType || resultType.getBitWidth() != 64)
return make_string_error("only single llvm.i64 function result supported");
return Error::success();
}
template <>
Error checkCompatibleReturnType<float>(LLVM::LLVMFuncOp mainFunction) {
- if (!mainFunction.getType().getFunctionResultType().isFloatTy())
+ if (!mainFunction.getType()
+ .cast<LLVM::LLVMFunctionType>()
+ .getReturnType()
+ .isa<LLVM::LLVMFloatType>())
return make_string_error("only single llvm.f32 function result supported");
return Error::success();
}
@@ -220,7 +231,7 @@ Error compileAndExecuteSingleReturnFunction(Options &options, ModuleOp module,
if (!mainFunction || mainFunction.isExternal())
return make_string_error("entry point not found");
- if (mainFunction.getType().getFunctionNumParams() != 0)
+ if (mainFunction.getType().cast<LLVM::LLVMFunctionType>().getNumParams() != 0)
return make_string_error("function inputs not supported");
if (Error error = checkCompatibleReturnType<Type>(mainFunction))
diff --git a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
index 7f89a41de5db..9786751ef4b0 100644
--- a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
+++ b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
@@ -172,57 +172,57 @@ Type Importer::getStdTypeForAttr(LLVMType type) {
if (!type)
return nullptr;
- if (type.isIntegerTy())
- return b.getIntegerType(type.getIntegerBitWidth());
+ if (auto intType = type.dyn_cast<LLVMIntegerType>())
+ return b.getIntegerType(intType.getBitWidth());
- if (type.isFloatTy())
+ if (type.isa<LLVMFloatType>())
return b.getF32Type();
- if (type.isDoubleTy())
+ if (type.isa<LLVMDoubleType>())
return b.getF64Type();
// LLVM vectors can only contain scalars.
- if (type.isVectorTy()) {
- auto numElements = type.getVectorElementCount();
+ if (auto vectorType = type.dyn_cast<LLVM::LLVMVectorType>()) {
+ auto numElements = vectorType.getElementCount();
if (numElements.isScalable()) {
emitError(unknownLoc) << "scalable vectors not supported";
return nullptr;
}
- Type elementType = getStdTypeForAttr(type.getVectorElementType());
+ Type elementType = getStdTypeForAttr(vectorType.getElementType());
if (!elementType)
return nullptr;
return VectorType::get(numElements.getKnownMinValue(), elementType);
}
// LLVM arrays can contain other arrays or vectors.
- if (type.isArrayTy()) {
+ if (auto arrayType = type.dyn_cast<LLVMArrayType>()) {
// Recover the nested array shape.
SmallVector<int64_t, 4> shape;
- shape.push_back(type.getArrayNumElements());
- while (type.getArrayElementType().isArrayTy()) {
- type = type.getArrayElementType();
- shape.push_back(type.getArrayNumElements());
+ shape.push_back(arrayType.getNumElements());
+ while (arrayType.getElementType().isa<LLVMArrayType>()) {
+ arrayType = arrayType.getElementType().cast<LLVMArrayType>();
+ shape.push_back(arrayType.getNumElements());
}
// If the innermost type is a vector, use the multi-dimensional vector as
// attribute type.
- if (type.getArrayElementType().isVectorTy()) {
- LLVMType vectorType = type.getArrayElementType();
- auto numElements = vectorType.getVectorElementCount();
+ if (auto vectorType =
+ arrayType.getElementType().dyn_cast<LLVMVectorType>()) {
+ auto numElements = vectorType.getElementCount();
if (numElements.isScalable()) {
emitError(unknownLoc) << "scalable vectors not supported";
return nullptr;
}
shape.push_back(numElements.getKnownMinValue());
- Type elementType = getStdTypeForAttr(vectorType.getVectorElementType());
+ Type elementType = getStdTypeForAttr(vectorType.getElementType());
if (!elementType)
return nullptr;
return VectorType::get(shape, elementType);
}
// Otherwise use a tensor.
- Type elementType = getStdTypeForAttr(type.getArrayElementType());
+ Type elementType = getStdTypeForAttr(arrayType.getElementType());
if (!elementType)
return nullptr;
return RankedTensorType::get(shape, elementType);
@@ -261,7 +261,7 @@ Attribute Importer::getConstantAsAttr(llvm::Constant *value) {
if (!attrType)
return nullptr;
- if (type.isIntegerTy()) {
+ if (type.isa<LLVMIntegerType>()) {
SmallVector<APInt, 8> values;
values.reserve(cd->getNumElements());
for (unsigned i = 0, e = cd->getNumElements(); i < e; ++i)
@@ -269,7 +269,7 @@ Attribute Importer::getConstantAsAttr(llvm::Constant *value) {
return DenseElementsAttr::get(attrType, values);
}
- if (type.isFloatTy() || type.isDoubleTy()) {
+ if (type.isa<LLVMFloatType>() || type.isa<LLVMDoubleType>()) {
SmallVector<APFloat, 8> values;
values.reserve(cd->getNumElements());
for (unsigned i = 0, e = cd->getNumElements(); i < e; ++i)
@@ -777,7 +777,8 @@ LogicalResult Importer::processFunction(llvm::Function *f) {
instMap.clear();
unknownInstMap.clear();
- LLVMType functionType = processType(f->getFunctionType());
+ auto functionType =
+ processType(f->getFunctionType()).dyn_cast<LLVMFunctionType>();
if (!functionType)
return failure();
@@ -805,8 +806,8 @@ LogicalResult Importer::processFunction(llvm::Function *f) {
// Add function arguments to the entry block.
for (auto kv : llvm::enumerate(f->args()))
- instMap[&kv.value()] = blockList[0]->addArgument(
- functionType.getFunctionParamType(kv.index()));
+ instMap[&kv.value()] =
+ blockList[0]->addArgument(functionType.getParamType(kv.index()));
for (auto bbs : llvm::zip(*f, blockList)) {
if (failed(processBasicBlock(&std::get<0>(bbs), std::get<1>(bbs))))
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 8c650506e2d7..ae0745b0be28 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -969,7 +969,7 @@ LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) {
// NB: Attribute already verified to be boolean, so check if we can indeed
// attach the attribute to this argument, based on its type.
auto argTy = mlirArg.getType().dyn_cast<LLVM::LLVMType>();
- if (!argTy.isPointerTy())
+ if (!argTy.isa<LLVM::LLVMPointerType>())
return func.emitError(
"llvm.noalias attribute attached to LLVM non-pointer argument");
if (attr.getValue())
@@ -981,7 +981,7 @@ LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) {
// NB: Attribute already verified to be int, so check if we can indeed
// attach the attribute to this argument, based on its type.
auto argTy = mlirArg.getType().dyn_cast<LLVM::LLVMType>();
- if (!argTy.isPointerTy())
+ if (!argTy.isa<LLVM::LLVMPointerType>())
return func.emitError(
"llvm.align attribute attached to LLVM non-pointer argument");
llvmArg.addAttrs(
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index 9461ebbd9ede..d02c252c0bf3 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -98,7 +98,7 @@ func @gep_non_function_type(%pos : !llvm.i64, %base : !llvm.ptr<float>) {
// -----
func @load_non_llvm_type(%foo : memref<f32>) {
- // expected-error at +1 {{expected LLVM IR dialect type}}
+ // expected-error at +1 {{expected LLVM pointer type}}
llvm.load %foo : memref<f32>
}
@@ -112,7 +112,7 @@ func @load_non_ptr_type(%foo : !llvm.float) {
// -----
func @store_non_llvm_type(%foo : memref<f32>, %bar : !llvm.float) {
- // expected-error at +1 {{expected LLVM IR dialect type}}
+ // expected-error at +1 {{expected LLVM pointer type}}
llvm.store %bar, %foo : memref<f32>
}
@@ -267,7 +267,7 @@ func @insertvalue_array_out_of_bounds() {
// -----
func @insertvalue_wrong_nesting() {
- // expected-error at +1 {{expected wrapped LLVM IR structure/array type}}
+ // expected-error at +1 {{expected LLVM IR structure/array type}}
llvm.insertvalue %a, %b[0,0] : !llvm.struct<(i32)>
}
@@ -311,7 +311,7 @@ func @extractvalue_array_out_of_bounds() {
// -----
func @extractvalue_wrong_nesting() {
- // expected-error at +1 {{expected wrapped LLVM IR structure/array type}}
+ // expected-error at +1 {{expected LLVM IR structure/array type}}
llvm.extractvalue %b[0,0] : !llvm.struct<(i32)>
}
More information about the Mlir-commits
mailing list