[Mlir-commits] [mlir] ac74188 - [mlir][sparse] Adding isSorted bit to SparseTensorCOO
wren romano
llvmlistbot at llvm.org
Thu Sep 29 15:02:28 PDT 2022
Author: wren romano
Date: 2022-09-29T15:02:17-07:00
New Revision: ac741889c1448170f15a4f4bf93db6cc89518169
URL: https://github.com/llvm/llvm-project/commit/ac741889c1448170f15a4f4bf93db6cc89518169
DIFF: https://github.com/llvm/llvm-project/commit/ac741889c1448170f15a4f4bf93db6cc89518169.diff
LOG: [mlir][sparse] Adding isSorted bit to SparseTensorCOO
This is a followup to the refactoring of D133462, D133830, D133831, and D133833.
Depends On D133833
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D133839
Added:
Modified:
mlir/include/mlir/ExecutionEngine/SparseTensor/COO.h
Removed:
################################################################################
diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensor/COO.h b/mlir/include/mlir/ExecutionEngine/SparseTensor/COO.h
index bef340bce6f8..1d45ea2112c0 100644
--- a/mlir/include/mlir/ExecutionEngine/SparseTensor/COO.h
+++ b/mlir/include/mlir/ExecutionEngine/SparseTensor/COO.h
@@ -48,6 +48,26 @@ struct Element final {
V value;
};
+/// Closure object for `operator<` on `Element` with a given rank.
+template <typename V>
+struct ElementLT final {
+ ElementLT(uint64_t rank) : rank(rank) {}
+
+ /// Compare two elements a la `operator<`.
+ ///
+ /// Precondition: the elements must both be valid for `rank`.
+ bool operator()(const Element<V> &e1, const Element<V> &e2) const {
+ for (uint64_t d = 0; d < rank; ++d) {
+ if (e1.indices[d] == e2.indices[d])
+ continue;
+ return e1.indices[d] < e2.indices[d];
+ }
+ return false;
+ }
+
+ const uint64_t rank;
+};
+
/// The type of callback functions which receive an element. We avoid
/// packaging the coordinates and value together as an `Element` object
/// because this helps keep code somewhat cleaner.
@@ -64,7 +84,8 @@ template <typename V>
class SparseTensorCOO final {
public:
SparseTensorCOO(const std::vector<uint64_t> &dimSizes, uint64_t capacity)
- : dimSizes(dimSizes) {
+ : dimSizes(dimSizes), isSorted(true), iteratorLocked(false),
+ iteratorPos(0) {
if (capacity) {
elements.reserve(capacity);
indices.reserve(capacity * getRank());
@@ -100,6 +121,9 @@ class SparseTensorCOO final {
/// Get the elements array.
const std::vector<Element<V>> &getElements() const { return elements; }
+ /// Returns the `operator<` closure object for the COO's element type.
+ ElementLT<V> getElementLT() const { return ElementLT<V>(getRank()); }
+
/// Adds an element to the tensor. This method does not check whether
/// `ind` is already associated with a value, it adds it regardless.
/// Resolving such conflicts is left up to clients of the iterator
@@ -130,8 +154,11 @@ class SparseTensorCOO final {
elements[i].indices = newBase + (elements[i].indices - base);
base = newBase;
}
- // Add element as (pointer into shared index pool, value) pair.
- elements.emplace_back(base + size, val);
+ // Add the new element and update the sorted bit.
+ Element<V> addedElem(base + size, val);
+ if (!elements.empty() && isSorted)
+ isSorted = getElementLT()(elements.back(), addedElem);
+ elements.push_back(addedElem);
}
/// Sorts elements lexicographically by index. If an index is mapped to
@@ -140,18 +167,10 @@ class SparseTensorCOO final {
/// Asserts: is not in iterator mode.
void sort() {
assert(!iteratorLocked && "Attempt to sort() after startIterator()");
- // TODO: we may want to cache an `isSorted` bit, to avoid
- // unnecessary/redundant sorting.
- uint64_t rank = getRank();
- std::sort(elements.begin(), elements.end(),
- [rank](const Element<V> &e1, const Element<V> &e2) {
- for (uint64_t r = 0; r < rank; ++r) {
- if (e1.indices[r] == e2.indices[r])
- continue;
- return e1.indices[r] < e2.indices[r];
- }
- return false;
- });
+ if (isSorted)
+ return;
+ std::sort(elements.begin(), elements.end(), getElementLT());
+ isSorted = true;
}
/// Switch into iterator mode. If already in iterator mode, then
@@ -177,8 +196,9 @@ class SparseTensorCOO final {
const std::vector<uint64_t> dimSizes; // per-dimension sizes
std::vector<Element<V>> elements; // all COO elements
std::vector<uint64_t> indices; // shared index pool
- bool iteratorLocked = false;
- unsigned iteratorPos = 0;
+ bool isSorted;
+ bool iteratorLocked;
+ unsigned iteratorPos;
};
} // namespace sparse_tensor
More information about the Mlir-commits
mailing list