[Mlir-commits] [mlir] 62adfed - Unrank mcuMemHostRegister tensor argument.

Christian Sigg llvmlistbot at llvm.org
Tue May 19 04:59:10 PDT 2020


Author: Christian Sigg
Date: 2020-05-19T13:58:54+02:00
New Revision: 62adfed30a125b0057d28b570b353bce4d23df83

URL: https://github.com/llvm/llvm-project/commit/62adfed30a125b0057d28b570b353bce4d23df83
DIFF: https://github.com/llvm/llvm-project/commit/62adfed30a125b0057d28b570b353bce4d23df83.diff

LOG: Unrank mcuMemHostRegister tensor argument.

Reviewers: herhut

Reviewed By: herhut

Subscribers: mehdi_amini, rriddle, jpienaar, shauheen, antiagainst, nicolasvasilache, arpith-jacob, mgester, lucyrfox, aartbik, liufengdb, stephenneuendorffer, Joonsoo, grosul1, frgossen, Kayjukh, jurahul, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D80118

Added: 
    

Modified: 
    mlir/test/mlir-cuda-runner/all-reduce-and.mlir
    mlir/test/mlir-cuda-runner/all-reduce-max.mlir
    mlir/test/mlir-cuda-runner/all-reduce-min.mlir
    mlir/test/mlir-cuda-runner/all-reduce-op.mlir
    mlir/test/mlir-cuda-runner/all-reduce-or.mlir
    mlir/test/mlir-cuda-runner/all-reduce-region.mlir
    mlir/test/mlir-cuda-runner/all-reduce-xor.mlir
    mlir/test/mlir-cuda-runner/gpu-to-cubin.mlir
    mlir/test/mlir-cuda-runner/multiple-all-reduce.mlir
    mlir/test/mlir-cuda-runner/shuffle.mlir
    mlir/test/mlir-cuda-runner/two-modules.mlir
    mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/test/mlir-cuda-runner/all-reduce-and.mlir b/mlir/test/mlir-cuda-runner/all-reduce-and.mlir
index 8eb2e0d72ec5..d3ad7a802537 100644
--- a/mlir/test/mlir-cuda-runner/all-reduce-and.mlir
+++ b/mlir/test/mlir-cuda-runner/all-reduce-and.mlir
@@ -24,10 +24,10 @@ func @main() {
   %c5 = constant 5 : index
   %c6 = constant 6 : index
 
-  %cast_data = memref_cast %data : memref<2x6xi32> to memref<?x?xi32>
-  call @mcuMemHostRegisterMemRef2dInt32(%cast_data) : (memref<?x?xi32>) -> ()
-  %cast_sum = memref_cast %sum : memref<2xi32> to memref<?xi32>
-  call @mcuMemHostRegisterMemRef1dInt32(%cast_sum) : (memref<?xi32>) -> ()
+  %cast_data = memref_cast %data : memref<2x6xi32> to memref<*xi32>
+  call @mcuMemHostRegisterInt32(%cast_data) : (memref<*xi32>) -> ()
+  %cast_sum = memref_cast %sum : memref<2xi32> to memref<*xi32>
+  call @mcuMemHostRegisterInt32(%cast_sum) : (memref<*xi32>) -> ()
 
   store %cst0, %data[%c0, %c0] : memref<2x6xi32>
   store %cst1, %data[%c0, %c1] : memref<2x6xi32>
@@ -52,14 +52,12 @@ func @main() {
     gpu.terminator
   }
 
-  %ptr = memref_cast %sum : memref<2xi32> to memref<*xi32>
-  call @print_memref_i32(%ptr) : (memref<*xi32>) -> ()
+  call @print_memref_i32(%cast_sum) : (memref<*xi32>) -> ()
   // CHECK: [0, 2]
 
   return
 }
 
-func @mcuMemHostRegisterMemRef1dInt32(%ptr : memref<?xi32>)
-func @mcuMemHostRegisterMemRef2dInt32(%ptr : memref<?x?xi32>)
+func @mcuMemHostRegisterInt32(%ptr : memref<*xi32>)
 func @print_memref_i32(memref<*xi32>)
 

diff  --git a/mlir/test/mlir-cuda-runner/all-reduce-max.mlir b/mlir/test/mlir-cuda-runner/all-reduce-max.mlir
index 1213625fa9a0..ae2f6c3d6b3e 100644
--- a/mlir/test/mlir-cuda-runner/all-reduce-max.mlir
+++ b/mlir/test/mlir-cuda-runner/all-reduce-max.mlir
@@ -24,10 +24,10 @@ func @main() {
   %c5 = constant 5 : index
   %c6 = constant 6 : index
 
-  %cast_data = memref_cast %data : memref<2x6xi32> to memref<?x?xi32>
-  call @mcuMemHostRegisterMemRef2dInt32(%cast_data) : (memref<?x?xi32>) -> ()
-  %cast_sum = memref_cast %sum : memref<2xi32> to memref<?xi32>
-  call @mcuMemHostRegisterMemRef1dInt32(%cast_sum) : (memref<?xi32>) -> ()
+  %cast_data = memref_cast %data : memref<2x6xi32> to memref<*xi32>
+  call @mcuMemHostRegisterInt32(%cast_data) : (memref<*xi32>) -> ()
+  %cast_sum = memref_cast %sum : memref<2xi32> to memref<*xi32>
+  call @mcuMemHostRegisterInt32(%cast_sum) : (memref<*xi32>) -> ()
 
   store %cst0, %data[%c0, %c0] : memref<2x6xi32>
   store %cst1, %data[%c0, %c1] : memref<2x6xi32>
@@ -52,14 +52,12 @@ func @main() {
     gpu.terminator
   }
 
-  %ptr = memref_cast %sum : memref<2xi32> to memref<*xi32>
-  call @print_memref_i32(%ptr) : (memref<*xi32>) -> ()
+  call @print_memref_i32(%cast_sum) : (memref<*xi32>) -> ()
   // CHECK: [16, 11]
 
   return
 }
 
-func @mcuMemHostRegisterMemRef1dInt32(%ptr : memref<?xi32>)
-func @mcuMemHostRegisterMemRef2dInt32(%ptr : memref<?x?xi32>)
+func @mcuMemHostRegisterInt32(%ptr : memref<*xi32>)
 func @print_memref_i32(memref<*xi32>)
 

diff  --git a/mlir/test/mlir-cuda-runner/all-reduce-min.mlir b/mlir/test/mlir-cuda-runner/all-reduce-min.mlir
index f467c80027c2..0cd4f11daf10 100644
--- a/mlir/test/mlir-cuda-runner/all-reduce-min.mlir
+++ b/mlir/test/mlir-cuda-runner/all-reduce-min.mlir
@@ -24,10 +24,10 @@ func @main() {
   %c5 = constant 5 : index
   %c6 = constant 6 : index
 
-  %cast_data = memref_cast %data : memref<2x6xi32> to memref<?x?xi32>
-  call @mcuMemHostRegisterMemRef2dInt32(%cast_data) : (memref<?x?xi32>) -> ()
-  %cast_sum = memref_cast %sum : memref<2xi32> to memref<?xi32>
-  call @mcuMemHostRegisterMemRef1dInt32(%cast_sum) : (memref<?xi32>) -> ()
+  %cast_data = memref_cast %data : memref<2x6xi32> to memref<*xi32>
+  call @mcuMemHostRegisterInt32(%cast_data) : (memref<*xi32>) -> ()
+  %cast_sum = memref_cast %sum : memref<2xi32> to memref<*xi32>
+  call @mcuMemHostRegisterInt32(%cast_sum) : (memref<*xi32>) -> ()
 
   store %cst0, %data[%c0, %c0] : memref<2x6xi32>
   store %cst1, %data[%c0, %c1] : memref<2x6xi32>
@@ -52,14 +52,12 @@ func @main() {
     gpu.terminator
   }
 
-  %ptr = memref_cast %sum : memref<2xi32> to memref<*xi32>
-  call @print_memref_i32(%ptr) : (memref<*xi32>) -> ()
+  call @print_memref_i32(%cast_sum) : (memref<*xi32>) -> ()
   // CHECK: [0, 2]
 
   return
 }
 
-func @mcuMemHostRegisterMemRef1dInt32(%ptr : memref<?xi32>)
-func @mcuMemHostRegisterMemRef2dInt32(%ptr : memref<?x?xi32>)
+func @mcuMemHostRegisterInt32(%ptr : memref<*xi32>)
 func @print_memref_i32(memref<*xi32>)
 

diff  --git a/mlir/test/mlir-cuda-runner/all-reduce-op.mlir b/mlir/test/mlir-cuda-runner/all-reduce-op.mlir
index 60acd91f5e75..eb522d2910a6 100644
--- a/mlir/test/mlir-cuda-runner/all-reduce-op.mlir
+++ b/mlir/test/mlir-cuda-runner/all-reduce-op.mlir
@@ -8,7 +8,8 @@ func @main() {
   %sx = dim %dst, 2 : memref<?x?x?xf32>
   %sy = dim %dst, 1 : memref<?x?x?xf32>
   %sz = dim %dst, 0 : memref<?x?x?xf32>
-  call @mcuMemHostRegisterMemRef3dFloat(%dst) : (memref<?x?x?xf32>) -> ()
+  %cast_dst = memref_cast %dst : memref<?x?x?xf32> to memref<*xf32>
+  call @mcuMemHostRegisterFloat(%cast_dst) : (memref<*xf32>) -> ()
   gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %one, %grid_y = %one, %grid_z = %one)
              threads(%tx, %ty, %tz) in (%block_x = %sx, %block_y = %sy, %block_z = %sz) {
     %t0 = muli %tz, %block_y : index
@@ -21,10 +22,9 @@ func @main() {
     store %sum, %dst[%tz, %ty, %tx] : memref<?x?x?xf32>
     gpu.terminator
   }
-  %U = memref_cast %dst : memref<?x?x?xf32> to memref<*xf32>
-  call @print_memref_f32(%U) : (memref<*xf32>) -> ()
+  call @print_memref_f32(%cast_dst) : (memref<*xf32>) -> ()
   return
 }
 
-func @mcuMemHostRegisterMemRef3dFloat(%ptr : memref<?x?x?xf32>)
+func @mcuMemHostRegisterFloat(%ptr : memref<*xf32>)
 func @print_memref_f32(%ptr : memref<*xf32>)

diff  --git a/mlir/test/mlir-cuda-runner/all-reduce-or.mlir b/mlir/test/mlir-cuda-runner/all-reduce-or.mlir
index 3135970620dc..cc9eae9e8b66 100644
--- a/mlir/test/mlir-cuda-runner/all-reduce-or.mlir
+++ b/mlir/test/mlir-cuda-runner/all-reduce-or.mlir
@@ -24,10 +24,10 @@ func @main() {
   %c5 = constant 5 : index
   %c6 = constant 6 : index
 
-  %cast_data = memref_cast %data : memref<2x6xi32> to memref<?x?xi32>
-  call @mcuMemHostRegisterMemRef2dInt32(%cast_data) : (memref<?x?xi32>) -> ()
-  %cast_sum = memref_cast %sum : memref<2xi32> to memref<?xi32>
-  call @mcuMemHostRegisterMemRef1dInt32(%cast_sum) : (memref<?xi32>) -> ()
+  %cast_data = memref_cast %data : memref<2x6xi32> to memref<*xi32>
+  call @mcuMemHostRegisterInt32(%cast_data) : (memref<*xi32>) -> ()
+  %cast_sum = memref_cast %sum : memref<2xi32> to memref<*xi32>
+  call @mcuMemHostRegisterInt32(%cast_sum) : (memref<*xi32>) -> ()
 
   store %cst0, %data[%c0, %c0] : memref<2x6xi32>
   store %cst1, %data[%c0, %c1] : memref<2x6xi32>
@@ -52,14 +52,12 @@ func @main() {
     gpu.terminator
   }
 
-  %ptr = memref_cast %sum : memref<2xi32> to memref<*xi32>
-  call @print_memref_i32(%ptr) : (memref<*xi32>) -> ()
+  call @print_memref_i32(%cast_sum) : (memref<*xi32>) -> ()
   // CHECK: [31, 15]
 
   return
 }
 
-func @mcuMemHostRegisterMemRef1dInt32(%ptr : memref<?xi32>)
-func @mcuMemHostRegisterMemRef2dInt32(%ptr : memref<?x?xi32>)
+func @mcuMemHostRegisterInt32(%ptr : memref<*xi32>)
 func @print_memref_i32(memref<*xi32>)
 

diff  --git a/mlir/test/mlir-cuda-runner/all-reduce-region.mlir b/mlir/test/mlir-cuda-runner/all-reduce-region.mlir
index 6d967d155930..69499215707e 100644
--- a/mlir/test/mlir-cuda-runner/all-reduce-region.mlir
+++ b/mlir/test/mlir-cuda-runner/all-reduce-region.mlir
@@ -6,7 +6,8 @@ func @main() {
   %dst = memref_cast %arg : memref<35xf32> to memref<?xf32>
   %one = constant 1 : index
   %sx = dim %dst, 0 : memref<?xf32>
-  call @mcuMemHostRegisterMemRef1dFloat(%dst) : (memref<?xf32>) -> ()
+  %cast_dst = memref_cast %dst : memref<?xf32> to memref<*xf32>
+  call @mcuMemHostRegisterFloat(%cast_dst) : (memref<*xf32>) -> ()
   gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %one, %grid_y = %one, %grid_z = %one)
              threads(%tx, %ty, %tz) in (%block_x = %sx, %block_y = %one, %block_z = %one) {
     %val = index_cast %tx : index to i32
@@ -19,10 +20,9 @@ func @main() {
     store %res, %dst[%tx] : memref<?xf32>
     gpu.terminator
   }
-  %U = memref_cast %dst : memref<?xf32> to memref<*xf32>
-  call @print_memref_f32(%U) : (memref<*xf32>) -> ()
+  call @print_memref_f32(%cast_dst) : (memref<*xf32>) -> ()
   return
 }
 
-func @mcuMemHostRegisterMemRef1dFloat(%ptr : memref<?xf32>)
+func @mcuMemHostRegisterFloat(%ptr : memref<*xf32>)
 func @print_memref_f32(memref<*xf32>)

diff  --git a/mlir/test/mlir-cuda-runner/all-reduce-xor.mlir b/mlir/test/mlir-cuda-runner/all-reduce-xor.mlir
index 913088f1164e..a32c4d3eb93e 100644
--- a/mlir/test/mlir-cuda-runner/all-reduce-xor.mlir
+++ b/mlir/test/mlir-cuda-runner/all-reduce-xor.mlir
@@ -24,10 +24,10 @@ func @main() {
   %c5 = constant 5 : index
   %c6 = constant 6 : index
 
-  %cast_data = memref_cast %data : memref<2x6xi32> to memref<?x?xi32>
-  call @mcuMemHostRegisterMemRef2dInt32(%cast_data) : (memref<?x?xi32>) -> ()
-  %cast_sum = memref_cast %sum : memref<2xi32> to memref<?xi32>
-  call @mcuMemHostRegisterMemRef1dInt32(%cast_sum) : (memref<?xi32>) -> ()
+  %cast_data = memref_cast %data : memref<2x6xi32> to memref<*xi32>
+  call @mcuMemHostRegisterInt32(%cast_data) : (memref<*xi32>) -> ()
+  %cast_sum = memref_cast %sum : memref<2xi32> to memref<*xi32>
+  call @mcuMemHostRegisterInt32(%cast_sum) : (memref<*xi32>) -> ()
 
   store %cst0, %data[%c0, %c0] : memref<2x6xi32>
   store %cst1, %data[%c0, %c1] : memref<2x6xi32>
@@ -52,14 +52,12 @@ func @main() {
     gpu.terminator
   }
 
-  %ptr = memref_cast %sum : memref<2xi32> to memref<*xi32>
-  call @print_memref_i32(%ptr) : (memref<*xi32>) -> ()
+  call @print_memref_i32(%cast_sum) : (memref<*xi32>) -> ()
   // CHECK: [31, 1]
 
   return
 }
 
-func @mcuMemHostRegisterMemRef1dInt32(%ptr : memref<?xi32>)
-func @mcuMemHostRegisterMemRef2dInt32(%ptr : memref<?x?xi32>)
+func @mcuMemHostRegisterInt32(%ptr : memref<*xi32>)
 func @print_memref_i32(memref<*xi32>)
 

diff  --git a/mlir/test/mlir-cuda-runner/gpu-to-cubin.mlir b/mlir/test/mlir-cuda-runner/gpu-to-cubin.mlir
index bf882761f148..242cc9c28c00 100644
--- a/mlir/test/mlir-cuda-runner/gpu-to-cubin.mlir
+++ b/mlir/test/mlir-cuda-runner/gpu-to-cubin.mlir
@@ -16,8 +16,8 @@ func @main() {
   %arg0 = alloc() : memref<5xf32>
   %21 = constant 5 : i32
   %22 = memref_cast %arg0 : memref<5xf32> to memref<?xf32>
-  call @mcuMemHostRegisterMemRef1dFloat(%22) : (memref<?xf32>) -> ()
   %23 = memref_cast %22 : memref<?xf32> to memref<*xf32>
+  call @mcuMemHostRegisterFloat(%23) : (memref<*xf32>) -> ()
   call @print_memref_f32(%23) : (memref<*xf32>) -> ()
   %24 = constant 1.0 : f32
   call @other_func(%24, %22) : (f32, memref<?xf32>) -> ()
@@ -25,5 +25,5 @@ func @main() {
   return
 }
 
-func @mcuMemHostRegisterMemRef1dFloat(%ptr : memref<?xf32>)
+func @mcuMemHostRegisterFloat(%ptr : memref<*xf32>)
 func @print_memref_f32(%ptr : memref<*xf32>)

diff  --git a/mlir/test/mlir-cuda-runner/multiple-all-reduce.mlir b/mlir/test/mlir-cuda-runner/multiple-all-reduce.mlir
index c989db9b1cb9..a7b143f760a7 100644
--- a/mlir/test/mlir-cuda-runner/multiple-all-reduce.mlir
+++ b/mlir/test/mlir-cuda-runner/multiple-all-reduce.mlir
@@ -25,12 +25,12 @@ func @main() {
   %c5 = constant 5 : index
   %c6 = constant 6 : index
 
-  %cast_data = memref_cast %data : memref<2x6xf32> to memref<?x?xf32>
-  call @mcuMemHostRegisterMemRef2dFloat(%cast_data) : (memref<?x?xf32>) -> ()
-  %cast_sum = memref_cast %sum : memref<2xf32> to memref<?xf32>
-  call @mcuMemHostRegisterMemRef1dFloat(%cast_sum) : (memref<?xf32>) -> ()
-  %cast_mul = memref_cast %mul : memref<2xf32> to memref<?xf32>
-  call @mcuMemHostRegisterMemRef1dFloat(%cast_mul) : (memref<?xf32>) -> ()
+  %cast_data = memref_cast %data : memref<2x6xf32> to memref<*xf32>
+  call @mcuMemHostRegisterFloat(%cast_data) : (memref<*xf32>) -> ()
+  %cast_sum = memref_cast %sum : memref<2xf32> to memref<*xf32>
+  call @mcuMemHostRegisterFloat(%cast_sum) : (memref<*xf32>) -> ()
+  %cast_mul = memref_cast %mul : memref<2xf32> to memref<*xf32>
+  call @mcuMemHostRegisterFloat(%cast_mul) : (memref<*xf32>) -> ()
 
   store %cst0, %data[%c0, %c0] : memref<2x6xf32>
   store %cst1, %data[%c0, %c1] : memref<2x6xf32>
@@ -57,17 +57,14 @@ func @main() {
     gpu.terminator
   }
 
-  %ptr_sum = memref_cast %sum : memref<2xf32> to memref<*xf32>
-  call @print_memref_f32(%ptr_sum) : (memref<*xf32>) -> ()
+  call @print_memref_f32(%cast_sum) : (memref<*xf32>) -> ()
   // CHECK: [31, 39]
 
-  %ptr_mul = memref_cast %mul : memref<2xf32> to memref<*xf32>
-  call @print_memref_f32(%ptr_mul) : (memref<*xf32>) -> ()
+  call @print_memref_f32(%cast_mul) : (memref<*xf32>) -> ()
   // CHECK: [0, 27720]
 
   return
 }
 
-func @mcuMemHostRegisterMemRef1dFloat(%ptr : memref<?xf32>)
-func @mcuMemHostRegisterMemRef2dFloat(%ptr : memref<?x?xf32>)
+func @mcuMemHostRegisterFloat(%ptr : memref<*xf32>)
 func @print_memref_f32(memref<*xf32>)

diff  --git a/mlir/test/mlir-cuda-runner/shuffle.mlir b/mlir/test/mlir-cuda-runner/shuffle.mlir
index 0067ef440571..09fbef0095d8 100644
--- a/mlir/test/mlir-cuda-runner/shuffle.mlir
+++ b/mlir/test/mlir-cuda-runner/shuffle.mlir
@@ -6,7 +6,8 @@ func @main() {
   %dst = memref_cast %arg : memref<13xf32> to memref<?xf32>
   %one = constant 1 : index
   %sx = dim %dst, 0 : memref<?xf32>
-  call @mcuMemHostRegisterMemRef1dFloat(%dst) : (memref<?xf32>) -> ()
+  %cast_dest = memref_cast %dst : memref<?xf32> to memref<*xf32>
+  call @mcuMemHostRegisterFloat(%cast_dest) : (memref<*xf32>) -> ()
   gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %one, %grid_y = %one, %grid_z = %one)
              threads(%tx, %ty, %tz) in (%block_x = %sx, %block_y = %one, %block_z = %one) {
     %t0 = index_cast %tx : index to i32
@@ -22,10 +23,9 @@ func @main() {
     store %value, %dst[%tx] : memref<?xf32>
     gpu.terminator
   }
-  %U = memref_cast %dst : memref<?xf32> to memref<*xf32>
-  call @print_memref_f32(%U) : (memref<*xf32>) -> ()
+  call @print_memref_f32(%cast_dest) : (memref<*xf32>) -> ()
   return
 }
 
-func @mcuMemHostRegisterMemRef1dFloat(%ptr : memref<?xf32>)
+func @mcuMemHostRegisterFloat(%ptr : memref<*xf32>)
 func @print_memref_f32(%ptr : memref<*xf32>)

diff  --git a/mlir/test/mlir-cuda-runner/two-modules.mlir b/mlir/test/mlir-cuda-runner/two-modules.mlir
index 0f01b36f5cee..68c936596315 100644
--- a/mlir/test/mlir-cuda-runner/two-modules.mlir
+++ b/mlir/test/mlir-cuda-runner/two-modules.mlir
@@ -6,7 +6,8 @@ func @main() {
   %dst = memref_cast %arg : memref<13xi32> to memref<?xi32>
   %one = constant 1 : index
   %sx = dim %dst, 0 : memref<?xi32>
-  call @mcuMemHostRegisterMemRef1dInt32(%dst) : (memref<?xi32>) -> ()
+  %cast_dst = memref_cast %dst : memref<?xi32> to memref<*xi32>
+  call @mcuMemHostRegisterInt32(%cast_dst) : (memref<*xi32>) -> ()
   gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %one, %grid_y = %one, %grid_z = %one)
              threads(%tx, %ty, %tz) in (%block_x = %sx, %block_y = %one, %block_z = %one) {
     %t0 = index_cast %tx : index to i32
@@ -19,10 +20,9 @@ func @main() {
     store %t0, %dst[%tx] : memref<?xi32>
     gpu.terminator
   }
-  %U = memref_cast %dst : memref<?xi32> to memref<*xi32>
-  call @print_memref_i32(%U) : (memref<*xi32>) -> ()
+  call @print_memref_i32(%cast_dst) : (memref<*xi32>) -> ()
   return
 }
 
-func @mcuMemHostRegisterMemRef1dInt32(%ptr : memref<?xi32>)
-func @print_memref_i32(%ptr : memref<*xi32>)
+func @mcuMemHostRegisterInt32(%memref : memref<*xi32>)
+func @print_memref_i32(%memref : memref<*xi32>)

diff  --git a/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp b/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp
index 9c191c5a1a4b..0efd1709cee3 100644
--- a/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp
+++ b/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp
@@ -15,6 +15,7 @@
 #include <cassert>
 #include <numeric>
 
+#include "mlir/ExecutionEngine/CRunnerUtils.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/Support/raw_ostream.h"
 
@@ -79,15 +80,6 @@ extern "C" void mcuMemHostRegister(void *ptr, uint64_t sizeBytes) {
                    "MemHostRegister");
 }
 
-// A struct that corresponds to how MLIR represents memrefs.
-template <typename T, int N> struct MemRefType {
-  T *basePtr;
-  T *data;
-  int64_t offset;
-  int64_t sizes[N];
-  int64_t strides[N];
-};
-
 // Allows to register a MemRef with the CUDA runtime. Initializes array with
 // value. Helpful until we have transfer functions implemented.
 template <typename T>
@@ -110,52 +102,16 @@ void mcuMemHostRegisterMemRef(T *pointer, llvm::ArrayRef<int64_t> sizes,
   mcuMemHostRegister(pointer, count * sizeof(T));
 }
 
-extern "C" void mcuMemHostRegisterMemRef1dFloat(float *allocated,
-                                                float *aligned, int64_t offset,
-                                                int64_t size, int64_t stride) {
-  mcuMemHostRegisterMemRef(aligned + offset, {size}, {stride}, 1.23f);
-}
-
-extern "C" void mcuMemHostRegisterMemRef2dFloat(float *allocated,
-                                                float *aligned, int64_t offset,
-                                                int64_t size0, int64_t size1,
-                                                int64_t stride0,
-                                                int64_t stride1) {
-  mcuMemHostRegisterMemRef(aligned + offset, {size0, size1}, {stride0, stride1},
-                           1.23f);
-}
-
-extern "C" void mcuMemHostRegisterMemRef3dFloat(float *allocated,
-                                                float *aligned, int64_t offset,
-                                                int64_t size0, int64_t size1,
-                                                int64_t size2, int64_t stride0,
-                                                int64_t stride1,
-                                                int64_t stride2) {
-  mcuMemHostRegisterMemRef(aligned + offset, {size0, size1, size2},
-                           {stride0, stride1, stride2}, 1.23f);
-}
-
-extern "C" void mcuMemHostRegisterMemRef1dInt32(int32_t *allocated,
-                                                int32_t *aligned,
-                                                int64_t offset, int64_t size,
-                                                int64_t stride) {
-  mcuMemHostRegisterMemRef(aligned + offset, {size}, {stride}, 123);
-}
-
-extern "C" void mcuMemHostRegisterMemRef2dInt32(int32_t *allocated,
-                                                int32_t *aligned,
-                                                int64_t offset, int64_t size0,
-                                                int64_t size1, int64_t stride0,
-                                                int64_t stride1) {
-  mcuMemHostRegisterMemRef(aligned + offset, {size0, size1}, {stride0, stride1},
-                           123);
+extern "C" void mcuMemHostRegisterFloat(int64_t rank, void *ptr) {
+  auto *desc = static_cast<StridedMemRefType<float, 1> *>(ptr);
+  auto sizes = llvm::ArrayRef<int64_t>(desc->sizes, rank);
+  auto strides = llvm::ArrayRef<int64_t>(desc->sizes + rank, rank);
+  mcuMemHostRegisterMemRef(desc->data + desc->offset, sizes, strides, 1.23f);
 }
 
-extern "C" void
-mcuMemHostRegisterMemRef3dInt32(int32_t *allocated, int32_t *aligned,
-                                int64_t offset, int64_t size0, int64_t size1,
-                                int64_t size2, int64_t stride0, int64_t stride1,
-                                int64_t stride2) {
-  mcuMemHostRegisterMemRef(aligned + offset, {size0, size1, size2},
-                           {stride0, stride1, stride2}, 123);
+extern "C" void mcuMemHostRegisterInt32(int64_t rank, void *ptr) {
+  auto *desc = static_cast<StridedMemRefType<int32_t, 1> *>(ptr);
+  auto sizes = llvm::ArrayRef<int64_t>(desc->sizes, rank);
+  auto strides = llvm::ArrayRef<int64_t>(desc->sizes + rank, rank);
+  mcuMemHostRegisterMemRef(desc->data + desc->offset, sizes, strides, 123);
 }


        


More information about the Mlir-commits mailing list