[Mlir-commits] [mlir] d3a9807 - [mlir] Remove most uses of LLVMDialect::getModule
Alex Zinenko
llvmlistbot at llvm.org
Thu Aug 6 01:54:38 PDT 2020
Author: Alex Zinenko
Date: 2020-08-06T10:54:30+02:00
New Revision: d3a9807674c1d7000bd5ec4028be399c81cbd098
URL: https://github.com/llvm/llvm-project/commit/d3a9807674c1d7000bd5ec4028be399c81cbd098
DIFF: https://github.com/llvm/llvm-project/commit/d3a9807674c1d7000bd5ec4028be399c81cbd098.diff
LOG: [mlir] Remove most uses of LLVMDialect::getModule
This prepares for the removal of llvm::Module and LLVMContext from the
mlir::LLVMDialect.
Reviewed By: rriddle
Differential Revision: https://reviews.llvm.org/D85371
Added:
Modified:
mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp
mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
index 43a1587f0353..1bf46b91b7a6 100644
--- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
+++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
@@ -118,8 +118,7 @@ class LLVMTypeConverter : public TypeConverter {
unsigned getPointerBitwidth(unsigned addressSpace = 0);
protected:
- /// LLVM IR module used to parse/create types.
- llvm::Module *module;
+ /// Pointer to the LLVM dialect.
LLVM::LLVMDialect *llvmDialect;
private:
@@ -400,9 +399,6 @@ class ConvertToLLVMPattern : public ConversionPattern {
/// Returns the LLVM IR context.
llvm::LLVMContext &getContext() const;
- /// Returns the LLVM IR module associated with the LLVM dialect.
- llvm::Module &getModule() const;
-
/// Gets the MLIR type wrapping the LLVM integer type whose bit width is
/// defined by the used type converter.
LLVM::LLVMType getIndexType() const;
@@ -437,8 +433,8 @@ class ConvertToLLVMPattern : public ConversionPattern {
ConversionPatternRewriter &rewriter) const;
Value getDataPtr(Location loc, MemRefType type, Value memRefDesc,
- ValueRange indices, ConversionPatternRewriter &rewriter,
- llvm::Module &module) const;
+ ValueRange indices,
+ ConversionPatternRewriter &rewriter) const;
/// Returns the type of a pointer to an element of the memref.
Type getElementPtrType(MemRefType type) const;
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
index 4d99bf265c65..6b265e73d897 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
@@ -25,6 +25,7 @@ def LLVM_Dialect : Dialect {
llvm::LLVMContext &getLLVMContext();
llvm::Module &getLLVMModule();
llvm::sys::SmartMutex<true> &getLLVMContextMutex();
+ const llvm::DataLayout &getDataLayout();
private:
friend LLVMType;
diff --git a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp
index 14011e08de02..c5ecaf798ebd 100644
--- a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp
+++ b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp
@@ -66,12 +66,7 @@ class GpuLaunchFuncToGpuRuntimeCallsPass
private:
LLVM::LLVMDialect *getLLVMDialect() { return llvmDialect; }
- llvm::LLVMContext &getLLVMContext() {
- return getLLVMDialect()->getLLVMContext();
- }
-
void initializeCachedTypes() {
- const llvm::Module &module = llvmDialect->getLLVMModule();
llvmVoidType = LLVM::LLVMType::getVoidTy(llvmDialect);
llvmPointerType = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
llvmPointerPointerType = llvmPointerType.getPointerTo();
@@ -79,7 +74,7 @@ class GpuLaunchFuncToGpuRuntimeCallsPass
llvmInt32Type = LLVM::LLVMType::getInt32Ty(llvmDialect);
llvmInt64Type = LLVM::LLVMType::getInt64Ty(llvmDialect);
llvmIntPtrType = LLVM::LLVMType::getIntNTy(
- llvmDialect, module.getDataLayout().getPointerSizeInBits());
+ llvmDialect, llvmDialect->getDataLayout().getPointerSizeInBits());
}
LLVM::LLVMType getVoidType() { return llvmVoidType; }
@@ -95,9 +90,9 @@ class GpuLaunchFuncToGpuRuntimeCallsPass
LLVM::LLVMType getInt64Type() { return llvmInt64Type; }
LLVM::LLVMType getIntPtrType() {
- const llvm::Module &module = getLLVMDialect()->getLLVMModule();
return LLVM::LLVMType::getIntNTy(
- getLLVMDialect(), module.getDataLayout().getPointerSizeInBits());
+ getLLVMDialect(),
+ getLLVMDialect()->getDataLayout().getPointerSizeInBits());
}
// Allocate a void pointer on the stack.
diff --git a/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp b/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
index e6527e0ca42b..c1a64bd091a9 100644
--- a/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
+++ b/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
@@ -59,10 +59,6 @@ class VulkanLaunchFuncToVulkanCallsPass
private:
LLVM::LLVMDialect *getLLVMDialect() { return llvmDialect; }
- llvm::LLVMContext &getLLVMContext() {
- return getLLVMDialect()->getLLVMContext();
- }
-
void initializeCachedTypes() {
llvmDialect = getContext().getRegisteredDialect<LLVM::LLVMDialect>();
llvmFloatType = LLVM::LLVMType::getFloatTy(llvmDialect);
diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index 9777071aa124..e7c8770ed8f3 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -128,10 +128,9 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
: llvmDialect(ctx->getRegisteredDialect<LLVM::LLVMDialect>()),
options(options) {
assert(llvmDialect && "LLVM IR dialect is not registered");
- module = &llvmDialect->getLLVMModule();
if (options.indexBitwidth == kDeriveIndexBitwidthFromDataLayout)
this->options.indexBitwidth =
- module->getDataLayout().getPointerSizeInBits();
+ llvmDialect->getDataLayout().getPointerSizeInBits();
// Register conversions for the standard types.
addConversion([&](ComplexType type) { return convertComplexType(type); });
@@ -196,7 +195,7 @@ MLIRContext &LLVMTypeConverter::getContext() {
/// Get the LLVM context.
llvm::LLVMContext &LLVMTypeConverter::getLLVMContext() {
- return module->getContext();
+ return llvmDialect->getLLVMContext();
}
LLVM::LLVMType LLVMTypeConverter::getIndexType() {
@@ -204,7 +203,7 @@ LLVM::LLVMType LLVMTypeConverter::getIndexType() {
}
unsigned LLVMTypeConverter::getPointerBitwidth(unsigned addressSpace) {
- return module->getDataLayout().getPointerSizeInBits(addressSpace);
+ return llvmDialect->getDataLayout().getPointerSizeInBits(addressSpace);
}
Type LLVMTypeConverter::convertIndexType(IndexType type) {
@@ -849,10 +848,6 @@ llvm::LLVMContext &ConvertToLLVMPattern::getContext() const {
return typeConverter.getLLVMContext();
}
-llvm::Module &ConvertToLLVMPattern::getModule() const {
- return getDialect().getLLVMModule();
-}
-
LLVM::LLVMType ConvertToLLVMPattern::getIndexType() const {
return typeConverter.getIndexType();
}
@@ -910,10 +905,9 @@ Value ConvertToLLVMPattern::getStridedElementPtr(
return rewriter.create<LLVM::GEPOp>(loc, elementTypePtr, base, offsetValue);
}
-Value ConvertToLLVMPattern::getDataPtr(Location loc, MemRefType type,
- Value memRefDesc, ValueRange indices,
- ConversionPatternRewriter &rewriter,
- llvm::Module &module) const {
+Value ConvertToLLVMPattern::getDataPtr(
+ Location loc, MemRefType type, Value memRefDesc, ValueRange indices,
+ ConversionPatternRewriter &rewriter) const {
LLVM::LLVMType ptrType = MemRefDescriptor(memRefDesc).getElementType();
int64_t offset;
SmallVector<int64_t, 4> strides;
@@ -2451,7 +2445,7 @@ struct LoadOpLowering : public LoadStoreOpLowering<LoadOp> {
auto type = loadOp.getMemRefType();
Value dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(),
- transformed.indices(), rewriter, getModule());
+ transformed.indices(), rewriter);
rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, dataPtr);
return success();
}
@@ -2469,7 +2463,7 @@ struct StoreOpLowering : public LoadStoreOpLowering<StoreOp> {
StoreOp::Adaptor transformed(operands);
Value dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(),
- transformed.indices(), rewriter, getModule());
+ transformed.indices(), rewriter);
rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, transformed.value(),
dataPtr);
return success();
@@ -2489,7 +2483,7 @@ struct PrefetchOpLowering : public LoadStoreOpLowering<PrefetchOp> {
auto type = prefetchOp.getMemRefType();
Value dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(),
- transformed.indices(), rewriter, getModule());
+ transformed.indices(), rewriter);
// Replace with llvm.prefetch.
auto llvmI32Type = typeConverter.convertType(rewriter.getIntegerType(32));
@@ -3086,7 +3080,7 @@ struct AtomicRMWOpLowering : public LoadStoreOpLowering<AtomicRMWOp> {
auto resultType = adaptor.value().getType();
auto memRefType = atomicOp.getMemRefType();
auto dataPtr = getDataPtr(op->getLoc(), memRefType, adaptor.memref(),
- adaptor.indices(), rewriter, getModule());
+ adaptor.indices(), rewriter);
rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>(
op, resultType, *maybeKind, dataPtr, adaptor.value(),
LLVM::AtomicOrdering::acq_rel);
@@ -3152,7 +3146,7 @@ struct GenericAtomicRMWOpLowering
rewriter.setInsertionPointToEnd(initBlock);
auto memRefType = atomicOp.memref().getType().cast<MemRefType>();
auto dataPtr = getDataPtr(loc, memRefType, adaptor.memref(),
- adaptor.indices(), rewriter, getModule());
+ adaptor.indices(), rewriter);
Value init = rewriter.create<LLVM::LoadOp>(loc, dataPtr);
rewriter.create<LLVM::BrOp>(loc, init, loopBlock);
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 64d38e4cc293..011143b810d9 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -131,7 +131,7 @@ LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, T op,
LLVM::LLVMDialect *dialect = typeConverter.getDialect();
align = LLVM::TypeToLLVMIRTranslator(dialect->getLLVMContext())
.getPreferredAlignment(elementTy.cast<LLVM::LLVMType>(),
- dialect->getLLVMModule().getDataLayout());
+ dialect->getDataLayout());
return success();
}
@@ -1152,7 +1152,7 @@ class VectorTransferConversion : public ConvertToLLVMPattern {
// address space 0.
// TODO: support alignment when possible.
Value dataPtr = getDataPtr(loc, memRefType, adaptor.memref(),
- adaptor.indices(), rewriter, getModule());
+ adaptor.indices(), rewriter);
auto vecTy =
toLLVMTy(xferOp.getVectorType()).template cast<LLVM::LLVMType>();
Value vectorDataPtr;
diff --git a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
index 9c8556dbab3a..f699ef054d1a 100644
--- a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
+++ b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
@@ -103,7 +103,7 @@ class VectorTransferConversion : public ConvertToLLVMPattern {
// indices, so no need to calculat offset size in bytes again in
// the MUBUF instruction.
Value dataPtr = getDataPtr(loc, memRefType, adaptor.memref(),
- adaptor.indices(), rewriter, getModule());
+ adaptor.indices(), rewriter);
// 1. Create and fill a <4 x i32> dwordConfig with:
// 1st two elements holding the address of dataPtr.
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 47129d7bd615..6a70af4744d1 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -1741,6 +1741,9 @@ llvm::Module &LLVMDialect::getLLVMModule() { return impl->module; }
llvm::sys::SmartMutex<true> &LLVMDialect::getLLVMContextMutex() {
return impl->mutex;
}
+const llvm::DataLayout &LLVMDialect::getDataLayout() {
+ return impl->module.getDataLayout();
+}
/// Parse a type registered to this dialect.
Type LLVMDialect::parseType(DialectAsmParser &parser) const {
More information about the Mlir-commits
mailing list