[Mlir-commits] [mlir] [mlir][sparse] bug fix on all-dense lex insertion (PR #73987)

Aart Bik llvmlistbot at llvm.org
Thu Nov 30 13:14:52 PST 2023


https://github.com/aartbik created https://github.com/llvm/llvm-project/pull/73987

Fixes a bug that appended values after insertion completed. Also slight optimization by avoiding all-Dense computation for every lexInsert call

>From 01ffc0df05bde4cbae25ad79c83d4bae1b026c18 Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Thu, 30 Nov 2023 13:12:23 -0800
Subject: [PATCH] [mlir][sparse] bug fix on all-dense lex insertion

Fixes a bug that appended values after insertion completed.
Also slight optimization by avoiding all-Dense computation
for every lexInsert call
---
 .../ExecutionEngine/SparseTensor/Storage.h    | 20 ++++++++-----------
 .../ExecutionEngine/SparseTensor/Storage.cpp  | 14 ++++++++++---
 2 files changed, 19 insertions(+), 15 deletions(-)

diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
index 19c49e6c487dff7..01c5f2382ffe69c 100644
--- a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
+++ b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
@@ -186,6 +186,7 @@ class SparseTensorStorageBase {
 
 protected:
   const MapRef map; // non-owning pointers into dim2lvl/lvl2dim vectors
+  const bool allDense;
 };
 
 /// A memory-resident sparse tensor using a storage scheme based on
@@ -293,8 +294,6 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
   /// Partially specialize lexicographical insertions based on template types.
   void lexInsert(const uint64_t *lvlCoords, V val) final {
     assert(lvlCoords);
-    bool allDense = std::all_of(getLvlTypes().begin(), getLvlTypes().end(),
-                                [](LevelType lt) { return isDenseLT(lt); });
     if (allDense) {
       uint64_t lvlRank = getLvlRank();
       uint64_t valIdx = 0;
@@ -363,10 +362,12 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
 
   /// Finalizes lexicographic insertions.
   void endLexInsert() final {
-    if (values.empty())
-      finalizeSegment(0);
-    else
-      endPath(0);
+    if (!allDense) {
+      if (values.empty())
+        finalizeSegment(0);
+      else
+        endPath(0);
+    }
   }
 
   /// Allocates a new COO object and initializes it with the contents.
@@ -705,7 +706,6 @@ SparseTensorStorage<P, C, V>::SparseTensorStorage(
   // we reserve position/coordinate space based on all previous dense
   // levels, which works well up to first sparse level; but we should
   // really use nnz and dense/sparse distribution.
-  bool allDense = true;
   uint64_t sz = 1;
   for (uint64_t l = 0; l < lvlRank; l++) {
     if (isCompressedLvl(l)) {
@@ -713,23 +713,19 @@ SparseTensorStorage<P, C, V>::SparseTensorStorage(
       positions[l].push_back(0);
       coordinates[l].reserve(sz);
       sz = 1;
-      allDense = false;
     } else if (isLooseCompressedLvl(l)) {
       positions[l].reserve(2 * sz + 1); // last one unused
       positions[l].push_back(0);
       coordinates[l].reserve(sz);
       sz = 1;
-      allDense = false;
     } else if (isSingletonLvl(l)) {
       coordinates[l].reserve(sz);
       sz = 1;
-      allDense = false;
     } else if (is2OutOf4Lvl(l)) {
-      assert(allDense && l == lvlRank - 1 && "unexpected 2:4 usage");
+      assert(l == lvlRank - 1 && "unexpected 2:4 usage");
       sz = detail::checkedMul(sz, lvlSizes[l]) / 2;
       coordinates[l].reserve(sz);
       values.reserve(sz);
-      allDense = false;
     } else { // Dense level.
       assert(isDenseLvl(l));
       sz = detail::checkedMul(sz, lvlSizes[l]);
diff --git a/mlir/lib/ExecutionEngine/SparseTensor/Storage.cpp b/mlir/lib/ExecutionEngine/SparseTensor/Storage.cpp
index 7f8f76f8ec18901..0c7b3a228a65cf7 100644
--- a/mlir/lib/ExecutionEngine/SparseTensor/Storage.cpp
+++ b/mlir/lib/ExecutionEngine/SparseTensor/Storage.cpp
@@ -17,6 +17,13 @@
 
 using namespace mlir::sparse_tensor;
 
+static inline bool isAllDense(uint64_t lvlRank, const LevelType *lvlTypes) {
+  for (uint64_t l = 0; l < lvlRank; l++)
+    if (!isDenseLT(lvlTypes[l]))
+      return false;
+  return true;
+}
+
 SparseTensorStorageBase::SparseTensorStorageBase( // NOLINT
     uint64_t dimRank, const uint64_t *dimSizes, uint64_t lvlRank,
     const uint64_t *lvlSizes, const LevelType *lvlTypes,
@@ -26,15 +33,16 @@ SparseTensorStorageBase::SparseTensorStorageBase( // NOLINT
       lvlTypes(lvlTypes, lvlTypes + lvlRank),
       dim2lvlVec(dim2lvl, dim2lvl + lvlRank),
       lvl2dimVec(lvl2dim, lvl2dim + dimRank),
-      map(dimRank, lvlRank, dim2lvlVec.data(), lvl2dimVec.data()) {
+      map(dimRank, lvlRank, dim2lvlVec.data(), lvl2dimVec.data()),
+      allDense(isAllDense(lvlRank, lvlTypes)) {
   assert(dimSizes && lvlSizes && lvlTypes && dim2lvl && lvl2dim);
   // Validate dim-indexed parameters.
   assert(dimRank > 0 && "Trivial shape is unsupported");
-  for (uint64_t d = 0; d < dimRank; ++d)
+  for (uint64_t d = 0; d < dimRank; d++)
     assert(dimSizes[d] > 0 && "Dimension size zero has trivial storage");
   // Validate lvl-indexed parameters.
   assert(lvlRank > 0 && "Trivial shape is unsupported");
-  for (uint64_t l = 0; l < lvlRank; ++l) {
+  for (uint64_t l = 0; l < lvlRank; l++) {
     assert(lvlSizes[l] > 0 && "Level size zero has trivial storage");
     assert(isDenseLvl(l) || isCompressedLvl(l) || isLooseCompressedLvl(l) ||
            isSingletonLvl(l) || is2OutOf4Lvl(l));



More information about the Mlir-commits mailing list