[Mlir-commits] [mlir] 8d8b566 - [mlir][sparse] Moving <P, I, V>-invariant parts of SparseTensorStorage to base

wren romano llvmlistbot at llvm.org
Fri Apr 8 11:44:25 PDT 2022


Author: wren romano
Date: 2022-04-08T11:44:17-07:00
New Revision: 8d8b566f0c668b7774cd50a9160a5eed3e71fd72

URL: https://github.com/llvm/llvm-project/commit/8d8b566f0c668b7774cd50a9160a5eed3e71fd72
DIFF: https://github.com/llvm/llvm-project/commit/8d8b566f0c668b7774cd50a9160a5eed3e71fd72.diff

LOG: [mlir][sparse] Moving <P,I,V>-invariant parts of SparseTensorStorage to base

This reorganization helps to clean up the changes needed for D122060.

Work towards fixing: https://github.com/llvm/llvm-project/issues/51652

Depends On D122625

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D122928

Added: 
    

Modified: 
    mlir/lib/ExecutionEngine/SparseTensorUtils.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp b/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp
index fbc858466e064..576abbcd3ea6d 100644
--- a/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp
+++ b/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp
@@ -156,6 +156,8 @@ struct SparseTensorCOO {
   /// the given ordering and expects subsequent add() calls to honor
   /// that same ordering for the given indices. The result is a
   /// fully permuted coordinate scheme.
+  ///
+  /// Precondition: `sizes` and `perm` must be valid for `rank`.
   static SparseTensorCOO<V> *newSparseTensorCOO(uint64_t rank,
                                                 const uint64_t *sizes,
                                                 const uint64_t *perm,
@@ -175,12 +177,63 @@ struct SparseTensorCOO {
   unsigned iteratorPos;
 };
 
-/// Abstract base class of sparse tensor storage. Note that we use
-/// function overloading to implement "partial" method specialization.
+/// Abstract base class for `SparseTensorStorage<P,I,V>`.  This class
+/// takes responsibility for all the `<P,I,V>`-independent aspects
+/// of the tensor (e.g., shape, sparsity, permutation).  In addition,
+/// we use function overloading to implement "partial" method
+/// specialization, which the C-API relies on to catch type errors
+/// arising from our use of opaque pointers.
 class SparseTensorStorageBase {
 public:
-  /// Dimension size query.
-  virtual uint64_t getDimSize(uint64_t) const = 0;
+  /// Constructs a new storage object.  The `perm` maps the tensor's
+  /// semantic-ordering of dimensions to this object's storage-order.
+  /// The `szs` and `sparsity` arrays are already in storage-order.
+  ///
+  /// Precondition: `perm` and `sparsity` must be valid for `szs.size()`.
+  SparseTensorStorageBase(const std::vector<uint64_t> &szs,
+                          const uint64_t *perm, const DimLevelType *sparsity)
+      : dimSizes(szs), rev(getRank()),
+        dimTypes(sparsity, sparsity + getRank()) {
+    const uint64_t rank = getRank();
+    // Validate parameters.
+    assert(rank > 0 && "Trivial shape is unsupported");
+    for (uint64_t r = 0; r < rank; r++) {
+      assert(dimSizes[r] > 0 && "Dimension size zero has trivial storage");
+      assert((dimTypes[r] == DimLevelType::kDense ||
+              dimTypes[r] == DimLevelType::kCompressed) &&
+             "Unsupported DimLevelType");
+    }
+    // Construct the "reverse" (i.e., inverse) permutation.
+    for (uint64_t r = 0; r < rank; r++)
+      rev[perm[r]] = r;
+  }
+
+  virtual ~SparseTensorStorageBase() = default;
+
+  /// Get the rank of the tensor.
+  uint64_t getRank() const { return dimSizes.size(); }
+
+  /// Getter for the dimension-sizes array, in storage-order.
+  const std::vector<uint64_t> &getDimSizes() const { return dimSizes; }
+
+  /// Safely lookup the size of the given (storage-order) dimension.
+  uint64_t getDimSize(uint64_t d) const {
+    assert(d < getRank());
+    return dimSizes[d];
+  }
+
+  /// Getter for the "reverse" permutation, which maps this object's
+  /// storage-order to the tensor's semantic-order.
+  const std::vector<uint64_t> &getRev() const { return rev; }
+
+  /// Getter for the dimension-types array, in storage-order.
+  const std::vector<DimLevelType> &getDimTypes() const { return dimTypes; }
+
+  /// Safely check if the (storage-order) dimension uses compressed storage.
+  bool isCompressedDim(uint64_t d) const {
+    assert(d < getRank());
+    return (dimTypes[d] == DimLevelType::kCompressed);
+  }
 
   /// Overhead storage.
   virtual void getPointers(std::vector<uint64_t> **, uint64_t) { fatal("p64"); }
@@ -231,13 +284,15 @@ class SparseTensorStorageBase {
   /// Finishes insertion.
   virtual void endInsert() = 0;
 
-  virtual ~SparseTensorStorageBase() = default;
-
 private:
   static void fatal(const char *tp) {
     fprintf(stderr, "unsupported %s\n", tp);
     exit(1);
   }
+
+  const std::vector<uint64_t> dimSizes;
+  std::vector<uint64_t> rev;
+  const std::vector<DimLevelType> dimTypes;
 };
 
 /// A memory-resident sparse tensor using a storage scheme based on
@@ -252,15 +307,13 @@ class SparseTensorStorage : public SparseTensorStorageBase {
   /// Constructs a sparse tensor storage scheme with the given dimensions,
   /// permutation, and per-dimension dense/sparse annotations, using
   /// the coordinate scheme tensor for the initial contents if provided.
+  ///
+  /// Precondition: `perm` and `sparsity` must be valid for `szs.size()`.
   SparseTensorStorage(const std::vector<uint64_t> &szs, const uint64_t *perm,
                       const DimLevelType *sparsity,
-                      SparseTensorCOO<V> *tensor = nullptr)
-      : sizes(szs), rev(getRank()), idx(getRank()), pointers(getRank()),
-        indices(getRank()) {
-    uint64_t rank = getRank();
-    // Store "reverse" permutation.
-    for (uint64_t r = 0; r < rank; r++)
-      rev[perm[r]] = r;
+                      SparseTensorCOO<V> *coo = nullptr)
+      : SparseTensorStorageBase(szs, perm, sparsity), pointers(getRank()),
+        indices(getRank()), idx(getRank()) {
     // Provide hints on capacity of pointers and indices.
     // TODO: needs much fine-tuning based on actual sparsity; currently
     //       we reserve pointer/index space based on all previous dense
@@ -268,30 +321,26 @@ class SparseTensorStorage : public SparseTensorStorageBase {
     //       we should really use nnz and dense/sparse distribution.
     bool allDense = true;
     uint64_t sz = 1;
-    for (uint64_t r = 0; r < rank; r++) {
-      assert(sizes[r] > 0 && "Dimension size zero has trivial storage");
-      if (sparsity[r] == DimLevelType::kCompressed) {
+    for (uint64_t r = 0, rank = getRank(); r < rank; r++) {
+      if (isCompressedDim(r)) {
+        // TODO: Take a parameter between 1 and `sizes[r]`, and multiply
+        // `sz` by that before reserving. (For now we just use 1.)
         pointers[r].reserve(sz + 1);
+        pointers[r].push_back(0);
         indices[r].reserve(sz);
         sz = 1;
         allDense = false;
-        // Prepare the pointer structure.  We cannot use `appendPointer`
-        // here, because `isCompressedDim` won't work until after this
-        // preparation has been done.
-        pointers[r].push_back(0);
-      } else {
-        assert(sparsity[r] == DimLevelType::kDense &&
-               "singleton not yet supported");
-        sz = checkedMul(sz, sizes[r]);
+      } else { // Dense dimension.
+        sz = checkedMul(sz, getDimSizes()[r]);
       }
     }
     // Then assign contents from coordinate scheme tensor if provided.
-    if (tensor) {
+    if (coo) {
       // Ensure both preconditions of `fromCOO`.
-      assert(tensor->getSizes() == sizes && "Tensor size mismatch");
-      tensor->sort();
+      assert(coo->getSizes() == getDimSizes() && "Tensor size mismatch");
+      coo->sort();
       // Now actually insert the `elements`.
-      const std::vector<Element<V>> &elements = tensor->getElements();
+      const std::vector<Element<V>> &elements = coo->getElements();
       uint64_t nnz = elements.size();
       values.reserve(nnz);
       fromCOO(elements, 0, nnz, 0);
@@ -302,15 +351,6 @@ class SparseTensorStorage : public SparseTensorStorageBase {
 
   ~SparseTensorStorage() override = default;
 
-  /// Get the rank of the tensor.
-  uint64_t getRank() const { return sizes.size(); }
-
-  /// Get the size of the given dimension of the tensor.
-  uint64_t getDimSize(uint64_t d) const override {
-    assert(d < getRank());
-    return sizes[d];
-  }
-
   /// Partially specialize these getter methods based on template types.
   void getPointers(std::vector<P> **out, uint64_t d) override {
     assert(d < getRank());
@@ -375,14 +415,18 @@ class SparseTensorStorage : public SparseTensorStorageBase {
 
   /// Returns this sparse tensor storage scheme as a new memory-resident
   /// sparse tensor in coordinate scheme with the given dimension order.
+  ///
+  /// Precondition: `perm` must be valid for `getRank()`.
   SparseTensorCOO<V> *toCOO(const uint64_t *perm) {
     // Restore original order of the dimension sizes and allocate coordinate
     // scheme with desired new ordering specified in perm.
-    uint64_t rank = getRank();
+    const uint64_t rank = getRank();
+    const auto &rev = getRev();
+    const auto &sizes = getDimSizes();
     std::vector<uint64_t> orgsz(rank);
     for (uint64_t r = 0; r < rank; r++)
       orgsz[rev[r]] = sizes[r];
-    SparseTensorCOO<V> *tensor = SparseTensorCOO<V>::newSparseTensorCOO(
+    SparseTensorCOO<V> *coo = SparseTensorCOO<V>::newSparseTensorCOO(
         rank, orgsz.data(), perm, values.size());
     // Populate coordinate scheme restored from old ordering and changed with
     // new ordering. Rather than applying both reorderings during the recursion,
@@ -390,9 +434,12 @@ class SparseTensorStorage : public SparseTensorStorageBase {
     std::vector<uint64_t> reord(rank);
     for (uint64_t r = 0; r < rank; r++)
       reord[r] = perm[rev[r]];
-    toCOO(*tensor, reord, 0, 0);
-    assert(tensor->getElements().size() == values.size());
-    return tensor;
+    toCOO(*coo, reord, 0, 0);
+    // TODO: This assertion assumes there are no stored zeros,
+    // or if there are then that we don't filter them out.
+    // Cf., <https://github.com/llvm/llvm-project/issues/54179>
+    assert(coo->getElements().size() == values.size());
+    return coo;
   }
 
   /// Factory method. Constructs a sparse tensor storage scheme with the given
@@ -400,16 +447,18 @@ class SparseTensorStorage : public SparseTensorStorageBase {
   /// using the coordinate scheme tensor for the initial contents if provided.
   /// In the latter case, the coordinate scheme must respect the same
   /// permutation as is desired for the new sparse tensor storage.
+  ///
+  /// Precondition: `shape`, `perm`, and `sparsity` must be valid for `rank`.
   static SparseTensorStorage<P, I, V> *
   newSparseTensor(uint64_t rank, const uint64_t *shape, const uint64_t *perm,
-                  const DimLevelType *sparsity, SparseTensorCOO<V> *tensor) {
+                  const DimLevelType *sparsity, SparseTensorCOO<V> *coo) {
     SparseTensorStorage<P, I, V> *n = nullptr;
-    if (tensor) {
-      assert(tensor->getRank() == rank);
+    if (coo) {
+      assert(coo->getRank() == rank && "Tensor rank mismatch");
+      const auto &coosz = coo->getSizes();
       for (uint64_t r = 0; r < rank; r++)
-        assert(shape[r] == 0 || shape[r] == tensor->getSizes()[perm[r]]);
-      n = new SparseTensorStorage<P, I, V>(tensor->getSizes(), perm, sparsity,
-                                           tensor);
+        assert(shape[r] == 0 || shape[r] == coosz[perm[r]]);
+      n = new SparseTensorStorage<P, I, V>(coosz, perm, sparsity, coo);
     } else {
       std::vector<uint64_t> permsz(rank);
       for (uint64_t r = 0; r < rank; r++) {
@@ -426,7 +475,7 @@ class SparseTensorStorage : public SparseTensorStorageBase {
   /// checks that `pos` is representable in the `P` type; however, it
   /// does not check that `pos` is semantically valid (i.e., larger than
   /// the previous position and smaller than `indices[d].capacity()`).
-  inline void appendPointer(uint64_t d, uint64_t pos, uint64_t count = 1) {
+  void appendPointer(uint64_t d, uint64_t pos, uint64_t count = 1) {
     assert(isCompressedDim(d));
     assert(pos <= std::numeric_limits<P>::max() &&
            "Pointer value is too large for the P-type");
@@ -510,7 +559,9 @@ class SparseTensorStorage : public SparseTensorStorageBase {
       }
     } else {
       // Dense dimension.
-      for (uint64_t i = 0, sz = sizes[d], off = pos * sz; i < sz; i++) {
+      const uint64_t sz = getDimSizes()[d];
+      const uint64_t off = pos * sz;
+      for (uint64_t i = 0; i < sz; i++) {
         idx[reord[d]] = i;
         toCOO(tensor, reord, off + i, d + 1);
       }
@@ -524,11 +575,9 @@ class SparseTensorStorage : public SparseTensorStorageBase {
     if (isCompressedDim(d)) {
       appendPointer(d, indices[d].size(), count);
     } else { // Dense dimension.
-      const uint64_t sz = sizes[d];
+      const uint64_t sz = getDimSizes()[d];
       assert(sz >= full && "Segment is overfull");
-      // Assuming we checked for overflows in the constructor, then this
-      // multiply will never overflow.
-      count *= (sz - full);
+      count = checkedMul(count, sz - full);
       // For dense storage we must enumerate all the remaining coordinates
       // in this dimension (i.e., coordinates after the last non-zero
       // element), and either fill in their zero values or else recurse
@@ -574,19 +623,11 @@ class SparseTensorStorage : public SparseTensorStorageBase {
     return -1u;
   }
 
-  /// Returns true if dimension is compressed.
-  inline bool isCompressedDim(uint64_t d) const {
-    assert(d < getRank());
-    return (!pointers[d].empty());
-  }
-
 private:
-  const std::vector<uint64_t> sizes; // per-dimension sizes
-  std::vector<uint64_t> rev;   // "reverse" permutation
-  std::vector<uint64_t> idx;   // index cursor
   std::vector<std::vector<P>> pointers;
   std::vector<std::vector<I>> indices;
   std::vector<V> values;
+  std::vector<uint64_t> idx; // index cursor for lexicographic insertion.
 };
 
 /// Helper to convert string to lower case.


        


More information about the Mlir-commits mailing list