[Mlir-commits] [mlir] ac74188 - [mlir][sparse] Adding isSorted bit to SparseTensorCOO

wren romano llvmlistbot at llvm.org
Thu Sep 29 15:02:28 PDT 2022


Author: wren romano
Date: 2022-09-29T15:02:17-07:00
New Revision: ac741889c1448170f15a4f4bf93db6cc89518169

URL: https://github.com/llvm/llvm-project/commit/ac741889c1448170f15a4f4bf93db6cc89518169
DIFF: https://github.com/llvm/llvm-project/commit/ac741889c1448170f15a4f4bf93db6cc89518169.diff

LOG: [mlir][sparse] Adding isSorted bit to SparseTensorCOO

This is a followup to the refactoring of D133462, D133830, D133831, and D133833.

Depends On D133833

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D133839

Added: 
    

Modified: 
    mlir/include/mlir/ExecutionEngine/SparseTensor/COO.h

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/ExecutionEngine/SparseTensor/COO.h b/mlir/include/mlir/ExecutionEngine/SparseTensor/COO.h
index bef340bce6f8..1d45ea2112c0 100644
--- a/mlir/include/mlir/ExecutionEngine/SparseTensor/COO.h
+++ b/mlir/include/mlir/ExecutionEngine/SparseTensor/COO.h
@@ -48,6 +48,26 @@ struct Element final {
   V value;
 };
 
+/// Closure object for `operator<` on `Element` with a given rank.
+template <typename V>
+struct ElementLT final {
+  ElementLT(uint64_t rank) : rank(rank) {}
+
+  /// Compare two elements a la `operator<`.
+  ///
+  /// Precondition: the elements must both be valid for `rank`.
+  bool operator()(const Element<V> &e1, const Element<V> &e2) const {
+    for (uint64_t d = 0; d < rank; ++d) {
+      if (e1.indices[d] == e2.indices[d])
+        continue;
+      return e1.indices[d] < e2.indices[d];
+    }
+    return false;
+  }
+
+  const uint64_t rank;
+};
+
 /// The type of callback functions which receive an element.  We avoid
 /// packaging the coordinates and value together as an `Element` object
 /// because this helps keep code somewhat cleaner.
@@ -64,7 +84,8 @@ template <typename V>
 class SparseTensorCOO final {
 public:
   SparseTensorCOO(const std::vector<uint64_t> &dimSizes, uint64_t capacity)
-      : dimSizes(dimSizes) {
+      : dimSizes(dimSizes), isSorted(true), iteratorLocked(false),
+        iteratorPos(0) {
     if (capacity) {
       elements.reserve(capacity);
       indices.reserve(capacity * getRank());
@@ -100,6 +121,9 @@ class SparseTensorCOO final {
   /// Get the elements array.
   const std::vector<Element<V>> &getElements() const { return elements; }
 
+  /// Returns the `operator<` closure object for the COO's element type.
+  ElementLT<V> getElementLT() const { return ElementLT<V>(getRank()); }
+
   /// Adds an element to the tensor.  This method does not check whether
   /// `ind` is already associated with a value, it adds it regardless.
   /// Resolving such conflicts is left up to clients of the iterator
@@ -130,8 +154,11 @@ class SparseTensorCOO final {
         elements[i].indices = newBase + (elements[i].indices - base);
       base = newBase;
     }
-    // Add element as (pointer into shared index pool, value) pair.
-    elements.emplace_back(base + size, val);
+    // Add the new element and update the sorted bit.
+    Element<V> addedElem(base + size, val);
+    if (!elements.empty() && isSorted)
+      isSorted = getElementLT()(elements.back(), addedElem);
+    elements.push_back(addedElem);
   }
 
   /// Sorts elements lexicographically by index.  If an index is mapped to
@@ -140,18 +167,10 @@ class SparseTensorCOO final {
   /// Asserts: is not in iterator mode.
   void sort() {
     assert(!iteratorLocked && "Attempt to sort() after startIterator()");
-    // TODO: we may want to cache an `isSorted` bit, to avoid
-    // unnecessary/redundant sorting.
-    uint64_t rank = getRank();
-    std::sort(elements.begin(), elements.end(),
-              [rank](const Element<V> &e1, const Element<V> &e2) {
-                for (uint64_t r = 0; r < rank; ++r) {
-                  if (e1.indices[r] == e2.indices[r])
-                    continue;
-                  return e1.indices[r] < e2.indices[r];
-                }
-                return false;
-              });
+    if (isSorted)
+      return;
+    std::sort(elements.begin(), elements.end(), getElementLT());
+    isSorted = true;
   }
 
   /// Switch into iterator mode.  If already in iterator mode, then
@@ -177,8 +196,9 @@ class SparseTensorCOO final {
   const std::vector<uint64_t> dimSizes; // per-dimension sizes
   std::vector<Element<V>> elements;     // all COO elements
   std::vector<uint64_t> indices;        // shared index pool
-  bool iteratorLocked = false;
-  unsigned iteratorPos = 0;
+  bool isSorted;
+  bool iteratorLocked;
+  unsigned iteratorPos;
 };
 
 } // namespace sparse_tensor


        


More information about the Mlir-commits mailing list