[Mlir-commits] [mlir] 62adfed - Unrank mcuMemHostRegister tensor argument.
Christian Sigg
llvmlistbot at llvm.org
Tue May 19 04:59:10 PDT 2020
Author: Christian Sigg
Date: 2020-05-19T13:58:54+02:00
New Revision: 62adfed30a125b0057d28b570b353bce4d23df83
URL: https://github.com/llvm/llvm-project/commit/62adfed30a125b0057d28b570b353bce4d23df83
DIFF: https://github.com/llvm/llvm-project/commit/62adfed30a125b0057d28b570b353bce4d23df83.diff
LOG: Unrank mcuMemHostRegister tensor argument.
Reviewers: herhut
Reviewed By: herhut
Subscribers: mehdi_amini, rriddle, jpienaar, shauheen, antiagainst, nicolasvasilache, arpith-jacob, mgester, lucyrfox, aartbik, liufengdb, stephenneuendorffer, Joonsoo, grosul1, frgossen, Kayjukh, jurahul, llvm-commits
Tags: #llvm
Differential Revision: https://reviews.llvm.org/D80118
Added:
Modified:
mlir/test/mlir-cuda-runner/all-reduce-and.mlir
mlir/test/mlir-cuda-runner/all-reduce-max.mlir
mlir/test/mlir-cuda-runner/all-reduce-min.mlir
mlir/test/mlir-cuda-runner/all-reduce-op.mlir
mlir/test/mlir-cuda-runner/all-reduce-or.mlir
mlir/test/mlir-cuda-runner/all-reduce-region.mlir
mlir/test/mlir-cuda-runner/all-reduce-xor.mlir
mlir/test/mlir-cuda-runner/gpu-to-cubin.mlir
mlir/test/mlir-cuda-runner/multiple-all-reduce.mlir
mlir/test/mlir-cuda-runner/shuffle.mlir
mlir/test/mlir-cuda-runner/two-modules.mlir
mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp
Removed:
################################################################################
diff --git a/mlir/test/mlir-cuda-runner/all-reduce-and.mlir b/mlir/test/mlir-cuda-runner/all-reduce-and.mlir
index 8eb2e0d72ec5..d3ad7a802537 100644
--- a/mlir/test/mlir-cuda-runner/all-reduce-and.mlir
+++ b/mlir/test/mlir-cuda-runner/all-reduce-and.mlir
@@ -24,10 +24,10 @@ func @main() {
%c5 = constant 5 : index
%c6 = constant 6 : index
- %cast_data = memref_cast %data : memref<2x6xi32> to memref<?x?xi32>
- call @mcuMemHostRegisterMemRef2dInt32(%cast_data) : (memref<?x?xi32>) -> ()
- %cast_sum = memref_cast %sum : memref<2xi32> to memref<?xi32>
- call @mcuMemHostRegisterMemRef1dInt32(%cast_sum) : (memref<?xi32>) -> ()
+ %cast_data = memref_cast %data : memref<2x6xi32> to memref<*xi32>
+ call @mcuMemHostRegisterInt32(%cast_data) : (memref<*xi32>) -> ()
+ %cast_sum = memref_cast %sum : memref<2xi32> to memref<*xi32>
+ call @mcuMemHostRegisterInt32(%cast_sum) : (memref<*xi32>) -> ()
store %cst0, %data[%c0, %c0] : memref<2x6xi32>
store %cst1, %data[%c0, %c1] : memref<2x6xi32>
@@ -52,14 +52,12 @@ func @main() {
gpu.terminator
}
- %ptr = memref_cast %sum : memref<2xi32> to memref<*xi32>
- call @print_memref_i32(%ptr) : (memref<*xi32>) -> ()
+ call @print_memref_i32(%cast_sum) : (memref<*xi32>) -> ()
// CHECK: [0, 2]
return
}
-func @mcuMemHostRegisterMemRef1dInt32(%ptr : memref<?xi32>)
-func @mcuMemHostRegisterMemRef2dInt32(%ptr : memref<?x?xi32>)
+func @mcuMemHostRegisterInt32(%ptr : memref<*xi32>)
func @print_memref_i32(memref<*xi32>)
diff --git a/mlir/test/mlir-cuda-runner/all-reduce-max.mlir b/mlir/test/mlir-cuda-runner/all-reduce-max.mlir
index 1213625fa9a0..ae2f6c3d6b3e 100644
--- a/mlir/test/mlir-cuda-runner/all-reduce-max.mlir
+++ b/mlir/test/mlir-cuda-runner/all-reduce-max.mlir
@@ -24,10 +24,10 @@ func @main() {
%c5 = constant 5 : index
%c6 = constant 6 : index
- %cast_data = memref_cast %data : memref<2x6xi32> to memref<?x?xi32>
- call @mcuMemHostRegisterMemRef2dInt32(%cast_data) : (memref<?x?xi32>) -> ()
- %cast_sum = memref_cast %sum : memref<2xi32> to memref<?xi32>
- call @mcuMemHostRegisterMemRef1dInt32(%cast_sum) : (memref<?xi32>) -> ()
+ %cast_data = memref_cast %data : memref<2x6xi32> to memref<*xi32>
+ call @mcuMemHostRegisterInt32(%cast_data) : (memref<*xi32>) -> ()
+ %cast_sum = memref_cast %sum : memref<2xi32> to memref<*xi32>
+ call @mcuMemHostRegisterInt32(%cast_sum) : (memref<*xi32>) -> ()
store %cst0, %data[%c0, %c0] : memref<2x6xi32>
store %cst1, %data[%c0, %c1] : memref<2x6xi32>
@@ -52,14 +52,12 @@ func @main() {
gpu.terminator
}
- %ptr = memref_cast %sum : memref<2xi32> to memref<*xi32>
- call @print_memref_i32(%ptr) : (memref<*xi32>) -> ()
+ call @print_memref_i32(%cast_sum) : (memref<*xi32>) -> ()
// CHECK: [16, 11]
return
}
-func @mcuMemHostRegisterMemRef1dInt32(%ptr : memref<?xi32>)
-func @mcuMemHostRegisterMemRef2dInt32(%ptr : memref<?x?xi32>)
+func @mcuMemHostRegisterInt32(%ptr : memref<*xi32>)
func @print_memref_i32(memref<*xi32>)
diff --git a/mlir/test/mlir-cuda-runner/all-reduce-min.mlir b/mlir/test/mlir-cuda-runner/all-reduce-min.mlir
index f467c80027c2..0cd4f11daf10 100644
--- a/mlir/test/mlir-cuda-runner/all-reduce-min.mlir
+++ b/mlir/test/mlir-cuda-runner/all-reduce-min.mlir
@@ -24,10 +24,10 @@ func @main() {
%c5 = constant 5 : index
%c6 = constant 6 : index
- %cast_data = memref_cast %data : memref<2x6xi32> to memref<?x?xi32>
- call @mcuMemHostRegisterMemRef2dInt32(%cast_data) : (memref<?x?xi32>) -> ()
- %cast_sum = memref_cast %sum : memref<2xi32> to memref<?xi32>
- call @mcuMemHostRegisterMemRef1dInt32(%cast_sum) : (memref<?xi32>) -> ()
+ %cast_data = memref_cast %data : memref<2x6xi32> to memref<*xi32>
+ call @mcuMemHostRegisterInt32(%cast_data) : (memref<*xi32>) -> ()
+ %cast_sum = memref_cast %sum : memref<2xi32> to memref<*xi32>
+ call @mcuMemHostRegisterInt32(%cast_sum) : (memref<*xi32>) -> ()
store %cst0, %data[%c0, %c0] : memref<2x6xi32>
store %cst1, %data[%c0, %c1] : memref<2x6xi32>
@@ -52,14 +52,12 @@ func @main() {
gpu.terminator
}
- %ptr = memref_cast %sum : memref<2xi32> to memref<*xi32>
- call @print_memref_i32(%ptr) : (memref<*xi32>) -> ()
+ call @print_memref_i32(%cast_sum) : (memref<*xi32>) -> ()
// CHECK: [0, 2]
return
}
-func @mcuMemHostRegisterMemRef1dInt32(%ptr : memref<?xi32>)
-func @mcuMemHostRegisterMemRef2dInt32(%ptr : memref<?x?xi32>)
+func @mcuMemHostRegisterInt32(%ptr : memref<*xi32>)
func @print_memref_i32(memref<*xi32>)
diff --git a/mlir/test/mlir-cuda-runner/all-reduce-op.mlir b/mlir/test/mlir-cuda-runner/all-reduce-op.mlir
index 60acd91f5e75..eb522d2910a6 100644
--- a/mlir/test/mlir-cuda-runner/all-reduce-op.mlir
+++ b/mlir/test/mlir-cuda-runner/all-reduce-op.mlir
@@ -8,7 +8,8 @@ func @main() {
%sx = dim %dst, 2 : memref<?x?x?xf32>
%sy = dim %dst, 1 : memref<?x?x?xf32>
%sz = dim %dst, 0 : memref<?x?x?xf32>
- call @mcuMemHostRegisterMemRef3dFloat(%dst) : (memref<?x?x?xf32>) -> ()
+ %cast_dst = memref_cast %dst : memref<?x?x?xf32> to memref<*xf32>
+ call @mcuMemHostRegisterFloat(%cast_dst) : (memref<*xf32>) -> ()
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %one, %grid_y = %one, %grid_z = %one)
threads(%tx, %ty, %tz) in (%block_x = %sx, %block_y = %sy, %block_z = %sz) {
%t0 = muli %tz, %block_y : index
@@ -21,10 +22,9 @@ func @main() {
store %sum, %dst[%tz, %ty, %tx] : memref<?x?x?xf32>
gpu.terminator
}
- %U = memref_cast %dst : memref<?x?x?xf32> to memref<*xf32>
- call @print_memref_f32(%U) : (memref<*xf32>) -> ()
+ call @print_memref_f32(%cast_dst) : (memref<*xf32>) -> ()
return
}
-func @mcuMemHostRegisterMemRef3dFloat(%ptr : memref<?x?x?xf32>)
+func @mcuMemHostRegisterFloat(%ptr : memref<*xf32>)
func @print_memref_f32(%ptr : memref<*xf32>)
diff --git a/mlir/test/mlir-cuda-runner/all-reduce-or.mlir b/mlir/test/mlir-cuda-runner/all-reduce-or.mlir
index 3135970620dc..cc9eae9e8b66 100644
--- a/mlir/test/mlir-cuda-runner/all-reduce-or.mlir
+++ b/mlir/test/mlir-cuda-runner/all-reduce-or.mlir
@@ -24,10 +24,10 @@ func @main() {
%c5 = constant 5 : index
%c6 = constant 6 : index
- %cast_data = memref_cast %data : memref<2x6xi32> to memref<?x?xi32>
- call @mcuMemHostRegisterMemRef2dInt32(%cast_data) : (memref<?x?xi32>) -> ()
- %cast_sum = memref_cast %sum : memref<2xi32> to memref<?xi32>
- call @mcuMemHostRegisterMemRef1dInt32(%cast_sum) : (memref<?xi32>) -> ()
+ %cast_data = memref_cast %data : memref<2x6xi32> to memref<*xi32>
+ call @mcuMemHostRegisterInt32(%cast_data) : (memref<*xi32>) -> ()
+ %cast_sum = memref_cast %sum : memref<2xi32> to memref<*xi32>
+ call @mcuMemHostRegisterInt32(%cast_sum) : (memref<*xi32>) -> ()
store %cst0, %data[%c0, %c0] : memref<2x6xi32>
store %cst1, %data[%c0, %c1] : memref<2x6xi32>
@@ -52,14 +52,12 @@ func @main() {
gpu.terminator
}
- %ptr = memref_cast %sum : memref<2xi32> to memref<*xi32>
- call @print_memref_i32(%ptr) : (memref<*xi32>) -> ()
+ call @print_memref_i32(%cast_sum) : (memref<*xi32>) -> ()
// CHECK: [31, 15]
return
}
-func @mcuMemHostRegisterMemRef1dInt32(%ptr : memref<?xi32>)
-func @mcuMemHostRegisterMemRef2dInt32(%ptr : memref<?x?xi32>)
+func @mcuMemHostRegisterInt32(%ptr : memref<*xi32>)
func @print_memref_i32(memref<*xi32>)
diff --git a/mlir/test/mlir-cuda-runner/all-reduce-region.mlir b/mlir/test/mlir-cuda-runner/all-reduce-region.mlir
index 6d967d155930..69499215707e 100644
--- a/mlir/test/mlir-cuda-runner/all-reduce-region.mlir
+++ b/mlir/test/mlir-cuda-runner/all-reduce-region.mlir
@@ -6,7 +6,8 @@ func @main() {
%dst = memref_cast %arg : memref<35xf32> to memref<?xf32>
%one = constant 1 : index
%sx = dim %dst, 0 : memref<?xf32>
- call @mcuMemHostRegisterMemRef1dFloat(%dst) : (memref<?xf32>) -> ()
+ %cast_dst = memref_cast %dst : memref<?xf32> to memref<*xf32>
+ call @mcuMemHostRegisterFloat(%cast_dst) : (memref<*xf32>) -> ()
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %one, %grid_y = %one, %grid_z = %one)
threads(%tx, %ty, %tz) in (%block_x = %sx, %block_y = %one, %block_z = %one) {
%val = index_cast %tx : index to i32
@@ -19,10 +20,9 @@ func @main() {
store %res, %dst[%tx] : memref<?xf32>
gpu.terminator
}
- %U = memref_cast %dst : memref<?xf32> to memref<*xf32>
- call @print_memref_f32(%U) : (memref<*xf32>) -> ()
+ call @print_memref_f32(%cast_dst) : (memref<*xf32>) -> ()
return
}
-func @mcuMemHostRegisterMemRef1dFloat(%ptr : memref<?xf32>)
+func @mcuMemHostRegisterFloat(%ptr : memref<*xf32>)
func @print_memref_f32(memref<*xf32>)
diff --git a/mlir/test/mlir-cuda-runner/all-reduce-xor.mlir b/mlir/test/mlir-cuda-runner/all-reduce-xor.mlir
index 913088f1164e..a32c4d3eb93e 100644
--- a/mlir/test/mlir-cuda-runner/all-reduce-xor.mlir
+++ b/mlir/test/mlir-cuda-runner/all-reduce-xor.mlir
@@ -24,10 +24,10 @@ func @main() {
%c5 = constant 5 : index
%c6 = constant 6 : index
- %cast_data = memref_cast %data : memref<2x6xi32> to memref<?x?xi32>
- call @mcuMemHostRegisterMemRef2dInt32(%cast_data) : (memref<?x?xi32>) -> ()
- %cast_sum = memref_cast %sum : memref<2xi32> to memref<?xi32>
- call @mcuMemHostRegisterMemRef1dInt32(%cast_sum) : (memref<?xi32>) -> ()
+ %cast_data = memref_cast %data : memref<2x6xi32> to memref<*xi32>
+ call @mcuMemHostRegisterInt32(%cast_data) : (memref<*xi32>) -> ()
+ %cast_sum = memref_cast %sum : memref<2xi32> to memref<*xi32>
+ call @mcuMemHostRegisterInt32(%cast_sum) : (memref<*xi32>) -> ()
store %cst0, %data[%c0, %c0] : memref<2x6xi32>
store %cst1, %data[%c0, %c1] : memref<2x6xi32>
@@ -52,14 +52,12 @@ func @main() {
gpu.terminator
}
- %ptr = memref_cast %sum : memref<2xi32> to memref<*xi32>
- call @print_memref_i32(%ptr) : (memref<*xi32>) -> ()
+ call @print_memref_i32(%cast_sum) : (memref<*xi32>) -> ()
// CHECK: [31, 1]
return
}
-func @mcuMemHostRegisterMemRef1dInt32(%ptr : memref<?xi32>)
-func @mcuMemHostRegisterMemRef2dInt32(%ptr : memref<?x?xi32>)
+func @mcuMemHostRegisterInt32(%ptr : memref<*xi32>)
func @print_memref_i32(memref<*xi32>)
diff --git a/mlir/test/mlir-cuda-runner/gpu-to-cubin.mlir b/mlir/test/mlir-cuda-runner/gpu-to-cubin.mlir
index bf882761f148..242cc9c28c00 100644
--- a/mlir/test/mlir-cuda-runner/gpu-to-cubin.mlir
+++ b/mlir/test/mlir-cuda-runner/gpu-to-cubin.mlir
@@ -16,8 +16,8 @@ func @main() {
%arg0 = alloc() : memref<5xf32>
%21 = constant 5 : i32
%22 = memref_cast %arg0 : memref<5xf32> to memref<?xf32>
- call @mcuMemHostRegisterMemRef1dFloat(%22) : (memref<?xf32>) -> ()
%23 = memref_cast %22 : memref<?xf32> to memref<*xf32>
+ call @mcuMemHostRegisterFloat(%23) : (memref<*xf32>) -> ()
call @print_memref_f32(%23) : (memref<*xf32>) -> ()
%24 = constant 1.0 : f32
call @other_func(%24, %22) : (f32, memref<?xf32>) -> ()
@@ -25,5 +25,5 @@ func @main() {
return
}
-func @mcuMemHostRegisterMemRef1dFloat(%ptr : memref<?xf32>)
+func @mcuMemHostRegisterFloat(%ptr : memref<*xf32>)
func @print_memref_f32(%ptr : memref<*xf32>)
diff --git a/mlir/test/mlir-cuda-runner/multiple-all-reduce.mlir b/mlir/test/mlir-cuda-runner/multiple-all-reduce.mlir
index c989db9b1cb9..a7b143f760a7 100644
--- a/mlir/test/mlir-cuda-runner/multiple-all-reduce.mlir
+++ b/mlir/test/mlir-cuda-runner/multiple-all-reduce.mlir
@@ -25,12 +25,12 @@ func @main() {
%c5 = constant 5 : index
%c6 = constant 6 : index
- %cast_data = memref_cast %data : memref<2x6xf32> to memref<?x?xf32>
- call @mcuMemHostRegisterMemRef2dFloat(%cast_data) : (memref<?x?xf32>) -> ()
- %cast_sum = memref_cast %sum : memref<2xf32> to memref<?xf32>
- call @mcuMemHostRegisterMemRef1dFloat(%cast_sum) : (memref<?xf32>) -> ()
- %cast_mul = memref_cast %mul : memref<2xf32> to memref<?xf32>
- call @mcuMemHostRegisterMemRef1dFloat(%cast_mul) : (memref<?xf32>) -> ()
+ %cast_data = memref_cast %data : memref<2x6xf32> to memref<*xf32>
+ call @mcuMemHostRegisterFloat(%cast_data) : (memref<*xf32>) -> ()
+ %cast_sum = memref_cast %sum : memref<2xf32> to memref<*xf32>
+ call @mcuMemHostRegisterFloat(%cast_sum) : (memref<*xf32>) -> ()
+ %cast_mul = memref_cast %mul : memref<2xf32> to memref<*xf32>
+ call @mcuMemHostRegisterFloat(%cast_mul) : (memref<*xf32>) -> ()
store %cst0, %data[%c0, %c0] : memref<2x6xf32>
store %cst1, %data[%c0, %c1] : memref<2x6xf32>
@@ -57,17 +57,14 @@ func @main() {
gpu.terminator
}
- %ptr_sum = memref_cast %sum : memref<2xf32> to memref<*xf32>
- call @print_memref_f32(%ptr_sum) : (memref<*xf32>) -> ()
+ call @print_memref_f32(%cast_sum) : (memref<*xf32>) -> ()
// CHECK: [31, 39]
- %ptr_mul = memref_cast %mul : memref<2xf32> to memref<*xf32>
- call @print_memref_f32(%ptr_mul) : (memref<*xf32>) -> ()
+ call @print_memref_f32(%cast_mul) : (memref<*xf32>) -> ()
// CHECK: [0, 27720]
return
}
-func @mcuMemHostRegisterMemRef1dFloat(%ptr : memref<?xf32>)
-func @mcuMemHostRegisterMemRef2dFloat(%ptr : memref<?x?xf32>)
+func @mcuMemHostRegisterFloat(%ptr : memref<*xf32>)
func @print_memref_f32(memref<*xf32>)
diff --git a/mlir/test/mlir-cuda-runner/shuffle.mlir b/mlir/test/mlir-cuda-runner/shuffle.mlir
index 0067ef440571..09fbef0095d8 100644
--- a/mlir/test/mlir-cuda-runner/shuffle.mlir
+++ b/mlir/test/mlir-cuda-runner/shuffle.mlir
@@ -6,7 +6,8 @@ func @main() {
%dst = memref_cast %arg : memref<13xf32> to memref<?xf32>
%one = constant 1 : index
%sx = dim %dst, 0 : memref<?xf32>
- call @mcuMemHostRegisterMemRef1dFloat(%dst) : (memref<?xf32>) -> ()
+ %cast_dest = memref_cast %dst : memref<?xf32> to memref<*xf32>
+ call @mcuMemHostRegisterFloat(%cast_dest) : (memref<*xf32>) -> ()
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %one, %grid_y = %one, %grid_z = %one)
threads(%tx, %ty, %tz) in (%block_x = %sx, %block_y = %one, %block_z = %one) {
%t0 = index_cast %tx : index to i32
@@ -22,10 +23,9 @@ func @main() {
store %value, %dst[%tx] : memref<?xf32>
gpu.terminator
}
- %U = memref_cast %dst : memref<?xf32> to memref<*xf32>
- call @print_memref_f32(%U) : (memref<*xf32>) -> ()
+ call @print_memref_f32(%cast_dest) : (memref<*xf32>) -> ()
return
}
-func @mcuMemHostRegisterMemRef1dFloat(%ptr : memref<?xf32>)
+func @mcuMemHostRegisterFloat(%ptr : memref<*xf32>)
func @print_memref_f32(%ptr : memref<*xf32>)
diff --git a/mlir/test/mlir-cuda-runner/two-modules.mlir b/mlir/test/mlir-cuda-runner/two-modules.mlir
index 0f01b36f5cee..68c936596315 100644
--- a/mlir/test/mlir-cuda-runner/two-modules.mlir
+++ b/mlir/test/mlir-cuda-runner/two-modules.mlir
@@ -6,7 +6,8 @@ func @main() {
%dst = memref_cast %arg : memref<13xi32> to memref<?xi32>
%one = constant 1 : index
%sx = dim %dst, 0 : memref<?xi32>
- call @mcuMemHostRegisterMemRef1dInt32(%dst) : (memref<?xi32>) -> ()
+ %cast_dst = memref_cast %dst : memref<?xi32> to memref<*xi32>
+ call @mcuMemHostRegisterInt32(%cast_dst) : (memref<*xi32>) -> ()
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %one, %grid_y = %one, %grid_z = %one)
threads(%tx, %ty, %tz) in (%block_x = %sx, %block_y = %one, %block_z = %one) {
%t0 = index_cast %tx : index to i32
@@ -19,10 +20,9 @@ func @main() {
store %t0, %dst[%tx] : memref<?xi32>
gpu.terminator
}
- %U = memref_cast %dst : memref<?xi32> to memref<*xi32>
- call @print_memref_i32(%U) : (memref<*xi32>) -> ()
+ call @print_memref_i32(%cast_dst) : (memref<*xi32>) -> ()
return
}
-func @mcuMemHostRegisterMemRef1dInt32(%ptr : memref<?xi32>)
-func @print_memref_i32(%ptr : memref<*xi32>)
+func @mcuMemHostRegisterInt32(%memref : memref<*xi32>)
+func @print_memref_i32(%memref : memref<*xi32>)
diff --git a/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp b/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp
index 9c191c5a1a4b..0efd1709cee3 100644
--- a/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp
+++ b/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp
@@ -15,6 +15,7 @@
#include <cassert>
#include <numeric>
+#include "mlir/ExecutionEngine/CRunnerUtils.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/Support/raw_ostream.h"
@@ -79,15 +80,6 @@ extern "C" void mcuMemHostRegister(void *ptr, uint64_t sizeBytes) {
"MemHostRegister");
}
-// A struct that corresponds to how MLIR represents memrefs.
-template <typename T, int N> struct MemRefType {
- T *basePtr;
- T *data;
- int64_t offset;
- int64_t sizes[N];
- int64_t strides[N];
-};
-
// Allows to register a MemRef with the CUDA runtime. Initializes array with
// value. Helpful until we have transfer functions implemented.
template <typename T>
@@ -110,52 +102,16 @@ void mcuMemHostRegisterMemRef(T *pointer, llvm::ArrayRef<int64_t> sizes,
mcuMemHostRegister(pointer, count * sizeof(T));
}
-extern "C" void mcuMemHostRegisterMemRef1dFloat(float *allocated,
- float *aligned, int64_t offset,
- int64_t size, int64_t stride) {
- mcuMemHostRegisterMemRef(aligned + offset, {size}, {stride}, 1.23f);
-}
-
-extern "C" void mcuMemHostRegisterMemRef2dFloat(float *allocated,
- float *aligned, int64_t offset,
- int64_t size0, int64_t size1,
- int64_t stride0,
- int64_t stride1) {
- mcuMemHostRegisterMemRef(aligned + offset, {size0, size1}, {stride0, stride1},
- 1.23f);
-}
-
-extern "C" void mcuMemHostRegisterMemRef3dFloat(float *allocated,
- float *aligned, int64_t offset,
- int64_t size0, int64_t size1,
- int64_t size2, int64_t stride0,
- int64_t stride1,
- int64_t stride2) {
- mcuMemHostRegisterMemRef(aligned + offset, {size0, size1, size2},
- {stride0, stride1, stride2}, 1.23f);
-}
-
-extern "C" void mcuMemHostRegisterMemRef1dInt32(int32_t *allocated,
- int32_t *aligned,
- int64_t offset, int64_t size,
- int64_t stride) {
- mcuMemHostRegisterMemRef(aligned + offset, {size}, {stride}, 123);
-}
-
-extern "C" void mcuMemHostRegisterMemRef2dInt32(int32_t *allocated,
- int32_t *aligned,
- int64_t offset, int64_t size0,
- int64_t size1, int64_t stride0,
- int64_t stride1) {
- mcuMemHostRegisterMemRef(aligned + offset, {size0, size1}, {stride0, stride1},
- 123);
+extern "C" void mcuMemHostRegisterFloat(int64_t rank, void *ptr) {
+ auto *desc = static_cast<StridedMemRefType<float, 1> *>(ptr);
+ auto sizes = llvm::ArrayRef<int64_t>(desc->sizes, rank);
+ auto strides = llvm::ArrayRef<int64_t>(desc->sizes + rank, rank);
+ mcuMemHostRegisterMemRef(desc->data + desc->offset, sizes, strides, 1.23f);
}
-extern "C" void
-mcuMemHostRegisterMemRef3dInt32(int32_t *allocated, int32_t *aligned,
- int64_t offset, int64_t size0, int64_t size1,
- int64_t size2, int64_t stride0, int64_t stride1,
- int64_t stride2) {
- mcuMemHostRegisterMemRef(aligned + offset, {size0, size1, size2},
- {stride0, stride1, stride2}, 123);
+extern "C" void mcuMemHostRegisterInt32(int64_t rank, void *ptr) {
+ auto *desc = static_cast<StridedMemRefType<int32_t, 1> *>(ptr);
+ auto sizes = llvm::ArrayRef<int64_t>(desc->sizes, rank);
+ auto strides = llvm::ArrayRef<int64_t>(desc->sizes + rank, rank);
+ mcuMemHostRegisterMemRef(desc->data + desc->offset, sizes, strides, 123);
}
More information about the Mlir-commits
mailing list