[Mlir-commits] [mlir] 8d5c1b4 - [mlir][CRunnerUtils] Fix iterators accessing MemRefs with non-zero offset

Alex Zinenko llvmlistbot at llvm.org
Thu Sep 14 04:14:19 PDT 2023


Author: Felix Schneider
Date: 2023-09-14T13:14:13+02:00
New Revision: 8d5c1b4562f880a61c9d9a2bddad73f584cdf311

URL: https://github.com/llvm/llvm-project/commit/8d5c1b4562f880a61c9d9a2bddad73f584cdf311
DIFF: https://github.com/llvm/llvm-project/commit/8d5c1b4562f880a61c9d9a2bddad73f584cdf311.diff

LOG: [mlir][CRunnerUtils] Fix iterators accessing MemRefs with non-zero offset

MemRef descriptors contain - among others - a field called `alignedPtr` or `data` and a field called `offset`. The actual buffer of the MemRef starts at `offset` elements after `alignedPtr`. In the CRunnerUtils, there exist helper classes to iterate over MemRefs' elements but the `offset` is not handled consistently so that accessing a MemRef with an `offset` != 0 via an iterator will lead to incorrect results.

The problem is that "offset" can be understood in two ways, firstly as the offset of the beginning of the MemRef with respect to the `alignedPtr`, ie what the `offset` field means in the MemRef descriptor, and secondly as the offset of some element within the MemRef relative to the first element of the MemRef, which could more accurately be called something like `linearIndex`.

The `offset` field within `StridedMemRefIterator` and `DynamicMemRefIterator` are interpreted the first way, therefore the offsets passed to the constructors of these classes need to account for the already existing offset in the descriptor on top of any potential "shift" within the MemRef.
This patch takes care of that and adds some basic tests that catch problems with indexing MemRefs with an `offset`.

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D157008

Added: 
    mlir/unittests/ExecutionEngine/StridedMemRef.cpp

Modified: 
    mlir/include/mlir/ExecutionEngine/CRunnerUtils.h
    mlir/unittests/ExecutionEngine/CMakeLists.txt
    mlir/unittests/ExecutionEngine/DynamicMemRef.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h b/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h
index e7798b2136af07e..e8f429463cb0b9b 100644
--- a/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h
+++ b/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h
@@ -149,7 +149,7 @@ struct StridedMemRefType {
     return data[curOffset];
   }
 
-  StridedMemrefIterator<T, N> begin() { return {*this}; }
+  StridedMemrefIterator<T, N> begin() { return {*this, offset}; }
   StridedMemrefIterator<T, N> end() { return {*this, -1}; }
 
   // This operator[] is extremely slow and only for sugaring purposes.
@@ -181,7 +181,7 @@ struct StridedMemRefType<T, 1> {
     return (*this)[*indices.begin()];
   }
 
-  StridedMemrefIterator<T, 1> begin() { return {*this}; }
+  StridedMemrefIterator<T, 1> begin() { return {*this, offset}; }
   StridedMemrefIterator<T, 1> end() { return {*this, -1}; }
 
   T &operator[](int64_t idx) { return *(data + offset + idx * strides[0]); }
@@ -202,8 +202,8 @@ struct StridedMemRefType<T, 0> {
     return data[offset];
   }
 
-  StridedMemrefIterator<T, 0> begin() { return {*this}; }
-  StridedMemrefIterator<T, 0> end() { return {*this, 1}; }
+  StridedMemrefIterator<T, 0> begin() { return {*this, offset}; }
+  StridedMemrefIterator<T, 0> end() { return {*this, offset + 1}; }
 };
 
 /// Iterate over all elements in a strided memref.
@@ -364,7 +364,7 @@ class DynamicMemRefType {
     return data[curOffset];
   }
 
-  DynamicMemRefIterator<T> begin() { return {*this}; }
+  DynamicMemRefIterator<T> begin() { return {*this, offset}; }
   DynamicMemRefIterator<T> end() { return {*this, -1}; }
 
   // This operator[] is extremely slow and only for sugaring purposes.

diff  --git a/mlir/unittests/ExecutionEngine/CMakeLists.txt b/mlir/unittests/ExecutionEngine/CMakeLists.txt
index 2676cb63cd2407c..383e172aa3f6670 100644
--- a/mlir/unittests/ExecutionEngine/CMakeLists.txt
+++ b/mlir/unittests/ExecutionEngine/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_unittest(MLIRExecutionEngineTests
   DynamicMemRef.cpp
+  StridedMemRef.cpp
   Invoke.cpp
 )
 get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)

diff  --git a/mlir/unittests/ExecutionEngine/DynamicMemRef.cpp b/mlir/unittests/ExecutionEngine/DynamicMemRef.cpp
index 5f4f01270246821..78a2440b59cbaab 100644
--- a/mlir/unittests/ExecutionEngine/DynamicMemRef.cpp
+++ b/mlir/unittests/ExecutionEngine/DynamicMemRef.cpp
@@ -96,4 +96,29 @@ TEST(DynamicMemRef, rankThree) {
 
   llvm::SmallVector<int, 24> values(dynamicMemRef.begin(), dynamicMemRef.end());
   EXPECT_THAT(values, ElementsAreArray(data));
-}
\ No newline at end of file
+}
+
+TEST(DynamicMemRef, rankOneWithOffset) {
+  constexpr int offset = 4;
+  std::array<int, 3 + offset> buffer;
+
+  for (size_t i = 0; i < buffer.size(); ++i) {
+    buffer[i] = i;
+  }
+
+  StridedMemRefType<int, 1> memRef;
+  memRef.basePtr = buffer.data();
+  memRef.data = buffer.data();
+  memRef.offset = offset;
+  memRef.sizes[0] = 3;
+  memRef.strides[0] = 1;
+
+  DynamicMemRefType<int> dynamicMemRef(memRef);
+
+  llvm::SmallVector<int, 3> values(dynamicMemRef.begin(), dynamicMemRef.end());
+
+  for (int64_t i = 0; i < 3; ++i) {
+    EXPECT_EQ(values[i], buffer[offset + i]);
+    EXPECT_EQ(*dynamicMemRef[i], buffer[offset + i]);
+  }
+}

diff  --git a/mlir/unittests/ExecutionEngine/StridedMemRef.cpp b/mlir/unittests/ExecutionEngine/StridedMemRef.cpp
new file mode 100644
index 000000000000000..f5ffcc8fc911f10
--- /dev/null
+++ b/mlir/unittests/ExecutionEngine/StridedMemRef.cpp
@@ -0,0 +1,41 @@
+//===- StridedMemRef.cpp ----------------------------------------*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/ExecutionEngine/CRunnerUtils.h"
+#include "llvm/ADT/SmallVector.h"
+
+#include "gmock/gmock.h"
+
+using namespace ::mlir;
+using namespace ::testing;
+
+TEST(StridedMemRef, rankOneWithOffset) {
+  std::array<int, 15> data;
+
+  for (size_t i = 0; i < data.size(); ++i) {
+    data[i] = i;
+  }
+
+  StridedMemRefType<int, 1> memRefA;
+  memRefA.basePtr = data.data();
+  memRefA.data = data.data();
+  memRefA.offset = 0;
+  memRefA.sizes[0] = 10;
+  memRefA.strides[0] = 1;
+
+  StridedMemRefType<int, 1> memRefB = memRefA;
+  memRefB.offset = 5;
+
+  llvm::SmallVector<int, 10> valuesA(memRefA.begin(), memRefA.end());
+  llvm::SmallVector<int, 10> valuesB(memRefB.begin(), memRefB.end());
+
+  for (int64_t i = 0; i < 10; ++i) {
+    EXPECT_EQ(valuesA[i], i);
+    EXPECT_EQ(valuesA[i] + 5, valuesB[i]);
+  }
+}


        


More information about the Mlir-commits mailing list