[Mlir-commits] [mlir] 8198739 - Fix StridedMemRefType operator[] SFINAE to allow correctly selecting the `int64_t` overload for non-container operands
Mehdi Amini
llvmlistbot at llvm.org
Wed Feb 10 12:02:26 PST 2021
Author: Mehdi Amini
Date: 2021-02-10T20:02:11Z
New Revision: 81987396ac2ceff56caaa19f54786834523f16db
URL: https://github.com/llvm/llvm-project/commit/81987396ac2ceff56caaa19f54786834523f16db
DIFF: https://github.com/llvm/llvm-project/commit/81987396ac2ceff56caaa19f54786834523f16db.diff
LOG: Fix StridedMemRefType operator[] SFINAE to allow correctly selecting the `int64_t` overload for non-container operands
Added:
Modified:
mlir/include/mlir/ExecutionEngine/CRunnerUtils.h
mlir/unittests/ExecutionEngine/Invoke.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h b/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h
index 0c2638307604..a263fc2ad1d3 100644
--- a/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h
+++ b/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h
@@ -133,8 +133,9 @@ struct StridedMemRefType {
int64_t sizes[N];
int64_t strides[N];
- template <typename Range>
- T &operator[](Range indices) {
+ template <typename Range,
+ typename sfinae = decltype(std::declval<Range>().begin())>
+ T &operator[](Range &&indices) {
assert(indices.size() == N &&
"indices should match rank in memref subscript");
int64_t curOffset = offset;
@@ -170,7 +171,8 @@ struct StridedMemRefType<T, 1> {
int64_t sizes[1];
int64_t strides[1];
- template <typename Range>
+ template <typename Range,
+ typename sfinae = decltype(std::declval<Range>().begin())>
T &operator[](Range indices) {
assert(indices.size() == 1 &&
"indices should match rank in memref subscript");
@@ -190,7 +192,8 @@ struct StridedMemRefType<T, 0> {
T *data;
int64_t offset;
- template <typename Range>
+ template <typename Range,
+ typename sfinae = decltype(std::declval<Range>().begin())>
T &operator[](Range indices) {
assert((indices.size() == 0) &&
"Expect empty indices for 0-rank memref subscript");
diff --git a/mlir/unittests/ExecutionEngine/Invoke.cpp b/mlir/unittests/ExecutionEngine/Invoke.cpp
index 29c59bdba857..9b7450e2a4f6 100644
--- a/mlir/unittests/ExecutionEngine/Invoke.cpp
+++ b/mlir/unittests/ExecutionEngine/Invoke.cpp
@@ -171,10 +171,12 @@ TEST(NativeMemRefJit, BasicMemref) {
ASSERT_EQ(A->sizes[1], M);
ASSERT_EQ(A->strides[0], M + 1);
ASSERT_EQ(A->strides[1], 1);
- for (int i = 0; i < K; ++i)
- for (int j = 0; j < M; ++j)
+ for (int i = 0; i < K; ++i) {
+ for (int j = 0; j < M; ++j) {
EXPECT_EQ((A[{i, j}]), i * M + j);
-
+ EXPECT_EQ(&(A[{i, j}]), &((*A)[i][j]));
+ }
+ }
std::string moduleStr = R"mlir(
func @rank2_memref(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?xf32>) attributes { llvm.emit_c_interface } {
%x = constant 2 : index
@@ -196,7 +198,7 @@ TEST(NativeMemRefJit, BasicMemref) {
llvm::Error error = jit->invoke("rank2_memref", &*A, &*A);
ASSERT_TRUE(!error);
- EXPECT_EQ((A[{1, 2}]), 42.);
+ EXPECT_EQ(((*A)[1][2]), 42.);
EXPECT_EQ((A[{2, 1}]), 42.);
}
More information about the Mlir-commits
mailing list