[Mlir-commits] [mlir] [mlir][sparse] refactor dim2lvl/lvl2dim passing into MapRef (PR #68649)

Aart Bik llvmlistbot at llvm.org
Mon Oct 9 17:57:25 PDT 2023


https://github.com/aartbik created https://github.com/llvm/llvm-project/pull/68649

This revision refactors all "swiss army knife" entry points to pass dim2lvl/lvl2dim mapping, so that the callee can construct a MapRef (shown for SparseTensorStorage class). This is a next step towards completely centralizing mapping code into a single MapRef class.

>From e737b8e1f5759c9eac63c7ecc87f622e45c3d512 Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Mon, 9 Oct 2023 17:51:52 -0700
Subject: [PATCH] [mlir][sparse] refactor dim2lvl/lvl2dim passing into MapRef

This revision refactors all "swiss army knife" entry points
to pass dim2lvl/lvl2dim mapping, so that the callee can construct
a MapRef (shown for SparseTensorStorage class). This is a next
step towards completely centralizing mapping code into a single
MapRef class.
---
 .../mlir/ExecutionEngine/SparseTensor/File.h  |   2 +-
 .../ExecutionEngine/SparseTensor/Storage.h    | 209 +++++++-----------
 .../ExecutionEngine/SparseTensor/Storage.cpp  |  11 +-
 .../ExecutionEngine/SparseTensorRuntime.cpp   |  21 +-
 4 files changed, 99 insertions(+), 144 deletions(-)

diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h b/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h
index 9157bfa7e773239..efc3f82d6a307ea 100644
--- a/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h
+++ b/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h
@@ -203,7 +203,7 @@ class SparseTensorReader final {
     MapRef map(dimRank, lvlRank, dim2lvl, lvl2dim);
     auto *coo = readCOO<V>(map, lvlSizes);
     auto *tensor = SparseTensorStorage<P, I, V>::newFromCOO(
-        dimRank, getDimSizes(), lvlRank, lvlTypes, lvl2dim, *coo);
+        dimRank, getDimSizes(), lvlRank, lvlTypes, dim2lvl, lvl2dim, *coo);
     delete coo;
     return tensor;
   }
diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
index 0407bccaae8790c..303a41bc471d5d9 100644
--- a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
+++ b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
@@ -24,13 +24,8 @@
 #include "mlir/ExecutionEngine/SparseTensor/ArithmeticUtils.h"
 #include "mlir/ExecutionEngine/SparseTensor/COO.h"
 #include "mlir/ExecutionEngine/SparseTensor/ErrorHandling.h"
+#include "mlir/ExecutionEngine/SparseTensor/MapRef.h"
 
-#define ASSERT_VALID_DIM(d)                                                    \
-  assert(d < getDimRank() && "Dimension is out of bounds");
-#define ASSERT_VALID_LVL(l)                                                    \
-  assert(l < getLvlRank() && "Level is out of bounds");
-#define ASSERT_COMPRESSED_LVL(l)                                               \
-  assert(isCompressedLvl(l) && "Level is not compressed");
 #define ASSERT_COMPRESSED_OR_SINGLETON_LVL(l)                                  \
   do {                                                                         \
     const DimLevelType dlt = getLvlType(l);                                    \
@@ -49,9 +44,9 @@ class SparseTensorEnumeratorBase;
 template <typename P, typename C, typename V>
 class SparseTensorEnumerator;
 
-/// Abstract base class for `SparseTensorStorage<P,C,V>`.  This class
+/// Abstract base class for `SparseTensorStorage<P,C,V>`. This class
 /// takes responsibility for all the `<P,C,V>`-independent aspects
-/// of the tensor (e.g., shape, sparsity, permutation).  In addition,
+/// of the tensor (e.g., shape, sparsity, mapping). In addition,
 /// we use function overloading to implement "partial" method
 /// specialization, which the C-API relies on to catch type errors
 /// arising from our use of opaque pointers.
@@ -62,24 +57,20 @@ class SparseTensorEnumerator;
 /// coordinate spaces (and their associated rank, shape, sizes, etc).
 /// Denotationally, we have the *dimensions* of the tensor represented
 /// by this object.  Operationally, we have the *levels* of the storage
-/// representation itself.  We use this "dimension" vs "level" terminology
-/// throughout, since alternative terminology like "tensor-dimension",
-/// "original-dimension", "storage-dimension", etc, is both more verbose
-/// and prone to introduce confusion whenever the qualifiers are dropped.
-/// Where necessary, we use "axis" as the generic term.
+/// representation itself.
 ///
 /// The *size* of an axis is the cardinality of possible coordinate
 /// values along that axis (regardless of which coordinates have stored
-/// element values).  As such, each size must be non-zero since if any
+/// element values). As such, each size must be non-zero since if any
 /// axis has size-zero then the whole tensor would have trivial storage
-/// (since there are no possible coordinates).  Thus we use the plural
+/// (since there are no possible coordinates). Thus we use the plural
 /// term *sizes* for a collection of non-zero cardinalities, and use
-/// this term whenever referring to run-time cardinalities.  Whereas we
+/// this term whenever referring to run-time cardinalities. Whereas we
 /// use the term *shape* for a collection of compile-time cardinalities,
 /// where zero is used to indicate cardinalities which are dynamic (i.e.,
-/// unknown/unspecified at compile-time).  At run-time, these dynamic
+/// unknown/unspecified at compile-time). At run-time, these dynamic
 /// cardinalities will be inferred from or checked against sizes otherwise
-/// specified.  Thus, dynamic cardinalities always have an "immutable but
+/// specified. Thus, dynamic cardinalities always have an "immutable but
 /// unknown" value; so the term "dynamic" should not be taken to indicate
 /// run-time mutability.
 class SparseTensorStorageBase {
@@ -89,25 +80,10 @@ class SparseTensorStorageBase {
 
 public:
   /// Constructs a new sparse-tensor storage object with the given encoding.
-  ///
-  /// Preconditions:
-  /// * `dimSizes`, `lvlSizes`, `lvlTypes`, and `lvl2dim` must be nonnull.
-  /// * `dimSizes` must be valid for `dimRank`.
-  /// * `lvlSizes`, `lvlTypes`, and `lvl2dim` must be valid for `lvlRank`.
-  /// * `lvl2dim` must map `lvlSizes`-coordinates to `dimSizes`-coordinates.
-  ///
-  /// Asserts:
-  /// * `dimRank` and `lvlRank` are nonzero.
-  /// * `dimSizes` and `lvlSizes` contain only nonzero sizes.
   SparseTensorStorageBase(uint64_t dimRank, const uint64_t *dimSizes,
                           uint64_t lvlRank, const uint64_t *lvlSizes,
-                          const DimLevelType *lvlTypes,
+                          const DimLevelType *lvlTypes, const uint64_t *dim2lvl,
                           const uint64_t *lvl2dim);
-  // NOTE: For the most part we only need the `dimRank`.  But we need
-  // `dimSizes` for `toCOO` to support the identity permutation nicely
-  // (i.e., without the caller needing to already know the tensor's
-  // dimension-sizes; e.g., as in `fromMLIRSparseTensor`).
-
   virtual ~SparseTensorStorageBase() = default;
 
   /// Gets the number of tensor-dimensions.
@@ -121,7 +97,7 @@ class SparseTensorStorageBase {
 
   /// Safely looks up the size of the given tensor-dimension.
   uint64_t getDimSize(uint64_t d) const {
-    ASSERT_VALID_DIM(d);
+    assert(d < getDimRank() && "Dimension is out of bounds");
     return dimSizes[d];
   }
 
@@ -130,19 +106,16 @@ class SparseTensorStorageBase {
 
   /// Safely looks up the size of the given storage-level.
   uint64_t getLvlSize(uint64_t l) const {
-    ASSERT_VALID_LVL(l);
+    assert(l < getLvlRank() && "Level is out of bounds");
     return lvlSizes[l];
   }
 
-  /// Gets the level-to-dimension mapping.
-  const std::vector<uint64_t> &getLvl2Dim() const { return lvl2dim; }
-
   /// Gets the level-types array.
   const std::vector<DimLevelType> &getLvlTypes() const { return lvlTypes; }
 
   /// Safely looks up the type of the given level.
   DimLevelType getLvlType(uint64_t l) const {
-    ASSERT_VALID_LVL(l);
+    assert(l < getLvlRank() && "Level is out of bounds");
     return lvlTypes[l];
   }
 
@@ -165,6 +138,10 @@ class SparseTensorStorageBase {
   /// Safely checks if the level is unique.
   bool isUniqueLvl(uint64_t l) const { return isUniqueDLT(getLvlType(l)); }
 
+  /// Gets the level-to-dimension mapping.
+  // TODO: REMOVE THIS
+  const std::vector<uint64_t> &getLvl2Dim() const { return lvl2dimVec; }
+
   /// Allocates a new enumerator.  Callers must make sure to delete
   /// the enumerator when they're done with it. The first argument
   /// is the out-parameter for storing the newly allocated enumerator;
@@ -228,12 +205,14 @@ class SparseTensorStorageBase {
   const std::vector<uint64_t> dimSizes;
   const std::vector<uint64_t> lvlSizes;
   const std::vector<DimLevelType> lvlTypes;
-  const std::vector<uint64_t> lvl2dim;
+  const std::vector<uint64_t> dim2lvlVec;
+  const std::vector<uint64_t> lvl2dimVec;
+  const MapRef map; // non-owning pointers into dim2lvl/lvl2dim vectors
 };
 
 /// A memory-resident sparse tensor using a storage scheme based on
-/// per-level sparse/dense annotations.  This data structure provides
-/// a bufferized form of a sparse tensor type.  In contrast to generating
+/// per-level sparse/dense annotations. This data structure provides
+/// a bufferized form of a sparse tensor type. In contrast to generating
 /// setup methods for each differently annotated sparse tensor, this
 /// method provides a convenient "one-size-fits-all" solution that simply
 /// takes an input tensor and annotations to implement all required setup
@@ -244,58 +223,45 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
   /// Beware that the object is not necessarily guaranteed to be in a
   /// valid state after this constructor alone; e.g., `isCompressedLvl(l)`
   /// doesn't entail `!(positions[l].empty())`.
-  ///
-  /// Preconditions/assertions are as per the `SparseTensorStorageBase` ctor.
   SparseTensorStorage(uint64_t dimRank, const uint64_t *dimSizes,
                       uint64_t lvlRank, const uint64_t *lvlSizes,
-                      const DimLevelType *lvlTypes, const uint64_t *lvl2dim)
+                      const DimLevelType *lvlTypes, const uint64_t *dim2lvl,
+                      const uint64_t *lvl2dim)
       : SparseTensorStorageBase(dimRank, dimSizes, lvlRank, lvlSizes, lvlTypes,
-                                lvl2dim),
+                                dim2lvl, lvl2dim),
         positions(lvlRank), coordinates(lvlRank), lvlCursor(lvlRank) {}
 
 public:
   /// Constructs a sparse tensor with the given encoding, and allocates
-  /// overhead storage according to some simple heuristics.  When the
+  /// overhead storage according to some simple heuristics. When the
   /// `bool` argument is true and `lvlTypes` are all dense, then this
-  /// ctor will also initialize the values array with zeros.  That
+  /// ctor will also initialize the values array with zeros. That
   /// argument should be true when an empty tensor is intended; whereas
   /// it should usually be false when the ctor will be followed up by
   /// some other form of initialization.
-  ///
-  /// Preconditions/assertions are as per the `SparseTensorStorageBase` ctor.
   SparseTensorStorage(uint64_t dimRank, const uint64_t *dimSizes,
                       uint64_t lvlRank, const uint64_t *lvlSizes,
-                      const DimLevelType *lvlTypes, const uint64_t *lvl2dim,
-                      bool initializeValuesIfAllDense);
+                      const DimLevelType *lvlTypes, const uint64_t *dim2lvl,
+                      const uint64_t *lvl2dim, bool initializeValuesIfAllDense);
 
   /// Constructs a sparse tensor with the given encoding, and initializes
-  /// the contents from the COO.  This ctor performs the same heuristic
-  /// overhead-storage allocation as the ctor taking a `bool`, and
-  /// has the same preconditions/assertions (where we define `lvlSizes =
-  /// lvlCOO.getDimSizes().data()`), with the following addition:
-  ///
-  /// Asserts:
-  /// * `lvlRank == lvlCOO.getRank()`.
+  /// the contents from the COO. This ctor performs the same heuristic
+  /// overhead-storage allocation as the ctor taking a `bool`.
   SparseTensorStorage(uint64_t dimRank, const uint64_t *dimSizes,
                       uint64_t lvlRank, const DimLevelType *lvlTypes,
-                      const uint64_t *lvl2dim, SparseTensorCOO<V> &lvlCOO);
+                      const uint64_t *dim2lvl, const uint64_t *lvl2dim,
+                      SparseTensorCOO<V> &lvlCOO);
 
   /// Constructs a sparse tensor with the given encoding, and initializes
-  /// the contents from the enumerator.  This ctor allocates exactly
+  /// the contents from the enumerator. This ctor allocates exactly
   /// the required amount of overhead storage, not using any heuristics.
-  /// Preconditions/assertions are as per the `SparseTensorStorageBase`
-  /// ctor (where we define `lvlSizes = lvlEnumerator.getTrgSizes().data()`),
-  /// with the following addition:
-  ///
-  /// Asserts:
-  /// * `lvlRank == lvlEnumerator.getTrgRank()`.
   SparseTensorStorage(uint64_t dimRank, const uint64_t *dimSizes,
                       uint64_t lvlRank, const DimLevelType *lvlTypes,
-                      const uint64_t *lvl2dim,
+                      const uint64_t *dim2lvl, const uint64_t *lvl2dim,
                       SparseTensorEnumeratorBase<V> &lvlEnumerator);
 
   /// Constructs a sparse tensor with the given encoding, and initializes
-  /// the contents from the level buffers.  This ctor allocates exactly
+  /// the contents from the level buffers. This ctor allocates exactly
   /// the required amount of overhead storage, not using any heuristics.
   /// It assumes that the data provided by `lvlBufs` can be directly used to
   /// interpret the result sparse tensor and performs *NO* integrity test on the
@@ -303,8 +269,8 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
   /// passed in as a single AoS memory.
   SparseTensorStorage(uint64_t dimRank, const uint64_t *dimSizes,
                       uint64_t lvlRank, const uint64_t *lvlSizes,
-                      const DimLevelType *lvlTypes, const uint64_t *lvl2dim,
-                      const intptr_t *lvlBufs);
+                      const DimLevelType *lvlTypes, const uint64_t *dim2lvl,
+                      const uint64_t *lvl2dim, const intptr_t *lvlBufs);
 
   /// Allocates a new empty sparse tensor. The preconditions/assertions
   /// are as per the `SparseTensorStorageBase` ctor; which is to say,
@@ -313,21 +279,15 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
   static SparseTensorStorage<P, C, V> *
   newEmpty(uint64_t dimRank, const uint64_t *dimSizes, uint64_t lvlRank,
            const uint64_t *lvlSizes, const DimLevelType *lvlTypes,
-           const uint64_t *lvl2dim) {
-    return new SparseTensorStorage<P, C, V>(dimRank, dimSizes, lvlRank,
-                                            lvlSizes, lvlTypes, lvl2dim, true);
+           const uint64_t *dim2lvl, const uint64_t *lvl2dim) {
+    return new SparseTensorStorage<P, C, V>(
+        dimRank, dimSizes, lvlRank, lvlSizes, lvlTypes, dim2lvl, lvl2dim, true);
   }
 
   /// Allocates a new sparse tensor and initializes it from the given COO.
   /// The preconditions are as per the `SparseTensorStorageBase` ctor
   /// (where we define `lvlSizes = lvlCOO.getDimSizes().data()`), but
   /// using the following assertions in lieu of the base ctor's assertions:
-  ///
-  /// Asserts:
-  /// * `dimRank` and `lvlRank` are nonzero.
-  /// * `lvlRank == lvlCOO.getRank()`.
-  /// * `lvlCOO.getDimSizes()` under the `lvl2dim` mapping is a refinement
-  ///   of `dimShape`.
   //
   // TODO: The ability to reconstruct dynamic dimensions-sizes does not
   // easily generalize to arbitrary `lvl2dim` mappings.  When compiling
@@ -338,8 +298,8 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
   // to update the type/preconditions of this factory too.
   static SparseTensorStorage<P, C, V> *
   newFromCOO(uint64_t dimRank, const uint64_t *dimShape, uint64_t lvlRank,
-             const DimLevelType *lvlTypes, const uint64_t *lvl2dim,
-             SparseTensorCOO<V> &lvlCOO);
+             const DimLevelType *lvlTypes, const uint64_t *dim2lvl,
+             const uint64_t *lvl2dim, SparseTensorCOO<V> &lvlCOO);
 
   /// Allocates a new sparse tensor and initializes it with the contents
   /// of another sparse tensor.
@@ -370,8 +330,9 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
   static SparseTensorStorage<P, C, V> *
   newFromSparseTensor(uint64_t dimRank, const uint64_t *dimShape,
                       uint64_t lvlRank, const uint64_t *lvlSizes,
-                      const DimLevelType *lvlTypes, const uint64_t *lvl2dim,
-                      uint64_t srcRank, const uint64_t *src2lvl,
+                      const DimLevelType *lvlTypes,
+                      const uint64_t *src2lvl, // FIXME: dim2lvl,
+                      const uint64_t *lvl2dim, uint64_t srcRank,
                       const SparseTensorStorageBase &source);
 
   /// Allocates a new sparse tensor and initialize it with the data stored level
@@ -380,24 +341,23 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
   /// Precondition:
   /// * as per the `SparseTensorStorageBase` ctor.
   /// * the data integrity stored in `buffers` is guaranteed by users already.
-  static SparseTensorStorage<P, C, V> *
-  packFromLvlBuffers(uint64_t dimRank, const uint64_t *dimShape,
-                     uint64_t lvlRank, const uint64_t *lvlSizes,
-                     const DimLevelType *lvlTypes, const uint64_t *lvl2dim,
-                     uint64_t srcRank, const uint64_t *src2lvl,
-                     const intptr_t *buffers);
+  static SparseTensorStorage<P, C, V> *packFromLvlBuffers(
+      uint64_t dimRank, const uint64_t *dimShape, uint64_t lvlRank,
+      const uint64_t *lvlSizes, const DimLevelType *lvlTypes,
+      const uint64_t *src2lvl, // FIXME: dim2lvl
+      const uint64_t *lvl2dim, uint64_t srcRank, const intptr_t *buffers);
 
   ~SparseTensorStorage() final = default;
 
   /// Partially specialize these getter methods based on template types.
   void getPositions(std::vector<P> **out, uint64_t lvl) final {
     assert(out && "Received nullptr for out parameter");
-    ASSERT_VALID_LVL(lvl);
+    assert(lvl < getLvlRank() && "Level is out of bounds");
     *out = &positions[lvl];
   }
   void getCoordinates(std::vector<C> **out, uint64_t lvl) final {
     assert(out && "Received nullptr for out parameter");
-    ASSERT_VALID_LVL(lvl);
+    assert(lvl < getLvlRank() && "Level is out of bounds");
     *out = &coordinates[lvl];
   }
   void getValues(std::vector<V> **out) final {
@@ -477,12 +437,12 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
 
   /// Allocates a new COO object and initializes it with the contents
   /// of this tensor under the given mapping from the `getDimSizes()`
-  /// coordinate-space to the `trgSizes` coordinate-space.  Callers must
+  /// coordinate-space to the `trgSizes` coordinate-space. Callers must
   /// make sure to delete the COO when they're done with it.
-  ///
-  /// Preconditions/assertions are as per the `SparseTensorEnumerator` ctor.
   SparseTensorCOO<V> *toCOO(uint64_t trgRank, const uint64_t *trgSizes,
-                            uint64_t srcRank, const uint64_t *src2trg) const {
+                            uint64_t srcRank,
+                            const uint64_t *src2trg, // FIXME: dim2lvl
+                            const uint64_t *lvl2dim) const {
     // We inline `newEnumerator` to avoid virtual dispatch and allocation.
     // TODO: use MapRef here too for the translation
     SparseTensorEnumerator<P, C, V> enumerator(*this, trgRank, trgSizes,
@@ -503,7 +463,7 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
   /// does not check that `pos` is semantically valid (i.e., larger than
   /// the previous position and smaller than `coordinates[lvl].capacity()`).
   void appendPos(uint64_t lvl, uint64_t pos, uint64_t count = 1) {
-    ASSERT_COMPRESSED_LVL(lvl);
+    assert(isCompressedLvl(lvl) && "Level is not compressed");
     positions[lvl].insert(positions[lvl].end(), count,
                           detail::checkOverflowCast<P>(pos));
   }
@@ -689,9 +649,6 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
 };
 
 #undef ASSERT_COMPRESSED_OR_SINGLETON_LVL
-#undef ASSERT_COMPRESSED_LVL
-#undef ASSERT_VALID_LVL
-#undef ASSERT_VALID_DIM
 
 //===----------------------------------------------------------------------===//
 /// A (higher-order) function object for enumerating the elements of some
@@ -934,11 +891,12 @@ class SparseTensorNNZ final {
 //===----------------------------------------------------------------------===//
 // Definitions of the ctors and factories of `SparseTensorStorage<P,C,V>`.
 
+// TODO: MapRef
 template <typename P, typename C, typename V>
 SparseTensorStorage<P, C, V> *SparseTensorStorage<P, C, V>::newFromCOO(
     uint64_t dimRank, const uint64_t *dimShape, uint64_t lvlRank,
-    const DimLevelType *lvlTypes, const uint64_t *lvl2dim,
-    SparseTensorCOO<V> &lvlCOO) {
+    const DimLevelType *lvlTypes, const uint64_t *dim2lvl,
+    const uint64_t *lvl2dim, SparseTensorCOO<V> &lvlCOO) {
   assert(dimShape && "Got nullptr for dimension shape");
   assert(lvl2dim && "Got nullptr for level-to-dimension mapping");
   const auto &lvlSizes = lvlCOO.getDimSizes();
@@ -955,14 +913,15 @@ SparseTensorStorage<P, C, V> *SparseTensorStorage<P, C, V>::newFromCOO(
     dimSizes[d] = lvlSizes[l];
   }
   return new SparseTensorStorage<P, C, V>(dimRank, dimSizes.data(), lvlRank,
-                                          lvlTypes, lvl2dim, lvlCOO);
+                                          lvlTypes, dim2lvl, lvl2dim, lvlCOO);
 }
 
 template <typename P, typename C, typename V>
 SparseTensorStorage<P, C, V> *SparseTensorStorage<P, C, V>::newFromSparseTensor(
     uint64_t dimRank, const uint64_t *dimShape, uint64_t lvlRank,
     const uint64_t *lvlSizes, const DimLevelType *lvlTypes,
-    const uint64_t *lvl2dim, uint64_t srcRank, const uint64_t *src2lvl,
+    const uint64_t *src2lvl, // dim2lvl
+    const uint64_t *lvl2dim, uint64_t srcRank,
     const SparseTensorStorageBase &source) {
   // Verify that the `source` dimensions match the expected `dimShape`.
   assert(dimShape && "Got nullptr for dimension shape");
@@ -977,8 +936,9 @@ SparseTensorStorage<P, C, V> *SparseTensorStorage<P, C, V>::newFromSparseTensor(
 #endif
   SparseTensorEnumeratorBase<V> *lvlEnumerator;
   source.newEnumerator(&lvlEnumerator, lvlRank, lvlSizes, srcRank, src2lvl);
-  auto *tensor = new SparseTensorStorage<P, C, V>(
-      dimRank, dimSizes.data(), lvlRank, lvlTypes, lvl2dim, *lvlEnumerator);
+  auto *tensor = new SparseTensorStorage<P, C, V>(dimRank, dimSizes.data(),
+                                                  lvlRank, lvlTypes, src2lvl,
+                                                  lvl2dim, *lvlEnumerator);
   delete lvlEnumerator;
   return tensor;
 }
@@ -987,11 +947,12 @@ template <typename P, typename C, typename V>
 SparseTensorStorage<P, C, V> *SparseTensorStorage<P, C, V>::packFromLvlBuffers(
     uint64_t dimRank, const uint64_t *dimShape, uint64_t lvlRank,
     const uint64_t *lvlSizes, const DimLevelType *lvlTypes,
-    const uint64_t *lvl2dim, uint64_t srcRank, const uint64_t *src2lvl,
-    const intptr_t *buffers) {
+    const uint64_t *src2lvl, // FIXME: dim2lvl
+    const uint64_t *lvl2dim, uint64_t srcRank, const intptr_t *buffers) {
   assert(dimShape && "Got nullptr for dimension shape");
-  auto *tensor = new SparseTensorStorage<P, C, V>(
-      dimRank, dimShape, lvlRank, lvlSizes, lvlTypes, lvl2dim, buffers);
+  auto *tensor =
+      new SparseTensorStorage<P, C, V>(dimRank, dimShape, lvlRank, lvlSizes,
+                                       lvlTypes, src2lvl, lvl2dim, buffers);
   return tensor;
 }
 
@@ -999,9 +960,10 @@ template <typename P, typename C, typename V>
 SparseTensorStorage<P, C, V>::SparseTensorStorage(
     uint64_t dimRank, const uint64_t *dimSizes, uint64_t lvlRank,
     const uint64_t *lvlSizes, const DimLevelType *lvlTypes,
-    const uint64_t *lvl2dim, bool initializeValuesIfAllDense)
+    const uint64_t *dim2lvl, const uint64_t *lvl2dim,
+    bool initializeValuesIfAllDense)
     : SparseTensorStorage(dimRank, dimSizes, lvlRank, lvlSizes, lvlTypes,
-                          lvl2dim) {
+                          dim2lvl, lvl2dim) {
   // Provide hints on capacity of positions and coordinates.
   // TODO: needs much fine-tuning based on actual sparsity; currently
   // we reserve position/coordinate space based on all previous dense
@@ -1012,8 +974,6 @@ SparseTensorStorage<P, C, V>::SparseTensorStorage(
   for (uint64_t l = 0; l < lvlRank; ++l) {
     const DimLevelType dlt = lvlTypes[l]; // Avoid redundant bounds checking.
     if (isCompressedDLT(dlt)) {
-      // TODO: Take a parameter between 1 and `lvlSizes[l]`, and multiply
-      // `sz` by that before reserving. (For now we just use 1.)
       positions[l].reserve(sz + 1);
       positions[l].push_back(0);
       coordinates[l].reserve(sz);
@@ -1035,11 +995,11 @@ SparseTensorStorage<P, C, V>::SparseTensorStorage(
 template <typename P, typename C, typename V>
 SparseTensorStorage<P, C, V>::SparseTensorStorage( // NOLINT
     uint64_t dimRank, const uint64_t *dimSizes, uint64_t lvlRank,
-    const DimLevelType *lvlTypes, const uint64_t *lvl2dim,
-    SparseTensorCOO<V> &lvlCOO)
+    const DimLevelType *lvlTypes, const uint64_t *dim2lvl,
+    const uint64_t *lvl2dim, SparseTensorCOO<V> &lvlCOO)
     : SparseTensorStorage(dimRank, dimSizes, lvlRank,
-                          lvlCOO.getDimSizes().data(), lvlTypes, lvl2dim,
-                          false) {
+                          lvlCOO.getDimSizes().data(), lvlTypes, dim2lvl,
+                          lvl2dim, false) {
   assert(lvlRank == lvlCOO.getDimSizes().size() && "Level-rank mismatch");
   // Ensure the preconditions of `fromCOO`.  (One is already ensured by
   // using `lvlSizes = lvlCOO.getDimSizes()` in the ctor above.)
@@ -1054,10 +1014,10 @@ SparseTensorStorage<P, C, V>::SparseTensorStorage( // NOLINT
 template <typename P, typename C, typename V>
 SparseTensorStorage<P, C, V>::SparseTensorStorage(
     uint64_t dimRank, const uint64_t *dimSizes, uint64_t lvlRank,
-    const DimLevelType *lvlTypes, const uint64_t *lvl2dim,
-    SparseTensorEnumeratorBase<V> &lvlEnumerator)
+    const DimLevelType *lvlTypes, const uint64_t *dim2lvl,
+    const uint64_t *lvl2dim, SparseTensorEnumeratorBase<V> &lvlEnumerator)
     : SparseTensorStorage(dimRank, dimSizes, lvlRank,
-                          lvlEnumerator.getTrgSizes().data(), lvlTypes,
+                          lvlEnumerator.getTrgSizes().data(), lvlTypes, dim2lvl,
                           lvl2dim) {
   assert(lvlRank == lvlEnumerator.getTrgRank() && "Level-rank mismatch");
   {
@@ -1137,7 +1097,6 @@ SparseTensorStorage<P, C, V>::SparseTensorStorage(
       // Can't check all of them, but at least we can check the last one.
       assert(positions[l][parentSz - 1] == positions[l][parentSz] &&
              "Positions got corrupted");
-      // TODO: optimize this by using `memmove` or similar.
       for (uint64_t n = 0; n < parentSz; ++n) {
         const uint64_t parentPos = parentSz - n;
         positions[l][parentPos] = positions[l][parentPos - 1];
@@ -1157,9 +1116,9 @@ template <typename P, typename C, typename V>
 SparseTensorStorage<P, C, V>::SparseTensorStorage(
     uint64_t dimRank, const uint64_t *dimSizes, uint64_t lvlRank,
     const uint64_t *lvlSizes, const DimLevelType *lvlTypes,
-    const uint64_t *lvl2dim, const intptr_t *lvlBufs)
+    const uint64_t *dim2lvl, const uint64_t *lvl2dim, const intptr_t *lvlBufs)
     : SparseTensorStorage(dimRank, dimSizes, lvlRank, lvlSizes, lvlTypes,
-                          lvl2dim) {
+                          dim2lvl, lvl2dim) {
   uint64_t trailCOOLen = 0, parentSz = 1, bufIdx = 0;
   for (uint64_t l = 0; l < lvlRank; l++) {
     if (!isUniqueLvl(l) && isCompressedLvl(l)) {
diff --git a/mlir/lib/ExecutionEngine/SparseTensor/Storage.cpp b/mlir/lib/ExecutionEngine/SparseTensor/Storage.cpp
index 199e4205a61a25b..1d654cae3b4b125 100644
--- a/mlir/lib/ExecutionEngine/SparseTensor/Storage.cpp
+++ b/mlir/lib/ExecutionEngine/SparseTensor/Storage.cpp
@@ -20,15 +20,14 @@ using namespace mlir::sparse_tensor;
 SparseTensorStorageBase::SparseTensorStorageBase( // NOLINT
     uint64_t dimRank, const uint64_t *dimSizes, uint64_t lvlRank,
     const uint64_t *lvlSizes, const DimLevelType *lvlTypes,
-    const uint64_t *lvl2dim)
+    const uint64_t *dim2lvl, const uint64_t *lvl2dim)
     : dimSizes(dimSizes, dimSizes + dimRank),
       lvlSizes(lvlSizes, lvlSizes + lvlRank),
       lvlTypes(lvlTypes, lvlTypes + lvlRank),
-      lvl2dim(lvl2dim, lvl2dim + lvlRank) {
-  assert(dimSizes && "Got nullptr for dimension sizes");
-  assert(lvlSizes && "Got nullptr for level sizes");
-  assert(lvlTypes && "Got nullptr for level types");
-  assert(lvl2dim && "Got nullptr for level-to-dimension mapping");
+      dim2lvlVec(dim2lvl, dim2lvl + dimRank),
+      lvl2dimVec(lvl2dim, lvl2dim + lvlRank),
+      map(dimRank, lvlRank, dim2lvlVec.data(), lvl2dimVec.data()) {
+  assert(dimSizes && lvlSizes && lvlTypes && dim2lvl && lvl2dim);
   // Validate dim-indexed parameters.
   assert(dimRank > 0 && "Trivial shape is unsupported");
   for (uint64_t d = 0; d < dimRank; ++d)
diff --git a/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp b/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp
index 8340fe7dcf925be..bc6d4ad2c740189 100644
--- a/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp
+++ b/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp
@@ -179,39 +179,39 @@ extern "C" {
     switch (action) {                                                          \
     case Action::kEmpty:                                                       \
       return SparseTensorStorage<P, C, V>::newEmpty(                           \
-          dimRank, dimSizes, lvlRank, lvlSizes, lvlTypes, lvl2dim);            \
+          dimRank, dimSizes, lvlRank, lvlSizes, lvlTypes, dim2lvl, lvl2dim);   \
     case Action::kFromCOO: {                                                   \
       assert(ptr && "Received nullptr for SparseTensorCOO object");            \
       auto &coo = *static_cast<SparseTensorCOO<V> *>(ptr);                     \
       return SparseTensorStorage<P, C, V>::newFromCOO(                         \
-          dimRank, dimSizes, lvlRank, lvlTypes, lvl2dim, coo);                 \
+          dimRank, dimSizes, lvlRank, lvlTypes, dim2lvl, lvl2dim, coo);        \
     }                                                                          \
     case Action::kSparseToSparse: {                                            \
       assert(ptr && "Received nullptr for SparseTensorStorage object");        \
       auto &tensor = *static_cast<SparseTensorStorageBase *>(ptr);             \
       return SparseTensorStorage<P, C, V>::newFromSparseTensor(                \
-          dimRank, dimSizes, lvlRank, lvlSizes, lvlTypes, lvl2dim, dimRank,    \
-          dim2lvl, tensor);                                                    \
+          dimRank, dimSizes, lvlRank, lvlSizes, lvlTypes, dim2lvl, lvl2dim,    \
+          dimRank, tensor);                                                    \
     }                                                                          \
     case Action::kEmptyCOO:                                                    \
       return new SparseTensorCOO<V>(lvlRank, lvlSizes);                        \
     case Action::kToCOO: {                                                     \
       assert(ptr && "Received nullptr for SparseTensorStorage object");        \
       auto &tensor = *static_cast<SparseTensorStorage<P, C, V> *>(ptr);        \
-      return tensor.toCOO(lvlRank, lvlSizes, dimRank, dim2lvl);                \
+      return tensor.toCOO(lvlRank, lvlSizes, dimRank, dim2lvl, lvl2dim);       \
     }                                                                          \
     case Action::kToIterator: {                                                \
       assert(ptr && "Received nullptr for SparseTensorStorage object");        \
       auto &tensor = *static_cast<SparseTensorStorage<P, C, V> *>(ptr);        \
-      auto *coo = tensor.toCOO(lvlRank, lvlSizes, dimRank, dim2lvl);           \
+      auto *coo = tensor.toCOO(lvlRank, lvlSizes, dimRank, dim2lvl, lvl2dim);  \
       return new SparseTensorIterator<V>(coo);                                 \
     }                                                                          \
     case Action::kPack: {                                                      \
       assert(ptr && "Received nullptr for SparseTensorStorage object");        \
       intptr_t *buffers = static_cast<intptr_t *>(ptr);                        \
       return SparseTensorStorage<P, C, V>::packFromLvlBuffers(                 \
-          dimRank, dimSizes, lvlRank, lvlSizes, lvlTypes, lvl2dim, dimRank,    \
-          dim2lvl, buffers);                                                   \
+          dimRank, dimSizes, lvlRank, lvlSizes, lvlTypes, dim2lvl, lvl2dim,    \
+          dimRank, buffers);                                                   \
     }                                                                          \
     }                                                                          \
     MLIR_SPARSETENSOR_FATAL("unknown action: %d\n",                            \
@@ -250,9 +250,6 @@ void *_mlir_ciface_newSparseTensor( // NOLINT
   const index_type *dim2lvl = MEMREF_GET_PAYLOAD(dim2lvlRef);
   const index_type *lvl2dim = MEMREF_GET_PAYLOAD(lvl2dimRef);
 
-  // Prepare map.
-  // TODO: start using MapRef map(dimRank, lvlRank, dim2lvl, lvl2dim) below
-
   // Rewrite kIndex to kU64, to avoid introducing a bunch of new cases.
   // This is safe because of the static_assert above.
   if (posTp == OverheadType::kIndex)
@@ -403,7 +400,7 @@ MLIR_SPARSETENSOR_FOREVERY_O(IMPL_SPARSECOORDINATES)
 #undef IMPL_GETOVERHEAD
 
 // TODO: use MapRef here for translation of coordinates
-// TOOD: remove dim2lvl
+// TODO: remove dim2lvl
 #define IMPL_ADDELT(VNAME, V)                                                  \
   void *_mlir_ciface_addElt##VNAME(                                            \
       void *lvlCOO, StridedMemRefType<V, 0> *vref,                             \



More information about the Mlir-commits mailing list