[Mlir-commits] [mlir] 2045cca - [mlir][sparse] add a forwarding insertion to SparseTensorStorage (#68939)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Oct 12 21:03:12 PDT 2023


Author: Aart Bik
Date: 2023-10-12T21:03:07-07:00
New Revision: 2045cca0c3d27f046c96257abfa11c769ce9b1ce

URL: https://github.com/llvm/llvm-project/commit/2045cca0c3d27f046c96257abfa11c769ce9b1ce
DIFF: https://github.com/llvm/llvm-project/commit/2045cca0c3d27f046c96257abfa11c769ce9b1ce.diff

LOG: [mlir][sparse] add a forwarding insertion to SparseTensorStorage (#68939)

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
    mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
    mlir/include/mlir/ExecutionEngine/SparseTensorRuntime.h
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
    mlir/lib/ExecutionEngine/SparseTensor/Storage.cpp
    mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp
    mlir/test/Dialect/SparseTensor/conversion.mlir
    mlir/test/Dialect/SparseTensor/sparse_expand.mlir
    mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
index ca9555248130f08..f1643d66c26a195 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
@@ -143,11 +143,10 @@ constexpr bool isComplexPrimaryType(PrimaryType valTy) {
 /// The actions performed by @newSparseTensor.
 enum class Action : uint32_t {
   kEmpty = 0,
-  // newSparseTensor no longer handles `kFromFile=1`, so we leave this
-  // number reserved to help catch any code that still needs updating.
+  kEmptyForward = 1,
   kFromCOO = 2,
   kSparseToSparse = 3,
-  kEmptyCOO = 4,
+  kFuture = 4, // not used
   kToCOO = 5,
   kToIterator = 6,
   kPack = 7,

diff  --git a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
index 607be1cbf956a7d..0d95c60a08689d2 100644
--- a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
+++ b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
@@ -33,7 +33,6 @@
     assert((isCompressedDLT(dlt) || isSingletonDLT(dlt)) &&                    \
            "Level is neither compressed nor singleton");                       \
   } while (false)
-#define ASSERT_DENSE_DLT(dlt) assert(isDenseDLT(dlt) && "Level is not dense");
 
 namespace mlir {
 namespace sparse_tensor {
@@ -44,6 +43,12 @@ class SparseTensorEnumeratorBase;
 template <typename P, typename C, typename V>
 class SparseTensorEnumerator;
 
+//===----------------------------------------------------------------------===//
+//
+//  SparseTensorStorage
+//
+//===----------------------------------------------------------------------===//
+
 /// Abstract base class for `SparseTensorStorage<P,C,V>`. This class
 /// takes responsibility for all the `<P,C,V>`-independent aspects
 /// of the tensor (e.g., shape, sparsity, mapping). In addition,
@@ -97,7 +102,7 @@ class SparseTensorStorageBase {
 
   /// Safely looks up the size of the given tensor-dimension.
   uint64_t getDimSize(uint64_t d) const {
-    assert(d < getDimRank() && "Dimension is out of bounds");
+    assert(d < getDimRank());
     return dimSizes[d];
   }
 
@@ -106,7 +111,7 @@ class SparseTensorStorageBase {
 
   /// Safely looks up the size of the given storage-level.
   uint64_t getLvlSize(uint64_t l) const {
-    assert(l < getLvlRank() && "Level is out of bounds");
+    assert(l < getLvlRank());
     return lvlSizes[l];
   }
 
@@ -115,7 +120,7 @@ class SparseTensorStorageBase {
 
   /// Safely looks up the type of the given level.
   DimLevelType getLvlType(uint64_t l) const {
-    assert(l < getLvlRank() && "Level is out of bounds");
+    assert(l < getLvlRank());
     return lvlTypes[l];
   }
 
@@ -173,6 +178,13 @@ class SparseTensorStorageBase {
   MLIR_SPARSETENSOR_FOREVERY_V(DECL_GETVALUES)
 #undef DECL_GETVALUES
 
+  /// Element-wise forwarding insertions. The first argument is the
+  /// dimension-coordinates for the value being inserted.
+#define DECL_FORWARDINGINSERT(VNAME, V)                                        \
+  virtual void forwardingInsert(const uint64_t *, V);
+  MLIR_SPARSETENSOR_FOREVERY_V(DECL_FORWARDINGINSERT)
+#undef DECL_FORWARDINGINSERT
+
   /// Element-wise insertion in lexicographic coordinate order. The first
   /// argument is the level-coordinates for the value being inserted.
 #define DECL_LEXINSERT(VNAME, V) virtual void lexInsert(const uint64_t *, V);
@@ -182,24 +194,17 @@ class SparseTensorStorageBase {
   /// Expanded insertion.  Note that this method resets the
   /// values/filled-switch array back to all-zero/false while only
   /// iterating over the nonzero elements.
-  ///
-  /// Arguments:
-  /// * `lvlCoords` the level-coordinates shared by the values being inserted.
-  /// * `values` a map from last-level coordinates to their associated value.
-  /// * `filled` a map from last-level coordinates to bool, indicating
-  ///   whether `values` contains a valid value to be inserted.
-  /// * `added` a map from `[0..count)` to last-level coordinates for
-  ///   which `filled` is true and `values` contains the assotiated value.
-  /// * `count` the size of `added`.
-  /// * `expsz` the size of the expanded vector (verification only).
 #define DECL_EXPINSERT(VNAME, V)                                               \
   virtual void expInsert(uint64_t *, V *, bool *, uint64_t *, uint64_t,        \
                          uint64_t);
   MLIR_SPARSETENSOR_FOREVERY_V(DECL_EXPINSERT)
 #undef DECL_EXPINSERT
 
-  /// Finishes insertion.
-  virtual void endInsert() = 0;
+  /// Finalizes forwarding insertions.
+  virtual void endForwardingInsert() = 0;
+
+  /// Finalizes lexicographic insertions.
+  virtual void endLexInsert() = 0;
 
 private:
   const std::vector<uint64_t> dimSizes;
@@ -207,6 +212,8 @@ class SparseTensorStorageBase {
   const std::vector<DimLevelType> lvlTypes;
   const std::vector<uint64_t> dim2lvlVec;
   const std::vector<uint64_t> lvl2dimVec;
+
+protected:
   const MapRef map; // non-owning pointers into dim2lvl/lvl2dim vectors
 };
 
@@ -229,7 +236,8 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
                       const uint64_t *lvl2dim)
       : SparseTensorStorageBase(dimRank, dimSizes, lvlRank, lvlSizes, lvlTypes,
                                 dim2lvl, lvl2dim),
-        positions(lvlRank), coordinates(lvlRank), lvlCursor(lvlRank) {}
+        positions(lvlRank), coordinates(lvlRank), lvlCursor(lvlRank), lvlCOO() {
+  }
 
 public:
   /// Constructs a sparse tensor with the given encoding, and allocates
@@ -242,11 +250,12 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
   SparseTensorStorage(uint64_t dimRank, const uint64_t *dimSizes,
                       uint64_t lvlRank, const uint64_t *lvlSizes,
                       const DimLevelType *lvlTypes, const uint64_t *dim2lvl,
-                      const uint64_t *lvl2dim, bool initializeValuesIfAllDense);
+                      const uint64_t *lvl2dim, SparseTensorCOO<V> *coo,
+                      bool initializeValuesIfAllDense);
 
   /// Constructs a sparse tensor with the given encoding, and initializes
   /// the contents from the COO. This ctor performs the same heuristic
-  /// overhead-storage allocation as the ctor taking a `bool`.
+  /// overhead-storage allocation as the ctor above.
   SparseTensorStorage(uint64_t dimRank, const uint64_t *dimSizes,
                       uint64_t lvlRank, const DimLevelType *lvlTypes,
                       const uint64_t *dim2lvl, const uint64_t *lvl2dim,
@@ -279,10 +288,7 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
   static SparseTensorStorage<P, C, V> *
   newEmpty(uint64_t dimRank, const uint64_t *dimSizes, uint64_t lvlRank,
            const uint64_t *lvlSizes, const DimLevelType *lvlTypes,
-           const uint64_t *dim2lvl, const uint64_t *lvl2dim) {
-    return new SparseTensorStorage<P, C, V>(
-        dimRank, dimSizes, lvlRank, lvlSizes, lvlTypes, dim2lvl, lvl2dim, true);
-  }
+           const uint64_t *dim2lvl, const uint64_t *lvl2dim, bool forwarding);
 
   /// Allocates a new sparse tensor and initializes it from the given COO.
   /// The preconditions are as per the `SparseTensorStorageBase` ctor
@@ -303,19 +309,6 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
 
   /// Allocates a new sparse tensor and initializes it with the contents
   /// of another sparse tensor.
-  ///
-  /// Preconditions:
-  /// * as per the `SparseTensorStorageBase` ctor.
-  /// * `src2lvl` must be valid for `srcRank`, must map coordinates valid
-  ///    for `source.getDimSizes()` to coordinates valid for `lvlSizes`,
-  ///    and therefore must be the inverse of `lvl2dim`.
-  /// * `source` must have the same value type `V`.
-  ///
-  /// Asserts:
-  /// * `dimRank` and `lvlRank` are nonzero.
-  /// * `srcRank == source.getDimRank()`.
-  /// * `lvlSizes` contains only nonzero sizes.
-  /// * `source.getDimSizes()` is a refinement of `dimShape`.
   //
   // TODO: The `dimRank` and `dimShape` arguments are only used for
   // verifying that the source tensor has the expected shape.  So if we
@@ -337,10 +330,6 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
 
   /// Allocates a new sparse tensor and initialize it with the data stored level
   /// buffers directly.
-  ///
-  /// Precondition:
-  /// * as per the `SparseTensorStorageBase` ctor.
-  /// * the data integrity stored in `buffers` is guaranteed by users already.
   static SparseTensorStorage<P, C, V> *packFromLvlBuffers(
       uint64_t dimRank, const uint64_t *dimShape, uint64_t lvlRank,
       const uint64_t *lvlSizes, const DimLevelType *lvlTypes,
@@ -352,12 +341,12 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
   /// Partially specialize these getter methods based on template types.
   void getPositions(std::vector<P> **out, uint64_t lvl) final {
     assert(out && "Received nullptr for out parameter");
-    assert(lvl < getLvlRank() && "Level is out of bounds");
+    assert(lvl < getLvlRank());
     *out = &positions[lvl];
   }
   void getCoordinates(std::vector<C> **out, uint64_t lvl) final {
     assert(out && "Received nullptr for out parameter");
-    assert(lvl < getLvlRank() && "Level is out of bounds");
+    assert(lvl < getLvlRank());
     *out = &coordinates[lvl];
   }
   void getValues(std::vector<V> **out) final {
@@ -365,15 +354,23 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
     *out = &values;
   }
 
+  /// Returns coordinate at given position.
   uint64_t getCrd(uint64_t lvl, uint64_t pos) const final {
     ASSERT_COMPRESSED_OR_SINGLETON_LVL(lvl);
-    assert(pos < coordinates[lvl].size() && "Position is out of bounds");
+    assert(pos < coordinates[lvl].size());
     return coordinates[lvl][pos]; // Converts the stored `C` into `uint64_t`.
   }
 
+  /// Partially specialize forwarding insertions based on template types.
+  void forwardingInsert(const uint64_t *dimCoords, V val) final {
+    assert(dimCoords && lvlCOO);
+    map.pushforward(dimCoords, lvlCursor.data());
+    lvlCOO->add(lvlCursor, val);
+  }
+
   /// Partially specialize lexicographical insertions based on template types.
   void lexInsert(const uint64_t *lvlCoords, V val) final {
-    assert(lvlCoords && "Received nullptr for level-coordinates");
+    assert(lvlCoords);
     // TODO: get rid of this! canonicalize all-dense "sparse" array into dense
     // tensors.
     bool allDense = std::all_of(getLvlTypes().begin(), getLvlTypes().end(),
@@ -429,8 +426,22 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
     }
   }
 
+  /// Finalizes forwarding insertions.
+  void endForwardingInsert() final {
+    // Ensure lvlCOO is sorted.
+    assert(lvlCOO);
+    lvlCOO->sort();
+    // Now actually insert the `elements`.
+    const auto &elements = lvlCOO->getElements();
+    const uint64_t nse = elements.size();
+    assert(values.size() == 0);
+    values.reserve(nse);
+    fromCOO(elements, 0, nse, 0);
+    delete lvlCOO;
+  }
+
   /// Finalizes lexicographic insertions.
-  void endInsert() final {
+  void endLexInsert() final {
     if (values.empty())
       finalizeSegment(0);
     else
@@ -533,7 +544,7 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
   /// does not check that `pos` is semantically valid (i.e., larger than
   /// the previous position and smaller than `coordinates[lvl].capacity()`).
   void appendPos(uint64_t lvl, uint64_t pos, uint64_t count = 1) {
-    assert(isCompressedLvl(lvl) && "Level is not compressed");
+    assert(isCompressedLvl(lvl));
     positions[lvl].insert(positions[lvl].end(), count,
                           detail::checkOverflowCast<P>(pos));
   }
@@ -552,7 +563,7 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
     if (isCompressedDLT(dlt) || isSingletonDLT(dlt)) {
       coordinates[lvl].push_back(detail::checkOverflowCast<C>(crd));
     } else { // Dense level.
-      ASSERT_DENSE_DLT(dlt);
+      assert(isDenseDLT(dlt));
       assert(crd >= full && "Coordinate was already filled");
       if (crd == full)
         return; // Short-circuit, since it'll be a nop.
@@ -572,7 +583,7 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
     // Subscript assignment to `std::vector` requires that the `pos`-th
     // entry has been initialized; thus we must be sure to check `size()`
     // here, instead of `capacity()` as would be ideal.
-    assert(pos < coordinates[lvl].size() && "Position is out of bounds");
+    assert(pos < coordinates[lvl].size());
     coordinates[lvl][pos] = detail::checkOverflowCast<C>(crd);
   }
 
@@ -644,7 +655,7 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
     } else if (isSingletonDLT(dlt)) {
       return; // Nothing to finalize.
     } else {  // Dense dimension.
-      ASSERT_DENSE_DLT(dlt);
+      assert(isDenseDLT(dlt));
       const uint64_t sz = getLvlSizes()[l];
       assert(sz >= full && "Segment is overfull");
       count = detail::checkedMul(count, sz - full);
@@ -663,7 +674,7 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
   void endPath(uint64_t 
diff Lvl) {
     const uint64_t lvlRank = getLvlRank();
     const uint64_t lastLvl = lvlRank - 1;
-    assert(
diff Lvl <= lvlRank && "Level-
diff  is out of bounds");
+    assert(
diff Lvl <= lvlRank);
     const uint64_t stop = lvlRank - 
diff Lvl;
     for (uint64_t i = 0; i < stop; ++i) {
       const uint64_t l = lastLvl - i;
@@ -676,7 +687,7 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
   void insPath(const uint64_t *lvlCoords, uint64_t 
diff Lvl, uint64_t full,
                V val) {
     const uint64_t lvlRank = getLvlRank();
-    assert(
diff Lvl <= lvlRank && "Level-
diff  is out of bounds");
+    assert(
diff Lvl <= lvlRank);
     for (uint64_t l = 
diff Lvl; l < lvlRank; ++l) {
       const uint64_t c = lvlCoords[l];
       appendCrd(l, full, c);
@@ -716,11 +727,17 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
   std::vector<std::vector<C>> coordinates;
   std::vector<V> values;
   std::vector<uint64_t> lvlCursor; // cursor for lexicographic insertion.
+  SparseTensorCOO<V> *lvlCOO;      // COO used during forwarding
 };
 
 #undef ASSERT_COMPRESSED_OR_SINGLETON_LVL
 
 //===----------------------------------------------------------------------===//
+//
+//  SparseTensorEnumerator
+//
+//===----------------------------------------------------------------------===//
+
 /// A (higher-order) function object for enumerating the elements of some
 /// `SparseTensorStorage` under a permutation.  That is, the `forallElements`
 /// method encapsulates the loop-nest for enumerating the elements of
@@ -808,7 +825,6 @@ class SparseTensorEnumeratorBase {
   std::vector<uint64_t> trgCursor; // in target order.
 };
 
-//===----------------------------------------------------------------------===//
 template <typename P, typename C, typename V>
 class SparseTensorEnumerator final : public SparseTensorEnumeratorBase<V> {
   using Base = SparseTensorEnumeratorBase<V>;
@@ -848,8 +864,7 @@ class SparseTensorEnumerator final : public SparseTensorEnumeratorBase<V> {
     // Recover the `<P,C,V>` type parameters of `src`.
     const auto &src = static_cast<const StorageImpl &>(this->src);
     if (l == src.getLvlRank()) {
-      assert(parentPos < src.values.size() &&
-             "Value position is out of bounds");
+      assert(parentPos < src.values.size());
       // TODO: <https://github.com/llvm/llvm-project/issues/54179>
       yield(this->trgCursor, src.values[parentPos]);
       return;
@@ -860,13 +875,12 @@ class SparseTensorEnumerator final : public SparseTensorEnumeratorBase<V> {
       // Look up the bounds of the `l`-level segment determined by the
       // `(l - 1)`-level position `parentPos`.
       const std::vector<P> &positionsL = src.positions[l];
-      assert(parentPos + 1 < positionsL.size() &&
-             "Parent position is out of bounds");
+      assert(parentPos + 1 < positionsL.size());
       const uint64_t pstart = static_cast<uint64_t>(positionsL[parentPos]);
       const uint64_t pstop = static_cast<uint64_t>(positionsL[parentPos + 1]);
       // Loop-invariant code for looking up the `l`-level coordinates.
       const std::vector<C> &coordinatesL = src.coordinates[l];
-      assert(pstop <= coordinatesL.size() && "Stop position is out of bounds");
+      assert(pstop <= coordinatesL.size());
       for (uint64_t pos = pstart; pos < pstop; ++pos) {
         cursorL = static_cast<uint64_t>(coordinatesL[pos]);
         forallElements(yield, pos, l + 1);
@@ -875,7 +889,7 @@ class SparseTensorEnumerator final : public SparseTensorEnumeratorBase<V> {
       cursorL = src.getCrd(l, parentPos);
       forallElements(yield, parentPos, l + 1);
     } else { // Dense level.
-      ASSERT_DENSE_DLT(dlt);
+      assert(isDenseDLT(dlt));
       const uint64_t sz = src.getLvlSizes()[l];
       const uint64_t pstart = parentPos * sz;
       for (uint64_t c = 0; c < sz; ++c) {
@@ -887,6 +901,11 @@ class SparseTensorEnumerator final : public SparseTensorEnumeratorBase<V> {
 };
 
 //===----------------------------------------------------------------------===//
+//
+//  SparseTensorNNZ
+//
+//===----------------------------------------------------------------------===//
+
 /// Statistics regarding the number of nonzero subtensors in
 /// a source tensor, for direct sparse=>sparse conversion a la
 /// <https://arxiv.org/abs/2001.02609>.
@@ -959,7 +978,23 @@ class SparseTensorNNZ final {
 };
 
 //===----------------------------------------------------------------------===//
-// Definitions of the ctors and factories of `SparseTensorStorage<P,C,V>`.
+//
+//  SparseTensorStorage Factories
+//
+//===----------------------------------------------------------------------===//
+
+template <typename P, typename C, typename V>
+SparseTensorStorage<P, C, V> *SparseTensorStorage<P, C, V>::newEmpty(
+    uint64_t dimRank, const uint64_t *dimSizes, uint64_t lvlRank,
+    const uint64_t *lvlSizes, const DimLevelType *lvlTypes,
+    const uint64_t *dim2lvl, const uint64_t *lvl2dim, bool forwarding) {
+  SparseTensorCOO<V> *lvlCOO = nullptr;
+  if (forwarding)
+    lvlCOO = new SparseTensorCOO<V>(lvlRank, lvlSizes);
+  return new SparseTensorStorage<P, C, V>(dimRank, dimSizes, lvlRank, lvlSizes,
+                                          lvlTypes, dim2lvl, lvl2dim, lvlCOO,
+                                          !forwarding);
+}
 
 // TODO: MapRef
 template <typename P, typename C, typename V>
@@ -967,8 +1002,7 @@ SparseTensorStorage<P, C, V> *SparseTensorStorage<P, C, V>::newFromCOO(
     uint64_t dimRank, const uint64_t *dimShape, uint64_t lvlRank,
     const DimLevelType *lvlTypes, const uint64_t *dim2lvl,
     const uint64_t *lvl2dim, SparseTensorCOO<V> &lvlCOO) {
-  assert(dimShape && "Got nullptr for dimension shape");
-  assert(lvl2dim && "Got nullptr for level-to-dimension mapping");
+  assert(dimShape && dim2lvl && lvl2dim);
   const auto &lvlSizes = lvlCOO.getDimSizes();
   assert(lvlRank == lvlSizes.size() && "Level-rank mismatch");
   // Must reconstruct `dimSizes` from `lvlSizes`.  While this is easy
@@ -1026,14 +1060,21 @@ SparseTensorStorage<P, C, V> *SparseTensorStorage<P, C, V>::packFromLvlBuffers(
   return tensor;
 }
 
+//===----------------------------------------------------------------------===//
+//
+//  SparseTensorStorage Constructors
+//
+//===----------------------------------------------------------------------===//
+
 template <typename P, typename C, typename V>
 SparseTensorStorage<P, C, V>::SparseTensorStorage(
     uint64_t dimRank, const uint64_t *dimSizes, uint64_t lvlRank,
     const uint64_t *lvlSizes, const DimLevelType *lvlTypes,
-    const uint64_t *dim2lvl, const uint64_t *lvl2dim,
+    const uint64_t *dim2lvl, const uint64_t *lvl2dim, SparseTensorCOO<V> *coo,
     bool initializeValuesIfAllDense)
     : SparseTensorStorage(dimRank, dimSizes, lvlRank, lvlSizes, lvlTypes,
                           dim2lvl, lvl2dim) {
+  lvlCOO = coo;
   // Provide hints on capacity of positions and coordinates.
   // TODO: needs much fine-tuning based on actual sparsity; currently
   // we reserve position/coordinate space based on all previous dense
@@ -1054,7 +1095,7 @@ SparseTensorStorage<P, C, V>::SparseTensorStorage(
       sz = 1;
       allDense = false;
     } else { // Dense level.
-      ASSERT_DENSE_DLT(dlt);
+      assert(isDenseDLT(dlt));
       sz = detail::checkedMul(sz, lvlSizes[l]);
     }
   }
@@ -1062,6 +1103,7 @@ SparseTensorStorage<P, C, V>::SparseTensorStorage(
     values.resize(sz, 0);
 }
 
+// TODO: share more code with forwarding methods?
 template <typename P, typename C, typename V>
 SparseTensorStorage<P, C, V>::SparseTensorStorage( // NOLINT
     uint64_t dimRank, const uint64_t *dimSizes, uint64_t lvlRank,
@@ -1069,14 +1111,14 @@ SparseTensorStorage<P, C, V>::SparseTensorStorage( // NOLINT
     const uint64_t *lvl2dim, SparseTensorCOO<V> &lvlCOO)
     : SparseTensorStorage(dimRank, dimSizes, lvlRank,
                           lvlCOO.getDimSizes().data(), lvlTypes, dim2lvl,
-                          lvl2dim, false) {
+                          lvl2dim, nullptr, false) {
+  // Ensure lvlCOO is sorted.
   assert(lvlRank == lvlCOO.getDimSizes().size() && "Level-rank mismatch");
-  // Ensure the preconditions of `fromCOO`.  (One is already ensured by
-  // using `lvlSizes = lvlCOO.getDimSizes()` in the ctor above.)
   lvlCOO.sort();
   // Now actually insert the `elements`.
   const auto &elements = lvlCOO.getElements();
   const uint64_t nse = elements.size();
+  assert(values.size() == 0);
   values.reserve(nse);
   fromCOO(elements, 0, nse, 0);
 }
@@ -1123,7 +1165,7 @@ SparseTensorStorage<P, C, V>::SparseTensorStorage(
       if (isCompressedDLT(dlt) || isSingletonDLT(dlt))
         coordinates[l].resize(parentSz, 0);
       else
-        ASSERT_DENSE_DLT(dlt); // Future-proofing.
+        assert(isDenseDLT(dlt));
     }
     values.resize(parentSz, 0); // Both allocate and zero-initialize.
   }
@@ -1137,7 +1179,7 @@ SparseTensorStorage<P, C, V>::SparseTensorStorage(
         // however, it's semantically invalid here since that entry
         // does not represent a segment of `coordinates[l]`.  Moreover, that
         // entry must be immutable for `assembledSize` to remain valid.
-        assert(parentPos < parentSz && "Parent position is out of bounds");
+        assert(parentPos < parentSz);
         const uint64_t currentPos = positions[l][parentPos];
         // This increment won't overflow the `P` type, since it can't
         // exceed the original value of `positions[l][parentPos+1]`
@@ -1150,12 +1192,12 @@ SparseTensorStorage<P, C, V>::SparseTensorStorage(
         writeCrd(l, parentPos, lvlCoords[l]);
         // the new parentPos equals the old parentPos.
       } else { // Dense level.
-        ASSERT_DENSE_DLT(dlt);
+        assert(isDenseDLT(dlt));
         parentPos = parentPos * getLvlSizes()[l] + lvlCoords[l];
       }
       parentSz = assembledSize(parentSz, l);
     }
-    assert(parentPos < values.size() && "Value position is out of bounds");
+    assert(parentPos < values.size());
     values[parentPos] = val;
   });
   // The finalizeYieldPos loop
@@ -1175,8 +1217,7 @@ SparseTensorStorage<P, C, V>::SparseTensorStorage(
     } else {
       // Both dense and singleton are no-ops for the finalizeYieldPos loop.
       // This assertion is for future-proofing.
-      assert((isDenseDLT(dlt) || isSingletonDLT(dlt)) &&
-             "Level is neither dense nor singleton");
+      assert((isDenseDLT(dlt) || isSingletonDLT(dlt)));
     }
     parentSz = assembledSize(parentSz, l);
   }
@@ -1210,7 +1251,7 @@ SparseTensorStorage<P, C, V>::SparseTensorStorage(
       positions[l].assign(posPtr, posPtr + parentSz + 1);
       coordinates[l].assign(crdPtr, crdPtr + positions[l][parentSz]);
     } else {
-      assert(isDenseLvl(l) && "Level is not dense");
+      assert(isDenseLvl(l));
     }
     parentSz = assembledSize(parentSz, l);
   }
@@ -1235,8 +1276,6 @@ SparseTensorStorage<P, C, V>::SparseTensorStorage(
   values.assign(valPtr, valPtr + parentSz);
 }
 
-#undef ASSERT_DENSE_DLT
-
 } // namespace sparse_tensor
 } // namespace mlir
 

diff  --git a/mlir/include/mlir/ExecutionEngine/SparseTensorRuntime.h b/mlir/include/mlir/ExecutionEngine/SparseTensorRuntime.h
index e723a354345849d..f9312c866f36317 100644
--- a/mlir/include/mlir/ExecutionEngine/SparseTensorRuntime.h
+++ b/mlir/include/mlir/ExecutionEngine/SparseTensorRuntime.h
@@ -37,7 +37,6 @@ extern "C" {
 //
 //===----------------------------------------------------------------------===//
 
-/// The @newSparseTensor function for constructing a new sparse tensor.
 /// This is the "swiss army knife" method for materializing sparse
 /// tensors into the computation.  The types of the `ptr` argument and
 /// the result depend on the action, as explained in the following table
@@ -45,14 +44,13 @@ extern "C" {
 /// a coordinate-scheme object, and "Iterator" means an iterator object).
 ///
 /// Action:         `ptr`:          Returns:
-/// kEmpty          unused          STS, empty
-/// kEmptyCOO       unused          COO, empty
-/// kFromFile       char* filename  STS, read from the file
+/// kEmpty          -               STS, empty
+/// kEmptyForward   -               STS, empty, with forwarding COO
 /// kFromCOO        COO             STS, copied from the COO source
-/// kToCOO          STS             COO, copied from the STS source
 /// kSparseToSparse STS             STS, copied from the STS source
-/// kToIterator     STS             Iterator, call @getNext to use and
-///                                 @delSparseTensorIterator to free.
+/// kToCOO          STS             COO, copied from the STS source
+/// kToIterator     STS             Iterator (@getNext/@delSparseTensorIterator)
+/// kPack           buffers         STS, from level buffers
 MLIR_CRUNNERUTILS_EXPORT void *_mlir_ciface_newSparseTensor( // NOLINT
     StridedMemRefType<index_type, 1> *dimSizesRef,
     StridedMemRefType<index_type, 1> *lvlSizesRef,
@@ -84,19 +82,15 @@ MLIR_SPARSETENSOR_FOREVERY_O(DECL_SPARSEPOSITIONS)
 MLIR_SPARSETENSOR_FOREVERY_O(DECL_SPARSECOORDINATES)
 #undef DECL_SPARSECOORDINATES
 
-/// Coordinate-scheme method for adding a new element.
-/// TODO: remove dim2lvl
-#define DECL_ADDELT(VNAME, V)                                                  \
-  MLIR_CRUNNERUTILS_EXPORT void *_mlir_ciface_addElt##VNAME(                   \
-      void *lvlCOO, StridedMemRefType<V, 0> *vref,                             \
-      StridedMemRefType<index_type, 1> *dimCoordsRef,                          \
-      StridedMemRefType<index_type, 1> *dim2lvlRef);
-MLIR_SPARSETENSOR_FOREVERY_V(DECL_ADDELT)
-#undef DECL_ADDELT
+/// Tensor-storage method for a dim to lvl forwarding insertion.
+#define DECL_FORWARDINGINSERT(VNAME, V)                                        \
+  MLIR_CRUNNERUTILS_EXPORT void _mlir_ciface_forwardingInsert##VNAME(          \
+      void *tensor, StridedMemRefType<V, 0> *vref,                             \
+      StridedMemRefType<index_type, 1> *dimCoordsRef);                         \
+  MLIR_SPARSETENSOR_FOREVERY_V(DECL_FORWARDINGINSERT)
+#undef DECL_FORWARDINGINSERT
 
 /// Coordinate-scheme method for getting the next element while iterating.
-/// The `cref` argument uses the same coordinate-space as the `iter` (which
-/// can be either dim- or lvl-coords, depending on context).
 #define DECL_GETNEXT(VNAME, V)                                                 \
   MLIR_CRUNNERUTILS_EXPORT bool _mlir_ciface_getNext##VNAME(                   \
       void *iter, StridedMemRefType<index_type, 1> *cref,                      \
@@ -185,8 +179,11 @@ MLIR_CRUNNERUTILS_EXPORT index_type sparseLvlSize(void *tensor, index_type l);
 /// Tensor-storage method to get the size of the given dimension.
 MLIR_CRUNNERUTILS_EXPORT index_type sparseDimSize(void *tensor, index_type d);
 
+/// Tensor-storage method to finalize forwarding insertions.
+MLIR_CRUNNERUTILS_EXPORT void endForwardingInsert(void *tensor);
+
 /// Tensor-storage method to finalize lexicographic insertions.
-MLIR_CRUNNERUTILS_EXPORT void endInsert(void *tensor);
+MLIR_CRUNNERUTILS_EXPORT void endLexInsert(void *tensor);
 
 /// Coordinate-scheme method to write to file in extended FROSTT format.
 #define DECL_OUTSPARSETENSOR(VNAME, V)                                         \

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index 8e2dbcf864f97ba..ce3b49915319ceb 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -596,7 +596,7 @@ class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> {
                   ConversionPatternRewriter &rewriter) const override {
     if (op.getHasInserts()) {
       // Finalize any pending insertions.
-      StringRef name = "endInsert";
+      StringRef name = "endLexInsert";
       createFuncCall(rewriter, op->getLoc(), name, {}, adaptor.getOperands(),
                      EmitCInterface::Off);
     }

diff  --git a/mlir/lib/ExecutionEngine/SparseTensor/Storage.cpp b/mlir/lib/ExecutionEngine/SparseTensor/Storage.cpp
index 1d654cae3b4b125..050dff2da1fa476 100644
--- a/mlir/lib/ExecutionEngine/SparseTensor/Storage.cpp
+++ b/mlir/lib/ExecutionEngine/SparseTensor/Storage.cpp
@@ -80,6 +80,13 @@ MLIR_SPARSETENSOR_FOREVERY_FIXED_O(IMPL_GETCOORDINATES)
 MLIR_SPARSETENSOR_FOREVERY_V(IMPL_GETVALUES)
 #undef IMPL_GETVALUES
 
+#define IMPL_FORWARDINGINSERT(VNAME, V)                                        \
+  void SparseTensorStorageBase::forwardingInsert(const uint64_t *, V) {        \
+    FATAL_PIV("forwardingInsert" #VNAME);                                      \
+  }
+MLIR_SPARSETENSOR_FOREVERY_V(IMPL_FORWARDINGINSERT)
+#undef IMPL_FORWARDINGINSERT
+
 #define IMPL_LEXINSERT(VNAME, V)                                               \
   void SparseTensorStorageBase::lexInsert(const uint64_t *, V) {               \
     FATAL_PIV("lexInsert" #VNAME);                                             \

diff  --git a/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp b/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp
index 83ceecaf5a30ee1..cd1b663578a48ce 100644
--- a/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp
+++ b/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp
@@ -177,9 +177,16 @@ extern "C" {
 #define CASE(p, c, v, P, C, V)                                                 \
   if (posTp == (p) && crdTp == (c) && valTp == (v)) {                          \
     switch (action) {                                                          \
-    case Action::kEmpty:                                                       \
+    case Action::kEmpty: {                                                     \
       return SparseTensorStorage<P, C, V>::newEmpty(                           \
-          dimRank, dimSizes, lvlRank, lvlSizes, lvlTypes, dim2lvl, lvl2dim);   \
+          dimRank, dimSizes, lvlRank, lvlSizes, lvlTypes, dim2lvl, lvl2dim,    \
+          false);                                                              \
+    }                                                                          \
+    case Action::kEmptyForward: {                                              \
+      return SparseTensorStorage<P, C, V>::newEmpty(                           \
+          dimRank, dimSizes, lvlRank, lvlSizes, lvlTypes, dim2lvl, lvl2dim,    \
+          true);                                                               \
+    }                                                                          \
     case Action::kFromCOO: {                                                   \
       assert(ptr && "Received nullptr for SparseTensorCOO object");            \
       auto &coo = *static_cast<SparseTensorCOO<V> *>(ptr);                     \
@@ -193,8 +200,9 @@ extern "C" {
           dimRank, dimSizes, lvlRank, lvlSizes, lvlTypes, dim2lvl, lvl2dim,    \
           dimRank, tensor);                                                    \
     }                                                                          \
-    case Action::kEmptyCOO:                                                    \
-      return new SparseTensorCOO<V>(lvlRank, lvlSizes);                        \
+    case Action::kFuture: {                                                    \
+      break;                                                                   \
+    }                                                                          \
     case Action::kToCOO: {                                                     \
       assert(ptr && "Received nullptr for SparseTensorStorage object");        \
       auto &tensor = *static_cast<SparseTensorStorage<P, C, V> *>(ptr);        \
@@ -405,29 +413,20 @@ MLIR_SPARSETENSOR_FOREVERY_O(IMPL_SPARSECOORDINATES)
 #undef IMPL_SPARSECOORDINATES
 #undef IMPL_GETOVERHEAD
 
-// TODO: use MapRef here for translation of coordinates
-// TODO: remove dim2lvl
-#define IMPL_ADDELT(VNAME, V)                                                  \
-  void *_mlir_ciface_addElt##VNAME(                                            \
-      void *lvlCOO, StridedMemRefType<V, 0> *vref,                             \
-      StridedMemRefType<index_type, 1> *dimCoordsRef,                          \
-      StridedMemRefType<index_type, 1> *dim2lvlRef) {                          \
-    assert(lvlCOO &&vref);                                                     \
+#define IMPL_FORWARDINGINSERT(VNAME, V)                                        \
+  void _mlir_ciface_forwardingInsert##VNAME(                                   \
+      void *t, StridedMemRefType<V, 0> *vref,                                  \
+      StridedMemRefType<index_type, 1> *dimCoordsRef) {                        \
+    assert(t &&vref);                                                          \
     ASSERT_NO_STRIDE(dimCoordsRef);                                            \
-    ASSERT_NO_STRIDE(dim2lvlRef);                                              \
-    const uint64_t rank = MEMREF_GET_USIZE(dimCoordsRef);                      \
-    ASSERT_USIZE_EQ(dim2lvlRef, rank);                                         \
     const index_type *dimCoords = MEMREF_GET_PAYLOAD(dimCoordsRef);            \
-    const index_type *dim2lvl = MEMREF_GET_PAYLOAD(dim2lvlRef);                \
-    std::vector<index_type> lvlCoords(rank);                                   \
-    for (uint64_t d = 0; d < rank; ++d)                                        \
-      lvlCoords[dim2lvl[d]] = dimCoords[d];                                    \
-    V *value = MEMREF_GET_PAYLOAD(vref);                                       \
-    static_cast<SparseTensorCOO<V> *>(lvlCOO)->add(lvlCoords, *value);         \
-    return lvlCOO;                                                             \
+    assert(dimCoords);                                                         \
+    const V *value = MEMREF_GET_PAYLOAD(vref);                                 \
+    static_cast<SparseTensorStorageBase *>(t)->forwardingInsert(dimCoords,     \
+                                                                *value);       \
   }
-MLIR_SPARSETENSOR_FOREVERY_V(IMPL_ADDELT)
-#undef IMPL_ADDELT
+MLIR_SPARSETENSOR_FOREVERY_V(IMPL_FORWARDINGINSERT)
+#undef IMPL_FORWARDINGINSERT
 
 // NOTE: the `cref` argument uses the same coordinate-space as the `iter`
 // (which can be either dim- or lvl-coords, depending on context).
@@ -692,8 +691,12 @@ index_type sparseDimSize(void *tensor, index_type d) {
   return static_cast<SparseTensorStorageBase *>(tensor)->getDimSize(d);
 }
 
-void endInsert(void *tensor) {
-  return static_cast<SparseTensorStorageBase *>(tensor)->endInsert();
+void endForwardingInsert(void *tensor) {
+  return static_cast<SparseTensorStorageBase *>(tensor)->endForwardingInsert();
+}
+
+void endLexInsert(void *tensor) {
+  return static_cast<SparseTensorStorageBase *>(tensor)->endLexInsert();
 }
 
 #define IMPL_OUTSPARSETENSOR(VNAME, V)                                         \

diff  --git a/mlir/test/Dialect/SparseTensor/conversion.mlir b/mlir/test/Dialect/SparseTensor/conversion.mlir
index 29093a055ab2e04..96300a98a6a4bc5 100644
--- a/mlir/test/Dialect/SparseTensor/conversion.mlir
+++ b/mlir/test/Dialect/SparseTensor/conversion.mlir
@@ -296,7 +296,7 @@ func.func @sparse_reconstruct(%arg0: tensor<128xf32, #SparseVector>) -> tensor<1
 
 // CHECK-LABEL: func @sparse_reconstruct_ins(
 //  CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>
-//       CHECK: call @endInsert(%[[A]]) : (!llvm.ptr<i8>) -> ()
+//       CHECK: call @endLexInsert(%[[A]]) : (!llvm.ptr<i8>) -> ()
 //       CHECK: return %[[A]] : !llvm.ptr<i8>
 func.func @sparse_reconstruct_ins(%arg0: tensor<128xf32, #SparseVector>) -> tensor<128xf32, #SparseVector> {
   %0 = sparse_tensor.load %arg0 hasInserts : tensor<128xf32, #SparseVector>

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_expand.mlir b/mlir/test/Dialect/SparseTensor/sparse_expand.mlir
index d19d7fe2871d674..9d8db10aa423022 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_expand.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_expand.mlir
@@ -62,7 +62,7 @@
 // CHECK-CONVERT: memref.dealloc %[[A]] : memref<?xf64>
 // CHECK-CONVERT: memref.dealloc %[[B]] : memref<?xi1>
 // CHECK-CONVERT: memref.dealloc %[[C]] : memref<?xindex>
-// CHECK-CONVERT: call @endInsert
+// CHECK-CONVERT: call @endLexInsert
 //
 func.func @kernel(%arga: tensor<?x?xf64, #DCSC>) -> tensor<?xf64, #SV> {
   %c0 = arith.constant 0 : index
@@ -115,7 +115,7 @@ func.func @kernel(%arga: tensor<?x?xf64, #DCSC>) -> tensor<?xf64, #SV> {
 // CHECK-CONVERT: memref.dealloc %[[A]] : memref<?xf64>
 // CHECK-CONVERT: memref.dealloc %[[B]] : memref<?xi1>
 // CHECK-CONVERT: memref.dealloc %[[C]] : memref<?xindex>
-// CHECK-CONVERT: call @endInsert
+// CHECK-CONVERT: call @endLexInsert
 //
 func.func @matmul1(%A: tensor<8x2xf64, #CSR>,
                    %B: tensor<2x4xf64, #CSR>) -> tensor<8x4xf64, #CSR> {
@@ -163,7 +163,7 @@ func.func @matmul1(%A: tensor<8x2xf64, #CSR>,
 // CHECK-CONVERT: memref.dealloc %[[A]] : memref<?xf64>
 // CHECK-CONVERT: memref.dealloc %[[B]] : memref<?xi1>
 // CHECK-CONVERT: memref.dealloc %[[C]] : memref<?xindex>
-// CHECK-CONVERT: call @endInsert
+// CHECK-CONVERT: call @endLexInsert
 //
 func.func @matmul2(%A: tensor<8x2xf64, #CSC>,
                    %B: tensor<2x4xf64, #CSC>) -> tensor<8x4xf64, #CSC> {

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir b/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir
index 7d852ca9cc1aa26..8ecbc1da965a156 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir
@@ -112,7 +112,7 @@
 // CHECK:           memref.dealloc %[[VAL_20]] : memref<300xf64>
 // CHECK:           memref.dealloc %[[VAL_22]] : memref<300xi1>
 // CHECK:           memref.dealloc %[[VAL_24]] : memref<300xindex>
-// CHECK:           call @endInsert(%[[VAL_19]]) : (!llvm.ptr<i8>) -> ()
+// CHECK:           call @endLexInsert(%[[VAL_19]]) : (!llvm.ptr<i8>) -> ()
 // CHECK:           return %[[VAL_19]] : !llvm.ptr<i8>
 // CHECK:       }
 func.func @fill_zero_after_alloc(%arg0: tensor<100x200xf64, #DCSR>,


        


More information about the Mlir-commits mailing list