[Mlir-commits] [mlir] d18bfb2 - [mlir][sparse] Add readCOOElement for reading a sparse tensor element from files.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Oct 16 07:25:29 PDT 2022
Author: bixia1
Date: 2022-10-16T07:25:21-07:00
New Revision: d18bfb23f3121c9d36da6d90e1884b3de88c64de
URL: https://github.com/llvm/llvm-project/commit/d18bfb23f3121c9d36da6d90e1884b3de88c64de
DIFF: https://github.com/llvm/llvm-project/commit/d18bfb23f3121c9d36da6d90e1884b3de88c64de.diff
LOG: [mlir][sparse] Add readCOOElement for reading a sparse tensor element from files.
Use the routine for openSparseTensorCOO and getSparseTensorReaderNext.
Reviewed By: aartbik, wrengr
Differential Revision: https://reviews.llvm.org/D135732
Added:
Modified:
mlir/include/mlir/ExecutionEngine/SparseTensor/File.h
mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h b/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h
index f5420912de3e..2382e42d8511 100644
--- a/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h
+++ b/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h
@@ -33,6 +33,44 @@
namespace mlir {
namespace sparse_tensor {
+namespace detail {
+
+template <typename T>
+struct is_complex final : public std::false_type {};
+
+template <typename T>
+struct is_complex<std::complex<T>> final : public std::true_type {};
+
+/// Reads an element of a non-complex type for the current indices in
+/// coordinate scheme.
+template <typename V>
+inline std::enable_if_t<!is_complex<V>::value, V>
+readCOOValue(char **linePtr, bool is_pattern) {
+ // The external formats always store these numerical values with the type
+ // double, but we cast these values to the sparse tensor object type.
+ // For a pattern tensor, we arbitrarily pick the value 1 for all entries.
+ return is_pattern ? 1.0 : strtod(*linePtr, linePtr);
+}
+
+/// Reads an element of a complex type for the current indices in
+/// coordinate scheme.
+template <typename V>
+inline std::enable_if_t<is_complex<V>::value, V> readCOOValue(char **linePtr,
+ bool is_pattern) {
+ // Read two values to make a complex. The external formats always store
+ // numerical values with the type double, but we cast these values to the
+ // sparse tensor object type. For a pattern tensor, we arbitrarily pick the
+ // value 1 for all entries.
+ double re = is_pattern ? 1.0 : strtod(*linePtr, linePtr);
+ double im = is_pattern ? 1.0 : strtod(*linePtr, linePtr);
+ // Avoiding brace-notation since that forbids narrowing to `float`.
+ return V(re, im);
+}
+
+} // namespace detail
+
+//===----------------------------------------------------------------------===//
+
// TODO: benchmark whether to keep various methods inline vs moving them
// off to the cpp file.
@@ -132,6 +170,31 @@ class SparseTensorReader final {
/// valid after parsing the header.
void assertMatchesShape(uint64_t rank, const uint64_t *shape) const;
+ /// Reads a sparse tensor element from the next line in the input file and
+ /// returns the value of the element. Stores the coordinates of the element
+ /// to the `indices` array.
+ template <typename V>
+ V readCOOElement(uint64_t rank, uint64_t *indices,
+ const uint64_t *perm = nullptr) {
+ assert(rank == getRank() && "Rank mismatch");
+ char *linePtr = readLine();
+ if (perm)
+ for (uint64_t r = 0; r < rank; ++r) {
+ // Parse the 1-based index.
+ uint64_t idx = strtoul(linePtr, &linePtr, 10);
+ // Store the 0-based index.
+ indices[perm[r]] = idx - 1;
+ }
+ else
+ for (uint64_t r = 0; r < rank; ++r) {
+ // Parse the 1-based index.
+ uint64_t idx = strtoul(linePtr, &linePtr, 10);
+ // Store the 0-based index.
+ indices[r] = idx - 1;
+ }
+ return detail::readCOOValue<V>(&linePtr, isPattern());
+ }
+
private:
/// Reads the MME header of a general sparse matrix of type real.
void readMMEHeader();
@@ -152,41 +215,6 @@ class SparseTensorReader final {
};
//===----------------------------------------------------------------------===//
-namespace detail {
-
-template <typename T>
-struct is_complex final : public std::false_type {};
-
-template <typename T>
-struct is_complex<std::complex<T>> final : public std::true_type {};
-
-/// Reads an element of a non-complex type for the current indices in
-/// coordinate scheme.
-template <typename V>
-inline std::enable_if_t<!is_complex<V>::value, V>
-readCOOValue(char **linePtr, bool is_pattern) {
- // The external formats always store these numerical values with the type
- // double, but we cast these values to the sparse tensor object type.
- // For a pattern tensor, we arbitrarily pick the value 1 for all entries.
- return is_pattern ? 1.0 : strtod(*linePtr, linePtr);
-}
-
-/// Reads an element of a complex type for the current indices in
-/// coordinate scheme.
-template <typename V>
-inline std::enable_if_t<is_complex<V>::value, V> readCOOValue(char **linePtr,
- bool is_pattern) {
- // Read two values to make a complex. The external formats always store
- // numerical values with the type double, but we cast these values to the
- // sparse tensor object type. For a pattern tensor, we arbitrarily pick the
- // value 1 for all entries.
- double re = is_pattern ? 1.0 : strtod(*linePtr, linePtr);
- double im = is_pattern ? 1.0 : strtod(*linePtr, linePtr);
- // Avoiding brace-notation since that forbids narrowing to `float`.
- return V(re, im);
-}
-
-} // namespace detail
/// Reads a sparse tensor with the given filename into a memory-resident
/// sparse tensor in coordinate scheme.
@@ -211,14 +239,7 @@ openSparseTensorCOO(const char *filename, uint64_t rank, const uint64_t *shape,
// Read all nonzero elements.
std::vector<uint64_t> indices(rank);
for (uint64_t k = 0; k < nnz; ++k) {
- char *linePtr = stfile.readLine();
- for (uint64_t r = 0; r < rank; ++r) {
- // Parse the 1-based index.
- uint64_t idx = strtoul(linePtr, &linePtr, 10);
- // Add the 0-based index.
- indices[perm[r]] = idx - 1;
- }
- const V value = detail::readCOOValue<V>(&linePtr, stfile.isPattern());
+ const V value = stfile.readCOOElement<V>(rank, indices.data(), perm);
// TODO: <https://github.com/llvm/llvm-project/issues/54179>
coo->add(indices, value);
// We currently chose to deal with symmetric matrices by fully
diff --git a/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp b/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp
index 0191fd144a17..721241eb98f1 100644
--- a/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp
+++ b/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp
@@ -626,13 +626,8 @@ void delSparseTensorReader(void *p) {
index_type *indices = iref->data + iref->offset; \
SparseTensorReader *stfile = static_cast<SparseTensorReader *>(p); \
index_type rank = stfile->getRank(); \
- char *linePtr = stfile->readLine(); \
- for (index_type r = 0; r < rank; ++r) { \
- uint64_t idx = strtoul(linePtr, &linePtr, 10); \
- indices[r] = idx - 1; \
- } \
V *value = vref->data + vref->offset; \
- *value = detail::readCOOValue<V>(&linePtr, stfile->isPattern()); \
+ *value = stfile->readCOOElement<V>(rank, indices); \
}
MLIR_SPARSETENSOR_FOREVERY_V(IMPL_GETNEXT)
#undef IMPL_GETNEXT
More information about the Mlir-commits
mailing list