[Mlir-commits] [mlir] b43ae21 - Fix all-reduce int tests by host-registering memrefs.

Christian Sigg llvmlistbot at llvm.org
Mon Mar 23 03:48:22 PDT 2020


Author: Christian Sigg
Date: 2020-03-23T11:48:13+01:00
New Revision: b43ae21e60823473d945defd7141b031658e1cf0

URL: https://github.com/llvm/llvm-project/commit/b43ae21e60823473d945defd7141b031658e1cf0
DIFF: https://github.com/llvm/llvm-project/commit/b43ae21e60823473d945defd7141b031658e1cf0.diff

LOG: Fix all-reduce int tests by host-registering memrefs.

Reduce amount of boiler plate to register host memory.

Summary: Fix all-reduce int tests by host-registering memrefs.

Reviewers: herhut

Reviewed By: herhut

Subscribers: clementval, mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, nicolasvasilache, arpith-jacob, mgester, lucyrfox, liufengdb, Joonsoo, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D76563

Added: 
    

Modified: 
    mlir/test/mlir-cuda-runner/all-reduce-and.mlir
    mlir/test/mlir-cuda-runner/all-reduce-max.mlir
    mlir/test/mlir-cuda-runner/all-reduce-min.mlir
    mlir/test/mlir-cuda-runner/all-reduce-or.mlir
    mlir/test/mlir-cuda-runner/all-reduce-xor.mlir
    mlir/test/mlir-cuda-runner/multiple-all-reduce.mlir
    mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/test/mlir-cuda-runner/all-reduce-and.mlir b/mlir/test/mlir-cuda-runner/all-reduce-and.mlir
index edf3e029b9e8..8eb2e0d72ec5 100644
--- a/mlir/test/mlir-cuda-runner/all-reduce-and.mlir
+++ b/mlir/test/mlir-cuda-runner/all-reduce-and.mlir
@@ -2,9 +2,7 @@
 
 func @main() {
   %data = alloc() : memref<2x6xi32>
-  %sum_and = alloc() : memref<2xi32>
-  %sum_or = alloc() : memref<2xi32>
-  %sum_min = alloc() : memref<2xi32>
+  %sum = alloc() : memref<2xi32>
   %cst0 = constant 0 : i32
   %cst1 = constant 1 : i32
   %cst2 = constant 2 : i32
@@ -25,7 +23,12 @@ func @main() {
   %c4 = constant 4 : index
   %c5 = constant 5 : index
   %c6 = constant 6 : index
-    
+
+  %cast_data = memref_cast %data : memref<2x6xi32> to memref<?x?xi32>
+  call @mcuMemHostRegisterMemRef2dInt32(%cast_data) : (memref<?x?xi32>) -> ()
+  %cast_sum = memref_cast %sum : memref<2xi32> to memref<?xi32>
+  call @mcuMemHostRegisterMemRef1dInt32(%cast_sum) : (memref<?xi32>) -> ()
+
   store %cst0, %data[%c0, %c0] : memref<2x6xi32>
   store %cst1, %data[%c0, %c1] : memref<2x6xi32>
   store %cst2, %data[%c0, %c2] : memref<2x6xi32>
@@ -44,17 +47,19 @@ func @main() {
   gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c2, %grid_y = %c1, %grid_z = %c1)
              threads(%tx, %ty, %tz) in (%block_x = %c6, %block_y = %c1, %block_z = %c1) {
     %val = load %data[%bx, %tx] : memref<2x6xi32>
-    %reduced_and = "gpu.all_reduce"(%val) ({}) { op = "and" } : (i32) -> (i32)
-    store %reduced_and, %sum_and[%bx] : memref<2xi32>
+    %reduced = "gpu.all_reduce"(%val) ({}) { op = "and" } : (i32) -> (i32)
+    store %reduced, %sum[%bx] : memref<2xi32>
     gpu.terminator
   }
 
-  %ptr_and = memref_cast %sum_and : memref<2xi32> to memref<*xi32>
-  call @print_memref_i32(%ptr_and) : (memref<*xi32>) -> ()
+  %ptr = memref_cast %sum : memref<2xi32> to memref<*xi32>
+  call @print_memref_i32(%ptr) : (memref<*xi32>) -> ()
   // CHECK: [0, 2]
 
   return
 }
 
+func @mcuMemHostRegisterMemRef1dInt32(%ptr : memref<?xi32>)
+func @mcuMemHostRegisterMemRef2dInt32(%ptr : memref<?x?xi32>)
 func @print_memref_i32(memref<*xi32>)
 

diff  --git a/mlir/test/mlir-cuda-runner/all-reduce-max.mlir b/mlir/test/mlir-cuda-runner/all-reduce-max.mlir
index 6ed27ccb9d4b..1213625fa9a0 100644
--- a/mlir/test/mlir-cuda-runner/all-reduce-max.mlir
+++ b/mlir/test/mlir-cuda-runner/all-reduce-max.mlir
@@ -23,7 +23,12 @@ func @main() {
   %c4 = constant 4 : index
   %c5 = constant 5 : index
   %c6 = constant 6 : index
-    
+
+  %cast_data = memref_cast %data : memref<2x6xi32> to memref<?x?xi32>
+  call @mcuMemHostRegisterMemRef2dInt32(%cast_data) : (memref<?x?xi32>) -> ()
+  %cast_sum = memref_cast %sum : memref<2xi32> to memref<?xi32>
+  call @mcuMemHostRegisterMemRef1dInt32(%cast_sum) : (memref<?xi32>) -> ()
+
   store %cst0, %data[%c0, %c0] : memref<2x6xi32>
   store %cst1, %data[%c0, %c1] : memref<2x6xi32>
   store %cst2, %data[%c0, %c2] : memref<2x6xi32>
@@ -54,5 +59,7 @@ func @main() {
   return
 }
 
+func @mcuMemHostRegisterMemRef1dInt32(%ptr : memref<?xi32>)
+func @mcuMemHostRegisterMemRef2dInt32(%ptr : memref<?x?xi32>)
 func @print_memref_i32(memref<*xi32>)
 

diff  --git a/mlir/test/mlir-cuda-runner/all-reduce-min.mlir b/mlir/test/mlir-cuda-runner/all-reduce-min.mlir
index 2165fe58ce49..f467c80027c2 100644
--- a/mlir/test/mlir-cuda-runner/all-reduce-min.mlir
+++ b/mlir/test/mlir-cuda-runner/all-reduce-min.mlir
@@ -23,7 +23,12 @@ func @main() {
   %c4 = constant 4 : index
   %c5 = constant 5 : index
   %c6 = constant 6 : index
-    
+
+  %cast_data = memref_cast %data : memref<2x6xi32> to memref<?x?xi32>
+  call @mcuMemHostRegisterMemRef2dInt32(%cast_data) : (memref<?x?xi32>) -> ()
+  %cast_sum = memref_cast %sum : memref<2xi32> to memref<?xi32>
+  call @mcuMemHostRegisterMemRef1dInt32(%cast_sum) : (memref<?xi32>) -> ()
+
   store %cst0, %data[%c0, %c0] : memref<2x6xi32>
   store %cst1, %data[%c0, %c1] : memref<2x6xi32>
   store %cst2, %data[%c0, %c2] : memref<2x6xi32>
@@ -54,5 +59,7 @@ func @main() {
   return
 }
 
+func @mcuMemHostRegisterMemRef1dInt32(%ptr : memref<?xi32>)
+func @mcuMemHostRegisterMemRef2dInt32(%ptr : memref<?x?xi32>)
 func @print_memref_i32(memref<*xi32>)
 

diff  --git a/mlir/test/mlir-cuda-runner/all-reduce-or.mlir b/mlir/test/mlir-cuda-runner/all-reduce-or.mlir
index 2091c22356f5..3135970620dc 100644
--- a/mlir/test/mlir-cuda-runner/all-reduce-or.mlir
+++ b/mlir/test/mlir-cuda-runner/all-reduce-or.mlir
@@ -23,7 +23,12 @@ func @main() {
   %c4 = constant 4 : index
   %c5 = constant 5 : index
   %c6 = constant 6 : index
-    
+
+  %cast_data = memref_cast %data : memref<2x6xi32> to memref<?x?xi32>
+  call @mcuMemHostRegisterMemRef2dInt32(%cast_data) : (memref<?x?xi32>) -> ()
+  %cast_sum = memref_cast %sum : memref<2xi32> to memref<?xi32>
+  call @mcuMemHostRegisterMemRef1dInt32(%cast_sum) : (memref<?xi32>) -> ()
+
   store %cst0, %data[%c0, %c0] : memref<2x6xi32>
   store %cst1, %data[%c0, %c1] : memref<2x6xi32>
   store %cst2, %data[%c0, %c2] : memref<2x6xi32>
@@ -54,5 +59,7 @@ func @main() {
   return
 }
 
+func @mcuMemHostRegisterMemRef1dInt32(%ptr : memref<?xi32>)
+func @mcuMemHostRegisterMemRef2dInt32(%ptr : memref<?x?xi32>)
 func @print_memref_i32(memref<*xi32>)
 

diff  --git a/mlir/test/mlir-cuda-runner/all-reduce-xor.mlir b/mlir/test/mlir-cuda-runner/all-reduce-xor.mlir
index 153164128b16..913088f1164e 100644
--- a/mlir/test/mlir-cuda-runner/all-reduce-xor.mlir
+++ b/mlir/test/mlir-cuda-runner/all-reduce-xor.mlir
@@ -23,7 +23,12 @@ func @main() {
   %c4 = constant 4 : index
   %c5 = constant 5 : index
   %c6 = constant 6 : index
-    
+
+  %cast_data = memref_cast %data : memref<2x6xi32> to memref<?x?xi32>
+  call @mcuMemHostRegisterMemRef2dInt32(%cast_data) : (memref<?x?xi32>) -> ()
+  %cast_sum = memref_cast %sum : memref<2xi32> to memref<?xi32>
+  call @mcuMemHostRegisterMemRef1dInt32(%cast_sum) : (memref<?xi32>) -> ()
+
   store %cst0, %data[%c0, %c0] : memref<2x6xi32>
   store %cst1, %data[%c0, %c1] : memref<2x6xi32>
   store %cst2, %data[%c0, %c2] : memref<2x6xi32>
@@ -54,5 +59,7 @@ func @main() {
   return
 }
 
+func @mcuMemHostRegisterMemRef1dInt32(%ptr : memref<?xi32>)
+func @mcuMemHostRegisterMemRef2dInt32(%ptr : memref<?x?xi32>)
 func @print_memref_i32(memref<*xi32>)
 

diff  --git a/mlir/test/mlir-cuda-runner/multiple-all-reduce.mlir b/mlir/test/mlir-cuda-runner/multiple-all-reduce.mlir
index 2c8eced2d4be..c989db9b1cb9 100644
--- a/mlir/test/mlir-cuda-runner/multiple-all-reduce.mlir
+++ b/mlir/test/mlir-cuda-runner/multiple-all-reduce.mlir
@@ -25,6 +25,13 @@ func @main() {
   %c5 = constant 5 : index
   %c6 = constant 6 : index
 
+  %cast_data = memref_cast %data : memref<2x6xf32> to memref<?x?xf32>
+  call @mcuMemHostRegisterMemRef2dFloat(%cast_data) : (memref<?x?xf32>) -> ()
+  %cast_sum = memref_cast %sum : memref<2xf32> to memref<?xf32>
+  call @mcuMemHostRegisterMemRef1dFloat(%cast_sum) : (memref<?xf32>) -> ()
+  %cast_mul = memref_cast %mul : memref<2xf32> to memref<?xf32>
+  call @mcuMemHostRegisterMemRef1dFloat(%cast_mul) : (memref<?xf32>) -> ()
+
   store %cst0, %data[%c0, %c0] : memref<2x6xf32>
   store %cst1, %data[%c0, %c1] : memref<2x6xf32>
   store %cst2, %data[%c0, %c2] : memref<2x6xf32>
@@ -61,4 +68,6 @@ func @main() {
   return
 }
 
+func @mcuMemHostRegisterMemRef1dFloat(%ptr : memref<?xf32>)
+func @mcuMemHostRegisterMemRef2dFloat(%ptr : memref<?x?xf32>)
 func @print_memref_f32(memref<*xf32>)

diff  --git a/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp b/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp
index 350d9869373a..9c191c5a1a4b 100644
--- a/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp
+++ b/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp
@@ -15,6 +15,7 @@
 #include <cassert>
 #include <numeric>
 
+#include "llvm/ADT/ArrayRef.h"
 #include "llvm/Support/raw_ostream.h"
 
 #include "cuda.h"
@@ -89,24 +90,39 @@ template <typename T, int N> struct MemRefType {
 
 // Allows to register a MemRef with the CUDA runtime. Initializes array with
 // value. Helpful until we have transfer functions implemented.
-template <typename T, int N>
-void mcuMemHostRegisterMemRef(const MemRefType<T, N> *arg, T value) {
-  auto count = std::accumulate(arg->sizes, arg->sizes + N, 1,
-                               std::multiplies<int64_t>());
-  std::fill_n(arg->data, count, value);
-  mcuMemHostRegister(arg->data, count * sizeof(T));
+template <typename T>
+void mcuMemHostRegisterMemRef(T *pointer, llvm::ArrayRef<int64_t> sizes,
+                              llvm::ArrayRef<int64_t> strides, T value) {
+  assert(sizes.size() == strides.size());
+  llvm::SmallVector<int64_t, 4> denseStrides(strides.size());
+
+  std::partial_sum(sizes.rbegin(), sizes.rend(), denseStrides.rbegin(),
+                   std::multiplies<int64_t>());
+  auto count = denseStrides.front();
+
+  // Only densely packed tensors are currently supported.
+  std::rotate(denseStrides.begin(), denseStrides.begin() + 1,
+              denseStrides.end());
+  denseStrides.back() = 1;
+  assert(strides == llvm::makeArrayRef(denseStrides));
+
+  std::fill_n(pointer, count, value);
+  mcuMemHostRegister(pointer, count * sizeof(T));
 }
 
 extern "C" void mcuMemHostRegisterMemRef1dFloat(float *allocated,
                                                 float *aligned, int64_t offset,
                                                 int64_t size, int64_t stride) {
-  MemRefType<float, 1> descriptor;
-  descriptor.basePtr = allocated;
-  descriptor.data = aligned;
-  descriptor.offset = offset;
-  descriptor.sizes[0] = size;
-  descriptor.strides[0] = stride;
-  mcuMemHostRegisterMemRef(&descriptor, 1.23f);
+  mcuMemHostRegisterMemRef(aligned + offset, {size}, {stride}, 1.23f);
+}
+
+extern "C" void mcuMemHostRegisterMemRef2dFloat(float *allocated,
+                                                float *aligned, int64_t offset,
+                                                int64_t size0, int64_t size1,
+                                                int64_t stride0,
+                                                int64_t stride1) {
+  mcuMemHostRegisterMemRef(aligned + offset, {size0, size1}, {stride0, stride1},
+                           1.23f);
 }
 
 extern "C" void mcuMemHostRegisterMemRef3dFloat(float *allocated,
@@ -115,15 +131,31 @@ extern "C" void mcuMemHostRegisterMemRef3dFloat(float *allocated,
                                                 int64_t size2, int64_t stride0,
                                                 int64_t stride1,
                                                 int64_t stride2) {
-  MemRefType<float, 3> descriptor;
-  descriptor.basePtr = allocated;
-  descriptor.data = aligned;
-  descriptor.offset = offset;
-  descriptor.sizes[0] = size0;
-  descriptor.strides[0] = stride0;
-  descriptor.sizes[1] = size1;
-  descriptor.strides[1] = stride1;
-  descriptor.sizes[2] = size2;
-  descriptor.strides[2] = stride2;
-  mcuMemHostRegisterMemRef(&descriptor, 1.23f);
+  mcuMemHostRegisterMemRef(aligned + offset, {size0, size1, size2},
+                           {stride0, stride1, stride2}, 1.23f);
+}
+
+extern "C" void mcuMemHostRegisterMemRef1dInt32(int32_t *allocated,
+                                                int32_t *aligned,
+                                                int64_t offset, int64_t size,
+                                                int64_t stride) {
+  mcuMemHostRegisterMemRef(aligned + offset, {size}, {stride}, 123);
+}
+
+extern "C" void mcuMemHostRegisterMemRef2dInt32(int32_t *allocated,
+                                                int32_t *aligned,
+                                                int64_t offset, int64_t size0,
+                                                int64_t size1, int64_t stride0,
+                                                int64_t stride1) {
+  mcuMemHostRegisterMemRef(aligned + offset, {size0, size1}, {stride0, stride1},
+                           123);
+}
+
+extern "C" void
+mcuMemHostRegisterMemRef3dInt32(int32_t *allocated, int32_t *aligned,
+                                int64_t offset, int64_t size0, int64_t size1,
+                                int64_t size2, int64_t stride0, int64_t stride1,
+                                int64_t stride2) {
+  mcuMemHostRegisterMemRef(aligned + offset, {size0, size1, size2},
+                           {stride0, stride1, stride2}, 123);
 }


        


More information about the Mlir-commits mailing list