[Mlir-commits] [mlir] 18cc07a - [MLIR][GPU] Add 16-bit version of cudaMemset in cudaRuntimeWrappers

Uday Bondhugula llvmlistbot at llvm.org
Thu Jun 8 05:05:46 PDT 2023


Author: Navdeep Katel
Date: 2023-06-08T17:33:26+05:30
New Revision: 18cc07aa07f6784cc59a4b4cfe33522867805586

URL: https://github.com/llvm/llvm-project/commit/18cc07aa07f6784cc59a4b4cfe33522867805586
DIFF: https://github.com/llvm/llvm-project/commit/18cc07aa07f6784cc59a4b4cfe33522867805586.diff

LOG: [MLIR][GPU] Add 16-bit version of cudaMemset in cudaRuntimeWrappers

Add 16-bit version of cudaMemset in cudaRuntimeWrappers and update the GPU to LLVM lowering.

Reviewed By: bondhugula

Differential Revision: https://reviews.llvm.org/D151642

Added: 
    

Modified: 
    mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
    mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
    mlir/test/Conversion/GPUCommon/typed-pointers.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index b528a0e256afe..272a9074f804e 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -97,6 +97,7 @@ class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> {
   Type llvmPointerPointerType =
       this->getTypeConverter()->getPointerType(llvmPointerType);
   Type llvmInt8Type = IntegerType::get(context, 8);
+  Type llvmInt16Type = IntegerType::get(context, 16);
   Type llvmInt32Type = IntegerType::get(context, 32);
   Type llvmInt64Type = IntegerType::get(context, 64);
   Type llvmInt8PointerType =
@@ -186,7 +187,14 @@ class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> {
       {llvmPointerType /* void *dst */, llvmPointerType /* void *src */,
        llvmIntPtrType /* intptr_t sizeBytes */,
        llvmPointerType /* void *stream */}};
-  FunctionCallBuilder memsetCallBuilder = {
+  FunctionCallBuilder memset16CallBuilder = {
+      "mgpuMemset16",
+      llvmVoidType,
+      {llvmPointerType /* void *dst */,
+       llvmInt16Type /* unsigned short value */,
+       llvmIntPtrType /* intptr_t sizeBytes */,
+       llvmPointerType /* void *stream */}};
+  FunctionCallBuilder memset32CallBuilder = {
       "mgpuMemset32",
       llvmVoidType,
       {llvmPointerType /* void *dst */, llvmInt32Type /* unsigned int value */,
@@ -1365,22 +1373,29 @@ LogicalResult ConvertMemsetOpToGpuRuntimeCallPattern::matchAndRewrite(
   auto loc = memsetOp.getLoc();
 
   Type valueType = adaptor.getValue().getType();
-  if (!valueType.isIntOrFloat() || valueType.getIntOrFloatBitWidth() != 32) {
-    return rewriter.notifyMatchFailure(memsetOp,
-                                       "value must be a 32 bit scalar");
+  unsigned bitWidth = valueType.getIntOrFloatBitWidth();
+  // Ints and floats of 16 or 32 bit width are allowed.
+  if (!valueType.isIntOrFloat() || (bitWidth != 16 && bitWidth != 32)) {
+    return rewriter.notifyMatchFailure(
+        memsetOp, "value must be a 16 or 32 bit int or float");
   }
 
+  unsigned valueTypeWidth = valueType.getIntOrFloatBitWidth();
+  Type bitCastType = valueTypeWidth == 32 ? llvmInt32Type : llvmInt16Type;
+
   MemRefDescriptor dstDesc(adaptor.getDst());
   Value numElements = getNumElements(rewriter, loc, memRefType, dstDesc);
 
   auto value =
-      rewriter.create<LLVM::BitcastOp>(loc, llvmInt32Type, adaptor.getValue());
+      rewriter.create<LLVM::BitcastOp>(loc, bitCastType, adaptor.getValue());
   auto dst = bitAndAddrspaceCast(loc, rewriter, llvmPointerType,
                                  dstDesc.alignedPtr(rewriter, loc),
                                  *getTypeConverter());
 
   auto stream = adaptor.getAsyncDependencies().front();
-  memsetCallBuilder.create(loc, rewriter, {dst, value, numElements, stream});
+  FunctionCallBuilder builder =
+      valueTypeWidth == 32 ? memset32CallBuilder : memset16CallBuilder;
+  builder.create(loc, rewriter, {dst, value, numElements, stream});
 
   rewriter.replaceOp(memsetOp, {stream});
   return success();

diff  --git a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
index f654f6a2a5ef7..c811b72a21403 100644
--- a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
+++ b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
@@ -176,6 +176,12 @@ extern "C" void mgpuMemset32(void *dst, unsigned int value, size_t count,
                                         value, count, stream));
 }
 
+extern "C" void mgpuMemset16(void *dst, unsigned short value, size_t count,
+                             CUstream stream) {
+  CUDA_REPORT_IF_ERROR(cuMemsetD16Async(reinterpret_cast<CUdeviceptr>(dst),
+                                        value, count, stream));
+}
+
 ///
 /// Helper functions for writing mlir example code
 ///

diff  --git a/mlir/test/Conversion/GPUCommon/typed-pointers.mlir b/mlir/test/Conversion/GPUCommon/typed-pointers.mlir
index 9752ffa007b02..2fa6c854c5678 100644
--- a/mlir/test/Conversion/GPUCommon/typed-pointers.mlir
+++ b/mlir/test/Conversion/GPUCommon/typed-pointers.mlir
@@ -42,8 +42,8 @@ module attributes {gpu.container_module} {
 
 module attributes {gpu.container_module} {
 
-  // CHECK: func @foo
-  func.func @foo(%dst : memref<7xf32, 1>, %value : f32) {
+  // CHECK: func @memset_f32
+  func.func @memset_f32(%dst : memref<7xf32, 1>, %value : f32) {
     // CHECK: %[[t0:.*]] = llvm.call @mgpuStreamCreate
     %t0 = gpu.wait async
     // CHECK: %[[size_bytes:.*]] = llvm.mlir.constant
@@ -59,3 +59,23 @@ module attributes {gpu.container_module} {
   }
 }
 
+// -----
+
+module attributes {gpu.container_module} {
+
+  // CHECK: func @memset_f16
+  func.func @memset_f16(%dst : memref<7xf16, 1>, %value : f16) {
+    // CHECK: %[[t0:.*]] = llvm.call @mgpuStreamCreate
+    %t0 = gpu.wait async
+    // CHECK: %[[size_bytes:.*]] = llvm.mlir.constant
+    // CHECK: %[[value:.*]] = llvm.bitcast
+    // CHECK: %[[addr_cast:.*]] = llvm.addrspacecast
+    // CHECK: %[[dst:.*]] = llvm.bitcast %[[addr_cast]]
+    // CHECK: llvm.call @mgpuMemset16(%[[dst]], %[[value]], %[[size_bytes]], %[[t0]])
+    %t1 = gpu.memset async [%t0] %dst, %value : memref<7xf16, 1>, f16
+    // CHECK: llvm.call @mgpuStreamSynchronize(%[[t0]])
+    // CHECK: llvm.call @mgpuStreamDestroy(%[[t0]])
+    gpu.wait [%t1]
+    return
+  }
+}


        


More information about the Mlir-commits mailing list