[Mlir-commits] [mlir] b43ae21 - Fix all-reduce int tests by host-registering memrefs.
Christian Sigg
llvmlistbot at llvm.org
Mon Mar 23 03:48:22 PDT 2020
Author: Christian Sigg
Date: 2020-03-23T11:48:13+01:00
New Revision: b43ae21e60823473d945defd7141b031658e1cf0
URL: https://github.com/llvm/llvm-project/commit/b43ae21e60823473d945defd7141b031658e1cf0
DIFF: https://github.com/llvm/llvm-project/commit/b43ae21e60823473d945defd7141b031658e1cf0.diff
LOG: Fix all-reduce int tests by host-registering memrefs.
Reduce amount of boiler plate to register host memory.
Summary: Fix all-reduce int tests by host-registering memrefs.
Reviewers: herhut
Reviewed By: herhut
Subscribers: clementval, mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, nicolasvasilache, arpith-jacob, mgester, lucyrfox, liufengdb, Joonsoo, llvm-commits
Tags: #llvm
Differential Revision: https://reviews.llvm.org/D76563
Added:
Modified:
mlir/test/mlir-cuda-runner/all-reduce-and.mlir
mlir/test/mlir-cuda-runner/all-reduce-max.mlir
mlir/test/mlir-cuda-runner/all-reduce-min.mlir
mlir/test/mlir-cuda-runner/all-reduce-or.mlir
mlir/test/mlir-cuda-runner/all-reduce-xor.mlir
mlir/test/mlir-cuda-runner/multiple-all-reduce.mlir
mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp
Removed:
################################################################################
diff --git a/mlir/test/mlir-cuda-runner/all-reduce-and.mlir b/mlir/test/mlir-cuda-runner/all-reduce-and.mlir
index edf3e029b9e8..8eb2e0d72ec5 100644
--- a/mlir/test/mlir-cuda-runner/all-reduce-and.mlir
+++ b/mlir/test/mlir-cuda-runner/all-reduce-and.mlir
@@ -2,9 +2,7 @@
func @main() {
%data = alloc() : memref<2x6xi32>
- %sum_and = alloc() : memref<2xi32>
- %sum_or = alloc() : memref<2xi32>
- %sum_min = alloc() : memref<2xi32>
+ %sum = alloc() : memref<2xi32>
%cst0 = constant 0 : i32
%cst1 = constant 1 : i32
%cst2 = constant 2 : i32
@@ -25,7 +23,12 @@ func @main() {
%c4 = constant 4 : index
%c5 = constant 5 : index
%c6 = constant 6 : index
-
+
+ %cast_data = memref_cast %data : memref<2x6xi32> to memref<?x?xi32>
+ call @mcuMemHostRegisterMemRef2dInt32(%cast_data) : (memref<?x?xi32>) -> ()
+ %cast_sum = memref_cast %sum : memref<2xi32> to memref<?xi32>
+ call @mcuMemHostRegisterMemRef1dInt32(%cast_sum) : (memref<?xi32>) -> ()
+
store %cst0, %data[%c0, %c0] : memref<2x6xi32>
store %cst1, %data[%c0, %c1] : memref<2x6xi32>
store %cst2, %data[%c0, %c2] : memref<2x6xi32>
@@ -44,17 +47,19 @@ func @main() {
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c2, %grid_y = %c1, %grid_z = %c1)
threads(%tx, %ty, %tz) in (%block_x = %c6, %block_y = %c1, %block_z = %c1) {
%val = load %data[%bx, %tx] : memref<2x6xi32>
- %reduced_and = "gpu.all_reduce"(%val) ({}) { op = "and" } : (i32) -> (i32)
- store %reduced_and, %sum_and[%bx] : memref<2xi32>
+ %reduced = "gpu.all_reduce"(%val) ({}) { op = "and" } : (i32) -> (i32)
+ store %reduced, %sum[%bx] : memref<2xi32>
gpu.terminator
}
- %ptr_and = memref_cast %sum_and : memref<2xi32> to memref<*xi32>
- call @print_memref_i32(%ptr_and) : (memref<*xi32>) -> ()
+ %ptr = memref_cast %sum : memref<2xi32> to memref<*xi32>
+ call @print_memref_i32(%ptr) : (memref<*xi32>) -> ()
// CHECK: [0, 2]
return
}
+func @mcuMemHostRegisterMemRef1dInt32(%ptr : memref<?xi32>)
+func @mcuMemHostRegisterMemRef2dInt32(%ptr : memref<?x?xi32>)
func @print_memref_i32(memref<*xi32>)
diff --git a/mlir/test/mlir-cuda-runner/all-reduce-max.mlir b/mlir/test/mlir-cuda-runner/all-reduce-max.mlir
index 6ed27ccb9d4b..1213625fa9a0 100644
--- a/mlir/test/mlir-cuda-runner/all-reduce-max.mlir
+++ b/mlir/test/mlir-cuda-runner/all-reduce-max.mlir
@@ -23,7 +23,12 @@ func @main() {
%c4 = constant 4 : index
%c5 = constant 5 : index
%c6 = constant 6 : index
-
+
+ %cast_data = memref_cast %data : memref<2x6xi32> to memref<?x?xi32>
+ call @mcuMemHostRegisterMemRef2dInt32(%cast_data) : (memref<?x?xi32>) -> ()
+ %cast_sum = memref_cast %sum : memref<2xi32> to memref<?xi32>
+ call @mcuMemHostRegisterMemRef1dInt32(%cast_sum) : (memref<?xi32>) -> ()
+
store %cst0, %data[%c0, %c0] : memref<2x6xi32>
store %cst1, %data[%c0, %c1] : memref<2x6xi32>
store %cst2, %data[%c0, %c2] : memref<2x6xi32>
@@ -54,5 +59,7 @@ func @main() {
return
}
+func @mcuMemHostRegisterMemRef1dInt32(%ptr : memref<?xi32>)
+func @mcuMemHostRegisterMemRef2dInt32(%ptr : memref<?x?xi32>)
func @print_memref_i32(memref<*xi32>)
diff --git a/mlir/test/mlir-cuda-runner/all-reduce-min.mlir b/mlir/test/mlir-cuda-runner/all-reduce-min.mlir
index 2165fe58ce49..f467c80027c2 100644
--- a/mlir/test/mlir-cuda-runner/all-reduce-min.mlir
+++ b/mlir/test/mlir-cuda-runner/all-reduce-min.mlir
@@ -23,7 +23,12 @@ func @main() {
%c4 = constant 4 : index
%c5 = constant 5 : index
%c6 = constant 6 : index
-
+
+ %cast_data = memref_cast %data : memref<2x6xi32> to memref<?x?xi32>
+ call @mcuMemHostRegisterMemRef2dInt32(%cast_data) : (memref<?x?xi32>) -> ()
+ %cast_sum = memref_cast %sum : memref<2xi32> to memref<?xi32>
+ call @mcuMemHostRegisterMemRef1dInt32(%cast_sum) : (memref<?xi32>) -> ()
+
store %cst0, %data[%c0, %c0] : memref<2x6xi32>
store %cst1, %data[%c0, %c1] : memref<2x6xi32>
store %cst2, %data[%c0, %c2] : memref<2x6xi32>
@@ -54,5 +59,7 @@ func @main() {
return
}
+func @mcuMemHostRegisterMemRef1dInt32(%ptr : memref<?xi32>)
+func @mcuMemHostRegisterMemRef2dInt32(%ptr : memref<?x?xi32>)
func @print_memref_i32(memref<*xi32>)
diff --git a/mlir/test/mlir-cuda-runner/all-reduce-or.mlir b/mlir/test/mlir-cuda-runner/all-reduce-or.mlir
index 2091c22356f5..3135970620dc 100644
--- a/mlir/test/mlir-cuda-runner/all-reduce-or.mlir
+++ b/mlir/test/mlir-cuda-runner/all-reduce-or.mlir
@@ -23,7 +23,12 @@ func @main() {
%c4 = constant 4 : index
%c5 = constant 5 : index
%c6 = constant 6 : index
-
+
+ %cast_data = memref_cast %data : memref<2x6xi32> to memref<?x?xi32>
+ call @mcuMemHostRegisterMemRef2dInt32(%cast_data) : (memref<?x?xi32>) -> ()
+ %cast_sum = memref_cast %sum : memref<2xi32> to memref<?xi32>
+ call @mcuMemHostRegisterMemRef1dInt32(%cast_sum) : (memref<?xi32>) -> ()
+
store %cst0, %data[%c0, %c0] : memref<2x6xi32>
store %cst1, %data[%c0, %c1] : memref<2x6xi32>
store %cst2, %data[%c0, %c2] : memref<2x6xi32>
@@ -54,5 +59,7 @@ func @main() {
return
}
+func @mcuMemHostRegisterMemRef1dInt32(%ptr : memref<?xi32>)
+func @mcuMemHostRegisterMemRef2dInt32(%ptr : memref<?x?xi32>)
func @print_memref_i32(memref<*xi32>)
diff --git a/mlir/test/mlir-cuda-runner/all-reduce-xor.mlir b/mlir/test/mlir-cuda-runner/all-reduce-xor.mlir
index 153164128b16..913088f1164e 100644
--- a/mlir/test/mlir-cuda-runner/all-reduce-xor.mlir
+++ b/mlir/test/mlir-cuda-runner/all-reduce-xor.mlir
@@ -23,7 +23,12 @@ func @main() {
%c4 = constant 4 : index
%c5 = constant 5 : index
%c6 = constant 6 : index
-
+
+ %cast_data = memref_cast %data : memref<2x6xi32> to memref<?x?xi32>
+ call @mcuMemHostRegisterMemRef2dInt32(%cast_data) : (memref<?x?xi32>) -> ()
+ %cast_sum = memref_cast %sum : memref<2xi32> to memref<?xi32>
+ call @mcuMemHostRegisterMemRef1dInt32(%cast_sum) : (memref<?xi32>) -> ()
+
store %cst0, %data[%c0, %c0] : memref<2x6xi32>
store %cst1, %data[%c0, %c1] : memref<2x6xi32>
store %cst2, %data[%c0, %c2] : memref<2x6xi32>
@@ -54,5 +59,7 @@ func @main() {
return
}
+func @mcuMemHostRegisterMemRef1dInt32(%ptr : memref<?xi32>)
+func @mcuMemHostRegisterMemRef2dInt32(%ptr : memref<?x?xi32>)
func @print_memref_i32(memref<*xi32>)
diff --git a/mlir/test/mlir-cuda-runner/multiple-all-reduce.mlir b/mlir/test/mlir-cuda-runner/multiple-all-reduce.mlir
index 2c8eced2d4be..c989db9b1cb9 100644
--- a/mlir/test/mlir-cuda-runner/multiple-all-reduce.mlir
+++ b/mlir/test/mlir-cuda-runner/multiple-all-reduce.mlir
@@ -25,6 +25,13 @@ func @main() {
%c5 = constant 5 : index
%c6 = constant 6 : index
+ %cast_data = memref_cast %data : memref<2x6xf32> to memref<?x?xf32>
+ call @mcuMemHostRegisterMemRef2dFloat(%cast_data) : (memref<?x?xf32>) -> ()
+ %cast_sum = memref_cast %sum : memref<2xf32> to memref<?xf32>
+ call @mcuMemHostRegisterMemRef1dFloat(%cast_sum) : (memref<?xf32>) -> ()
+ %cast_mul = memref_cast %mul : memref<2xf32> to memref<?xf32>
+ call @mcuMemHostRegisterMemRef1dFloat(%cast_mul) : (memref<?xf32>) -> ()
+
store %cst0, %data[%c0, %c0] : memref<2x6xf32>
store %cst1, %data[%c0, %c1] : memref<2x6xf32>
store %cst2, %data[%c0, %c2] : memref<2x6xf32>
@@ -61,4 +68,6 @@ func @main() {
return
}
+func @mcuMemHostRegisterMemRef1dFloat(%ptr : memref<?xf32>)
+func @mcuMemHostRegisterMemRef2dFloat(%ptr : memref<?x?xf32>)
func @print_memref_f32(memref<*xf32>)
diff --git a/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp b/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp
index 350d9869373a..9c191c5a1a4b 100644
--- a/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp
+++ b/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp
@@ -15,6 +15,7 @@
#include <cassert>
#include <numeric>
+#include "llvm/ADT/ArrayRef.h"
#include "llvm/Support/raw_ostream.h"
#include "cuda.h"
@@ -89,24 +90,39 @@ template <typename T, int N> struct MemRefType {
// Allows to register a MemRef with the CUDA runtime. Initializes array with
// value. Helpful until we have transfer functions implemented.
-template <typename T, int N>
-void mcuMemHostRegisterMemRef(const MemRefType<T, N> *arg, T value) {
- auto count = std::accumulate(arg->sizes, arg->sizes + N, 1,
- std::multiplies<int64_t>());
- std::fill_n(arg->data, count, value);
- mcuMemHostRegister(arg->data, count * sizeof(T));
+template <typename T>
+void mcuMemHostRegisterMemRef(T *pointer, llvm::ArrayRef<int64_t> sizes,
+ llvm::ArrayRef<int64_t> strides, T value) {
+ assert(sizes.size() == strides.size());
+ llvm::SmallVector<int64_t, 4> denseStrides(strides.size());
+
+ std::partial_sum(sizes.rbegin(), sizes.rend(), denseStrides.rbegin(),
+ std::multiplies<int64_t>());
+ auto count = denseStrides.front();
+
+ // Only densely packed tensors are currently supported.
+ std::rotate(denseStrides.begin(), denseStrides.begin() + 1,
+ denseStrides.end());
+ denseStrides.back() = 1;
+ assert(strides == llvm::makeArrayRef(denseStrides));
+
+ std::fill_n(pointer, count, value);
+ mcuMemHostRegister(pointer, count * sizeof(T));
}
extern "C" void mcuMemHostRegisterMemRef1dFloat(float *allocated,
float *aligned, int64_t offset,
int64_t size, int64_t stride) {
- MemRefType<float, 1> descriptor;
- descriptor.basePtr = allocated;
- descriptor.data = aligned;
- descriptor.offset = offset;
- descriptor.sizes[0] = size;
- descriptor.strides[0] = stride;
- mcuMemHostRegisterMemRef(&descriptor, 1.23f);
+ mcuMemHostRegisterMemRef(aligned + offset, {size}, {stride}, 1.23f);
+}
+
+extern "C" void mcuMemHostRegisterMemRef2dFloat(float *allocated,
+ float *aligned, int64_t offset,
+ int64_t size0, int64_t size1,
+ int64_t stride0,
+ int64_t stride1) {
+ mcuMemHostRegisterMemRef(aligned + offset, {size0, size1}, {stride0, stride1},
+ 1.23f);
}
extern "C" void mcuMemHostRegisterMemRef3dFloat(float *allocated,
@@ -115,15 +131,31 @@ extern "C" void mcuMemHostRegisterMemRef3dFloat(float *allocated,
int64_t size2, int64_t stride0,
int64_t stride1,
int64_t stride2) {
- MemRefType<float, 3> descriptor;
- descriptor.basePtr = allocated;
- descriptor.data = aligned;
- descriptor.offset = offset;
- descriptor.sizes[0] = size0;
- descriptor.strides[0] = stride0;
- descriptor.sizes[1] = size1;
- descriptor.strides[1] = stride1;
- descriptor.sizes[2] = size2;
- descriptor.strides[2] = stride2;
- mcuMemHostRegisterMemRef(&descriptor, 1.23f);
+ mcuMemHostRegisterMemRef(aligned + offset, {size0, size1, size2},
+ {stride0, stride1, stride2}, 1.23f);
+}
+
+extern "C" void mcuMemHostRegisterMemRef1dInt32(int32_t *allocated,
+ int32_t *aligned,
+ int64_t offset, int64_t size,
+ int64_t stride) {
+ mcuMemHostRegisterMemRef(aligned + offset, {size}, {stride}, 123);
+}
+
+extern "C" void mcuMemHostRegisterMemRef2dInt32(int32_t *allocated,
+ int32_t *aligned,
+ int64_t offset, int64_t size0,
+ int64_t size1, int64_t stride0,
+ int64_t stride1) {
+ mcuMemHostRegisterMemRef(aligned + offset, {size0, size1}, {stride0, stride1},
+ 123);
+}
+
+extern "C" void
+mcuMemHostRegisterMemRef3dInt32(int32_t *allocated, int32_t *aligned,
+ int64_t offset, int64_t size0, int64_t size1,
+ int64_t size2, int64_t stride0, int64_t stride1,
+ int64_t stride2) {
+ mcuMemHostRegisterMemRef(aligned + offset, {size0, size1, size2},
+ {stride0, stride1, stride2}, 123);
}
More information about the Mlir-commits
mailing list