[Mlir-commits] [mlir] 2c48e36 - [MLIR] Adding gpu.host_register op and lower it to a runtime call.
Christian Sigg
llvmlistbot at llvm.org
Mon Aug 10 13:46:27 PDT 2020
Author: Christian Sigg
Date: 2020-08-10T22:46:17+02:00
New Revision: 2c48e3629cfb25ac1034117fa8945fd0d342f2ae
URL: https://github.com/llvm/llvm-project/commit/2c48e3629cfb25ac1034117fa8945fd0d342f2ae
DIFF: https://github.com/llvm/llvm-project/commit/2c48e3629cfb25ac1034117fa8945fd0d342f2ae.diff
LOG: [MLIR] Adding gpu.host_register op and lower it to a runtime call.
Reviewed By: herhut
Differential Revision: https://reviews.llvm.org/D85631
Added:
Modified:
mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
mlir/include/mlir/Dialect/GPU/GPUOps.td
mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp
mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
mlir/test/mlir-cuda-runner/all-reduce-and.mlir
mlir/test/mlir-cuda-runner/all-reduce-max.mlir
mlir/test/mlir-cuda-runner/all-reduce-min.mlir
mlir/test/mlir-cuda-runner/all-reduce-op.mlir
mlir/test/mlir-cuda-runner/all-reduce-or.mlir
mlir/test/mlir-cuda-runner/all-reduce-region.mlir
mlir/test/mlir-cuda-runner/all-reduce-xor.mlir
mlir/test/mlir-cuda-runner/gpu-to-cubin.mlir
mlir/test/mlir-cuda-runner/multiple-all-reduce.mlir
mlir/test/mlir-cuda-runner/shuffle.mlir
mlir/test/mlir-cuda-runner/two-modules.mlir
mlir/test/mlir-rocm-runner/gpu-to-hsaco.mlir
mlir/test/mlir-rocm-runner/two-modules.mlir
mlir/test/mlir-rocm-runner/vecadd.mlir
mlir/test/mlir-rocm-runner/vector-transferops.mlir
mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp
mlir/tools/mlir-rocm-runner/rocm-runtime-wrappers.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
index 25bff67d4a6e..2f8f87fa0e41 100644
--- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
+++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
@@ -440,6 +440,10 @@ class ConvertToLLVMPattern : public ConversionPattern {
ConversionPatternRewriter &rewriter,
SmallVectorImpl<Value> &sizes) const;
+ /// Computes the size of type in bytes.
+ Value getSizeInBytes(Location loc, Type type,
+ ConversionPatternRewriter &rewriter) const;
+
/// Computes total size in bytes of to store the given shape.
Value getCumulativeSizeInBytes(Location loc, Type elementType,
ArrayRef<Value> shape,
diff --git a/mlir/include/mlir/Dialect/GPU/GPUOps.td b/mlir/include/mlir/Dialect/GPU/GPUOps.td
index c0a6ac101d7b..288031c598ff 100644
--- a/mlir/include/mlir/Dialect/GPU/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/GPUOps.td
@@ -741,4 +741,16 @@ def GPU_ModuleEndOp : GPU_Op<"module_end", [
let printer = [{ p << getOperationName(); }];
}
+def GPU_HostRegisterOp : GPU_Op<"host_register">,
+ Arguments<(ins AnyUnrankedMemRef:$value)> {
+ let summary = "Registers a memref for access from device.";
+ let description = [{
+ This op registers the host memory pointed to by a memref to be accessed from
+ a device.
+ }];
+
+ let assemblyFormat = "$value attr-dict `:` type($value)";
+ let verifier = [{ return success(); }];
+}
+
#endif // GPU_OPS
diff --git a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp
index e14c9fbc6718..8aa843308cff 100644
--- a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp
+++ b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp
@@ -117,6 +117,26 @@ class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> {
"mgpuStreamSynchronize",
llvmVoidType,
{llvmPointerType /* void *stream */}};
+ FunctionCallBuilder hostRegisterCallBuilder = {
+ "mgpuMemHostRegisterMemRef",
+ llvmVoidType,
+ {llvmIntPtrType /* intptr_t rank */,
+ llvmPointerType /* void *memrefDesc */,
+ llvmIntPtrType /* intptr_t elementSizeBytes */}};
+};
+
+/// A rewrite patter to convert gpu.host_register operations into a GPU runtime
+/// call. Currently it supports CUDA and ROCm (HIP).
+class ConvertHostRegisterOpToGpuRuntimeCallPattern
+ : public ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp> {
+public:
+ ConvertHostRegisterOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
+ : ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp>(typeConverter) {}
+
+private:
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override;
};
/// A rewrite patter to convert gpu.launch_func operations into a sequence of
@@ -192,6 +212,33 @@ LLVM::CallOp FunctionCallBuilder::create(Location loc, OpBuilder &builder,
builder.getSymbolRefAttr(function), arguments);
}
+// Returns whether value is of LLVM type.
+static bool isLLVMType(Value value) {
+ return value.getType().isa<LLVM::LLVMType>();
+}
+
+LogicalResult ConvertHostRegisterOpToGpuRuntimeCallPattern::matchAndRewrite(
+ Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const {
+ if (!llvm::all_of(operands, isLLVMType))
+ return rewriter.notifyMatchFailure(
+ op, "Cannot convert if operands aren't of LLVM type.");
+
+ Location loc = op->getLoc();
+
+ auto memRefType = cast<gpu::HostRegisterOp>(op).value().getType();
+ auto elementType = memRefType.cast<UnrankedMemRefType>().getElementType();
+ auto elementSize = getSizeInBytes(loc, elementType, rewriter);
+
+ auto arguments =
+ typeConverter.promoteOperands(loc, op->getOperands(), operands, rewriter);
+ arguments.push_back(elementSize);
+ hostRegisterCallBuilder.create(loc, rewriter, arguments);
+
+ rewriter.eraseOp(op);
+ return success();
+}
+
// Creates a struct containing all kernel parameters on the stack and returns
// an array of type-erased pointers to the fields of the struct. The array can
// then be passed to the CUDA / ROCm (HIP) kernel launch calls.
@@ -269,11 +316,6 @@ Value ConvertLaunchFuncOpToGpuRuntimeCallPattern::generateKernelNameConstant(
LLVM::Linkage::Internal);
}
-// Returns whether value is of LLVM type.
-static bool isLLVMType(Value value) {
- return value.getType().isa<LLVM::LLVMType>();
-}
-
// Emits LLVM IR to launch a kernel function. Expects the module that contains
// the compiled kernel function as a cubin in the 'nvvm.cubin' attribute, or a
// hsaco in the 'rocdl.hsaco' attribute of the kernel function in the IR.
@@ -351,6 +393,7 @@ mlir::createGpuToLLVMConversionPass(StringRef gpuBinaryAnnotation) {
void mlir::populateGpuToLLVMConversionPatterns(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
StringRef gpuBinaryAnnotation) {
+ patterns.insert<ConvertHostRegisterOpToGpuRuntimeCallPattern>(converter);
patterns.insert<ConvertLaunchFuncOpToGpuRuntimeCallPattern>(
converter, gpuBinaryAnnotation);
patterns.insert<EraseGpuModuleOpPattern>(&converter.getContext());
diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index 5f8a38743400..08d1c32d13c5 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -927,30 +927,32 @@ void ConvertToLLVMPattern::getMemRefDescriptorSizes(
: createIndexConstant(rewriter, loc, s));
}
-Value ConvertToLLVMPattern::getCumulativeSizeInBytes(
- Location loc, Type elementType, ArrayRef<Value> sizes,
- ConversionPatternRewriter &rewriter) const {
- // Compute the total number of memref elements.
- Value cumulativeSizeInBytes =
- sizes.empty() ? createIndexConstant(rewriter, loc, 1) : sizes.front();
- for (unsigned i = 1, e = sizes.size(); i < e; ++i)
- cumulativeSizeInBytes = rewriter.create<LLVM::MulOp>(
- loc, getIndexType(), ArrayRef<Value>{cumulativeSizeInBytes, sizes[i]});
-
+Value ConvertToLLVMPattern::getSizeInBytes(
+ Location loc, Type type, ConversionPatternRewriter &rewriter) const {
// Compute the size of an individual element. This emits the MLIR equivalent
// of the following sizeof(...) implementation in LLVM IR:
// %0 = getelementptr %elementType* null, %indexType 1
// %1 = ptrtoint %elementType* %0 to %indexType
// which is a common pattern of getting the size of a type in bytes.
- auto convertedPtrType = typeConverter.convertType(elementType)
- .cast<LLVM::LLVMType>()
- .getPointerTo();
+ auto convertedPtrType =
+ typeConverter.convertType(type).cast<LLVM::LLVMType>().getPointerTo();
auto nullPtr = rewriter.create<LLVM::NullOp>(loc, convertedPtrType);
auto gep = rewriter.create<LLVM::GEPOp>(
loc, convertedPtrType,
ArrayRef<Value>{nullPtr, createIndexConstant(rewriter, loc, 1)});
- auto elementSize =
- rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gep);
+ return rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gep);
+}
+
+Value ConvertToLLVMPattern::getCumulativeSizeInBytes(
+ Location loc, Type elementType, ArrayRef<Value> sizes,
+ ConversionPatternRewriter &rewriter) const {
+ // Compute the total number of memref elements.
+ Value cumulativeSizeInBytes =
+ sizes.empty() ? createIndexConstant(rewriter, loc, 1) : sizes.front();
+ for (unsigned i = 1, e = sizes.size(); i < e; ++i)
+ cumulativeSizeInBytes = rewriter.create<LLVM::MulOp>(
+ loc, getIndexType(), ArrayRef<Value>{cumulativeSizeInBytes, sizes[i]});
+ auto elementSize = this->getSizeInBytes(loc, elementType, rewriter);
return rewriter.create<LLVM::MulOp>(
loc, getIndexType(), ArrayRef<Value>{cumulativeSizeInBytes, elementSize});
}
diff --git a/mlir/test/mlir-cuda-runner/all-reduce-and.mlir b/mlir/test/mlir-cuda-runner/all-reduce-and.mlir
index f89f91415724..ef8d50580d9e 100644
--- a/mlir/test/mlir-cuda-runner/all-reduce-and.mlir
+++ b/mlir/test/mlir-cuda-runner/all-reduce-and.mlir
@@ -25,9 +25,9 @@ func @main() {
%c6 = constant 6 : index
%cast_data = memref_cast %data : memref<2x6xi32> to memref<*xi32>
- call @mgpuMemHostRegisterInt32(%cast_data) : (memref<*xi32>) -> ()
+ gpu.host_register %cast_data : memref<*xi32>
%cast_sum = memref_cast %sum : memref<2xi32> to memref<*xi32>
- call @mgpuMemHostRegisterInt32(%cast_sum) : (memref<*xi32>) -> ()
+ gpu.host_register %cast_sum : memref<*xi32>
store %cst0, %data[%c0, %c0] : memref<2x6xi32>
store %cst1, %data[%c0, %c1] : memref<2x6xi32>
@@ -58,6 +58,5 @@ func @main() {
return
}
-func @mgpuMemHostRegisterInt32(%ptr : memref<*xi32>)
func @print_memref_i32(memref<*xi32>)
diff --git a/mlir/test/mlir-cuda-runner/all-reduce-max.mlir b/mlir/test/mlir-cuda-runner/all-reduce-max.mlir
index 4adf8a73d924..be8087a55ea2 100644
--- a/mlir/test/mlir-cuda-runner/all-reduce-max.mlir
+++ b/mlir/test/mlir-cuda-runner/all-reduce-max.mlir
@@ -25,9 +25,9 @@ func @main() {
%c6 = constant 6 : index
%cast_data = memref_cast %data : memref<2x6xi32> to memref<*xi32>
- call @mgpuMemHostRegisterInt32(%cast_data) : (memref<*xi32>) -> ()
+ gpu.host_register %cast_data : memref<*xi32>
%cast_sum = memref_cast %sum : memref<2xi32> to memref<*xi32>
- call @mgpuMemHostRegisterInt32(%cast_sum) : (memref<*xi32>) -> ()
+ gpu.host_register %cast_sum : memref<*xi32>
store %cst0, %data[%c0, %c0] : memref<2x6xi32>
store %cst1, %data[%c0, %c1] : memref<2x6xi32>
@@ -58,6 +58,5 @@ func @main() {
return
}
-func @mgpuMemHostRegisterInt32(%ptr : memref<*xi32>)
func @print_memref_i32(memref<*xi32>)
diff --git a/mlir/test/mlir-cuda-runner/all-reduce-min.mlir b/mlir/test/mlir-cuda-runner/all-reduce-min.mlir
index 8cb3116e9d0d..ad03ed5497f0 100644
--- a/mlir/test/mlir-cuda-runner/all-reduce-min.mlir
+++ b/mlir/test/mlir-cuda-runner/all-reduce-min.mlir
@@ -25,9 +25,9 @@ func @main() {
%c6 = constant 6 : index
%cast_data = memref_cast %data : memref<2x6xi32> to memref<*xi32>
- call @mgpuMemHostRegisterInt32(%cast_data) : (memref<*xi32>) -> ()
+ gpu.host_register %cast_data : memref<*xi32>
%cast_sum = memref_cast %sum : memref<2xi32> to memref<*xi32>
- call @mgpuMemHostRegisterInt32(%cast_sum) : (memref<*xi32>) -> ()
+ gpu.host_register %cast_sum : memref<*xi32>
store %cst0, %data[%c0, %c0] : memref<2x6xi32>
store %cst1, %data[%c0, %c1] : memref<2x6xi32>
@@ -58,6 +58,5 @@ func @main() {
return
}
-func @mgpuMemHostRegisterInt32(%ptr : memref<*xi32>)
func @print_memref_i32(memref<*xi32>)
diff --git a/mlir/test/mlir-cuda-runner/all-reduce-op.mlir b/mlir/test/mlir-cuda-runner/all-reduce-op.mlir
index 72306674c3ff..a639c699027b 100644
--- a/mlir/test/mlir-cuda-runner/all-reduce-op.mlir
+++ b/mlir/test/mlir-cuda-runner/all-reduce-op.mlir
@@ -11,7 +11,7 @@ func @main() {
%sy = dim %dst, %c1 : memref<?x?x?xf32>
%sz = dim %dst, %c0 : memref<?x?x?xf32>
%cast_dst = memref_cast %dst : memref<?x?x?xf32> to memref<*xf32>
- call @mgpuMemHostRegisterFloat(%cast_dst) : (memref<*xf32>) -> ()
+ gpu.host_register %cast_dst : memref<*xf32>
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1)
threads(%tx, %ty, %tz) in (%block_x = %sx, %block_y = %sy, %block_z = %sz) {
%t0 = muli %tz, %block_y : index
@@ -28,5 +28,4 @@ func @main() {
return
}
-func @mgpuMemHostRegisterFloat(%ptr : memref<*xf32>)
func @print_memref_f32(%ptr : memref<*xf32>)
diff --git a/mlir/test/mlir-cuda-runner/all-reduce-or.mlir b/mlir/test/mlir-cuda-runner/all-reduce-or.mlir
index 7d0ed929322e..1ed85e61ded5 100644
--- a/mlir/test/mlir-cuda-runner/all-reduce-or.mlir
+++ b/mlir/test/mlir-cuda-runner/all-reduce-or.mlir
@@ -25,9 +25,9 @@ func @main() {
%c6 = constant 6 : index
%cast_data = memref_cast %data : memref<2x6xi32> to memref<*xi32>
- call @mgpuMemHostRegisterInt32(%cast_data) : (memref<*xi32>) -> ()
+ gpu.host_register %cast_data : memref<*xi32>
%cast_sum = memref_cast %sum : memref<2xi32> to memref<*xi32>
- call @mgpuMemHostRegisterInt32(%cast_sum) : (memref<*xi32>) -> ()
+ gpu.host_register %cast_sum : memref<*xi32>
store %cst0, %data[%c0, %c0] : memref<2x6xi32>
store %cst1, %data[%c0, %c1] : memref<2x6xi32>
@@ -58,6 +58,5 @@ func @main() {
return
}
-func @mgpuMemHostRegisterInt32(%ptr : memref<*xi32>)
func @print_memref_i32(memref<*xi32>)
diff --git a/mlir/test/mlir-cuda-runner/all-reduce-region.mlir b/mlir/test/mlir-cuda-runner/all-reduce-region.mlir
index a9426c658978..49f24db131c7 100644
--- a/mlir/test/mlir-cuda-runner/all-reduce-region.mlir
+++ b/mlir/test/mlir-cuda-runner/all-reduce-region.mlir
@@ -8,7 +8,7 @@ func @main() {
%c0 = constant 0 : index
%sx = dim %dst, %c0 : memref<?xf32>
%cast_dst = memref_cast %dst : memref<?xf32> to memref<*xf32>
- call @mgpuMemHostRegisterFloat(%cast_dst) : (memref<*xf32>) -> ()
+ gpu.host_register %cast_dst : memref<*xf32>
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %one, %grid_y = %one, %grid_z = %one)
threads(%tx, %ty, %tz) in (%block_x = %sx, %block_y = %one, %block_z = %one) {
%val = index_cast %tx : index to i32
@@ -25,5 +25,4 @@ func @main() {
return
}
-func @mgpuMemHostRegisterFloat(%ptr : memref<*xf32>)
func @print_memref_f32(memref<*xf32>)
diff --git a/mlir/test/mlir-cuda-runner/all-reduce-xor.mlir b/mlir/test/mlir-cuda-runner/all-reduce-xor.mlir
index 67461783b257..88cd9036998f 100644
--- a/mlir/test/mlir-cuda-runner/all-reduce-xor.mlir
+++ b/mlir/test/mlir-cuda-runner/all-reduce-xor.mlir
@@ -25,9 +25,9 @@ func @main() {
%c6 = constant 6 : index
%cast_data = memref_cast %data : memref<2x6xi32> to memref<*xi32>
- call @mgpuMemHostRegisterInt32(%cast_data) : (memref<*xi32>) -> ()
+ gpu.host_register %cast_data : memref<*xi32>
%cast_sum = memref_cast %sum : memref<2xi32> to memref<*xi32>
- call @mgpuMemHostRegisterInt32(%cast_sum) : (memref<*xi32>) -> ()
+ gpu.host_register %cast_sum : memref<*xi32>
store %cst0, %data[%c0, %c0] : memref<2x6xi32>
store %cst1, %data[%c0, %c1] : memref<2x6xi32>
@@ -58,6 +58,5 @@ func @main() {
return
}
-func @mgpuMemHostRegisterInt32(%ptr : memref<*xi32>)
func @print_memref_i32(memref<*xi32>)
diff --git a/mlir/test/mlir-cuda-runner/gpu-to-cubin.mlir b/mlir/test/mlir-cuda-runner/gpu-to-cubin.mlir
index 80339c36fb38..fb458eb375d0 100644
--- a/mlir/test/mlir-cuda-runner/gpu-to-cubin.mlir
+++ b/mlir/test/mlir-cuda-runner/gpu-to-cubin.mlir
@@ -18,7 +18,7 @@ func @main() {
%21 = constant 5 : i32
%22 = memref_cast %arg0 : memref<5xf32> to memref<?xf32>
%23 = memref_cast %22 : memref<?xf32> to memref<*xf32>
- call @mgpuMemHostRegisterFloat(%23) : (memref<*xf32>) -> ()
+ gpu.host_register %23 : memref<*xf32>
call @print_memref_f32(%23) : (memref<*xf32>) -> ()
%24 = constant 1.0 : f32
call @other_func(%24, %22) : (f32, memref<?xf32>) -> ()
@@ -26,5 +26,4 @@ func @main() {
return
}
-func @mgpuMemHostRegisterFloat(%ptr : memref<*xf32>)
func @print_memref_f32(%ptr : memref<*xf32>)
diff --git a/mlir/test/mlir-cuda-runner/multiple-all-reduce.mlir b/mlir/test/mlir-cuda-runner/multiple-all-reduce.mlir
index b88d8e1b8ba1..f3fb3b219ce0 100644
--- a/mlir/test/mlir-cuda-runner/multiple-all-reduce.mlir
+++ b/mlir/test/mlir-cuda-runner/multiple-all-reduce.mlir
@@ -26,11 +26,11 @@ func @main() {
%c6 = constant 6 : index
%cast_data = memref_cast %data : memref<2x6xf32> to memref<*xf32>
- call @mgpuMemHostRegisterFloat(%cast_data) : (memref<*xf32>) -> ()
+ gpu.host_register %cast_data : memref<*xf32>
%cast_sum = memref_cast %sum : memref<2xf32> to memref<*xf32>
- call @mgpuMemHostRegisterFloat(%cast_sum) : (memref<*xf32>) -> ()
+ gpu.host_register %cast_sum : memref<*xf32>
%cast_mul = memref_cast %mul : memref<2xf32> to memref<*xf32>
- call @mgpuMemHostRegisterFloat(%cast_mul) : (memref<*xf32>) -> ()
+ gpu.host_register %cast_mul : memref<*xf32>
store %cst0, %data[%c0, %c0] : memref<2x6xf32>
store %cst1, %data[%c0, %c1] : memref<2x6xf32>
@@ -66,5 +66,4 @@ func @main() {
return
}
-func @mgpuMemHostRegisterFloat(%ptr : memref<*xf32>)
func @print_memref_f32(memref<*xf32>)
diff --git a/mlir/test/mlir-cuda-runner/shuffle.mlir b/mlir/test/mlir-cuda-runner/shuffle.mlir
index a4563cc0c381..9846455142d6 100644
--- a/mlir/test/mlir-cuda-runner/shuffle.mlir
+++ b/mlir/test/mlir-cuda-runner/shuffle.mlir
@@ -7,8 +7,8 @@ func @main() {
%one = constant 1 : index
%c0 = constant 0 : index
%sx = dim %dst, %c0 : memref<?xf32>
- %cast_dest = memref_cast %dst : memref<?xf32> to memref<*xf32>
- call @mgpuMemHostRegisterFloat(%cast_dest) : (memref<*xf32>) -> ()
+ %cast_dst = memref_cast %dst : memref<?xf32> to memref<*xf32>
+ gpu.host_register %cast_dst : memref<*xf32>
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %one, %grid_y = %one, %grid_z = %one)
threads(%tx, %ty, %tz) in (%block_x = %sx, %block_y = %one, %block_z = %one) {
%t0 = index_cast %tx : index to i32
@@ -24,9 +24,8 @@ func @main() {
store %value, %dst[%tx] : memref<?xf32>
gpu.terminator
}
- call @print_memref_f32(%cast_dest) : (memref<*xf32>) -> ()
+ call @print_memref_f32(%cast_dst) : (memref<*xf32>) -> ()
return
}
-func @mgpuMemHostRegisterFloat(%ptr : memref<*xf32>)
func @print_memref_f32(%ptr : memref<*xf32>)
diff --git a/mlir/test/mlir-cuda-runner/two-modules.mlir b/mlir/test/mlir-cuda-runner/two-modules.mlir
index 9bdda2ae9c66..2ea58ae55b5e 100644
--- a/mlir/test/mlir-cuda-runner/two-modules.mlir
+++ b/mlir/test/mlir-cuda-runner/two-modules.mlir
@@ -8,7 +8,7 @@ func @main() {
%c0 = constant 0 : index
%sx = dim %dst, %c0 : memref<?xi32>
%cast_dst = memref_cast %dst : memref<?xi32> to memref<*xi32>
- call @mgpuMemHostRegisterInt32(%cast_dst) : (memref<*xi32>) -> ()
+ gpu.host_register %cast_dst : memref<*xi32>
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %one, %grid_y = %one, %grid_z = %one)
threads(%tx, %ty, %tz) in (%block_x = %sx, %block_y = %one, %block_z = %one) {
%t0 = index_cast %tx : index to i32
@@ -25,5 +25,4 @@ func @main() {
return
}
-func @mgpuMemHostRegisterInt32(%memref : memref<*xi32>)
func @print_memref_i32(%memref : memref<*xi32>)
diff --git a/mlir/test/mlir-rocm-runner/gpu-to-hsaco.mlir b/mlir/test/mlir-rocm-runner/gpu-to-hsaco.mlir
index 4b1137468b14..68e31fede8dc 100644
--- a/mlir/test/mlir-rocm-runner/gpu-to-hsaco.mlir
+++ b/mlir/test/mlir-rocm-runner/gpu-to-hsaco.mlir
@@ -18,7 +18,7 @@ func @main() {
%21 = constant 5 : i32
%22 = memref_cast %arg0 : memref<5xf32> to memref<?xf32>
%cast = memref_cast %22 : memref<?xf32> to memref<*xf32>
- call @mgpuMemHostRegisterFloat(%cast) : (memref<*xf32>) -> ()
+ gpu.host_register %cast : memref<*xf32>
%23 = memref_cast %22 : memref<?xf32> to memref<*xf32>
call @print_memref_f32(%23) : (memref<*xf32>) -> ()
%24 = constant 1.0 : f32
@@ -28,6 +28,5 @@ func @main() {
return
}
-func @mgpuMemHostRegisterFloat(%ptr : memref<*xf32>)
func @mgpuMemGetDeviceMemRef1dFloat(%ptr : memref<?xf32>) -> (memref<?xf32>)
func @print_memref_f32(%ptr : memref<*xf32>)
diff --git a/mlir/test/mlir-rocm-runner/two-modules.mlir b/mlir/test/mlir-rocm-runner/two-modules.mlir
index d6b92229b585..f196c8e8fefe 100644
--- a/mlir/test/mlir-rocm-runner/two-modules.mlir
+++ b/mlir/test/mlir-rocm-runner/two-modules.mlir
@@ -8,7 +8,7 @@ func @main() {
%c1 = constant 1 : index
%sx = dim %dst, %c0 : memref<?xi32>
%cast_dst = memref_cast %dst : memref<?xi32> to memref<*xi32>
- call @mgpuMemHostRegisterInt32(%cast_dst) : (memref<*xi32>) -> ()
+ gpu.host_register %cast_dst : memref<*xi32>
%dst_device = call @mgpuMemGetDeviceMemRef1dInt32(%dst) : (memref<?xi32>) -> (memref<?xi32>)
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1)
threads(%tx, %ty, %tz) in (%block_x = %sx, %block_y = %c1, %block_z = %c1) {
@@ -26,6 +26,5 @@ func @main() {
return
}
-func @mgpuMemHostRegisterInt32(%ptr : memref<*xi32>)
func @mgpuMemGetDeviceMemRef1dInt32(%ptr : memref<?xi32>) -> (memref<?xi32>)
func @print_memref_i32(%ptr : memref<*xi32>)
diff --git a/mlir/test/mlir-rocm-runner/vecadd.mlir b/mlir/test/mlir-rocm-runner/vecadd.mlir
index a86412ff8fef..df5c073f9b81 100644
--- a/mlir/test/mlir-rocm-runner/vecadd.mlir
+++ b/mlir/test/mlir-rocm-runner/vecadd.mlir
@@ -26,9 +26,9 @@ func @main() {
%6 = memref_cast %3 : memref<?xf32> to memref<*xf32>
%7 = memref_cast %4 : memref<?xf32> to memref<*xf32>
%8 = memref_cast %5 : memref<?xf32> to memref<*xf32>
- call @mgpuMemHostRegisterFloat(%6) : (memref<*xf32>) -> ()
- call @mgpuMemHostRegisterFloat(%7) : (memref<*xf32>) -> ()
- call @mgpuMemHostRegisterFloat(%8) : (memref<*xf32>) -> ()
+ gpu.host_register %6 : memref<*xf32>
+ gpu.host_register %7 : memref<*xf32>
+ gpu.host_register %8 : memref<*xf32>
%9 = call @mgpuMemGetDeviceMemRef1dFloat(%3) : (memref<?xf32>) -> (memref<?xf32>)
%10 = call @mgpuMemGetDeviceMemRef1dFloat(%4) : (memref<?xf32>) -> (memref<?xf32>)
%11 = call @mgpuMemGetDeviceMemRef1dFloat(%5) : (memref<?xf32>) -> (memref<?xf32>)
@@ -38,6 +38,5 @@ func @main() {
return
}
-func @mgpuMemHostRegisterFloat(%ptr : memref<*xf32>)
func @mgpuMemGetDeviceMemRef1dFloat(%ptr : memref<?xf32>) -> (memref<?xf32>)
func @print_memref_f32(%ptr : memref<*xf32>)
diff --git a/mlir/test/mlir-rocm-runner/vector-transferops.mlir b/mlir/test/mlir-rocm-runner/vector-transferops.mlir
index b028f91f8394..873897011464 100644
--- a/mlir/test/mlir-rocm-runner/vector-transferops.mlir
+++ b/mlir/test/mlir-rocm-runner/vector-transferops.mlir
@@ -55,8 +55,8 @@ func @main() {
%cast0 = memref_cast %22 : memref<?xf32> to memref<*xf32>
%cast1 = memref_cast %23 : memref<?xf32> to memref<*xf32>
- call @mgpuMemHostRegisterFloat(%cast0) : (memref<*xf32>) -> ()
- call @mgpuMemHostRegisterFloat(%cast1) : (memref<*xf32>) -> ()
+ gpu.host_register %cast0 : memref<*xf32>
+ gpu.host_register %cast1 : memref<*xf32>
%24 = call @mgpuMemGetDeviceMemRef1dFloat(%22) : (memref<?xf32>) -> (memref<?xf32>)
%26 = call @mgpuMemGetDeviceMemRef1dFloat(%23) : (memref<?xf32>) -> (memref<?xf32>)
@@ -71,6 +71,5 @@ func @main() {
return
}
-func @mgpuMemHostRegisterFloat(%ptr : memref<*xf32>)
func @mgpuMemGetDeviceMemRef1dFloat(%ptr : memref<?xf32>) -> (memref<?xf32>)
func @print_memref_f32(%ptr : memref<*xf32>)
diff --git a/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp b/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp
index 8e2dc029fa9f..517fc9fc18f5 100644
--- a/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp
+++ b/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp
@@ -75,17 +75,19 @@ extern "C" void mgpuMemHostRegister(void *ptr, uint64_t sizeBytes) {
CUDA_REPORT_IF_ERROR(cuMemHostRegister(ptr, sizeBytes, /*flags=*/0));
}
-// Allows to register a MemRef with the CUDA runtime. Initializes array with
-// value. Helpful until we have transfer functions implemented.
-template <typename T>
-void mgpuMemHostRegisterMemRef(const DynamicMemRefType<T> &memRef, T value) {
- llvm::SmallVector<int64_t, 4> denseStrides(memRef.rank);
- llvm::ArrayRef<int64_t> sizes(memRef.sizes, memRef.rank);
- llvm::ArrayRef<int64_t> strides(memRef.strides, memRef.rank);
+// Allows to register a MemRef with the CUDA runtime. Helpful until we have
+// transfer functions implemented.
+extern "C" void
+mgpuMemHostRegisterMemRef(int64_t rank, StridedMemRefType<char, 1> *descriptor,
+ int64_t elementSizeBytes) {
+
+ llvm::SmallVector<int64_t, 4> denseStrides(rank);
+ llvm::ArrayRef<int64_t> sizes(descriptor->sizes, rank);
+ llvm::ArrayRef<int64_t> strides(sizes.end(), rank);
std::partial_sum(sizes.rbegin(), sizes.rend(), denseStrides.rbegin(),
std::multiplies<int64_t>());
- auto count = denseStrides.front();
+ auto sizeBytes = denseStrides.front() * elementSizeBytes;
// Only densely packed tensors are currently supported.
std::rotate(denseStrides.begin(), denseStrides.begin() + 1,
@@ -93,17 +95,6 @@ void mgpuMemHostRegisterMemRef(const DynamicMemRefType<T> &memRef, T value) {
denseStrides.back() = 1;
assert(strides == llvm::makeArrayRef(denseStrides));
- auto *pointer = memRef.data + memRef.offset;
- std::fill_n(pointer, count, value);
- mgpuMemHostRegister(pointer, count * sizeof(T));
-}
-
-extern "C" void mgpuMemHostRegisterFloat(int64_t rank, void *ptr) {
- UnrankedMemRefType<float> memRef = {rank, ptr};
- mgpuMemHostRegisterMemRef(DynamicMemRefType<float>(memRef), 1.23f);
-}
-
-extern "C" void mgpuMemHostRegisterInt32(int64_t rank, void *ptr) {
- UnrankedMemRefType<int32_t> memRef = {rank, ptr};
- mgpuMemHostRegisterMemRef(DynamicMemRefType<int32_t>(memRef), 123);
+ auto ptr = descriptor->data + descriptor->offset * elementSizeBytes;
+ mgpuMemHostRegister(ptr, sizeBytes);
}
diff --git a/mlir/tools/mlir-rocm-runner/rocm-runtime-wrappers.cpp b/mlir/tools/mlir-rocm-runner/rocm-runtime-wrappers.cpp
index a64815007661..9184c9fa20fa 100644
--- a/mlir/tools/mlir-rocm-runner/rocm-runtime-wrappers.cpp
+++ b/mlir/tools/mlir-rocm-runner/rocm-runtime-wrappers.cpp
@@ -76,17 +76,19 @@ extern "C" void mgpuMemHostRegister(void *ptr, uint64_t sizeBytes) {
HIP_REPORT_IF_ERROR(hipHostRegister(ptr, sizeBytes, /*flags=*/0));
}
-// Allows to register a MemRef with the ROCM runtime. Initializes array with
-// value. Helpful until we have transfer functions implemented.
-template <typename T>
-void mgpuMemHostRegisterMemRef(T *pointer, llvm::ArrayRef<int64_t> sizes,
- llvm::ArrayRef<int64_t> strides, T value) {
- assert(sizes.size() == strides.size());
- llvm::SmallVector<int64_t, 4> denseStrides(strides.size());
+// Allows to register a MemRef with the ROCm runtime. Helpful until we have
+// transfer functions implemented.
+extern "C" void
+mgpuMemHostRegisterMemRef(int64_t rank, StridedMemRefType<char, 1> *descriptor,
+ int64_t elementSizeBytes) {
+
+ llvm::SmallVector<int64_t, 4> denseStrides(rank);
+ llvm::ArrayRef<int64_t> sizes(descriptor->sizes, rank);
+ llvm::ArrayRef<int64_t> strides(sizes.end(), rank);
std::partial_sum(sizes.rbegin(), sizes.rend(), denseStrides.rbegin(),
std::multiplies<int64_t>());
- auto count = denseStrides.front();
+ auto sizeBytes = denseStrides.front() * elementSizeBytes;
// Only densely packed tensors are currently supported.
std::rotate(denseStrides.begin(), denseStrides.begin() + 1,
@@ -94,22 +96,8 @@ void mgpuMemHostRegisterMemRef(T *pointer, llvm::ArrayRef<int64_t> sizes,
denseStrides.back() = 1;
assert(strides == llvm::makeArrayRef(denseStrides));
- std::fill_n(pointer, count, value);
- mgpuMemHostRegister(pointer, count * sizeof(T));
-}
-
-extern "C" void mgpuMemHostRegisterFloat(int64_t rank, void *ptr) {
- auto *desc = static_cast<StridedMemRefType<float, 1> *>(ptr);
- auto sizes = llvm::ArrayRef<int64_t>(desc->sizes, rank);
- auto strides = llvm::ArrayRef<int64_t>(desc->sizes + rank, rank);
- mgpuMemHostRegisterMemRef(desc->data + desc->offset, sizes, strides, 1.23f);
-}
-
-extern "C" void mgpuMemHostRegisterInt32(int64_t rank, void *ptr) {
- auto *desc = static_cast<StridedMemRefType<int32_t, 1> *>(ptr);
- auto sizes = llvm::ArrayRef<int64_t>(desc->sizes, rank);
- auto strides = llvm::ArrayRef<int64_t>(desc->sizes + rank, rank);
- mgpuMemHostRegisterMemRef(desc->data + desc->offset, sizes, strides, 123);
+ auto ptr = descriptor->data + descriptor->offset * elementSizeBytes;
+ mgpuMemHostRegister(ptr, sizeBytes);
}
template <typename T>
More information about the Mlir-commits
mailing list