[Mlir-commits] [mlir] 9680ea5 - Add convenience C++ helper to manipulate ranked strided memref
Mehdi Amini
llvmlistbot at llvm.org
Wed Feb 10 10:58:20 PST 2021
Author: Mehdi Amini
Date: 2021-02-10T18:58:05Z
New Revision: 9680ea5c982e26212d7bb401532c8273cdeaa3e0
URL: https://github.com/llvm/llvm-project/commit/9680ea5c982e26212d7bb401532c8273cdeaa3e0
DIFF: https://github.com/llvm/llvm-project/commit/9680ea5c982e26212d7bb401532c8273cdeaa3e0.diff
LOG: Add convenience C++ helper to manipulate ranked strided memref
Reland 11f32a41c21 that was reverted in e49967fbd90 after fixing the build.
Differential Revision: https://reviews.llvm.org/D96192
Added:
mlir/include/mlir/ExecutionEngine/MemRefUtils.h
Modified:
mlir/include/mlir/ExecutionEngine/CRunnerUtils.h
mlir/unittests/ExecutionEngine/Invoke.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h b/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h
index 15335370531e..0c2638307604 100644
--- a/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h
+++ b/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h
@@ -31,11 +31,15 @@
#define MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS
#endif // _WIN32
+#include <array>
+#include <cassert>
#include <cstdint>
+#include <initializer_list>
//===----------------------------------------------------------------------===//
// Codegen-compatible structures for Vector type.
//===----------------------------------------------------------------------===//
+namespace mlir {
namespace detail {
constexpr bool isPowerOf2(int N) { return (!(N & (N - 1))); }
@@ -65,9 +69,8 @@ struct Vector1D<T, Dim, /*IsPowerOf2=*/true> {
template <typename T, int Dim>
struct Vector1D<T, Dim, /*IsPowerOf2=*/false> {
Vector1D() {
- static_assert(detail::nextPowerOf2(sizeof(T[Dim])) > sizeof(T[Dim]),
- "size error");
- static_assert(detail::nextPowerOf2(sizeof(T[Dim])) < 2 * sizeof(T[Dim]),
+ static_assert(nextPowerOf2(sizeof(T[Dim])) > sizeof(T[Dim]), "size error");
+ static_assert(nextPowerOf2(sizeof(T[Dim])) < 2 * sizeof(T[Dim]),
"size error");
}
inline T &operator[](unsigned i) { return vector[i]; }
@@ -75,9 +78,10 @@ struct Vector1D<T, Dim, /*IsPowerOf2=*/false> {
private:
T vector[Dim];
- char padding[detail::nextPowerOf2(sizeof(T[Dim])) - sizeof(T[Dim])];
+ char padding[nextPowerOf2(sizeof(T[Dim])) - sizeof(T[Dim])];
};
} // end namespace detail
+} // end namespace mlir
// N-D vectors recurse down to 1-D.
template <typename T, int Dim, int... Dims>
@@ -95,7 +99,9 @@ struct Vector {
// We insert explicit padding in to account for this.
template <typename T, int Dim>
struct Vector<T, Dim>
- : public detail::Vector1D<T, Dim, detail::isPowerOf2(sizeof(T[Dim]))> {};
+ : public mlir::detail::Vector1D<T, Dim,
+ mlir::detail::isPowerOf2(sizeof(T[Dim]))> {
+};
template <int D1, typename T>
using Vector1D = Vector<T, D1>;
@@ -115,6 +121,9 @@ void dropFront(int64_t arr[N], int64_t *res) {
//===----------------------------------------------------------------------===//
// Codegen-compatible structures for StridedMemRef type.
//===----------------------------------------------------------------------===//
+template <typename T, int Rank>
+class StridedMemrefIterator;
+
/// StridedMemRef descriptor type with static rank.
template <typename T, int N>
struct StridedMemRefType {
@@ -123,6 +132,23 @@ struct StridedMemRefType {
int64_t offset;
int64_t sizes[N];
int64_t strides[N];
+
+ template <typename Range>
+ T &operator[](Range indices) {
+ assert(indices.size() == N &&
+ "indices should match rank in memref subscript");
+ int64_t curOffset = offset;
+ for (int dim = N - 1; dim >= 0; --dim) {
+ int64_t currentIndex = *(indices.begin() + dim);
+ assert(currentIndex < sizes[dim] && "Index overflow");
+ curOffset += currentIndex * strides[dim];
+ }
+ return data[curOffset];
+ }
+
+ StridedMemrefIterator<T, N> begin() { return {*this}; }
+ StridedMemrefIterator<T, N> end() { return {*this, -1}; }
+
// This operator[] is extremely slow and only for sugaring purposes.
StridedMemRefType<T, N - 1> operator[](int64_t idx) {
StridedMemRefType<T, N - 1> res;
@@ -143,6 +169,17 @@ struct StridedMemRefType<T, 1> {
int64_t offset;
int64_t sizes[1];
int64_t strides[1];
+
+ template <typename Range>
+ T &operator[](Range indices) {
+ assert(indices.size() == 1 &&
+ "indices should match rank in memref subscript");
+ return (*this)[*indices.begin()];
+ }
+
+ StridedMemrefIterator<T, 1> begin() { return {*this}; }
+ StridedMemrefIterator<T, 1> end() { return {*this, -1}; }
+
T &operator[](int64_t idx) { return *(data + offset + idx * strides[0]); }
};
@@ -152,6 +189,99 @@ struct StridedMemRefType<T, 0> {
T *basePtr;
T *data;
int64_t offset;
+
+ template <typename Range>
+ T &operator[](Range indices) {
+ assert((indices.size() == 0) &&
+ "Expect empty indices for 0-rank memref subscript");
+ return data[offset];
+ }
+
+ StridedMemrefIterator<T, 0> begin() { return {*this}; }
+ StridedMemrefIterator<T, 0> end() { return {*this, 1}; }
+};
+
+/// Iterate over all elements in a strided memref.
+template <typename T, int Rank>
+class StridedMemrefIterator {
+public:
+ StridedMemrefIterator(StridedMemRefType<T, Rank> &descriptor,
+ int64_t offset = 0)
+ : offset(offset), descriptor(descriptor) {}
+ StridedMemrefIterator<T, Rank> &operator++() {
+ int dim = Rank - 1;
+ while (dim >= 0 && indices[dim] == (descriptor.sizes[dim] - 1)) {
+ offset -= indices[dim] * descriptor.strides[dim];
+ indices[dim] = 0;
+ --dim;
+ }
+ if (dim < 0) {
+ offset = -1;
+ return *this;
+ }
+ ++indices[dim];
+ offset += descriptor.strides[dim];
+ return *this;
+ }
+
+ T &operator*() { return descriptor.data[offset]; }
+ T *operator->() { return &descriptor.data[offset]; }
+
+ const std::array<int64_t, Rank> &getIndices() { return indices; }
+
+ bool operator==(const StridedMemrefIterator &other) const {
+ return other.offset == offset && &other.descriptor == &descriptor;
+ }
+
+ bool operator!=(const StridedMemrefIterator &other) const {
+ return !(*this == other);
+ }
+
+private:
+ /// Offset in the buffer. This can be derived from the indices and the
+ /// descriptor.
+ int64_t offset = 0;
+ /// Array of indices in the multi-dimensional memref.
+ std::array<int64_t, Rank> indices = {};
+ /// Descriptor for the strided memref.
+ StridedMemRefType<T, Rank> &descriptor;
+};
+
+/// Iterate over all elements in a 0-ranked strided memref.
+template <typename T>
+class StridedMemrefIterator<T, 0> {
+public:
+ StridedMemrefIterator(StridedMemRefType<T, 0> &descriptor, int64_t offset = 0)
+ : elt(descriptor.data + offset) {}
+
+ StridedMemrefIterator<T, 0> &operator++() {
+ ++elt;
+ return *this;
+ }
+
+ T &operator*() { return *elt; }
+ T *operator->() { return elt; }
+
+ // There are no indices for a 0-ranked memref, but this API is provided for
+ // consistency with the general case.
+ const std::array<int64_t, 0> &getIndices() {
+ // Since this is a 0-array of indices we can keep a single global const
+ // copy.
+ static const std::array<int64_t, 0> indices = {};
+ return indices;
+ }
+
+ bool operator==(const StridedMemrefIterator &other) const {
+ return other.elt == elt;
+ }
+
+ bool operator!=(const StridedMemrefIterator &other) const {
+ return !(*this == other);
+ }
+
+private:
+ /// Pointer to the single element in the zero-ranked memref.
+ T *elt;
};
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/ExecutionEngine/MemRefUtils.h b/mlir/include/mlir/ExecutionEngine/MemRefUtils.h
new file mode 100644
index 000000000000..53227e9f4eae
--- /dev/null
+++ b/mlir/include/mlir/ExecutionEngine/MemRefUtils.h
@@ -0,0 +1,214 @@
+//===- MemRefUtils.h - Memref helpers to invoke MLIR JIT code ---*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Utils for MLIR ABI interfacing with frameworks.
+//
+// The templated free functions below make it possible to allocate dense
+// contiguous buffers with shapes that interoperate properly with the MLIR
+// codegen ABI.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/ExecutionEngine/CRunnerUtils.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/Optional.h"
+#include "llvm/ADT/STLExtras.h"
+
+#include "llvm/Support/raw_ostream.h"
+
+#include <algorithm>
+#include <array>
+#include <cassert>
+#include <functional>
+#include <initializer_list>
+#include <memory>
+
+#ifndef MLIR_EXECUTIONENGINE_MEMREFUTILS_H_
+#define MLIR_EXECUTIONENGINE_MEMREFUTILS_H_
+
+namespace mlir {
+using AllocFunType = llvm::function_ref<void *(size_t)>;
+
+namespace detail {
+
+/// Given a shape with sizes greater than 0 along all dimensions, returns the
+/// distance, in number of elements, between a slice in a dimension and the next
+/// slice in the same dimension.
+/// e.g. shape[3, 4, 5] -> strides[20, 5, 1]
+template <size_t N>
+inline std::array<int64_t, N> makeStrides(ArrayRef<int64_t> shape) {
+ assert(shape.size() == N && "expect shape specification to match rank");
+ std::array<int64_t, N> res;
+ int64_t running = 1;
+ for (int64_t idx = N - 1; idx >= 0; --idx) {
+ assert(shape[idx] && "size must be non-negative for all shape dimensions");
+ res[idx] = running;
+ running *= shape[idx];
+ }
+ return res;
+}
+
+/// Build a `StridedMemRefDescriptor<T, N>` that matches the MLIR ABI.
+/// This is an implementation detail that is kept in sync with MLIR codegen
+/// conventions. Additionally takes a `shapeAlloc` array which
+/// is used instead of `shape` to allocate "more aligned" data and compute the
+/// corresponding strides.
+template <int N, typename T>
+typename std::enable_if<(N >= 1), StridedMemRefType<T, N>>::type
+makeStridedMemRefDescriptor(T *ptr, T *alignedPtr, ArrayRef<int64_t> shape,
+ ArrayRef<int64_t> shapeAlloc) {
+ assert(shape.size() == N);
+ assert(shapeAlloc.size() == N);
+ StridedMemRefType<T, N> descriptor;
+ descriptor.basePtr = static_cast<T *>(ptr);
+ descriptor.data = static_cast<T *>(alignedPtr);
+ descriptor.offset = 0;
+ std::copy(shape.begin(), shape.end(), descriptor.sizes);
+ auto strides = makeStrides<N>(shapeAlloc);
+ std::copy(strides.begin(), strides.end(), descriptor.strides);
+ return descriptor;
+}
+
+/// Build a `StridedMemRefDescriptor<T, 0>` that matches the MLIR ABI.
+/// This is an implementation detail that is kept in sync with MLIR codegen
+/// conventions. Additionally takes a `shapeAlloc` array which
+/// is used instead of `shape` to allocate "more aligned" data and compute the
+/// corresponding strides.
+template <int N, typename T>
+typename std::enable_if<(N == 0), StridedMemRefType<T, 0>>::type
+makeStridedMemRefDescriptor(T *ptr, T *alignedPtr, ArrayRef<int64_t> shape = {},
+ ArrayRef<int64_t> shapeAlloc = {}) {
+ assert(shape.size() == N);
+ assert(shapeAlloc.size() == N);
+ StridedMemRefType<T, 0> descriptor;
+ descriptor.basePtr = static_cast<T *>(ptr);
+ descriptor.data = static_cast<T *>(alignedPtr);
+ descriptor.offset = 0;
+ return descriptor;
+}
+
+/// Align `nElements` of type T with an optional `alignment`.
+/// This replaces a portable `posix_memalign`.
+/// `alignment` must be a power of 2 and greater than the size of T. By default
+/// the alignment is sizeof(T).
+template <typename T>
+std::pair<T *, T *>
+allocAligned(size_t nElements, AllocFunType allocFun = &::malloc,
+ llvm::Optional<uint64_t> alignment = llvm::Optional<uint64_t>()) {
+ assert(sizeof(T) < (1ul << 32) && "Elemental type overflows");
+ auto size = nElements * sizeof(T);
+ auto desiredAlignment = alignment.getValueOr(nextPowerOf2(sizeof(T)));
+ assert((desiredAlignment & (desiredAlignment - 1)) == 0);
+ assert(desiredAlignment >= sizeof(T));
+ T *data = reinterpret_cast<T *>(allocFun(size + desiredAlignment));
+ uintptr_t addr = reinterpret_cast<uintptr_t>(data);
+ uintptr_t rem = addr % desiredAlignment;
+ T *alignedData = (rem == 0)
+ ? data
+ : reinterpret_cast<T *>(addr + (desiredAlignment - rem));
+ assert(reinterpret_cast<uintptr_t>(alignedData) % desiredAlignment == 0);
+ return std::make_pair(data, alignedData);
+}
+
+} // namespace detail
+
+//===----------------------------------------------------------------------===//
+// Public API
+//===----------------------------------------------------------------------===//
+
+/// Convenient callback to "visit" a memref element by element.
+/// This takes a reference to an individual element as well as the coordinates.
+/// It can be used in conjuction with a StridedMemrefIterator.
+template <typename T>
+using ElementWiseVisitor = llvm::function_ref<void(T &ptr, ArrayRef<int64_t>)>;
+
+/// Owning MemRef type that abstracts over the runtime type for ranked strided
+/// memref.
+template <typename T, int Rank>
+class OwningMemRef {
+public:
+ using DescriptorType = StridedMemRefType<T, Rank>;
+ using FreeFunType = std::function<void(DescriptorType)>;
+
+ /// Allocate a new dense StridedMemrefRef with a given `shape`. An optional
+ /// `shapeAlloc` array can be supplied to "pad" every dimension individually.
+ /// If an ElementWiseVisitor is provided, it will be used to initialize the
+ /// data, else the memory will be zero-initialized. The alloc and free method
+ /// used to manage the data allocation can be optionally provided, and default
+ /// to malloc/free.
+ OwningMemRef(
+ ArrayRef<int64_t> shape, ArrayRef<int64_t> shapeAlloc = {},
+ ElementWiseVisitor<T> init = {},
+ llvm::Optional<uint64_t> alignment = llvm::Optional<uint64_t>(),
+ AllocFunType allocFun = &::malloc,
+ std::function<void(StridedMemRefType<T, Rank>)> freeFun =
+ [](StridedMemRefType<T, Rank> descriptor) {
+ ::free(descriptor.data);
+ })
+ : freeFunc(freeFun) {
+ if (shapeAlloc.empty())
+ shapeAlloc = shape;
+ assert(shape.size() == Rank);
+ assert(shapeAlloc.size() == Rank);
+ for (unsigned i = 0; i < Rank; ++i)
+ assert(shape[i] <= shapeAlloc[i] &&
+ "shapeAlloc must be greater than or equal to shape");
+ int64_t nElements = 1;
+ for (int64_t s : shapeAlloc)
+ nElements *= s;
+ T *data, *alignedData;
+ std::tie(data, alignedData) =
+ detail::allocAligned<T>(nElements, allocFun, alignment);
+ descriptor = detail::makeStridedMemRefDescriptor<Rank>(data, alignedData,
+ shape, shapeAlloc);
+ if (init) {
+ for (StridedMemrefIterator<T, Rank> it = descriptor.begin(),
+ end = descriptor.end();
+ it != end; ++it)
+ init(*it, it.getIndices());
+ } else {
+ memset(descriptor.data, 0,
+ nElements * sizeof(T) +
+ alignment.getValueOr(detail::nextPowerOf2(sizeof(T))));
+ }
+ }
+ /// Take ownership of an existing descriptor with a custom deleter.
+ OwningMemRef(DescriptorType descriptor, FreeFunType freeFunc)
+ : freeFunc(freeFunc), descriptor(descriptor) {}
+ ~OwningMemRef() {
+ if (freeFunc)
+ freeFunc(descriptor);
+ }
+ OwningMemRef(const OwningMemRef &) = delete;
+ OwningMemRef &operator=(const OwningMemRef &) = delete;
+ OwningMemRef &operator=(const OwningMemRef &&other) {
+ freeFunc = other.freeFunc;
+ descriptor = other.descriptor;
+ other.freeFunc = nullptr;
+ memset(0, &other.descriptor, sizeof(other.descriptor));
+ }
+ OwningMemRef(OwningMemRef &&other) { *this = std::move(other); }
+
+ DescriptorType &operator*() { return descriptor; }
+ DescriptorType *operator->() { return &descriptor; }
+ T &operator[](std::initializer_list<int64_t> indices) {
+ return descriptor[std::move(indices)];
+ }
+
+private:
+ /// Custom deleter used to release the data buffer manager with the descriptor
+ /// below.
+ FreeFunType freeFunc;
+ /// The descriptor is an instance of StridedMemRefType<T, rank>.
+ DescriptorType descriptor;
+};
+
+} // namespace mlir
+
+#endif // MLIR_EXECUTIONENGINE_MEMREFUTILS_H_
diff --git a/mlir/unittests/ExecutionEngine/Invoke.cpp b/mlir/unittests/ExecutionEngine/Invoke.cpp
index c9abf2e108b8..29c59bdba857 100644
--- a/mlir/unittests/ExecutionEngine/Invoke.cpp
+++ b/mlir/unittests/ExecutionEngine/Invoke.cpp
@@ -13,6 +13,7 @@
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/ExecutionEngine/CRunnerUtils.h"
#include "mlir/ExecutionEngine/ExecutionEngine.h"
+#include "mlir/ExecutionEngine/MemRefUtils.h"
#include "mlir/ExecutionEngine/RunnerUtils.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/InitAllDialects.h"
@@ -51,7 +52,7 @@ TEST(MLIRExecutionEngine, AddInteger) {
}
)mlir";
MLIRContext context;
- registerAllDialects(context.getDialectRegistry());
+ registerAllDialects(context);
OwningModuleRef module = parseSourceString(moduleStr, &context);
ASSERT_TRUE(!!module);
ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
@@ -74,7 +75,7 @@ TEST(MLIRExecutionEngine, SubtractFloat) {
}
)mlir";
MLIRContext context;
- registerAllDialects(context.getDialectRegistry());
+ registerAllDialects(context);
OwningModuleRef module = parseSourceString(moduleStr, &context);
ASSERT_TRUE(!!module);
ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
@@ -89,4 +90,163 @@ TEST(MLIRExecutionEngine, SubtractFloat) {
ASSERT_EQ(result, 42.f);
}
+TEST(NativeMemRefJit, ZeroRankMemref) {
+ OwningMemRef<float, 0> A({});
+ A[{}] = 42.;
+ ASSERT_EQ(*A->data, 42);
+ A[{}] = 0;
+ std::string moduleStr = R"mlir(
+ func @zero_ranked(%arg0 : memref<f32>) attributes { llvm.emit_c_interface } {
+ %cst42 = constant 42.0 : f32
+ store %cst42, %arg0[] : memref<f32>
+ return
+ }
+ )mlir";
+ MLIRContext context;
+ registerAllDialects(context);
+ auto module = parseSourceString(moduleStr, &context);
+ ASSERT_TRUE(!!module);
+ ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
+ auto jitOrError = ExecutionEngine::create(*module);
+ ASSERT_TRUE(!!jitOrError);
+ auto jit = std::move(jitOrError.get());
+
+ llvm::Error error = jit->invoke("zero_ranked", &*A);
+ ASSERT_TRUE(!error);
+ EXPECT_EQ((A[{}]), 42.);
+ for (float &elt : *A)
+ EXPECT_EQ(&elt, &(A[{}]));
+}
+
+TEST(NativeMemRefJit, RankOneMemref) {
+ int64_t shape[] = {9};
+ OwningMemRef<float, 1> A(shape);
+ int count = 1;
+ for (float &elt : *A) {
+ EXPECT_EQ(&elt, &(A[{count - 1}]));
+ elt = count++;
+ }
+
+ std::string moduleStr = R"mlir(
+ func @one_ranked(%arg0 : memref<?xf32>) attributes { llvm.emit_c_interface } {
+ %cst42 = constant 42.0 : f32
+ %cst5 = constant 5 : index
+ store %cst42, %arg0[%cst5] : memref<?xf32>
+ return
+ }
+ )mlir";
+ MLIRContext context;
+ registerAllDialects(context);
+ auto module = parseSourceString(moduleStr, &context);
+ ASSERT_TRUE(!!module);
+ ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
+ auto jitOrError = ExecutionEngine::create(*module);
+ ASSERT_TRUE(!!jitOrError);
+ auto jit = std::move(jitOrError.get());
+
+ llvm::Error error = jit->invoke("one_ranked", &*A);
+ ASSERT_TRUE(!error);
+ count = 1;
+ for (float &elt : *A) {
+ if (count == 6)
+ EXPECT_EQ(elt, 42.);
+ else
+ EXPECT_EQ(elt, count);
+ count++;
+ }
+}
+
+TEST(NativeMemRefJit, BasicMemref) {
+ constexpr int K = 3;
+ constexpr int M = 7;
+ // Prepare arguments beforehand.
+ auto init = [=](float &elt, ArrayRef<int64_t> indices) {
+ assert(indices.size() == 2);
+ elt = M * indices[0] + indices[1];
+ };
+ int64_t shape[] = {K, M};
+ int64_t shapeAlloc[] = {K + 1, M + 1};
+ OwningMemRef<float, 2> A(shape, shapeAlloc, init);
+ ASSERT_EQ(A->sizes[0], K);
+ ASSERT_EQ(A->sizes[1], M);
+ ASSERT_EQ(A->strides[0], M + 1);
+ ASSERT_EQ(A->strides[1], 1);
+ for (int i = 0; i < K; ++i)
+ for (int j = 0; j < M; ++j)
+ EXPECT_EQ((A[{i, j}]), i * M + j);
+
+ std::string moduleStr = R"mlir(
+ func @rank2_memref(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?xf32>) attributes { llvm.emit_c_interface } {
+ %x = constant 2 : index
+ %y = constant 1 : index
+ %cst42 = constant 42.0 : f32
+ store %cst42, %arg0[%y, %x] : memref<?x?xf32>
+ store %cst42, %arg1[%x, %y] : memref<?x?xf32>
+ return
+ }
+ )mlir";
+ MLIRContext context;
+ registerAllDialects(context);
+ OwningModuleRef module = parseSourceString(moduleStr, &context);
+ ASSERT_TRUE(!!module);
+ ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
+ auto jitOrError = ExecutionEngine::create(*module);
+ ASSERT_TRUE(!!jitOrError);
+ std::unique_ptr<ExecutionEngine> jit = std::move(jitOrError.get());
+
+ llvm::Error error = jit->invoke("rank2_memref", &*A, &*A);
+ ASSERT_TRUE(!error);
+ EXPECT_EQ((A[{1, 2}]), 42.);
+ EXPECT_EQ((A[{2, 1}]), 42.);
+}
+
+// A helper function that will be called from the JIT
+static void memref_multiply(::StridedMemRefType<float, 2> *memref,
+ int32_t coefficient) {
+ for (float &elt : *memref)
+ elt *= coefficient;
+}
+
+TEST(NativeMemRefJit, JITCallback) {
+ constexpr int K = 2;
+ constexpr int M = 2;
+ int64_t shape[] = {K, M};
+ int64_t shapeAlloc[] = {K + 1, M + 1};
+ OwningMemRef<float, 2> A(shape, shapeAlloc);
+ int count = 1;
+ for (float &elt : *A)
+ elt = count++;
+
+ std::string moduleStr = R"mlir(
+ func private @callback(%arg0: memref<?x?xf32>, %coefficient: i32) attributes { llvm.emit_c_interface }
+ func @caller_for_callback(%arg0: memref<?x?xf32>, %coefficient: i32) attributes { llvm.emit_c_interface } {
+ %unranked = memref_cast %arg0: memref<?x?xf32> to memref<*xf32>
+ call @callback(%arg0, %coefficient) : (memref<?x?xf32>, i32) -> ()
+ return
+ }
+ )mlir";
+ MLIRContext context;
+ registerAllDialects(context);
+ auto module = parseSourceString(moduleStr, &context);
+ ASSERT_TRUE(!!module);
+ ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
+ auto jitOrError = ExecutionEngine::create(*module);
+ ASSERT_TRUE(!!jitOrError);
+ auto jit = std::move(jitOrError.get());
+ // Define any extra symbols so they're available at runtime.
+ jit->registerSymbols([&](llvm::orc::MangleAndInterner interner) {
+ llvm::orc::SymbolMap symbolMap;
+ symbolMap[interner("_mlir_ciface_callback")] =
+ llvm::JITEvaluatedSymbol::fromPointer(memref_multiply);
+ return symbolMap;
+ });
+
+ int32_t coefficient = 3.;
+ llvm::Error error = jit->invoke("caller_for_callback", &*A, coefficient);
+ ASSERT_TRUE(!error);
+ count = 1;
+ for (float elt : *A)
+ ASSERT_EQ(elt, coefficient * count++);
+}
+
#endif // _WIN32
More information about the Mlir-commits
mailing list