[Mlir-commits] [mlir] [mlir][sparse] refactoring sparse runtime lib into less paths (PR #85332)

Aart Bik llvmlistbot at llvm.org
Thu Mar 14 16:20:25 PDT 2024


https://github.com/aartbik created https://github.com/llvm/llvm-project/pull/85332

Two constructors could be easily refactored into one after a lot of previous deprecated code has been removed.

>From 2a477234f9c50fff428a0bd6be22f7ab1d64c04d Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Thu, 14 Mar 2024 16:18:01 -0700
Subject: [PATCH] [mlir][sparse] refactoring sparse runtime lib into less paths

Two constructors could be easily refactored into one
after a lot of previous deprecated code has been removed.
---
 .../mlir/ExecutionEngine/SparseTensor/File.h  |  2 +-
 .../ExecutionEngine/SparseTensor/Storage.h    | 88 +++++++------------
 .../ExecutionEngine/SparseTensorRuntime.cpp   |  2 +-
 3 files changed, 34 insertions(+), 58 deletions(-)

diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h b/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h
index f927b82628b1a6..714e664dd0f4eb 100644
--- a/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h
+++ b/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h
@@ -206,7 +206,7 @@ class SparseTensorReader final {
     auto *lvlCOO = readCOO<V>(map, lvlSizes);
     auto *tensor = SparseTensorStorage<P, I, V>::newFromCOO(
         dimRank, getDimSizes(), lvlRank, lvlSizes, lvlTypes, dim2lvl, lvl2dim,
-        *lvlCOO);
+        lvlCOO);
     delete lvlCOO;
     return tensor;
   }
diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
index b207fc1ee104d3..773957a8b51162 100644
--- a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
+++ b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
@@ -201,33 +201,18 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
 
 public:
   /// Constructs a sparse tensor with the given encoding, and allocates
-  /// overhead storage according to some simple heuristics. When the
-  /// `bool` argument is true and `lvlTypes` are all dense, then this
-  /// ctor will also initialize the values array with zeros. That
-  /// argument should be true when an empty tensor is intended; whereas
-  /// it should usually be false when the ctor will be followed up by
-  /// some other form of initialization.
+  /// overhead storage according to some simple heuristics. When lvlCOO
+  /// is set, the sparse tensor initializes with the contents from that
+  /// data structure. Otherwise, an empty sparse tensor results.
   SparseTensorStorage(uint64_t dimRank, const uint64_t *dimSizes,
                       uint64_t lvlRank, const uint64_t *lvlSizes,
                       const LevelType *lvlTypes, const uint64_t *dim2lvl,
-                      const uint64_t *lvl2dim, SparseTensorCOO<V> *lvlCOO,
-                      bool initializeValuesIfAllDense);
+                      const uint64_t *lvl2dim, SparseTensorCOO<V> *lvlCOO);
 
   /// Constructs a sparse tensor with the given encoding, and initializes
-  /// the contents from the COO. This ctor performs the same heuristic
-  /// overhead-storage allocation as the ctor above.
-  SparseTensorStorage(uint64_t dimRank, const uint64_t *dimSizes,
-                      uint64_t lvlRank, const uint64_t *lvlSizes,
-                      const LevelType *lvlTypes, const uint64_t *dim2lvl,
-                      const uint64_t *lvl2dim, SparseTensorCOO<V> &lvlCOO);
-
-  /// Constructs a sparse tensor with the given encoding, and initializes
-  /// the contents from the level buffers. This ctor allocates exactly
-  /// the required amount of overhead storage, not using any heuristics.
-  /// It assumes that the data provided by `lvlBufs` can be directly used to
-  /// interpret the result sparse tensor and performs *NO* integrity test on the
-  /// input data. It also assume that the trailing COO coordinate buffer is
-  /// passed in as a single AoS memory.
+  /// the contents from the level buffers. The constructor assumes that the
+  /// data provided by `lvlBufs` can be directly used to interpret the result
+  /// sparse tensor and performs no integrity test on the input data.
   SparseTensorStorage(uint64_t dimRank, const uint64_t *dimSizes,
                       uint64_t lvlRank, const uint64_t *lvlSizes,
                       const LevelType *lvlTypes, const uint64_t *dim2lvl,
@@ -244,16 +229,14 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
   newFromCOO(uint64_t dimRank, const uint64_t *dimSizes, uint64_t lvlRank,
              const uint64_t *lvlSizes, const LevelType *lvlTypes,
              const uint64_t *dim2lvl, const uint64_t *lvl2dim,
-             SparseTensorCOO<V> &lvlCOO);
+             SparseTensorCOO<V> *lvlCOO);
 
-  /// Allocates a new sparse tensor and initialize it with the data stored level
-  /// buffers directly.
+  /// Allocates a new sparse tensor and initialize it from the given buffers.
   static SparseTensorStorage<P, C, V> *
-  packFromLvlBuffers(uint64_t dimRank, const uint64_t *dimSizes,
-                     uint64_t lvlRank, const uint64_t *lvlSizes,
-                     const LevelType *lvlTypes, const uint64_t *dim2lvl,
-                     const uint64_t *lvl2dim, uint64_t srcRank,
-                     const intptr_t *buffers);
+  newFromBuffers(uint64_t dimRank, const uint64_t *dimSizes, uint64_t lvlRank,
+                 const uint64_t *lvlSizes, const LevelType *lvlTypes,
+                 const uint64_t *dim2lvl, const uint64_t *lvl2dim,
+                 uint64_t srcRank, const intptr_t *buffers);
 
   ~SparseTensorStorage() final = default;
 
@@ -563,9 +546,9 @@ SparseTensorStorage<P, C, V> *SparseTensorStorage<P, C, V>::newEmpty(
     uint64_t dimRank, const uint64_t *dimSizes, uint64_t lvlRank,
     const uint64_t *lvlSizes, const LevelType *lvlTypes,
     const uint64_t *dim2lvl, const uint64_t *lvl2dim) {
+  SparseTensorCOO<V> *noLvlCOO = nullptr;
   return new SparseTensorStorage<P, C, V>(dimRank, dimSizes, lvlRank, lvlSizes,
-                                          lvlTypes, dim2lvl, lvl2dim, nullptr,
-                                          true);
+                                          lvlTypes, dim2lvl, lvl2dim, noLvlCOO);
 }
 
 template <typename P, typename C, typename V>
@@ -573,13 +556,14 @@ SparseTensorStorage<P, C, V> *SparseTensorStorage<P, C, V>::newFromCOO(
     uint64_t dimRank, const uint64_t *dimSizes, uint64_t lvlRank,
     const uint64_t *lvlSizes, const LevelType *lvlTypes,
     const uint64_t *dim2lvl, const uint64_t *lvl2dim,
-    SparseTensorCOO<V> &lvlCOO) {
+    SparseTensorCOO<V> *lvlCOO) {
+  assert(lvlCOO);
   return new SparseTensorStorage<P, C, V>(dimRank, dimSizes, lvlRank, lvlSizes,
                                           lvlTypes, dim2lvl, lvl2dim, lvlCOO);
 }
 
 template <typename P, typename C, typename V>
-SparseTensorStorage<P, C, V> *SparseTensorStorage<P, C, V>::packFromLvlBuffers(
+SparseTensorStorage<P, C, V> *SparseTensorStorage<P, C, V>::newFromBuffers(
     uint64_t dimRank, const uint64_t *dimSizes, uint64_t lvlRank,
     const uint64_t *lvlSizes, const LevelType *lvlTypes,
     const uint64_t *dim2lvl, const uint64_t *lvl2dim, uint64_t srcRank,
@@ -599,10 +583,9 @@ SparseTensorStorage<P, C, V>::SparseTensorStorage(
     uint64_t dimRank, const uint64_t *dimSizes, uint64_t lvlRank,
     const uint64_t *lvlSizes, const LevelType *lvlTypes,
     const uint64_t *dim2lvl, const uint64_t *lvl2dim,
-    SparseTensorCOO<V> *lvlCOO, bool initializeValuesIfAllDense)
+    SparseTensorCOO<V> *lvlCOO)
     : SparseTensorStorage(dimRank, dimSizes, lvlRank, lvlSizes, lvlTypes,
                           dim2lvl, lvl2dim) {
-  assert(!lvlCOO || lvlRank == lvlCOO->getRank());
   // Provide hints on capacity of positions and coordinates.
   // TODO: needs much fine-tuning based on actual sparsity; currently
   // we reserve position/coordinate space based on all previous dense
@@ -633,27 +616,20 @@ SparseTensorStorage<P, C, V>::SparseTensorStorage(
       sz = detail::checkedMul(sz, lvlSizes[l]);
     }
   }
-  if (allDense && initializeValuesIfAllDense)
+  if (lvlCOO) {
+    /* New from COO: ensure it is sorted. */
+    assert(lvlCOO->getRank() == lvlRank);
+    lvlCOO->sort();
+    // Now actually insert the `elements`.
+    const auto &elements = lvlCOO->getElements();
+    const uint64_t nse = elements.size();
+    assert(values.size() == 0);
+    values.reserve(nse);
+    fromCOO(elements, 0, nse, 0);
+  } else if (allDense) {
+    /* New empty (all dense) */
     values.resize(sz, 0);
-}
-
-template <typename P, typename C, typename V>
-SparseTensorStorage<P, C, V>::SparseTensorStorage( // NOLINT
-    uint64_t dimRank, const uint64_t *dimSizes, uint64_t lvlRank,
-    const uint64_t *lvlSizes, const LevelType *lvlTypes,
-    const uint64_t *dim2lvl, const uint64_t *lvl2dim,
-    SparseTensorCOO<V> &lvlCOO)
-    : SparseTensorStorage(dimRank, dimSizes, lvlRank, lvlSizes, lvlTypes,
-                          dim2lvl, lvl2dim, nullptr, false) {
-  // Ensure lvlCOO is sorted.
-  assert(lvlRank == lvlCOO.getRank());
-  lvlCOO.sort();
-  // Now actually insert the `elements`.
-  const auto &elements = lvlCOO.getElements();
-  const uint64_t nse = elements.size();
-  assert(values.size() == 0);
-  values.reserve(nse);
-  fromCOO(elements, 0, nse, 0);
+  }
 }
 
 template <typename P, typename C, typename V>
diff --git a/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp b/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp
index 731abcbbf1f39e..8835056099d234 100644
--- a/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp
+++ b/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp
@@ -127,7 +127,7 @@ extern "C" {
     case Action::kPack: {                                                      \
       assert(ptr && "Received nullptr for SparseTensorStorage object");        \
       intptr_t *buffers = static_cast<intptr_t *>(ptr);                        \
-      return SparseTensorStorage<P, C, V>::packFromLvlBuffers(                 \
+      return SparseTensorStorage<P, C, V>::newFromBuffers(                     \
           dimRank, dimSizes, lvlRank, lvlSizes, lvlTypes, dim2lvl, lvl2dim,    \
           dimRank, buffers);                                                   \
     }                                                                          \



More information about the Mlir-commits mailing list