[Mlir-commits] [mlir] d18bfb2 - [mlir][sparse] Add readCOOElement for reading a sparse tensor element from files.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Oct 16 07:25:29 PDT 2022


Author: bixia1
Date: 2022-10-16T07:25:21-07:00
New Revision: d18bfb23f3121c9d36da6d90e1884b3de88c64de

URL: https://github.com/llvm/llvm-project/commit/d18bfb23f3121c9d36da6d90e1884b3de88c64de
DIFF: https://github.com/llvm/llvm-project/commit/d18bfb23f3121c9d36da6d90e1884b3de88c64de.diff

LOG: [mlir][sparse] Add readCOOElement for reading a sparse tensor element from files.

Use the routine for openSparseTensorCOO and getSparseTensorReaderNext.

Reviewed By: aartbik, wrengr

Differential Revision: https://reviews.llvm.org/D135732

Added: 
    

Modified: 
    mlir/include/mlir/ExecutionEngine/SparseTensor/File.h
    mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h b/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h
index f5420912de3e..2382e42d8511 100644
--- a/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h
+++ b/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h
@@ -33,6 +33,44 @@
 namespace mlir {
 namespace sparse_tensor {
 
+namespace detail {
+
+template <typename T>
+struct is_complex final : public std::false_type {};
+
+template <typename T>
+struct is_complex<std::complex<T>> final : public std::true_type {};
+
+/// Reads an element of a non-complex type for the current indices in
+/// coordinate scheme.
+template <typename V>
+inline std::enable_if_t<!is_complex<V>::value, V>
+readCOOValue(char **linePtr, bool is_pattern) {
+  // The external formats always store these numerical values with the type
+  // double, but we cast these values to the sparse tensor object type.
+  // For a pattern tensor, we arbitrarily pick the value 1 for all entries.
+  return is_pattern ? 1.0 : strtod(*linePtr, linePtr);
+}
+
+/// Reads an element of a complex type for the current indices in
+/// coordinate scheme.
+template <typename V>
+inline std::enable_if_t<is_complex<V>::value, V> readCOOValue(char **linePtr,
+                                                              bool is_pattern) {
+  // Read two values to make a complex. The external formats always store
+  // numerical values with the type double, but we cast these values to the
+  // sparse tensor object type. For a pattern tensor, we arbitrarily pick the
+  // value 1 for all entries.
+  double re = is_pattern ? 1.0 : strtod(*linePtr, linePtr);
+  double im = is_pattern ? 1.0 : strtod(*linePtr, linePtr);
+  // Avoiding brace-notation since that forbids narrowing to `float`.
+  return V(re, im);
+}
+
+} // namespace detail
+
+//===----------------------------------------------------------------------===//
+
 // TODO: benchmark whether to keep various methods inline vs moving them
 // off to the cpp file.
 
@@ -132,6 +170,31 @@ class SparseTensorReader final {
   /// valid after parsing the header.
   void assertMatchesShape(uint64_t rank, const uint64_t *shape) const;
 
+  /// Reads a sparse tensor element from the next line in the input file and
+  /// returns the value of the element. Stores the coordinates of the element
+  /// to the `indices` array.
+  template <typename V>
+  V readCOOElement(uint64_t rank, uint64_t *indices,
+                   const uint64_t *perm = nullptr) {
+    assert(rank == getRank() && "Rank mismatch");
+    char *linePtr = readLine();
+    if (perm)
+      for (uint64_t r = 0; r < rank; ++r) {
+        // Parse the 1-based index.
+        uint64_t idx = strtoul(linePtr, &linePtr, 10);
+        // Store the 0-based index.
+        indices[perm[r]] = idx - 1;
+      }
+    else
+      for (uint64_t r = 0; r < rank; ++r) {
+        // Parse the 1-based index.
+        uint64_t idx = strtoul(linePtr, &linePtr, 10);
+        // Store the 0-based index.
+        indices[r] = idx - 1;
+      }
+    return detail::readCOOValue<V>(&linePtr, isPattern());
+  }
+
 private:
   /// Reads the MME header of a general sparse matrix of type real.
   void readMMEHeader();
@@ -152,41 +215,6 @@ class SparseTensorReader final {
 };
 
 //===----------------------------------------------------------------------===//
-namespace detail {
-
-template <typename T>
-struct is_complex final : public std::false_type {};
-
-template <typename T>
-struct is_complex<std::complex<T>> final : public std::true_type {};
-
-/// Reads an element of a non-complex type for the current indices in
-/// coordinate scheme.
-template <typename V>
-inline std::enable_if_t<!is_complex<V>::value, V>
-readCOOValue(char **linePtr, bool is_pattern) {
-  // The external formats always store these numerical values with the type
-  // double, but we cast these values to the sparse tensor object type.
-  // For a pattern tensor, we arbitrarily pick the value 1 for all entries.
-  return is_pattern ? 1.0 : strtod(*linePtr, linePtr);
-}
-
-/// Reads an element of a complex type for the current indices in
-/// coordinate scheme.
-template <typename V>
-inline std::enable_if_t<is_complex<V>::value, V> readCOOValue(char **linePtr,
-                                                              bool is_pattern) {
-  // Read two values to make a complex. The external formats always store
-  // numerical values with the type double, but we cast these values to the
-  // sparse tensor object type. For a pattern tensor, we arbitrarily pick the
-  // value 1 for all entries.
-  double re = is_pattern ? 1.0 : strtod(*linePtr, linePtr);
-  double im = is_pattern ? 1.0 : strtod(*linePtr, linePtr);
-  // Avoiding brace-notation since that forbids narrowing to `float`.
-  return V(re, im);
-}
-
-} // namespace detail
 
 /// Reads a sparse tensor with the given filename into a memory-resident
 /// sparse tensor in coordinate scheme.
@@ -211,14 +239,7 @@ openSparseTensorCOO(const char *filename, uint64_t rank, const uint64_t *shape,
   // Read all nonzero elements.
   std::vector<uint64_t> indices(rank);
   for (uint64_t k = 0; k < nnz; ++k) {
-    char *linePtr = stfile.readLine();
-    for (uint64_t r = 0; r < rank; ++r) {
-      // Parse the 1-based index.
-      uint64_t idx = strtoul(linePtr, &linePtr, 10);
-      // Add the 0-based index.
-      indices[perm[r]] = idx - 1;
-    }
-    const V value = detail::readCOOValue<V>(&linePtr, stfile.isPattern());
+    const V value = stfile.readCOOElement<V>(rank, indices.data(), perm);
     // TODO: <https://github.com/llvm/llvm-project/issues/54179>
     coo->add(indices, value);
     // We currently chose to deal with symmetric matrices by fully

diff  --git a/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp b/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp
index 0191fd144a17..721241eb98f1 100644
--- a/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp
+++ b/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp
@@ -626,13 +626,8 @@ void delSparseTensorReader(void *p) {
     index_type *indices = iref->data + iref->offset;                           \
     SparseTensorReader *stfile = static_cast<SparseTensorReader *>(p);         \
     index_type rank = stfile->getRank();                                       \
-    char *linePtr = stfile->readLine();                                        \
-    for (index_type r = 0; r < rank; ++r) {                                    \
-      uint64_t idx = strtoul(linePtr, &linePtr, 10);                           \
-      indices[r] = idx - 1;                                                    \
-    }                                                                          \
     V *value = vref->data + vref->offset;                                      \
-    *value = detail::readCOOValue<V>(&linePtr, stfile->isPattern());           \
+    *value = stfile->readCOOElement<V>(rank, indices);                         \
   }
 MLIR_SPARSETENSOR_FOREVERY_V(IMPL_GETNEXT)
 #undef IMPL_GETNEXT


        


More information about the Mlir-commits mailing list