[Mlir-commits] [mlir] d3a9807 - [mlir] Remove most uses of LLVMDialect::getModule

Alex Zinenko llvmlistbot at llvm.org
Thu Aug 6 01:54:38 PDT 2020


Author: Alex Zinenko
Date: 2020-08-06T10:54:30+02:00
New Revision: d3a9807674c1d7000bd5ec4028be399c81cbd098

URL: https://github.com/llvm/llvm-project/commit/d3a9807674c1d7000bd5ec4028be399c81cbd098
DIFF: https://github.com/llvm/llvm-project/commit/d3a9807674c1d7000bd5ec4028be399c81cbd098.diff

LOG: [mlir] Remove most uses of LLVMDialect::getModule

This prepares for the removal of llvm::Module and LLVMContext from the
mlir::LLVMDialect.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D85371

Added: 
    

Modified: 
    mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
    mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
    mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp
    mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
    mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
    mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
    mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
index 43a1587f0353..1bf46b91b7a6 100644
--- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
+++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
@@ -118,8 +118,7 @@ class LLVMTypeConverter : public TypeConverter {
   unsigned getPointerBitwidth(unsigned addressSpace = 0);
 
 protected:
-  /// LLVM IR module used to parse/create types.
-  llvm::Module *module;
+  /// Pointer to the LLVM dialect.
   LLVM::LLVMDialect *llvmDialect;
 
 private:
@@ -400,9 +399,6 @@ class ConvertToLLVMPattern : public ConversionPattern {
   /// Returns the LLVM IR context.
   llvm::LLVMContext &getContext() const;
 
-  /// Returns the LLVM IR module associated with the LLVM dialect.
-  llvm::Module &getModule() const;
-
   /// Gets the MLIR type wrapping the LLVM integer type whose bit width is
   /// defined by the used type converter.
   LLVM::LLVMType getIndexType() const;
@@ -437,8 +433,8 @@ class ConvertToLLVMPattern : public ConversionPattern {
                              ConversionPatternRewriter &rewriter) const;
 
   Value getDataPtr(Location loc, MemRefType type, Value memRefDesc,
-                   ValueRange indices, ConversionPatternRewriter &rewriter,
-                   llvm::Module &module) const;
+                   ValueRange indices,
+                   ConversionPatternRewriter &rewriter) const;
 
   /// Returns the type of a pointer to an element of the memref.
   Type getElementPtrType(MemRefType type) const;

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
index 4d99bf265c65..6b265e73d897 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
@@ -25,6 +25,7 @@ def LLVM_Dialect : Dialect {
     llvm::LLVMContext &getLLVMContext();
     llvm::Module &getLLVMModule();
     llvm::sys::SmartMutex<true> &getLLVMContextMutex();
+    const llvm::DataLayout &getDataLayout();
 
   private:
     friend LLVMType;

diff  --git a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp
index 14011e08de02..c5ecaf798ebd 100644
--- a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp
+++ b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp
@@ -66,12 +66,7 @@ class GpuLaunchFuncToGpuRuntimeCallsPass
 private:
   LLVM::LLVMDialect *getLLVMDialect() { return llvmDialect; }
 
-  llvm::LLVMContext &getLLVMContext() {
-    return getLLVMDialect()->getLLVMContext();
-  }
-
   void initializeCachedTypes() {
-    const llvm::Module &module = llvmDialect->getLLVMModule();
     llvmVoidType = LLVM::LLVMType::getVoidTy(llvmDialect);
     llvmPointerType = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
     llvmPointerPointerType = llvmPointerType.getPointerTo();
@@ -79,7 +74,7 @@ class GpuLaunchFuncToGpuRuntimeCallsPass
     llvmInt32Type = LLVM::LLVMType::getInt32Ty(llvmDialect);
     llvmInt64Type = LLVM::LLVMType::getInt64Ty(llvmDialect);
     llvmIntPtrType = LLVM::LLVMType::getIntNTy(
-        llvmDialect, module.getDataLayout().getPointerSizeInBits());
+        llvmDialect, llvmDialect->getDataLayout().getPointerSizeInBits());
   }
 
   LLVM::LLVMType getVoidType() { return llvmVoidType; }
@@ -95,9 +90,9 @@ class GpuLaunchFuncToGpuRuntimeCallsPass
   LLVM::LLVMType getInt64Type() { return llvmInt64Type; }
 
   LLVM::LLVMType getIntPtrType() {
-    const llvm::Module &module = getLLVMDialect()->getLLVMModule();
     return LLVM::LLVMType::getIntNTy(
-        getLLVMDialect(), module.getDataLayout().getPointerSizeInBits());
+        getLLVMDialect(),
+        getLLVMDialect()->getDataLayout().getPointerSizeInBits());
   }
 
   // Allocate a void pointer on the stack.

diff  --git a/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp b/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
index e6527e0ca42b..c1a64bd091a9 100644
--- a/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
+++ b/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
@@ -59,10 +59,6 @@ class VulkanLaunchFuncToVulkanCallsPass
 private:
   LLVM::LLVMDialect *getLLVMDialect() { return llvmDialect; }
 
-  llvm::LLVMContext &getLLVMContext() {
-    return getLLVMDialect()->getLLVMContext();
-  }
-
   void initializeCachedTypes() {
     llvmDialect = getContext().getRegisteredDialect<LLVM::LLVMDialect>();
     llvmFloatType = LLVM::LLVMType::getFloatTy(llvmDialect);

diff  --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index 9777071aa124..e7c8770ed8f3 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -128,10 +128,9 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
     : llvmDialect(ctx->getRegisteredDialect<LLVM::LLVMDialect>()),
       options(options) {
   assert(llvmDialect && "LLVM IR dialect is not registered");
-  module = &llvmDialect->getLLVMModule();
   if (options.indexBitwidth == kDeriveIndexBitwidthFromDataLayout)
     this->options.indexBitwidth =
-        module->getDataLayout().getPointerSizeInBits();
+        llvmDialect->getDataLayout().getPointerSizeInBits();
 
   // Register conversions for the standard types.
   addConversion([&](ComplexType type) { return convertComplexType(type); });
@@ -196,7 +195,7 @@ MLIRContext &LLVMTypeConverter::getContext() {
 
 /// Get the LLVM context.
 llvm::LLVMContext &LLVMTypeConverter::getLLVMContext() {
-  return module->getContext();
+  return llvmDialect->getLLVMContext();
 }
 
 LLVM::LLVMType LLVMTypeConverter::getIndexType() {
@@ -204,7 +203,7 @@ LLVM::LLVMType LLVMTypeConverter::getIndexType() {
 }
 
 unsigned LLVMTypeConverter::getPointerBitwidth(unsigned addressSpace) {
-  return module->getDataLayout().getPointerSizeInBits(addressSpace);
+  return llvmDialect->getDataLayout().getPointerSizeInBits(addressSpace);
 }
 
 Type LLVMTypeConverter::convertIndexType(IndexType type) {
@@ -849,10 +848,6 @@ llvm::LLVMContext &ConvertToLLVMPattern::getContext() const {
   return typeConverter.getLLVMContext();
 }
 
-llvm::Module &ConvertToLLVMPattern::getModule() const {
-  return getDialect().getLLVMModule();
-}
-
 LLVM::LLVMType ConvertToLLVMPattern::getIndexType() const {
   return typeConverter.getIndexType();
 }
@@ -910,10 +905,9 @@ Value ConvertToLLVMPattern::getStridedElementPtr(
   return rewriter.create<LLVM::GEPOp>(loc, elementTypePtr, base, offsetValue);
 }
 
-Value ConvertToLLVMPattern::getDataPtr(Location loc, MemRefType type,
-                                       Value memRefDesc, ValueRange indices,
-                                       ConversionPatternRewriter &rewriter,
-                                       llvm::Module &module) const {
+Value ConvertToLLVMPattern::getDataPtr(
+    Location loc, MemRefType type, Value memRefDesc, ValueRange indices,
+    ConversionPatternRewriter &rewriter) const {
   LLVM::LLVMType ptrType = MemRefDescriptor(memRefDesc).getElementType();
   int64_t offset;
   SmallVector<int64_t, 4> strides;
@@ -2451,7 +2445,7 @@ struct LoadOpLowering : public LoadStoreOpLowering<LoadOp> {
     auto type = loadOp.getMemRefType();
 
     Value dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(),
-                               transformed.indices(), rewriter, getModule());
+                               transformed.indices(), rewriter);
     rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, dataPtr);
     return success();
   }
@@ -2469,7 +2463,7 @@ struct StoreOpLowering : public LoadStoreOpLowering<StoreOp> {
     StoreOp::Adaptor transformed(operands);
 
     Value dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(),
-                               transformed.indices(), rewriter, getModule());
+                               transformed.indices(), rewriter);
     rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, transformed.value(),
                                                dataPtr);
     return success();
@@ -2489,7 +2483,7 @@ struct PrefetchOpLowering : public LoadStoreOpLowering<PrefetchOp> {
     auto type = prefetchOp.getMemRefType();
 
     Value dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(),
-                               transformed.indices(), rewriter, getModule());
+                               transformed.indices(), rewriter);
 
     // Replace with llvm.prefetch.
     auto llvmI32Type = typeConverter.convertType(rewriter.getIntegerType(32));
@@ -3086,7 +3080,7 @@ struct AtomicRMWOpLowering : public LoadStoreOpLowering<AtomicRMWOp> {
     auto resultType = adaptor.value().getType();
     auto memRefType = atomicOp.getMemRefType();
     auto dataPtr = getDataPtr(op->getLoc(), memRefType, adaptor.memref(),
-                              adaptor.indices(), rewriter, getModule());
+                              adaptor.indices(), rewriter);
     rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>(
         op, resultType, *maybeKind, dataPtr, adaptor.value(),
         LLVM::AtomicOrdering::acq_rel);
@@ -3152,7 +3146,7 @@ struct GenericAtomicRMWOpLowering
     rewriter.setInsertionPointToEnd(initBlock);
     auto memRefType = atomicOp.memref().getType().cast<MemRefType>();
     auto dataPtr = getDataPtr(loc, memRefType, adaptor.memref(),
-                              adaptor.indices(), rewriter, getModule());
+                              adaptor.indices(), rewriter);
     Value init = rewriter.create<LLVM::LoadOp>(loc, dataPtr);
     rewriter.create<LLVM::BrOp>(loc, init, loopBlock);
 

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 64d38e4cc293..011143b810d9 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -131,7 +131,7 @@ LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, T op,
   LLVM::LLVMDialect *dialect = typeConverter.getDialect();
   align = LLVM::TypeToLLVMIRTranslator(dialect->getLLVMContext())
               .getPreferredAlignment(elementTy.cast<LLVM::LLVMType>(),
-                                     dialect->getLLVMModule().getDataLayout());
+                                     dialect->getDataLayout());
   return success();
 }
 
@@ -1152,7 +1152,7 @@ class VectorTransferConversion : public ConvertToLLVMPattern {
     //    address space 0.
     // TODO: support alignment when possible.
     Value dataPtr = getDataPtr(loc, memRefType, adaptor.memref(),
-                               adaptor.indices(), rewriter, getModule());
+                               adaptor.indices(), rewriter);
     auto vecTy =
         toLLVMTy(xferOp.getVectorType()).template cast<LLVM::LLVMType>();
     Value vectorDataPtr;

diff  --git a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
index 9c8556dbab3a..f699ef054d1a 100644
--- a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
+++ b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
@@ -103,7 +103,7 @@ class VectorTransferConversion : public ConvertToLLVMPattern {
     // indices, so no need to calculat offset size in bytes again in
     // the MUBUF instruction.
     Value dataPtr = getDataPtr(loc, memRefType, adaptor.memref(),
-                               adaptor.indices(), rewriter, getModule());
+                               adaptor.indices(), rewriter);
 
     // 1. Create and fill a <4 x i32> dwordConfig with:
     //    1st two elements holding the address of dataPtr.

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 47129d7bd615..6a70af4744d1 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -1741,6 +1741,9 @@ llvm::Module &LLVMDialect::getLLVMModule() { return impl->module; }
 llvm::sys::SmartMutex<true> &LLVMDialect::getLLVMContextMutex() {
   return impl->mutex;
 }
+const llvm::DataLayout &LLVMDialect::getDataLayout() {
+  return impl->module.getDataLayout();
+}
 
 /// Parse a type registered to this dialect.
 Type LLVMDialect::parseType(DialectAsmParser &parser) const {


        


More information about the Mlir-commits mailing list