[Mlir-commits] [mlir] 8a10ee7 - [MLIR] DynamicMemRefType: iteration and access by indices

Michele Scuttari llvmlistbot at llvm.org
Thu Aug 18 12:31:51 PDT 2022


Author: Michele Scuttari
Date: 2022-08-18T21:30:20+02:00
New Revision: 8a10ee7590a91c9b2e5c80b52822a3a4c3af1a15

URL: https://github.com/llvm/llvm-project/commit/8a10ee7590a91c9b2e5c80b52822a3a4c3af1a15
DIFF: https://github.com/llvm/llvm-project/commit/8a10ee7590a91c9b2e5c80b52822a3a4c3af1a15.diff

LOG: [MLIR] DynamicMemRefType: iteration and access by indices

The methods to perform such operations have been implemented for the DynamicMemRefType in a way that is similar to the implementation for StridedMemRefType. Up until here one could pass an unranked memref to the library, and thus obtain a “dynamic” memref descriptor, but then there would have been no possibility to operate on its content.

Added: 
    mlir/unittests/ExecutionEngine/DynamicMemRef.cpp

Modified: 
    mlir/include/mlir/ExecutionEngine/CRunnerUtils.h
    mlir/unittests/ExecutionEngine/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h b/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h
index e536ae8fe1154..f76e8c0eb0db4 100644
--- a/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h
+++ b/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h
@@ -36,6 +36,7 @@
 #include <cassert>
 #include <cstdint>
 #include <initializer_list>
+#include <vector>
 
 //===----------------------------------------------------------------------===//
 // Codegen-compatible structures for Vector type.
@@ -209,13 +210,19 @@ struct StridedMemRefType<T, 0> {
 template <typename T, int Rank>
 class StridedMemrefIterator {
 public:
+  using iterator_category = std::forward_iterator_tag;
+  using value_type = T;
+  using 
diff erence_type = std::ptr
diff _t;
+  using pointer = T *;
+  using reference = T &;
+
   StridedMemrefIterator(StridedMemRefType<T, Rank> &descriptor,
                         int64_t offset = 0)
-      : offset(offset), descriptor(descriptor) {}
+      : offset(offset), descriptor(&descriptor) {}
   StridedMemrefIterator<T, Rank> &operator++() {
     int dim = Rank - 1;
-    while (dim >= 0 && indices[dim] == (descriptor.sizes[dim] - 1)) {
-      offset -= indices[dim] * descriptor.strides[dim];
+    while (dim >= 0 && indices[dim] == (descriptor->sizes[dim] - 1)) {
+      offset -= indices[dim] * descriptor->strides[dim];
       indices[dim] = 0;
       --dim;
     }
@@ -224,17 +231,17 @@ class StridedMemrefIterator {
       return *this;
     }
     ++indices[dim];
-    offset += descriptor.strides[dim];
+    offset += descriptor->strides[dim];
     return *this;
   }
 
-  T &operator*() { return descriptor.data[offset]; }
-  T *operator->() { return &descriptor.data[offset]; }
+  reference operator*() { return descriptor->data[offset]; }
+  pointer operator->() { return &descriptor->data[offset]; }
 
   const std::array<int64_t, Rank> &getIndices() { return indices; }
 
   bool operator==(const StridedMemrefIterator &other) const {
-    return other.offset == offset && &other.descriptor == &descriptor;
+    return other.offset == offset && other.descriptor == descriptor;
   }
 
   bool operator!=(const StridedMemrefIterator &other) const {
@@ -245,16 +252,24 @@ class StridedMemrefIterator {
   /// Offset in the buffer. This can be derived from the indices and the
   /// descriptor.
   int64_t offset = 0;
+
   /// Array of indices in the multi-dimensional memref.
   std::array<int64_t, Rank> indices = {};
+
   /// Descriptor for the strided memref.
-  StridedMemRefType<T, Rank> &descriptor;
+  StridedMemRefType<T, Rank> *descriptor;
 };
 
 /// Iterate over all elements in a 0-ranked strided memref.
 template <typename T>
 class StridedMemrefIterator<T, 0> {
 public:
+  using iterator_category = std::forward_iterator_tag;
+  using value_type = T;
+  using 
diff erence_type = std::ptr
diff _t;
+  using pointer = T *;
+  using reference = T &;
+
   StridedMemrefIterator(StridedMemRefType<T, 0> &descriptor, int64_t offset = 0)
       : elt(descriptor.data + offset) {}
 
@@ -263,8 +278,8 @@ class StridedMemrefIterator<T, 0> {
     return *this;
   }
 
-  T &operator*() { return *elt; }
-  T *operator->() { return elt; }
+  reference operator*() { return *elt; }
+  pointer operator->() { return elt; }
 
   // There are no indices for a 0-ranked memref, but this API is provided for
   // consistency with the general case.
@@ -301,10 +316,20 @@ struct UnrankedMemRefType {
 //===----------------------------------------------------------------------===//
 // DynamicMemRefType type.
 //===----------------------------------------------------------------------===//
+template <typename T>
+class DynamicMemRefIterator;
+
 // A reference to one of the StridedMemRef types.
 template <typename T>
 class DynamicMemRefType {
 public:
+  int64_t rank;
+  T *basePtr;
+  T *data;
+  int64_t offset;
+  const int64_t *sizes;
+  const int64_t *strides;
+
   explicit DynamicMemRefType(const StridedMemRefType<T, 0> &memRef)
       : rank(0), basePtr(memRef.basePtr), data(memRef.data),
         offset(memRef.offset), sizes(nullptr), strides(nullptr) {}
@@ -322,12 +347,108 @@ class DynamicMemRefType {
     strides = sizes + rank;
   }
 
-  int64_t rank;
-  T *basePtr;
-  T *data;
-  int64_t offset;
-  const int64_t *sizes;
-  const int64_t *strides;
+  template <typename Range,
+            typename sfinae = decltype(std::declval<Range>().begin())>
+  T &operator[](Range &&indices) {
+    assert(indices.size() == rank &&
+           "indices should match rank in memref subscript");
+    if (rank == 0)
+      return data[offset];
+
+    int64_t curOffset = offset;
+    for (int dim = rank - 1; dim >= 0; --dim) {
+      int64_t currentIndex = *(indices.begin() + dim);
+      assert(currentIndex < sizes[dim] && "Index overflow");
+      curOffset += currentIndex * strides[dim];
+    }
+    return data[curOffset];
+  }
+
+  DynamicMemRefIterator<T> begin() { return {*this}; }
+  DynamicMemRefIterator<T> end() { return {*this, -1}; }
+
+  // This operator[] is extremely slow and only for sugaring purposes.
+  DynamicMemRefType<T> operator[](int64_t idx) {
+    assert(rank > 0 && "can't make a subscript of a zero ranked array");
+
+    DynamicMemRefType<T> res(*this);
+    --res.rank;
+    res.offset += idx * res.strides[0];
+    ++res.sizes;
+    ++res.strides;
+    return res;
+  }
+
+  // This operator* can be used in conjunction with the previous operator[] in
+  // order to access the underlying value in case of zero-ranked memref.
+  T &operator*() {
+    assert(rank == 0 && "not a zero-ranked memRef");
+    return data[offset];
+  }
+};
+
+/// Iterate over all elements in a dynamic memref.
+template <typename T>
+class DynamicMemRefIterator {
+public:
+  using iterator_category = std::forward_iterator_tag;
+  using value_type = T;
+  using 
diff erence_type = std::ptr
diff _t;
+  using pointer = T *;
+  using reference = T &;
+
+  DynamicMemRefIterator(DynamicMemRefType<T> &descriptor, int64_t offset = 0)
+      : offset(offset), descriptor(&descriptor) {
+    indices.resize(descriptor.rank, 0);
+  }
+
+  DynamicMemRefIterator<T> &operator++() {
+    if (descriptor->rank == 0) {
+      offset = -1;
+      return *this;
+    }
+
+    int dim = descriptor->rank - 1;
+
+    while (dim >= 0 && indices[dim] == (descriptor->sizes[dim] - 1)) {
+      offset -= indices[dim] * descriptor->strides[dim];
+      indices[dim] = 0;
+      --dim;
+    }
+
+    if (dim < 0) {
+      offset = -1;
+      return *this;
+    }
+
+    ++indices[dim];
+    offset += descriptor->strides[dim];
+    return *this;
+  }
+
+  reference operator*() { return descriptor->data[offset]; }
+  pointer operator->() { return &descriptor->data[offset]; }
+
+  const std::vector<int64_t> &getIndices() { return indices; }
+
+  bool operator==(const DynamicMemRefIterator &other) const {
+    return other.offset == offset && other.descriptor == descriptor;
+  }
+
+  bool operator!=(const DynamicMemRefIterator &other) const {
+    return !(*this == other);
+  }
+
+private:
+  /// Offset in the buffer. This can be derived from the indices and the
+  /// descriptor.
+  int64_t offset = 0;
+
+  /// Array of indices in the multi-dimensional memref.
+  std::vector<int64_t> indices = {};
+
+  /// Descriptor for the dynamic memref.
+  DynamicMemRefType<T> *descriptor;
 };
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/unittests/ExecutionEngine/CMakeLists.txt b/mlir/unittests/ExecutionEngine/CMakeLists.txt
index d17acb6647f81..32722d0dd9582 100644
--- a/mlir/unittests/ExecutionEngine/CMakeLists.txt
+++ b/mlir/unittests/ExecutionEngine/CMakeLists.txt
@@ -1,4 +1,5 @@
 add_mlir_unittest(MLIRExecutionEngineTests
+  DynamicMemRef.cpp
   Invoke.cpp
 )
 get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)

diff  --git a/mlir/unittests/ExecutionEngine/DynamicMemRef.cpp b/mlir/unittests/ExecutionEngine/DynamicMemRef.cpp
new file mode 100644
index 0000000000000..5f4f012702468
--- /dev/null
+++ b/mlir/unittests/ExecutionEngine/DynamicMemRef.cpp
@@ -0,0 +1,99 @@
+//===- DynamicMemRef.cpp ----------------------------------------*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/ExecutionEngine/CRunnerUtils.h"
+#include "llvm/ADT/SmallVector.h"
+
+#include "gmock/gmock.h"
+
+using namespace ::mlir;
+using namespace ::testing;
+
+TEST(DynamicMemRef, rankZero) {
+  int data = 57;
+
+  StridedMemRefType<int, 0> memRef;
+  memRef.basePtr = &data;
+  memRef.data = &data;
+  memRef.offset = 0;
+
+  DynamicMemRefType<int> dynamicMemRef(memRef);
+
+  llvm::SmallVector<int, 1> values(dynamicMemRef.begin(), dynamicMemRef.end());
+  EXPECT_THAT(values, ElementsAre(57));
+}
+
+TEST(DynamicMemRef, rankOne) {
+  std::array<int, 3> data;
+
+  for (size_t i = 0; i < data.size(); ++i) {
+    data[i] = i;
+  }
+
+  StridedMemRefType<int, 1> memRef;
+  memRef.basePtr = data.data();
+  memRef.data = data.data();
+  memRef.offset = 0;
+  memRef.sizes[0] = 3;
+  memRef.strides[0] = 1;
+
+  DynamicMemRefType<int> dynamicMemRef(memRef);
+
+  llvm::SmallVector<int, 3> values(dynamicMemRef.begin(), dynamicMemRef.end());
+  EXPECT_THAT(values, ElementsAreArray(data));
+
+  for (int64_t i = 0; i < 3; ++i) {
+    EXPECT_EQ(*dynamicMemRef[i], data[i]);
+  }
+}
+
+TEST(DynamicMemRef, rankTwo) {
+  std::array<int, 6> data;
+
+  for (size_t i = 0; i < data.size(); ++i) {
+    data[i] = i;
+  }
+
+  StridedMemRefType<int, 2> memRef;
+  memRef.basePtr = data.data();
+  memRef.data = data.data();
+  memRef.offset = 0;
+  memRef.sizes[0] = 2;
+  memRef.sizes[1] = 3;
+  memRef.strides[0] = 3;
+  memRef.strides[1] = 1;
+
+  DynamicMemRefType<int> dynamicMemRef(memRef);
+
+  llvm::SmallVector<int, 6> values(dynamicMemRef.begin(), dynamicMemRef.end());
+  EXPECT_THAT(values, ElementsAreArray(data));
+}
+
+TEST(DynamicMemRef, rankThree) {
+  std::array<int, 24> data;
+
+  for (size_t i = 0; i < data.size(); ++i) {
+    data[i] = i;
+  }
+
+  StridedMemRefType<int, 3> memRef;
+  memRef.basePtr = data.data();
+  memRef.data = data.data();
+  memRef.offset = 0;
+  memRef.sizes[0] = 2;
+  memRef.sizes[1] = 3;
+  memRef.sizes[2] = 4;
+  memRef.strides[0] = 12;
+  memRef.strides[1] = 4;
+  memRef.strides[2] = 1;
+
+  DynamicMemRefType<int> dynamicMemRef(memRef);
+
+  llvm::SmallVector<int, 24> values(dynamicMemRef.begin(), dynamicMemRef.end());
+  EXPECT_THAT(values, ElementsAreArray(data));
+}
\ No newline at end of file


        


More information about the Mlir-commits mailing list