[Mlir-commits] [mlir] 222e0e5 - [MLIR] Helper class referencing MemRefType to unify runner implementations.
Christian Sigg
llvmlistbot at llvm.org
Tue May 26 07:32:47 PDT 2020
Author: Christian Sigg
Date: 2020-05-26T16:32:36+02:00
New Revision: 222e0e58a87649623b3d16ce3fef56a6a0555be3
URL: https://github.com/llvm/llvm-project/commit/222e0e58a87649623b3d16ce3fef56a6a0555be3
DIFF: https://github.com/llvm/llvm-project/commit/222e0e58a87649623b3d16ce3fef56a6a0555be3.diff
LOG: [MLIR] Helper class referencing MemRefType to unify runner implementations.
Summary:
Add DynamicMemRefType which can reference one of the statically ranked StridedMemRefType or a UnrankedMemRefType so that runner utils only need to be implemented once.
There is definitely room for more clean up and unification, but I will keep that for follow-ups.
Reviewers: nicolasvasilache
Reviewed By: nicolasvasilache
Subscribers: mehdi_amini, rriddle, jpienaar, shauheen, antiagainst, nicolasvasilache, arpith-jacob, mgester, lucyrfox, liufengdb, stephenneuendorffer, Joonsoo, grosul1, frgossen, Kayjukh, jurahul, llvm-commits
Tags: #llvm
Differential Revision: https://reviews.llvm.org/D80513
Added:
Modified:
mlir/include/mlir/ExecutionEngine/CRunnerUtils.h
mlir/include/mlir/ExecutionEngine/RunnerUtils.h
mlir/lib/ExecutionEngine/RunnerUtils.cpp
mlir/test/mlir-cpu-runner/unranked_memref.mlir
mlir/test/mlir-cpu-runner/utils.mlir
mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h b/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h
index 8155820d6347..bc59d3de2086 100644
--- a/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h
+++ b/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h
@@ -33,12 +33,6 @@
#include <cstdint>
-template <int N>
-void dropFront(int64_t arr[N], int64_t *res) {
- for (unsigned i = 1; i < N; ++i)
- *(res + i - 1) = arr[i];
-}
-
//===----------------------------------------------------------------------===//
// Codegen-compatible structures for Vector type.
//===----------------------------------------------------------------------===//
@@ -129,6 +123,10 @@ struct StridedMemRefType {
res.basePtr = basePtr;
res.data = data;
res.offset = offset + idx * strides[0];
+ auto dropFront = [](const int64_t *arr, int64_t *res) {
+ for (unsigned i = 1; i < N; ++i)
+ res[i - 1] = arr[i];
+ };
dropFront<N>(sizes, res.sizes);
dropFront<N>(strides, res.strides);
return res;
@@ -164,6 +162,39 @@ struct UnrankedMemRefType {
void *descriptor;
};
+//===----------------------------------------------------------------------===//
+// DynamicMemRefType type.
+//===----------------------------------------------------------------------===//
+// A reference to one of the StridedMemRef types.
+template <typename T>
+class DynamicMemRefType {
+public:
+ explicit DynamicMemRefType(const StridedMemRefType<T, 0> &mem_ref)
+ : rank(0), basePtr(mem_ref.basePtr), data(mem_ref.data),
+ offset(mem_ref.offset), sizes(nullptr), strides(nullptr) {}
+ template <int N>
+ explicit DynamicMemRefType(const StridedMemRefType<T, N> &mem_ref)
+ : rank(N), basePtr(mem_ref.basePtr), data(mem_ref.data),
+ offset(mem_ref.offset), sizes(mem_ref.sizes), strides(mem_ref.strides) {
+ }
+ explicit DynamicMemRefType(const UnrankedMemRefType<T> &mem_ref)
+ : rank(mem_ref.rank) {
+ auto *desc = static_cast<StridedMemRefType<T, 1> *>(mem_ref.descriptor);
+ basePtr = desc->basePtr;
+ data = desc->data;
+ offset = desc->offset;
+ sizes = rank == 0 ? nullptr : desc->sizes;
+ strides = sizes + rank;
+ }
+
+ int64_t rank;
+ T *basePtr;
+ T *data;
+ int64_t offset;
+ const int64_t *sizes;
+ const int64_t *strides;
+};
+
//===----------------------------------------------------------------------===//
// Small runtime support "lib" for vector.print lowering during codegen.
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/ExecutionEngine/RunnerUtils.h b/mlir/include/mlir/ExecutionEngine/RunnerUtils.h
index 5f239a4c146e..7729b9c88796 100644
--- a/mlir/include/mlir/ExecutionEngine/RunnerUtils.h
+++ b/mlir/include/mlir/ExecutionEngine/RunnerUtils.h
@@ -35,29 +35,35 @@
#include "mlir/ExecutionEngine/CRunnerUtils.h"
-template <typename StreamType, typename T, int N>
-void printMemRefMetaData(StreamType &os, StridedMemRefType<T, N> &V) {
- static_assert(N > 0, "Expected N > 0");
- os << "Memref base@ = " << reinterpret_cast<void *>(V.data) << " rank = " << N
- << " offset = " << V.offset << " sizes = [" << V.sizes[0];
- for (unsigned i = 1; i < N; ++i)
- os << ", " << V.sizes[i];
- os << "] strides = [" << V.strides[0];
- for (unsigned i = 1; i < N; ++i)
- os << ", " << V.strides[i];
+template <typename T, typename StreamType>
+void printMemRefMetaData(StreamType &os, const DynamicMemRefType<T> &V) {
+ os << "base@ = " << reinterpret_cast<void *>(V.data) << " rank = " << V.rank
+ << " offset = " << V.offset;
+ auto print = [&](const int64_t *ptr) {
+ if (V.rank == 0)
+ return;
+ os << ptr[0];
+ for (int64_t i = 1; i < V.rank; ++i)
+ os << ", " << ptr[i];
+ };
+ os << " sizes = [";
+ print(V.sizes);
+ os << "] strides = [";
+ print(V.strides);
os << "]";
}
-template <typename StreamType, typename T>
-void printMemRefMetaData(StreamType &os, StridedMemRefType<T, 0> &V) {
- os << "Memref base@ = " << reinterpret_cast<void *>(V.data) << " rank = 0"
- << " offset = " << V.offset;
+template <typename StreamType, typename T, int N>
+void printMemRefMetaData(StreamType &os, StridedMemRefType<T, N> &V) {
+ static_assert(N >= 0, "Expected N > 0");
+ os << "MemRef ";
+ printMemRefMetaData(os, DynamicMemRefType<T>(V));
}
-template <typename T, typename StreamType>
+template <typename StreamType, typename T>
void printUnrankedMemRefMetaData(StreamType &os, UnrankedMemRefType<T> &V) {
- os << "Unranked Memref rank = " << V.rank << " "
- << "descriptor@ = " << reinterpret_cast<void *>(V.descriptor) << "\n";
+ os << "Unranked MemRef ";
+ printMemRefMetaData(os, DynamicMemRefType<T>(V));
}
////////////////////////////////////////////////////////////////////////////////
@@ -118,88 +124,92 @@ std::ostream &operator<<(std::ostream &os, const Vector<T, M, Dims...> &v) {
return os;
}
-template <typename T, int N> struct MemRefDataPrinter {
- static void print(std::ostream &os, T *base, int64_t rank, int64_t offset,
- int64_t *sizes, int64_t *strides);
- static void printFirst(std::ostream &os, T *base, int64_t rank,
- int64_t offset, int64_t *sizes, int64_t *strides);
- static void printLast(std::ostream &os, T *base, int64_t rank, int64_t offset,
- int64_t *sizes, int64_t *strides);
-};
-
-template <typename T> struct MemRefDataPrinter<T, 0> {
- static void print(std::ostream &os, T *base, int64_t rank, int64_t offset,
- int64_t *sizes = nullptr, int64_t *strides = nullptr);
+template <typename T>
+struct MemRefDataPrinter {
+ static void print(std::ostream &os, T *base, int64_t dim, int64_t rank,
+ int64_t offset, const int64_t *sizes,
+ const int64_t *strides);
+ static void printFirst(std::ostream &os, T *base, int64_t dim, int64_t rank,
+ int64_t offset, const int64_t *sizes,
+ const int64_t *strides);
+ static void printLast(std::ostream &os, T *base, int64_t dim, int64_t rank,
+ int64_t offset, const int64_t *sizes,
+ const int64_t *strides);
};
-template <typename T, int N>
-void MemRefDataPrinter<T, N>::printFirst(std::ostream &os, T *base,
- int64_t rank, int64_t offset,
- int64_t *sizes, int64_t *strides) {
+template <typename T>
+void MemRefDataPrinter<T>::printFirst(std::ostream &os, T *base, int64_t dim,
+ int64_t rank, int64_t offset,
+ const int64_t *sizes,
+ const int64_t *strides) {
os << "[";
- MemRefDataPrinter<T, N - 1>::print(os, base, rank, offset, sizes + 1,
- strides + 1);
+ print(os, base, dim - 1, rank, offset, sizes + 1, strides + 1);
// If single element, close square bracket and return early.
if (sizes[0] <= 1) {
os << "]";
return;
}
os << ", ";
- if (N > 1)
+ if (dim > 1)
os << "\n";
}
-template <typename T, int N>
-void MemRefDataPrinter<T, N>::print(std::ostream &os, T *base, int64_t rank,
- int64_t offset, int64_t *sizes,
- int64_t *strides) {
- printFirst(os, base, rank, offset, sizes, strides);
+template <typename T>
+void MemRefDataPrinter<T>::print(std::ostream &os, T *base, int64_t dim,
+ int64_t rank, int64_t offset,
+ const int64_t *sizes, const int64_t *strides) {
+ if (dim == 0) {
+ os << base[offset];
+ return;
+ }
+ printFirst(os, base, dim, rank, offset, sizes, strides);
for (unsigned i = 1; i + 1 < sizes[0]; ++i) {
- printSpace(os, rank - N + 1);
- MemRefDataPrinter<T, N - 1>::print(os, base, rank, offset + i * strides[0],
- sizes + 1, strides + 1);
+ printSpace(os, rank - dim + 1);
+ print(os, base, dim - 1, rank, offset + i * strides[0], sizes + 1,
+ strides + 1);
os << ", ";
- if (N > 1)
+ if (dim > 1)
os << "\n";
}
if (sizes[0] <= 1)
return;
- printLast(os, base, rank, offset, sizes, strides);
+ printLast(os, base, dim, rank, offset, sizes, strides);
}
-template <typename T, int N>
-void MemRefDataPrinter<T, N>::printLast(std::ostream &os, T *base, int64_t rank,
- int64_t offset, int64_t *sizes,
- int64_t *strides) {
- printSpace(os, rank - N + 1);
- MemRefDataPrinter<T, N - 1>::print(os, base, rank,
- offset + (sizes[0] - 1) * (*strides),
- sizes + 1, strides + 1);
+template <typename T>
+void MemRefDataPrinter<T>::printLast(std::ostream &os, T *base, int64_t dim,
+ int64_t rank, int64_t offset,
+ const int64_t *sizes,
+ const int64_t *strides) {
+ printSpace(os, rank - dim + 1);
+ print(os, base, dim - 1, rank, offset + (sizes[0] - 1) * (*strides),
+ sizes + 1, strides + 1);
os << "]";
}
template <typename T>
-void MemRefDataPrinter<T, 0>::print(std::ostream &os, T *base, int64_t rank,
- int64_t offset, int64_t *sizes,
- int64_t *strides) {
- os << base[offset];
-}
-
-template <typename T, int N> void printMemRef(StridedMemRefType<T, N> &M) {
- static_assert(N > 0, "Expected N > 0");
+void printMemRef(const DynamicMemRefType<T> &M) {
printMemRefMetaData(std::cout, M);
std::cout << " data = " << std::endl;
- MemRefDataPrinter<T, N>::print(std::cout, M.data, N, M.offset, M.sizes,
- M.strides);
+ if (M.rank == 0)
+ std::cout << "[";
+ MemRefDataPrinter<T>::print(std::cout, M.data, M.rank, M.rank, M.offset,
+ M.sizes, M.strides);
+ if (M.rank == 0)
+ std::cout << "]";
std::cout << std::endl;
}
-template <typename T> void printMemRef(StridedMemRefType<T, 0> &M) {
- printMemRefMetaData(std::cout, M);
- std::cout << " data = " << std::endl;
- std::cout << "[";
- MemRefDataPrinter<T, 0>::print(std::cout, M.data, 0, M.offset);
- std::cout << "]" << std::endl;
+template <typename T, int N>
+void printMemRef(StridedMemRefType<T, N> &M) {
+ std::cout << "Memref ";
+ printMemRef(DynamicMemRefType<T>(M));
+}
+
+template <typename T>
+void printMemRef(UnrankedMemRefType<T> &M) {
+ std::cout << "Unranked Memref ";
+ printMemRef(DynamicMemRefType<T>(M));
}
} // namespace impl
diff --git a/mlir/lib/ExecutionEngine/RunnerUtils.cpp b/mlir/lib/ExecutionEngine/RunnerUtils.cpp
index 7991eca61994..7497ebdacf68 100644
--- a/mlir/lib/ExecutionEngine/RunnerUtils.cpp
+++ b/mlir/lib/ExecutionEngine/RunnerUtils.cpp
@@ -24,57 +24,16 @@ extern "C" void _mlir_ciface_print_memref_vector_4x4xf32(
impl::printMemRef(*M);
}
-#define MEMREF_CASE(TYPE, RANK) \
- case RANK: \
- impl::printMemRef(*(static_cast<StridedMemRefType<TYPE, RANK> *>(ptr))); \
- break
-
extern "C" void _mlir_ciface_print_memref_i8(UnrankedMemRefType<int8_t> *M) {
- printUnrankedMemRefMetaData(std::cout, *M);
- int64_t rank = M->rank;
- void *ptr = M->descriptor;
-
- switch (rank) {
- MEMREF_CASE(int8_t, 0);
- MEMREF_CASE(int8_t, 1);
- MEMREF_CASE(int8_t, 2);
- MEMREF_CASE(int8_t, 3);
- MEMREF_CASE(int8_t, 4);
- default:
- assert(0 && "Unsupported rank to print");
- }
+ impl::printMemRef(*M);
}
extern "C" void _mlir_ciface_print_memref_i32(UnrankedMemRefType<int32_t> *M) {
- printUnrankedMemRefMetaData(std::cout, *M);
- int64_t rank = M->rank;
- void *ptr = M->descriptor;
-
- switch (rank) {
- MEMREF_CASE(int32_t, 0);
- MEMREF_CASE(int32_t, 1);
- MEMREF_CASE(int32_t, 2);
- MEMREF_CASE(int32_t, 3);
- MEMREF_CASE(int32_t, 4);
- default:
- assert(0 && "Unsupported rank to print");
- }
+ impl::printMemRef(*M);
}
extern "C" void _mlir_ciface_print_memref_f32(UnrankedMemRefType<float> *M) {
- printUnrankedMemRefMetaData(std::cout, *M);
- int64_t rank = M->rank;
- void *ptr = M->descriptor;
-
- switch (rank) {
- MEMREF_CASE(float, 0);
- MEMREF_CASE(float, 1);
- MEMREF_CASE(float, 2);
- MEMREF_CASE(float, 3);
- MEMREF_CASE(float, 4);
- default:
- assert(0 && "Unsupported rank to print");
- }
+ impl::printMemRef(*M);
}
extern "C" void print_memref_i32(int64_t rank, void *ptr) {
diff --git a/mlir/test/mlir-cpu-runner/unranked_memref.mlir b/mlir/test/mlir-cpu-runner/unranked_memref.mlir
index aa54b56b06b7..0eb68ac03368 100644
--- a/mlir/test/mlir-cpu-runner/unranked_memref.mlir
+++ b/mlir/test/mlir-cpu-runner/unranked_memref.mlir
@@ -1,25 +1,21 @@
// RUN: mlir-opt %s -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | mlir-cpu-runner -e main -entry-point-result=void -shared-libs=%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_test_cblas%shlibext,%linalg_test_lib_dir/libmlir_test_cblas_interface%shlibext | FileCheck %s
-// CHECK: rank = 2
// CHECK: rank = 2
// CHECK-SAME: sizes = [10, 3]
// CHECK-SAME: strides = [3, 1]
// CHECK-COUNT-10: [10, 10, 10]
//
// CHECK: rank = 2
-// CHECK: rank = 2
// CHECK-SAME: sizes = [10, 3]
// CHECK-SAME: strides = [3, 1]
// CHECK-COUNT-10: [5, 5, 5]
//
// CHECK: rank = 2
-// CHECK: rank = 2
// CHECK-SAME: sizes = [10, 3]
// CHECK-SAME: strides = [3, 1]
// CHECK-COUNT-10: [2, 2, 2]
//
// CHECK: rank = 0
-// CHECK: rank = 0
// 122 is ASCII for 'z'.
// CHECK: [z]
func @main() -> () {
diff --git a/mlir/test/mlir-cpu-runner/utils.mlir b/mlir/test/mlir-cpu-runner/utils.mlir
index d3ab6177eb65..65957400bf7f 100644
--- a/mlir/test/mlir-cpu-runner/utils.mlir
+++ b/mlir/test/mlir-cpu-runner/utils.mlir
@@ -12,8 +12,7 @@ func @print_0d() {
dealloc %A : memref<f32>
return
}
-// PRINT-0D: Unranked Memref rank = 0 descriptor@ = {{.*}}
-// PRINT-0D: Memref base@ = {{.*}} rank = 0 offset = 0 data =
+// PRINT-0D: Unranked Memref base@ = {{.*}} rank = 0 offset = 0 sizes = [] strides = [] data =
// PRINT-0D: [2]
func @print_1d() {
@@ -26,7 +25,7 @@ func @print_1d() {
dealloc %A : memref<16xf32>
return
}
-// PRINT-1D: Memref base@ = {{.*}} rank = 1 offset = 0 sizes = [16] strides = [1] data =
+// PRINT-1D: Unranked Memref base@ = {{.*}} rank = 1 offset = 0 sizes = [16] strides = [1] data =
// PRINT-1D-NEXT: [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]
func @print_3d() {
@@ -43,7 +42,7 @@ func @print_3d() {
dealloc %A : memref<3x4x5xf32>
return
}
-// PRINT-3D: Memref base@ = {{.*}} rank = 3 offset = 0 sizes = [3, 4, 5] strides = [20, 5, 1] data =
+// PRINT-3D: Unranked Memref base@ = {{.*}} rank = 3 offset = 0 sizes = [3, 4, 5] strides = [20, 5, 1] data =
// PRINT-3D-COUNT-4: {{.*[[:space:]].*}}2, 2, 2, 2, 2
// PRINT-3D-COUNT-4: {{.*[[:space:]].*}}2, 2, 2, 2, 2
// PRINT-3D-COUNT-2: {{.*[[:space:]].*}}2, 2, 2, 2, 2
diff --git a/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp b/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp
index dbe78a55c0b1..705fa9f00930 100644
--- a/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp
+++ b/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp
@@ -83,10 +83,10 @@ extern "C" void mgpuMemHostRegister(void *ptr, uint64_t sizeBytes) {
// Allows to register a MemRef with the CUDA runtime. Initializes array with
// value. Helpful until we have transfer functions implemented.
template <typename T>
-void mcuMemHostRegisterMemRef(T *pointer, llvm::ArrayRef<int64_t> sizes,
- llvm::ArrayRef<int64_t> strides, T value) {
- assert(sizes.size() == strides.size());
- llvm::SmallVector<int64_t, 4> denseStrides(strides.size());
+void mcuMemHostRegisterMemRef(const DynamicMemRefType<T> &mem_ref, T value) {
+ llvm::SmallVector<int64_t, 4> denseStrides(mem_ref.rank);
+ llvm::ArrayRef<int64_t> sizes(mem_ref.sizes, mem_ref.rank);
+ llvm::ArrayRef<int64_t> strides(mem_ref.strides, mem_ref.rank);
std::partial_sum(sizes.rbegin(), sizes.rend(), denseStrides.rbegin(),
std::multiplies<int64_t>());
@@ -98,20 +98,17 @@ void mcuMemHostRegisterMemRef(T *pointer, llvm::ArrayRef<int64_t> sizes,
denseStrides.back() = 1;
assert(strides == llvm::makeArrayRef(denseStrides));
+ auto *pointer = mem_ref.data + mem_ref.offset;
std::fill_n(pointer, count, value);
mgpuMemHostRegister(pointer, count * sizeof(T));
}
extern "C" void mcuMemHostRegisterFloat(int64_t rank, void *ptr) {
- auto *desc = static_cast<StridedMemRefType<float, 1> *>(ptr);
- auto sizes = llvm::ArrayRef<int64_t>(desc->sizes, rank);
- auto strides = llvm::ArrayRef<int64_t>(desc->sizes + rank, rank);
- mcuMemHostRegisterMemRef(desc->data + desc->offset, sizes, strides, 1.23f);
+ UnrankedMemRefType<float> mem_ref = {rank, ptr};
+ mcuMemHostRegisterMemRef(DynamicMemRefType<float>(mem_ref), 1.23f);
}
extern "C" void mcuMemHostRegisterInt32(int64_t rank, void *ptr) {
- auto *desc = static_cast<StridedMemRefType<int32_t, 1> *>(ptr);
- auto sizes = llvm::ArrayRef<int64_t>(desc->sizes, rank);
- auto strides = llvm::ArrayRef<int64_t>(desc->sizes + rank, rank);
- mcuMemHostRegisterMemRef(desc->data + desc->offset, sizes, strides, 123);
+ UnrankedMemRefType<int32_t> mem_ref = {rank, ptr};
+ mcuMemHostRegisterMemRef(DynamicMemRefType<int32_t>(mem_ref), 123);
}
More information about the Mlir-commits
mailing list