[Mlir-commits] [mlir] 11f32a4 - Add convenience C++ helper to manipulate ranked strided memref

Mehdi Amini llvmlistbot at llvm.org
Wed Feb 10 09:40:57 PST 2021


Author: Mehdi Amini
Date: 2021-02-10T17:40:36Z
New Revision: 11f32a41c2144aeec80d1dce8cc6908fa91794a3

URL: https://github.com/llvm/llvm-project/commit/11f32a41c2144aeec80d1dce8cc6908fa91794a3
DIFF: https://github.com/llvm/llvm-project/commit/11f32a41c2144aeec80d1dce8cc6908fa91794a3.diff

LOG: Add convenience C++ helper to manipulate ranked strided memref

Differential Revision: https://reviews.llvm.org/D96192

Added: 
    mlir/include/mlir/ExecutionEngine/MemRefUtils.h

Modified: 
    mlir/include/mlir/ExecutionEngine/CRunnerUtils.h
    mlir/unittests/ExecutionEngine/Invoke.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h b/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h
index 15335370531e..805c98a7e82b 100644
--- a/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h
+++ b/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h
@@ -31,11 +31,15 @@
 #define MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS
 #endif // _WIN32
 
+#include <array>
+#include <cassert>
 #include <cstdint>
+#include <initializer_list>
 
 //===----------------------------------------------------------------------===//
 // Codegen-compatible structures for Vector type.
 //===----------------------------------------------------------------------===//
+namespace mlir {
 namespace detail {
 
 constexpr bool isPowerOf2(int N) { return (!(N & (N - 1))); }
@@ -65,9 +69,8 @@ struct Vector1D<T, Dim, /*IsPowerOf2=*/true> {
 template <typename T, int Dim>
 struct Vector1D<T, Dim, /*IsPowerOf2=*/false> {
   Vector1D() {
-    static_assert(detail::nextPowerOf2(sizeof(T[Dim])) > sizeof(T[Dim]),
-                  "size error");
-    static_assert(detail::nextPowerOf2(sizeof(T[Dim])) < 2 * sizeof(T[Dim]),
+    static_assert(nextPowerOf2(sizeof(T[Dim])) > sizeof(T[Dim]), "size error");
+    static_assert(nextPowerOf2(sizeof(T[Dim])) < 2 * sizeof(T[Dim]),
                   "size error");
   }
   inline T &operator[](unsigned i) { return vector[i]; }
@@ -75,9 +78,10 @@ struct Vector1D<T, Dim, /*IsPowerOf2=*/false> {
 
 private:
   T vector[Dim];
-  char padding[detail::nextPowerOf2(sizeof(T[Dim])) - sizeof(T[Dim])];
+  char padding[nextPowerOf2(sizeof(T[Dim])) - sizeof(T[Dim])];
 };
 } // end namespace detail
+} // end namespace mlir
 
 // N-D vectors recurse down to 1-D.
 template <typename T, int Dim, int... Dims>
@@ -95,7 +99,9 @@ struct Vector {
 // We insert explicit padding in to account for this.
 template <typename T, int Dim>
 struct Vector<T, Dim>
-    : public detail::Vector1D<T, Dim, detail::isPowerOf2(sizeof(T[Dim]))> {};
+    : public mlir::detail::Vector1D<T, Dim,
+                                    mlir::detail::isPowerOf2(sizeof(T[Dim]))> {
+};
 
 template <int D1, typename T>
 using Vector1D = Vector<T, D1>;
@@ -115,6 +121,9 @@ void dropFront(int64_t arr[N], int64_t *res) {
 //===----------------------------------------------------------------------===//
 // Codegen-compatible structures for StridedMemRef type.
 //===----------------------------------------------------------------------===//
+template <typename T, int Rank>
+class StridedMemrefIterator;
+
 /// StridedMemRef descriptor type with static rank.
 template <typename T, int N>
 struct StridedMemRefType {
@@ -123,6 +132,23 @@ struct StridedMemRefType {
   int64_t offset;
   int64_t sizes[N];
   int64_t strides[N];
+
+  template <typename Range>
+  T &operator[](Range indices) {
+    assert(indices.size() == N &&
+           "indices should match rank in memref subscript");
+    int64_t curOffset = offset;
+    for (int dim = N - 1; dim >= 0; --dim) {
+      int64_t currentIndex = *(indices.begin() + dim);
+      assert(currentIndex < sizes[dim] && "Index overflow");
+      curOffset += currentIndex * strides[dim];
+    }
+    return data[curOffset];
+  }
+
+  StridedMemrefIterator<T, N> begin() { return {*this}; }
+  StridedMemrefIterator<T, N> end() { return {*this, -1}; }
+
   // This operator[] is extremely slow and only for sugaring purposes.
   StridedMemRefType<T, N - 1> operator[](int64_t idx) {
     StridedMemRefType<T, N - 1> res;
@@ -143,6 +169,17 @@ struct StridedMemRefType<T, 1> {
   int64_t offset;
   int64_t sizes[1];
   int64_t strides[1];
+
+  template <typename Range>
+  T &operator[](Range indices) {
+    assert(indices.size() == 1 &&
+           "indices should match rank in memref subscript");
+    return (*this)[*indices.begin()];
+  }
+
+  StridedMemrefIterator<T, 1> begin() { return {*this}; }
+  StridedMemrefIterator<T, 1> end() { return {*this, -1}; }
+
   T &operator[](int64_t idx) { return *(data + offset + idx * strides[0]); }
 };
 
@@ -152,6 +189,99 @@ struct StridedMemRefType<T, 0> {
   T *basePtr;
   T *data;
   int64_t offset;
+
+  template <typename Range>
+  T &operator[](Range indices) {
+    assert(indices.empty() &&
+           "Expect empty indices for 0-rank memref subscript");
+    return data[offset];
+  }
+
+  StridedMemrefIterator<T, 0> begin() { return {*this}; }
+  StridedMemrefIterator<T, 0> end() { return {*this, 1}; }
+};
+
+/// Iterate over all elements in a strided memref.
+template <typename T, int Rank>
+class StridedMemrefIterator {
+public:
+  StridedMemrefIterator(StridedMemRefType<T, Rank> &descriptor,
+                        int64_t offset = 0)
+      : offset(offset), descriptor(descriptor) {}
+  StridedMemrefIterator<T, Rank> &operator++() {
+    int dim = Rank - 1;
+    while (dim >= 0 && indices[dim] == (descriptor.sizes[dim] - 1)) {
+      offset -= indices[dim] * descriptor.strides[dim];
+      indices[dim] = 0;
+      --dim;
+    }
+    if (dim < 0) {
+      offset = -1;
+      return *this;
+    }
+    ++indices[dim];
+    offset += descriptor.strides[dim];
+    return *this;
+  }
+
+  T &operator*() { return descriptor.data[offset]; }
+  T *operator->() { return &descriptor.data[offset]; }
+
+  const std::array<int64_t, Rank> &getIndices() { return indices; }
+
+  bool operator==(const StridedMemrefIterator &other) const {
+    return other.offset == offset && &other.descriptor == &descriptor;
+  }
+
+  bool operator!=(const StridedMemrefIterator &other) const {
+    return !(*this == other);
+  }
+
+private:
+  /// Offset in the buffer. This can be derived from the indices and the
+  /// descriptor.
+  int64_t offset = 0;
+  /// Array of indices in the multi-dimensional memref.
+  std::array<int64_t, Rank> indices = {};
+  /// Descriptor for the strided memref.
+  StridedMemRefType<T, Rank> &descriptor;
+};
+
+/// Iterate over all elements in a 0-ranked strided memref.
+template <typename T>
+class StridedMemrefIterator<T, 0> {
+public:
+  StridedMemrefIterator(StridedMemRefType<T, 0> &descriptor, int64_t offset = 0)
+      : elt(descriptor.data + offset) {}
+
+  StridedMemrefIterator<T, 0> &operator++() {
+    ++elt;
+    return *this;
+  }
+
+  T &operator*() { return *elt; }
+  T *operator->() { return elt; }
+
+  // There are no indices for a 0-ranked memref, but this API is provided for
+  // consistency with the general case.
+  const std::array<int64_t, 0> &getIndices() {
+    // Since this is a 0-array of indices we can keep a single global const
+    // copy.
+    static const std::array<int64_t, 0> indices = {};
+    return indices;
+  }
+
+  bool operator==(const StridedMemrefIterator &other) const {
+    return other.elt == elt;
+  }
+
+  bool operator!=(const StridedMemrefIterator &other) const {
+    return !(*this == other);
+  }
+
+private:
+  /// Pointer to the single element in the zero-ranked memref.
+  T *elt;
 };
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/ExecutionEngine/MemRefUtils.h b/mlir/include/mlir/ExecutionEngine/MemRefUtils.h
new file mode 100644
index 000000000000..53227e9f4eae
--- /dev/null
+++ b/mlir/include/mlir/ExecutionEngine/MemRefUtils.h
@@ -0,0 +1,214 @@
+//===- MemRefUtils.h - Memref helpers to invoke MLIR JIT code ---*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Utils for MLIR ABI interfacing with frameworks.
+//
+// The templated free functions below make it possible to allocate dense
+// contiguous buffers with shapes that interoperate properly with the MLIR
+// codegen ABI.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/ExecutionEngine/CRunnerUtils.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/Optional.h"
+#include "llvm/ADT/STLExtras.h"
+
+#include "llvm/Support/raw_ostream.h"
+
+#include <algorithm>
+#include <array>
+#include <cassert>
+#include <functional>
+#include <initializer_list>
+#include <memory>
+
+#ifndef MLIR_EXECUTIONENGINE_MEMREFUTILS_H_
+#define MLIR_EXECUTIONENGINE_MEMREFUTILS_H_
+
+namespace mlir {
+using AllocFunType = llvm::function_ref<void *(size_t)>;
+
+namespace detail {
+
+/// Given a shape with sizes greater than 0 along all dimensions, returns the
+/// distance, in number of elements, between a slice in a dimension and the next
+/// slice in the same dimension.
+///    e.g. shape[3, 4, 5] -> strides[20, 5, 1]
+template <size_t N>
+inline std::array<int64_t, N> makeStrides(ArrayRef<int64_t> shape) {
+  assert(shape.size() == N && "expect shape specification to match rank");
+  std::array<int64_t, N> res;
+  int64_t running = 1;
+  for (int64_t idx = N - 1; idx >= 0; --idx) {
+    assert(shape[idx] && "size must be non-negative for all shape dimensions");
+    res[idx] = running;
+    running *= shape[idx];
+  }
+  return res;
+}
+
+/// Build a `StridedMemRefDescriptor<T, N>` that matches the MLIR ABI.
+/// This is an implementation detail that is kept in sync with MLIR codegen
+/// conventions.  Additionally takes a `shapeAlloc` array which
+/// is used instead of `shape` to allocate "more aligned" data and compute the
+/// corresponding strides.
+template <int N, typename T>
+typename std::enable_if<(N >= 1), StridedMemRefType<T, N>>::type
+makeStridedMemRefDescriptor(T *ptr, T *alignedPtr, ArrayRef<int64_t> shape,
+                            ArrayRef<int64_t> shapeAlloc) {
+  assert(shape.size() == N);
+  assert(shapeAlloc.size() == N);
+  StridedMemRefType<T, N> descriptor;
+  descriptor.basePtr = static_cast<T *>(ptr);
+  descriptor.data = static_cast<T *>(alignedPtr);
+  descriptor.offset = 0;
+  std::copy(shape.begin(), shape.end(), descriptor.sizes);
+  auto strides = makeStrides<N>(shapeAlloc);
+  std::copy(strides.begin(), strides.end(), descriptor.strides);
+  return descriptor;
+}
+
+/// Build a `StridedMemRefDescriptor<T, 0>` that matches the MLIR ABI.
+/// This is an implementation detail that is kept in sync with MLIR codegen
+/// conventions.  Additionally takes a `shapeAlloc` array which
+/// is used instead of `shape` to allocate "more aligned" data and compute the
+/// corresponding strides.
+template <int N, typename T>
+typename std::enable_if<(N == 0), StridedMemRefType<T, 0>>::type
+makeStridedMemRefDescriptor(T *ptr, T *alignedPtr, ArrayRef<int64_t> shape = {},
+                            ArrayRef<int64_t> shapeAlloc = {}) {
+  assert(shape.size() == N);
+  assert(shapeAlloc.size() == N);
+  StridedMemRefType<T, 0> descriptor;
+  descriptor.basePtr = static_cast<T *>(ptr);
+  descriptor.data = static_cast<T *>(alignedPtr);
+  descriptor.offset = 0;
+  return descriptor;
+}
+
+/// Align `nElements` of type T with an optional `alignment`.
+/// This replaces a portable `posix_memalign`.
+/// `alignment` must be a power of 2 and greater than the size of T. By default
+/// the alignment is sizeof(T).
+template <typename T>
+std::pair<T *, T *>
+allocAligned(size_t nElements, AllocFunType allocFun = &::malloc,
+             llvm::Optional<uint64_t> alignment = llvm::Optional<uint64_t>()) {
+  assert(sizeof(T) < (1ul << 32) && "Elemental type overflows");
+  auto size = nElements * sizeof(T);
+  auto desiredAlignment = alignment.getValueOr(nextPowerOf2(sizeof(T)));
+  assert((desiredAlignment & (desiredAlignment - 1)) == 0);
+  assert(desiredAlignment >= sizeof(T));
+  T *data = reinterpret_cast<T *>(allocFun(size + desiredAlignment));
+  uintptr_t addr = reinterpret_cast<uintptr_t>(data);
+  uintptr_t rem = addr % desiredAlignment;
+  T *alignedData = (rem == 0)
+                       ? data
+                       : reinterpret_cast<T *>(addr + (desiredAlignment - rem));
+  assert(reinterpret_cast<uintptr_t>(alignedData) % desiredAlignment == 0);
+  return std::make_pair(data, alignedData);
+}
+
+} // namespace detail
+
+//===----------------------------------------------------------------------===//
+// Public API
+//===----------------------------------------------------------------------===//
+
+/// Convenient callback to "visit" a memref element by element.
+/// This takes a reference to an individual element as well as the coordinates.
+/// It can be used in conjuction with a StridedMemrefIterator.
+template <typename T>
+using ElementWiseVisitor = llvm::function_ref<void(T &ptr, ArrayRef<int64_t>)>;
+
+/// Owning MemRef type that abstracts over the runtime type for ranked strided
+/// memref.
+template <typename T, int Rank>
+class OwningMemRef {
+public:
+  using DescriptorType = StridedMemRefType<T, Rank>;
+  using FreeFunType = std::function<void(DescriptorType)>;
+
+  /// Allocate a new dense StridedMemrefRef with a given `shape`. An optional
+  /// `shapeAlloc` array can be supplied to "pad" every dimension individually.
+  /// If an ElementWiseVisitor is provided, it will be used to initialize the
+  /// data, else the memory will be zero-initialized. The alloc and free method
+  /// used to manage the data allocation can be optionally provided, and default
+  /// to malloc/free.
+  OwningMemRef(
+      ArrayRef<int64_t> shape, ArrayRef<int64_t> shapeAlloc = {},
+      ElementWiseVisitor<T> init = {},
+      llvm::Optional<uint64_t> alignment = llvm::Optional<uint64_t>(),
+      AllocFunType allocFun = &::malloc,
+      std::function<void(StridedMemRefType<T, Rank>)> freeFun =
+          [](StridedMemRefType<T, Rank> descriptor) {
+            ::free(descriptor.data);
+          })
+      : freeFunc(freeFun) {
+    if (shapeAlloc.empty())
+      shapeAlloc = shape;
+    assert(shape.size() == Rank);
+    assert(shapeAlloc.size() == Rank);
+    for (unsigned i = 0; i < Rank; ++i)
+      assert(shape[i] <= shapeAlloc[i] &&
+             "shapeAlloc must be greater than or equal to shape");
+    int64_t nElements = 1;
+    for (int64_t s : shapeAlloc)
+      nElements *= s;
+    T *data, *alignedData;
+    std::tie(data, alignedData) =
+        detail::allocAligned<T>(nElements, allocFun, alignment);
+    descriptor = detail::makeStridedMemRefDescriptor<Rank>(data, alignedData,
+                                                           shape, shapeAlloc);
+    if (init) {
+      for (StridedMemrefIterator<T, Rank> it = descriptor.begin(),
+                                          end = descriptor.end();
+           it != end; ++it)
+        init(*it, it.getIndices());
+    } else {
+      memset(descriptor.data, 0,
+             nElements * sizeof(T) +
+                 alignment.getValueOr(detail::nextPowerOf2(sizeof(T))));
+    }
+  }
+  /// Take ownership of an existing descriptor with a custom deleter.
+  OwningMemRef(DescriptorType descriptor, FreeFunType freeFunc)
+      : freeFunc(freeFunc), descriptor(descriptor) {}
+  ~OwningMemRef() {
+    if (freeFunc)
+      freeFunc(descriptor);
+  }
+  OwningMemRef(const OwningMemRef &) = delete;
+  OwningMemRef &operator=(const OwningMemRef &) = delete;
+  OwningMemRef &operator=(const OwningMemRef &&other) {
+    freeFunc = other.freeFunc;
+    descriptor = other.descriptor;
+    other.freeFunc = nullptr;
+    memset(0, &other.descriptor, sizeof(other.descriptor));
+  }
+  OwningMemRef(OwningMemRef &&other) { *this = std::move(other); }
+
+  DescriptorType &operator*() { return descriptor; }
+  DescriptorType *operator->() { return &descriptor; }
+  T &operator[](std::initializer_list<int64_t> indices) {
+    return descriptor[std::move(indices)];
+  }
+
+private:
+  /// Custom deleter used to release the data buffer manager with the descriptor
+  /// below.
+  FreeFunType freeFunc;
+  /// The descriptor is an instance of StridedMemRefType<T, rank>.
+  DescriptorType descriptor;
+};
+
+} // namespace mlir
+
+#endif // MLIR_EXECUTIONENGINE_MEMREFUTILS_H_

diff  --git a/mlir/unittests/ExecutionEngine/Invoke.cpp b/mlir/unittests/ExecutionEngine/Invoke.cpp
index c9abf2e108b8..6afbb08a8fab 100644
--- a/mlir/unittests/ExecutionEngine/Invoke.cpp
+++ b/mlir/unittests/ExecutionEngine/Invoke.cpp
@@ -13,6 +13,7 @@
 #include "mlir/Dialect/Linalg/Passes.h"
 #include "mlir/ExecutionEngine/CRunnerUtils.h"
 #include "mlir/ExecutionEngine/ExecutionEngine.h"
+#include "mlir/ExecutionEngine/MemRefUtils.h"
 #include "mlir/ExecutionEngine/RunnerUtils.h"
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/InitAllDialects.h"
@@ -89,4 +90,163 @@ TEST(MLIRExecutionEngine, SubtractFloat) {
   ASSERT_EQ(result, 42.f);
 }
 
+TEST(NativeMemRefJit, ZeroRankMemref) {
+  OwningMemRef<float, 0> A({});
+  A[{}] = 42.;
+  ASSERT_EQ(*A->data, 42);
+  A[{}] = 0;
+  std::string moduleStr = R"mlir(
+  func @zero_ranked(%arg0 : memref<f32>) attributes { llvm.emit_c_interface } {
+    %cst42 = constant 42.0 : f32
+    store %cst42, %arg0[] : memref<f32>
+    return
+  }
+  )mlir";
+  MLIRContext context;
+  registerAllDialects(context.getDialectRegistry());
+  auto module = parseSourceString(moduleStr, &context);
+  ASSERT_TRUE(!!module);
+  ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
+  auto jitOrError = ExecutionEngine::create(*module);
+  ASSERT_TRUE(!!jitOrError);
+  auto jit = std::move(jitOrError.get());
+
+  llvm::Error error = jit->invoke("zero_ranked", &*A);
+  ASSERT_TRUE(!error);
+  EXPECT_EQ((A[{}]), 42.);
+  for (float &elt : *A)
+    EXPECT_EQ(&elt, &(A[{}]));
+}
+
+TEST(NativeMemRefJit, RankOneMemref) {
+  int64_t shape[] = {9};
+  OwningMemRef<float, 1> A(shape);
+  int count = 1;
+  for (float &elt : *A) {
+    EXPECT_EQ(&elt, &(A[{count - 1}]));
+    elt = count++;
+  }
+
+  std::string moduleStr = R"mlir(
+  func @one_ranked(%arg0 : memref<?xf32>) attributes { llvm.emit_c_interface } {
+    %cst42 = constant 42.0 : f32
+    %cst5 = constant 5 : index
+    store %cst42, %arg0[%cst5] : memref<?xf32>
+    return
+  }
+  )mlir";
+  MLIRContext context;
+  registerAllDialects(context.getDialectRegistry());
+  auto module = parseSourceString(moduleStr, &context);
+  ASSERT_TRUE(!!module);
+  ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
+  auto jitOrError = ExecutionEngine::create(*module);
+  ASSERT_TRUE(!!jitOrError);
+  auto jit = std::move(jitOrError.get());
+
+  llvm::Error error = jit->invoke("one_ranked", &*A);
+  ASSERT_TRUE(!error);
+  count = 1;
+  for (float &elt : *A) {
+    if (count == 6)
+      EXPECT_EQ(elt, 42.);
+    else
+      EXPECT_EQ(elt, count);
+    count++;
+  }
+}
+
+TEST(NativeMemRefJit, BasicMemref) {
+  constexpr int K = 3;
+  constexpr int M = 7;
+  // Prepare arguments beforehand.
+  auto init = [=](float &elt, ArrayRef<int64_t> indices) {
+    assert(indices.size() == 2);
+    elt = M * indices[0] + indices[1];
+  };
+  int64_t shape[] = {K, M};
+  int64_t shapeAlloc[] = {K + 1, M + 1};
+  OwningMemRef<float, 2> A(shape, shapeAlloc, init);
+  ASSERT_EQ(A->sizes[0], K);
+  ASSERT_EQ(A->sizes[1], M);
+  ASSERT_EQ(A->strides[0], M + 1);
+  ASSERT_EQ(A->strides[1], 1);
+  for (int i = 0; i < K; ++i)
+    for (int j = 0; j < M; ++j)
+      EXPECT_EQ((A[{i, j}]), i * M + j);
+
+  std::string moduleStr = R"mlir(
+  func @rank2_memref(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?xf32>) attributes { llvm.emit_c_interface } {
+    %x = constant 2 : index
+    %y = constant 1 : index
+    %cst42 = constant 42.0 : f32
+    store %cst42, %arg0[%y, %x] : memref<?x?xf32>
+    store %cst42, %arg1[%x, %y] : memref<?x?xf32>
+    return
+  }
+  )mlir";
+  MLIRContext context;
+  registerAllDialects(context.getDialectRegistry());
+  OwningModuleRef module = parseSourceString(moduleStr, &context);
+  ASSERT_TRUE(!!module);
+  ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
+  auto jitOrError = ExecutionEngine::create(*module);
+  ASSERT_TRUE(!!jitOrError);
+  std::unique_ptr<ExecutionEngine> jit = std::move(jitOrError.get());
+
+  llvm::Error error = jit->invoke("rank2_memref", &*A, &*A);
+  ASSERT_TRUE(!error);
+  EXPECT_EQ((A[{1, 2}]), 42.);
+  EXPECT_EQ((A[{2, 1}]), 42.);
+}
+
+// A helper function that will be called from the JIT
+static void memref_multiply(::StridedMemRefType<float, 2> *memref,
+                            int32_t coefficient) {
+  for (float &elt : *memref)
+    elt *= coefficient;
+}
+
+TEST(NativeMemRefJit, JITCallback) {
+  constexpr int K = 2;
+  constexpr int M = 2;
+  int64_t shape[] = {K, M};
+  int64_t shapeAlloc[] = {K + 1, M + 1};
+  OwningMemRef<float, 2> A(shape, shapeAlloc);
+  int count = 1;
+  for (float &elt : *A)
+    elt = count++;
+
+  std::string moduleStr = R"mlir(
+  func private @callback(%arg0: memref<?x?xf32>, %coefficient: i32)  attributes { llvm.emit_c_interface } 
+  func @caller_for_callback(%arg0: memref<?x?xf32>, %coefficient: i32) attributes { llvm.emit_c_interface } {
+    %unranked = memref_cast %arg0: memref<?x?xf32> to memref<*xf32>
+    call @callback(%arg0, %coefficient) : (memref<?x?xf32>, i32) -> ()
+    return
+  }
+  )mlir";
+  MLIRContext context;
+  registerAllDialects(context.getDialectRegistry());
+  auto module = parseSourceString(moduleStr, &context);
+  ASSERT_TRUE(!!module);
+  ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
+  auto jitOrError = ExecutionEngine::create(*module);
+  ASSERT_TRUE(!!jitOrError);
+  auto jit = std::move(jitOrError.get());
+  // Define any extra symbols so they're available at runtime.
+  jit->registerSymbols([&](llvm::orc::MangleAndInterner interner) {
+    llvm::orc::SymbolMap symbolMap;
+    symbolMap[interner("_mlir_ciface_callback")] =
+        llvm::JITEvaluatedSymbol::fromPointer(memref_multiply);
+    return symbolMap;
+  });
+
+  int32_t coefficient = 3.;
+  llvm::Error error = jit->invoke("caller_for_callback", &*A, coefficient);
+  ASSERT_TRUE(!error);
+  count = 1;
+  for (float elt : *A)
+    ASSERT_EQ(elt, coefficient * count++);
+}
+
 #endif // _WIN32


        


More information about the Mlir-commits mailing list