[Mlir-commits] [mlir] 0aaf2e3 - [mlir][GPU] add required address space cast when lowering to LLVM

Markus Böck llvmlistbot at llvm.org
Mon Feb 13 13:23:51 PST 2023


Author: Markus Böck
Date: 2023-02-13T22:24:20+01:00
New Revision: 0aaf2e3bc057aa1d784455f8f4da66bc464733d6

URL: https://github.com/llvm/llvm-project/commit/0aaf2e3bc057aa1d784455f8f4da66bc464733d6
DIFF: https://github.com/llvm/llvm-project/commit/0aaf2e3bc057aa1d784455f8f4da66bc464733d6.diff

LOG: [mlir][GPU] add required address space cast when lowering to LLVM

The runtime functions `memset` and `memcpy` are lowered are declared with pointers to the default address space (0) while their ops however are compatible with memrefs taking any address space.
Such cases do not cause any issues with MLIRs LLVM Dialect due to `bitcast`s verifier being too lenient at the moment, but actual LLVM IR does not allow casting between address spaces using `bitcast`: https://godbolt.org/z/3a1z97rc9

This patch fixes the issue by inserting an address space cast before the bitcast, to first cast the pointer into the correct address space before doing the bitcast.

Differential Revision: https://reviews.llvm.org/D143866

Added: 
    

Modified: 
    mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
    mlir/test/Conversion/GPUCommon/lower-memcpy-to-gpu-runtime-calls.mlir
    mlir/test/Conversion/GPUCommon/lower-memset-to-gpu-runtime-calls.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index 3f61be5471e7b..4bb0e3ae028f7 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -92,7 +92,7 @@ class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> {
   MLIRContext *context = &this->getTypeConverter()->getContext();
 
   Type llvmVoidType = LLVM::LLVMVoidType::get(context);
-  Type llvmPointerType =
+  LLVM::LLVMPointerType llvmPointerType =
       LLVM::LLVMPointerType::get(IntegerType::get(context, 8));
   Type llvmPointerPointerType = LLVM::LLVMPointerType::get(llvmPointerType);
   Type llvmInt8Type = IntegerType::get(context, 8);
@@ -807,6 +807,22 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite(
   return success();
 }
 
+static Value bitAndAddrspaceCast(Location loc,
+                                 ConversionPatternRewriter &rewriter,
+                                 LLVM::LLVMPointerType destinationType,
+                                 Value sourcePtr,
+                                 LLVMTypeConverter &typeConverter) {
+  auto sourceTy = sourcePtr.getType().cast<LLVM::LLVMPointerType>();
+  if (destinationType.getAddressSpace() != sourceTy.getAddressSpace())
+    sourcePtr = rewriter.create<LLVM::AddrSpaceCastOp>(
+        loc,
+        typeConverter.getPointerType(sourceTy.getElementType(),
+                                     destinationType.getAddressSpace()),
+        sourcePtr);
+
+  return rewriter.create<LLVM::BitcastOp>(loc, destinationType, sourcePtr);
+}
+
 LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite(
     gpu::MemcpyOp memcpyOp, OpAdaptor adaptor,
     ConversionPatternRewriter &rewriter) const {
@@ -829,11 +845,13 @@ LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite(
   auto sizeBytes =
       rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gepPtr);
 
-  auto src = rewriter.create<LLVM::BitcastOp>(
-      loc, llvmPointerType, srcDesc.alignedPtr(rewriter, loc));
-  auto dst = rewriter.create<LLVM::BitcastOp>(
-      loc, llvmPointerType,
-      MemRefDescriptor(adaptor.getDst()).alignedPtr(rewriter, loc));
+  auto src = bitAndAddrspaceCast(loc, rewriter, llvmPointerType,
+                                 srcDesc.alignedPtr(rewriter, loc),
+                                 *getTypeConverter());
+  auto dst = bitAndAddrspaceCast(
+      loc, rewriter, llvmPointerType,
+      MemRefDescriptor(adaptor.getDst()).alignedPtr(rewriter, loc),
+      *getTypeConverter());
 
   auto stream = adaptor.getAsyncDependencies().front();
   memcpyCallBuilder.create(loc, rewriter, {dst, src, sizeBytes, stream});
@@ -866,8 +884,9 @@ LogicalResult ConvertMemsetOpToGpuRuntimeCallPattern::matchAndRewrite(
 
   auto value =
       rewriter.create<LLVM::BitcastOp>(loc, llvmInt32Type, adaptor.getValue());
-  auto dst = rewriter.create<LLVM::BitcastOp>(
-      loc, llvmPointerType, dstDesc.alignedPtr(rewriter, loc));
+  auto dst = bitAndAddrspaceCast(loc, rewriter, llvmPointerType,
+                                 dstDesc.alignedPtr(rewriter, loc),
+                                 *getTypeConverter());
 
   auto stream = adaptor.getAsyncDependencies().front();
   memsetCallBuilder.create(loc, rewriter, {dst, value, numElements, stream});

diff  --git a/mlir/test/Conversion/GPUCommon/lower-memcpy-to-gpu-runtime-calls.mlir b/mlir/test/Conversion/GPUCommon/lower-memcpy-to-gpu-runtime-calls.mlir
index df10b317aa499..89c0268a0f72e 100644
--- a/mlir/test/Conversion/GPUCommon/lower-memcpy-to-gpu-runtime-calls.mlir
+++ b/mlir/test/Conversion/GPUCommon/lower-memcpy-to-gpu-runtime-calls.mlir
@@ -7,8 +7,10 @@ module attributes {gpu.container_module} {
     // CHECK: %[[t0:.*]] = llvm.call @mgpuStreamCreate
     %t0 = gpu.wait async
     // CHECK: %[[size_bytes:.*]] = llvm.ptrtoint
+    // CHECK-NOT: llvm.addrspacecast
     // CHECK: %[[src:.*]] = llvm.bitcast
-    // CHECK: %[[dst:.*]] = llvm.bitcast
+    // CHECK: %[[addr_cast:.*]] = llvm.addrspacecast
+    // CHECK: %[[dst:.*]] = llvm.bitcast %[[addr_cast]]
     // CHECK: llvm.call @mgpuMemcpy(%[[dst]], %[[src]], %[[size_bytes]], %[[t0]])
     %t1 = gpu.memcpy async [%t0] %dst, %src : memref<7xf32, 1>, memref<7xf32>
     // CHECK: llvm.call @mgpuStreamSynchronize(%[[t0]])

diff  --git a/mlir/test/Conversion/GPUCommon/lower-memset-to-gpu-runtime-calls.mlir b/mlir/test/Conversion/GPUCommon/lower-memset-to-gpu-runtime-calls.mlir
index ef5b6ef2c7bb6..562c15583369b 100644
--- a/mlir/test/Conversion/GPUCommon/lower-memset-to-gpu-runtime-calls.mlir
+++ b/mlir/test/Conversion/GPUCommon/lower-memset-to-gpu-runtime-calls.mlir
@@ -8,7 +8,8 @@ module attributes {gpu.container_module} {
     %t0 = gpu.wait async
     // CHECK: %[[size_bytes:.*]] = llvm.mlir.constant
     // CHECK: %[[value:.*]] = llvm.bitcast
-    // CHECK: %[[dst:.*]] = llvm.bitcast
+    // CHECK: %[[addr_cast:.*]] = llvm.addrspacecast
+    // CHECK: %[[dst:.*]] = llvm.bitcast %[[addr_cast]]
     // CHECK: llvm.call @mgpuMemset32(%[[dst]], %[[value]], %[[size_bytes]], %[[t0]])
     %t1 = gpu.memset async [%t0] %dst, %value : memref<7xf32, 1>, f32
     // CHECK: llvm.call @mgpuStreamSynchronize(%[[t0]])


        


More information about the Mlir-commits mailing list