[Mlir-commits] [mlir] 8c2025c - [MLIR] Refactor memref type -> LLVM Type conversion

Rahul Joshi llvmlistbot at llvm.org
Wed Nov 4 10:33:31 PST 2020


Author: Rahul Joshi
Date: 2020-11-04T10:32:56-08:00
New Revision: 8c2025cc617aa0413dee269084b4a66bcb7de4d5

URL: https://github.com/llvm/llvm-project/commit/8c2025cc617aa0413dee269084b4a66bcb7de4d5
DIFF: https://github.com/llvm/llvm-project/commit/8c2025cc617aa0413dee269084b4a66bcb7de4d5.diff

LOG: [MLIR] Refactor memref type -> LLVM Type conversion

- Eliminate duplicated information about mapping from memref -> its descriptor fields
  by consolidating that mapping in two functions:  getMemRefDescriptorFields and
  getUnrankedMemRefDescriptorFields.
- Change convertMemRefType() and convertUnrankedMemRefType() to use these
  functions.
- Remove convertMemrefSignature and convertUnrankedMemrefSignature.

Differential Revision: https://reviews.llvm.org/D90707

Added: 
    

Modified: 
    mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
    mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
index c52de63224b1..e7aa9d5ae516 100644
--- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
+++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
@@ -164,12 +164,20 @@ class LLVMTypeConverter : public TypeConverter {
   /// Convert a memref type into an LLVM type that captures the relevant data.
   Type convertMemRefType(MemRefType type);
 
-  /// Convert a memref type into a list of non-aggregate LLVM IR types that
-  /// contain all the relevant data. In particular, the list will contain:
+  /// Convert a memref type into a list of LLVM IR types that will form the
+  /// memref descriptor. If `unpackAggregates` is true the `sizes` and `strides`
+  /// arrays in the descriptors are unpacked to individual index-typed elements,
+  /// else they are are kept as rank-sized arrays of index type. In particular,
+  /// the list will contain:
   /// - two pointers to the memref element type, followed by
-  /// - an integer offset, followed by
-  /// - one integer size per dimension of the memref, followed by
-  /// - one integer stride per dimension of the memref.
+  /// - an index-typed offset, followed by
+  /// - (if unpackAggregates = true)
+  ///    - one index-typed size per dimension of the memref, followed by
+  ///    - one index-typed stride per dimension of the memref.
+  /// - (if unpackArrregates = false)
+  ///   - one rank-sized array of index-type for the size of each dimension
+  ///   - one rank-sized array of index-type for the stride of each dimension
+  ///
   /// For example, memref<?x?xf32> is converted to the following list:
   /// - `!llvm<"float*">` (allocated pointer),
   /// - `!llvm<"float*">` (aligned pointer),
@@ -177,17 +185,19 @@ class LLVMTypeConverter : public TypeConverter {
   /// - `!llvm.i64`, `!llvm.i64` (sizes),
   /// - `!llvm.i64`, `!llvm.i64` (strides).
   /// These types can be recomposed to a memref descriptor struct.
-  SmallVector<Type, 5> convertMemRefSignature(MemRefType type);
+  SmallVector<LLVM::LLVMType, 5>
+  getMemRefDescriptorFields(MemRefType type, bool unpackAggregates);
 
   /// Convert an unranked memref type into a list of non-aggregate LLVM IR types
-  /// that contain all the relevant data. In particular, this list contains:
+  /// that will form the unranked memref descriptor. In particular, this list
+  /// contains:
   /// - an integer rank, followed by
   /// - a pointer to the memref descriptor struct.
   /// For example, memref<*xf32> is converted to the following list:
   /// !llvm.i64 (rank)
   /// !llvm<"i8*"> (type-erased pointer).
   /// These types can be recomposed to a unranked memref descriptor struct.
-  SmallVector<Type, 2> convertUnrankedMemRefSignature();
+  SmallVector<LLVM::LLVMType, 2> getUnrankedMemRefDescriptorFields();
 
   // Convert an unranked memref type to an LLVM type that captures the
   // runtime rank and a pointer to the static ranked memref desc

diff  --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index 03f604b652ab..3cd28bf919e8 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -61,14 +61,17 @@ LogicalResult mlir::structFuncArgTypeConverter(LLVMTypeConverter &converter,
                                                Type type,
                                                SmallVectorImpl<Type> &result) {
   if (auto memref = type.dyn_cast<MemRefType>()) {
-    auto converted = converter.convertMemRefSignature(memref);
+    // In signatures, Memref descriptors are expanded into lists of
+    // non-aggregate values.
+    auto converted =
+        converter.getMemRefDescriptorFields(memref, /*unpackAggregates=*/true);
     if (converted.empty())
       return failure();
     result.append(converted.begin(), converted.end());
     return success();
   }
   if (type.isa<UnrankedMemRefType>()) {
-    auto converted = converter.convertUnrankedMemRefSignature();
+    auto converted = converter.getUnrankedMemRefDescriptorFields();
     if (converted.empty())
       return failure();
     result.append(converted.begin(), converted.end());
@@ -216,32 +219,6 @@ Type LLVMTypeConverter::convertFunctionType(FunctionType type) {
   return converted.getPointerTo();
 }
 
-/// In signatures, MemRef descriptors are expanded into lists of non-aggregate
-/// values.
-SmallVector<Type, 5>
-LLVMTypeConverter::convertMemRefSignature(MemRefType type) {
-  SmallVector<Type, 5> results;
-  assert(isStrided(type) &&
-         "Non-strided layout maps must have been normalized away");
-
-  LLVM::LLVMType elementType = unwrap(convertType(type.getElementType()));
-  if (!elementType)
-    return {};
-  auto indexTy = getIndexType();
-
-  results.insert(results.begin(), 2,
-                 elementType.getPointerTo(type.getMemorySpace()));
-  results.push_back(indexTy);
-  auto rank = type.getRank();
-  results.insert(results.end(), 2 * rank, indexTy);
-  return results;
-}
-
-/// In signatures, unranked MemRef descriptors are expanded into a pair "rank,
-/// pointer to descriptor".
-SmallVector<Type, 2> LLVMTypeConverter::convertUnrankedMemRefSignature() {
-  return {getIndexType(), LLVM::LLVMType::getInt8PtrTy(&getContext())};
-}
 
 // Function types are converted to LLVM Function types by recursively converting
 // argument and result types.  If MLIR Function has zero results, the LLVM
@@ -305,69 +282,92 @@ LLVMTypeConverter::convertFunctionTypeCWrapper(FunctionType type) {
   return LLVM::LLVMType::getFunctionTy(resultType, inputs, false);
 }
 
-// Convert a MemRef to an LLVM type. The result is a MemRef descriptor which
-// contains:
-//   1. the pointer to the data buffer, followed by
-//   2.  a lowered `index`-type integer containing the distance between the
-//   beginning of the buffer and the first element to be accessed through the
-//   view, followed by
-//   3. an array containing as many `index`-type integers as the rank of the
-//   MemRef: the array represents the size, in number of elements, of the memref
-//   along the given dimension. For constant MemRef dimensions, the
-//   corresponding size entry is a constant whose runtime value must match the
-//   static value, followed by
-//   4. a second array containing as many `index`-type integers as the rank of
-//   the MemRef: the second array represents the "stride" (in tensor abstraction
-//   sense), i.e. the number of consecutive elements of the underlying buffer.
-//   TODO: add assertions for the static cases.
-//
-// template <typename Elem, size_t Rank>
-// struct {
-//   Elem *allocatedPtr;
-//   Elem *alignedPtr;
-//   int64_t offset;
-//   int64_t sizes[Rank]; // omitted when rank == 0
-//   int64_t strides[Rank]; // omitted when rank == 0
-// };
 static constexpr unsigned kAllocatedPtrPosInMemRefDescriptor = 0;
 static constexpr unsigned kAlignedPtrPosInMemRefDescriptor = 1;
 static constexpr unsigned kOffsetPosInMemRefDescriptor = 2;
 static constexpr unsigned kSizePosInMemRefDescriptor = 3;
 static constexpr unsigned kStridePosInMemRefDescriptor = 4;
-Type LLVMTypeConverter::convertMemRefType(MemRefType type) {
-  int64_t offset;
-  SmallVector<int64_t, 4> strides;
-  bool strideSuccess = succeeded(getStridesAndOffset(type, strides, offset));
-  assert(strideSuccess &&
+
+/// Convert a memref type into a list of LLVM IR types that will form the
+/// memref descriptor. The result contains the following types:
+///  1. The pointer to the allocated data buffer, followed by
+///  2. The pointer to the aligned data buffer, followed by
+///  3. A lowered `index`-type integer containing the distance between the
+///  beginning of the buffer and the first element to be accessed through the
+///  view, followed by
+///  4. An array containing as many `index`-type integers as the rank of the
+///  MemRef: the array represents the size, in number of elements, of the memref
+///  along the given dimension. For constant MemRef dimensions, the
+///  corresponding size entry is a constant whose runtime value must match the
+///  static value, followed by
+///  5. A second array containing as many `index`-type integers as the rank of
+///  the MemRef: the second array represents the "stride" (in tensor abstraction
+///  sense), i.e. the number of consecutive elements of the underlying buffer.
+///  TODO: add assertions for the static cases.
+///
+///  If `unpackAggregates` is set to true, the arrays described in (4) and (5)
+///  are expanded into individual index-type elements.
+///
+///  template <typename Elem, typename Index, size_t Rank>
+///  struct {
+///    Elem *allocatedPtr;
+///    Elem *alignedPtr;
+///    Index offset;
+///    Index sizes[Rank]; // omitted when rank == 0
+///    Index strides[Rank]; // omitted when rank == 0
+///  };
+SmallVector<LLVM::LLVMType, 5>
+LLVMTypeConverter::getMemRefDescriptorFields(MemRefType type,
+                                             bool unpackAggregates) {
+  assert(isStrided(type) &&
          "Non-strided layout maps must have been normalized away");
-  (void)strideSuccess;
+
   LLVM::LLVMType elementType = unwrap(convertType(type.getElementType()));
   if (!elementType)
     return {};
   auto ptrTy = elementType.getPointerTo(type.getMemorySpace());
   auto indexTy = getIndexType();
+
+  SmallVector<LLVM::LLVMType, 5> results = {ptrTy, ptrTy, indexTy};
   auto rank = type.getRank();
-  if (rank > 0) {
-    auto arrayTy = LLVM::LLVMType::getArrayTy(indexTy, type.getRank());
-    return LLVM::LLVMType::getStructTy(ptrTy, ptrTy, indexTy, arrayTy, arrayTy);
-  }
-  return LLVM::LLVMType::getStructTy(ptrTy, ptrTy, indexTy);
+  if (rank == 0)
+    return results;
+
+  if (unpackAggregates)
+    results.insert(results.end(), 2 * rank, indexTy);
+  else
+    results.insert(results.end(), 2, LLVM::LLVMType::getArrayTy(indexTy, rank));
+  return results;
 }
 
-// Converts UnrankedMemRefType to LLVMType. The result is a descriptor which
-// contains:
-// 1. int64_t rank, the dynamic rank of this MemRef
-// 2. void* ptr, pointer to the static ranked MemRef descriptor. This will be
-//    stack allocated (alloca) copy of a MemRef descriptor that got casted to
-//    be unranked.
+/// Converts MemRefType to LLVMType. A MemRefType is converted to a struct that
+/// packs the descriptor fields as defined by `getMemRefDescriptorFields`.
+Type LLVMTypeConverter::convertMemRefType(MemRefType type) {
+  // When converting a MemRefType to a struct with descriptor fields, do not
+  // unpack the `sizes` and `strides` arrays.
+  SmallVector<LLVM::LLVMType, 5> types =
+      getMemRefDescriptorFields(type, /*unpackAggregates=*/false);
+  return LLVM::LLVMType::getStructTy(&getContext(), types);
+}
 
 static constexpr unsigned kRankInUnrankedMemRefDescriptor = 0;
 static constexpr unsigned kPtrInUnrankedMemRefDescriptor = 1;
 
+/// Convert an unranked memref type into a list of non-aggregate LLVM IR types
+/// that will form the unranked memref descriptor. In particular, the fields
+/// for an unranked memref descriptor are:
+/// 1. index-typed rank, the dynamic rank of this MemRef
+/// 2. void* ptr, pointer to the static ranked MemRef descriptor. This will be
+///    stack allocated (alloca) copy of a MemRef descriptor that got casted to
+///    be unranked.
+SmallVector<LLVM::LLVMType, 2>
+LLVMTypeConverter::getUnrankedMemRefDescriptorFields() {
+  return {getIndexType(), LLVM::LLVMType::getInt8PtrTy(&getContext())};
+}
+
 Type LLVMTypeConverter::convertUnrankedMemRefType(UnrankedMemRefType type) {
-  auto rankTy = getIndexType();
-  auto ptrTy = LLVM::LLVMType::getInt8PtrTy(&getContext());
-  return LLVM::LLVMType::getStructTy(rankTy, ptrTy);
+  return LLVM::LLVMType::getStructTy(&getContext(),
+                                     getUnrankedMemRefDescriptorFields());
 }
 
 /// Convert a memref type to a bare pointer to the memref element type.


        


More information about the Mlir-commits mailing list