[Mlir-commits] [mlir] 753fe33 - [mlir][sparse] Factoring out an enumerator over elements of SparseTensorStorage

wren romano llvmlistbot at llvm.org
Thu May 12 17:06:03 PDT 2022


Author: wren romano
Date: 2022-05-12T17:05:56-07:00
New Revision: 753fe330c1d6cfebac07ecd385fb2dcf63a0f6c9

URL: https://github.com/llvm/llvm-project/commit/753fe330c1d6cfebac07ecd385fb2dcf63a0f6c9
DIFF: https://github.com/llvm/llvm-project/commit/753fe330c1d6cfebac07ecd385fb2dcf63a0f6c9.diff

LOG: [mlir][sparse] Factoring out an enumerator over elements of SparseTensorStorage

Work towards fixing: https://github.com/llvm/llvm-project/issues/51652

Depends On D122928

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D122060

Added: 
    

Modified: 
    mlir/lib/ExecutionEngine/SparseTensorUtils.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp b/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp
index a7d636749c92c..336c500ce1d33 100644
--- a/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp
+++ b/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp
@@ -27,6 +27,7 @@
 #include <cstdlib>
 #include <cstring>
 #include <fstream>
+#include <functional>
 #include <iostream>
 #include <limits>
 #include <numeric>
@@ -94,6 +95,13 @@ struct Element {
   V value;
 };
 
+/// The type of callback functions which receive an element.  We avoid
+/// packaging the coordinates and value together as an `Element` object
+/// because this helps keep code somewhat cleaner.
+template <typename V>
+using ElementConsumer =
+    const std::function<void(const std::vector<uint64_t> &, V)> &;
+
 /// A memory-resident sparse tensor in coordinate scheme (collection of
 /// elements). This data structure is used to read a sparse tensor from
 /// any external format into memory and sort the elements lexicographically
@@ -220,6 +228,7 @@ class SparseTensorStorageBase {
                           const uint64_t *perm, const DimLevelType *sparsity)
       : dimSizes(szs), rev(getRank()),
         dimTypes(sparsity, sparsity + getRank()) {
+    assert(perm && sparsity);
     const uint64_t rank = getRank();
     // Validate parameters.
     assert(rank > 0 && "Trivial shape is unsupported");
@@ -310,6 +319,16 @@ class SparseTensorStorageBase {
   /// Finishes insertion.
   virtual void endInsert() = 0;
 
+protected:
+  // Since this class is virtual, we must disallow public copying in
+  // order to avoid "slicing".  Since this class has data members,
+  // that means making copying protected.
+  // <https://github.com/isocpp/CppCoreGuidelines/blob/master/CppCoreGuidelines.md#Rc-copy-virtual>
+  SparseTensorStorageBase(const SparseTensorStorageBase &) = default;
+  // Copy-assignment would be implicitly deleted (because `dimSizes`
+  // is const), so we explicitly delete it for clarity.
+  SparseTensorStorageBase &operator=(const SparseTensorStorageBase &) = delete;
+
 private:
   static void fatal(const char *tp) {
     fprintf(stderr, "unsupported %s\n", tp);
@@ -321,6 +340,10 @@ class SparseTensorStorageBase {
   const std::vector<DimLevelType> dimTypes;
 };
 
+// Forward.
+template <typename P, typename I, typename V>
+class SparseTensorEnumerator;
+
 /// A memory-resident sparse tensor using a storage scheme based on
 /// per-dimension sparse/dense annotations. This data structure provides a
 /// bufferized form of a sparse tensor type. In contrast to generating setup
@@ -443,24 +466,13 @@ class SparseTensorStorage : public SparseTensorStorageBase {
   /// sparse tensor in coordinate scheme with the given dimension order.
   ///
   /// Precondition: `perm` must be valid for `getRank()`.
-  SparseTensorCOO<V> *toCOO(const uint64_t *perm) {
-    // Restore original order of the dimension sizes and allocate coordinate
-    // scheme with desired new ordering specified in perm.
-    const uint64_t rank = getRank();
-    const auto &rev = getRev();
-    const auto &sizes = getDimSizes();
-    std::vector<uint64_t> orgsz(rank);
-    for (uint64_t r = 0; r < rank; r++)
-      orgsz[rev[r]] = sizes[r];
-    SparseTensorCOO<V> *coo = SparseTensorCOO<V>::newSparseTensorCOO(
-        rank, orgsz.data(), perm, values.size());
-    // Populate coordinate scheme restored from old ordering and changed with
-    // new ordering. Rather than applying both reorderings during the recursion,
-    // we compute the combine permutation in advance.
-    std::vector<uint64_t> reord(rank);
-    for (uint64_t r = 0; r < rank; r++)
-      reord[r] = perm[rev[r]];
-    toCOO(*coo, reord, 0, 0);
+  SparseTensorCOO<V> *toCOO(const uint64_t *perm) const {
+    SparseTensorEnumerator<P, I, V> enumerator(*this, getRank(), perm);
+    SparseTensorCOO<V> *coo =
+        new SparseTensorCOO<V>(enumerator.permutedSizes(), values.size());
+    enumerator.forallElements([&coo](const std::vector<uint64_t> &ind, V val) {
+      coo->add(ind, val);
+    });
     // TODO: This assertion assumes there are no stored zeros,
     // or if there are then that we don't filter them out.
     // Cf., <https://github.com/llvm/llvm-project/issues/54179>
@@ -543,9 +555,10 @@ class SparseTensorStorage : public SparseTensorStorageBase {
   ///     and pointwise less-than).
   void fromCOO(const std::vector<Element<V>> &elements, uint64_t lo,
                uint64_t hi, uint64_t d) {
+    uint64_t rank = getRank();
+    assert(d <= rank && hi <= elements.size());
     // Once dimensions are exhausted, insert the numerical values.
-    assert(d <= getRank() && hi <= elements.size());
-    if (d == getRank()) {
+    if (d == rank) {
       assert(lo < hi);
       values.push_back(elements[lo].value);
       return;
@@ -569,31 +582,6 @@ class SparseTensorStorage : public SparseTensorStorageBase {
     finalizeSegment(d, full);
   }
 
-  /// Stores the sparse tensor storage scheme into a memory-resident sparse
-  /// tensor in coordinate scheme.
-  void toCOO(SparseTensorCOO<V> &tensor, std::vector<uint64_t> &reord,
-             uint64_t pos, uint64_t d) {
-    assert(d <= getRank());
-    if (d == getRank()) {
-      assert(pos < values.size());
-      tensor.add(idx, values[pos]);
-    } else if (isCompressedDim(d)) {
-      // Sparse dimension.
-      for (uint64_t ii = pointers[d][pos]; ii < pointers[d][pos + 1]; ii++) {
-        idx[reord[d]] = indices[d][ii];
-        toCOO(tensor, reord, ii, d + 1);
-      }
-    } else {
-      // Dense dimension.
-      const uint64_t sz = getDimSizes()[d];
-      const uint64_t off = pos * sz;
-      for (uint64_t i = 0; i < sz; i++) {
-        idx[reord[d]] = i;
-        toCOO(tensor, reord, off + i, d + 1);
-      }
-    }
-  }
-
   /// Finalize the sparse pointer structure at this dimension.
   void finalizeSegment(uint64_t d, uint64_t full = 0, uint64_t count = 1) {
     if (count == 0)
@@ -649,13 +637,151 @@ class SparseTensorStorage : public SparseTensorStorageBase {
     return -1u;
   }
 
-private:
+  // Allow `SparseTensorEnumerator` to access the data-members (to avoid
+  // the cost of virtual-function dispatch in inner loops), without
+  // making them public to other client code.
+  friend class SparseTensorEnumerator<P, I, V>;
+
   std::vector<std::vector<P>> pointers;
   std::vector<std::vector<I>> indices;
   std::vector<V> values;
   std::vector<uint64_t> idx; // index cursor for lexicographic insertion.
 };
 
+/// A (higher-order) function object for enumerating the elements of some
+/// `SparseTensorStorage` under a permutation.  That is, the `forallElements`
+/// method encapsulates the loop-nest for enumerating the elements of
+/// the source tensor (in whatever order is best for the source tensor),
+/// and applies a permutation to the coordinates/indices before handing
+/// each element to the callback.  A single enumerator object can be
+/// freely reused for several calls to `forallElements`, just so long
+/// as each call is sequential with respect to one another.
+///
+/// N.B., this class stores a reference to the `SparseTensorStorageBase`
+/// passed to the constructor; thus, objects of this class must not
+/// outlive the sparse tensor they depend on.
+///
+/// Design Note: The reason we define this class instead of simply using
+/// `SparseTensorEnumerator<P,I,V>` is because we need to hide/generalize
+/// the `<P,I>` template parameters from MLIR client code (to simplify the
+/// type parameters used for direct sparse-to-sparse conversion).  And the
+/// reason we define the `SparseTensorEnumerator<P,I,V>` subclasses rather
+/// than simply using this class, is to avoid the cost of virtual-method
+/// dispatch within the loop-nest.
+template <typename V>
+class SparseTensorEnumeratorBase {
+public:
+  /// Constructs an enumerator with the given permutation for mapping
+  /// the semantic-ordering of dimensions to the desired target-ordering.
+  ///
+  /// Preconditions:
+  /// * the `tensor` must have the same `V` value type.
+  /// * `perm` must be valid for `rank`.
+  SparseTensorEnumeratorBase(const SparseTensorStorageBase &tensor,
+                             uint64_t rank, const uint64_t *perm)
+      : src(tensor), permsz(src.getRev().size()), reord(getRank()),
+        cursor(getRank()) {
+    assert(perm && "Received nullptr for permutation");
+    assert(rank == getRank() && "Permutation rank mismatch");
+    const auto &rev = src.getRev();        // source stg-order -> semantic-order
+    const auto &sizes = src.getDimSizes(); // in source storage-order
+    for (uint64_t s = 0; s < rank; s++) {  // `s` source storage-order
+      uint64_t t = perm[rev[s]];           // `t` target-order
+      reord[s] = t;
+      permsz[t] = sizes[s];
+    }
+  }
+
+  virtual ~SparseTensorEnumeratorBase() = default;
+
+  // We disallow copying to help avoid leaking the `src` reference.
+  // (In addition to avoiding the problem of slicing.)
+  SparseTensorEnumeratorBase(const SparseTensorEnumeratorBase &) = delete;
+  SparseTensorEnumeratorBase &
+  operator=(const SparseTensorEnumeratorBase &) = delete;
+
+  /// Returns the source/target tensor's rank.  (The source-rank and
+  /// target-rank are always equal since we only support permutations.
+  /// Though once we add support for other dimension mappings, this
+  /// method will have to be split in two.)
+  uint64_t getRank() const { return permsz.size(); }
+
+  /// Returns the target tensor's dimension sizes.
+  const std::vector<uint64_t> &permutedSizes() const { return permsz; }
+
+  /// Enumerates all elements of the source tensor, permutes their
+  /// indices, and passes the permuted element to the callback.
+  /// The callback must not store the cursor reference directly,
+  /// since this function reuses the storage.  Instead, the callback
+  /// must copy it if they want to keep it.
+  virtual void forallElements(ElementConsumer<V> yield) = 0;
+
+protected:
+  const SparseTensorStorageBase &src;
+  std::vector<uint64_t> permsz; // in target order.
+  std::vector<uint64_t> reord;  // source storage-order -> target order.
+  std::vector<uint64_t> cursor; // in target order.
+};
+
+template <typename P, typename I, typename V>
+class SparseTensorEnumerator final : public SparseTensorEnumeratorBase<V> {
+  using Base = SparseTensorEnumeratorBase<V>;
+
+public:
+  /// Constructs an enumerator with the given permutation for mapping
+  /// the semantic-ordering of dimensions to the desired target-ordering.
+  ///
+  /// Precondition: `perm` must be valid for `rank`.
+  SparseTensorEnumerator(const SparseTensorStorage<P, I, V> &tensor,
+                         uint64_t rank, const uint64_t *perm)
+      : Base(tensor, rank, perm) {}
+
+  ~SparseTensorEnumerator() final override = default;
+
+  void forallElements(ElementConsumer<V> yield) final override {
+    forallElements(yield, 0, 0);
+  }
+
+private:
+  /// The recursive component of the public `forallElements`.
+  void forallElements(ElementConsumer<V> yield, uint64_t parentPos,
+                      uint64_t d) {
+    // Recover the `<P,I,V>` type parameters of `src`.
+    const auto &src =
+        static_cast<const SparseTensorStorage<P, I, V> &>(this->src);
+    if (d == Base::getRank()) {
+      assert(parentPos < src.values.size() &&
+             "Value position is out of bounds");
+      // TODO: <https://github.com/llvm/llvm-project/issues/54179>
+      yield(this->cursor, src.values[parentPos]);
+    } else if (src.isCompressedDim(d)) {
+      // Look up the bounds of the `d`-level segment determined by the
+      // `d-1`-level position `parentPos`.
+      const std::vector<P> &pointers_d = src.pointers[d];
+      assert(parentPos + 1 < pointers_d.size() &&
+             "Parent pointer position is out of bounds");
+      const uint64_t pstart = static_cast<uint64_t>(pointers_d[parentPos]);
+      const uint64_t pstop = static_cast<uint64_t>(pointers_d[parentPos + 1]);
+      // Loop-invariant code for looking up the `d`-level coordinates/indices.
+      const std::vector<I> &indices_d = src.indices[d];
+      assert(pstop - 1 < indices_d.size() && "Index position is out of bounds");
+      uint64_t &cursor_reord_d = this->cursor[this->reord[d]];
+      for (uint64_t pos = pstart; pos < pstop; pos++) {
+        cursor_reord_d = static_cast<uint64_t>(indices_d[pos]);
+        forallElements(yield, pos, d + 1);
+      }
+    } else { // Dense dimension.
+      const uint64_t sz = src.getDimSizes()[d];
+      const uint64_t pstart = parentPos * sz;
+      uint64_t &cursor_reord_d = this->cursor[this->reord[d]];
+      for (uint64_t i = 0; i < sz; i++) {
+        cursor_reord_d = i;
+        forallElements(yield, pstart + i, d + 1);
+      }
+    }
+  }
+};
+
 /// Helper to convert string to lower case.
 static char *toLower(char *token) {
   for (char *c = token; *c; c++)


        


More information about the Mlir-commits mailing list