[Mlir-commits] [mlir] eafca23 - [mlir][MemRef] Add required address space cast when lowering alloc to LLVM

Markus Böck llvmlistbot at llvm.org
Mon Feb 6 03:09:44 PST 2023


Author: Markus Böck
Date: 2023-02-06T12:10:07+01:00
New Revision: eafca2303769800f5da4bc4cbf9e842c6a8cde9f

URL: https://github.com/llvm/llvm-project/commit/eafca2303769800f5da4bc4cbf9e842c6a8cde9f
DIFF: https://github.com/llvm/llvm-project/commit/eafca2303769800f5da4bc4cbf9e842c6a8cde9f.diff

LOG: [mlir][MemRef] Add required address space cast when lowering alloc to LLVM

alloc uses either `malloc` or a plugable allocation function for allocating the required memory. Both of these functions always return a `llvm.ptr<i8>`, aka a pointer in the default address space. When allocating for a memref in a different memory space however, no address space cast is created, leading to invalid LLVM IR being generated.

This is currently not caught by the verifier since the pointer to the memory is always bitcast which currently lacks a verifier disallowing address space casts. Translating to actual LLVM IR would cause the verifier to go off, since bitcast cannot translate from one address space to another: https://godbolt.org/z/3a1z97rc9

This patch fixes that issue by generating an address space cast if the address space of the allocation function does not match the address space of the resulting memref.

Not sure whether this is actually a real life problem. I found this issue while converting the pass to using opaque pointers which gets rid of all the bitcasts and hence caused type errors without the address space cast.

Differential Revision: https://reviews.llvm.org/D143341

Added: 
    

Modified: 
    mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
    mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp b/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
index 4a5be48707097..8d99e1ffed9be 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
@@ -50,6 +50,23 @@ Value AllocationOpLLVMLowering::createAligned(
   return rewriter.create<LLVM::SubOp>(loc, bumped, mod);
 }
 
+static Value castAllocFuncResult(ConversionPatternRewriter &rewriter,
+                                 Location loc, Value allocatedPtr,
+                                 MemRefType memRefType, Type elementPtrType,
+                                 LLVMTypeConverter &typeConverter) {
+  auto allocatedPtrTy = allocatedPtr.getType().cast<LLVM::LLVMPointerType>();
+  if (allocatedPtrTy.getAddressSpace() != memRefType.getMemorySpaceAsInt())
+    allocatedPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
+        loc,
+        LLVM::LLVMPointerType::get(allocatedPtrTy.getElementType(),
+                                   memRefType.getMemorySpaceAsInt()),
+        allocatedPtr);
+
+  allocatedPtr =
+      rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, allocatedPtr);
+  return allocatedPtr;
+}
+
 std::tuple<Value, Value> AllocationOpLLVMLowering::allocateBufferManuallyAlign(
     ConversionPatternRewriter &rewriter, Location loc, Value sizeBytes,
     Operation *op, Value alignment) const {
@@ -64,8 +81,10 @@ std::tuple<Value, Value> AllocationOpLLVMLowering::allocateBufferManuallyAlign(
   LLVM::LLVMFuncOp allocFuncOp = getNotalignedAllocFn(
       getTypeConverter(), op->getParentOfType<ModuleOp>(), getIndexType());
   auto results = rewriter.create<LLVM::CallOp>(loc, allocFuncOp, sizeBytes);
-  Value allocatedPtr = rewriter.create<LLVM::BitcastOp>(loc, elementPtrType,
-                                                        results.getResult());
+
+  Value allocatedPtr =
+      castAllocFuncResult(rewriter, loc, results.getResult(), memRefType,
+                          elementPtrType, *getTypeConverter());
 
   Value alignedPtr = allocatedPtr;
   if (alignment) {
@@ -126,10 +145,9 @@ Value AllocationOpLLVMLowering::allocateBufferAutoAlign(
       getTypeConverter(), op->getParentOfType<ModuleOp>(), getIndexType());
   auto results = rewriter.create<LLVM::CallOp>(
       loc, allocFuncOp, ValueRange({allocAlignment, sizeBytes}));
-  Value allocatedPtr = rewriter.create<LLVM::BitcastOp>(loc, elementPtrType,
-                                                        results.getResult());
 
-  return allocatedPtr;
+  return castAllocFuncResult(rewriter, loc, results.getResult(), memRefType,
+                             elementPtrType, *getTypeConverter());
 }
 
 LogicalResult AllocLikeOpLLVMLowering::matchAndRewrite(

diff  --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
index b6c73e9a917bc..3f61f6d78a7bb 100644
--- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
@@ -182,6 +182,11 @@ func.func @dim_of_unranked(%unranked: memref<*xi32>) -> index {
 
 // CHECK-LABEL: func @address_space(
 func.func @address_space(%arg0 : memref<32xf32, affine_map<(d0) -> (d0)>, 7>) {
+  // CHECK: %[[MEMORY:.*]] = llvm.call @malloc(%{{.*}})
+  // CHECK: %[[CAST:.*]] = llvm.addrspacecast %[[MEMORY]] : !llvm.ptr<i8> to !llvm.ptr<i8, 5>
+  // CHECK: %[[BCAST:.*]] = llvm.bitcast %[[CAST]]
+  // CHECK: llvm.insertvalue %[[BCAST]], %{{[[:alnum:]]+}}[0]
+  // CHECK: llvm.insertvalue %[[BCAST]], %{{[[:alnum:]]+}}[1]
   %0 = memref.alloc() : memref<32xf32, affine_map<(d0) -> (d0)>, 5>
   %1 = arith.constant 7 : index
   // CHECK: llvm.load %{{.*}} : !llvm.ptr<f32, 5>


        


More information about the Mlir-commits mailing list