[Mlir-commits] [mlir] 3a577f5 - Rename MemRefDescriptor::getElementType() to MemRefDescriptor::getElementPtrType().

Christian Sigg llvmlistbot at llvm.org
Wed Sep 9 02:45:48 PDT 2020


Author: Christian Sigg
Date: 2020-09-09T11:45:39+02:00
New Revision: 3a577f544618d9713aca5052e55143142d23f427

URL: https://github.com/llvm/llvm-project/commit/3a577f544618d9713aca5052e55143142d23f427
DIFF: https://github.com/llvm/llvm-project/commit/3a577f544618d9713aca5052e55143142d23f427.diff

LOG: Rename MemRefDescriptor::getElementType() to MemRefDescriptor::getElementPtrType().

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D87284

Added: 
    

Modified: 
    mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
    mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
index 63ffd7837382..ab047a08f404 100644
--- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
+++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
@@ -34,6 +34,7 @@ class UnrankedMemRefType;
 namespace LLVM {
 class LLVMDialect;
 class LLVMType;
+class LLVMPointerType;
 } // namespace LLVM
 
 /// Callback to convert function argument types. It converts a MemRef function
@@ -281,8 +282,8 @@ class MemRefDescriptor : public StructBuilder {
   void setConstantStride(OpBuilder &builder, Location loc, unsigned pos,
                          uint64_t stride);
 
-  /// Returns the (LLVM) type this descriptor points to.
-  LLVM::LLVMType getElementType();
+  /// Returns the (LLVM) pointer type this descriptor contains.
+  LLVM::LLVMPointerType getElementPtrType();
 
   /// Builds IR populating a MemRef descriptor structure from a list of
   /// individual values composing that descriptor, in the following order:

diff  --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index 55a926ef1423..2aa589a0fb7b 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -642,9 +642,11 @@ void MemRefDescriptor::setConstantStride(OpBuilder &builder, Location loc,
             createIndexAttrConstant(builder, loc, indexType, stride));
 }
 
-LLVM::LLVMType MemRefDescriptor::getElementType() {
-  return value.getType().cast<LLVM::LLVMType>().getStructElementType(
-      kAlignedPtrPosInMemRefDescriptor);
+LLVM::LLVMPointerType MemRefDescriptor::getElementPtrType() {
+  return value.getType()
+      .cast<LLVM::LLVMType>()
+      .getStructElementType(kAlignedPtrPosInMemRefDescriptor)
+      .cast<LLVM::LLVMPointerType>();
 }
 
 /// Creates a MemRef descriptor structure from a list of individual values
@@ -894,7 +896,7 @@ Value ConvertToLLVMPattern::getStridedElementPtr(
 Value ConvertToLLVMPattern::getDataPtr(
     Location loc, MemRefType type, Value memRefDesc, ValueRange indices,
     ConversionPatternRewriter &rewriter) const {
-  LLVM::LLVMType ptrType = MemRefDescriptor(memRefDesc).getElementType();
+  LLVM::LLVMType ptrType = MemRefDescriptor(memRefDesc).getElementPtrType();
   int64_t offset;
   SmallVector<int64_t, 4> strides;
   auto successStrides = getStridesAndOffset(type, strides, offset);

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index d51a96dca384..73fd3285ec97 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -198,7 +198,7 @@ static LogicalResult getBasePtr(ConversionPatternRewriter &rewriter,
   Value base;
   if (failed(getBase(rewriter, loc, memref, memRefType, base)))
     return failure();
-  auto pType = MemRefDescriptor(memref).getElementType();
+  auto pType = MemRefDescriptor(memref).getElementPtrType();
   ptr = rewriter.create<LLVM::GEPOp>(loc, pType, base);
   return success();
 }
@@ -225,7 +225,7 @@ static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter,
   Value base;
   if (failed(getBase(rewriter, loc, memref, memRefType, base)))
     return failure();
-  auto pType = MemRefDescriptor(memref).getElementType();
+  auto pType = MemRefDescriptor(memref).getElementPtrType();
   auto ptrsType = LLVM::LLVMType::getVectorTy(pType, vType.getDimSize(0));
   ptrs = rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, indices);
   return success();
@@ -1151,7 +1151,7 @@ class VectorTypeCastOpConversion : public ConvertToLLVMPattern {
 
     // Create descriptor.
     auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
-    Type llvmTargetElementTy = desc.getElementType();
+    Type llvmTargetElementTy = desc.getElementPtrType();
     // Set allocated ptr.
     Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
     allocated =


        


More information about the Mlir-commits mailing list