[Mlir-commits] [mlir] [mlir][sparse] complete migration to dim2lvl/lvl2dim in library (PR #69268)

Aart Bik llvmlistbot at llvm.org
Mon Oct 16 17:36:01 PDT 2023


https://github.com/aartbik created https://github.com/llvm/llvm-project/pull/69268

This last revision completed the migration to non-permutation support in the SparseTensor library. All mappings are now controlled by the MapRef (forward and backward). Unused code has been removed, which simplifies subsequent testing of block sparsity.

>From 14874810c4cfecc762fdbdbe606dfa60669c4d18 Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Mon, 16 Oct 2023 17:33:43 -0700
Subject: [PATCH] [mlir][sparse] complete migration to dim2lvl/lvl2dim in
 library

This last revision completed the migration to non-permutation
support in the SparseTensor library. All mappings are now
controlled by the MapRef (forward and backward). Unused
code has been removed, which simplifies subsequent testing
of block sparsity.
---
 .../mlir/ExecutionEngine/SparseTensor/File.h  |   7 +-
 .../ExecutionEngine/SparseTensor/Storage.h    | 361 ++++--------------
 .../ExecutionEngine/SparseTensorRuntime.cpp   |   5 +-
 3 files changed, 80 insertions(+), 293 deletions(-)

diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h b/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h
index efc3f82d6a307ea..1b5f0553a3af959 100644
--- a/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h
+++ b/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h
@@ -201,10 +201,11 @@ class SparseTensorReader final {
                    const uint64_t *lvl2dim) {
     const uint64_t dimRank = getRank();
     MapRef map(dimRank, lvlRank, dim2lvl, lvl2dim);
-    auto *coo = readCOO<V>(map, lvlSizes);
+    auto *lvlCOO = readCOO<V>(map, lvlSizes);
     auto *tensor = SparseTensorStorage<P, I, V>::newFromCOO(
-        dimRank, getDimSizes(), lvlRank, lvlTypes, dim2lvl, lvl2dim, *coo);
-    delete coo;
+        dimRank, getDimSizes(), lvlRank, lvlSizes, lvlTypes, dim2lvl, lvl2dim,
+        *lvlCOO);
+    delete lvlCOO;
     return tensor;
   }
 
diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
index bafc9baa7edde1e..f31ac7c1b7cc246 100644
--- a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
+++ b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
@@ -10,8 +10,6 @@
 //
 // * `SparseTensorStorageBase`
 // * `SparseTensorStorage<P, C, V>`
-// * `SparseTensorEnumeratorBase<V>`
-// * `SparseTensorEnumerator<P, C, V>`
 //
 //===----------------------------------------------------------------------===//
 
@@ -29,25 +27,17 @@ namespace mlir {
 namespace sparse_tensor {
 
 /// The type of callback functions which receive an element.
-template <typename V>
-using ElementConsumer =
-    const std::function<void(const std::vector<uint64_t> &, V)> &;
-
-// Forward references.
-template <typename V>
-class SparseTensorEnumeratorBase;
-template <typename P, typename C, typename V>
-class SparseTensorEnumerator;
+template <typename V> using ElementConsumer = const std::function<void(V)> &;
 
 //===----------------------------------------------------------------------===//
 //
-//  SparseTensorStorage
+//  SparseTensorStorage Classes
 //
 //===----------------------------------------------------------------------===//
 
 /// Abstract base class for `SparseTensorStorage<P,C,V>`. This class
 /// takes responsibility for all the `<P,C,V>`-independent aspects
-/// of the tensor (e.g., shape, sparsity, mapping). In addition,
+/// of the tensor (e.g., sizes, sparsity, mapping). In addition,
 /// we use function overloading to implement "partial" method
 /// specialization, which the C-API relies on to catch type errors
 /// arising from our use of opaque pointers.
@@ -55,7 +45,7 @@ class SparseTensorEnumerator;
 /// Because this class forms a bridge between the denotational semantics
 /// of "tensors" and the operational semantics of how we store and
 /// compute with them, it also distinguishes between two different
-/// coordinate spaces (and their associated rank, shape, sizes, etc).
+/// coordinate spaces (and their associated rank, sizes, etc).
 /// Denotationally, we have the *dimensions* of the tensor represented
 /// by this object.  Operationally, we have the *levels* of the storage
 /// representation itself.
@@ -139,10 +129,6 @@ class SparseTensorStorageBase {
   /// Safely checks if the level is unique.
   bool isUniqueLvl(uint64_t l) const { return isUniqueDLT(getLvlType(l)); }
 
-  /// Gets the level-to-dimension mapping.
-  // TODO: REMOVE THIS
-  const std::vector<uint64_t> &getLvl2Dim() const { return lvl2dimVec; }
-
   /// Gets positions-overhead storage for the given level.
 #define DECL_GETPOSITIONS(PNAME, P)                                            \
   virtual void getPositions(std::vector<P> **, uint64_t);
@@ -154,6 +140,7 @@ class SparseTensorStorageBase {
   virtual void getCoordinates(std::vector<C> **, uint64_t);
   MLIR_SPARSETENSOR_FOREVERY_FIXED_O(DECL_GETCOORDINATES)
 #undef DECL_GETCOORDINATES
+
   /// Gets the coordinate-value stored at the given level and position.
   virtual uint64_t getCrd(uint64_t lvl, uint64_t pos) const = 0;
 
@@ -241,17 +228,9 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
   /// the contents from the COO. This ctor performs the same heuristic
   /// overhead-storage allocation as the ctor above.
   SparseTensorStorage(uint64_t dimRank, const uint64_t *dimSizes,
-                      uint64_t lvlRank, const DimLevelType *lvlTypes,
-                      const uint64_t *dim2lvl, const uint64_t *lvl2dim,
-                      SparseTensorCOO<V> &lvlCOO);
-
-  /// Constructs a sparse tensor with the given encoding, and initializes
-  /// the contents from the enumerator. This ctor allocates exactly
-  /// the required amount of overhead storage, not using any heuristics.
-  SparseTensorStorage(uint64_t dimRank, const uint64_t *dimSizes,
-                      uint64_t lvlRank, const DimLevelType *lvlTypes,
-                      const uint64_t *dim2lvl, const uint64_t *lvl2dim,
-                      SparseTensorEnumeratorBase<V> &lvlEnumerator);
+                      uint64_t lvlRank, const uint64_t *lvlSizes,
+                      const DimLevelType *lvlTypes, const uint64_t *dim2lvl,
+                      const uint64_t *lvl2dim, SparseTensorCOO<V> &lvlCOO);
 
   /// Constructs a sparse tensor with the given encoding, and initializes
   /// the contents from the level buffers. This ctor allocates exactly
@@ -265,39 +244,27 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
                       const DimLevelType *lvlTypes, const uint64_t *dim2lvl,
                       const uint64_t *lvl2dim, const intptr_t *lvlBufs);
 
-  /// Allocates a new empty sparse tensor. The preconditions/assertions
-  /// are as per the `SparseTensorStorageBase` ctor; which is to say,
-  /// the `dimSizes` and `lvlSizes` must both be "sizes" not "shapes",
-  /// since there's nowhere to reconstruct dynamic sizes from.
+  /// Allocates a new empty sparse tensor.
   static SparseTensorStorage<P, C, V> *
   newEmpty(uint64_t dimRank, const uint64_t *dimSizes, uint64_t lvlRank,
            const uint64_t *lvlSizes, const DimLevelType *lvlTypes,
            const uint64_t *dim2lvl, const uint64_t *lvl2dim, bool forwarding);
 
   /// Allocates a new sparse tensor and initializes it from the given COO.
-  /// The preconditions are as per the `SparseTensorStorageBase` ctor
-  /// (where we define `lvlSizes = lvlCOO.getDimSizes().data()`), but
-  /// using the following assertions in lieu of the base ctor's assertions:
-  //
-  // TODO: The ability to reconstruct dynamic dimensions-sizes does not
-  // easily generalize to arbitrary `lvl2dim` mappings.  When compiling
-  // MLIR programs to use this library, we should be able to generate
-  // code for effectively computing the reconstruction, but it's not clear
-  // that there's a feasible way to do so from within the library itself.
-  // Therefore, when we functionalize the `lvl2dim` mapping we'll have
-  // to update the type/preconditions of this factory too.
   static SparseTensorStorage<P, C, V> *
-  newFromCOO(uint64_t dimRank, const uint64_t *dimShape, uint64_t lvlRank,
-             const DimLevelType *lvlTypes, const uint64_t *dim2lvl,
-             const uint64_t *lvl2dim, SparseTensorCOO<V> &lvlCOO);
+  newFromCOO(uint64_t dimRank, const uint64_t *dimSizes, uint64_t lvlRank,
+             const uint64_t *lvlSizes, const DimLevelType *lvlTypes,
+             const uint64_t *dim2lvl, const uint64_t *lvl2dim,
+             SparseTensorCOO<V> &lvlCOO);
 
   /// Allocates a new sparse tensor and initialize it with the data stored level
   /// buffers directly.
-  static SparseTensorStorage<P, C, V> *packFromLvlBuffers(
-      uint64_t dimRank, const uint64_t *dimShape, uint64_t lvlRank,
-      const uint64_t *lvlSizes, const DimLevelType *lvlTypes,
-      const uint64_t *src2lvl, // FIXME: dim2lvl
-      const uint64_t *lvl2dim, uint64_t srcRank, const intptr_t *buffers);
+  static SparseTensorStorage<P, C, V> *
+  packFromLvlBuffers(uint64_t dimRank, const uint64_t *dimSizes,
+                     uint64_t lvlRank, const uint64_t *lvlSizes,
+                     const DimLevelType *lvlTypes, const uint64_t *dim2lvl,
+                     const uint64_t *lvl2dim, uint64_t srcRank,
+                     const intptr_t *buffers);
 
   ~SparseTensorStorage() final = default;
 
@@ -334,8 +301,7 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
   /// Partially specialize lexicographical insertions based on template types.
   void lexInsert(const uint64_t *lvlCoords, V val) final {
     assert(lvlCoords);
-    // TODO: get rid of this! canonicalize all-dense "sparse" array into dense
-    // tensors.
+    // TODO: needed?
     bool allDense = std::all_of(getLvlTypes().begin(), getLvlTypes().end(),
                                 [](DimLevelType lt) { return isDenseDLT(lt); });
     if (allDense) {
@@ -411,23 +377,18 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
       endPath(0);
   }
 
-  /// Allocates a new COO object and initializes it with the contents
-  /// of this tensor under the given mapping from the `getDimSizes()`
-  /// coordinate-space to the `trgSizes` coordinate-space. Callers must
-  /// make sure to delete the COO when they're done with it.
-  SparseTensorCOO<V> *toCOO(uint64_t trgRank, const uint64_t *trgSizes,
-                            uint64_t srcRank,
-                            const uint64_t *src2trg, // FIXME: dim2lvl
-                            const uint64_t *lvl2dim) const {
-    // TODO: use MapRef here too for the translation
-    SparseTensorEnumerator<P, C, V> enumerator(*this, trgRank, trgSizes,
-                                               srcRank, src2trg);
-    auto *coo = new SparseTensorCOO<V>(trgRank, trgSizes, values.size());
-    enumerator.forallElements(
-        [&coo](const auto &trgCoords, V val) { coo->add(trgCoords, val); });
-    // TODO: This assertion assumes there are no stored zeros,
-    // or if there are then that we don't filter them out.
-    // <https://github.com/llvm/llvm-project/issues/54179>
+  /// Allocates a new COO object and initializes it with the contents.
+  /// Callers must make sure to delete the COO when they're done with it.
+  SparseTensorCOO<V> *toCOO() const {
+    std::vector<uint64_t> dimCoords(getDimRank());
+    std::vector<uint64_t> lvlCoords(getLvlRank());
+    auto *coo = new SparseTensorCOO<V>(getDimSizes(), values.size());
+    forallElements(
+        [this, &dimCoords, &lvlCoords, &coo](V val) {
+          map.pushbackward(lvlCoords.data(), dimCoords.data());
+          coo->add(dimCoords, val);
+        },
+        0, 0, lvlCoords);
     assert(coo->getElements().size() == values.size());
     return coo;
   }
@@ -525,27 +486,11 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
     }
   }
 
-  /// Writes the given coordinate to `coordinates[lvl][pos]`.  This method
-  /// checks that `crd` is representable in the `C` type; however, it
-  /// does not check that `crd` is semantically valid (i.e., in bounds
-  /// for `dimSizes[lvl]` and not elsewhere occurring in the same segment).
-  void writeCrd(uint64_t lvl, uint64_t pos, uint64_t crd) {
-    assert(isCompressedDLT(getLvlType(lvl)) || isSingletonDLT(getLvlType(lvl)));
-    // Subscript assignment to `std::vector` requires that the `pos`-th
-    // entry has been initialized; thus we must be sure to check `size()`
-    // here, instead of `capacity()` as would be ideal.
-    assert(pos < coordinates[lvl].size());
-    coordinates[lvl][pos] = detail::checkOverflowCast<C>(crd);
-  }
-
   /// Computes the assembled-size associated with the `l`-th level,
   /// given the assembled-size associated with the `(l-1)`-th level.
   /// "Assembled-sizes" correspond to the (nominal) sizes of overhead
   /// storage, as opposed to "level-sizes" which are the cardinality
   /// of possible coordinates for that level.
-  ///
-  /// Precondition: the `positions[l]` array must be fully initialized
-  /// before calling this method.
   uint64_t assembledSize(uint64_t parentSz, uint64_t l) const {
     const auto dlt = getLvlType(l); // Avoid redundant bounds checking.
     if (isCompressedDLT(dlt))
@@ -553,7 +498,7 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
     if (isSingletonDLT(dlt))
       return parentSz; // New size is same as the parent.
     if (isDenseDLT(dlt))
-      return parentSz * getLvlSizes()[l];
+      return parentSz * getLvlSize(l);
     MLIR_SPARSETENSOR_FATAL("unsupported level type: %d\n",
                             static_cast<uint8_t>(dlt));
   }
@@ -561,11 +506,6 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
   /// Initializes sparse tensor storage scheme from a memory-resident sparse
   /// tensor in coordinate scheme. This method prepares the positions and
   /// coordinates arrays under the given per-level dense/sparse annotations.
-  ///
-  /// Preconditions:
-  /// * the `lvlElements` must be lexicographically sorted.
-  /// * the coordinates of every element are valid for `getLvlSizes()`
-  ///   (i.e., equal rank and pointwise less-than).
   void fromCOO(const std::vector<Element<V>> &lvlElements, uint64_t lo,
                uint64_t hi, uint64_t l) {
     const uint64_t lvlRank = getLvlRank();
@@ -669,184 +609,46 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
     return -1u;
   }
 
-  // Allow `SparseTensorEnumerator` to access the data-members (to avoid
-  // the cost of virtual-function dispatch in inner loops), without
-  // making them public to other client code.
-  friend class SparseTensorEnumerator<P, C, V>;
-
-  std::vector<std::vector<P>> positions;
-  std::vector<std::vector<C>> coordinates;
-  std::vector<V> values;
-  std::vector<uint64_t> lvlCursor; // cursor for lexicographic insertion.
-  SparseTensorCOO<V> *lvlCOO;      // COO used during forwarding
-};
-
-//===----------------------------------------------------------------------===//
-//
-//  SparseTensorEnumerator
-//
-//===----------------------------------------------------------------------===//
-
-/// A (higher-order) function object for enumerating the elements of some
-/// `SparseTensorStorage` under a permutation.  That is, the `forallElements`
-/// method encapsulates the loop-nest for enumerating the elements of
-/// the source tensor (in whatever order is best for the source tensor),
-/// and applies a permutation to the coordinates before handing
-/// each element to the callback.  A single enumerator object can be
-/// freely reused for several calls to `forallElements`, just so long
-/// as each call is sequential with respect to one another.
-///
-/// N.B., this class stores a reference to the `SparseTensorStorageBase`
-/// passed to the constructor; thus, objects of this class must not
-/// outlive the sparse tensor they depend on.
-///
-/// Design Note: The reason we define this class instead of simply using
-/// `SparseTensorEnumerator<P,C,V>` is because we need to hide/generalize
-/// the `<P,C>` template parameters from MLIR client code (to simplify the
-/// type parameters used for direct sparse-to-sparse conversion).  And the
-/// reason we define the `SparseTensorEnumerator<P,C,V>` subclasses rather
-/// than simply using this class, is to avoid the cost of virtual-method
-/// dispatch within the loop-nest.
-template <typename V>
-class SparseTensorEnumeratorBase {
-public:
-  /// Constructs an enumerator which automatically applies the given
-  /// mapping from the source tensor's dimensions to the desired
-  /// target tensor dimensions.
-  ///
-  /// Preconditions:
-  /// * the `src` must have the same `V` value type.
-  /// * `trgSizes` must be valid for `trgRank`.
-  /// * `src2trg` must be valid for `srcRank`, and must map coordinates
-  ///   valid for `src.getDimSizes()` to coordinates valid for `trgSizes`.
-  ///
-  /// Asserts:
-  /// * `trgSizes` must be nonnull and must contain only nonzero sizes.
-  /// * `srcRank == src.getDimRank()`.
-  /// * `src2trg` must be nonnull.
-  SparseTensorEnumeratorBase(const SparseTensorStorageBase &src,
-                             uint64_t trgRank, const uint64_t *trgSizes,
-                             uint64_t srcRank, const uint64_t *src2trg)
-      : src(src), trgSizes(trgSizes, trgSizes + trgRank),
-        lvl2trg(src.getLvlRank()), trgCursor(trgRank) {
-    assert(trgSizes && "Received nullptr for target-sizes");
-    assert(src2trg && "Received nullptr for source-to-target mapping");
-    assert(srcRank == src.getDimRank() && "Source-rank mismatch");
-    for (uint64_t t = 0; t < trgRank; ++t)
-      assert(trgSizes[t] > 0 && "Target-size zero has trivial storage");
-    const auto &lvl2src = src.getLvl2Dim();
-    for (uint64_t lvlRank = src.getLvlRank(), l = 0; l < lvlRank; ++l)
-      lvl2trg[l] = src2trg[lvl2src[l]];
-  }
-
-  virtual ~SparseTensorEnumeratorBase() = default;
-
-  // We disallow copying to help avoid leaking the `src` reference.
-  // (In addition to avoiding the problem of slicing.)
-  SparseTensorEnumeratorBase(const SparseTensorEnumeratorBase &) = delete;
-  SparseTensorEnumeratorBase &
-  operator=(const SparseTensorEnumeratorBase &) = delete;
-
-  /// Gets the source's dimension-rank.
-  uint64_t getSrcDimRank() const { return src.getDimRank(); }
-
-  /// Gets the target's dimension-/level-rank.  (This is usually
-  /// "dimension-rank", though that may coincide with "level-rank"
-  /// depending on usage.)
-  uint64_t getTrgRank() const { return trgSizes.size(); }
-
-  /// Gets the target's dimension-/level-sizes.  (These are usually
-  /// "dimensions", though that may coincide with "level-rank" depending
-  /// on usage.)
-  const std::vector<uint64_t> &getTrgSizes() const { return trgSizes; }
-
-  /// Enumerates all elements of the source tensor, permutes their
-  /// coordinates, and passes the permuted element to the callback.
-  /// The callback must not store the cursor reference directly,
-  /// since this function reuses the storage.  Instead, the callback
-  /// must copy it if they want to keep it.
-  virtual void forallElements(ElementConsumer<V> yield) = 0;
-
-protected:
-  const SparseTensorStorageBase &src;
-  std::vector<uint64_t> trgSizes;  // in target order.
-  std::vector<uint64_t> lvl2trg;   // source-levels -> target-dims/lvls.
-  std::vector<uint64_t> trgCursor; // in target order.
-};
-
-template <typename P, typename C, typename V>
-class SparseTensorEnumerator final : public SparseTensorEnumeratorBase<V> {
-  using Base = SparseTensorEnumeratorBase<V>;
-  using StorageImpl = SparseTensorStorage<P, C, V>;
-
-public:
-  /// Constructs an enumerator which automatically applies the given
-  /// mapping from the source tensor's dimensions to the desired
-  /// target tensor dimensions.
-  ///
-  /// Preconditions/assertions are as per the `SparseTensorEnumeratorBase` ctor.
-  SparseTensorEnumerator(const StorageImpl &src, uint64_t trgRank,
-                         const uint64_t *trgSizes, uint64_t srcRank,
-                         const uint64_t *src2trg)
-      : Base(src, trgRank, trgSizes, srcRank, src2trg) {}
-
-  ~SparseTensorEnumerator() final = default;
-
-  void forallElements(ElementConsumer<V> yield) final {
-    forallElements(yield, 0, 0);
-  }
-
-private:
-  // TODO: Once we functionalize the mappings, then we'll no longer
-  // be able to use the current approach of constructing `lvl2trg` in the
-  // ctor and using it to incrementally fill the `trgCursor` cursor as we
-  // recurse through `forallElements`.  Instead we'll want to incrementally
-  // fill a `lvlCursor` as we recurse, and then use `src.getLvl2Dim()`
-  // and `src2trg` to convert it just before yielding to the callback.
-  // It's probably most efficient to just store the `srcCursor` and
-  // `trgCursor` buffers in this object, but we may want to benchmark
-  // that against using `std::calloc` to stack-allocate them instead.
-  //
-  /// The recursive component of the public `forallElements`.
-  void forallElements(ElementConsumer<V> yield, uint64_t parentPos,
-                      uint64_t l) {
-    // Recover the `<P,C,V>` type parameters of `src`.
-    const auto &src = static_cast<const StorageImpl &>(this->src);
-    if (l == src.getLvlRank()) {
-      assert(parentPos < src.values.size());
-      // TODO: <https://github.com/llvm/llvm-project/issues/54179>
-      yield(this->trgCursor, src.values[parentPos]);
+  void forallElements(ElementConsumer<V> yield, uint64_t parentPos, uint64_t l,
+                      std::vector<uint64_t> &lvlCoords) const {
+    if (l == getLvlRank()) {
+      assert(parentPos < values.size());
+      yield(values[parentPos]);
       return;
     }
-    uint64_t &cursorL = this->trgCursor[this->lvl2trg[l]];
-    const auto dlt = src.getLvlType(l); // Avoid redundant bounds checking.
-    if (isCompressedDLT(dlt)) {
+    if (isCompressedLvl(l)) {
       // Look up the bounds of the `l`-level segment determined by the
       // `(l - 1)`-level position `parentPos`.
-      const std::vector<P> &positionsL = src.positions[l];
+      const std::vector<P> &positionsL = positions[l];
       assert(parentPos + 1 < positionsL.size());
       const uint64_t pstart = static_cast<uint64_t>(positionsL[parentPos]);
       const uint64_t pstop = static_cast<uint64_t>(positionsL[parentPos + 1]);
       // Loop-invariant code for looking up the `l`-level coordinates.
-      const std::vector<C> &coordinatesL = src.coordinates[l];
+      const std::vector<C> &coordinatesL = coordinates[l];
       assert(pstop <= coordinatesL.size());
       for (uint64_t pos = pstart; pos < pstop; ++pos) {
-        cursorL = static_cast<uint64_t>(coordinatesL[pos]);
-        forallElements(yield, pos, l + 1);
+        lvlCoords[l] = static_cast<uint64_t>(coordinatesL[pos]);
+        forallElements(yield, pos, l + 1, lvlCoords);
       }
-    } else if (isSingletonDLT(dlt)) {
-      cursorL = src.getCrd(l, parentPos);
-      forallElements(yield, parentPos, l + 1);
+    } else if (isSingletonLvl(l)) {
+      lvlCoords[l] = getCrd(l, parentPos);
+      forallElements(yield, parentPos, l + 1, lvlCoords);
     } else { // Dense level.
-      assert(isDenseDLT(dlt));
-      const uint64_t sz = src.getLvlSizes()[l];
+      assert(isDenseLvl(l));
+      const uint64_t sz = getLvlSizes()[l];
       const uint64_t pstart = parentPos * sz;
       for (uint64_t c = 0; c < sz; ++c) {
-        cursorL = c;
-        forallElements(yield, pstart + c, l + 1);
+        lvlCoords[l] = c;
+        forallElements(yield, pstart + c, l + 1, lvlCoords);
       }
     }
   }
+
+  std::vector<std::vector<P>> positions;
+  std::vector<std::vector<C>> coordinates;
+  std::vector<V> values;
+  std::vector<uint64_t> lvlCursor; // cursor for lexicographic insertion.
+  SparseTensorCOO<V> *lvlCOO;      // COO used during forwarding
 };
 
 //===----------------------------------------------------------------------===//
@@ -868,41 +670,24 @@ SparseTensorStorage<P, C, V> *SparseTensorStorage<P, C, V>::newEmpty(
                                           !forwarding);
 }
 
-// TODO: MapRef
 template <typename P, typename C, typename V>
 SparseTensorStorage<P, C, V> *SparseTensorStorage<P, C, V>::newFromCOO(
-    uint64_t dimRank, const uint64_t *dimShape, uint64_t lvlRank,
-    const DimLevelType *lvlTypes, const uint64_t *dim2lvl,
-    const uint64_t *lvl2dim, SparseTensorCOO<V> &lvlCOO) {
-  assert(dimShape && dim2lvl && lvl2dim);
-  const auto &lvlSizes = lvlCOO.getDimSizes();
-  assert(lvlRank == lvlSizes.size() && "Level-rank mismatch");
-  // Must reconstruct `dimSizes` from `lvlSizes`.  While this is easy
-  // enough to do when `lvl2dim` is a permutation, this approach will
-  // not work for more general mappings; so we will need to move this
-  // computation off to codegen.
-  std::vector<uint64_t> dimSizes(dimRank);
-  for (uint64_t l = 0; l < lvlRank; ++l) {
-    const uint64_t d = lvl2dim[l];
-    assert((dimShape[d] == 0 || dimShape[d] == lvlSizes[l]) &&
-           "Dimension sizes do not match expected shape");
-    dimSizes[d] = lvlSizes[l];
-  }
-  return new SparseTensorStorage<P, C, V>(dimRank, dimSizes.data(), lvlRank,
+    uint64_t dimRank, const uint64_t *dimSizes, uint64_t lvlRank,
+    const uint64_t *lvlSizes, const DimLevelType *lvlTypes,
+    const uint64_t *dim2lvl, const uint64_t *lvl2dim,
+    SparseTensorCOO<V> &lvlCOO) {
+  return new SparseTensorStorage<P, C, V>(dimRank, dimSizes, lvlRank, lvlSizes,
                                           lvlTypes, dim2lvl, lvl2dim, lvlCOO);
 }
 
 template <typename P, typename C, typename V>
 SparseTensorStorage<P, C, V> *SparseTensorStorage<P, C, V>::packFromLvlBuffers(
-    uint64_t dimRank, const uint64_t *dimShape, uint64_t lvlRank,
+    uint64_t dimRank, const uint64_t *dimSizes, uint64_t lvlRank,
     const uint64_t *lvlSizes, const DimLevelType *lvlTypes,
-    const uint64_t *src2lvl, // FIXME: dim2lvl
-    const uint64_t *lvl2dim, uint64_t srcRank, const intptr_t *buffers) {
-  assert(dimShape && "Got nullptr for dimension shape");
-  auto *tensor =
-      new SparseTensorStorage<P, C, V>(dimRank, dimShape, lvlRank, lvlSizes,
-                                       lvlTypes, src2lvl, lvl2dim, buffers);
-  return tensor;
+    const uint64_t *dim2lvl, const uint64_t *lvl2dim, uint64_t srcRank,
+    const intptr_t *buffers) {
+  return new SparseTensorStorage<P, C, V>(dimRank, dimSizes, lvlRank, lvlSizes,
+                                          lvlTypes, dim2lvl, lvl2dim, buffers);
 }
 
 //===----------------------------------------------------------------------===//
@@ -920,6 +705,7 @@ SparseTensorStorage<P, C, V>::SparseTensorStorage(
     : SparseTensorStorage(dimRank, dimSizes, lvlRank, lvlSizes, lvlTypes,
                           dim2lvl, lvl2dim) {
   lvlCOO = coo;
+  assert(!lvlCOO || lvlRank == lvlCOO->getRank());
   // Provide hints on capacity of positions and coordinates.
   // TODO: needs much fine-tuning based on actual sparsity; currently
   // we reserve position/coordinate space based on all previous dense
@@ -948,17 +734,16 @@ SparseTensorStorage<P, C, V>::SparseTensorStorage(
     values.resize(sz, 0);
 }
 
-// TODO: share more code with forwarding methods?
 template <typename P, typename C, typename V>
 SparseTensorStorage<P, C, V>::SparseTensorStorage( // NOLINT
     uint64_t dimRank, const uint64_t *dimSizes, uint64_t lvlRank,
-    const DimLevelType *lvlTypes, const uint64_t *dim2lvl,
-    const uint64_t *lvl2dim, SparseTensorCOO<V> &lvlCOO)
-    : SparseTensorStorage(dimRank, dimSizes, lvlRank,
-                          lvlCOO.getDimSizes().data(), lvlTypes, dim2lvl,
-                          lvl2dim, nullptr, false) {
+    const uint64_t *lvlSizes, const DimLevelType *lvlTypes,
+    const uint64_t *dim2lvl, const uint64_t *lvl2dim,
+    SparseTensorCOO<V> &lvlCOO)
+    : SparseTensorStorage(dimRank, dimSizes, lvlRank, lvlSizes, lvlTypes,
+                          dim2lvl, lvl2dim, nullptr, false) {
   // Ensure lvlCOO is sorted.
-  assert(lvlRank == lvlCOO.getDimSizes().size() && "Level-rank mismatch");
+  assert(lvlRank == lvlCOO.getRank());
   lvlCOO.sort();
   // Now actually insert the `elements`.
   const auto &elements = lvlCOO.getElements();
diff --git a/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp b/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp
index 6a4c0f292c5f81e..36d888a08de6d60 100644
--- a/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp
+++ b/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp
@@ -129,7 +129,8 @@ extern "C" {
       assert(ptr && "Received nullptr for SparseTensorCOO object");            \
       auto &coo = *static_cast<SparseTensorCOO<V> *>(ptr);                     \
       return SparseTensorStorage<P, C, V>::newFromCOO(                         \
-          dimRank, dimSizes, lvlRank, lvlTypes, dim2lvl, lvl2dim, coo);        \
+          dimRank, dimSizes, lvlRank, lvlSizes, lvlTypes, dim2lvl, lvl2dim,    \
+          coo);                                                                \
     }                                                                          \
     case Action::kFromReader: {                                                \
       assert(ptr && "Received nullptr for SparseTensorReader object");         \
@@ -140,7 +141,7 @@ extern "C" {
     case Action::kToCOO: {                                                     \
       assert(ptr && "Received nullptr for SparseTensorStorage object");        \
       auto &tensor = *static_cast<SparseTensorStorage<P, C, V> *>(ptr);        \
-      return tensor.toCOO(lvlRank, lvlSizes, dimRank, dim2lvl, lvl2dim);       \
+      return tensor.toCOO();                                                   \
     }                                                                          \
     case Action::kPack: {                                                      \
       assert(ptr && "Received nullptr for SparseTensorStorage object");        \



More information about the Mlir-commits mailing list