[Mlir-commits] [mlir] 8a10ee7 - [MLIR] DynamicMemRefType: iteration and access by indices
Michele Scuttari
llvmlistbot at llvm.org
Thu Aug 18 12:31:51 PDT 2022
Author: Michele Scuttari
Date: 2022-08-18T21:30:20+02:00
New Revision: 8a10ee7590a91c9b2e5c80b52822a3a4c3af1a15
URL: https://github.com/llvm/llvm-project/commit/8a10ee7590a91c9b2e5c80b52822a3a4c3af1a15
DIFF: https://github.com/llvm/llvm-project/commit/8a10ee7590a91c9b2e5c80b52822a3a4c3af1a15.diff
LOG: [MLIR] DynamicMemRefType: iteration and access by indices
The methods to perform such operations have been implemented for the DynamicMemRefType in a way that is similar to the implementation for StridedMemRefType. Up until here one could pass an unranked memref to the library, and thus obtain a “dynamic” memref descriptor, but then there would have been no possibility to operate on its content.
Added:
mlir/unittests/ExecutionEngine/DynamicMemRef.cpp
Modified:
mlir/include/mlir/ExecutionEngine/CRunnerUtils.h
mlir/unittests/ExecutionEngine/CMakeLists.txt
Removed:
################################################################################
diff --git a/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h b/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h
index e536ae8fe1154..f76e8c0eb0db4 100644
--- a/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h
+++ b/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h
@@ -36,6 +36,7 @@
#include <cassert>
#include <cstdint>
#include <initializer_list>
+#include <vector>
//===----------------------------------------------------------------------===//
// Codegen-compatible structures for Vector type.
@@ -209,13 +210,19 @@ struct StridedMemRefType<T, 0> {
template <typename T, int Rank>
class StridedMemrefIterator {
public:
+ using iterator_category = std::forward_iterator_tag;
+ using value_type = T;
+ using
diff erence_type = std::ptr
diff _t;
+ using pointer = T *;
+ using reference = T &;
+
StridedMemrefIterator(StridedMemRefType<T, Rank> &descriptor,
int64_t offset = 0)
- : offset(offset), descriptor(descriptor) {}
+ : offset(offset), descriptor(&descriptor) {}
StridedMemrefIterator<T, Rank> &operator++() {
int dim = Rank - 1;
- while (dim >= 0 && indices[dim] == (descriptor.sizes[dim] - 1)) {
- offset -= indices[dim] * descriptor.strides[dim];
+ while (dim >= 0 && indices[dim] == (descriptor->sizes[dim] - 1)) {
+ offset -= indices[dim] * descriptor->strides[dim];
indices[dim] = 0;
--dim;
}
@@ -224,17 +231,17 @@ class StridedMemrefIterator {
return *this;
}
++indices[dim];
- offset += descriptor.strides[dim];
+ offset += descriptor->strides[dim];
return *this;
}
- T &operator*() { return descriptor.data[offset]; }
- T *operator->() { return &descriptor.data[offset]; }
+ reference operator*() { return descriptor->data[offset]; }
+ pointer operator->() { return &descriptor->data[offset]; }
const std::array<int64_t, Rank> &getIndices() { return indices; }
bool operator==(const StridedMemrefIterator &other) const {
- return other.offset == offset && &other.descriptor == &descriptor;
+ return other.offset == offset && other.descriptor == descriptor;
}
bool operator!=(const StridedMemrefIterator &other) const {
@@ -245,16 +252,24 @@ class StridedMemrefIterator {
/// Offset in the buffer. This can be derived from the indices and the
/// descriptor.
int64_t offset = 0;
+
/// Array of indices in the multi-dimensional memref.
std::array<int64_t, Rank> indices = {};
+
/// Descriptor for the strided memref.
- StridedMemRefType<T, Rank> &descriptor;
+ StridedMemRefType<T, Rank> *descriptor;
};
/// Iterate over all elements in a 0-ranked strided memref.
template <typename T>
class StridedMemrefIterator<T, 0> {
public:
+ using iterator_category = std::forward_iterator_tag;
+ using value_type = T;
+ using
diff erence_type = std::ptr
diff _t;
+ using pointer = T *;
+ using reference = T &;
+
StridedMemrefIterator(StridedMemRefType<T, 0> &descriptor, int64_t offset = 0)
: elt(descriptor.data + offset) {}
@@ -263,8 +278,8 @@ class StridedMemrefIterator<T, 0> {
return *this;
}
- T &operator*() { return *elt; }
- T *operator->() { return elt; }
+ reference operator*() { return *elt; }
+ pointer operator->() { return elt; }
// There are no indices for a 0-ranked memref, but this API is provided for
// consistency with the general case.
@@ -301,10 +316,20 @@ struct UnrankedMemRefType {
//===----------------------------------------------------------------------===//
// DynamicMemRefType type.
//===----------------------------------------------------------------------===//
+template <typename T>
+class DynamicMemRefIterator;
+
// A reference to one of the StridedMemRef types.
template <typename T>
class DynamicMemRefType {
public:
+ int64_t rank;
+ T *basePtr;
+ T *data;
+ int64_t offset;
+ const int64_t *sizes;
+ const int64_t *strides;
+
explicit DynamicMemRefType(const StridedMemRefType<T, 0> &memRef)
: rank(0), basePtr(memRef.basePtr), data(memRef.data),
offset(memRef.offset), sizes(nullptr), strides(nullptr) {}
@@ -322,12 +347,108 @@ class DynamicMemRefType {
strides = sizes + rank;
}
- int64_t rank;
- T *basePtr;
- T *data;
- int64_t offset;
- const int64_t *sizes;
- const int64_t *strides;
+ template <typename Range,
+ typename sfinae = decltype(std::declval<Range>().begin())>
+ T &operator[](Range &&indices) {
+ assert(indices.size() == rank &&
+ "indices should match rank in memref subscript");
+ if (rank == 0)
+ return data[offset];
+
+ int64_t curOffset = offset;
+ for (int dim = rank - 1; dim >= 0; --dim) {
+ int64_t currentIndex = *(indices.begin() + dim);
+ assert(currentIndex < sizes[dim] && "Index overflow");
+ curOffset += currentIndex * strides[dim];
+ }
+ return data[curOffset];
+ }
+
+ DynamicMemRefIterator<T> begin() { return {*this}; }
+ DynamicMemRefIterator<T> end() { return {*this, -1}; }
+
+ // This operator[] is extremely slow and only for sugaring purposes.
+ DynamicMemRefType<T> operator[](int64_t idx) {
+ assert(rank > 0 && "can't make a subscript of a zero ranked array");
+
+ DynamicMemRefType<T> res(*this);
+ --res.rank;
+ res.offset += idx * res.strides[0];
+ ++res.sizes;
+ ++res.strides;
+ return res;
+ }
+
+ // This operator* can be used in conjunction with the previous operator[] in
+ // order to access the underlying value in case of zero-ranked memref.
+ T &operator*() {
+ assert(rank == 0 && "not a zero-ranked memRef");
+ return data[offset];
+ }
+};
+
+/// Iterate over all elements in a dynamic memref.
+template <typename T>
+class DynamicMemRefIterator {
+public:
+ using iterator_category = std::forward_iterator_tag;
+ using value_type = T;
+ using
diff erence_type = std::ptr
diff _t;
+ using pointer = T *;
+ using reference = T &;
+
+ DynamicMemRefIterator(DynamicMemRefType<T> &descriptor, int64_t offset = 0)
+ : offset(offset), descriptor(&descriptor) {
+ indices.resize(descriptor.rank, 0);
+ }
+
+ DynamicMemRefIterator<T> &operator++() {
+ if (descriptor->rank == 0) {
+ offset = -1;
+ return *this;
+ }
+
+ int dim = descriptor->rank - 1;
+
+ while (dim >= 0 && indices[dim] == (descriptor->sizes[dim] - 1)) {
+ offset -= indices[dim] * descriptor->strides[dim];
+ indices[dim] = 0;
+ --dim;
+ }
+
+ if (dim < 0) {
+ offset = -1;
+ return *this;
+ }
+
+ ++indices[dim];
+ offset += descriptor->strides[dim];
+ return *this;
+ }
+
+ reference operator*() { return descriptor->data[offset]; }
+ pointer operator->() { return &descriptor->data[offset]; }
+
+ const std::vector<int64_t> &getIndices() { return indices; }
+
+ bool operator==(const DynamicMemRefIterator &other) const {
+ return other.offset == offset && other.descriptor == descriptor;
+ }
+
+ bool operator!=(const DynamicMemRefIterator &other) const {
+ return !(*this == other);
+ }
+
+private:
+ /// Offset in the buffer. This can be derived from the indices and the
+ /// descriptor.
+ int64_t offset = 0;
+
+ /// Array of indices in the multi-dimensional memref.
+ std::vector<int64_t> indices = {};
+
+ /// Descriptor for the dynamic memref.
+ DynamicMemRefType<T> *descriptor;
};
//===----------------------------------------------------------------------===//
diff --git a/mlir/unittests/ExecutionEngine/CMakeLists.txt b/mlir/unittests/ExecutionEngine/CMakeLists.txt
index d17acb6647f81..32722d0dd9582 100644
--- a/mlir/unittests/ExecutionEngine/CMakeLists.txt
+++ b/mlir/unittests/ExecutionEngine/CMakeLists.txt
@@ -1,4 +1,5 @@
add_mlir_unittest(MLIRExecutionEngineTests
+ DynamicMemRef.cpp
Invoke.cpp
)
get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
diff --git a/mlir/unittests/ExecutionEngine/DynamicMemRef.cpp b/mlir/unittests/ExecutionEngine/DynamicMemRef.cpp
new file mode 100644
index 0000000000000..5f4f012702468
--- /dev/null
+++ b/mlir/unittests/ExecutionEngine/DynamicMemRef.cpp
@@ -0,0 +1,99 @@
+//===- DynamicMemRef.cpp ----------------------------------------*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/ExecutionEngine/CRunnerUtils.h"
+#include "llvm/ADT/SmallVector.h"
+
+#include "gmock/gmock.h"
+
+using namespace ::mlir;
+using namespace ::testing;
+
+TEST(DynamicMemRef, rankZero) {
+ int data = 57;
+
+ StridedMemRefType<int, 0> memRef;
+ memRef.basePtr = &data;
+ memRef.data = &data;
+ memRef.offset = 0;
+
+ DynamicMemRefType<int> dynamicMemRef(memRef);
+
+ llvm::SmallVector<int, 1> values(dynamicMemRef.begin(), dynamicMemRef.end());
+ EXPECT_THAT(values, ElementsAre(57));
+}
+
+TEST(DynamicMemRef, rankOne) {
+ std::array<int, 3> data;
+
+ for (size_t i = 0; i < data.size(); ++i) {
+ data[i] = i;
+ }
+
+ StridedMemRefType<int, 1> memRef;
+ memRef.basePtr = data.data();
+ memRef.data = data.data();
+ memRef.offset = 0;
+ memRef.sizes[0] = 3;
+ memRef.strides[0] = 1;
+
+ DynamicMemRefType<int> dynamicMemRef(memRef);
+
+ llvm::SmallVector<int, 3> values(dynamicMemRef.begin(), dynamicMemRef.end());
+ EXPECT_THAT(values, ElementsAreArray(data));
+
+ for (int64_t i = 0; i < 3; ++i) {
+ EXPECT_EQ(*dynamicMemRef[i], data[i]);
+ }
+}
+
+TEST(DynamicMemRef, rankTwo) {
+ std::array<int, 6> data;
+
+ for (size_t i = 0; i < data.size(); ++i) {
+ data[i] = i;
+ }
+
+ StridedMemRefType<int, 2> memRef;
+ memRef.basePtr = data.data();
+ memRef.data = data.data();
+ memRef.offset = 0;
+ memRef.sizes[0] = 2;
+ memRef.sizes[1] = 3;
+ memRef.strides[0] = 3;
+ memRef.strides[1] = 1;
+
+ DynamicMemRefType<int> dynamicMemRef(memRef);
+
+ llvm::SmallVector<int, 6> values(dynamicMemRef.begin(), dynamicMemRef.end());
+ EXPECT_THAT(values, ElementsAreArray(data));
+}
+
+TEST(DynamicMemRef, rankThree) {
+ std::array<int, 24> data;
+
+ for (size_t i = 0; i < data.size(); ++i) {
+ data[i] = i;
+ }
+
+ StridedMemRefType<int, 3> memRef;
+ memRef.basePtr = data.data();
+ memRef.data = data.data();
+ memRef.offset = 0;
+ memRef.sizes[0] = 2;
+ memRef.sizes[1] = 3;
+ memRef.sizes[2] = 4;
+ memRef.strides[0] = 12;
+ memRef.strides[1] = 4;
+ memRef.strides[2] = 1;
+
+ DynamicMemRefType<int> dynamicMemRef(memRef);
+
+ llvm::SmallVector<int, 24> values(dynamicMemRef.begin(), dynamicMemRef.end());
+ EXPECT_THAT(values, ElementsAreArray(data));
+}
\ No newline at end of file
More information about the Mlir-commits
mailing list