[Mlir-commits] [mlir] 6860959 - [mlir][sparse] Improve sparse_tensor::detail::readCOOValue template

wren romano llvmlistbot at llvm.org
Thu Sep 29 15:26:41 PDT 2022


Author: wren romano
Date: 2022-09-29T15:26:29-07:00
New Revision: 68609598e45f3c864ad21c596bd62206c8000841

URL: https://github.com/llvm/llvm-project/commit/68609598e45f3c864ad21c596bd62206c8000841
DIFF: https://github.com/llvm/llvm-project/commit/68609598e45f3c864ad21c596bd62206c8000841.diff

LOG: [mlir][sparse] Improve sparse_tensor::detail::readCOOValue template

This is a followup to the refactoring of D133462, D133830, D133831, and D133833.

Depends On D133833

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D133836

Added: 
    

Modified: 
    mlir/include/mlir/ExecutionEngine/SparseTensor/File.h

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h b/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h
index 5dd6c17364a8..fb00e1ecbfc8 100644
--- a/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h
+++ b/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h
@@ -154,50 +154,36 @@ class SparseTensorFile final {
 //===----------------------------------------------------------------------===//
 namespace detail {
 
-// Adds a value to a tensor in coordinate scheme. If is_symmetric_value is true,
-// also adds the value to its symmetric location.
-template <typename T, typename V>
-inline void addValue(T *coo, V value, const std::vector<uint64_t> indices,
-                     bool is_symmetric_value) {
-  // TODO: <https://github.com/llvm/llvm-project/issues/54179>
-  coo->add(indices, value);
-  // We currently chose to deal with symmetric matrices by fully constructing
-  // them. In the future, we may want to make symmetry implicit for storage
-  // reasons.
-  if (is_symmetric_value)
-    coo->add({indices[1], indices[0]}, value);
+template <typename T>
+struct is_complex final : public std::false_type {};
+
+template <typename T>
+struct is_complex<std::complex<T>> final : public std::true_type {};
+
+/// Reads an element of a non-complex type for the current indices in
+/// coordinate scheme.
+template <typename V>
+inline typename std::enable_if<!is_complex<V>::value, V>::type
+readCOOValue(char **linePtr, bool is_pattern) {
+  // The external formats always store these numerical values with the type
+  // double, but we cast these values to the sparse tensor object type.
+  // For a pattern tensor, we arbitrarily pick the value 1 for all entries.
+  return is_pattern ? 1.0 : strtod(*linePtr, linePtr);
 }
 
 /// Reads an element of a complex type for the current indices in
 /// coordinate scheme.
 template <typename V>
-inline void readCOOValue(SparseTensorCOO<std::complex<V>> *coo,
-                         const std::vector<uint64_t> indices, char **linePtr,
-                         bool is_pattern, bool add_symmetric_value) {
+inline typename std::enable_if<is_complex<V>::value, V>::type
+readCOOValue(char **linePtr, bool is_pattern) {
   // Read two values to make a complex. The external formats always store
   // numerical values with the type double, but we cast these values to the
   // sparse tensor object type. For a pattern tensor, we arbitrarily pick the
   // value 1 for all entries.
-  V re = is_pattern ? 1.0 : strtod(*linePtr, linePtr);
-  V im = is_pattern ? 1.0 : strtod(*linePtr, linePtr);
-  std::complex<V> value = {re, im};
-  addValue(coo, value, indices, add_symmetric_value);
-}
-
-// Reads an element of a non-complex type for the current indices in coordinate
-// scheme.
-template <typename V,
-          typename std::enable_if<
-              !std::is_same<std::complex<float>, V>::value &&
-              !std::is_same<std::complex<double>, V>::value>::type * = nullptr>
-inline void readCOOValue(SparseTensorCOO<V> *coo,
-                         const std::vector<uint64_t> indices, char **linePtr,
-                         bool is_pattern, bool is_symmetric_value) {
-  // The external formats always store these numerical values with the type
-  // double, but we cast these values to the sparse tensor object type.
-  // For a pattern tensor, we arbitrarily pick the value 1 for all entries.
-  double value = is_pattern ? 1.0 : strtod(*linePtr, linePtr);
-  addValue(coo, value, indices, is_symmetric_value);
+  double re = is_pattern ? 1.0 : strtod(*linePtr, linePtr);
+  double im = is_pattern ? 1.0 : strtod(*linePtr, linePtr);
+  // Avoiding brace-notation since that forbids narrowing to `float`.
+  return V(re, im);
 }
 
 } // namespace detail
@@ -232,8 +218,14 @@ openSparseTensorCOO(const char *filename, uint64_t rank, const uint64_t *shape,
       // Add the 0-based index.
       indices[perm[r]] = idx - 1;
     }
-    detail::readCOOValue(coo, indices, &linePtr, stfile.isPattern(),
-                         stfile.isSymmetric() && indices[0] != indices[1]);
+    const V value = detail::readCOOValue<V>(&linePtr, stfile.isPattern());
+    // TODO: <https://github.com/llvm/llvm-project/issues/54179>
+    coo->add(indices, value);
+    // We currently chose to deal with symmetric matrices by fully
+    // constructing them.  In the future, we may want to make symmetry
+    // implicit for storage reasons.
+    if (stfile.isSymmetric() && indices[0] != indices[1])
+      coo->add({indices[1], indices[0]}, value);
   }
   // Close the file and return tensor.
   stfile.closeFile();


        


More information about the Mlir-commits mailing list