[Mlir-commits] [mlir] d0f19ce - [mlir] Handle different pointer sizes in unranked memref descriptors

Krzysztof Drewniak llvmlistbot at llvm.org
Thu Feb 9 11:15:03 PST 2023


Author: Krzysztof Drewniak
Date: 2023-02-09T19:14:58Z
New Revision: d0f19ce774b0d3fa5cd3b22561cc75c7b38b290d

URL: https://github.com/llvm/llvm-project/commit/d0f19ce774b0d3fa5cd3b22561cc75c7b38b290d
DIFF: https://github.com/llvm/llvm-project/commit/d0f19ce774b0d3fa5cd3b22561cc75c7b38b290d.diff

LOG: [mlir] Handle different pointer sizes in unranked memref descriptors

The code for unranked memref descriptors assumed that
sizeof(!llvm.ptr) == lizeof(!llvm.ptr<N>) for all address spaces N.
This is not always true (ex. the AMDGPU compiler backend has
sizeof(!llvm.ptr) = 64 bits but sizeof(!llvm.ptr<5>) = 32 bits, where
address space 5 is used for stack allocations). While this is merely
an overallocation in the case where a non-0 address space has pointers
smaller than the default, the existing code could cause OOB memory
accesses when sizeof(!llvm.ptr<N>) > sizeof(!llvm.ptr).

So, add an address spaces parameter to computeSizes in order to
partially resolve this class of bugs. Note that the LLVM data layout
in the conversion passes is currently set to "" and not constructed
from the MLIR data layout or some other source, but this could change
in the future.

Depends on D142159

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D141293

Added: 
    

Modified: 
    mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h
    mlir/include/mlir/Conversion/LLVMCommon/StructBuilder.h
    mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
    mlir/lib/Conversion/LLVMCommon/Pattern.cpp
    mlir/lib/Conversion/LLVMCommon/StructBuilder.cpp
    mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
    mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
    mlir/test/Conversion/FuncToLLVM/calling-convention.mlir
    mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir
    mlir/test/Conversion/MemRefToLLVM/typed-pointers.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h b/mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h
index bf1bd8f7f325a..a68e0879444db 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h
@@ -157,11 +157,11 @@ class UnrankedMemRefDescriptor : public StructBuilder {
                                         Type descriptorType);
 
   /// Builds IR extracting the rank from the descriptor
-  Value rank(OpBuilder &builder, Location loc);
+  Value rank(OpBuilder &builder, Location loc) const;
   /// Builds IR setting the rank in the descriptor
   void setRank(OpBuilder &builder, Location loc, Value value);
   /// Builds IR extracting ranked memref descriptor ptr
-  Value memRefDescPtr(OpBuilder &builder, Location loc);
+  Value memRefDescPtr(OpBuilder &builder, Location loc) const;
   /// Builds IR setting ranked memref descriptor ptr
   void setMemRefDescPtr(OpBuilder &builder, Location loc, Value value);
 
@@ -183,10 +183,13 @@ class UnrankedMemRefDescriptor : public StructBuilder {
   static unsigned getNumUnpackedValues() { return 2; }
 
   /// Builds IR computing the sizes in bytes (suitable for opaque allocation)
-  /// and appends the corresponding values into `sizes`.
+  /// and appends the corresponding values into `sizes`. `addressSpaces`
+  /// which must have the same length as `values`, is needed to handle layouts
+  /// where sizeof(ptr addrspace(N)) != sizeof(ptr addrspace(0)).
   static void computeSizes(OpBuilder &builder, Location loc,
                            LLVMTypeConverter &typeConverter,
                            ArrayRef<UnrankedMemRefDescriptor> values,
+                           ArrayRef<unsigned> addressSpaces,
                            SmallVectorImpl<Value> &sizes);
 
   /// TODO: The following accessors don't take alignment rules between elements

diff  --git a/mlir/include/mlir/Conversion/LLVMCommon/StructBuilder.h b/mlir/include/mlir/Conversion/LLVMCommon/StructBuilder.h
index 3523f98780b11..1a5b97eb92830 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/StructBuilder.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/StructBuilder.h
@@ -41,7 +41,7 @@ class StructBuilder {
 
 protected:
   /// Builds IR to extract a value from the struct at position pos
-  Value extractPtr(OpBuilder &builder, Location loc, unsigned pos);
+  Value extractPtr(OpBuilder &builder, Location loc, unsigned pos) const;
   /// Builds IR to set a value in the struct at position pos
   void setPtr(OpBuilder &builder, Location loc, unsigned pos, Value ptr);
 };

diff  --git a/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp b/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
index 8ce9fb69ee6df..17259e48e3a3a 100644
--- a/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
@@ -296,7 +296,7 @@ UnrankedMemRefDescriptor UnrankedMemRefDescriptor::undef(OpBuilder &builder,
   Value descriptor = builder.create<LLVM::UndefOp>(loc, descriptorType);
   return UnrankedMemRefDescriptor(descriptor);
 }
-Value UnrankedMemRefDescriptor::rank(OpBuilder &builder, Location loc) {
+Value UnrankedMemRefDescriptor::rank(OpBuilder &builder, Location loc) const {
   return extractPtr(builder, loc, kRankInUnrankedMemRefDescriptor);
 }
 void UnrankedMemRefDescriptor::setRank(OpBuilder &builder, Location loc,
@@ -304,7 +304,7 @@ void UnrankedMemRefDescriptor::setRank(OpBuilder &builder, Location loc,
   setPtr(builder, loc, kRankInUnrankedMemRefDescriptor, v);
 }
 Value UnrankedMemRefDescriptor::memRefDescPtr(OpBuilder &builder,
-                                              Location loc) {
+                                              Location loc) const {
   return extractPtr(builder, loc, kPtrInUnrankedMemRefDescriptor);
 }
 void UnrankedMemRefDescriptor::setMemRefDescPtr(OpBuilder &builder,
@@ -341,24 +341,24 @@ void UnrankedMemRefDescriptor::unpack(OpBuilder &builder, Location loc,
 
 void UnrankedMemRefDescriptor::computeSizes(
     OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter,
-    ArrayRef<UnrankedMemRefDescriptor> values, SmallVectorImpl<Value> &sizes) {
+    ArrayRef<UnrankedMemRefDescriptor> values, ArrayRef<unsigned> addressSpaces,
+    SmallVectorImpl<Value> &sizes) {
   if (values.empty())
     return;
-
+  assert(values.size() == addressSpaces.size() &&
+         "must provide address space for each descriptor");
   // Cache the index type.
   Type indexType = typeConverter.getIndexType();
 
   // Initialize shared constants.
   Value one = createIndexAttrConstant(builder, loc, indexType, 1);
   Value two = createIndexAttrConstant(builder, loc, indexType, 2);
-  Value pointerSize = createIndexAttrConstant(
-      builder, loc, indexType, ceilDiv(typeConverter.getPointerBitwidth(), 8));
   Value indexSize =
       createIndexAttrConstant(builder, loc, indexType,
                               ceilDiv(typeConverter.getIndexTypeBitwidth(), 8));
 
   sizes.reserve(sizes.size() + values.size());
-  for (UnrankedMemRefDescriptor desc : values) {
+  for (auto [desc, addressSpace] : llvm::zip(values, addressSpaces)) {
     // Emit IR computing the memory necessary to store the descriptor. This
     // assumes the descriptor to be
     //   { type*, type*, index, index[rank], index[rank] }
@@ -366,6 +366,9 @@ void UnrankedMemRefDescriptor::computeSizes(
     //   2 * sizeof(pointer) + (1 + 2 * rank) * sizeof(index).
     // TODO: consider including the actual size (including eventual padding due
     // to data layout) into the unranked descriptor.
+    Value pointerSize = createIndexAttrConstant(
+        builder, loc, indexType,
+        ceilDiv(typeConverter.getPointerBitwidth(addressSpace), 8));
     Value doublePointerSize =
         builder.create<LLVM::MulOp>(loc, indexType, two, pointerSize);
 

diff  --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
index 69a0172ebf0b1..d3983a3c0a944 100644
--- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
@@ -232,18 +232,27 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
          "expected as may original types as operands");
 
   // Find operands of unranked memref type and store them.
-  SmallVector<UnrankedMemRefDescriptor, 4> unrankedMemrefs;
-  for (unsigned i = 0, e = operands.size(); i < e; ++i)
-    if (origTypes[i].isa<UnrankedMemRefType>())
+  SmallVector<UnrankedMemRefDescriptor> unrankedMemrefs;
+  SmallVector<unsigned> unrankedAddressSpaces;
+  for (unsigned i = 0, e = operands.size(); i < e; ++i) {
+    if (auto memRefType = origTypes[i].dyn_cast<UnrankedMemRefType>()) {
       unrankedMemrefs.emplace_back(operands[i]);
+      FailureOr<unsigned> addressSpace =
+          getTypeConverter()->getMemRefAddressSpace(memRefType);
+      if (failed(addressSpace))
+        return failure();
+      unrankedAddressSpaces.emplace_back(*addressSpace);
+    }
+  }
 
   if (unrankedMemrefs.empty())
     return success();
 
   // Compute allocation sizes.
-  SmallVector<Value, 4> sizes;
+  SmallVector<Value> sizes;
   UnrankedMemRefDescriptor::computeSizes(builder, loc, *getTypeConverter(),
-                                         unrankedMemrefs, sizes);
+                                         unrankedMemrefs, unrankedAddressSpaces,
+                                         sizes);
 
   // Get frequently used types.
   MLIRContext *context = builder.getContext();

diff  --git a/mlir/lib/Conversion/LLVMCommon/StructBuilder.cpp b/mlir/lib/Conversion/LLVMCommon/StructBuilder.cpp
index b5b192d0f0399..1cd0bd85f9894 100644
--- a/mlir/lib/Conversion/LLVMCommon/StructBuilder.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/StructBuilder.cpp
@@ -23,7 +23,7 @@ StructBuilder::StructBuilder(Value v) : value(v), structType(v.getType()) {
 }
 
 Value StructBuilder::extractPtr(OpBuilder &builder, Location loc,
-                                unsigned pos) {
+                                unsigned pos) const {
   return builder.create<LLVM::ExtractValueOp>(loc, value, pos);
 }
 

diff  --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index c3ff7d86852ae..5dccb9b5f9ea9 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -1329,7 +1329,7 @@ struct MemRefReshapeOpLowering
     targetDesc.setRank(rewriter, loc, resultRank);
     SmallVector<Value, 4> sizes;
     UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(),
-                                           targetDesc, sizes);
+                                           targetDesc, addressSpace, sizes);
     Value underlyingDescPtr = rewriter.create<LLVM::AllocaOp>(
         loc, getVoidPtrType(), IntegerType::get(getContext(), 8),
         sizes.front());

diff  --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index 509a03c1096a5..cdd8cd77aa9c0 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -742,8 +742,6 @@ convertTransferReadToLoads(vector::TransferReadOp op,
   if (failed(warpMatrixInfo))
     return failure();
 
-  Attribute memorySpace =
-      op.getSource().getType().cast<MemRefType>().getMemorySpace();
   bool isLdMatrixCompatible =
       isSharedMemory(op.getSource().getType().cast<MemRefType>()) &&
       nvgpu::inferTileWidthInBits(*warpMatrixInfo) == 128;

diff  --git a/mlir/test/Conversion/FuncToLLVM/calling-convention.mlir b/mlir/test/Conversion/FuncToLLVM/calling-convention.mlir
index b59c0b60497da..daa824d84ba74 100644
--- a/mlir/test/Conversion/FuncToLLVM/calling-convention.mlir
+++ b/mlir/test/Conversion/FuncToLLVM/calling-convention.mlir
@@ -122,9 +122,9 @@ func.func @return_var_memref_caller(%arg0: memref<4x3xf32>) {
   // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1 : index)
   // CHECK: %[[TWO:.*]] = llvm.mlir.constant(2 : index)
   // These sizes may depend on the data layout, not matching specific values.
-  // CHECK: %[[PTR_SIZE:.*]] = llvm.mlir.constant
   // CHECK: %[[IDX_SIZE:.*]] = llvm.mlir.constant
 
+  // CHECK: %[[PTR_SIZE:.*]] = llvm.mlir.constant
   // CHECK: %[[DOUBLE_PTR_SIZE:.*]] = llvm.mul %[[TWO]], %[[PTR_SIZE]]
   // CHECK: %[[RANK:.*]] = llvm.extractvalue %[[CALL_RES]][0] : !llvm.struct<(i64, ptr)>
   // CHECK: %[[DOUBLE_RANK:.*]] = llvm.mul %[[TWO]], %[[RANK]]
@@ -153,13 +153,12 @@ func.func @return_var_memref(%arg0: memref<4x3xf32>) -> memref<*xf32> attributes
   // CHECK: %[[DESC_2:.*]] = llvm.insertvalue %[[ALLOCA]], %[[DESC_1]][1]
   %0 = memref.cast %arg0: memref<4x3xf32> to memref<*xf32>
 
-
   // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1 : index)
   // CHECK: %[[TWO:.*]] = llvm.mlir.constant(2 : index)
   // These sizes may depend on the data layout, not matching specific values.
-  // CHECK: %[[PTR_SIZE:.*]] = llvm.mlir.constant
   // CHECK: %[[IDX_SIZE:.*]] = llvm.mlir.constant
 
+  // CHECK: %[[PTR_SIZE:.*]] = llvm.mlir.constant
   // CHECK: %[[DOUBLE_PTR_SIZE:.*]] = llvm.mul %[[TWO]], %[[PTR_SIZE]]
   // CHECK: %[[DOUBLE_RANK:.*]] = llvm.mul %[[TWO]], %[[RANK]]
   // CHECK: %[[DOUBLE_RANK_INC:.*]] = llvm.add %[[DOUBLE_RANK]], %[[ONE]]

diff  --git a/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir b/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir
index 6987bdcc104d5..9624b18f30dde 100644
--- a/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir
@@ -408,8 +408,8 @@ func.func @memref_reshape(%input : memref<2x3xf32>, %shape : memref<?xindex>) {
 // Compute size in bytes to allocate result ranked descriptor
 // CHECK: [[C1:%.*]] = llvm.mlir.constant(1 : index) : i64
 // CHECK: [[C2:%.*]] = llvm.mlir.constant(2 : index) : i64
-// CHECK: [[PTR_SIZE:%.*]] = llvm.mlir.constant(8 : index) : i64
 // CHECK: [[INDEX_SIZE:%.*]] = llvm.mlir.constant(8 : index) : i64
+// CHECK: [[PTR_SIZE:%.*]] = llvm.mlir.constant(8 : index) : i64
 // CHECK: [[DOUBLE_PTR_SIZE:%.*]] = llvm.mul [[C2]], [[PTR_SIZE]] : i64
 // CHECK: [[DESC_ALLOC_SIZE:%.*]] = llvm.add [[DOUBLE_PTR_SIZE]], %{{.*}}
 // CHECK: [[UNDERLYING_DESC:%.*]] = llvm.alloca [[DESC_ALLOC_SIZE]] x i8

diff  --git a/mlir/test/Conversion/MemRefToLLVM/typed-pointers.mlir b/mlir/test/Conversion/MemRefToLLVM/typed-pointers.mlir
index e14f70df15b97..36206f1b57b08 100644
--- a/mlir/test/Conversion/MemRefToLLVM/typed-pointers.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/typed-pointers.mlir
@@ -323,8 +323,8 @@ func.func @memref_reshape(%input : memref<2x3xf32>, %shape : memref<?xindex>) {
 // Compute size in bytes to allocate result ranked descriptor
 // CHECK: [[C1:%.*]] = llvm.mlir.constant(1 : index) : i64
 // CHECK: [[C2:%.*]] = llvm.mlir.constant(2 : index) : i64
-// CHECK: [[PTR_SIZE:%.*]] = llvm.mlir.constant(8 : index) : i64
 // CHECK: [[INDEX_SIZE:%.*]] = llvm.mlir.constant(8 : index) : i64
+// CHECK: [[PTR_SIZE:%.*]] = llvm.mlir.constant(8 : index) : i64
 // CHECK: [[DOUBLE_PTR_SIZE:%.*]] = llvm.mul [[C2]], [[PTR_SIZE]] : i64
 // CHECK: [[DESC_ALLOC_SIZE:%.*]] = llvm.add [[DOUBLE_PTR_SIZE]], %{{.*}}
 // CHECK: [[UNDERLYING_DESC:%.*]] = llvm.alloca [[DESC_ALLOC_SIZE]] x i8


        


More information about the Mlir-commits mailing list