[Mlir-commits] [mlir] [mlir][sparse] complete migration to dim2lvl/lvl2dim in library (PR #69268)

Peiming Liu llvmlistbot at llvm.org
Tue Oct 17 09:32:54 PDT 2023


================
@@ -669,184 +599,48 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
     return -1u;
   }
 
-  // Allow `SparseTensorEnumerator` to access the data-members (to avoid
-  // the cost of virtual-function dispatch in inner loops), without
-  // making them public to other client code.
-  friend class SparseTensorEnumerator<P, C, V>;
-
-  std::vector<std::vector<P>> positions;
-  std::vector<std::vector<C>> coordinates;
-  std::vector<V> values;
-  std::vector<uint64_t> lvlCursor; // cursor for lexicographic insertion.
-  SparseTensorCOO<V> *lvlCOO;      // COO used during forwarding
-};
-
-//===----------------------------------------------------------------------===//
-//
-//  SparseTensorEnumerator
-//
-//===----------------------------------------------------------------------===//
-
-/// A (higher-order) function object for enumerating the elements of some
-/// `SparseTensorStorage` under a permutation.  That is, the `forallElements`
-/// method encapsulates the loop-nest for enumerating the elements of
-/// the source tensor (in whatever order is best for the source tensor),
-/// and applies a permutation to the coordinates before handing
-/// each element to the callback.  A single enumerator object can be
-/// freely reused for several calls to `forallElements`, just so long
-/// as each call is sequential with respect to one another.
-///
-/// N.B., this class stores a reference to the `SparseTensorStorageBase`
-/// passed to the constructor; thus, objects of this class must not
-/// outlive the sparse tensor they depend on.
-///
-/// Design Note: The reason we define this class instead of simply using
-/// `SparseTensorEnumerator<P,C,V>` is because we need to hide/generalize
-/// the `<P,C>` template parameters from MLIR client code (to simplify the
-/// type parameters used for direct sparse-to-sparse conversion).  And the
-/// reason we define the `SparseTensorEnumerator<P,C,V>` subclasses rather
-/// than simply using this class, is to avoid the cost of virtual-method
-/// dispatch within the loop-nest.
-template <typename V>
-class SparseTensorEnumeratorBase {
-public:
-  /// Constructs an enumerator which automatically applies the given
-  /// mapping from the source tensor's dimensions to the desired
-  /// target tensor dimensions.
-  ///
-  /// Preconditions:
-  /// * the `src` must have the same `V` value type.
-  /// * `trgSizes` must be valid for `trgRank`.
-  /// * `src2trg` must be valid for `srcRank`, and must map coordinates
-  ///   valid for `src.getDimSizes()` to coordinates valid for `trgSizes`.
-  ///
-  /// Asserts:
-  /// * `trgSizes` must be nonnull and must contain only nonzero sizes.
-  /// * `srcRank == src.getDimRank()`.
-  /// * `src2trg` must be nonnull.
-  SparseTensorEnumeratorBase(const SparseTensorStorageBase &src,
-                             uint64_t trgRank, const uint64_t *trgSizes,
-                             uint64_t srcRank, const uint64_t *src2trg)
-      : src(src), trgSizes(trgSizes, trgSizes + trgRank),
-        lvl2trg(src.getLvlRank()), trgCursor(trgRank) {
-    assert(trgSizes && "Received nullptr for target-sizes");
-    assert(src2trg && "Received nullptr for source-to-target mapping");
-    assert(srcRank == src.getDimRank() && "Source-rank mismatch");
-    for (uint64_t t = 0; t < trgRank; ++t)
-      assert(trgSizes[t] > 0 && "Target-size zero has trivial storage");
-    const auto &lvl2src = src.getLvl2Dim();
-    for (uint64_t lvlRank = src.getLvlRank(), l = 0; l < lvlRank; ++l)
-      lvl2trg[l] = src2trg[lvl2src[l]];
-  }
-
-  virtual ~SparseTensorEnumeratorBase() = default;
-
-  // We disallow copying to help avoid leaking the `src` reference.
-  // (In addition to avoiding the problem of slicing.)
-  SparseTensorEnumeratorBase(const SparseTensorEnumeratorBase &) = delete;
-  SparseTensorEnumeratorBase &
-  operator=(const SparseTensorEnumeratorBase &) = delete;
-
-  /// Gets the source's dimension-rank.
-  uint64_t getSrcDimRank() const { return src.getDimRank(); }
-
-  /// Gets the target's dimension-/level-rank.  (This is usually
-  /// "dimension-rank", though that may coincide with "level-rank"
-  /// depending on usage.)
-  uint64_t getTrgRank() const { return trgSizes.size(); }
-
-  /// Gets the target's dimension-/level-sizes.  (These are usually
-  /// "dimensions", though that may coincide with "level-rank" depending
-  /// on usage.)
-  const std::vector<uint64_t> &getTrgSizes() const { return trgSizes; }
-
-  /// Enumerates all elements of the source tensor, permutes their
-  /// coordinates, and passes the permuted element to the callback.
-  /// The callback must not store the cursor reference directly,
-  /// since this function reuses the storage.  Instead, the callback
-  /// must copy it if they want to keep it.
-  virtual void forallElements(ElementConsumer<V> yield) = 0;
-
-protected:
-  const SparseTensorStorageBase &src;
-  std::vector<uint64_t> trgSizes;  // in target order.
-  std::vector<uint64_t> lvl2trg;   // source-levels -> target-dims/lvls.
-  std::vector<uint64_t> trgCursor; // in target order.
-};
-
-template <typename P, typename C, typename V>
-class SparseTensorEnumerator final : public SparseTensorEnumeratorBase<V> {
-  using Base = SparseTensorEnumeratorBase<V>;
-  using StorageImpl = SparseTensorStorage<P, C, V>;
-
-public:
-  /// Constructs an enumerator which automatically applies the given
-  /// mapping from the source tensor's dimensions to the desired
-  /// target tensor dimensions.
-  ///
-  /// Preconditions/assertions are as per the `SparseTensorEnumeratorBase` ctor.
-  SparseTensorEnumerator(const StorageImpl &src, uint64_t trgRank,
-                         const uint64_t *trgSizes, uint64_t srcRank,
-                         const uint64_t *src2trg)
-      : Base(src, trgRank, trgSizes, srcRank, src2trg) {}
-
-  ~SparseTensorEnumerator() final = default;
-
-  void forallElements(ElementConsumer<V> yield) final {
-    forallElements(yield, 0, 0);
-  }
-
-private:
-  // TODO: Once we functionalize the mappings, then we'll no longer
-  // be able to use the current approach of constructing `lvl2trg` in the
-  // ctor and using it to incrementally fill the `trgCursor` cursor as we
-  // recurse through `forallElements`.  Instead we'll want to incrementally
-  // fill a `lvlCursor` as we recurse, and then use `src.getLvl2Dim()`
-  // and `src2trg` to convert it just before yielding to the callback.
-  // It's probably most efficient to just store the `srcCursor` and
-  // `trgCursor` buffers in this object, but we may want to benchmark
-  // that against using `std::calloc` to stack-allocate them instead.
-  //
-  /// The recursive component of the public `forallElements`.
-  void forallElements(ElementConsumer<V> yield, uint64_t parentPos,
-                      uint64_t l) {
-    // Recover the `<P,C,V>` type parameters of `src`.
-    const auto &src = static_cast<const StorageImpl &>(this->src);
-    if (l == src.getLvlRank()) {
-      assert(parentPos < src.values.size());
-      // TODO: <https://github.com/llvm/llvm-project/issues/54179>
-      yield(this->trgCursor, src.values[parentPos]);
+  // Performs forall on level entries and inserts into dim COO.
+  void toCOO(uint64_t parentPos, uint64_t l, std::vector<uint64_t> &dimCoords) {
+    if (l == getLvlRank()) {
+      map.pushbackward(lvlCursor.data(), dimCoords.data());
+      assert(coo);
+      assert(parentPos < values.size());
+      coo->add(dimCoords, values[parentPos]);
       return;
     }
-    uint64_t &cursorL = this->trgCursor[this->lvl2trg[l]];
-    const auto dlt = src.getLvlType(l); // Avoid redundant bounds checking.
-    if (isCompressedDLT(dlt)) {
+    if (isCompressedLvl(l)) {
       // Look up the bounds of the `l`-level segment determined by the
       // `(l - 1)`-level position `parentPos`.
-      const std::vector<P> &positionsL = src.positions[l];
+      const std::vector<P> &positionsL = positions[l];
       assert(parentPos + 1 < positionsL.size());
       const uint64_t pstart = static_cast<uint64_t>(positionsL[parentPos]);
       const uint64_t pstop = static_cast<uint64_t>(positionsL[parentPos + 1]);
       // Loop-invariant code for looking up the `l`-level coordinates.
-      const std::vector<C> &coordinatesL = src.coordinates[l];
+      const std::vector<C> &coordinatesL = coordinates[l];
       assert(pstop <= coordinatesL.size());
       for (uint64_t pos = pstart; pos < pstop; ++pos) {
-        cursorL = static_cast<uint64_t>(coordinatesL[pos]);
-        forallElements(yield, pos, l + 1);
+        lvlCursor[l] = static_cast<uint64_t>(coordinatesL[pos]);
+        toCOO(pos, l + 1, dimCoords);
       }
-    } else if (isSingletonDLT(dlt)) {
-      cursorL = src.getCrd(l, parentPos);
-      forallElements(yield, parentPos, l + 1);
+    } else if (isSingletonLvl(l)) {
+      lvlCursor[l] = getCrd(l, parentPos);
+      toCOO(parentPos, l + 1, dimCoords);
     } else { // Dense level.
-      assert(isDenseDLT(dlt));
-      const uint64_t sz = src.getLvlSizes()[l];
+      assert(isDenseLvl(l));
+      const uint64_t sz = getLvlSizes()[l];
       const uint64_t pstart = parentPos * sz;
       for (uint64_t c = 0; c < sz; ++c) {
-        cursorL = c;
-        forallElements(yield, pstart + c, l + 1);
+        lvlCursor[l] = c;
+        toCOO(pstart + c, l + 1, dimCoords);
       }
     }
   }
+
+  std::vector<std::vector<P>> positions;
+  std::vector<std::vector<C>> coordinates;
+  std::vector<V> values;
+  std::vector<uint64_t> lvlCursor;
+  SparseTensorCOO<V> *coo;
----------------
PeimingLiu wrote:

How do you guarantee that it wouldn't leak? Do we have some assumptions here on how user call the APIs?

https://github.com/llvm/llvm-project/pull/69268


More information about the Mlir-commits mailing list