[Mlir-commits] [mlir] 7ed9cfc - [mlir] Remove static constructors from LLVMType
Alex Zinenko
llvmlistbot at llvm.org
Wed Dec 23 04:12:55 PST 2020
Author: Alex Zinenko
Date: 2020-12-23T13:12:47+01:00
New Revision: 7ed9cfc7b19fdba9eb441ce1a8ba82cda14d76a8
URL: https://github.com/llvm/llvm-project/commit/7ed9cfc7b19fdba9eb441ce1a8ba82cda14d76a8
DIFF: https://github.com/llvm/llvm-project/commit/7ed9cfc7b19fdba9eb441ce1a8ba82cda14d76a8.diff
LOG: [mlir] Remove static constructors from LLVMType
LLVMType contains numerous static constructors that were initially introduced
for API compatibility with LLVM. Most of these merely forward to arguments to
`SpecificType::get` (MLIR defines classes for all types, unlike LLVM IR), while
some introduce subtle semantics differences due to different modeling of MLIR
types (e.g., structs are not auto-renamed in case of conflicts). Furthermore,
these constructors don't match MLIR idioms and actively prevent us from making
the LLVM dialect type system more open. Remove them and use `SpecificType::get`
instead.
Depends On D93680
Reviewed By: mehdi_amini
Differential Revision: https://reviews.llvm.org/D93681
Added:
Modified:
mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp
mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h
mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp
mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp
mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
mlir/test/lib/Transforms/TestConvertCallOp.cpp
Removed:
################################################################################
diff --git a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
index a04b3ecd4dae..6fbf29f4128d 100644
--- a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
+++ b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
@@ -111,10 +111,11 @@ class PrintOpLowering : public ConversionPattern {
// Create a function declaration for printf, the signature is:
// * `i32 (i8*, ...)`
- auto llvmI32Ty = LLVM::LLVMType::getInt32Ty(context);
- auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(context);
- auto llvmFnType = LLVM::LLVMType::getFunctionTy(llvmI32Ty, llvmI8PtrTy,
- /*isVarArg=*/true);
+ auto llvmI32Ty = LLVM::LLVMIntegerType::get(context, 32);
+ auto llvmI8PtrTy =
+ LLVM::LLVMPointerType::get(LLVM::LLVMIntegerType::get(context, 8));
+ auto llvmFnType = LLVM::LLVMFunctionType::get(llvmI32Ty, llvmI8PtrTy,
+ /*isVarArg=*/true);
// Insert the printf function into the body of the parent module.
PatternRewriter::InsertionGuard insertGuard(rewriter);
@@ -133,8 +134,8 @@ class PrintOpLowering : public ConversionPattern {
if (!(global = module.lookupSymbol<LLVM::GlobalOp>(name))) {
OpBuilder::InsertionGuard insertGuard(builder);
builder.setInsertionPointToStart(module.getBody());
- auto type = LLVM::LLVMType::getArrayTy(
- LLVM::LLVMType::getInt8Ty(builder.getContext()), value.size());
+ auto type = LLVM::LLVMArrayType::get(
+ LLVM::LLVMIntegerType::get(builder.getContext(), 8), value.size());
global = builder.create<LLVM::GlobalOp>(loc, type, /*isConstant=*/true,
LLVM::Linkage::Internal, name,
builder.getStringAttr(value));
@@ -143,11 +144,13 @@ class PrintOpLowering : public ConversionPattern {
// Get the pointer to the first character in the global string.
Value globalPtr = builder.create<LLVM::AddressOfOp>(loc, global);
Value cst0 = builder.create<LLVM::ConstantOp>(
- loc, LLVM::LLVMType::getInt64Ty(builder.getContext()),
+ loc, LLVM::LLVMIntegerType::get(builder.getContext(), 64),
builder.getIntegerAttr(builder.getIndexType(), 0));
return builder.create<LLVM::GEPOp>(
- loc, LLVM::LLVMType::getInt8PtrTy(builder.getContext()), globalPtr,
- ArrayRef<Value>({cst0, cst0}));
+ loc,
+ LLVM::LLVMPointerType::get(
+ LLVM::LLVMIntegerType::get(builder.getContext(), 8)),
+ globalPtr, ArrayRef<Value>({cst0, cst0}));
}
};
} // end anonymous namespace
diff --git a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
index a04b3ecd4dae..6fbf29f4128d 100644
--- a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
+++ b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
@@ -111,10 +111,11 @@ class PrintOpLowering : public ConversionPattern {
// Create a function declaration for printf, the signature is:
// * `i32 (i8*, ...)`
- auto llvmI32Ty = LLVM::LLVMType::getInt32Ty(context);
- auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(context);
- auto llvmFnType = LLVM::LLVMType::getFunctionTy(llvmI32Ty, llvmI8PtrTy,
- /*isVarArg=*/true);
+ auto llvmI32Ty = LLVM::LLVMIntegerType::get(context, 32);
+ auto llvmI8PtrTy =
+ LLVM::LLVMPointerType::get(LLVM::LLVMIntegerType::get(context, 8));
+ auto llvmFnType = LLVM::LLVMFunctionType::get(llvmI32Ty, llvmI8PtrTy,
+ /*isVarArg=*/true);
// Insert the printf function into the body of the parent module.
PatternRewriter::InsertionGuard insertGuard(rewriter);
@@ -133,8 +134,8 @@ class PrintOpLowering : public ConversionPattern {
if (!(global = module.lookupSymbol<LLVM::GlobalOp>(name))) {
OpBuilder::InsertionGuard insertGuard(builder);
builder.setInsertionPointToStart(module.getBody());
- auto type = LLVM::LLVMType::getArrayTy(
- LLVM::LLVMType::getInt8Ty(builder.getContext()), value.size());
+ auto type = LLVM::LLVMArrayType::get(
+ LLVM::LLVMIntegerType::get(builder.getContext(), 8), value.size());
global = builder.create<LLVM::GlobalOp>(loc, type, /*isConstant=*/true,
LLVM::Linkage::Internal, name,
builder.getStringAttr(value));
@@ -143,11 +144,13 @@ class PrintOpLowering : public ConversionPattern {
// Get the pointer to the first character in the global string.
Value globalPtr = builder.create<LLVM::AddressOfOp>(loc, global);
Value cst0 = builder.create<LLVM::ConstantOp>(
- loc, LLVM::LLVMType::getInt64Ty(builder.getContext()),
+ loc, LLVM::LLVMIntegerType::get(builder.getContext(), 64),
builder.getIntegerAttr(builder.getIndexType(), 0));
return builder.create<LLVM::GEPOp>(
- loc, LLVM::LLVMType::getInt8PtrTy(builder.getContext()), globalPtr,
- ArrayRef<Value>({cst0, cst0}));
+ loc,
+ LLVM::LLVMPointerType::get(
+ LLVM::LLVMIntegerType::get(builder.getContext(), 8)),
+ globalPtr, ArrayRef<Value>({cst0, cst0}));
}
};
} // end anonymous namespace
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 552fe15e6899..4968b33f47a4 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -150,7 +150,7 @@ def LLVM_ICmpOp : LLVM_Op<"icmp", [NoSideEffect]> {
let builders = [
OpBuilderDAG<(ins "ICmpPredicate":$predicate, "Value":$lhs, "Value":$rhs),
[{
- build($_builder, $_state, LLVMType::getInt1Ty(lhs.getType().getContext()),
+ build($_builder, $_state, LLVMIntegerType::get(lhs.getType().getContext(), 1),
$_builder.getI64IntegerAttr(static_cast<int64_t>(predicate)), lhs, rhs);
}]>];
let parser = [{ return parseCmpOp<ICmpPredicate>(parser, result); }];
@@ -198,7 +198,7 @@ def LLVM_FCmpOp : LLVM_Op<"fcmp", [NoSideEffect]> {
let builders = [
OpBuilderDAG<(ins "FCmpPredicate":$predicate, "Value":$lhs, "Value":$rhs),
[{
- build($_builder, $_state, LLVMType::getInt1Ty(lhs.getType().getContext()),
+ build($_builder, $_state, LLVMIntegerType::get(lhs.getType().getContext(), 1),
$_builder.getI64IntegerAttr(static_cast<int64_t>(predicate)), lhs, rhs);
}]>];
let parser = [{ return parseCmpOp<FCmpPredicate>(parser, result); }];
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
index e1938c12c809..7c7731946ba8 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
@@ -53,9 +53,7 @@ class LLVMIntegerType;
///
/// The LLVM dialect in MLIR fully reflects the LLVM IR type system, prodiving a
/// separate MLIR type for each LLVM IR type. All types are represented as
-/// separate subclasses and are compatible with the isa/cast infrastructure. For
-/// convenience, the base class provides most of the APIs available on
-/// llvm::Type in addition to MLIR-compatible APIs.
+/// separate subclasses and are compatible with the isa/cast infrastructure.
///
/// The LLVM dialect type system is closed: parametric types can only refer to
/// other LLVM dialect types. This is consistent with LLVM IR and enables a more
@@ -64,6 +62,11 @@ class LLVMIntegerType;
/// Similarly to other MLIR types, LLVM dialect types are owned by the MLIR
/// context, have an immutable identifier (for most types except identified
/// structs, the entire type is the identifier) and are thread-safe.
+///
+/// This class is a thin common base class for
diff erent types available in the
+/// LLVM dialect. It intentionally does not provide the API similar to
+/// llvm::Type to avoid confusion and highlight potentially expensive operations
+/// (e.g., type creation in MLIR takes a lock, so it's better to cache types).
class LLVMType : public Type {
public:
/// Inherit base constructors.
@@ -79,98 +82,6 @@ class LLVMType : public Type {
static bool classof(Type type);
LLVMDialect &getDialect();
-
- /// Utilities used to generate floating point types.
- static LLVMType getDoubleTy(MLIRContext *context);
- static LLVMType getFloatTy(MLIRContext *context);
- static LLVMType getBFloatTy(MLIRContext *context);
- static LLVMType getHalfTy(MLIRContext *context);
- static LLVMType getFP128Ty(MLIRContext *context);
- static LLVMType getX86_FP80Ty(MLIRContext *context);
-
- /// Utilities used to generate integer types.
- static LLVMType getIntNTy(MLIRContext *context, unsigned numBits);
- static LLVMType getInt1Ty(MLIRContext *context) {
- return getIntNTy(context, /*numBits=*/1);
- }
- static LLVMType getInt8Ty(MLIRContext *context) {
- return getIntNTy(context, /*numBits=*/8);
- }
- static LLVMType getInt8PtrTy(MLIRContext *context);
- static LLVMType getInt16Ty(MLIRContext *context) {
- return getIntNTy(context, /*numBits=*/16);
- }
- static LLVMType getInt32Ty(MLIRContext *context) {
- return getIntNTy(context, /*numBits=*/32);
- }
- static LLVMType getInt64Ty(MLIRContext *context) {
- return getIntNTy(context, /*numBits=*/64);
- }
-
- /// Utilities used to generate other miscellaneous types.
- static LLVMType getArrayTy(LLVMType elementType, uint64_t numElements);
- static LLVMType getFunctionTy(LLVMType result, ArrayRef<LLVMType> params,
- bool isVarArg);
- static LLVMType getFunctionTy(LLVMType result, bool isVarArg) {
- return getFunctionTy(result, llvm::None, isVarArg);
- }
- static LLVMType getStructTy(MLIRContext *context, ArrayRef<LLVMType> elements,
- bool isPacked = false);
- static LLVMType getStructTy(MLIRContext *context, bool isPacked = false) {
- return getStructTy(context, llvm::None, isPacked);
- }
- template <typename... Args>
- static typename std::enable_if<llvm::are_base_of<LLVMType, Args...>::value,
- LLVMType>::type
- getStructTy(LLVMType elt1, Args... elts) {
- SmallVector<LLVMType, 8> fields({elt1, elts...});
- return getStructTy(elt1.getContext(), fields);
- }
- static LLVMType getVectorTy(LLVMType elementType, unsigned numElements);
-
- /// Void type utilities.
- static LLVMType getVoidTy(MLIRContext *context);
-
- // Creation and setting of LLVM's identified struct types
- static LLVMType createStructTy(MLIRContext *context,
- ArrayRef<LLVMType> elements,
- Optional<StringRef> name,
- bool isPacked = false);
-
- static LLVMType createStructTy(MLIRContext *context,
- Optional<StringRef> name) {
- return createStructTy(context, llvm::None, name);
- }
-
- static LLVMType createStructTy(ArrayRef<LLVMType> elements,
- Optional<StringRef> name,
- bool isPacked = false) {
- assert(!elements.empty() &&
- "This method may not be invoked with an empty list");
- LLVMType ele0 = elements.front();
- return createStructTy(ele0.getContext(), elements, name, isPacked);
- }
-
- template <typename... Args>
- static typename std::enable_if_t<llvm::are_base_of<LLVMType, Args...>::value,
- LLVMType>
- createStructTy(StringRef name, LLVMType elt1, Args... elts) {
- SmallVector<LLVMType, 8> fields({elt1, elts...});
- Optional<StringRef> opt_name(name);
- return createStructTy(elt1.getContext(), fields, opt_name);
- }
-
- static LLVMType setStructTyBody(LLVMType structType,
- ArrayRef<LLVMType> elements,
- bool isPacked = false);
-
- template <typename... Args>
- static typename std::enable_if_t<llvm::are_base_of<LLVMType, Args...>::value,
- LLVMType>
- setStructTyBody(LLVMType structType, LLVMType elt1, Args... elts) {
- SmallVector<LLVMType, 8> fields({elt1, elts...});
- return setStructTyBody(structType, fields);
- }
};
//===----------------------------------------------------------------------===//
@@ -386,6 +297,14 @@ class LLVMStructType : public Type::TypeBase<LLVMStructType, LLVMType,
static LLVMStructType getIdentified(MLIRContext *context, StringRef name);
static LLVMStructType getIdentifiedChecked(Location loc, StringRef name);
+ /// Gets a new identified struct with the given body. The body _cannot_ be
+ /// changed later. If a struct with the given name already exists, renames
+ /// the struct by appending a `.` followed by a number to the name. Renaming
+ /// happens even if the existing struct has the same body.
+ static LLVMStructType getNewIdentified(MLIRContext *context, StringRef name,
+ ArrayRef<LLVMType> elements,
+ bool isPacked = false);
+
/// Gets or creates a literal struct with the given body in the provided
/// context.
static LLVMStructType getLiteral(MLIRContext *context,
diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
index 2415924557db..3daa70b0a952 100644
--- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
+++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
@@ -52,7 +52,7 @@ namespace {
// Async Runtime API function types.
struct AsyncAPI {
static FunctionType addOrDropRefFunctionType(MLIRContext *ctx) {
- auto ref = LLVM::LLVMType::getInt8PtrTy(ctx);
+ auto ref = LLVM::LLVMPointerType::get(LLVM::LLVMIntegerType::get(ctx, 8));
auto count = IntegerType::get(ctx, 32);
return FunctionType::get(ctx, {ref, count}, {});
}
@@ -78,7 +78,7 @@ struct AsyncAPI {
}
static FunctionType executeFunctionType(MLIRContext *ctx) {
- auto hdl = LLVM::LLVMType::getInt8PtrTy(ctx);
+ auto hdl = LLVM::LLVMPointerType::get(LLVM::LLVMIntegerType::get(ctx, 8));
auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
return FunctionType::get(ctx, {hdl, resume}, {});
}
@@ -90,22 +90,22 @@ struct AsyncAPI {
}
static FunctionType awaitAndExecuteFunctionType(MLIRContext *ctx) {
- auto hdl = LLVM::LLVMType::getInt8PtrTy(ctx);
+ auto hdl = LLVM::LLVMPointerType::get(LLVM::LLVMIntegerType::get(ctx, 8));
auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
return FunctionType::get(ctx, {TokenType::get(ctx), hdl, resume}, {});
}
static FunctionType awaitAllAndExecuteFunctionType(MLIRContext *ctx) {
- auto hdl = LLVM::LLVMType::getInt8PtrTy(ctx);
+ auto hdl = LLVM::LLVMPointerType::get(LLVM::LLVMIntegerType::get(ctx, 8));
auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
return FunctionType::get(ctx, {GroupType::get(ctx), hdl, resume}, {});
}
// Auxiliary coroutine resume intrinsic wrapper.
static LLVM::LLVMType resumeFunctionType(MLIRContext *ctx) {
- auto voidTy = LLVM::LLVMType::getVoidTy(ctx);
- auto i8Ptr = LLVM::LLVMType::getInt8PtrTy(ctx);
- return LLVM::LLVMType::getFunctionTy(voidTy, {i8Ptr}, false);
+ auto voidTy = LLVM::LLVMVoidType::get(ctx);
+ auto i8Ptr = LLVM::LLVMPointerType::get(LLVM::LLVMIntegerType::get(ctx, 8));
+ return LLVM::LLVMFunctionType::get(voidTy, {i8Ptr}, false);
}
};
} // namespace
@@ -155,7 +155,7 @@ static void addLLVMFuncDecl(ModuleOp module, ImplicitLocOpBuilder &builder,
ArrayRef<LLVM::LLVMType> params) {
if (module.lookupSymbol(name))
return;
- LLVM::LLVMType type = LLVM::LLVMType::getFunctionTy(ret, params, false);
+ LLVM::LLVMType type = LLVM::LLVMFunctionType::get(ret, params);
builder.create<LLVM::LLVMFuncOp>(name, type);
}
@@ -168,13 +168,13 @@ static void addCoroutineIntrinsicsDeclarations(ModuleOp module) {
module.getBody()->getTerminator());
auto token = LLVMTokenType::get(ctx);
- auto voidTy = LLVMType::getVoidTy(ctx);
+ auto voidTy = LLVMVoidType::get(ctx);
- auto i8 = LLVMType::getInt8Ty(ctx);
- auto i1 = LLVMType::getInt1Ty(ctx);
- auto i32 = LLVMType::getInt32Ty(ctx);
- auto i64 = LLVMType::getInt64Ty(ctx);
- auto i8Ptr = LLVMType::getInt8PtrTy(ctx);
+ auto i8 = LLVMIntegerType::get(ctx, 8);
+ auto i1 = LLVMIntegerType::get(ctx, 1);
+ auto i32 = LLVMIntegerType::get(ctx, 32);
+ auto i64 = LLVMIntegerType::get(ctx, 64);
+ auto i8Ptr = LLVMPointerType::get(i8);
addLLVMFuncDecl(module, builder, kCoroId, token, {i32, i8Ptr, i8Ptr, i8Ptr});
addLLVMFuncDecl(module, builder, kCoroSizeI64, i64, {});
@@ -201,9 +201,9 @@ static void addCRuntimeDeclarations(ModuleOp module) {
ImplicitLocOpBuilder builder(module.getLoc(),
module.getBody()->getTerminator());
- auto voidTy = LLVMType::getVoidTy(ctx);
- auto i64 = LLVMType::getInt64Ty(ctx);
- auto i8Ptr = LLVMType::getInt8PtrTy(ctx);
+ auto voidTy = LLVMVoidType::get(ctx);
+ auto i64 = LLVMIntegerType::get(ctx, 64);
+ auto i8Ptr = LLVMPointerType::get(LLVMIntegerType::get(ctx, 8));
addLLVMFuncDecl(module, builder, kMalloc, i8Ptr, {i64});
addLLVMFuncDecl(module, builder, kFree, voidTy, {i8Ptr});
@@ -227,11 +227,11 @@ static void addResumeFunction(ModuleOp module) {
if (module.lookupSymbol(kResume))
return;
- auto voidTy = LLVM::LLVMType::getVoidTy(ctx);
- auto i8Ptr = LLVM::LLVMType::getInt8PtrTy(ctx);
+ auto voidTy = LLVM::LLVMVoidType::get(ctx);
+ auto i8Ptr = LLVM::LLVMPointerType::get(LLVM::LLVMIntegerType::get(ctx, 8));
auto resumeOp = moduleBuilder.create<LLVM::LLVMFuncOp>(
- loc, kResume, LLVM::LLVMType::getFunctionTy(voidTy, {i8Ptr}, false));
+ loc, kResume, LLVM::LLVMFunctionType::get(voidTy, {i8Ptr}));
resumeOp.setPrivate();
auto *block = resumeOp.addEntryBlock();
@@ -297,10 +297,10 @@ static CoroMachinery setupCoroMachinery(FuncOp func) {
MLIRContext *ctx = func.getContext();
auto token = LLVM::LLVMTokenType::get(ctx);
- auto i1 = LLVM::LLVMType::getInt1Ty(ctx);
- auto i32 = LLVM::LLVMType::getInt32Ty(ctx);
- auto i64 = LLVM::LLVMType::getInt64Ty(ctx);
- auto i8Ptr = LLVM::LLVMType::getInt8PtrTy(ctx);
+ auto i1 = LLVM::LLVMIntegerType::get(ctx, 1);
+ auto i32 = LLVM::LLVMIntegerType::get(ctx, 32);
+ auto i64 = LLVM::LLVMIntegerType::get(ctx, 64);
+ auto i8Ptr = LLVM::LLVMPointerType::get(LLVM::LLVMIntegerType::get(ctx, 8));
Block *entryBlock = func.addEntryBlock();
Location loc = func.getBody().getLoc();
@@ -421,8 +421,8 @@ static void addSuspensionPoint(CoroMachinery coro, Value coroState,
OpBuilder &builder) {
Location loc = op->getLoc();
MLIRContext *ctx = op->getContext();
- auto i1 = LLVM::LLVMType::getInt1Ty(ctx);
- auto i8 = LLVM::LLVMType::getInt8Ty(ctx);
+ auto i1 = LLVM::LLVMIntegerType::get(ctx, 1);
+ auto i8 = LLVM::LLVMIntegerType::get(ctx, 8);
// Add a coroutine suspension in place of original `op` in the split block.
OpBuilder::InsertionGuard guard(builder);
@@ -568,7 +568,7 @@ class AsyncRuntimeTypeConverter : public TypeConverter {
MLIRContext *ctx = type.getContext();
// Convert async tokens and groups to opaque pointers.
if (type.isa<TokenType, GroupType>())
- return LLVM::LLVMType::getInt8PtrTy(ctx);
+ return LLVM::LLVMPointerType::get(LLVM::LLVMIntegerType::get(ctx, 8));
return type;
}
};
diff --git a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp
index d35aa0346f74..6859834de67f 100644
--- a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp
+++ b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp
@@ -55,8 +55,7 @@ class FunctionCallBuilder {
FunctionCallBuilder(StringRef functionName, LLVM::LLVMType returnType,
ArrayRef<LLVM::LLVMType> argumentTypes)
: functionName(functionName),
- functionType(LLVM::LLVMFunctionType::get(returnType, argumentTypes,
- /*isVarArg=*/false)) {}
+ functionType(LLVM::LLVMFunctionType::get(returnType, argumentTypes)) {}
LLVM::CallOp create(Location loc, OpBuilder &builder,
ArrayRef<Value> arguments) const;
@@ -74,14 +73,15 @@ class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> {
protected:
MLIRContext *context = &this->getTypeConverter()->getContext();
- LLVM::LLVMType llvmVoidType = LLVM::LLVMType::getVoidTy(context);
- LLVM::LLVMType llvmPointerType = LLVM::LLVMType::getInt8PtrTy(context);
+ LLVM::LLVMType llvmVoidType = LLVM::LLVMVoidType::get(context);
+ LLVM::LLVMType llvmPointerType =
+ LLVM::LLVMPointerType::get(LLVM::LLVMIntegerType::get(context, 8));
LLVM::LLVMType llvmPointerPointerType =
LLVM::LLVMPointerType::get(llvmPointerType);
- LLVM::LLVMType llvmInt8Type = LLVM::LLVMType::getInt8Ty(context);
- LLVM::LLVMType llvmInt32Type = LLVM::LLVMType::getInt32Ty(context);
- LLVM::LLVMType llvmInt64Type = LLVM::LLVMType::getInt64Ty(context);
- LLVM::LLVMType llvmIntPtrType = LLVM::LLVMType::getIntNTy(
+ LLVM::LLVMType llvmInt8Type = LLVM::LLVMIntegerType::get(context, 8);
+ LLVM::LLVMType llvmInt32Type = LLVM::LLVMIntegerType::get(context, 32);
+ LLVM::LLVMType llvmInt64Type = LLVM::LLVMIntegerType::get(context, 64);
+ LLVM::LLVMType llvmIntPtrType = LLVM::LLVMIntegerType::get(
context, this->getTypeConverter()->getPointerBitwidth(0));
FunctionCallBuilder moduleLoadCallBuilder = {
@@ -515,7 +515,8 @@ Value ConvertLaunchFuncOpToGpuRuntimeCallPattern::generateParamsArray(
argumentTypes.reserve(numArguments);
for (auto argument : arguments)
argumentTypes.push_back(argument.getType().cast<LLVM::LLVMType>());
- auto structType = LLVM::LLVMType::createStructTy(argumentTypes, StringRef());
+ auto structType = LLVM::LLVMStructType::getNewIdentified(context, StringRef(),
+ argumentTypes);
auto one = builder.create<LLVM::ConstantOp>(loc, llvmInt32Type,
builder.getI32IntegerAttr(1));
auto structPtr = builder.create<LLVM::AllocaOp>(
@@ -716,10 +717,10 @@ mlir::createGpuToLLVMConversionPass(StringRef gpuBinaryAnnotation) {
void mlir::populateGpuToLLVMConversionPatterns(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
StringRef gpuBinaryAnnotation) {
- converter.addConversion(
- [context = &converter.getContext()](gpu::AsyncTokenType type) -> Type {
- return LLVM::LLVMType::getInt8PtrTy(context);
- });
+ converter.addConversion([context = &converter.getContext()](
+ gpu::AsyncTokenType type) -> Type {
+ return LLVM::LLVMPointerType::get(LLVM::LLVMIntegerType::get(context, 8));
+ });
patterns.insert<ConvertAllocOpToGpuRuntimeCallPattern,
ConvertDeallocOpToGpuRuntimeCallPattern,
ConvertHostRegisterOpToGpuRuntimeCallPattern,
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
index 914b7ee50cf9..3acc73415ef1 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
@@ -39,7 +39,7 @@ struct GPUFuncOpLowering : ConvertOpToLLVMPattern<gpu::GPUFuncOp> {
auto elementType = typeConverter->convertType(type.getElementType())
.template cast<LLVM::LLVMType>();
- auto arrayType = LLVM::LLVMType::getArrayTy(elementType, numElements);
+ auto arrayType = LLVM::LLVMArrayType::get(elementType, numElements);
std::string name = std::string(
llvm::formatv("__wg_{0}_{1}", gpuFuncOp.getName(), en.index()));
auto globalOp = rewriter.create<LLVM::GlobalOp>(
@@ -85,7 +85,7 @@ struct GPUFuncOpLowering : ConvertOpToLLVMPattern<gpu::GPUFuncOp> {
// Rewrite workgroup memory attributions to addresses of global buffers.
rewriter.setInsertionPointToStart(&gpuFuncOp.front());
unsigned numProperArguments = gpuFuncOp.getNumArguments();
- auto i32Type = LLVM::LLVMType::getInt32Ty(rewriter.getContext());
+ auto i32Type = LLVM::LLVMIntegerType::get(rewriter.getContext(), 32);
Value zero = nullptr;
if (!workgroupBuffers.empty())
@@ -114,7 +114,7 @@ struct GPUFuncOpLowering : ConvertOpToLLVMPattern<gpu::GPUFuncOp> {
// Rewrite private memory attributions to alloca'ed buffers.
unsigned numWorkgroupAttributions =
gpuFuncOp.getNumWorkgroupAttributions();
- auto int64Ty = LLVM::LLVMType::getInt64Ty(rewriter.getContext());
+ auto int64Ty = LLVM::LLVMIntegerType::get(rewriter.getContext(), 64);
for (auto en : llvm::enumerate(gpuFuncOp.getPrivateAttributions())) {
Value attribution = en.value();
auto type = attribution.getType().cast<MemRefType>();
diff --git a/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h b/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h
index a51dff51cac4..0a1e76b99dbe 100644
--- a/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h
@@ -48,13 +48,16 @@ struct GPUIndexIntrinsicOpLowering : public ConvertOpToLLVMPattern<Op> {
Value newOp;
switch (dimensionToIndex(op)) {
case X:
- newOp = rewriter.create<XOp>(loc, LLVM::LLVMType::getInt32Ty(context));
+ newOp =
+ rewriter.create<XOp>(loc, LLVM::LLVMIntegerType::get(context, 32));
break;
case Y:
- newOp = rewriter.create<YOp>(loc, LLVM::LLVMType::getInt32Ty(context));
+ newOp =
+ rewriter.create<YOp>(loc, LLVM::LLVMIntegerType::get(context, 32));
break;
case Z:
- newOp = rewriter.create<ZOp>(loc, LLVM::LLVMType::getInt32Ty(context));
+ newOp =
+ rewriter.create<ZOp>(loc, LLVM::LLVMIntegerType::get(context, 32));
break;
default:
return failure();
@@ -62,10 +65,10 @@ struct GPUIndexIntrinsicOpLowering : public ConvertOpToLLVMPattern<Op> {
if (indexBitwidth > 32) {
newOp = rewriter.create<LLVM::SExtOp>(
- loc, LLVM::LLVMType::getIntNTy(context, indexBitwidth), newOp);
+ loc, LLVM::LLVMIntegerType::get(context, indexBitwidth), newOp);
} else if (indexBitwidth < 32) {
newOp = rewriter.create<LLVM::TruncOp>(
- loc, LLVM::LLVMType::getIntNTy(context, indexBitwidth), newOp);
+ loc, LLVM::LLVMIntegerType::get(context, indexBitwidth), newOp);
}
rewriter.replaceOp(op, {newOp});
diff --git a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
index b2887aa1d782..631eca5cd32d 100644
--- a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
@@ -85,7 +85,7 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
return operand;
return rewriter.create<LLVM::FPExtOp>(
- operand.getLoc(), LLVM::LLVMType::getFloatTy(rewriter.getContext()),
+ operand.getLoc(), LLVM::LLVMFloatType::get(rewriter.getContext()),
operand);
}
@@ -96,8 +96,7 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
for (Value operand : operands) {
operandTypes.push_back(operand.getType().cast<LLVMType>());
}
- return LLVMType::getFunctionTy(resultType, operandTypes,
- /*isVarArg=*/false);
+ return LLVM::LLVMFunctionType::get(resultType, operandTypes);
}
StringRef getFunctionName(LLVM::LLVMType type) const {
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index cea1cdc7e25f..f747f519c66b 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -57,10 +57,10 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
gpu::ShuffleOpAdaptor adaptor(operands);
auto valueTy = adaptor.value().getType().cast<LLVM::LLVMType>();
- auto int32Type = LLVM::LLVMType::getInt32Ty(rewriter.getContext());
- auto predTy = LLVM::LLVMType::getInt1Ty(rewriter.getContext());
- auto resultTy =
- LLVM::LLVMType::getStructTy(rewriter.getContext(), {valueTy, predTy});
+ auto int32Type = LLVM::LLVMIntegerType::get(rewriter.getContext(), 32);
+ auto predTy = LLVM::LLVMIntegerType::get(rewriter.getContext(), 1);
+ auto resultTy = LLVM::LLVMStructType::getLiteral(rewriter.getContext(),
+ {valueTy, predTy});
Value one = rewriter.create<LLVM::ConstantOp>(
loc, int32Type, rewriter.getI32IntegerAttr(1));
diff --git a/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp b/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
index c676cd256d66..4b657d25f51e 100644
--- a/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
+++ b/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
@@ -57,11 +57,12 @@ class VulkanLaunchFuncToVulkanCallsPass
VulkanLaunchFuncToVulkanCallsPass> {
private:
void initializeCachedTypes() {
- llvmFloatType = LLVM::LLVMType::getFloatTy(&getContext());
- llvmVoidType = LLVM::LLVMType::getVoidTy(&getContext());
- llvmPointerType = LLVM::LLVMType::getInt8PtrTy(&getContext());
- llvmInt32Type = LLVM::LLVMType::getInt32Ty(&getContext());
- llvmInt64Type = LLVM::LLVMType::getInt64Ty(&getContext());
+ llvmFloatType = LLVM::LLVMFloatType::get(&getContext());
+ llvmVoidType = LLVM::LLVMVoidType::get(&getContext());
+ llvmPointerType = LLVM::LLVMPointerType::get(
+ LLVM::LLVMIntegerType::get(&getContext(), 8));
+ llvmInt32Type = LLVM::LLVMIntegerType::get(&getContext(), 32);
+ llvmInt64Type = LLVM::LLVMIntegerType::get(&getContext(), 64);
}
LLVM::LLVMType getMemRefType(uint32_t rank, LLVM::LLVMType elemenType) {
@@ -77,12 +78,12 @@ class VulkanLaunchFuncToVulkanCallsPass
// };
auto llvmPtrToElementType = LLVM::LLVMPointerType::get(elemenType);
auto llvmArrayRankElementSizeType =
- LLVM::LLVMType::getArrayTy(getInt64Type(), rank);
+ LLVM::LLVMArrayType::get(getInt64Type(), rank);
// Create a type
// `!llvm<"{ `element-type`*, `element-type`*, i64,
// [`rank` x i64], [`rank` x i64]}">`.
- return LLVM::LLVMType::getStructTy(
+ return LLVM::LLVMStructType::getLiteral(
&getContext(),
{llvmPtrToElementType, llvmPtrToElementType, getInt64Type(),
llvmArrayRankElementSizeType, llvmArrayRankElementSizeType});
@@ -242,7 +243,7 @@ void VulkanLaunchFuncToVulkanCallsPass::createBindMemRefCalls(
// int16_t and bitcast the descriptor.
if (type.isa<LLVM::LLVMHalfType>()) {
auto memRefTy =
- getMemRefType(rank, LLVM::LLVMType::getInt16Ty(&getContext()));
+ getMemRefType(rank, LLVM::LLVMIntegerType::get(&getContext(), 16));
ptrToMemRefDescriptor = builder.create<LLVM::BitcastOp>(
loc, LLVM::LLVMPointerType::get(memRefTy), ptrToMemRefDescriptor);
}
@@ -296,47 +297,46 @@ void VulkanLaunchFuncToVulkanCallsPass::declareVulkanFunctions(Location loc) {
if (!module.lookupSymbol(kSetEntryPoint)) {
builder.create<LLVM::LLVMFuncOp>(
loc, kSetEntryPoint,
- LLVM::LLVMType::getFunctionTy(getVoidType(),
- {getPointerType(), getPointerType()},
- /*isVarArg=*/false));
+ LLVM::LLVMFunctionType::get(getVoidType(),
+ {getPointerType(), getPointerType()}));
}
if (!module.lookupSymbol(kSetNumWorkGroups)) {
builder.create<LLVM::LLVMFuncOp>(
loc, kSetNumWorkGroups,
- LLVM::LLVMType::getFunctionTy(
- getVoidType(),
- {getPointerType(), getInt64Type(), getInt64Type(), getInt64Type()},
- /*isVarArg=*/false));
+ LLVM::LLVMFunctionType::get(getVoidType(),
+ {getPointerType(), getInt64Type(),
+ getInt64Type(), getInt64Type()}));
}
if (!module.lookupSymbol(kSetBinaryShader)) {
builder.create<LLVM::LLVMFuncOp>(
loc, kSetBinaryShader,
- LLVM::LLVMType::getFunctionTy(
- getVoidType(), {getPointerType(), getPointerType(), getInt32Type()},
- /*isVarArg=*/false));
+ LLVM::LLVMFunctionType::get(
+ getVoidType(),
+ {getPointerType(), getPointerType(), getInt32Type()}));
}
if (!module.lookupSymbol(kRunOnVulkan)) {
builder.create<LLVM::LLVMFuncOp>(
loc, kRunOnVulkan,
- LLVM::LLVMType::getFunctionTy(getVoidType(), {getPointerType()},
- /*isVarArg=*/false));
+ LLVM::LLVMFunctionType::get(getVoidType(), {getPointerType()}));
}
for (unsigned i = 1; i <= 3; i++) {
- for (LLVM::LLVMType type : {LLVM::LLVMType::getFloatTy(&getContext()),
- LLVM::LLVMType::getInt32Ty(&getContext()),
- LLVM::LLVMType::getInt16Ty(&getContext()),
- LLVM::LLVMType::getInt8Ty(&getContext()),
- LLVM::LLVMType::getHalfTy(&getContext())}) {
+ SmallVector<LLVM::LLVMType, 5> types{
+ LLVM::LLVMFloatType::get(&getContext()),
+ LLVM::LLVMIntegerType::get(&getContext(), 32),
+ LLVM::LLVMIntegerType::get(&getContext(), 16),
+ LLVM::LLVMIntegerType::get(&getContext(), 8),
+ LLVM::LLVMHalfType::get(&getContext())};
+ for (auto type : types) {
std::string fnName = "bindMemRef" + std::to_string(i) + "D" +
std::string(stringifyType(type));
if (type.isa<LLVM::LLVMHalfType>())
- type = LLVM::LLVMType::getInt16Ty(&getContext());
+ type = LLVM::LLVMIntegerType::get(&getContext(), 16);
if (!module.lookupSymbol(fnName)) {
- auto fnType = LLVM::LLVMType::getFunctionTy(
+ auto fnType = LLVM::LLVMFunctionType::get(
getVoidType(),
{getPointerType(), getInt32Type(), getInt32Type(),
LLVM::LLVMPointerType::get(getMemRefType(i, type))},
@@ -348,16 +348,13 @@ void VulkanLaunchFuncToVulkanCallsPass::declareVulkanFunctions(Location loc) {
if (!module.lookupSymbol(kInitVulkan)) {
builder.create<LLVM::LLVMFuncOp>(
- loc, kInitVulkan,
- LLVM::LLVMType::getFunctionTy(getPointerType(), {},
- /*isVarArg=*/false));
+ loc, kInitVulkan, LLVM::LLVMFunctionType::get(getPointerType(), {}));
}
if (!module.lookupSymbol(kDeinitVulkan)) {
builder.create<LLVM::LLVMFuncOp>(
loc, kDeinitVulkan,
- LLVM::LLVMType::getFunctionTy(getVoidType(), {getPointerType()},
- /*isVarArg=*/false));
+ LLVM::LLVMFunctionType::get(getVoidType(), {getPointerType()}));
}
}
diff --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
index 5546c82a9e69..c86ae710aac6 100644
--- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
+++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
@@ -86,7 +86,7 @@ static Type convertRangeType(RangeType t, LLVMTypeConverter &converter) {
auto *context = t.getContext();
auto int64Ty = converter.convertType(IntegerType::get(context, 64))
.cast<LLVM::LLVMType>();
- return LLVMType::getStructTy(int64Ty, int64Ty, int64Ty);
+ return LLVMStructType::getLiteral(context, {int64Ty, int64Ty, int64Ty});
}
namespace {
diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp
index 1724c7044339..3b4a8d66001d 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp
@@ -60,7 +60,7 @@ static unsigned calculateGlobalIndex(spirv::GlobalVariableOp op) {
static void copy(Location loc, Value dst, Value src, Value size,
OpBuilder &builder) {
MLIRContext *context = builder.getContext();
- auto llvmI1Type = LLVM::LLVMType::getInt1Ty(context);
+ auto llvmI1Type = LLVM::LLVMIntegerType::get(context, 1);
Value isVolatile = builder.create<LLVM::ConstantOp>(
loc, llvmI1Type, builder.getBoolAttr(false));
builder.create<LLVM::MemcpyOp>(loc, dst, src, size, isVolatile);
@@ -183,9 +183,8 @@ class GPULaunchLowering : public ConvertOpToLLVMPattern<gpu::LaunchFuncOp> {
rewriter.setInsertionPointToStart(module.getBody());
kernelFunc = rewriter.create<LLVM::LLVMFuncOp>(
rewriter.getUnknownLoc(), newKernelFuncName,
- LLVM::LLVMType::getFunctionTy(LLVM::LLVMType::getVoidTy(context),
- ArrayRef<LLVM::LLVMType>(),
- /*isVarArg=*/false));
+ LLVM::LLVMFunctionType::get(LLVM::LLVMVoidType::get(context),
+ ArrayRef<LLVM::LLVMType>()));
rewriter.setInsertionPoint(launchOp);
}
diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp
index 7da9c47f9219..2633f4bdfe6f 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp
@@ -195,8 +195,8 @@ convertStructTypeWithOffset(spirv::StructType type,
llvm::map_range(type.getElementTypes(), [&](Type elementType) {
return converter.convertType(elementType).cast<LLVM::LLVMType>();
}));
- return LLVM::LLVMType::getStructTy(type.getContext(), elementsVector,
- /*isPacked=*/false);
+ return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector,
+ /*isPacked=*/false);
}
/// Converts SPIR-V struct with no offset to packed LLVM struct.
@@ -206,15 +206,15 @@ static Type convertStructTypePacked(spirv::StructType type,
llvm::map_range(type.getElementTypes(), [&](Type elementType) {
return converter.convertType(elementType).cast<LLVM::LLVMType>();
}));
- return LLVM::LLVMType::getStructTy(type.getContext(), elementsVector,
- /*isPacked=*/true);
+ return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector,
+ /*isPacked=*/true);
}
/// Creates LLVM dialect constant with the given value.
static Value createI32ConstantOf(Location loc, PatternRewriter &rewriter,
unsigned value) {
return rewriter.create<LLVM::ConstantOp>(
- loc, LLVM::LLVMType::getInt32Ty(rewriter.getContext()),
+ loc, LLVM::LLVMIntegerType::get(rewriter.getContext(), 32),
rewriter.getIntegerAttr(rewriter.getI32Type(), value));
}
@@ -258,7 +258,7 @@ static Optional<Type> convertArrayType(spirv::ArrayType type,
auto llvmElementType =
converter.convertType(elementType).cast<LLVM::LLVMType>();
unsigned numElements = type.getNumElements();
- return LLVM::LLVMType::getArrayTy(llvmElementType, numElements);
+ return LLVM::LLVMArrayType::get(llvmElementType, numElements);
}
/// Converts SPIR-V pointer type to LLVM pointer. Pointer's storage class is not
@@ -279,7 +279,7 @@ static Optional<Type> convertRuntimeArrayType(spirv::RuntimeArrayType type,
return llvm::None;
auto elementType =
converter.convertType(type.getElementType()).cast<LLVM::LLVMType>();
- return LLVM::LLVMType::getArrayTy(elementType, 0);
+ return LLVM::LLVMArrayType::get(elementType, 0);
}
/// Converts SPIR-V struct to LLVM struct. There is no support of structs with
@@ -666,15 +666,15 @@ class ExecutionModePattern
// int32_t executionMode;
// int32_t values[]; // optional values
// };
- auto llvmI32Type = LLVM::LLVMType::getInt32Ty(context);
+ auto llvmI32Type = LLVM::LLVMIntegerType::get(context, 32);
SmallVector<LLVM::LLVMType, 2> fields;
fields.push_back(llvmI32Type);
ArrayAttr values = op.values();
if (!values.empty()) {
- auto arrayType = LLVM::LLVMType::getArrayTy(llvmI32Type, values.size());
+ auto arrayType = LLVM::LLVMArrayType::get(llvmI32Type, values.size());
fields.push_back(arrayType);
}
- auto structType = LLVM::LLVMType::getStructTy(context, fields);
+ auto structType = LLVM::LLVMStructType::getLiteral(context, fields);
// Create `llvm.mlir.global` with initializer region containing one block.
auto global = rewriter.create<LLVM::GlobalOp>(
diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index f4d1df81565b..233c2eadc77c 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -171,7 +171,7 @@ MLIRContext &LLVMTypeConverter::getContext() {
}
LLVM::LLVMType LLVMTypeConverter::getIndexType() {
- return LLVM::LLVMType::getIntNTy(&getContext(), getIndexTypeBitwidth());
+ return LLVM::LLVMIntegerType::get(&getContext(), getIndexTypeBitwidth());
}
unsigned LLVMTypeConverter::getPointerBitwidth(unsigned addressSpace) {
@@ -183,18 +183,18 @@ Type LLVMTypeConverter::convertIndexType(IndexType type) {
}
Type LLVMTypeConverter::convertIntegerType(IntegerType type) {
- return LLVM::LLVMType::getIntNTy(&getContext(), type.getWidth());
+ return LLVM::LLVMIntegerType::get(&getContext(), type.getWidth());
}
Type LLVMTypeConverter::convertFloatType(FloatType type) {
if (type.isa<Float32Type>())
- return LLVM::LLVMType::getFloatTy(&getContext());
+ return LLVM::LLVMFloatType::get(&getContext());
if (type.isa<Float64Type>())
- return LLVM::LLVMType::getDoubleTy(&getContext());
+ return LLVM::LLVMDoubleType::get(&getContext());
if (type.isa<Float16Type>())
- return LLVM::LLVMType::getHalfTy(&getContext());
+ return LLVM::LLVMHalfType::get(&getContext());
if (type.isa<BFloat16Type>())
- return LLVM::LLVMType::getBFloatTy(&getContext());
+ return LLVM::LLVMBFloatType::get(&getContext());
llvm_unreachable("non-float type in convertFloatType");
}
@@ -206,7 +206,8 @@ static constexpr unsigned kRealPosInComplexNumberStruct = 0;
static constexpr unsigned kImaginaryPosInComplexNumberStruct = 1;
Type LLVMTypeConverter::convertComplexType(ComplexType type) {
auto elementType = convertType(type.getElementType()).cast<LLVM::LLVMType>();
- return LLVM::LLVMType::getStructTy(&getContext(), {elementType, elementType});
+ return LLVM::LLVMStructType::getLiteral(&getContext(),
+ {elementType, elementType});
}
// Except for signatures, MLIR function types are converted into LLVM
@@ -249,11 +250,11 @@ LLVM::LLVMType LLVMTypeConverter::convertFunctionSignature(
// a struct.
LLVM::LLVMType resultType =
funcTy.getNumResults() == 0
- ? LLVM::LLVMType::getVoidTy(&getContext())
+ ? LLVM::LLVMVoidType::get(&getContext())
: unwrap(packFunctionResults(funcTy.getResults()));
if (!resultType)
return {};
- return LLVM::LLVMType::getFunctionTy(resultType, argTypes, isVariadic);
+ return LLVM::LLVMFunctionType::get(resultType, argTypes, isVariadic);
}
/// Converts the function type to a C-compatible format, in particular using
@@ -273,12 +274,12 @@ LLVMTypeConverter::convertFunctionTypeCWrapper(FunctionType type) {
LLVM::LLVMType resultType =
type.getNumResults() == 0
- ? LLVM::LLVMType::getVoidTy(&getContext())
+ ? LLVM::LLVMVoidType::get(&getContext())
: unwrap(packFunctionResults(type.getResults()));
if (!resultType)
return {};
- return LLVM::LLVMType::getFunctionTy(resultType, inputs, false);
+ return LLVM::LLVMFunctionType::get(resultType, inputs);
}
static constexpr unsigned kAllocatedPtrPosInMemRefDescriptor = 0;
@@ -335,7 +336,7 @@ LLVMTypeConverter::getMemRefDescriptorFields(MemRefType type,
if (unpackAggregates)
results.insert(results.end(), 2 * rank, indexTy);
else
- results.insert(results.end(), 2, LLVM::LLVMType::getArrayTy(indexTy, rank));
+ results.insert(results.end(), 2, LLVM::LLVMArrayType::get(indexTy, rank));
return results;
}
@@ -346,7 +347,7 @@ Type LLVMTypeConverter::convertMemRefType(MemRefType type) {
// unpack the `sizes` and `strides` arrays.
SmallVector<LLVM::LLVMType, 5> types =
getMemRefDescriptorFields(type, /*unpackAggregates=*/false);
- return LLVM::LLVMType::getStructTy(&getContext(), types);
+ return LLVM::LLVMStructType::getLiteral(&getContext(), types);
}
static constexpr unsigned kRankInUnrankedMemRefDescriptor = 0;
@@ -361,12 +362,13 @@ static constexpr unsigned kPtrInUnrankedMemRefDescriptor = 1;
/// be unranked.
SmallVector<LLVM::LLVMType, 2>
LLVMTypeConverter::getUnrankedMemRefDescriptorFields() {
- return {getIndexType(), LLVM::LLVMType::getInt8PtrTy(&getContext())};
+ return {getIndexType(), LLVM::LLVMPointerType::get(
+ LLVM::LLVMIntegerType::get(&getContext(), 8))};
}
Type LLVMTypeConverter::convertUnrankedMemRefType(UnrankedMemRefType type) {
- return LLVM::LLVMType::getStructTy(&getContext(),
- getUnrankedMemRefDescriptorFields());
+ return LLVM::LLVMStructType::getLiteral(&getContext(),
+ getUnrankedMemRefDescriptorFields());
}
/// Convert a memref type to a bare pointer to the memref element type.
@@ -407,11 +409,11 @@ Type LLVMTypeConverter::convertVectorType(VectorType type) {
auto elementType = unwrap(convertType(type.getElementType()));
if (!elementType)
return {};
- auto vectorType =
- LLVM::LLVMType::getVectorTy(elementType, type.getShape().back());
+ LLVM::LLVMType vectorType =
+ LLVM::LLVMFixedVectorType::get(elementType, type.getShape().back());
auto shape = type.getShape();
for (int i = shape.size() - 2; i >= 0; --i)
- vectorType = LLVM::LLVMType::getArrayTy(vectorType, shape[i]);
+ vectorType = LLVM::LLVMArrayType::get(vectorType, shape[i]);
return vectorType;
}
@@ -620,7 +622,7 @@ Value MemRefDescriptor::size(OpBuilder &builder, Location loc, Value pos,
int64_t rank) {
auto indexTy = indexType.cast<LLVM::LLVMType>();
auto indexPtrTy = LLVM::LLVMPointerType::get(indexTy);
- auto arrayTy = LLVM::LLVMType::getArrayTy(indexTy, rank);
+ auto arrayTy = LLVM::LLVMArrayType::get(indexTy, rank);
auto arrayPtrTy = LLVM::LLVMPointerType::get(arrayTy);
// Copy size values to stack-allocated memory.
@@ -949,8 +951,9 @@ Value UnrankedMemRefDescriptor::sizeBasePtr(
Value memRefDescPtr, LLVM::LLVMPointerType elemPtrPtrType) {
LLVM::LLVMType elemPtrTy = elemPtrPtrType.getElementType();
LLVM::LLVMType indexTy = typeConverter.getIndexType();
- LLVM::LLVMType structPtrTy = LLVM::LLVMPointerType::get(
- LLVM::LLVMType::getStructTy(elemPtrTy, elemPtrTy, indexTy, indexTy));
+ LLVM::LLVMType structPtrTy =
+ LLVM::LLVMPointerType::get(LLVM::LLVMStructType::getLiteral(
+ indexTy.getContext(), {elemPtrTy, elemPtrTy, indexTy, indexTy}));
Value structPtr =
builder.create<LLVM::BitcastOp>(loc, structPtrTy, memRefDescPtr);
@@ -1031,17 +1034,18 @@ LLVM::LLVMType ConvertToLLVMPattern::getIndexType() const {
LLVM::LLVMType
ConvertToLLVMPattern::getIntPtrType(unsigned addressSpace) const {
- return LLVM::LLVMType::getIntNTy(
+ return LLVM::LLVMIntegerType::get(
&getTypeConverter()->getContext(),
getTypeConverter()->getPointerBitwidth(addressSpace));
}
LLVM::LLVMType ConvertToLLVMPattern::getVoidType() const {
- return LLVM::LLVMType::getVoidTy(&getTypeConverter()->getContext());
+ return LLVM::LLVMVoidType::get(&getTypeConverter()->getContext());
}
LLVM::LLVMType ConvertToLLVMPattern::getVoidPtrType() const {
- return LLVM::LLVMType::getInt8PtrTy(&getTypeConverter()->getContext());
+ return LLVM::LLVMPointerType::get(
+ LLVM::LLVMIntegerType::get(&getTypeConverter()->getContext(), 8));
}
Value ConvertToLLVMPattern::createIndexConstant(
@@ -1724,8 +1728,7 @@ struct AssertOpLowering : public ConvertOpToLLVMPattern<AssertOp> {
if (!abortFunc) {
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(module.getBody());
- auto abortFuncTy =
- LLVM::LLVMType::getFunctionTy(getVoidType(), {}, /*isVarArg=*/false);
+ auto abortFuncTy = LLVM::LLVMFunctionType::get(getVoidType(), {});
abortFunc = rewriter.create<LLVM::LLVMFuncOp>(rewriter.getUnknownLoc(),
"abort", abortFuncTy);
}
@@ -1950,8 +1953,7 @@ struct AllocLikeOpLowering : public ConvertToLLVMPattern {
for (Value param : params)
paramTypes.push_back(param.getType().cast<LLVM::LLVMType>());
auto allocFuncType =
- LLVM::LLVMType::getFunctionTy(getVoidPtrType(), paramTypes,
- /*isVarArg=*/false);
+ LLVM::LLVMFunctionType::get(getVoidPtrType(), paramTypes);
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(module.getBody());
allocFuncOp = rewriter.create<LLVM::LLVMFuncOp>(rewriter.getUnknownLoc(),
@@ -2203,9 +2205,10 @@ static LogicalResult copyUnrankedDescriptors(OpBuilder &builder, Location loc,
// Get frequently used types.
MLIRContext *context = builder.getContext();
- auto voidType = LLVM::LLVMType::getVoidTy(context);
- auto voidPtrType = LLVM::LLVMType::getInt8PtrTy(context);
- auto i1Type = LLVM::LLVMType::getInt1Ty(context);
+ auto voidType = LLVM::LLVMVoidType::get(context);
+ LLVM::LLVMType voidPtrType =
+ LLVM::LLVMPointerType::get(LLVM::LLVMIntegerType::get(context, 8));
+ auto i1Type = LLVM::LLVMIntegerType::get(context, 1);
LLVM::LLVMType indexType = typeConverter.getIndexType();
// Find the malloc and free, or declare them if necessary.
@@ -2216,8 +2219,8 @@ static LogicalResult copyUnrankedDescriptors(OpBuilder &builder, Location loc,
builder.setInsertionPointToStart(module.getBody());
mallocFunc = builder.create<LLVM::LLVMFuncOp>(
builder.getUnknownLoc(), "malloc",
- LLVM::LLVMType::getFunctionTy(
- voidPtrType, llvm::makeArrayRef(indexType), /*isVarArg=*/false));
+ LLVM::LLVMFunctionType::get(voidPtrType, llvm::makeArrayRef(indexType),
+ /*isVarArg=*/false));
}
auto freeFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("free");
if (!freeFunc && !toDynamic) {
@@ -2225,8 +2228,8 @@ static LogicalResult copyUnrankedDescriptors(OpBuilder &builder, Location loc,
builder.setInsertionPointToStart(module.getBody());
freeFunc = builder.create<LLVM::LLVMFuncOp>(
builder.getUnknownLoc(), "free",
- LLVM::LLVMType::getFunctionTy(voidType, llvm::makeArrayRef(voidPtrType),
- /*isVarArg=*/false));
+ LLVM::LLVMFunctionType::get(voidType, llvm::makeArrayRef(voidPtrType),
+ /*isVarArg=*/false));
}
// Initialize shared constants.
@@ -2372,8 +2375,7 @@ struct DeallocOpLowering : public ConvertOpToLLVMPattern<DeallocOp> {
op->getParentOfType<ModuleOp>().getBody());
freeFunc = rewriter.create<LLVM::LLVMFuncOp>(
rewriter.getUnknownLoc(), "free",
- LLVM::LLVMType::getFunctionTy(getVoidType(), getVoidPtrType(),
- /*isVarArg=*/false));
+ LLVM::LLVMFunctionType::get(getVoidType(), getVoidPtrType()));
}
MemRefDescriptor memref(transformed.memref());
@@ -2400,7 +2402,7 @@ convertGlobalMemrefTypeToLLVM(MemRefType type,
LLVM::LLVMType arrayTy = elementType;
// Shape has the outermost dim at index 0, so need to walk it backwards
for (int64_t dim : llvm::reverse(type.getShape()))
- arrayTy = LLVM::LLVMType::getArrayTy(arrayTy, dim);
+ arrayTy = LLVM::LLVMArrayType::get(arrayTy, dim);
return arrayTy;
}
@@ -2855,7 +2857,7 @@ struct MemRefReshapeOpLowering
Value zeroIndex = createIndexConstant(rewriter, loc, 0);
Value pred = rewriter.create<LLVM::ICmpOp>(
- loc, LLVM::LLVMType::getInt1Ty(rewriter.getContext()),
+ loc, LLVM::LLVMIntegerType::get(rewriter.getContext(), 1),
LLVM::ICmpPredicate::sge, indexArg, zeroIndex);
Block *bodyBlock =
@@ -3889,8 +3891,9 @@ struct GenericAtomicRMWOpLowering
// Append the cmpxchg op to the end of the loop block.
auto successOrdering = LLVM::AtomicOrdering::acq_rel;
auto failureOrdering = LLVM::AtomicOrdering::monotonic;
- auto boolType = LLVM::LLVMType::getInt1Ty(rewriter.getContext());
- auto pairType = LLVM::LLVMType::getStructTy(valueType, boolType);
+ auto boolType = LLVM::LLVMIntegerType::get(rewriter.getContext(), 1);
+ auto pairType = LLVM::LLVMStructType::getLiteral(rewriter.getContext(),
+ {valueType, boolType});
auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>(
loc, pairType, dataPtr, loopArgument, result, successOrdering,
failureOrdering);
@@ -4067,13 +4070,13 @@ Type LLVMTypeConverter::packFunctionResults(ArrayRef<Type> types) {
resultTypes.push_back(converted);
}
- return LLVM::LLVMType::getStructTy(&getContext(), resultTypes);
+ return LLVM::LLVMStructType::getLiteral(&getContext(), resultTypes);
}
Value LLVMTypeConverter::promoteOneMemRefDescriptor(Location loc, Value operand,
OpBuilder &builder) {
auto *context = builder.getContext();
- auto int64Ty = LLVM::LLVMType::getInt64Ty(builder.getContext());
+ auto int64Ty = LLVM::LLVMIntegerType::get(builder.getContext(), 64);
auto indexType = IndexType::get(context);
// Alloca with proper alignment. We do not expect optimizations of this
// alloca op and so we omit allocating at the entry block.
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index bcc91e304e72..b315417a420b 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -209,7 +209,7 @@ static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter,
if (failed(getBase(rewriter, loc, memref, memRefType, base)))
return failure();
auto pType = MemRefDescriptor(memref).getElementPtrType();
- auto ptrsType = LLVM::LLVMType::getVectorTy(pType, vType.getDimSize(0));
+ auto ptrsType = LLVM::LLVMFixedVectorType::get(pType, vType.getDimSize(0));
ptrs = rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, indices);
return success();
}
@@ -748,7 +748,7 @@ class VectorExtractOpConversion
// Remaining extraction of element from 1-D LLVM vector
auto position = positionAttrs.back().cast<IntegerAttr>();
- auto i64Type = LLVM::LLVMType::getInt64Ty(rewriter.getContext());
+ auto i64Type = LLVM::LLVMIntegerType::get(rewriter.getContext(), 64);
auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
extracted =
rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant);
@@ -856,7 +856,7 @@ class VectorInsertOpConversion
}
// Insertion of an element into a 1-D LLVM vector.
- auto i64Type = LLVM::LLVMType::getInt64Ty(rewriter.getContext());
+ auto i64Type = LLVM::LLVMIntegerType::get(rewriter.getContext(), 64);
auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
Value inserted = rewriter.create<LLVM::InsertElementOp>(
loc, typeConverter->convertType(oneDVectorType), extracted,
@@ -1123,7 +1123,7 @@ class VectorTypeCastOpConversion
}))
return failure();
- auto int64Ty = LLVM::LLVMType::getInt64Ty(rewriter.getContext());
+ auto int64Ty = LLVM::LLVMIntegerType::get(rewriter.getContext(), 64);
// Create descriptor.
auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
@@ -1362,11 +1362,11 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
switch (conversion) {
case PrintConversion::ZeroExt64:
value = rewriter.create<ZeroExtendIOp>(
- loc, value, LLVM::LLVMType::getInt64Ty(rewriter.getContext()));
+ loc, value, LLVM::LLVMIntegerType::get(rewriter.getContext(), 64));
break;
case PrintConversion::SignExt64:
value = rewriter.create<SignExtendIOp>(
- loc, value, LLVM::LLVMType::getInt64Ty(rewriter.getContext()));
+ loc, value, LLVM::LLVMIntegerType::get(rewriter.getContext(), 64));
break;
case PrintConversion::None:
break;
@@ -1410,27 +1410,25 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
OpBuilder moduleBuilder(module.getBodyRegion());
return moduleBuilder.create<LLVM::LLVMFuncOp>(
op->getLoc(), name,
- LLVM::LLVMType::getFunctionTy(
- LLVM::LLVMType::getVoidTy(op->getContext()), params,
- /*isVarArg=*/false));
+ LLVM::LLVMFunctionType::get(LLVM::LLVMVoidType::get(op->getContext()),
+ params));
}
// Helpers for method names.
Operation *getPrintI64(Operation *op) const {
return getPrint(op, "printI64",
- LLVM::LLVMType::getInt64Ty(op->getContext()));
+ LLVM::LLVMIntegerType::get(op->getContext(), 64));
}
Operation *getPrintU64(Operation *op) const {
return getPrint(op, "printU64",
- LLVM::LLVMType::getInt64Ty(op->getContext()));
+ LLVM::LLVMIntegerType::get(op->getContext(), 64));
}
Operation *getPrintFloat(Operation *op) const {
- return getPrint(op, "printF32",
- LLVM::LLVMType::getFloatTy(op->getContext()));
+ return getPrint(op, "printF32", LLVM::LLVMFloatType::get(op->getContext()));
}
Operation *getPrintDouble(Operation *op) const {
return getPrint(op, "printF64",
- LLVM::LLVMType::getDoubleTy(op->getContext()));
+ LLVM::LLVMDoubleType::get(op->getContext()));
}
Operation *getPrintOpen(Operation *op) const {
return getPrint(op, "printOpen", {});
diff --git a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
index 1335f33e10aa..3e3ddc6aaff6 100644
--- a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
+++ b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
@@ -121,7 +121,7 @@ class VectorTransferConversion : public ConvertOpToLLVMPattern<ConcreteOp> {
Type i64Ty = rewriter.getIntegerType(64);
Value i64x2Ty = rewriter.create<LLVM::BitcastOp>(
loc,
- LLVM::LLVMType::getVectorTy(
+ LLVM::LLVMFixedVectorType::get(
toLLVMTy(i64Ty).template cast<LLVM::LLVMType>(), 2),
constConfig);
Value dataPtrAsI64 = rewriter.create<LLVM::PtrToIntOp>(
@@ -129,7 +129,7 @@ class VectorTransferConversion : public ConvertOpToLLVMPattern<ConcreteOp> {
Value zero = this->createIndexConstant(rewriter, loc, 0);
Value dwordConfig = rewriter.create<LLVM::InsertElementOp>(
loc,
- LLVM::LLVMType::getVectorTy(
+ LLVM::LLVMFixedVectorType::get(
toLLVMTy(i64Ty).template cast<LLVM::LLVMType>(), 2),
i64x2Ty, dataPtrAsI64, zero);
dwordConfig =
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 2bdbb877ec84..765538ca7a53 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -101,12 +101,13 @@ static ParseResult parseCmpOp(OpAsmParser &parser, OperationState &result) {
// The result type is either i1 or a vector type <? x i1> if the inputs are
// vectors.
- auto resultType = LLVMType::getInt1Ty(builder.getContext());
+ LLVMType resultType = LLVMIntegerType::get(builder.getContext(), 1);
auto argType = type.dyn_cast<LLVM::LLVMType>();
if (!argType)
return parser.emitError(trailingTypeLoc, "expected LLVM IR dialect type");
if (auto vecArgType = argType.dyn_cast<LLVM::LLVMFixedVectorType>())
- resultType = LLVMType::getVectorTy(resultType, vecArgType.getNumElements());
+ resultType =
+ LLVMFixedVectorType::get(resultType, vecArgType.getNumElements());
assert(!argType.isa<LLVM::LLVMScalableVectorType>() &&
"unhandled scalable vector");
@@ -547,7 +548,7 @@ static ParseResult parseInvokeOp(OpAsmParser &parser, OperationState &result) {
LLVM::LLVMType llvmResultType;
if (funcType.getNumResults() == 0) {
- llvmResultType = LLVM::LLVMType::getVoidTy(builder.getContext());
+ llvmResultType = LLVM::LLVMVoidType::get(builder.getContext());
} else {
llvmResultType = funcType.getResult(0).dyn_cast<LLVM::LLVMType>();
if (!llvmResultType)
@@ -565,8 +566,7 @@ static ParseResult parseInvokeOp(OpAsmParser &parser, OperationState &result) {
"expected LLVM types as inputs");
}
- auto llvmFuncType = LLVM::LLVMType::getFunctionTy(llvmResultType, argTypes,
- /*isVarArg=*/false);
+ auto llvmFuncType = LLVM::LLVMFunctionType::get(llvmResultType, argTypes);
auto wrappedFuncType = LLVM::LLVMPointerType::get(llvmFuncType);
auto funcArguments = llvm::makeArrayRef(operands).drop_front();
@@ -827,7 +827,7 @@ static ParseResult parseCallOp(OpAsmParser &parser, OperationState &result) {
Builder &builder = parser.getBuilder();
LLVM::LLVMType llvmResultType;
if (funcType.getNumResults() == 0) {
- llvmResultType = LLVM::LLVMType::getVoidTy(builder.getContext());
+ llvmResultType = LLVM::LLVMVoidType::get(builder.getContext());
} else {
llvmResultType = funcType.getResult(0).dyn_cast<LLVM::LLVMType>();
if (!llvmResultType)
@@ -844,8 +844,7 @@ static ParseResult parseCallOp(OpAsmParser &parser, OperationState &result) {
"expected LLVM types as inputs");
argTypes.push_back(argType);
}
- auto llvmFuncType = LLVM::LLVMType::getFunctionTy(llvmResultType, argTypes,
- /*isVarArg=*/false);
+ auto llvmFuncType = LLVM::LLVMFunctionType::get(llvmResultType, argTypes);
auto wrappedFuncType = LLVM::LLVMPointerType::get(llvmFuncType);
auto funcArguments =
@@ -1477,8 +1476,8 @@ static ParseResult parseGlobalOp(OpAsmParser &parser, OperationState &result) {
if (types.empty()) {
if (auto strAttr = value.dyn_cast_or_null<StringAttr>()) {
MLIRContext *context = parser.getBuilder().getContext();
- auto arrayType = LLVM::LLVMType::getArrayTy(
- LLVM::LLVMType::getInt8Ty(context), strAttr.getValue().size());
+ auto arrayType = LLVM::LLVMArrayType::get(
+ LLVM::LLVMIntegerType::get(context, 8), strAttr.getValue().size());
types.push_back(arrayType);
} else {
return parser.emitError(parser.getNameLoc(),
@@ -1539,7 +1538,7 @@ void LLVM::ShuffleVectorOp::build(OpBuilder &b, OperationState &result,
ArrayRef<NamedAttribute> attrs) {
auto containerType = v1.getType().cast<LLVM::LLVMVectorType>();
auto vType =
- LLVMType::getVectorTy(containerType.getElementType(), mask.size());
+ LLVMFixedVectorType::get(containerType.getElementType(), mask.size());
build(b, result, vType, v1, v2, mask);
result.addAttributes(attrs);
}
@@ -1574,7 +1573,7 @@ static ParseResult parseShuffleVectorOp(OpAsmParser &parser,
return parser.emitError(
loc, "expected LLVM IR dialect vector type for operand #1");
auto vType =
- LLVMType::getVectorTy(containerType.getElementType(), maskAttr.size());
+ LLVMFixedVectorType::get(containerType.getElementType(), maskAttr.size());
result.addTypes(vType);
return success();
}
@@ -1646,15 +1645,15 @@ static Type buildLLVMFunctionType(OpAsmParser &parser, llvm::SMLoc loc,
}
// No output is denoted as "void" in LLVM type system.
- LLVMType llvmOutput = outputs.empty() ? LLVMType::getVoidTy(b.getContext())
+ LLVMType llvmOutput = outputs.empty() ? LLVMVoidType::get(b.getContext())
: outputs.front().dyn_cast<LLVMType>();
if (!llvmOutput) {
parser.emitError(loc, "failed to construct function type: expected LLVM "
"type for function results");
return {};
}
- return LLVMType::getFunctionTy(llvmOutput, llvmInputs,
- variadicFlag.isVariadic());
+ return LLVMFunctionType::get(llvmOutput, llvmInputs,
+ variadicFlag.isVariadic());
}
// Parses an LLVM function.
@@ -1970,8 +1969,9 @@ static ParseResult parseAtomicCmpXchgOp(OpAsmParser &parser,
parser.resolveOperand(val, type, result.operands))
return failure();
- auto boolType = LLVMType::getInt1Ty(builder.getContext());
- auto resultType = LLVMType::getStructTy(type, boolType);
+ auto boolType = LLVMIntegerType::get(builder.getContext(), 1);
+ auto resultType =
+ LLVMStructType::getLiteral(builder.getContext(), {type, boolType});
result.addTypes(resultType);
return success();
@@ -2159,8 +2159,8 @@ Value mlir::LLVM::createGlobalString(Location loc, OpBuilder &builder,
// Create the global at the entry of the module.
OpBuilder moduleBuilder(module.getBodyRegion());
MLIRContext *ctx = builder.getContext();
- auto type =
- LLVM::LLVMType::getArrayTy(LLVM::LLVMType::getInt8Ty(ctx), value.size());
+ auto type = LLVM::LLVMArrayType::get(LLVM::LLVMIntegerType::get(ctx, 8),
+ value.size());
auto global = moduleBuilder.create<LLVM::GlobalOp>(
loc, type, /*isConstant=*/true, linkage, name,
builder.getStringAttr(value));
@@ -2168,10 +2168,11 @@ Value mlir::LLVM::createGlobalString(Location loc, OpBuilder &builder,
// Get the pointer to the first character in the global string.
Value globalPtr = builder.create<LLVM::AddressOfOp>(loc, global);
Value cst0 = builder.create<LLVM::ConstantOp>(
- loc, LLVM::LLVMType::getInt64Ty(ctx),
+ loc, LLVM::LLVMIntegerType::get(ctx, 64),
builder.getIntegerAttr(builder.getIndexType(), 0));
- return builder.create<LLVM::GEPOp>(loc, LLVM::LLVMType::getInt8PtrTy(ctx),
- globalPtr, ValueRange{cst0, cst0});
+ return builder.create<LLVM::GEPOp>(
+ loc, LLVM::LLVMPointerType::get(LLVMIntegerType::get(ctx, 8)), globalPtr,
+ ValueRange{cst0, cst0});
}
bool mlir::LLVM::satisfiesLLVMModule(Operation *op) {
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
index 0616efb7ef3f..3d75245a1fb3 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
@@ -36,106 +36,9 @@ LLVMDialect &LLVMType::getDialect() {
return static_cast<LLVMDialect &>(Type::getDialect());
}
-//----------------------------------------------------------------------------//
-// Utilities used to generate floating point types.
-
-LLVMType LLVMType::getDoubleTy(MLIRContext *context) {
- return LLVMDoubleType::get(context);
-}
-
-LLVMType LLVMType::getFloatTy(MLIRContext *context) {
- return LLVMFloatType::get(context);
-}
-
-LLVMType LLVMType::getBFloatTy(MLIRContext *context) {
- return LLVMBFloatType::get(context);
-}
-
-LLVMType LLVMType::getHalfTy(MLIRContext *context) {
- return LLVMHalfType::get(context);
-}
-
-LLVMType LLVMType::getFP128Ty(MLIRContext *context) {
- return LLVMFP128Type::get(context);
-}
-
-LLVMType LLVMType::getX86_FP80Ty(MLIRContext *context) {
- return LLVMX86FP80Type::get(context);
-}
-
-//----------------------------------------------------------------------------//
-// Utilities used to generate integer types.
-
-LLVMType LLVMType::getIntNTy(MLIRContext *context, unsigned numBits) {
- return LLVMIntegerType::get(context, numBits);
-}
-
-LLVMType LLVMType::getInt8PtrTy(MLIRContext *context) {
- return LLVMPointerType::get(LLVMIntegerType::get(context, 8));
-}
-
-//----------------------------------------------------------------------------//
-// Utilities used to generate other miscellaneous types.
-
-LLVMType LLVMType::getArrayTy(LLVMType elementType, uint64_t numElements) {
- return LLVMArrayType::get(elementType, numElements);
-}
-
-LLVMType LLVMType::getFunctionTy(LLVMType result, ArrayRef<LLVMType> params,
- bool isVarArg) {
- return LLVMFunctionType::get(result, params, isVarArg);
-}
-
-LLVMType LLVMType::getStructTy(MLIRContext *context,
- ArrayRef<LLVMType> elements, bool isPacked) {
- return LLVMStructType::getLiteral(context, elements, isPacked);
-}
-
-LLVMType LLVMType::getVectorTy(LLVMType elementType, unsigned numElements) {
- return LLVMFixedVectorType::get(elementType, numElements);
-}
-
-//----------------------------------------------------------------------------//
-// Void type utilities.
-
-LLVMType LLVMType::getVoidTy(MLIRContext *context) {
- return LLVMVoidType::get(context);
-}
-
-//----------------------------------------------------------------------------//
-// Creation and setting of LLVM's identified struct types
-
-LLVMType LLVMType::createStructTy(MLIRContext *context,
- ArrayRef<LLVMType> elements,
- Optional<StringRef> name, bool isPacked) {
- assert(name.hasValue() &&
- "identified structs with no identifier not supported");
- StringRef stringNameBase = name.getValueOr("");
- std::string stringName = stringNameBase.str();
- unsigned counter = 0;
- do {
- auto type = LLVMStructType::getIdentified(context, stringName);
- if (type.isInitialized() || failed(type.setBody(elements, isPacked))) {
- counter += 1;
- stringName =
- (Twine(stringNameBase) + "." + std::to_string(counter)).str();
- continue;
- }
- return type;
- } while (true);
-}
-
-LLVMType LLVMType::setStructTyBody(LLVMType structType,
- ArrayRef<LLVMType> elements, bool isPacked) {
- LogicalResult couldSet =
- structType.cast<LLVMStructType>().setBody(elements, isPacked);
- assert(succeeded(couldSet) && "failed to set the body");
- (void)couldSet;
- return structType;
-}
-
//===----------------------------------------------------------------------===//
// Array type.
+//===----------------------------------------------------------------------===//
bool LLVMArrayType::isValidElementType(LLVMType type) {
return !type.isa<LLVMVoidType, LLVMLabelType, LLVMMetadataType,
@@ -167,6 +70,7 @@ LLVMArrayType::verifyConstructionInvariants(Location loc, LLVMType elementType,
//===----------------------------------------------------------------------===//
// Function type.
+//===----------------------------------------------------------------------===//
bool LLVMFunctionType::isValidArgumentType(LLVMType type) {
return !type.isa<LLVMVoidType, LLVMFunctionType>();
@@ -222,6 +126,7 @@ LogicalResult LLVMFunctionType::verifyConstructionInvariants(
//===----------------------------------------------------------------------===//
// Integer type.
+//===----------------------------------------------------------------------===//
LLVMIntegerType LLVMIntegerType::get(MLIRContext *ctx, unsigned bitwidth) {
return Base::get(ctx, bitwidth);
@@ -243,6 +148,7 @@ LogicalResult LLVMIntegerType::verifyConstructionInvariants(Location loc,
//===----------------------------------------------------------------------===//
// Pointer type.
+//===----------------------------------------------------------------------===//
bool LLVMPointerType::isValidElementType(LLVMType type) {
return !type.isa<LLVMVoidType, LLVMTokenType, LLVMMetadataType,
@@ -273,6 +179,7 @@ LogicalResult LLVMPointerType::verifyConstructionInvariants(Location loc,
//===----------------------------------------------------------------------===//
// Struct type.
+//===----------------------------------------------------------------------===//
bool LLVMStructType::isValidElementType(LLVMType type) {
return !type.isa<LLVMVoidType, LLVMLabelType, LLVMMetadataType,
@@ -289,6 +196,23 @@ LLVMStructType LLVMStructType::getIdentifiedChecked(Location loc,
return Base::getChecked(loc, name, /*opaque=*/false);
}
+LLVMStructType LLVMStructType::getNewIdentified(MLIRContext *context,
+ StringRef name,
+ ArrayRef<LLVMType> elements,
+ bool isPacked) {
+ std::string stringName = name.str();
+ unsigned counter = 0;
+ do {
+ auto type = LLVMStructType::getIdentified(context, stringName);
+ if (type.isInitialized() || failed(type.setBody(elements, isPacked))) {
+ counter += 1;
+ stringName = (Twine(name) + "." + std::to_string(counter)).str();
+ continue;
+ }
+ return type;
+ } while (true);
+}
+
LLVMStructType LLVMStructType::getLiteral(MLIRContext *context,
ArrayRef<LLVMType> types,
bool isPacked) {
@@ -346,6 +270,7 @@ LLVMStructType::verifyConstructionInvariants(Location loc,
//===----------------------------------------------------------------------===//
// Vector types.
+//===----------------------------------------------------------------------===//
bool LLVMVectorType::isValidElementType(LLVMType type) {
return type.isa<LLVMIntegerType, LLVMPointerType>() ||
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index c202075fa206..c2f689be493a 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -63,7 +63,8 @@ static ParseResult parseNVVMShflSyncBflyOp(OpAsmParser &parser,
break;
}
- auto int32Ty = LLVM::LLVMType::getInt32Ty(parser.getBuilder().getContext());
+ auto int32Ty =
+ LLVM::LLVMIntegerType::get(parser.getBuilder().getContext(), 32);
return parser.resolveOperands(ops, {int32Ty, type, int32Ty, int32Ty},
parser.getNameLoc(), result.operands);
}
@@ -72,8 +73,8 @@ static ParseResult parseNVVMShflSyncBflyOp(OpAsmParser &parser,
static ParseResult parseNVVMVoteBallotOp(OpAsmParser &parser,
OperationState &result) {
MLIRContext *context = parser.getBuilder().getContext();
- auto int32Ty = LLVM::LLVMType::getInt32Ty(context);
- auto int1Ty = LLVM::LLVMType::getInt1Ty(context);
+ auto int32Ty = LLVM::LLVMIntegerType::get(context, 32);
+ auto int1Ty = LLVM::LLVMIntegerType::get(context, 1);
SmallVector<OpAsmParser::OperandType, 8> ops;
Type type;
@@ -87,12 +88,12 @@ static ParseResult parseNVVMVoteBallotOp(OpAsmParser &parser,
static LogicalResult verify(MmaOp op) {
MLIRContext *context = op.getContext();
- auto f16Ty = LLVM::LLVMType::getHalfTy(context);
- auto f16x2Ty = LLVM::LLVMType::getVectorTy(f16Ty, 2);
- auto f32Ty = LLVM::LLVMType::getFloatTy(context);
- auto f16x2x4StructTy = LLVM::LLVMType::getStructTy(
+ auto f16Ty = LLVM::LLVMHalfType::get(context);
+ auto f16x2Ty = LLVM::LLVMFixedVectorType::get(f16Ty, 2);
+ auto f32Ty = LLVM::LLVMFloatType::get(context);
+ auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral(
context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
- auto f32x8StructTy = LLVM::LLVMType::getStructTy(
+ auto f32x8StructTy = LLVM::LLVMStructType::getLiteral(
context, {f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty});
SmallVector<Type, 12> operand_types(op.getOperandTypes().begin(),
diff --git a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
index 8d0c96ce2aa1..f50c49f03a07 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
@@ -46,9 +46,9 @@ static ParseResult parseROCDLMubufLoadOp(OpAsmParser &parser,
return failure();
MLIRContext *context = parser.getBuilder().getContext();
- auto int32Ty = LLVM::LLVMType::getInt32Ty(context);
- auto int1Ty = LLVM::LLVMType::getInt1Ty(context);
- auto i32x4Ty = LLVM::LLVMType::getVectorTy(int32Ty, 4);
+ auto int32Ty = LLVM::LLVMIntegerType::get(context, 32);
+ auto int1Ty = LLVM::LLVMIntegerType::get(context, 1);
+ auto i32x4Ty = LLVM::LLVMFixedVectorType::get(int32Ty, 4);
return parser.resolveOperands(ops,
{i32x4Ty, int32Ty, int32Ty, int1Ty, int1Ty},
parser.getNameLoc(), result.operands);
@@ -65,9 +65,9 @@ static ParseResult parseROCDLMubufStoreOp(OpAsmParser &parser,
return failure();
MLIRContext *context = parser.getBuilder().getContext();
- auto int32Ty = LLVM::LLVMType::getInt32Ty(context);
- auto int1Ty = LLVM::LLVMType::getInt1Ty(context);
- auto i32x4Ty = LLVM::LLVMType::getVectorTy(int32Ty, 4);
+ auto int32Ty = LLVM::LLVMIntegerType::get(context, 32);
+ auto int1Ty = LLVM::LLVMIntegerType::get(context, 1);
+ auto i32x4Ty = LLVM::LLVMFixedVectorType::get(int32Ty, 4);
if (parser.resolveOperands(ops,
{type, i32x4Ty, int32Ty, int32Ty, int1Ty, int1Ty},
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 0b2cf7de270f..c28588d32ad6 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -769,13 +769,12 @@ LogicalResult ModuleTranslation::convertOperation(Operation &opInst,
LLVM::LLVMType resultType;
if (inlineAsmOp.getNumResults() == 0) {
- resultType = LLVM::LLVMType::getVoidTy(mlirModule->getContext());
+ resultType = LLVM::LLVMVoidType::get(mlirModule->getContext());
} else {
assert(inlineAsmOp.getNumResults() == 1);
resultType = inlineAsmOp.getResultTypes()[0].cast<LLVM::LLVMType>();
}
- auto ft = LLVM::LLVMType::getFunctionTy(resultType, operandTypes,
- /*isVarArg=*/false);
+ auto ft = LLVM::LLVMFunctionType::get(resultType, operandTypes);
llvm::InlineAsm *inlineAsmInst =
inlineAsmOp.asm_dialect().hasValue()
? llvm::InlineAsm::get(
diff --git a/mlir/test/lib/Transforms/TestConvertCallOp.cpp b/mlir/test/lib/Transforms/TestConvertCallOp.cpp
index 61062c7938fe..82cc95aac8a8 100644
--- a/mlir/test/lib/Transforms/TestConvertCallOp.cpp
+++ b/mlir/test/lib/Transforms/TestConvertCallOp.cpp
@@ -45,7 +45,8 @@ class TestConvertCallOp
// Populate type conversions.
LLVMTypeConverter type_converter(m.getContext());
type_converter.addConversion([&](test::TestType type) {
- return LLVM::LLVMType::getInt8PtrTy(m.getContext());
+ return LLVM::LLVMPointerType::get(
+ LLVM::LLVMIntegerType::get(m.getContext(), 8));
});
// Populate patterns.
More information about the Mlir-commits
mailing list