[Mlir-commits] [mlir] 5446ec8 - [mlir] take MLIRContext instead of LLVMDialect in getters of LLVMType's
Alex Zinenko
llvmlistbot at llvm.org
Thu Aug 6 02:05:49 PDT 2020
Author: Alex Zinenko
Date: 2020-08-06T11:05:40+02:00
New Revision: 5446ec8507080ac075274a30c7cf25652d9a860f
URL: https://github.com/llvm/llvm-project/commit/5446ec8507080ac075274a30c7cf25652d9a860f
DIFF: https://github.com/llvm/llvm-project/commit/5446ec8507080ac075274a30c7cf25652d9a860f.diff
LOG: [mlir] take MLIRContext instead of LLVMDialect in getters of LLVMType's
Historical modeling of the LLVM dialect types had been wrapping LLVM IR types
and therefore needed access to the instance of LLVMContext stored in the
LLVMDialect. The new modeling does not rely on that and only needs the
MLIRContext that is used for uniquing, similarly to other MLIR types. Change
LLVMType::get<Kind>Ty functions to take `MLIRContext *` instead of
`LLVMDialect *` as first argument. This brings the code base closer to
completely removing the dependence on LLVMContext from the LLVMDialect,
together with additional support for thread-safety of its use.
Depends On D85371
Reviewed By: rriddle
Differential Revision: https://reviews.llvm.org/D85372
Added:
Modified:
mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp
mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h
mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp
mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
Removed:
################################################################################
diff --git a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
index af4130c6a5ca..74b32dc0ca11 100644
--- a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
+++ b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
@@ -56,19 +56,15 @@ class PrintOpLowering : public ConversionPattern {
auto memRefType = (*op->operand_type_begin()).cast<MemRefType>();
auto memRefShape = memRefType.getShape();
auto loc = op->getLoc();
- auto *llvmDialect =
- op->getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
- assert(llvmDialect && "expected llvm dialect to be registered");
ModuleOp parentModule = op->getParentOfType<ModuleOp>();
// Get a symbol reference to the printf function, inserting it if necessary.
- auto printfRef = getOrInsertPrintf(rewriter, parentModule, llvmDialect);
+ auto printfRef = getOrInsertPrintf(rewriter, parentModule);
Value formatSpecifierCst = getOrCreateGlobalString(
- loc, rewriter, "frmt_spec", StringRef("%f \0", 4), parentModule,
- llvmDialect);
+ loc, rewriter, "frmt_spec", StringRef("%f \0", 4), parentModule);
Value newLineCst = getOrCreateGlobalString(
- loc, rewriter, "nl", StringRef("\n\0", 2), parentModule, llvmDialect);
+ loc, rewriter, "nl", StringRef("\n\0", 2), parentModule);
// Create a loop for each of the dimensions within the shape.
SmallVector<Value, 4> loopIvs;
@@ -108,16 +104,15 @@ class PrintOpLowering : public ConversionPattern {
/// Return a symbol reference to the printf function, inserting it into the
/// module if necessary.
static FlatSymbolRefAttr getOrInsertPrintf(PatternRewriter &rewriter,
- ModuleOp module,
- LLVM::LLVMDialect *llvmDialect) {
+ ModuleOp module) {
auto *context = module.getContext();
if (module.lookupSymbol<LLVM::LLVMFuncOp>("printf"))
return SymbolRefAttr::get("printf", context);
// Create a function declaration for printf, the signature is:
// * `i32 (i8*, ...)`
- auto llvmI32Ty = LLVM::LLVMType::getInt32Ty(llvmDialect);
- auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
+ auto llvmI32Ty = LLVM::LLVMType::getInt32Ty(context);
+ auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(context);
auto llvmFnType = LLVM::LLVMType::getFunctionTy(llvmI32Ty, llvmI8PtrTy,
/*isVarArg=*/true);
@@ -132,15 +127,14 @@ class PrintOpLowering : public ConversionPattern {
/// name, creating the string if necessary.
static Value getOrCreateGlobalString(Location loc, OpBuilder &builder,
StringRef name, StringRef value,
- ModuleOp module,
- LLVM::LLVMDialect *llvmDialect) {
+ ModuleOp module) {
// Create the global at the entry of the module.
LLVM::GlobalOp global;
if (!(global = module.lookupSymbol<LLVM::GlobalOp>(name))) {
OpBuilder::InsertionGuard insertGuard(builder);
builder.setInsertionPointToStart(module.getBody());
auto type = LLVM::LLVMType::getArrayTy(
- LLVM::LLVMType::getInt8Ty(llvmDialect), value.size());
+ LLVM::LLVMType::getInt8Ty(builder.getContext()), value.size());
global = builder.create<LLVM::GlobalOp>(loc, type, /*isConstant=*/true,
LLVM::Linkage::Internal, name,
builder.getStringAttr(value));
@@ -149,10 +143,10 @@ class PrintOpLowering : public ConversionPattern {
// Get the pointer to the first character in the global string.
Value globalPtr = builder.create<LLVM::AddressOfOp>(loc, global);
Value cst0 = builder.create<LLVM::ConstantOp>(
- loc, LLVM::LLVMType::getInt64Ty(llvmDialect),
+ loc, LLVM::LLVMType::getInt64Ty(builder.getContext()),
builder.getIntegerAttr(builder.getIndexType(), 0));
return builder.create<LLVM::GEPOp>(
- loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), globalPtr,
+ loc, LLVM::LLVMType::getInt8PtrTy(builder.getContext()), globalPtr,
ArrayRef<Value>({cst0, cst0}));
}
};
diff --git a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
index af4130c6a5ca..74b32dc0ca11 100644
--- a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
+++ b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
@@ -56,19 +56,15 @@ class PrintOpLowering : public ConversionPattern {
auto memRefType = (*op->operand_type_begin()).cast<MemRefType>();
auto memRefShape = memRefType.getShape();
auto loc = op->getLoc();
- auto *llvmDialect =
- op->getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
- assert(llvmDialect && "expected llvm dialect to be registered");
ModuleOp parentModule = op->getParentOfType<ModuleOp>();
// Get a symbol reference to the printf function, inserting it if necessary.
- auto printfRef = getOrInsertPrintf(rewriter, parentModule, llvmDialect);
+ auto printfRef = getOrInsertPrintf(rewriter, parentModule);
Value formatSpecifierCst = getOrCreateGlobalString(
- loc, rewriter, "frmt_spec", StringRef("%f \0", 4), parentModule,
- llvmDialect);
+ loc, rewriter, "frmt_spec", StringRef("%f \0", 4), parentModule);
Value newLineCst = getOrCreateGlobalString(
- loc, rewriter, "nl", StringRef("\n\0", 2), parentModule, llvmDialect);
+ loc, rewriter, "nl", StringRef("\n\0", 2), parentModule);
// Create a loop for each of the dimensions within the shape.
SmallVector<Value, 4> loopIvs;
@@ -108,16 +104,15 @@ class PrintOpLowering : public ConversionPattern {
/// Return a symbol reference to the printf function, inserting it into the
/// module if necessary.
static FlatSymbolRefAttr getOrInsertPrintf(PatternRewriter &rewriter,
- ModuleOp module,
- LLVM::LLVMDialect *llvmDialect) {
+ ModuleOp module) {
auto *context = module.getContext();
if (module.lookupSymbol<LLVM::LLVMFuncOp>("printf"))
return SymbolRefAttr::get("printf", context);
// Create a function declaration for printf, the signature is:
// * `i32 (i8*, ...)`
- auto llvmI32Ty = LLVM::LLVMType::getInt32Ty(llvmDialect);
- auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
+ auto llvmI32Ty = LLVM::LLVMType::getInt32Ty(context);
+ auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(context);
auto llvmFnType = LLVM::LLVMType::getFunctionTy(llvmI32Ty, llvmI8PtrTy,
/*isVarArg=*/true);
@@ -132,15 +127,14 @@ class PrintOpLowering : public ConversionPattern {
/// name, creating the string if necessary.
static Value getOrCreateGlobalString(Location loc, OpBuilder &builder,
StringRef name, StringRef value,
- ModuleOp module,
- LLVM::LLVMDialect *llvmDialect) {
+ ModuleOp module) {
// Create the global at the entry of the module.
LLVM::GlobalOp global;
if (!(global = module.lookupSymbol<LLVM::GlobalOp>(name))) {
OpBuilder::InsertionGuard insertGuard(builder);
builder.setInsertionPointToStart(module.getBody());
auto type = LLVM::LLVMType::getArrayTy(
- LLVM::LLVMType::getInt8Ty(llvmDialect), value.size());
+ LLVM::LLVMType::getInt8Ty(builder.getContext()), value.size());
global = builder.create<LLVM::GlobalOp>(loc, type, /*isConstant=*/true,
LLVM::Linkage::Internal, name,
builder.getStringAttr(value));
@@ -149,10 +143,10 @@ class PrintOpLowering : public ConversionPattern {
// Get the pointer to the first character in the global string.
Value globalPtr = builder.create<LLVM::AddressOfOp>(loc, global);
Value cst0 = builder.create<LLVM::ConstantOp>(
- loc, LLVM::LLVMType::getInt64Ty(llvmDialect),
+ loc, LLVM::LLVMType::getInt64Ty(builder.getContext()),
builder.getIntegerAttr(builder.getIndexType(), 0));
return builder.create<LLVM::GEPOp>(
- loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), globalPtr,
+ loc, LLVM::LLVMType::getInt8PtrTy(builder.getContext()), globalPtr,
ArrayRef<Value>({cst0, cst0}));
}
};
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
index 2853ef631761..1614a244070f 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
@@ -59,8 +59,7 @@ struct LLVMDialectImpl;
/// global and use it to compute the address of the first character in the
/// string (operations inserted at the builder insertion point).
Value createGlobalString(Location loc, OpBuilder &builder, StringRef name,
- StringRef value, LLVM::Linkage linkage,
- LLVM::LLVMDialect *llvmDialect);
+ StringRef value, LLVM::Linkage linkage);
/// LLVM requires some operations to be inside of a Module operation. This
/// function confirms that the Operation has the desired properties.
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
index 6b265e73d897..fb4001f1bc9b 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
@@ -58,8 +58,7 @@ class LLVMI<int width>
"$_self.cast<::mlir::LLVM::LLVMType>().isIntegerTy(" # width # ")">]>,
"LLVM dialect " # width # "-bit integer">,
BuildableType<
- "::mlir::LLVM::LLVMType::getIntNTy("
- "$_builder.getContext()->getRegisteredDialect<LLVM::LLVMDialect>(),"
+ "::mlir::LLVM::LLVMType::getIntNTy($_builder.getContext(),"
# width # ")">;
def LLVMI1 : LLVMI<1>;
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 768d8db121df..a4a0db171e81 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -151,8 +151,7 @@ def LLVM_ICmpOp : LLVM_OneResultOp<"icmp", [NoSideEffect]>,
let builders = [OpBuilder<
"OpBuilder &b, OperationState &result, ICmpPredicate predicate, Value lhs, "
"Value rhs", [{
- LLVMDialect *dialect = &lhs.getType().cast<LLVMType>().getDialect();
- build(b, result, LLVMType::getInt1Ty(dialect),
+ build(b, result, LLVMType::getInt1Ty(lhs.getType().getContext()),
b.getI64IntegerAttr(static_cast<int64_t>(predicate)), lhs, rhs);
}]>];
let parser = [{ return parseCmpOp<ICmpPredicate>(parser, result); }];
@@ -198,8 +197,7 @@ def LLVM_FCmpOp : LLVM_OneResultOp<"fcmp", [NoSideEffect]>,
let builders = [OpBuilder<
"OpBuilder &b, OperationState &result, FCmpPredicate predicate, Value lhs, "
"Value rhs", [{
- LLVMDialect *dialect = &lhs.getType().cast<LLVMType>().getDialect();
- build(b, result, LLVMType::getInt1Ty(dialect),
+ build(b, result, LLVMType::getInt1Ty(lhs.getType().getContext()),
b.getI64IntegerAttr(static_cast<int64_t>(predicate)), lhs, rhs);
}]>];
let parser = [{ return parseCmpOp<FCmpPredicate>(parser, result); }];
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
index 7d7839c166f7..0d3a6f3249b1 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
@@ -152,32 +152,32 @@ class LLVMType : public Type::TypeBase<LLVMType, Type, TypeStorage> {
bool isStructTy();
/// Utilities used to generate floating point types.
- static LLVMType getDoubleTy(LLVMDialect *dialect);
- static LLVMType getFloatTy(LLVMDialect *dialect);
- static LLVMType getBFloatTy(LLVMDialect *dialect);
- static LLVMType getHalfTy(LLVMDialect *dialect);
- static LLVMType getFP128Ty(LLVMDialect *dialect);
- static LLVMType getX86_FP80Ty(LLVMDialect *dialect);
+ static LLVMType getDoubleTy(MLIRContext *context);
+ static LLVMType getFloatTy(MLIRContext *context);
+ static LLVMType getBFloatTy(MLIRContext *context);
+ static LLVMType getHalfTy(MLIRContext *context);
+ static LLVMType getFP128Ty(MLIRContext *context);
+ static LLVMType getX86_FP80Ty(MLIRContext *context);
/// Utilities used to generate integer types.
- static LLVMType getIntNTy(LLVMDialect *dialect, unsigned numBits);
- static LLVMType getInt1Ty(LLVMDialect *dialect) {
- return getIntNTy(dialect, /*numBits=*/1);
+ static LLVMType getIntNTy(MLIRContext *context, unsigned numBits);
+ static LLVMType getInt1Ty(MLIRContext *context) {
+ return getIntNTy(context, /*numBits=*/1);
}
- static LLVMType getInt8Ty(LLVMDialect *dialect) {
- return getIntNTy(dialect, /*numBits=*/8);
+ static LLVMType getInt8Ty(MLIRContext *context) {
+ return getIntNTy(context, /*numBits=*/8);
}
- static LLVMType getInt8PtrTy(LLVMDialect *dialect) {
- return getInt8Ty(dialect).getPointerTo();
+ static LLVMType getInt8PtrTy(MLIRContext *context) {
+ return getInt8Ty(context).getPointerTo();
}
- static LLVMType getInt16Ty(LLVMDialect *dialect) {
- return getIntNTy(dialect, /*numBits=*/16);
+ static LLVMType getInt16Ty(MLIRContext *context) {
+ return getIntNTy(context, /*numBits=*/16);
}
- static LLVMType getInt32Ty(LLVMDialect *dialect) {
- return getIntNTy(dialect, /*numBits=*/32);
+ static LLVMType getInt32Ty(MLIRContext *context) {
+ return getIntNTy(context, /*numBits=*/32);
}
- static LLVMType getInt64Ty(LLVMDialect *dialect) {
- return getIntNTy(dialect, /*numBits=*/64);
+ static LLVMType getInt64Ty(MLIRContext *context) {
+ return getIntNTy(context, /*numBits=*/64);
}
/// Utilities used to generate other miscellaneous types.
@@ -187,33 +187,33 @@ class LLVMType : public Type::TypeBase<LLVMType, Type, TypeStorage> {
static LLVMType getFunctionTy(LLVMType result, bool isVarArg) {
return getFunctionTy(result, llvm::None, isVarArg);
}
- static LLVMType getStructTy(LLVMDialect *dialect, ArrayRef<LLVMType> elements,
+ static LLVMType getStructTy(MLIRContext *context, ArrayRef<LLVMType> elements,
bool isPacked = false);
- static LLVMType getStructTy(LLVMDialect *dialect, bool isPacked = false) {
- return getStructTy(dialect, llvm::None, isPacked);
+ static LLVMType getStructTy(MLIRContext *context, bool isPacked = false) {
+ return getStructTy(context, llvm::None, isPacked);
}
template <typename... Args>
static typename std::enable_if<llvm::are_base_of<LLVMType, Args...>::value,
LLVMType>::type
getStructTy(LLVMType elt1, Args... elts) {
SmallVector<LLVMType, 8> fields({elt1, elts...});
- return getStructTy(&elt1.getDialect(), fields);
+ return getStructTy(elt1.getContext(), fields);
}
static LLVMType getVectorTy(LLVMType elementType, unsigned numElements);
/// Void type utilities.
- static LLVMType getVoidTy(LLVMDialect *dialect);
+ static LLVMType getVoidTy(MLIRContext *context);
bool isVoidTy();
// Creation and setting of LLVM's identified struct types
- static LLVMType createStructTy(LLVMDialect *dialect,
+ static LLVMType createStructTy(MLIRContext *context,
ArrayRef<LLVMType> elements,
Optional<StringRef> name,
bool isPacked = false);
- static LLVMType createStructTy(LLVMDialect *dialect,
+ static LLVMType createStructTy(MLIRContext *context,
Optional<StringRef> name) {
- return createStructTy(dialect, llvm::None, name);
+ return createStructTy(context, llvm::None, name);
}
static LLVMType createStructTy(ArrayRef<LLVMType> elements,
@@ -222,7 +222,7 @@ class LLVMType : public Type::TypeBase<LLVMType, Type, TypeStorage> {
assert(!elements.empty() &&
"This method may not be invoked with an empty list");
LLVMType ele0 = elements.front();
- return createStructTy(&ele0.getDialect(), elements, name, isPacked);
+ return createStructTy(ele0.getContext(), elements, name, isPacked);
}
template <typename... Args>
@@ -231,7 +231,7 @@ class LLVMType : public Type::TypeBase<LLVMType, Type, TypeStorage> {
createStructTy(StringRef name, LLVMType elt1, Args... elts) {
SmallVector<LLVMType, 8> fields({elt1, elts...});
Optional<StringRef> opt_name(name);
- return createStructTy(&elt1.getDialect(), fields, opt_name);
+ return createStructTy(elt1.getContext(), fields, opt_name);
}
static LLVMType setStructTyBody(LLVMType structType,
diff --git a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp
index c5ecaf798ebd..e186a335214b 100644
--- a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp
+++ b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp
@@ -67,14 +67,14 @@ class GpuLaunchFuncToGpuRuntimeCallsPass
LLVM::LLVMDialect *getLLVMDialect() { return llvmDialect; }
void initializeCachedTypes() {
- llvmVoidType = LLVM::LLVMType::getVoidTy(llvmDialect);
- llvmPointerType = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
+ llvmVoidType = LLVM::LLVMType::getVoidTy(&getContext());
+ llvmPointerType = LLVM::LLVMType::getInt8PtrTy(&getContext());
llvmPointerPointerType = llvmPointerType.getPointerTo();
- llvmInt8Type = LLVM::LLVMType::getInt8Ty(llvmDialect);
- llvmInt32Type = LLVM::LLVMType::getInt32Ty(llvmDialect);
- llvmInt64Type = LLVM::LLVMType::getInt64Ty(llvmDialect);
+ llvmInt8Type = LLVM::LLVMType::getInt8Ty(&getContext());
+ llvmInt32Type = LLVM::LLVMType::getInt32Ty(&getContext());
+ llvmInt64Type = LLVM::LLVMType::getInt64Ty(&getContext());
llvmIntPtrType = LLVM::LLVMType::getIntNTy(
- llvmDialect, llvmDialect->getDataLayout().getPointerSizeInBits());
+ &getContext(), llvmDialect->getDataLayout().getPointerSizeInBits());
}
LLVM::LLVMType getVoidType() { return llvmVoidType; }
@@ -91,7 +91,7 @@ class GpuLaunchFuncToGpuRuntimeCallsPass
LLVM::LLVMType getIntPtrType() {
return LLVM::LLVMType::getIntNTy(
- getLLVMDialect(),
+ &getContext(),
getLLVMDialect()->getDataLayout().getPointerSizeInBits());
}
@@ -340,7 +340,7 @@ Value GpuLaunchFuncToGpuRuntimeCallsPass::generateKernelNameConstant(
std::string(llvm::formatv("{0}_{1}_kernel_name", moduleName, name));
return LLVM::createGlobalString(
loc, builder, globalName, StringRef(kernelName.data(), kernelName.size()),
- LLVM::Linkage::Internal, llvmDialect);
+ LLVM::Linkage::Internal);
}
// Emits LLVM IR to launch a kernel function. Expects the module that contains
@@ -378,9 +378,9 @@ void GpuLaunchFuncToGpuRuntimeCallsPass::translateGpuLaunchCalls(
SmallString<128> nameBuffer(kernelModule.getName());
nameBuffer.append(kGpuBinaryStorageSuffix);
- Value data = LLVM::createGlobalString(
- loc, builder, nameBuffer.str(), binaryAttr.getValue(),
- LLVM::Linkage::Internal, getLLVMDialect());
+ Value data =
+ LLVM::createGlobalString(loc, builder, nameBuffer.str(),
+ binaryAttr.getValue(), LLVM::Linkage::Internal);
// Emit the load module call to load the module data. Error checking is done
// in the called helper function.
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
index f6aede4c0d70..f15ccdc7104c 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
@@ -89,7 +89,7 @@ struct GPUFuncOpLowering : ConvertToLLVMPattern {
// Rewrite workgroup memory attributions to addresses of global buffers.
rewriter.setInsertionPointToStart(&gpuFuncOp.front());
unsigned numProperArguments = gpuFuncOp.getNumArguments();
- auto i32Type = LLVM::LLVMType::getInt32Ty(typeConverter.getDialect());
+ auto i32Type = LLVM::LLVMType::getInt32Ty(rewriter.getContext());
Value zero = nullptr;
if (!workgroupBuffers.empty())
@@ -117,7 +117,7 @@ struct GPUFuncOpLowering : ConvertToLLVMPattern {
// Rewrite private memory attributions to alloca'ed buffers.
unsigned numWorkgroupAttributions =
gpuFuncOp.getNumWorkgroupAttributions();
- auto int64Ty = LLVM::LLVMType::getInt64Ty(typeConverter.getDialect());
+ auto int64Ty = LLVM::LLVMType::getInt64Ty(rewriter.getContext());
for (auto en : llvm::enumerate(gpuFuncOp.getPrivateAttributions())) {
Value attribution = en.value();
auto type = attribution.getType().cast<MemRefType>();
diff --git a/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h b/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h
index 6d4b99a77d7d..93381054dd21 100644
--- a/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h
@@ -46,17 +46,17 @@ struct GPUIndexIntrinsicOpLowering : public ConvertToLLVMPattern {
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
- auto dialect = typeConverter.getDialect();
+ MLIRContext *context = rewriter.getContext();
Value newOp;
switch (dimensionToIndex(cast<Op>(op))) {
case X:
- newOp = rewriter.create<XOp>(loc, LLVM::LLVMType::getInt32Ty(dialect));
+ newOp = rewriter.create<XOp>(loc, LLVM::LLVMType::getInt32Ty(context));
break;
case Y:
- newOp = rewriter.create<YOp>(loc, LLVM::LLVMType::getInt32Ty(dialect));
+ newOp = rewriter.create<YOp>(loc, LLVM::LLVMType::getInt32Ty(context));
break;
case Z:
- newOp = rewriter.create<ZOp>(loc, LLVM::LLVMType::getInt32Ty(dialect));
+ newOp = rewriter.create<ZOp>(loc, LLVM::LLVMType::getInt32Ty(context));
break;
default:
return failure();
@@ -64,10 +64,10 @@ struct GPUIndexIntrinsicOpLowering : public ConvertToLLVMPattern {
if (indexBitwidth > 32) {
newOp = rewriter.create<LLVM::SExtOp>(
- loc, LLVM::LLVMType::getIntNTy(dialect, indexBitwidth), newOp);
+ loc, LLVM::LLVMType::getIntNTy(context, indexBitwidth), newOp);
} else if (indexBitwidth < 32) {
newOp = rewriter.create<LLVM::TruncOp>(
- loc, LLVM::LLVMType::getIntNTy(dialect, indexBitwidth), newOp);
+ loc, LLVM::LLVMType::getIntNTy(context, indexBitwidth), newOp);
}
rewriter.replaceOp(op, {newOp});
diff --git a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
index 58b5f1dbc975..fc743823fd31 100644
--- a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
@@ -85,7 +85,7 @@ struct OpToFuncCallLowering : public ConvertToLLVMPattern {
return operand;
return rewriter.create<LLVM::FPExtOp>(
- operand.getLoc(), LLVM::LLVMType::getFloatTy(&type.getDialect()),
+ operand.getLoc(), LLVM::LLVMType::getFloatTy(rewriter.getContext()),
operand);
}
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index afb6d2875866..76c166842c2d 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -57,11 +57,11 @@ struct GPUShuffleOpLowering : public ConvertToLLVMPattern {
Location loc = op->getLoc();
gpu::ShuffleOpAdaptor adaptor(operands);
- auto dialect = typeConverter.getDialect();
auto valueTy = adaptor.value().getType().cast<LLVM::LLVMType>();
- auto int32Type = LLVM::LLVMType::getInt32Ty(dialect);
- auto predTy = LLVM::LLVMType::getInt1Ty(dialect);
- auto resultTy = LLVM::LLVMType::getStructTy(dialect, {valueTy, predTy});
+ auto int32Type = LLVM::LLVMType::getInt32Ty(rewriter.getContext());
+ auto predTy = LLVM::LLVMType::getInt1Ty(rewriter.getContext());
+ auto resultTy =
+ LLVM::LLVMType::getStructTy(rewriter.getContext(), {valueTy, predTy});
Value one = rewriter.create<LLVM::ConstantOp>(
loc, int32Type, rewriter.getI32IntegerAttr(1));
diff --git a/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp b/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
index c1a64bd091a9..377b4f0e3e55 100644
--- a/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
+++ b/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
@@ -57,15 +57,12 @@ class VulkanLaunchFuncToVulkanCallsPass
: public ConvertVulkanLaunchFuncToVulkanCallsBase<
VulkanLaunchFuncToVulkanCallsPass> {
private:
- LLVM::LLVMDialect *getLLVMDialect() { return llvmDialect; }
-
void initializeCachedTypes() {
- llvmDialect = getContext().getRegisteredDialect<LLVM::LLVMDialect>();
- llvmFloatType = LLVM::LLVMType::getFloatTy(llvmDialect);
- llvmVoidType = LLVM::LLVMType::getVoidTy(llvmDialect);
- llvmPointerType = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
- llvmInt32Type = LLVM::LLVMType::getInt32Ty(llvmDialect);
- llvmInt64Type = LLVM::LLVMType::getInt64Ty(llvmDialect);
+ llvmFloatType = LLVM::LLVMType::getFloatTy(&getContext());
+ llvmVoidType = LLVM::LLVMType::getVoidTy(&getContext());
+ llvmPointerType = LLVM::LLVMType::getInt8PtrTy(&getContext());
+ llvmInt32Type = LLVM::LLVMType::getInt32Ty(&getContext());
+ llvmInt64Type = LLVM::LLVMType::getInt64Ty(&getContext());
}
LLVM::LLVMType getMemRefType(uint32_t rank, LLVM::LLVMType elemenType) {
@@ -87,7 +84,7 @@ class VulkanLaunchFuncToVulkanCallsPass
// `!llvm<"{ `element-type`*, `element-type`*, i64,
// [`rank` x i64], [`rank` x i64]}">`.
return LLVM::LLVMType::getStructTy(
- llvmDialect,
+ &getContext(),
{llvmPtrToElementType, llvmPtrToElementType, getInt64Type(),
llvmArrayRankElementSizeType, llvmArrayRankElementSizeType});
}
@@ -153,7 +150,6 @@ class VulkanLaunchFuncToVulkanCallsPass
void runOnOperation() override;
private:
- LLVM::LLVMDialect *llvmDialect;
LLVM::LLVMType llvmFloatType;
LLVM::LLVMType llvmVoidType;
LLVM::LLVMType llvmPointerType;
@@ -245,7 +241,7 @@ void VulkanLaunchFuncToVulkanCallsPass::createBindMemRefCalls(
// int16_t and bitcast the descriptor.
if (type.isHalfTy()) {
auto memRefTy =
- getMemRefType(rank, LLVM::LLVMType::getInt16Ty(llvmDialect));
+ getMemRefType(rank, LLVM::LLVMType::getInt16Ty(&getContext()));
ptrToMemRefDescriptor = builder.create<LLVM::BitcastOp>(
loc, memRefTy.getPointerTo(), ptrToMemRefDescriptor);
}
@@ -324,15 +320,15 @@ void VulkanLaunchFuncToVulkanCallsPass::declareVulkanFunctions(Location loc) {
}
for (unsigned i = 1; i <= 3; i++) {
- for (LLVM::LLVMType type : {LLVM::LLVMType::getFloatTy(llvmDialect),
- LLVM::LLVMType::getInt32Ty(llvmDialect),
- LLVM::LLVMType::getInt16Ty(llvmDialect),
- LLVM::LLVMType::getInt8Ty(llvmDialect),
- LLVM::LLVMType::getHalfTy(llvmDialect)}) {
+ for (LLVM::LLVMType type : {LLVM::LLVMType::getFloatTy(&getContext()),
+ LLVM::LLVMType::getInt32Ty(&getContext()),
+ LLVM::LLVMType::getInt16Ty(&getContext()),
+ LLVM::LLVMType::getInt8Ty(&getContext()),
+ LLVM::LLVMType::getHalfTy(&getContext())}) {
std::string fnName = "bindMemRef" + std::to_string(i) + "D" +
std::string(stringifyType(type));
if (type.isHalfTy())
- type = getMemRefType(i, LLVM::LLVMType::getInt16Ty(llvmDialect));
+ type = getMemRefType(i, LLVM::LLVMType::getInt16Ty(&getContext()));
if (!module.lookupSymbol(fnName)) {
auto fnType = LLVM::LLVMType::getFunctionTy(
getVoidType(),
@@ -368,8 +364,7 @@ Value VulkanLaunchFuncToVulkanCallsPass::createEntryPointNameConstant(
std::string entryPointGlobalName = (name + "_spv_entry_point_name").str();
return LLVM::createGlobalString(loc, builder, entryPointGlobalName,
- shaderName, LLVM::Linkage::Internal,
- getLLVMDialect());
+ shaderName, LLVM::Linkage::Internal);
}
void VulkanLaunchFuncToVulkanCallsPass::translateVulkanLaunchCall(
@@ -388,7 +383,7 @@ void VulkanLaunchFuncToVulkanCallsPass::translateVulkanLaunchCall(
// that data to runtime call.
Value ptrToSPIRVBinary = LLVM::createGlobalString(
loc, builder, kSPIRVBinary, spirvAttributes.first.getValue(),
- LLVM::Linkage::Internal, getLLVMDialect());
+ LLVM::Linkage::Internal);
// Create LLVM constant for the size of SPIR-V binary shader.
Value binarySize = builder.create<LLVM::ConstantOp>(
diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp
index 0c326287e69e..024f2b14a989 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp
@@ -186,15 +186,15 @@ static Type convertStructTypePacked(spirv::StructType type,
llvm::map_range(type.getElementTypes(), [&](Type elementType) {
return converter.convertType(elementType).cast<LLVM::LLVMType>();
}));
- return LLVM::LLVMType::getStructTy(converter.getDialect(), elementsVector,
+ return LLVM::LLVMType::getStructTy(type.getContext(), elementsVector,
/*isPacked=*/true);
}
/// Creates LLVM dialect constant with the given value.
static Value createI32ConstantOf(Location loc, PatternRewriter &rewriter,
- LLVMTypeConverter &converter, unsigned value) {
+ unsigned value) {
return rewriter.create<LLVM::ConstantOp>(
- loc, LLVM::LLVMType::getInt32Ty(converter.getDialect()),
+ loc, LLVM::LLVMType::getInt32Ty(rewriter.getContext()),
rewriter.getIntegerAttr(rewriter.getI32Type(), value));
}
@@ -1002,7 +1002,7 @@ class VariablePattern : public SPIRVToLLVMConversion<spirv::VariableOp> {
return failure();
Location loc = varOp.getLoc();
- Value size = createI32ConstantOf(loc, rewriter, typeConverter, 1);
+ Value size = createI32ConstantOf(loc, rewriter, 1);
if (!init) {
rewriter.replaceOpWithNewOp<LLVM::AllocaOp>(varOp, dstType, size);
return success();
diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index e7c8770ed8f3..a5ecbe4381de 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -199,7 +199,7 @@ llvm::LLVMContext &LLVMTypeConverter::getLLVMContext() {
}
LLVM::LLVMType LLVMTypeConverter::getIndexType() {
- return LLVM::LLVMType::getIntNTy(llvmDialect, getIndexTypeBitwidth());
+ return LLVM::LLVMType::getIntNTy(&getContext(), getIndexTypeBitwidth());
}
unsigned LLVMTypeConverter::getPointerBitwidth(unsigned addressSpace) {
@@ -211,19 +211,19 @@ Type LLVMTypeConverter::convertIndexType(IndexType type) {
}
Type LLVMTypeConverter::convertIntegerType(IntegerType type) {
- return LLVM::LLVMType::getIntNTy(llvmDialect, type.getWidth());
+ return LLVM::LLVMType::getIntNTy(&getContext(), type.getWidth());
}
Type LLVMTypeConverter::convertFloatType(FloatType type) {
switch (type.getKind()) {
case mlir::StandardTypes::F32:
- return LLVM::LLVMType::getFloatTy(llvmDialect);
+ return LLVM::LLVMType::getFloatTy(&getContext());
case mlir::StandardTypes::F64:
- return LLVM::LLVMType::getDoubleTy(llvmDialect);
+ return LLVM::LLVMType::getDoubleTy(&getContext());
case mlir::StandardTypes::F16:
- return LLVM::LLVMType::getHalfTy(llvmDialect);
+ return LLVM::LLVMType::getHalfTy(&getContext());
case mlir::StandardTypes::BF16: {
- return LLVM::LLVMType::getBFloatTy(llvmDialect);
+ return LLVM::LLVMType::getBFloatTy(&getContext());
}
default:
llvm_unreachable("non-float type in convertFloatType");
@@ -238,7 +238,7 @@ static constexpr unsigned kRealPosInComplexNumberStruct = 0;
static constexpr unsigned kImaginaryPosInComplexNumberStruct = 1;
Type LLVMTypeConverter::convertComplexType(ComplexType type) {
auto elementType = convertType(type.getElementType()).cast<LLVM::LLVMType>();
- return LLVM::LLVMType::getStructTy(llvmDialect, {elementType, elementType});
+ return LLVM::LLVMType::getStructTy(&getContext(), {elementType, elementType});
}
// Except for signatures, MLIR function types are converted into LLVM
@@ -274,7 +274,7 @@ LLVMTypeConverter::convertMemRefSignature(MemRefType type) {
/// In signatures, unranked MemRef descriptors are expanded into a pair "rank,
/// pointer to descriptor".
SmallVector<Type, 2> LLVMTypeConverter::convertUnrankedMemRefSignature() {
- return {getIndexType(), LLVM::LLVMType::getInt8PtrTy(llvmDialect)};
+ return {getIndexType(), LLVM::LLVMType::getInt8PtrTy(&getContext())};
}
// Function types are converted to LLVM Function types by recursively converting
@@ -307,7 +307,7 @@ LLVM::LLVMType LLVMTypeConverter::convertFunctionSignature(
// a struct.
LLVM::LLVMType resultType =
type.getNumResults() == 0
- ? LLVM::LLVMType::getVoidTy(llvmDialect)
+ ? LLVM::LLVMType::getVoidTy(&getContext())
: unwrap(packFunctionResults(type.getResults()));
if (!resultType)
return {};
@@ -331,7 +331,7 @@ LLVMTypeConverter::convertFunctionTypeCWrapper(FunctionType type) {
LLVM::LLVMType resultType =
type.getNumResults() == 0
- ? LLVM::LLVMType::getVoidTy(llvmDialect)
+ ? LLVM::LLVMType::getVoidTy(&getContext())
: unwrap(packFunctionResults(type.getResults()));
if (!resultType)
return {};
@@ -400,7 +400,7 @@ static constexpr unsigned kPtrInUnrankedMemRefDescriptor = 1;
Type LLVMTypeConverter::convertUnrankedMemRefType(UnrankedMemRefType type) {
auto rankTy = getIndexType();
- auto ptrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
+ auto ptrTy = LLVM::LLVMType::getInt8PtrTy(&getContext());
return LLVM::LLVMType::getStructTy(rankTy, ptrTy);
}
@@ -853,11 +853,11 @@ LLVM::LLVMType ConvertToLLVMPattern::getIndexType() const {
}
LLVM::LLVMType ConvertToLLVMPattern::getVoidType() const {
- return LLVM::LLVMType::getVoidTy(&getDialect());
+ return LLVM::LLVMType::getVoidTy(&typeConverter.getContext());
}
LLVM::LLVMType ConvertToLLVMPattern::getVoidPtrType() const {
- return LLVM::LLVMType::getInt8PtrTy(&getDialect());
+ return LLVM::LLVMType::getInt8PtrTy(&typeConverter.getContext());
}
Value ConvertToLLVMPattern::createIndexConstant(
@@ -2025,9 +2025,10 @@ static LogicalResult copyUnrankedDescriptors(OpBuilder &builder, Location loc,
unrankedMemrefs, sizes);
// Get frequently used types.
- auto voidType = LLVM::LLVMType::getVoidTy(typeConverter.getDialect());
- auto voidPtrType = LLVM::LLVMType::getInt8PtrTy(typeConverter.getDialect());
- auto i1Type = LLVM::LLVMType::getInt1Ty(typeConverter.getDialect());
+ MLIRContext *context = builder.getContext();
+ auto voidType = LLVM::LLVMType::getVoidTy(context);
+ auto voidPtrType = LLVM::LLVMType::getInt8PtrTy(context);
+ auto i1Type = LLVM::LLVMType::getInt1Ty(context);
LLVM::LLVMType indexType = typeConverter.getIndexType();
// Find the malloc and free, or declare them if necessary.
@@ -3168,7 +3169,7 @@ struct GenericAtomicRMWOpLowering
// Append the cmpxchg op to the end of the loop block.
auto successOrdering = LLVM::AtomicOrdering::acq_rel;
auto failureOrdering = LLVM::AtomicOrdering::monotonic;
- auto boolType = LLVM::LLVMType::getInt1Ty(&getDialect());
+ auto boolType = LLVM::LLVMType::getInt1Ty(rewriter.getContext());
auto pairType = LLVM::LLVMType::getStructTy(valueType, boolType);
auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>(
loc, pairType, dataPtr, loopArgument, result, successOrdering,
@@ -3330,13 +3331,13 @@ Type LLVMTypeConverter::packFunctionResults(ArrayRef<Type> types) {
resultTypes.push_back(converted);
}
- return LLVM::LLVMType::getStructTy(llvmDialect, resultTypes);
+ return LLVM::LLVMType::getStructTy(&getContext(), resultTypes);
}
Value LLVMTypeConverter::promoteOneMemRefDescriptor(Location loc, Value operand,
OpBuilder &builder) {
auto *context = builder.getContext();
- auto int64Ty = LLVM::LLVMType::getInt64Ty(getDialect());
+ auto int64Ty = LLVM::LLVMType::getInt64Ty(builder.getContext());
auto indexType = IndexType::get(context);
// Alloca with proper alignment. We do not expect optimizations of this
// alloca op and so we omit allocating at the entry block.
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 011143b810d9..f5d66f28f0ce 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -715,7 +715,7 @@ class VectorExtractOpConversion : public ConvertToLLVMPattern {
// Remaining extraction of element from 1-D LLVM vector
auto position = positionAttrs.back().cast<IntegerAttr>();
- auto i64Type = LLVM::LLVMType::getInt64Ty(typeConverter.getDialect());
+ auto i64Type = LLVM::LLVMType::getInt64Ty(rewriter.getContext());
auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
extracted =
rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant);
@@ -832,7 +832,7 @@ class VectorInsertOpConversion : public ConvertToLLVMPattern {
}
// Insertion of an element into a 1-D LLVM vector.
- auto i64Type = LLVM::LLVMType::getInt64Ty(typeConverter.getDialect());
+ auto i64Type = LLVM::LLVMType::getInt64Ty(rewriter.getContext());
auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
Value inserted = rewriter.create<LLVM::InsertElementOp>(
loc, typeConverter.convertType(oneDVectorType), extracted,
@@ -1074,7 +1074,7 @@ class VectorTypeCastOpConversion : public ConvertToLLVMPattern {
if (failed(successStrides) || !isContiguous)
return failure();
- auto int64Ty = LLVM::LLVMType::getInt64Ty(typeConverter.getDialect());
+ auto int64Ty = LLVM::LLVMType::getInt64Ty(rewriter.getContext());
// Create descriptor.
auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
@@ -1263,11 +1263,10 @@ class VectorPrintOpConversion : public ConvertToLLVMPattern {
int64_t rank) const {
Location loc = op->getLoc();
if (rank == 0) {
- if (value.getType() ==
- LLVM::LLVMType::getInt1Ty(typeConverter.getDialect())) {
+ if (value.getType() == LLVM::LLVMType::getInt1Ty(rewriter.getContext())) {
// Convert i1 (bool) to i32 so we can use the print_i32 method.
// This avoids the need for a print_i1 method with an unclear ABI.
- auto i32Type = LLVM::LLVMType::getInt32Ty(typeConverter.getDialect());
+ auto i32Type = LLVM::LLVMType::getInt32Ty(rewriter.getContext());
auto trueVal = rewriter.create<ConstantOp>(
loc, i32Type, rewriter.getI32IntegerAttr(1));
auto falseVal = rewriter.create<ConstantOp>(
@@ -1303,8 +1302,8 @@ class VectorPrintOpConversion : public ConvertToLLVMPattern {
}
// Helper for printer method declaration (first hit) and lookup.
- static Operation *getPrint(Operation *op, LLVM::LLVMDialect *dialect,
- StringRef name, ArrayRef<LLVM::LLVMType> params) {
+ static Operation *getPrint(Operation *op, StringRef name,
+ ArrayRef<LLVM::LLVMType> params) {
auto module = op->getParentOfType<ModuleOp>();
auto func = module.lookupSymbol<LLVM::LLVMFuncOp>(name);
if (func)
@@ -1312,42 +1311,39 @@ class VectorPrintOpConversion : public ConvertToLLVMPattern {
OpBuilder moduleBuilder(module.getBodyRegion());
return moduleBuilder.create<LLVM::LLVMFuncOp>(
op->getLoc(), name,
- LLVM::LLVMType::getFunctionTy(LLVM::LLVMType::getVoidTy(dialect),
- params, /*isVarArg=*/false));
+ LLVM::LLVMType::getFunctionTy(
+ LLVM::LLVMType::getVoidTy(op->getContext()), params,
+ /*isVarArg=*/false));
}
// Helpers for method names.
Operation *getPrintI32(Operation *op) const {
- LLVM::LLVMDialect *dialect = typeConverter.getDialect();
- return getPrint(op, dialect, "print_i32",
- LLVM::LLVMType::getInt32Ty(dialect));
+ return getPrint(op, "print_i32",
+ LLVM::LLVMType::getInt32Ty(op->getContext()));
}
Operation *getPrintI64(Operation *op) const {
- LLVM::LLVMDialect *dialect = typeConverter.getDialect();
- return getPrint(op, dialect, "print_i64",
- LLVM::LLVMType::getInt64Ty(dialect));
+ return getPrint(op, "print_i64",
+ LLVM::LLVMType::getInt64Ty(op->getContext()));
}
Operation *getPrintFloat(Operation *op) const {
- LLVM::LLVMDialect *dialect = typeConverter.getDialect();
- return getPrint(op, dialect, "print_f32",
- LLVM::LLVMType::getFloatTy(dialect));
+ return getPrint(op, "print_f32",
+ LLVM::LLVMType::getFloatTy(op->getContext()));
}
Operation *getPrintDouble(Operation *op) const {
- LLVM::LLVMDialect *dialect = typeConverter.getDialect();
- return getPrint(op, dialect, "print_f64",
- LLVM::LLVMType::getDoubleTy(dialect));
+ return getPrint(op, "print_f64",
+ LLVM::LLVMType::getDoubleTy(op->getContext()));
}
Operation *getPrintOpen(Operation *op) const {
- return getPrint(op, typeConverter.getDialect(), "print_open", {});
+ return getPrint(op, "print_open", {});
}
Operation *getPrintClose(Operation *op) const {
- return getPrint(op, typeConverter.getDialect(), "print_close", {});
+ return getPrint(op, "print_close", {});
}
Operation *getPrintComma(Operation *op) const {
- return getPrint(op, typeConverter.getDialect(), "print_comma", {});
+ return getPrint(op, "print_comma", {});
}
Operation *getPrintNewline(Operation *op) const {
- return getPrint(op, typeConverter.getDialect(), "print_newline", {});
+ return getPrint(op, "print_newline", {});
}
};
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 6a70af4744d1..876fd05aa1b5 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -101,8 +101,7 @@ static ParseResult parseCmpOp(OpAsmParser &parser, OperationState &result) {
// The result type is either i1 or a vector type <? x i1> if the inputs are
// vectors.
- auto *dialect = builder.getContext()->getRegisteredDialect<LLVMDialect>();
- auto resultType = LLVMType::getInt1Ty(dialect);
+ auto resultType = LLVMType::getInt1Ty(builder.getContext());
auto argType = type.dyn_cast<LLVM::LLVMType>();
if (!argType)
return parser.emitError(trailingTypeLoc, "expected LLVM IR dialect type");
@@ -393,11 +392,9 @@ static ParseResult parseInvokeOp(OpAsmParser &parser, OperationState &result) {
return parser.emitError(trailingTypeLoc,
"expected function with 0 or 1 result");
- auto *llvmDialect =
- builder.getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
LLVM::LLVMType llvmResultType;
if (funcType.getNumResults() == 0) {
- llvmResultType = LLVM::LLVMType::getVoidTy(llvmDialect);
+ llvmResultType = LLVM::LLVMType::getVoidTy(builder.getContext());
} else {
llvmResultType = funcType.getResult(0).dyn_cast<LLVM::LLVMType>();
if (!llvmResultType)
@@ -601,11 +598,9 @@ static ParseResult parseCallOp(OpAsmParser &parser, OperationState &result) {
"expected function with 0 or 1 result");
Builder &builder = parser.getBuilder();
- auto *llvmDialect =
- builder.getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
LLVM::LLVMType llvmResultType;
if (funcType.getNumResults() == 0) {
- llvmResultType = LLVM::LLVMType::getVoidTy(llvmDialect);
+ llvmResultType = LLVM::LLVMType::getVoidTy(builder.getContext());
} else {
llvmResultType = funcType.getResult(0).dyn_cast<LLVM::LLVMType>();
if (!llvmResultType)
@@ -1101,9 +1096,8 @@ static ParseResult parseGlobalOp(OpAsmParser &parser, OperationState &result) {
if (types.empty()) {
if (auto strAttr = value.dyn_cast_or_null<StringAttr>()) {
MLIRContext *context = parser.getBuilder().getContext();
- auto *dialect = context->getRegisteredDialect<LLVMDialect>();
auto arrayType = LLVM::LLVMType::getArrayTy(
- LLVM::LLVMType::getInt8Ty(dialect), strAttr.getValue().size());
+ LLVM::LLVMType::getInt8Ty(context), strAttr.getValue().size());
types.push_back(arrayType);
} else {
return parser.emitError(parser.getNameLoc(),
@@ -1265,14 +1259,8 @@ static Type buildLLVMFunctionType(OpAsmParser &parser, llvm::SMLoc loc,
llvmInputs.push_back(llvmTy);
}
- // Get the dialect from the input type, if any exist. Look it up in the
- // context otherwise.
- LLVMDialect *dialect =
- llvmInputs.empty() ? b.getContext()->getRegisteredDialect<LLVMDialect>()
- : &llvmInputs.front().getDialect();
-
// No output is denoted as "void" in LLVM type system.
- LLVMType llvmOutput = outputs.empty() ? LLVMType::getVoidTy(dialect)
+ LLVMType llvmOutput = outputs.empty() ? LLVMType::getVoidTy(b.getContext())
: outputs.front().dyn_cast<LLVMType>();
if (!llvmOutput) {
parser.emitError(loc, "failed to construct function type: expected LLVM "
@@ -1605,8 +1593,7 @@ static ParseResult parseAtomicCmpXchgOp(OpAsmParser &parser,
parser.resolveOperand(val, type, result.operands))
return failure();
- auto *dialect = builder.getContext()->getRegisteredDialect<LLVMDialect>();
- auto boolType = LLVMType::getInt1Ty(dialect);
+ auto boolType = LLVMType::getInt1Ty(builder.getContext());
auto resultType = LLVMType::getStructTy(type, boolType);
result.addTypes(resultType);
@@ -1777,8 +1764,7 @@ LogicalResult LLVMDialect::verifyRegionArgAttribute(Operation *op,
Value mlir::LLVM::createGlobalString(Location loc, OpBuilder &builder,
StringRef name, StringRef value,
- LLVM::Linkage linkage,
- LLVM::LLVMDialect *llvmDialect) {
+ LLVM::Linkage linkage) {
assert(builder.getInsertionBlock() &&
builder.getInsertionBlock()->getParentOp() &&
"expected builder to point to a block constrained in an op");
@@ -1788,8 +1774,9 @@ Value mlir::LLVM::createGlobalString(Location loc, OpBuilder &builder,
// Create the global at the entry of the module.
OpBuilder moduleBuilder(module.getBodyRegion());
- auto type = LLVM::LLVMType::getArrayTy(LLVM::LLVMType::getInt8Ty(llvmDialect),
- value.size());
+ MLIRContext *ctx = builder.getContext();
+ auto type =
+ LLVM::LLVMType::getArrayTy(LLVM::LLVMType::getInt8Ty(ctx), value.size());
auto global = moduleBuilder.create<LLVM::GlobalOp>(
loc, type, /*isConstant=*/true, linkage, name,
builder.getStringAttr(value));
@@ -1797,10 +1784,9 @@ Value mlir::LLVM::createGlobalString(Location loc, OpBuilder &builder,
// Get the pointer to the first character in the global string.
Value globalPtr = builder.create<LLVM::AddressOfOp>(loc, global);
Value cst0 = builder.create<LLVM::ConstantOp>(
- loc, LLVM::LLVMType::getInt64Ty(llvmDialect),
+ loc, LLVM::LLVMType::getInt64Ty(ctx),
builder.getIntegerAttr(builder.getIndexType(), 0));
- return builder.create<LLVM::GEPOp>(loc,
- LLVM::LLVMType::getInt8PtrTy(llvmDialect),
+ return builder.create<LLVM::GEPOp>(loc, LLVM::LLVMType::getInt8PtrTy(ctx),
globalPtr, ArrayRef<Value>({cst0, cst0}));
}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
index fa25f2dcdad8..f8cadeb0c40f 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
@@ -127,35 +127,35 @@ bool LLVMType::isStructTy() { return isa<LLVMStructType>(); }
//----------------------------------------------------------------------------//
// Utilities used to generate floating point types.
-LLVMType LLVMType::getDoubleTy(LLVMDialect *dialect) {
- return LLVMDoubleType::get(dialect->getContext());
+LLVMType LLVMType::getDoubleTy(MLIRContext *context) {
+ return LLVMDoubleType::get(context);
}
-LLVMType LLVMType::getFloatTy(LLVMDialect *dialect) {
- return LLVMFloatType::get(dialect->getContext());
+LLVMType LLVMType::getFloatTy(MLIRContext *context) {
+ return LLVMFloatType::get(context);
}
-LLVMType LLVMType::getBFloatTy(LLVMDialect *dialect) {
- return LLVMBFloatType::get(dialect->getContext());
+LLVMType LLVMType::getBFloatTy(MLIRContext *context) {
+ return LLVMBFloatType::get(context);
}
-LLVMType LLVMType::getHalfTy(LLVMDialect *dialect) {
- return LLVMHalfType::get(dialect->getContext());
+LLVMType LLVMType::getHalfTy(MLIRContext *context) {
+ return LLVMHalfType::get(context);
}
-LLVMType LLVMType::getFP128Ty(LLVMDialect *dialect) {
- return LLVMFP128Type::get(dialect->getContext());
+LLVMType LLVMType::getFP128Ty(MLIRContext *context) {
+ return LLVMFP128Type::get(context);
}
-LLVMType LLVMType::getX86_FP80Ty(LLVMDialect *dialect) {
- return LLVMX86FP80Type::get(dialect->getContext());
+LLVMType LLVMType::getX86_FP80Ty(MLIRContext *context) {
+ return LLVMX86FP80Type::get(context);
}
//----------------------------------------------------------------------------//
// Utilities used to generate integer types.
-LLVMType LLVMType::getIntNTy(LLVMDialect *dialect, unsigned numBits) {
- return LLVMIntegerType::get(dialect->getContext(), numBits);
+LLVMType LLVMType::getIntNTy(MLIRContext *context, unsigned numBits) {
+ return LLVMIntegerType::get(context, numBits);
}
//----------------------------------------------------------------------------//
@@ -170,9 +170,9 @@ LLVMType LLVMType::getFunctionTy(LLVMType result, ArrayRef<LLVMType> params,
return LLVMFunctionType::get(result, params, isVarArg);
}
-LLVMType LLVMType::getStructTy(LLVMDialect *dialect,
+LLVMType LLVMType::getStructTy(MLIRContext *context,
ArrayRef<LLVMType> elements, bool isPacked) {
- return LLVMStructType::getLiteral(dialect->getContext(), elements, isPacked);
+ return LLVMStructType::getLiteral(context, elements, isPacked);
}
LLVMType LLVMType::getVectorTy(LLVMType elementType, unsigned numElements) {
@@ -182,8 +182,8 @@ LLVMType LLVMType::getVectorTy(LLVMType elementType, unsigned numElements) {
//----------------------------------------------------------------------------//
// Void type utilities.
-LLVMType LLVMType::getVoidTy(LLVMDialect *dialect) {
- return LLVMVoidType::get(dialect->getContext());
+LLVMType LLVMType::getVoidTy(MLIRContext *context) {
+ return LLVMVoidType::get(context);
}
bool LLVMType::isVoidTy() { return isa<LLVMVoidType>(); }
@@ -191,7 +191,7 @@ bool LLVMType::isVoidTy() { return isa<LLVMVoidType>(); }
//----------------------------------------------------------------------------//
// Creation and setting of LLVM's identified struct types
-LLVMType LLVMType::createStructTy(LLVMDialect *dialect,
+LLVMType LLVMType::createStructTy(MLIRContext *context,
ArrayRef<LLVMType> elements,
Optional<StringRef> name, bool isPacked) {
assert(name.hasValue() &&
@@ -200,8 +200,7 @@ LLVMType LLVMType::createStructTy(LLVMDialect *dialect,
std::string stringName = stringNameBase.str();
unsigned counter = 0;
do {
- auto type =
- LLVMStructType::getIdentified(dialect->getContext(), stringName);
+ auto type = LLVMStructType::getIdentified(context, stringName);
if (type.isInitialized() || failed(type.setBody(elements, isPacked))) {
counter += 1;
stringName =
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 9a694a5e9899..9a09488570e1 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -41,12 +41,6 @@ static void printNVVMIntrinsicOp(OpAsmPrinter &p, Operation *op) {
p << " : " << op->getResultTypes();
}
-static LLVM::LLVMDialect *getLlvmDialect(OpAsmParser &parser) {
- return parser.getBuilder()
- .getContext()
- ->getRegisteredDialect<LLVM::LLVMDialect>();
-}
-
// <operation> ::=
// `llvm.nvvm.shfl.sync.bfly %dst, %val, %offset, %clamp_and_mask`
// ({return_value_and_is_valid})? : result_type
@@ -69,7 +63,7 @@ static ParseResult parseNVVMShflSyncBflyOp(OpAsmParser &parser,
break;
}
- auto int32Ty = LLVM::LLVMType::getInt32Ty(getLlvmDialect(parser));
+ auto int32Ty = LLVM::LLVMType::getInt32Ty(parser.getBuilder().getContext());
return parser.resolveOperands(ops, {int32Ty, type, int32Ty, int32Ty},
parser.getNameLoc(), result.operands);
}
@@ -77,9 +71,9 @@ static ParseResult parseNVVMShflSyncBflyOp(OpAsmParser &parser,
// <operation> ::= `llvm.nvvm.vote.ballot.sync %mask, %pred` : result_type
static ParseResult parseNVVMVoteBallotOp(OpAsmParser &parser,
OperationState &result) {
- auto llvmDialect = getLlvmDialect(parser);
- auto int32Ty = LLVM::LLVMType::getInt32Ty(llvmDialect);
- auto int1Ty = LLVM::LLVMType::getInt1Ty(llvmDialect);
+ MLIRContext *context = parser.getBuilder().getContext();
+ auto int32Ty = LLVM::LLVMType::getInt32Ty(context);
+ auto int1Ty = LLVM::LLVMType::getInt1Ty(context);
SmallVector<OpAsmParser::OperandType, 8> ops;
Type type;
@@ -92,14 +86,14 @@ static ParseResult parseNVVMVoteBallotOp(OpAsmParser &parser,
}
static LogicalResult verify(MmaOp op) {
- auto dialect = op.getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
- auto f16Ty = LLVM::LLVMType::getHalfTy(dialect);
+ MLIRContext *context = op.getContext();
+ auto f16Ty = LLVM::LLVMType::getHalfTy(context);
auto f16x2Ty = LLVM::LLVMType::getVectorTy(f16Ty, 2);
- auto f32Ty = LLVM::LLVMType::getFloatTy(dialect);
+ auto f32Ty = LLVM::LLVMType::getFloatTy(context);
auto f16x2x4StructTy = LLVM::LLVMType::getStructTy(
- dialect, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
+ context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
auto f32x8StructTy = LLVM::LLVMType::getStructTy(
- dialect, {f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty});
+ context, {f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty});
SmallVector<Type, 12> operand_types(op.getOperandTypes().begin(),
op.getOperandTypes().end());
diff --git a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
index f3771dd57719..47089b9d934d 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
@@ -34,12 +34,6 @@ using namespace ROCDL;
// Parsing for ROCDL ops
//===----------------------------------------------------------------------===//
-static LLVM::LLVMDialect *getLlvmDialect(OpAsmParser &parser) {
- return parser.getBuilder()
- .getContext()
- ->getRegisteredDialect<LLVM::LLVMDialect>();
-}
-
// <operation> ::=
// `llvm.amdgcn.buffer.load.* %rsrc, %vindex, %offset, %glc, %slc :
// result_type`
@@ -51,8 +45,9 @@ static ParseResult parseROCDLMubufLoadOp(OpAsmParser &parser,
parser.addTypeToList(type, result.types))
return failure();
- auto int32Ty = LLVM::LLVMType::getInt32Ty(getLlvmDialect(parser));
- auto int1Ty = LLVM::LLVMType::getInt1Ty(getLlvmDialect(parser));
+ MLIRContext *context = parser.getBuilder().getContext();
+ auto int32Ty = LLVM::LLVMType::getInt32Ty(context);
+ auto int1Ty = LLVM::LLVMType::getInt1Ty(context);
auto i32x4Ty = LLVM::LLVMType::getVectorTy(int32Ty, 4);
return parser.resolveOperands(ops,
{i32x4Ty, int32Ty, int32Ty, int1Ty, int1Ty},
@@ -69,8 +64,9 @@ static ParseResult parseROCDLMubufStoreOp(OpAsmParser &parser,
if (parser.parseOperandList(ops, 6) || parser.parseColonType(type))
return failure();
- auto int32Ty = LLVM::LLVMType::getInt32Ty(getLlvmDialect(parser));
- auto int1Ty = LLVM::LLVMType::getInt1Ty(getLlvmDialect(parser));
+ MLIRContext *context = parser.getBuilder().getContext();
+ auto int32Ty = LLVM::LLVMType::getInt32Ty(context);
+ auto int1Ty = LLVM::LLVMType::getInt1Ty(context);
auto i32x4Ty = LLVM::LLVMType::getVectorTy(int32Ty, 4);
if (parser.resolveOperands(ops,
More information about the Mlir-commits
mailing list