[Mlir-commits] [mlir] 8198739 - Fix StridedMemRefType operator[] SFINAE to allow correctly selecting the `int64_t` overload for non-container operands

Mehdi Amini llvmlistbot at llvm.org
Wed Feb 10 12:02:26 PST 2021


Author: Mehdi Amini
Date: 2021-02-10T20:02:11Z
New Revision: 81987396ac2ceff56caaa19f54786834523f16db

URL: https://github.com/llvm/llvm-project/commit/81987396ac2ceff56caaa19f54786834523f16db
DIFF: https://github.com/llvm/llvm-project/commit/81987396ac2ceff56caaa19f54786834523f16db.diff

LOG: Fix StridedMemRefType operator[] SFINAE to allow correctly selecting the `int64_t` overload for non-container operands

Added: 
    

Modified: 
    mlir/include/mlir/ExecutionEngine/CRunnerUtils.h
    mlir/unittests/ExecutionEngine/Invoke.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h b/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h
index 0c2638307604..a263fc2ad1d3 100644
--- a/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h
+++ b/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h
@@ -133,8 +133,9 @@ struct StridedMemRefType {
   int64_t sizes[N];
   int64_t strides[N];
 
-  template <typename Range>
-  T &operator[](Range indices) {
+  template <typename Range,
+            typename sfinae = decltype(std::declval<Range>().begin())>
+  T &operator[](Range &&indices) {
     assert(indices.size() == N &&
            "indices should match rank in memref subscript");
     int64_t curOffset = offset;
@@ -170,7 +171,8 @@ struct StridedMemRefType<T, 1> {
   int64_t sizes[1];
   int64_t strides[1];
 
-  template <typename Range>
+  template <typename Range,
+            typename sfinae = decltype(std::declval<Range>().begin())>
   T &operator[](Range indices) {
     assert(indices.size() == 1 &&
            "indices should match rank in memref subscript");
@@ -190,7 +192,8 @@ struct StridedMemRefType<T, 0> {
   T *data;
   int64_t offset;
 
-  template <typename Range>
+  template <typename Range,
+            typename sfinae = decltype(std::declval<Range>().begin())>
   T &operator[](Range indices) {
     assert((indices.size() == 0) &&
            "Expect empty indices for 0-rank memref subscript");

diff  --git a/mlir/unittests/ExecutionEngine/Invoke.cpp b/mlir/unittests/ExecutionEngine/Invoke.cpp
index 29c59bdba857..9b7450e2a4f6 100644
--- a/mlir/unittests/ExecutionEngine/Invoke.cpp
+++ b/mlir/unittests/ExecutionEngine/Invoke.cpp
@@ -171,10 +171,12 @@ TEST(NativeMemRefJit, BasicMemref) {
   ASSERT_EQ(A->sizes[1], M);
   ASSERT_EQ(A->strides[0], M + 1);
   ASSERT_EQ(A->strides[1], 1);
-  for (int i = 0; i < K; ++i)
-    for (int j = 0; j < M; ++j)
+  for (int i = 0; i < K; ++i) {
+    for (int j = 0; j < M; ++j) {
       EXPECT_EQ((A[{i, j}]), i * M + j);
-
+      EXPECT_EQ(&(A[{i, j}]), &((*A)[i][j]));
+    }
+  }
   std::string moduleStr = R"mlir(
   func @rank2_memref(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?xf32>) attributes { llvm.emit_c_interface } {
     %x = constant 2 : index
@@ -196,7 +198,7 @@ TEST(NativeMemRefJit, BasicMemref) {
 
   llvm::Error error = jit->invoke("rank2_memref", &*A, &*A);
   ASSERT_TRUE(!error);
-  EXPECT_EQ((A[{1, 2}]), 42.);
+  EXPECT_EQ(((*A)[1][2]), 42.);
   EXPECT_EQ((A[{2, 1}]), 42.);
 }
 


        


More information about the Mlir-commits mailing list