[Mlir-commits] [mlir] 222e0e5 - [MLIR] Helper class referencing MemRefType to unify runner implementations.

Christian Sigg llvmlistbot at llvm.org
Tue May 26 07:32:47 PDT 2020


Author: Christian Sigg
Date: 2020-05-26T16:32:36+02:00
New Revision: 222e0e58a87649623b3d16ce3fef56a6a0555be3

URL: https://github.com/llvm/llvm-project/commit/222e0e58a87649623b3d16ce3fef56a6a0555be3
DIFF: https://github.com/llvm/llvm-project/commit/222e0e58a87649623b3d16ce3fef56a6a0555be3.diff

LOG: [MLIR] Helper class referencing MemRefType to unify runner implementations.

Summary:
Add DynamicMemRefType which can reference one of the statically ranked StridedMemRefType or a UnrankedMemRefType so that runner utils only need to be implemented once.

There is definitely room for more clean up and unification, but I will keep that for follow-ups.

Reviewers: nicolasvasilache

Reviewed By: nicolasvasilache

Subscribers: mehdi_amini, rriddle, jpienaar, shauheen, antiagainst, nicolasvasilache, arpith-jacob, mgester, lucyrfox, liufengdb, stephenneuendorffer, Joonsoo, grosul1, frgossen, Kayjukh, jurahul, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D80513

Added: 
    

Modified: 
    mlir/include/mlir/ExecutionEngine/CRunnerUtils.h
    mlir/include/mlir/ExecutionEngine/RunnerUtils.h
    mlir/lib/ExecutionEngine/RunnerUtils.cpp
    mlir/test/mlir-cpu-runner/unranked_memref.mlir
    mlir/test/mlir-cpu-runner/utils.mlir
    mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h b/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h
index 8155820d6347..bc59d3de2086 100644
--- a/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h
+++ b/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h
@@ -33,12 +33,6 @@
 
 #include <cstdint>
 
-template <int N>
-void dropFront(int64_t arr[N], int64_t *res) {
-  for (unsigned i = 1; i < N; ++i)
-    *(res + i - 1) = arr[i];
-}
-
 //===----------------------------------------------------------------------===//
 // Codegen-compatible structures for Vector type.
 //===----------------------------------------------------------------------===//
@@ -129,6 +123,10 @@ struct StridedMemRefType {
     res.basePtr = basePtr;
     res.data = data;
     res.offset = offset + idx * strides[0];
+    auto dropFront = [](const int64_t *arr, int64_t *res) {
+      for (unsigned i = 1; i < N; ++i)
+        res[i - 1] = arr[i];
+    };
     dropFront<N>(sizes, res.sizes);
     dropFront<N>(strides, res.strides);
     return res;
@@ -164,6 +162,39 @@ struct UnrankedMemRefType {
   void *descriptor;
 };
 
+//===----------------------------------------------------------------------===//
+// DynamicMemRefType type.
+//===----------------------------------------------------------------------===//
+// A reference to one of the StridedMemRef types.
+template <typename T>
+class DynamicMemRefType {
+public:
+  explicit DynamicMemRefType(const StridedMemRefType<T, 0> &mem_ref)
+      : rank(0), basePtr(mem_ref.basePtr), data(mem_ref.data),
+        offset(mem_ref.offset), sizes(nullptr), strides(nullptr) {}
+  template <int N>
+  explicit DynamicMemRefType(const StridedMemRefType<T, N> &mem_ref)
+      : rank(N), basePtr(mem_ref.basePtr), data(mem_ref.data),
+        offset(mem_ref.offset), sizes(mem_ref.sizes), strides(mem_ref.strides) {
+  }
+  explicit DynamicMemRefType(const UnrankedMemRefType<T> &mem_ref)
+      : rank(mem_ref.rank) {
+    auto *desc = static_cast<StridedMemRefType<T, 1> *>(mem_ref.descriptor);
+    basePtr = desc->basePtr;
+    data = desc->data;
+    offset = desc->offset;
+    sizes = rank == 0 ? nullptr : desc->sizes;
+    strides = sizes + rank;
+  }
+
+  int64_t rank;
+  T *basePtr;
+  T *data;
+  int64_t offset;
+  const int64_t *sizes;
+  const int64_t *strides;
+};
+
 //===----------------------------------------------------------------------===//
 // Small runtime support "lib" for vector.print lowering during codegen.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/ExecutionEngine/RunnerUtils.h b/mlir/include/mlir/ExecutionEngine/RunnerUtils.h
index 5f239a4c146e..7729b9c88796 100644
--- a/mlir/include/mlir/ExecutionEngine/RunnerUtils.h
+++ b/mlir/include/mlir/ExecutionEngine/RunnerUtils.h
@@ -35,29 +35,35 @@
 
 #include "mlir/ExecutionEngine/CRunnerUtils.h"
 
-template <typename StreamType, typename T, int N>
-void printMemRefMetaData(StreamType &os, StridedMemRefType<T, N> &V) {
-  static_assert(N > 0, "Expected N > 0");
-  os << "Memref base@ = " << reinterpret_cast<void *>(V.data) << " rank = " << N
-     << " offset = " << V.offset << " sizes = [" << V.sizes[0];
-  for (unsigned i = 1; i < N; ++i)
-    os << ", " << V.sizes[i];
-  os << "] strides = [" << V.strides[0];
-  for (unsigned i = 1; i < N; ++i)
-    os << ", " << V.strides[i];
+template <typename T, typename StreamType>
+void printMemRefMetaData(StreamType &os, const DynamicMemRefType<T> &V) {
+  os << "base@ = " << reinterpret_cast<void *>(V.data) << " rank = " << V.rank
+     << " offset = " << V.offset;
+  auto print = [&](const int64_t *ptr) {
+    if (V.rank == 0)
+      return;
+    os << ptr[0];
+    for (int64_t i = 1; i < V.rank; ++i)
+      os << ", " << ptr[i];
+  };
+  os << " sizes = [";
+  print(V.sizes);
+  os << "] strides = [";
+  print(V.strides);
   os << "]";
 }
 
-template <typename StreamType, typename T>
-void printMemRefMetaData(StreamType &os, StridedMemRefType<T, 0> &V) {
-  os << "Memref base@ = " << reinterpret_cast<void *>(V.data) << " rank = 0"
-     << " offset = " << V.offset;
+template <typename StreamType, typename T, int N>
+void printMemRefMetaData(StreamType &os, StridedMemRefType<T, N> &V) {
+  static_assert(N >= 0, "Expected N > 0");
+  os << "MemRef ";
+  printMemRefMetaData(os, DynamicMemRefType<T>(V));
 }
 
-template <typename T, typename StreamType>
+template <typename StreamType, typename T>
 void printUnrankedMemRefMetaData(StreamType &os, UnrankedMemRefType<T> &V) {
-  os << "Unranked Memref rank = " << V.rank << " "
-     << "descriptor@ = " << reinterpret_cast<void *>(V.descriptor) << "\n";
+  os << "Unranked MemRef ";
+  printMemRefMetaData(os, DynamicMemRefType<T>(V));
 }
 
 ////////////////////////////////////////////////////////////////////////////////
@@ -118,88 +124,92 @@ std::ostream &operator<<(std::ostream &os, const Vector<T, M, Dims...> &v) {
   return os;
 }
 
-template <typename T, int N> struct MemRefDataPrinter {
-  static void print(std::ostream &os, T *base, int64_t rank, int64_t offset,
-                    int64_t *sizes, int64_t *strides);
-  static void printFirst(std::ostream &os, T *base, int64_t rank,
-                         int64_t offset, int64_t *sizes, int64_t *strides);
-  static void printLast(std::ostream &os, T *base, int64_t rank, int64_t offset,
-                        int64_t *sizes, int64_t *strides);
-};
-
-template <typename T> struct MemRefDataPrinter<T, 0> {
-  static void print(std::ostream &os, T *base, int64_t rank, int64_t offset,
-                    int64_t *sizes = nullptr, int64_t *strides = nullptr);
+template <typename T>
+struct MemRefDataPrinter {
+  static void print(std::ostream &os, T *base, int64_t dim, int64_t rank,
+                    int64_t offset, const int64_t *sizes,
+                    const int64_t *strides);
+  static void printFirst(std::ostream &os, T *base, int64_t dim, int64_t rank,
+                         int64_t offset, const int64_t *sizes,
+                         const int64_t *strides);
+  static void printLast(std::ostream &os, T *base, int64_t dim, int64_t rank,
+                        int64_t offset, const int64_t *sizes,
+                        const int64_t *strides);
 };
 
-template <typename T, int N>
-void MemRefDataPrinter<T, N>::printFirst(std::ostream &os, T *base,
-                                         int64_t rank, int64_t offset,
-                                         int64_t *sizes, int64_t *strides) {
+template <typename T>
+void MemRefDataPrinter<T>::printFirst(std::ostream &os, T *base, int64_t dim,
+                                      int64_t rank, int64_t offset,
+                                      const int64_t *sizes,
+                                      const int64_t *strides) {
   os << "[";
-  MemRefDataPrinter<T, N - 1>::print(os, base, rank, offset, sizes + 1,
-                                     strides + 1);
+  print(os, base, dim - 1, rank, offset, sizes + 1, strides + 1);
   // If single element, close square bracket and return early.
   if (sizes[0] <= 1) {
     os << "]";
     return;
   }
   os << ", ";
-  if (N > 1)
+  if (dim > 1)
     os << "\n";
 }
 
-template <typename T, int N>
-void MemRefDataPrinter<T, N>::print(std::ostream &os, T *base, int64_t rank,
-                                    int64_t offset, int64_t *sizes,
-                                    int64_t *strides) {
-  printFirst(os, base, rank, offset, sizes, strides);
+template <typename T>
+void MemRefDataPrinter<T>::print(std::ostream &os, T *base, int64_t dim,
+                                 int64_t rank, int64_t offset,
+                                 const int64_t *sizes, const int64_t *strides) {
+  if (dim == 0) {
+    os << base[offset];
+    return;
+  }
+  printFirst(os, base, dim, rank, offset, sizes, strides);
   for (unsigned i = 1; i + 1 < sizes[0]; ++i) {
-    printSpace(os, rank - N + 1);
-    MemRefDataPrinter<T, N - 1>::print(os, base, rank, offset + i * strides[0],
-                                       sizes + 1, strides + 1);
+    printSpace(os, rank - dim + 1);
+    print(os, base, dim - 1, rank, offset + i * strides[0], sizes + 1,
+          strides + 1);
     os << ", ";
-    if (N > 1)
+    if (dim > 1)
       os << "\n";
   }
   if (sizes[0] <= 1)
     return;
-  printLast(os, base, rank, offset, sizes, strides);
+  printLast(os, base, dim, rank, offset, sizes, strides);
 }
 
-template <typename T, int N>
-void MemRefDataPrinter<T, N>::printLast(std::ostream &os, T *base, int64_t rank,
-                                        int64_t offset, int64_t *sizes,
-                                        int64_t *strides) {
-  printSpace(os, rank - N + 1);
-  MemRefDataPrinter<T, N - 1>::print(os, base, rank,
-                                     offset + (sizes[0] - 1) * (*strides),
-                                     sizes + 1, strides + 1);
+template <typename T>
+void MemRefDataPrinter<T>::printLast(std::ostream &os, T *base, int64_t dim,
+                                     int64_t rank, int64_t offset,
+                                     const int64_t *sizes,
+                                     const int64_t *strides) {
+  printSpace(os, rank - dim + 1);
+  print(os, base, dim - 1, rank, offset + (sizes[0] - 1) * (*strides),
+        sizes + 1, strides + 1);
   os << "]";
 }
 
 template <typename T>
-void MemRefDataPrinter<T, 0>::print(std::ostream &os, T *base, int64_t rank,
-                                    int64_t offset, int64_t *sizes,
-                                    int64_t *strides) {
-  os << base[offset];
-}
-
-template <typename T, int N> void printMemRef(StridedMemRefType<T, N> &M) {
-  static_assert(N > 0, "Expected N > 0");
+void printMemRef(const DynamicMemRefType<T> &M) {
   printMemRefMetaData(std::cout, M);
   std::cout << " data = " << std::endl;
-  MemRefDataPrinter<T, N>::print(std::cout, M.data, N, M.offset, M.sizes,
-                                 M.strides);
+  if (M.rank == 0)
+    std::cout << "[";
+  MemRefDataPrinter<T>::print(std::cout, M.data, M.rank, M.rank, M.offset,
+                              M.sizes, M.strides);
+  if (M.rank == 0)
+    std::cout << "]";
   std::cout << std::endl;
 }
 
-template <typename T> void printMemRef(StridedMemRefType<T, 0> &M) {
-  printMemRefMetaData(std::cout, M);
-  std::cout << " data = " << std::endl;
-  std::cout << "[";
-  MemRefDataPrinter<T, 0>::print(std::cout, M.data, 0, M.offset);
-  std::cout << "]" << std::endl;
+template <typename T, int N>
+void printMemRef(StridedMemRefType<T, N> &M) {
+  std::cout << "Memref ";
+  printMemRef(DynamicMemRefType<T>(M));
+}
+
+template <typename T>
+void printMemRef(UnrankedMemRefType<T> &M) {
+  std::cout << "Unranked Memref ";
+  printMemRef(DynamicMemRefType<T>(M));
 }
 } // namespace impl
 

diff  --git a/mlir/lib/ExecutionEngine/RunnerUtils.cpp b/mlir/lib/ExecutionEngine/RunnerUtils.cpp
index 7991eca61994..7497ebdacf68 100644
--- a/mlir/lib/ExecutionEngine/RunnerUtils.cpp
+++ b/mlir/lib/ExecutionEngine/RunnerUtils.cpp
@@ -24,57 +24,16 @@ extern "C" void _mlir_ciface_print_memref_vector_4x4xf32(
   impl::printMemRef(*M);
 }
 
-#define MEMREF_CASE(TYPE, RANK)                                                \
-  case RANK:                                                                   \
-    impl::printMemRef(*(static_cast<StridedMemRefType<TYPE, RANK> *>(ptr)));   \
-    break
-
 extern "C" void _mlir_ciface_print_memref_i8(UnrankedMemRefType<int8_t> *M) {
-  printUnrankedMemRefMetaData(std::cout, *M);
-  int64_t rank = M->rank;
-  void *ptr = M->descriptor;
-
-  switch (rank) {
-    MEMREF_CASE(int8_t, 0);
-    MEMREF_CASE(int8_t, 1);
-    MEMREF_CASE(int8_t, 2);
-    MEMREF_CASE(int8_t, 3);
-    MEMREF_CASE(int8_t, 4);
-  default:
-    assert(0 && "Unsupported rank to print");
-  }
+  impl::printMemRef(*M);
 }
 
 extern "C" void _mlir_ciface_print_memref_i32(UnrankedMemRefType<int32_t> *M) {
-  printUnrankedMemRefMetaData(std::cout, *M);
-  int64_t rank = M->rank;
-  void *ptr = M->descriptor;
-
-  switch (rank) {
-    MEMREF_CASE(int32_t, 0);
-    MEMREF_CASE(int32_t, 1);
-    MEMREF_CASE(int32_t, 2);
-    MEMREF_CASE(int32_t, 3);
-    MEMREF_CASE(int32_t, 4);
-  default:
-    assert(0 && "Unsupported rank to print");
-  }
+  impl::printMemRef(*M);
 }
 
 extern "C" void _mlir_ciface_print_memref_f32(UnrankedMemRefType<float> *M) {
-  printUnrankedMemRefMetaData(std::cout, *M);
-  int64_t rank = M->rank;
-  void *ptr = M->descriptor;
-
-  switch (rank) {
-    MEMREF_CASE(float, 0);
-    MEMREF_CASE(float, 1);
-    MEMREF_CASE(float, 2);
-    MEMREF_CASE(float, 3);
-    MEMREF_CASE(float, 4);
-  default:
-    assert(0 && "Unsupported rank to print");
-  }
+  impl::printMemRef(*M);
 }
 
 extern "C" void print_memref_i32(int64_t rank, void *ptr) {

diff  --git a/mlir/test/mlir-cpu-runner/unranked_memref.mlir b/mlir/test/mlir-cpu-runner/unranked_memref.mlir
index aa54b56b06b7..0eb68ac03368 100644
--- a/mlir/test/mlir-cpu-runner/unranked_memref.mlir
+++ b/mlir/test/mlir-cpu-runner/unranked_memref.mlir
@@ -1,25 +1,21 @@
 // RUN: mlir-opt %s -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | mlir-cpu-runner -e main -entry-point-result=void -shared-libs=%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_test_cblas%shlibext,%linalg_test_lib_dir/libmlir_test_cblas_interface%shlibext | FileCheck %s
 
-// CHECK: rank = 2
 // CHECK: rank = 2
 // CHECK-SAME: sizes = [10, 3]
 // CHECK-SAME: strides = [3, 1]
 // CHECK-COUNT-10: [10, 10, 10]
 //
 // CHECK: rank = 2
-// CHECK: rank = 2
 // CHECK-SAME: sizes = [10, 3]
 // CHECK-SAME: strides = [3, 1]
 // CHECK-COUNT-10: [5, 5, 5]
 //
 // CHECK: rank = 2
-// CHECK: rank = 2
 // CHECK-SAME: sizes = [10, 3]
 // CHECK-SAME: strides = [3, 1]
 // CHECK-COUNT-10: [2, 2, 2]
 //
 // CHECK: rank = 0
-// CHECK: rank = 0
 // 122 is ASCII for 'z'.
 // CHECK: [z]
 func @main() -> () {

diff  --git a/mlir/test/mlir-cpu-runner/utils.mlir b/mlir/test/mlir-cpu-runner/utils.mlir
index d3ab6177eb65..65957400bf7f 100644
--- a/mlir/test/mlir-cpu-runner/utils.mlir
+++ b/mlir/test/mlir-cpu-runner/utils.mlir
@@ -12,8 +12,7 @@ func @print_0d() {
   dealloc %A : memref<f32>
   return
 }
-// PRINT-0D: Unranked Memref rank = 0 descriptor@ = {{.*}}
-// PRINT-0D: Memref base@ = {{.*}} rank = 0 offset = 0 data =
+// PRINT-0D: Unranked Memref base@ = {{.*}} rank = 0 offset = 0 sizes = [] strides = [] data =
 // PRINT-0D: [2]
 
 func @print_1d() {
@@ -26,7 +25,7 @@ func @print_1d() {
   dealloc %A : memref<16xf32>
   return
 }
-// PRINT-1D: Memref base@ = {{.*}} rank = 1 offset = 0 sizes = [16] strides = [1] data =
+// PRINT-1D: Unranked Memref base@ = {{.*}} rank = 1 offset = 0 sizes = [16] strides = [1] data =
 // PRINT-1D-NEXT: [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]
 
 func @print_3d() {
@@ -43,7 +42,7 @@ func @print_3d() {
   dealloc %A : memref<3x4x5xf32>
   return
 }
-// PRINT-3D: Memref base@ = {{.*}} rank = 3 offset = 0 sizes = [3, 4, 5] strides = [20, 5, 1] data =
+// PRINT-3D: Unranked Memref base@ = {{.*}} rank = 3 offset = 0 sizes = [3, 4, 5] strides = [20, 5, 1] data =
 // PRINT-3D-COUNT-4: {{.*[[:space:]].*}}2,    2,    2,    2,    2
 // PRINT-3D-COUNT-4: {{.*[[:space:]].*}}2,    2,    2,    2,    2
 // PRINT-3D-COUNT-2: {{.*[[:space:]].*}}2,    2,    2,    2,    2

diff  --git a/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp b/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp
index dbe78a55c0b1..705fa9f00930 100644
--- a/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp
+++ b/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp
@@ -83,10 +83,10 @@ extern "C" void mgpuMemHostRegister(void *ptr, uint64_t sizeBytes) {
 // Allows to register a MemRef with the CUDA runtime. Initializes array with
 // value. Helpful until we have transfer functions implemented.
 template <typename T>
-void mcuMemHostRegisterMemRef(T *pointer, llvm::ArrayRef<int64_t> sizes,
-                              llvm::ArrayRef<int64_t> strides, T value) {
-  assert(sizes.size() == strides.size());
-  llvm::SmallVector<int64_t, 4> denseStrides(strides.size());
+void mcuMemHostRegisterMemRef(const DynamicMemRefType<T> &mem_ref, T value) {
+  llvm::SmallVector<int64_t, 4> denseStrides(mem_ref.rank);
+  llvm::ArrayRef<int64_t> sizes(mem_ref.sizes, mem_ref.rank);
+  llvm::ArrayRef<int64_t> strides(mem_ref.strides, mem_ref.rank);
 
   std::partial_sum(sizes.rbegin(), sizes.rend(), denseStrides.rbegin(),
                    std::multiplies<int64_t>());
@@ -98,20 +98,17 @@ void mcuMemHostRegisterMemRef(T *pointer, llvm::ArrayRef<int64_t> sizes,
   denseStrides.back() = 1;
   assert(strides == llvm::makeArrayRef(denseStrides));
 
+  auto *pointer = mem_ref.data + mem_ref.offset;
   std::fill_n(pointer, count, value);
   mgpuMemHostRegister(pointer, count * sizeof(T));
 }
 
 extern "C" void mcuMemHostRegisterFloat(int64_t rank, void *ptr) {
-  auto *desc = static_cast<StridedMemRefType<float, 1> *>(ptr);
-  auto sizes = llvm::ArrayRef<int64_t>(desc->sizes, rank);
-  auto strides = llvm::ArrayRef<int64_t>(desc->sizes + rank, rank);
-  mcuMemHostRegisterMemRef(desc->data + desc->offset, sizes, strides, 1.23f);
+  UnrankedMemRefType<float> mem_ref = {rank, ptr};
+  mcuMemHostRegisterMemRef(DynamicMemRefType<float>(mem_ref), 1.23f);
 }
 
 extern "C" void mcuMemHostRegisterInt32(int64_t rank, void *ptr) {
-  auto *desc = static_cast<StridedMemRefType<int32_t, 1> *>(ptr);
-  auto sizes = llvm::ArrayRef<int64_t>(desc->sizes, rank);
-  auto strides = llvm::ArrayRef<int64_t>(desc->sizes + rank, rank);
-  mcuMemHostRegisterMemRef(desc->data + desc->offset, sizes, strides, 123);
+  UnrankedMemRefType<int32_t> mem_ref = {rank, ptr};
+  mcuMemHostRegisterMemRef(DynamicMemRefType<int32_t>(mem_ref), 123);
 }


        


More information about the Mlir-commits mailing list