[Mlir-commits] [mlir] 8d5c1b4 - [mlir][CRunnerUtils] Fix iterators accessing MemRefs with non-zero offset
Alex Zinenko
llvmlistbot at llvm.org
Thu Sep 14 04:14:19 PDT 2023
Author: Felix Schneider
Date: 2023-09-14T13:14:13+02:00
New Revision: 8d5c1b4562f880a61c9d9a2bddad73f584cdf311
URL: https://github.com/llvm/llvm-project/commit/8d5c1b4562f880a61c9d9a2bddad73f584cdf311
DIFF: https://github.com/llvm/llvm-project/commit/8d5c1b4562f880a61c9d9a2bddad73f584cdf311.diff
LOG: [mlir][CRunnerUtils] Fix iterators accessing MemRefs with non-zero offset
MemRef descriptors contain - among others - a field called `alignedPtr` or `data` and a field called `offset`. The actual buffer of the MemRef starts at `offset` elements after `alignedPtr`. In the CRunnerUtils, there exist helper classes to iterate over MemRefs' elements but the `offset` is not handled consistently so that accessing a MemRef with an `offset` != 0 via an iterator will lead to incorrect results.
The problem is that "offset" can be understood in two ways, firstly as the offset of the beginning of the MemRef with respect to the `alignedPtr`, ie what the `offset` field means in the MemRef descriptor, and secondly as the offset of some element within the MemRef relative to the first element of the MemRef, which could more accurately be called something like `linearIndex`.
The `offset` field within `StridedMemRefIterator` and `DynamicMemRefIterator` are interpreted the first way, therefore the offsets passed to the constructors of these classes need to account for the already existing offset in the descriptor on top of any potential "shift" within the MemRef.
This patch takes care of that and adds some basic tests that catch problems with indexing MemRefs with an `offset`.
Reviewed By: mehdi_amini
Differential Revision: https://reviews.llvm.org/D157008
Added:
mlir/unittests/ExecutionEngine/StridedMemRef.cpp
Modified:
mlir/include/mlir/ExecutionEngine/CRunnerUtils.h
mlir/unittests/ExecutionEngine/CMakeLists.txt
mlir/unittests/ExecutionEngine/DynamicMemRef.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h b/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h
index e7798b2136af07e..e8f429463cb0b9b 100644
--- a/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h
+++ b/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h
@@ -149,7 +149,7 @@ struct StridedMemRefType {
return data[curOffset];
}
- StridedMemrefIterator<T, N> begin() { return {*this}; }
+ StridedMemrefIterator<T, N> begin() { return {*this, offset}; }
StridedMemrefIterator<T, N> end() { return {*this, -1}; }
// This operator[] is extremely slow and only for sugaring purposes.
@@ -181,7 +181,7 @@ struct StridedMemRefType<T, 1> {
return (*this)[*indices.begin()];
}
- StridedMemrefIterator<T, 1> begin() { return {*this}; }
+ StridedMemrefIterator<T, 1> begin() { return {*this, offset}; }
StridedMemrefIterator<T, 1> end() { return {*this, -1}; }
T &operator[](int64_t idx) { return *(data + offset + idx * strides[0]); }
@@ -202,8 +202,8 @@ struct StridedMemRefType<T, 0> {
return data[offset];
}
- StridedMemrefIterator<T, 0> begin() { return {*this}; }
- StridedMemrefIterator<T, 0> end() { return {*this, 1}; }
+ StridedMemrefIterator<T, 0> begin() { return {*this, offset}; }
+ StridedMemrefIterator<T, 0> end() { return {*this, offset + 1}; }
};
/// Iterate over all elements in a strided memref.
@@ -364,7 +364,7 @@ class DynamicMemRefType {
return data[curOffset];
}
- DynamicMemRefIterator<T> begin() { return {*this}; }
+ DynamicMemRefIterator<T> begin() { return {*this, offset}; }
DynamicMemRefIterator<T> end() { return {*this, -1}; }
// This operator[] is extremely slow and only for sugaring purposes.
diff --git a/mlir/unittests/ExecutionEngine/CMakeLists.txt b/mlir/unittests/ExecutionEngine/CMakeLists.txt
index 2676cb63cd2407c..383e172aa3f6670 100644
--- a/mlir/unittests/ExecutionEngine/CMakeLists.txt
+++ b/mlir/unittests/ExecutionEngine/CMakeLists.txt
@@ -1,5 +1,6 @@
add_mlir_unittest(MLIRExecutionEngineTests
DynamicMemRef.cpp
+ StridedMemRef.cpp
Invoke.cpp
)
get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
diff --git a/mlir/unittests/ExecutionEngine/DynamicMemRef.cpp b/mlir/unittests/ExecutionEngine/DynamicMemRef.cpp
index 5f4f01270246821..78a2440b59cbaab 100644
--- a/mlir/unittests/ExecutionEngine/DynamicMemRef.cpp
+++ b/mlir/unittests/ExecutionEngine/DynamicMemRef.cpp
@@ -96,4 +96,29 @@ TEST(DynamicMemRef, rankThree) {
llvm::SmallVector<int, 24> values(dynamicMemRef.begin(), dynamicMemRef.end());
EXPECT_THAT(values, ElementsAreArray(data));
-}
\ No newline at end of file
+}
+
+TEST(DynamicMemRef, rankOneWithOffset) {
+ constexpr int offset = 4;
+ std::array<int, 3 + offset> buffer;
+
+ for (size_t i = 0; i < buffer.size(); ++i) {
+ buffer[i] = i;
+ }
+
+ StridedMemRefType<int, 1> memRef;
+ memRef.basePtr = buffer.data();
+ memRef.data = buffer.data();
+ memRef.offset = offset;
+ memRef.sizes[0] = 3;
+ memRef.strides[0] = 1;
+
+ DynamicMemRefType<int> dynamicMemRef(memRef);
+
+ llvm::SmallVector<int, 3> values(dynamicMemRef.begin(), dynamicMemRef.end());
+
+ for (int64_t i = 0; i < 3; ++i) {
+ EXPECT_EQ(values[i], buffer[offset + i]);
+ EXPECT_EQ(*dynamicMemRef[i], buffer[offset + i]);
+ }
+}
diff --git a/mlir/unittests/ExecutionEngine/StridedMemRef.cpp b/mlir/unittests/ExecutionEngine/StridedMemRef.cpp
new file mode 100644
index 000000000000000..f5ffcc8fc911f10
--- /dev/null
+++ b/mlir/unittests/ExecutionEngine/StridedMemRef.cpp
@@ -0,0 +1,41 @@
+//===- StridedMemRef.cpp ----------------------------------------*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/ExecutionEngine/CRunnerUtils.h"
+#include "llvm/ADT/SmallVector.h"
+
+#include "gmock/gmock.h"
+
+using namespace ::mlir;
+using namespace ::testing;
+
+TEST(StridedMemRef, rankOneWithOffset) {
+ std::array<int, 15> data;
+
+ for (size_t i = 0; i < data.size(); ++i) {
+ data[i] = i;
+ }
+
+ StridedMemRefType<int, 1> memRefA;
+ memRefA.basePtr = data.data();
+ memRefA.data = data.data();
+ memRefA.offset = 0;
+ memRefA.sizes[0] = 10;
+ memRefA.strides[0] = 1;
+
+ StridedMemRefType<int, 1> memRefB = memRefA;
+ memRefB.offset = 5;
+
+ llvm::SmallVector<int, 10> valuesA(memRefA.begin(), memRefA.end());
+ llvm::SmallVector<int, 10> valuesB(memRefB.begin(), memRefB.end());
+
+ for (int64_t i = 0; i < 10; ++i) {
+ EXPECT_EQ(valuesA[i], i);
+ EXPECT_EQ(valuesA[i] + 5, valuesB[i]);
+ }
+}
More information about the Mlir-commits
mailing list