[Mlir-commits] [mlir] 1dfb9a6 - [mlir][sparse] LICM for SparseTensorReader::readCOO

wren romano llvmlistbot at llvm.org
Fri Dec 2 11:13:43 PST 2022


Author: wren romano
Date: 2022-12-02T11:13:36-08:00
New Revision: 1dfb9a64f1f0fd24eefb19a2ed0b1ae3039db413

URL: https://github.com/llvm/llvm-project/commit/1dfb9a64f1f0fd24eefb19a2ed0b1ae3039db413
DIFF: https://github.com/llvm/llvm-project/commit/1dfb9a64f1f0fd24eefb19a2ed0b1ae3039db413.diff

LOG: [mlir][sparse] LICM for SparseTensorReader::readCOO

This commit performs two related changes.  First we adjust `readCOOValue` to take the `IsPattern` bool as a template parameter rather than a function argument.  Second we factor `readCOOLoop` out from `readCOO`, and template it on `IsPattern` and `IsSymmetric`.  Together this moves all the assertions and header-dependent conditionals out of the main for-loop of `readCOO`.  The only remaining conditional is in the `IsSymmetric=true` variant: checking whether the element is on the diagonal or not (which cannot be lifted out of the loop).

Depends On D138363

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D138365

Added: 
    

Modified: 
    mlir/include/mlir/ExecutionEngine/SparseTensor/File.h

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h b/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h
index 3734f96889ff..fa36b7e393ee 100644
--- a/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h
+++ b/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h
@@ -42,32 +42,46 @@ struct is_complex final : public std::false_type {};
 template <typename T>
 struct is_complex<std::complex<T>> final : public std::true_type {};
 
-/// Reads an element of a non-complex type for the current indices in
-/// coordinate scheme.
-template <typename V>
-inline std::enable_if_t<!is_complex<V>::value, V>
-readCOOValue(char **linePtr, bool is_pattern) {
+/// Returns an element-value of non-complex type.  If `IsPattern` is true,
+/// then returns an arbitrary value.  If `IsPattern` is false, then
+/// reads the value from the current line buffer beginning at `linePtr`.
+template <typename V, bool IsPattern>
+inline std::enable_if_t<!is_complex<V>::value, V> readCOOValue(char **linePtr) {
   // The external formats always store these numerical values with the type
   // double, but we cast these values to the sparse tensor object type.
   // For a pattern tensor, we arbitrarily pick the value 1 for all entries.
-  return is_pattern ? 1.0 : strtod(*linePtr, linePtr);
+  if constexpr (IsPattern)
+    return 1.0;
+  return strtod(*linePtr, linePtr);
 }
 
-/// Reads an element of a complex type for the current indices in
-/// coordinate scheme.
-template <typename V>
-inline std::enable_if_t<is_complex<V>::value, V> readCOOValue(char **linePtr,
-                                                              bool is_pattern) {
+/// Returns an element-value of complex type.  If `IsPattern` is true,
+/// then returns an arbitrary value.  If `IsPattern` is false, then reads
+/// the value from the current line buffer beginning at `linePtr`.
+template <typename V, bool IsPattern>
+inline std::enable_if_t<is_complex<V>::value, V> readCOOValue(char **linePtr) {
   // Read two values to make a complex. The external formats always store
   // numerical values with the type double, but we cast these values to the
   // sparse tensor object type. For a pattern tensor, we arbitrarily pick the
   // value 1 for all entries.
-  double re = is_pattern ? 1.0 : strtod(*linePtr, linePtr);
-  double im = is_pattern ? 1.0 : strtod(*linePtr, linePtr);
+  if constexpr (IsPattern)
+    return V(1.0, 1.0);
+  double re = strtod(*linePtr, linePtr);
+  double im = strtod(*linePtr, linePtr);
   // Avoiding brace-notation since that forbids narrowing to `float`.
   return V(re, im);
 }
 
+/// Returns an element-value.  If `is_pattern` is true, then returns an
+/// arbitrary value.  If `is_pattern` is false, then reads the value from
+/// the current line buffer beginning at `linePtr`.
+template <typename V>
+inline V readCOOValue(char **linePtr, bool is_pattern) {
+  if (is_pattern)
+    return readCOOValue<V, true>(linePtr);
+  return readCOOValue<V, false>(linePtr);
+}
+
 } // namespace detail
 
 //===----------------------------------------------------------------------===//
@@ -249,6 +263,18 @@ class SparseTensorReader final {
   /// Precondition: `indices` is valid for `getRank()`.
   char *readCOOIndices(uint64_t *indices);
 
+  /// The internal implementation of `readCOO`.  We template over
+  /// `IsPattern` and `IsSymmetric` in order to perform LICM without
+  /// needing to duplicate the source code.
+  //
+  // TODO: We currently take the `dim2lvl` argument as a `PermutationRef`
+  // since that's what `readCOO` creates.  Once we update `readCOO` to
+  // functionalize the mapping, then this helper will just take that
+  // same function.
+  template <typename V, bool IsPattern, bool IsSymmetric>
+  void readCOOLoop(uint64_t lvlRank, detail::PermutationRef dim2lvl,
+                   SparseTensorCOO<V> *lvlCOO);
+
   /// Reads the MME header of a general sparse matrix of type real.
   void readMMEHeader();
 
@@ -282,36 +308,50 @@ SparseTensorCOO<V> *SparseTensorReader::readCOO(uint64_t lvlRank,
   assert(lvlRank == dimRank && "Rank mismatch");
   detail::PermutationRef d2l(dimRank, dim2lvl);
   // Prepare a COO object with the number of nonzeros as initial capacity.
-  const uint64_t nnz = getNNZ();
-  auto *lvlCOO = new SparseTensorCOO<V>(lvlRank, lvlSizes, nnz);
-  // Read all nonzero elements.
+  auto *lvlCOO = new SparseTensorCOO<V>(lvlRank, lvlSizes, getNNZ());
+  // Do some manual LICM, to avoid assertions in the for-loop.
+  const bool IsPattern = isPattern();
+  const bool IsSymmetric = (isSymmetric() && getRank() == 2);
+  if (IsPattern && IsSymmetric)
+    readCOOLoop<V, true, true>(lvlRank, d2l, lvlCOO);
+  else if (IsPattern)
+    readCOOLoop<V, true, false>(lvlRank, d2l, lvlCOO);
+  else if (IsSymmetric)
+    readCOOLoop<V, false, true>(lvlRank, d2l, lvlCOO);
+  else
+    readCOOLoop<V, false, false>(lvlRank, d2l, lvlCOO);
+  // Close the file and return the COO.
+  closeFile();
+  return lvlCOO;
+}
+
+template <typename V, bool IsPattern, bool IsSymmetric>
+void SparseTensorReader::readCOOLoop(uint64_t lvlRank,
+                                     detail::PermutationRef dim2lvl,
+                                     SparseTensorCOO<V> *lvlCOO) {
+  const uint64_t dimRank = getRank();
   std::vector<uint64_t> dimInd(dimRank);
   std::vector<uint64_t> lvlInd(lvlRank);
-  // Do some manual LICM, to avoid assertions in the for-loop.
-  const bool addSymmetric = (isSymmetric() && dimRank == 2);
-  const bool isPattern_ = isPattern();
-  for (uint64_t k = 0; k < nnz; ++k) {
+  for (uint64_t nnz = getNNZ(), k = 0; k < nnz; ++k) {
     // We inline `readCOOElement` here in order to avoid redundant
     // assertions, since they're guaranteed by the call to `isValid()`
     // and the construction of `dimInd` above.
     char *linePtr = readCOOIndices(dimInd.data());
-    const V value = detail::readCOOValue<V>(&linePtr, isPattern_);
-    d2l.pushforward(dimRank, dimInd.data(), lvlInd.data());
+    const V value = detail::readCOOValue<V, IsPattern>(&linePtr);
+    dim2lvl.pushforward(dimRank, dimInd.data(), lvlInd.data());
     // TODO: <https://github.com/llvm/llvm-project/issues/54179>
     lvlCOO->add(lvlInd, value);
     // We currently chose to deal with symmetric matrices by fully
     // constructing them.  In the future, we may want to make symmetry
     // implicit for storage reasons.
-    if (addSymmetric && dimInd[0] != dimInd[1]) {
-      // Must recompute `lvlInd`, since arbitrary mappings don't preserve swap.
-      std::swap(dimInd[0], dimInd[1]);
-      d2l.pushforward(dimRank, dimInd.data(), lvlInd.data());
-      lvlCOO->add(lvlInd, value);
-    }
+    if constexpr (IsSymmetric)
+      if (dimInd[0] != dimInd[1]) {
+        // Must recompute `lvlInd`, since arbitrary maps don't preserve swap.
+        std::swap(dimInd[0], dimInd[1]);
+        dim2lvl.pushforward(dimRank, dimInd.data(), lvlInd.data());
+        lvlCOO->add(lvlInd, value);
+      }
   }
-  // Close the file and return the COO.
-  closeFile();
-  return lvlCOO;
 }
 
 /// Writes the sparse tensor to `filename` in extended FROSTT format.


        


More information about the Mlir-commits mailing list