[Mlir-commits] [mlir] 2c48e36 - [MLIR] Adding gpu.host_register op and lower it to a runtime call.

Christian Sigg llvmlistbot at llvm.org
Mon Aug 10 13:46:27 PDT 2020


Author: Christian Sigg
Date: 2020-08-10T22:46:17+02:00
New Revision: 2c48e3629cfb25ac1034117fa8945fd0d342f2ae

URL: https://github.com/llvm/llvm-project/commit/2c48e3629cfb25ac1034117fa8945fd0d342f2ae
DIFF: https://github.com/llvm/llvm-project/commit/2c48e3629cfb25ac1034117fa8945fd0d342f2ae.diff

LOG: [MLIR] Adding gpu.host_register op and lower it to a runtime call.

Reviewed By: herhut

Differential Revision: https://reviews.llvm.org/D85631

Added: 
    

Modified: 
    mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
    mlir/include/mlir/Dialect/GPU/GPUOps.td
    mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp
    mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
    mlir/test/mlir-cuda-runner/all-reduce-and.mlir
    mlir/test/mlir-cuda-runner/all-reduce-max.mlir
    mlir/test/mlir-cuda-runner/all-reduce-min.mlir
    mlir/test/mlir-cuda-runner/all-reduce-op.mlir
    mlir/test/mlir-cuda-runner/all-reduce-or.mlir
    mlir/test/mlir-cuda-runner/all-reduce-region.mlir
    mlir/test/mlir-cuda-runner/all-reduce-xor.mlir
    mlir/test/mlir-cuda-runner/gpu-to-cubin.mlir
    mlir/test/mlir-cuda-runner/multiple-all-reduce.mlir
    mlir/test/mlir-cuda-runner/shuffle.mlir
    mlir/test/mlir-cuda-runner/two-modules.mlir
    mlir/test/mlir-rocm-runner/gpu-to-hsaco.mlir
    mlir/test/mlir-rocm-runner/two-modules.mlir
    mlir/test/mlir-rocm-runner/vecadd.mlir
    mlir/test/mlir-rocm-runner/vector-transferops.mlir
    mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp
    mlir/tools/mlir-rocm-runner/rocm-runtime-wrappers.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
index 25bff67d4a6e..2f8f87fa0e41 100644
--- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
+++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
@@ -440,6 +440,10 @@ class ConvertToLLVMPattern : public ConversionPattern {
                                 ConversionPatternRewriter &rewriter,
                                 SmallVectorImpl<Value> &sizes) const;
 
+  /// Computes the size of type in bytes.
+  Value getSizeInBytes(Location loc, Type type,
+                       ConversionPatternRewriter &rewriter) const;
+
   /// Computes total size in bytes of to store the given shape.
   Value getCumulativeSizeInBytes(Location loc, Type elementType,
                                  ArrayRef<Value> shape,

diff  --git a/mlir/include/mlir/Dialect/GPU/GPUOps.td b/mlir/include/mlir/Dialect/GPU/GPUOps.td
index c0a6ac101d7b..288031c598ff 100644
--- a/mlir/include/mlir/Dialect/GPU/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/GPUOps.td
@@ -741,4 +741,16 @@ def GPU_ModuleEndOp : GPU_Op<"module_end", [
   let printer = [{ p << getOperationName(); }];
 }
 
+def GPU_HostRegisterOp : GPU_Op<"host_register">,
+    Arguments<(ins AnyUnrankedMemRef:$value)> {
+  let summary = "Registers a memref for access from device.";
+  let description = [{
+    This op registers the host memory pointed to by a memref to be accessed from
+    a device.
+  }];
+
+  let assemblyFormat = "$value attr-dict `:` type($value)";
+  let verifier = [{ return success(); }];
+}
+
 #endif // GPU_OPS

diff  --git a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp
index e14c9fbc6718..8aa843308cff 100644
--- a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp
+++ b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp
@@ -117,6 +117,26 @@ class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> {
       "mgpuStreamSynchronize",
       llvmVoidType,
       {llvmPointerType /* void *stream */}};
+  FunctionCallBuilder hostRegisterCallBuilder = {
+      "mgpuMemHostRegisterMemRef",
+      llvmVoidType,
+      {llvmIntPtrType /* intptr_t rank */,
+       llvmPointerType /* void *memrefDesc */,
+       llvmIntPtrType /* intptr_t elementSizeBytes */}};
+};
+
+/// A rewrite patter to convert gpu.host_register operations into a GPU runtime
+/// call. Currently it supports CUDA and ROCm (HIP).
+class ConvertHostRegisterOpToGpuRuntimeCallPattern
+    : public ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp> {
+public:
+  ConvertHostRegisterOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
+      : ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp>(typeConverter) {}
+
+private:
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override;
 };
 
 /// A rewrite patter to convert gpu.launch_func operations into a sequence of
@@ -192,6 +212,33 @@ LLVM::CallOp FunctionCallBuilder::create(Location loc, OpBuilder &builder,
       builder.getSymbolRefAttr(function), arguments);
 }
 
+// Returns whether value is of LLVM type.
+static bool isLLVMType(Value value) {
+  return value.getType().isa<LLVM::LLVMType>();
+}
+
+LogicalResult ConvertHostRegisterOpToGpuRuntimeCallPattern::matchAndRewrite(
+    Operation *op, ArrayRef<Value> operands,
+    ConversionPatternRewriter &rewriter) const {
+  if (!llvm::all_of(operands, isLLVMType))
+    return rewriter.notifyMatchFailure(
+        op, "Cannot convert if operands aren't of LLVM type.");
+
+  Location loc = op->getLoc();
+
+  auto memRefType = cast<gpu::HostRegisterOp>(op).value().getType();
+  auto elementType = memRefType.cast<UnrankedMemRefType>().getElementType();
+  auto elementSize = getSizeInBytes(loc, elementType, rewriter);
+
+  auto arguments =
+      typeConverter.promoteOperands(loc, op->getOperands(), operands, rewriter);
+  arguments.push_back(elementSize);
+  hostRegisterCallBuilder.create(loc, rewriter, arguments);
+
+  rewriter.eraseOp(op);
+  return success();
+}
+
 // Creates a struct containing all kernel parameters on the stack and returns
 // an array of type-erased pointers to the fields of the struct. The array can
 // then be passed to the CUDA / ROCm (HIP) kernel launch calls.
@@ -269,11 +316,6 @@ Value ConvertLaunchFuncOpToGpuRuntimeCallPattern::generateKernelNameConstant(
       LLVM::Linkage::Internal);
 }
 
-// Returns whether value is of LLVM type.
-static bool isLLVMType(Value value) {
-  return value.getType().isa<LLVM::LLVMType>();
-}
-
 // Emits LLVM IR to launch a kernel function. Expects the module that contains
 // the compiled kernel function as a cubin in the 'nvvm.cubin' attribute, or a
 // hsaco in the 'rocdl.hsaco' attribute of the kernel function in the IR.
@@ -351,6 +393,7 @@ mlir::createGpuToLLVMConversionPass(StringRef gpuBinaryAnnotation) {
 void mlir::populateGpuToLLVMConversionPatterns(
     LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
     StringRef gpuBinaryAnnotation) {
+  patterns.insert<ConvertHostRegisterOpToGpuRuntimeCallPattern>(converter);
   patterns.insert<ConvertLaunchFuncOpToGpuRuntimeCallPattern>(
       converter, gpuBinaryAnnotation);
   patterns.insert<EraseGpuModuleOpPattern>(&converter.getContext());

diff  --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index 5f8a38743400..08d1c32d13c5 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -927,30 +927,32 @@ void ConvertToLLVMPattern::getMemRefDescriptorSizes(
                         : createIndexConstant(rewriter, loc, s));
 }
 
-Value ConvertToLLVMPattern::getCumulativeSizeInBytes(
-    Location loc, Type elementType, ArrayRef<Value> sizes,
-    ConversionPatternRewriter &rewriter) const {
-  // Compute the total number of memref elements.
-  Value cumulativeSizeInBytes =
-      sizes.empty() ? createIndexConstant(rewriter, loc, 1) : sizes.front();
-  for (unsigned i = 1, e = sizes.size(); i < e; ++i)
-    cumulativeSizeInBytes = rewriter.create<LLVM::MulOp>(
-        loc, getIndexType(), ArrayRef<Value>{cumulativeSizeInBytes, sizes[i]});
-
+Value ConvertToLLVMPattern::getSizeInBytes(
+    Location loc, Type type, ConversionPatternRewriter &rewriter) const {
   // Compute the size of an individual element. This emits the MLIR equivalent
   // of the following sizeof(...) implementation in LLVM IR:
   //   %0 = getelementptr %elementType* null, %indexType 1
   //   %1 = ptrtoint %elementType* %0 to %indexType
   // which is a common pattern of getting the size of a type in bytes.
-  auto convertedPtrType = typeConverter.convertType(elementType)
-                              .cast<LLVM::LLVMType>()
-                              .getPointerTo();
+  auto convertedPtrType =
+      typeConverter.convertType(type).cast<LLVM::LLVMType>().getPointerTo();
   auto nullPtr = rewriter.create<LLVM::NullOp>(loc, convertedPtrType);
   auto gep = rewriter.create<LLVM::GEPOp>(
       loc, convertedPtrType,
       ArrayRef<Value>{nullPtr, createIndexConstant(rewriter, loc, 1)});
-  auto elementSize =
-      rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gep);
+  return rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gep);
+}
+
+Value ConvertToLLVMPattern::getCumulativeSizeInBytes(
+    Location loc, Type elementType, ArrayRef<Value> sizes,
+    ConversionPatternRewriter &rewriter) const {
+  // Compute the total number of memref elements.
+  Value cumulativeSizeInBytes =
+      sizes.empty() ? createIndexConstant(rewriter, loc, 1) : sizes.front();
+  for (unsigned i = 1, e = sizes.size(); i < e; ++i)
+    cumulativeSizeInBytes = rewriter.create<LLVM::MulOp>(
+        loc, getIndexType(), ArrayRef<Value>{cumulativeSizeInBytes, sizes[i]});
+  auto elementSize = this->getSizeInBytes(loc, elementType, rewriter);
   return rewriter.create<LLVM::MulOp>(
       loc, getIndexType(), ArrayRef<Value>{cumulativeSizeInBytes, elementSize});
 }

diff  --git a/mlir/test/mlir-cuda-runner/all-reduce-and.mlir b/mlir/test/mlir-cuda-runner/all-reduce-and.mlir
index f89f91415724..ef8d50580d9e 100644
--- a/mlir/test/mlir-cuda-runner/all-reduce-and.mlir
+++ b/mlir/test/mlir-cuda-runner/all-reduce-and.mlir
@@ -25,9 +25,9 @@ func @main() {
   %c6 = constant 6 : index
 
   %cast_data = memref_cast %data : memref<2x6xi32> to memref<*xi32>
-  call @mgpuMemHostRegisterInt32(%cast_data) : (memref<*xi32>) -> ()
+  gpu.host_register %cast_data : memref<*xi32>
   %cast_sum = memref_cast %sum : memref<2xi32> to memref<*xi32>
-  call @mgpuMemHostRegisterInt32(%cast_sum) : (memref<*xi32>) -> ()
+  gpu.host_register %cast_sum : memref<*xi32>
 
   store %cst0, %data[%c0, %c0] : memref<2x6xi32>
   store %cst1, %data[%c0, %c1] : memref<2x6xi32>
@@ -58,6 +58,5 @@ func @main() {
   return
 }
 
-func @mgpuMemHostRegisterInt32(%ptr : memref<*xi32>)
 func @print_memref_i32(memref<*xi32>)
 

diff  --git a/mlir/test/mlir-cuda-runner/all-reduce-max.mlir b/mlir/test/mlir-cuda-runner/all-reduce-max.mlir
index 4adf8a73d924..be8087a55ea2 100644
--- a/mlir/test/mlir-cuda-runner/all-reduce-max.mlir
+++ b/mlir/test/mlir-cuda-runner/all-reduce-max.mlir
@@ -25,9 +25,9 @@ func @main() {
   %c6 = constant 6 : index
 
   %cast_data = memref_cast %data : memref<2x6xi32> to memref<*xi32>
-  call @mgpuMemHostRegisterInt32(%cast_data) : (memref<*xi32>) -> ()
+  gpu.host_register %cast_data : memref<*xi32>
   %cast_sum = memref_cast %sum : memref<2xi32> to memref<*xi32>
-  call @mgpuMemHostRegisterInt32(%cast_sum) : (memref<*xi32>) -> ()
+  gpu.host_register %cast_sum : memref<*xi32>
 
   store %cst0, %data[%c0, %c0] : memref<2x6xi32>
   store %cst1, %data[%c0, %c1] : memref<2x6xi32>
@@ -58,6 +58,5 @@ func @main() {
   return
 }
 
-func @mgpuMemHostRegisterInt32(%ptr : memref<*xi32>)
 func @print_memref_i32(memref<*xi32>)
 

diff  --git a/mlir/test/mlir-cuda-runner/all-reduce-min.mlir b/mlir/test/mlir-cuda-runner/all-reduce-min.mlir
index 8cb3116e9d0d..ad03ed5497f0 100644
--- a/mlir/test/mlir-cuda-runner/all-reduce-min.mlir
+++ b/mlir/test/mlir-cuda-runner/all-reduce-min.mlir
@@ -25,9 +25,9 @@ func @main() {
   %c6 = constant 6 : index
 
   %cast_data = memref_cast %data : memref<2x6xi32> to memref<*xi32>
-  call @mgpuMemHostRegisterInt32(%cast_data) : (memref<*xi32>) -> ()
+  gpu.host_register %cast_data : memref<*xi32>
   %cast_sum = memref_cast %sum : memref<2xi32> to memref<*xi32>
-  call @mgpuMemHostRegisterInt32(%cast_sum) : (memref<*xi32>) -> ()
+  gpu.host_register %cast_sum : memref<*xi32>
 
   store %cst0, %data[%c0, %c0] : memref<2x6xi32>
   store %cst1, %data[%c0, %c1] : memref<2x6xi32>
@@ -58,6 +58,5 @@ func @main() {
   return
 }
 
-func @mgpuMemHostRegisterInt32(%ptr : memref<*xi32>)
 func @print_memref_i32(memref<*xi32>)
 

diff  --git a/mlir/test/mlir-cuda-runner/all-reduce-op.mlir b/mlir/test/mlir-cuda-runner/all-reduce-op.mlir
index 72306674c3ff..a639c699027b 100644
--- a/mlir/test/mlir-cuda-runner/all-reduce-op.mlir
+++ b/mlir/test/mlir-cuda-runner/all-reduce-op.mlir
@@ -11,7 +11,7 @@ func @main() {
   %sy = dim %dst, %c1 : memref<?x?x?xf32>
   %sz = dim %dst, %c0 : memref<?x?x?xf32>
   %cast_dst = memref_cast %dst : memref<?x?x?xf32> to memref<*xf32>
-  call @mgpuMemHostRegisterFloat(%cast_dst) : (memref<*xf32>) -> ()
+  gpu.host_register %cast_dst : memref<*xf32>
   gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1)
              threads(%tx, %ty, %tz) in (%block_x = %sx, %block_y = %sy, %block_z = %sz) {
     %t0 = muli %tz, %block_y : index
@@ -28,5 +28,4 @@ func @main() {
   return
 }
 
-func @mgpuMemHostRegisterFloat(%ptr : memref<*xf32>)
 func @print_memref_f32(%ptr : memref<*xf32>)

diff  --git a/mlir/test/mlir-cuda-runner/all-reduce-or.mlir b/mlir/test/mlir-cuda-runner/all-reduce-or.mlir
index 7d0ed929322e..1ed85e61ded5 100644
--- a/mlir/test/mlir-cuda-runner/all-reduce-or.mlir
+++ b/mlir/test/mlir-cuda-runner/all-reduce-or.mlir
@@ -25,9 +25,9 @@ func @main() {
   %c6 = constant 6 : index
 
   %cast_data = memref_cast %data : memref<2x6xi32> to memref<*xi32>
-  call @mgpuMemHostRegisterInt32(%cast_data) : (memref<*xi32>) -> ()
+  gpu.host_register %cast_data : memref<*xi32>
   %cast_sum = memref_cast %sum : memref<2xi32> to memref<*xi32>
-  call @mgpuMemHostRegisterInt32(%cast_sum) : (memref<*xi32>) -> ()
+  gpu.host_register %cast_sum : memref<*xi32>
 
   store %cst0, %data[%c0, %c0] : memref<2x6xi32>
   store %cst1, %data[%c0, %c1] : memref<2x6xi32>
@@ -58,6 +58,5 @@ func @main() {
   return
 }
 
-func @mgpuMemHostRegisterInt32(%ptr : memref<*xi32>)
 func @print_memref_i32(memref<*xi32>)
 

diff  --git a/mlir/test/mlir-cuda-runner/all-reduce-region.mlir b/mlir/test/mlir-cuda-runner/all-reduce-region.mlir
index a9426c658978..49f24db131c7 100644
--- a/mlir/test/mlir-cuda-runner/all-reduce-region.mlir
+++ b/mlir/test/mlir-cuda-runner/all-reduce-region.mlir
@@ -8,7 +8,7 @@ func @main() {
   %c0 = constant 0 : index
   %sx = dim %dst, %c0 : memref<?xf32>
   %cast_dst = memref_cast %dst : memref<?xf32> to memref<*xf32>
-  call @mgpuMemHostRegisterFloat(%cast_dst) : (memref<*xf32>) -> ()
+  gpu.host_register %cast_dst : memref<*xf32>
   gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %one, %grid_y = %one, %grid_z = %one)
              threads(%tx, %ty, %tz) in (%block_x = %sx, %block_y = %one, %block_z = %one) {
     %val = index_cast %tx : index to i32
@@ -25,5 +25,4 @@ func @main() {
   return
 }
 
-func @mgpuMemHostRegisterFloat(%ptr : memref<*xf32>)
 func @print_memref_f32(memref<*xf32>)

diff  --git a/mlir/test/mlir-cuda-runner/all-reduce-xor.mlir b/mlir/test/mlir-cuda-runner/all-reduce-xor.mlir
index 67461783b257..88cd9036998f 100644
--- a/mlir/test/mlir-cuda-runner/all-reduce-xor.mlir
+++ b/mlir/test/mlir-cuda-runner/all-reduce-xor.mlir
@@ -25,9 +25,9 @@ func @main() {
   %c6 = constant 6 : index
 
   %cast_data = memref_cast %data : memref<2x6xi32> to memref<*xi32>
-  call @mgpuMemHostRegisterInt32(%cast_data) : (memref<*xi32>) -> ()
+  gpu.host_register %cast_data : memref<*xi32>
   %cast_sum = memref_cast %sum : memref<2xi32> to memref<*xi32>
-  call @mgpuMemHostRegisterInt32(%cast_sum) : (memref<*xi32>) -> ()
+  gpu.host_register %cast_sum : memref<*xi32>
 
   store %cst0, %data[%c0, %c0] : memref<2x6xi32>
   store %cst1, %data[%c0, %c1] : memref<2x6xi32>
@@ -58,6 +58,5 @@ func @main() {
   return
 }
 
-func @mgpuMemHostRegisterInt32(%ptr : memref<*xi32>)
 func @print_memref_i32(memref<*xi32>)
 

diff  --git a/mlir/test/mlir-cuda-runner/gpu-to-cubin.mlir b/mlir/test/mlir-cuda-runner/gpu-to-cubin.mlir
index 80339c36fb38..fb458eb375d0 100644
--- a/mlir/test/mlir-cuda-runner/gpu-to-cubin.mlir
+++ b/mlir/test/mlir-cuda-runner/gpu-to-cubin.mlir
@@ -18,7 +18,7 @@ func @main() {
   %21 = constant 5 : i32
   %22 = memref_cast %arg0 : memref<5xf32> to memref<?xf32>
   %23 = memref_cast %22 : memref<?xf32> to memref<*xf32>
-  call @mgpuMemHostRegisterFloat(%23) : (memref<*xf32>) -> ()
+  gpu.host_register %23 : memref<*xf32>
   call @print_memref_f32(%23) : (memref<*xf32>) -> ()
   %24 = constant 1.0 : f32
   call @other_func(%24, %22) : (f32, memref<?xf32>) -> ()
@@ -26,5 +26,4 @@ func @main() {
   return
 }
 
-func @mgpuMemHostRegisterFloat(%ptr : memref<*xf32>)
 func @print_memref_f32(%ptr : memref<*xf32>)

diff  --git a/mlir/test/mlir-cuda-runner/multiple-all-reduce.mlir b/mlir/test/mlir-cuda-runner/multiple-all-reduce.mlir
index b88d8e1b8ba1..f3fb3b219ce0 100644
--- a/mlir/test/mlir-cuda-runner/multiple-all-reduce.mlir
+++ b/mlir/test/mlir-cuda-runner/multiple-all-reduce.mlir
@@ -26,11 +26,11 @@ func @main() {
   %c6 = constant 6 : index
 
   %cast_data = memref_cast %data : memref<2x6xf32> to memref<*xf32>
-  call @mgpuMemHostRegisterFloat(%cast_data) : (memref<*xf32>) -> ()
+  gpu.host_register %cast_data : memref<*xf32>
   %cast_sum = memref_cast %sum : memref<2xf32> to memref<*xf32>
-  call @mgpuMemHostRegisterFloat(%cast_sum) : (memref<*xf32>) -> ()
+  gpu.host_register %cast_sum : memref<*xf32>
   %cast_mul = memref_cast %mul : memref<2xf32> to memref<*xf32>
-  call @mgpuMemHostRegisterFloat(%cast_mul) : (memref<*xf32>) -> ()
+  gpu.host_register %cast_mul : memref<*xf32>
 
   store %cst0, %data[%c0, %c0] : memref<2x6xf32>
   store %cst1, %data[%c0, %c1] : memref<2x6xf32>
@@ -66,5 +66,4 @@ func @main() {
   return
 }
 
-func @mgpuMemHostRegisterFloat(%ptr : memref<*xf32>)
 func @print_memref_f32(memref<*xf32>)

diff  --git a/mlir/test/mlir-cuda-runner/shuffle.mlir b/mlir/test/mlir-cuda-runner/shuffle.mlir
index a4563cc0c381..9846455142d6 100644
--- a/mlir/test/mlir-cuda-runner/shuffle.mlir
+++ b/mlir/test/mlir-cuda-runner/shuffle.mlir
@@ -7,8 +7,8 @@ func @main() {
   %one = constant 1 : index
   %c0 = constant 0 : index
   %sx = dim %dst, %c0 : memref<?xf32>
-  %cast_dest = memref_cast %dst : memref<?xf32> to memref<*xf32>
-  call @mgpuMemHostRegisterFloat(%cast_dest) : (memref<*xf32>) -> ()
+  %cast_dst = memref_cast %dst : memref<?xf32> to memref<*xf32>
+  gpu.host_register %cast_dst : memref<*xf32>
   gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %one, %grid_y = %one, %grid_z = %one)
              threads(%tx, %ty, %tz) in (%block_x = %sx, %block_y = %one, %block_z = %one) {
     %t0 = index_cast %tx : index to i32
@@ -24,9 +24,8 @@ func @main() {
     store %value, %dst[%tx] : memref<?xf32>
     gpu.terminator
   }
-  call @print_memref_f32(%cast_dest) : (memref<*xf32>) -> ()
+  call @print_memref_f32(%cast_dst) : (memref<*xf32>) -> ()
   return
 }
 
-func @mgpuMemHostRegisterFloat(%ptr : memref<*xf32>)
 func @print_memref_f32(%ptr : memref<*xf32>)

diff  --git a/mlir/test/mlir-cuda-runner/two-modules.mlir b/mlir/test/mlir-cuda-runner/two-modules.mlir
index 9bdda2ae9c66..2ea58ae55b5e 100644
--- a/mlir/test/mlir-cuda-runner/two-modules.mlir
+++ b/mlir/test/mlir-cuda-runner/two-modules.mlir
@@ -8,7 +8,7 @@ func @main() {
   %c0 = constant 0 : index
   %sx = dim %dst, %c0 : memref<?xi32>
   %cast_dst = memref_cast %dst : memref<?xi32> to memref<*xi32>
-  call @mgpuMemHostRegisterInt32(%cast_dst) : (memref<*xi32>) -> ()
+  gpu.host_register %cast_dst : memref<*xi32>
   gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %one, %grid_y = %one, %grid_z = %one)
              threads(%tx, %ty, %tz) in (%block_x = %sx, %block_y = %one, %block_z = %one) {
     %t0 = index_cast %tx : index to i32
@@ -25,5 +25,4 @@ func @main() {
   return
 }
 
-func @mgpuMemHostRegisterInt32(%memref : memref<*xi32>)
 func @print_memref_i32(%memref : memref<*xi32>)

diff  --git a/mlir/test/mlir-rocm-runner/gpu-to-hsaco.mlir b/mlir/test/mlir-rocm-runner/gpu-to-hsaco.mlir
index 4b1137468b14..68e31fede8dc 100644
--- a/mlir/test/mlir-rocm-runner/gpu-to-hsaco.mlir
+++ b/mlir/test/mlir-rocm-runner/gpu-to-hsaco.mlir
@@ -18,7 +18,7 @@ func @main() {
   %21 = constant 5 : i32
   %22 = memref_cast %arg0 : memref<5xf32> to memref<?xf32>
   %cast = memref_cast %22 : memref<?xf32> to memref<*xf32>
-  call @mgpuMemHostRegisterFloat(%cast) : (memref<*xf32>) -> ()
+  gpu.host_register %cast : memref<*xf32>
   %23 = memref_cast %22 : memref<?xf32> to memref<*xf32>
   call @print_memref_f32(%23) : (memref<*xf32>) -> ()
   %24 = constant 1.0 : f32
@@ -28,6 +28,5 @@ func @main() {
   return
 }
 
-func @mgpuMemHostRegisterFloat(%ptr : memref<*xf32>)
 func @mgpuMemGetDeviceMemRef1dFloat(%ptr : memref<?xf32>) -> (memref<?xf32>)
 func @print_memref_f32(%ptr : memref<*xf32>)

diff  --git a/mlir/test/mlir-rocm-runner/two-modules.mlir b/mlir/test/mlir-rocm-runner/two-modules.mlir
index d6b92229b585..f196c8e8fefe 100644
--- a/mlir/test/mlir-rocm-runner/two-modules.mlir
+++ b/mlir/test/mlir-rocm-runner/two-modules.mlir
@@ -8,7 +8,7 @@ func @main() {
   %c1 = constant 1 : index
   %sx = dim %dst, %c0 : memref<?xi32>
   %cast_dst = memref_cast %dst : memref<?xi32> to memref<*xi32>
-  call @mgpuMemHostRegisterInt32(%cast_dst) : (memref<*xi32>) -> ()
+  gpu.host_register %cast_dst : memref<*xi32>
   %dst_device = call @mgpuMemGetDeviceMemRef1dInt32(%dst) : (memref<?xi32>) -> (memref<?xi32>)
   gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1)
              threads(%tx, %ty, %tz) in (%block_x = %sx, %block_y = %c1, %block_z = %c1) {
@@ -26,6 +26,5 @@ func @main() {
   return
 }
 
-func @mgpuMemHostRegisterInt32(%ptr : memref<*xi32>)
 func @mgpuMemGetDeviceMemRef1dInt32(%ptr : memref<?xi32>) -> (memref<?xi32>)
 func @print_memref_i32(%ptr : memref<*xi32>)

diff  --git a/mlir/test/mlir-rocm-runner/vecadd.mlir b/mlir/test/mlir-rocm-runner/vecadd.mlir
index a86412ff8fef..df5c073f9b81 100644
--- a/mlir/test/mlir-rocm-runner/vecadd.mlir
+++ b/mlir/test/mlir-rocm-runner/vecadd.mlir
@@ -26,9 +26,9 @@ func @main() {
   %6 = memref_cast %3 : memref<?xf32> to memref<*xf32>
   %7 = memref_cast %4 : memref<?xf32> to memref<*xf32>
   %8 = memref_cast %5 : memref<?xf32> to memref<*xf32>
-  call @mgpuMemHostRegisterFloat(%6) : (memref<*xf32>) -> ()
-  call @mgpuMemHostRegisterFloat(%7) : (memref<*xf32>) -> ()
-  call @mgpuMemHostRegisterFloat(%8) : (memref<*xf32>) -> ()
+  gpu.host_register %6 : memref<*xf32>
+  gpu.host_register %7 : memref<*xf32>
+  gpu.host_register %8 : memref<*xf32>
   %9 = call @mgpuMemGetDeviceMemRef1dFloat(%3) : (memref<?xf32>) -> (memref<?xf32>)
   %10 = call @mgpuMemGetDeviceMemRef1dFloat(%4) : (memref<?xf32>) -> (memref<?xf32>)
   %11 = call @mgpuMemGetDeviceMemRef1dFloat(%5) : (memref<?xf32>) -> (memref<?xf32>)
@@ -38,6 +38,5 @@ func @main() {
   return
 }
 
-func @mgpuMemHostRegisterFloat(%ptr : memref<*xf32>)
 func @mgpuMemGetDeviceMemRef1dFloat(%ptr : memref<?xf32>) -> (memref<?xf32>)
 func @print_memref_f32(%ptr : memref<*xf32>)

diff  --git a/mlir/test/mlir-rocm-runner/vector-transferops.mlir b/mlir/test/mlir-rocm-runner/vector-transferops.mlir
index b028f91f8394..873897011464 100644
--- a/mlir/test/mlir-rocm-runner/vector-transferops.mlir
+++ b/mlir/test/mlir-rocm-runner/vector-transferops.mlir
@@ -55,8 +55,8 @@ func @main() {
   %cast0 = memref_cast %22 : memref<?xf32> to memref<*xf32>
   %cast1 = memref_cast %23 : memref<?xf32> to memref<*xf32>
 
-  call @mgpuMemHostRegisterFloat(%cast0) : (memref<*xf32>) -> ()
-  call @mgpuMemHostRegisterFloat(%cast1) : (memref<*xf32>) -> ()
+  gpu.host_register %cast0 : memref<*xf32>
+  gpu.host_register %cast1 : memref<*xf32>
 
   %24 = call @mgpuMemGetDeviceMemRef1dFloat(%22) : (memref<?xf32>) -> (memref<?xf32>)
   %26 = call @mgpuMemGetDeviceMemRef1dFloat(%23) : (memref<?xf32>) -> (memref<?xf32>)
@@ -71,6 +71,5 @@ func @main() {
   return
 }
 
-func @mgpuMemHostRegisterFloat(%ptr : memref<*xf32>)
 func @mgpuMemGetDeviceMemRef1dFloat(%ptr : memref<?xf32>) -> (memref<?xf32>)
 func @print_memref_f32(%ptr : memref<*xf32>)

diff  --git a/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp b/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp
index 8e2dc029fa9f..517fc9fc18f5 100644
--- a/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp
+++ b/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp
@@ -75,17 +75,19 @@ extern "C" void mgpuMemHostRegister(void *ptr, uint64_t sizeBytes) {
   CUDA_REPORT_IF_ERROR(cuMemHostRegister(ptr, sizeBytes, /*flags=*/0));
 }
 
-// Allows to register a MemRef with the CUDA runtime. Initializes array with
-// value. Helpful until we have transfer functions implemented.
-template <typename T>
-void mgpuMemHostRegisterMemRef(const DynamicMemRefType<T> &memRef, T value) {
-  llvm::SmallVector<int64_t, 4> denseStrides(memRef.rank);
-  llvm::ArrayRef<int64_t> sizes(memRef.sizes, memRef.rank);
-  llvm::ArrayRef<int64_t> strides(memRef.strides, memRef.rank);
+// Allows to register a MemRef with the CUDA runtime. Helpful until we have
+// transfer functions implemented.
+extern "C" void
+mgpuMemHostRegisterMemRef(int64_t rank, StridedMemRefType<char, 1> *descriptor,
+                          int64_t elementSizeBytes) {
+
+  llvm::SmallVector<int64_t, 4> denseStrides(rank);
+  llvm::ArrayRef<int64_t> sizes(descriptor->sizes, rank);
+  llvm::ArrayRef<int64_t> strides(sizes.end(), rank);
 
   std::partial_sum(sizes.rbegin(), sizes.rend(), denseStrides.rbegin(),
                    std::multiplies<int64_t>());
-  auto count = denseStrides.front();
+  auto sizeBytes = denseStrides.front() * elementSizeBytes;
 
   // Only densely packed tensors are currently supported.
   std::rotate(denseStrides.begin(), denseStrides.begin() + 1,
@@ -93,17 +95,6 @@ void mgpuMemHostRegisterMemRef(const DynamicMemRefType<T> &memRef, T value) {
   denseStrides.back() = 1;
   assert(strides == llvm::makeArrayRef(denseStrides));
 
-  auto *pointer = memRef.data + memRef.offset;
-  std::fill_n(pointer, count, value);
-  mgpuMemHostRegister(pointer, count * sizeof(T));
-}
-
-extern "C" void mgpuMemHostRegisterFloat(int64_t rank, void *ptr) {
-  UnrankedMemRefType<float> memRef = {rank, ptr};
-  mgpuMemHostRegisterMemRef(DynamicMemRefType<float>(memRef), 1.23f);
-}
-
-extern "C" void mgpuMemHostRegisterInt32(int64_t rank, void *ptr) {
-  UnrankedMemRefType<int32_t> memRef = {rank, ptr};
-  mgpuMemHostRegisterMemRef(DynamicMemRefType<int32_t>(memRef), 123);
+  auto ptr = descriptor->data + descriptor->offset * elementSizeBytes;
+  mgpuMemHostRegister(ptr, sizeBytes);
 }

diff  --git a/mlir/tools/mlir-rocm-runner/rocm-runtime-wrappers.cpp b/mlir/tools/mlir-rocm-runner/rocm-runtime-wrappers.cpp
index a64815007661..9184c9fa20fa 100644
--- a/mlir/tools/mlir-rocm-runner/rocm-runtime-wrappers.cpp
+++ b/mlir/tools/mlir-rocm-runner/rocm-runtime-wrappers.cpp
@@ -76,17 +76,19 @@ extern "C" void mgpuMemHostRegister(void *ptr, uint64_t sizeBytes) {
   HIP_REPORT_IF_ERROR(hipHostRegister(ptr, sizeBytes, /*flags=*/0));
 }
 
-// Allows to register a MemRef with the ROCM runtime. Initializes array with
-// value. Helpful until we have transfer functions implemented.
-template <typename T>
-void mgpuMemHostRegisterMemRef(T *pointer, llvm::ArrayRef<int64_t> sizes,
-                               llvm::ArrayRef<int64_t> strides, T value) {
-  assert(sizes.size() == strides.size());
-  llvm::SmallVector<int64_t, 4> denseStrides(strides.size());
+// Allows to register a MemRef with the ROCm runtime. Helpful until we have
+// transfer functions implemented.
+extern "C" void
+mgpuMemHostRegisterMemRef(int64_t rank, StridedMemRefType<char, 1> *descriptor,
+                          int64_t elementSizeBytes) {
+
+  llvm::SmallVector<int64_t, 4> denseStrides(rank);
+  llvm::ArrayRef<int64_t> sizes(descriptor->sizes, rank);
+  llvm::ArrayRef<int64_t> strides(sizes.end(), rank);
 
   std::partial_sum(sizes.rbegin(), sizes.rend(), denseStrides.rbegin(),
                    std::multiplies<int64_t>());
-  auto count = denseStrides.front();
+  auto sizeBytes = denseStrides.front() * elementSizeBytes;
 
   // Only densely packed tensors are currently supported.
   std::rotate(denseStrides.begin(), denseStrides.begin() + 1,
@@ -94,22 +96,8 @@ void mgpuMemHostRegisterMemRef(T *pointer, llvm::ArrayRef<int64_t> sizes,
   denseStrides.back() = 1;
   assert(strides == llvm::makeArrayRef(denseStrides));
 
-  std::fill_n(pointer, count, value);
-  mgpuMemHostRegister(pointer, count * sizeof(T));
-}
-
-extern "C" void mgpuMemHostRegisterFloat(int64_t rank, void *ptr) {
-  auto *desc = static_cast<StridedMemRefType<float, 1> *>(ptr);
-  auto sizes = llvm::ArrayRef<int64_t>(desc->sizes, rank);
-  auto strides = llvm::ArrayRef<int64_t>(desc->sizes + rank, rank);
-  mgpuMemHostRegisterMemRef(desc->data + desc->offset, sizes, strides, 1.23f);
-}
-
-extern "C" void mgpuMemHostRegisterInt32(int64_t rank, void *ptr) {
-  auto *desc = static_cast<StridedMemRefType<int32_t, 1> *>(ptr);
-  auto sizes = llvm::ArrayRef<int64_t>(desc->sizes, rank);
-  auto strides = llvm::ArrayRef<int64_t>(desc->sizes + rank, rank);
-  mgpuMemHostRegisterMemRef(desc->data + desc->offset, sizes, strides, 123);
+  auto ptr = descriptor->data + descriptor->offset * elementSizeBytes;
+  mgpuMemHostRegister(ptr, sizeBytes);
 }
 
 template <typename T>


        


More information about the Mlir-commits mailing list