[Mlir-commits] [mlir] 753fe33 - [mlir][sparse] Factoring out an enumerator over elements of SparseTensorStorage
wren romano
llvmlistbot at llvm.org
Thu May 12 17:06:03 PDT 2022
Author: wren romano
Date: 2022-05-12T17:05:56-07:00
New Revision: 753fe330c1d6cfebac07ecd385fb2dcf63a0f6c9
URL: https://github.com/llvm/llvm-project/commit/753fe330c1d6cfebac07ecd385fb2dcf63a0f6c9
DIFF: https://github.com/llvm/llvm-project/commit/753fe330c1d6cfebac07ecd385fb2dcf63a0f6c9.diff
LOG: [mlir][sparse] Factoring out an enumerator over elements of SparseTensorStorage
Work towards fixing: https://github.com/llvm/llvm-project/issues/51652
Depends On D122928
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D122060
Added:
Modified:
mlir/lib/ExecutionEngine/SparseTensorUtils.cpp
Removed:
################################################################################
diff --git a/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp b/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp
index a7d636749c92c..336c500ce1d33 100644
--- a/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp
+++ b/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp
@@ -27,6 +27,7 @@
#include <cstdlib>
#include <cstring>
#include <fstream>
+#include <functional>
#include <iostream>
#include <limits>
#include <numeric>
@@ -94,6 +95,13 @@ struct Element {
V value;
};
+/// The type of callback functions which receive an element. We avoid
+/// packaging the coordinates and value together as an `Element` object
+/// because this helps keep code somewhat cleaner.
+template <typename V>
+using ElementConsumer =
+ const std::function<void(const std::vector<uint64_t> &, V)> &;
+
/// A memory-resident sparse tensor in coordinate scheme (collection of
/// elements). This data structure is used to read a sparse tensor from
/// any external format into memory and sort the elements lexicographically
@@ -220,6 +228,7 @@ class SparseTensorStorageBase {
const uint64_t *perm, const DimLevelType *sparsity)
: dimSizes(szs), rev(getRank()),
dimTypes(sparsity, sparsity + getRank()) {
+ assert(perm && sparsity);
const uint64_t rank = getRank();
// Validate parameters.
assert(rank > 0 && "Trivial shape is unsupported");
@@ -310,6 +319,16 @@ class SparseTensorStorageBase {
/// Finishes insertion.
virtual void endInsert() = 0;
+protected:
+ // Since this class is virtual, we must disallow public copying in
+ // order to avoid "slicing". Since this class has data members,
+ // that means making copying protected.
+ // <https://github.com/isocpp/CppCoreGuidelines/blob/master/CppCoreGuidelines.md#Rc-copy-virtual>
+ SparseTensorStorageBase(const SparseTensorStorageBase &) = default;
+ // Copy-assignment would be implicitly deleted (because `dimSizes`
+ // is const), so we explicitly delete it for clarity.
+ SparseTensorStorageBase &operator=(const SparseTensorStorageBase &) = delete;
+
private:
static void fatal(const char *tp) {
fprintf(stderr, "unsupported %s\n", tp);
@@ -321,6 +340,10 @@ class SparseTensorStorageBase {
const std::vector<DimLevelType> dimTypes;
};
+// Forward.
+template <typename P, typename I, typename V>
+class SparseTensorEnumerator;
+
/// A memory-resident sparse tensor using a storage scheme based on
/// per-dimension sparse/dense annotations. This data structure provides a
/// bufferized form of a sparse tensor type. In contrast to generating setup
@@ -443,24 +466,13 @@ class SparseTensorStorage : public SparseTensorStorageBase {
/// sparse tensor in coordinate scheme with the given dimension order.
///
/// Precondition: `perm` must be valid for `getRank()`.
- SparseTensorCOO<V> *toCOO(const uint64_t *perm) {
- // Restore original order of the dimension sizes and allocate coordinate
- // scheme with desired new ordering specified in perm.
- const uint64_t rank = getRank();
- const auto &rev = getRev();
- const auto &sizes = getDimSizes();
- std::vector<uint64_t> orgsz(rank);
- for (uint64_t r = 0; r < rank; r++)
- orgsz[rev[r]] = sizes[r];
- SparseTensorCOO<V> *coo = SparseTensorCOO<V>::newSparseTensorCOO(
- rank, orgsz.data(), perm, values.size());
- // Populate coordinate scheme restored from old ordering and changed with
- // new ordering. Rather than applying both reorderings during the recursion,
- // we compute the combine permutation in advance.
- std::vector<uint64_t> reord(rank);
- for (uint64_t r = 0; r < rank; r++)
- reord[r] = perm[rev[r]];
- toCOO(*coo, reord, 0, 0);
+ SparseTensorCOO<V> *toCOO(const uint64_t *perm) const {
+ SparseTensorEnumerator<P, I, V> enumerator(*this, getRank(), perm);
+ SparseTensorCOO<V> *coo =
+ new SparseTensorCOO<V>(enumerator.permutedSizes(), values.size());
+ enumerator.forallElements([&coo](const std::vector<uint64_t> &ind, V val) {
+ coo->add(ind, val);
+ });
// TODO: This assertion assumes there are no stored zeros,
// or if there are then that we don't filter them out.
// Cf., <https://github.com/llvm/llvm-project/issues/54179>
@@ -543,9 +555,10 @@ class SparseTensorStorage : public SparseTensorStorageBase {
/// and pointwise less-than).
void fromCOO(const std::vector<Element<V>> &elements, uint64_t lo,
uint64_t hi, uint64_t d) {
+ uint64_t rank = getRank();
+ assert(d <= rank && hi <= elements.size());
// Once dimensions are exhausted, insert the numerical values.
- assert(d <= getRank() && hi <= elements.size());
- if (d == getRank()) {
+ if (d == rank) {
assert(lo < hi);
values.push_back(elements[lo].value);
return;
@@ -569,31 +582,6 @@ class SparseTensorStorage : public SparseTensorStorageBase {
finalizeSegment(d, full);
}
- /// Stores the sparse tensor storage scheme into a memory-resident sparse
- /// tensor in coordinate scheme.
- void toCOO(SparseTensorCOO<V> &tensor, std::vector<uint64_t> &reord,
- uint64_t pos, uint64_t d) {
- assert(d <= getRank());
- if (d == getRank()) {
- assert(pos < values.size());
- tensor.add(idx, values[pos]);
- } else if (isCompressedDim(d)) {
- // Sparse dimension.
- for (uint64_t ii = pointers[d][pos]; ii < pointers[d][pos + 1]; ii++) {
- idx[reord[d]] = indices[d][ii];
- toCOO(tensor, reord, ii, d + 1);
- }
- } else {
- // Dense dimension.
- const uint64_t sz = getDimSizes()[d];
- const uint64_t off = pos * sz;
- for (uint64_t i = 0; i < sz; i++) {
- idx[reord[d]] = i;
- toCOO(tensor, reord, off + i, d + 1);
- }
- }
- }
-
/// Finalize the sparse pointer structure at this dimension.
void finalizeSegment(uint64_t d, uint64_t full = 0, uint64_t count = 1) {
if (count == 0)
@@ -649,13 +637,151 @@ class SparseTensorStorage : public SparseTensorStorageBase {
return -1u;
}
-private:
+ // Allow `SparseTensorEnumerator` to access the data-members (to avoid
+ // the cost of virtual-function dispatch in inner loops), without
+ // making them public to other client code.
+ friend class SparseTensorEnumerator<P, I, V>;
+
std::vector<std::vector<P>> pointers;
std::vector<std::vector<I>> indices;
std::vector<V> values;
std::vector<uint64_t> idx; // index cursor for lexicographic insertion.
};
+/// A (higher-order) function object for enumerating the elements of some
+/// `SparseTensorStorage` under a permutation. That is, the `forallElements`
+/// method encapsulates the loop-nest for enumerating the elements of
+/// the source tensor (in whatever order is best for the source tensor),
+/// and applies a permutation to the coordinates/indices before handing
+/// each element to the callback. A single enumerator object can be
+/// freely reused for several calls to `forallElements`, just so long
+/// as each call is sequential with respect to one another.
+///
+/// N.B., this class stores a reference to the `SparseTensorStorageBase`
+/// passed to the constructor; thus, objects of this class must not
+/// outlive the sparse tensor they depend on.
+///
+/// Design Note: The reason we define this class instead of simply using
+/// `SparseTensorEnumerator<P,I,V>` is because we need to hide/generalize
+/// the `<P,I>` template parameters from MLIR client code (to simplify the
+/// type parameters used for direct sparse-to-sparse conversion). And the
+/// reason we define the `SparseTensorEnumerator<P,I,V>` subclasses rather
+/// than simply using this class, is to avoid the cost of virtual-method
+/// dispatch within the loop-nest.
+template <typename V>
+class SparseTensorEnumeratorBase {
+public:
+ /// Constructs an enumerator with the given permutation for mapping
+ /// the semantic-ordering of dimensions to the desired target-ordering.
+ ///
+ /// Preconditions:
+ /// * the `tensor` must have the same `V` value type.
+ /// * `perm` must be valid for `rank`.
+ SparseTensorEnumeratorBase(const SparseTensorStorageBase &tensor,
+ uint64_t rank, const uint64_t *perm)
+ : src(tensor), permsz(src.getRev().size()), reord(getRank()),
+ cursor(getRank()) {
+ assert(perm && "Received nullptr for permutation");
+ assert(rank == getRank() && "Permutation rank mismatch");
+ const auto &rev = src.getRev(); // source stg-order -> semantic-order
+ const auto &sizes = src.getDimSizes(); // in source storage-order
+ for (uint64_t s = 0; s < rank; s++) { // `s` source storage-order
+ uint64_t t = perm[rev[s]]; // `t` target-order
+ reord[s] = t;
+ permsz[t] = sizes[s];
+ }
+ }
+
+ virtual ~SparseTensorEnumeratorBase() = default;
+
+ // We disallow copying to help avoid leaking the `src` reference.
+ // (In addition to avoiding the problem of slicing.)
+ SparseTensorEnumeratorBase(const SparseTensorEnumeratorBase &) = delete;
+ SparseTensorEnumeratorBase &
+ operator=(const SparseTensorEnumeratorBase &) = delete;
+
+ /// Returns the source/target tensor's rank. (The source-rank and
+ /// target-rank are always equal since we only support permutations.
+ /// Though once we add support for other dimension mappings, this
+ /// method will have to be split in two.)
+ uint64_t getRank() const { return permsz.size(); }
+
+ /// Returns the target tensor's dimension sizes.
+ const std::vector<uint64_t> &permutedSizes() const { return permsz; }
+
+ /// Enumerates all elements of the source tensor, permutes their
+ /// indices, and passes the permuted element to the callback.
+ /// The callback must not store the cursor reference directly,
+ /// since this function reuses the storage. Instead, the callback
+ /// must copy it if they want to keep it.
+ virtual void forallElements(ElementConsumer<V> yield) = 0;
+
+protected:
+ const SparseTensorStorageBase &src;
+ std::vector<uint64_t> permsz; // in target order.
+ std::vector<uint64_t> reord; // source storage-order -> target order.
+ std::vector<uint64_t> cursor; // in target order.
+};
+
+template <typename P, typename I, typename V>
+class SparseTensorEnumerator final : public SparseTensorEnumeratorBase<V> {
+ using Base = SparseTensorEnumeratorBase<V>;
+
+public:
+ /// Constructs an enumerator with the given permutation for mapping
+ /// the semantic-ordering of dimensions to the desired target-ordering.
+ ///
+ /// Precondition: `perm` must be valid for `rank`.
+ SparseTensorEnumerator(const SparseTensorStorage<P, I, V> &tensor,
+ uint64_t rank, const uint64_t *perm)
+ : Base(tensor, rank, perm) {}
+
+ ~SparseTensorEnumerator() final override = default;
+
+ void forallElements(ElementConsumer<V> yield) final override {
+ forallElements(yield, 0, 0);
+ }
+
+private:
+ /// The recursive component of the public `forallElements`.
+ void forallElements(ElementConsumer<V> yield, uint64_t parentPos,
+ uint64_t d) {
+ // Recover the `<P,I,V>` type parameters of `src`.
+ const auto &src =
+ static_cast<const SparseTensorStorage<P, I, V> &>(this->src);
+ if (d == Base::getRank()) {
+ assert(parentPos < src.values.size() &&
+ "Value position is out of bounds");
+ // TODO: <https://github.com/llvm/llvm-project/issues/54179>
+ yield(this->cursor, src.values[parentPos]);
+ } else if (src.isCompressedDim(d)) {
+ // Look up the bounds of the `d`-level segment determined by the
+ // `d-1`-level position `parentPos`.
+ const std::vector<P> &pointers_d = src.pointers[d];
+ assert(parentPos + 1 < pointers_d.size() &&
+ "Parent pointer position is out of bounds");
+ const uint64_t pstart = static_cast<uint64_t>(pointers_d[parentPos]);
+ const uint64_t pstop = static_cast<uint64_t>(pointers_d[parentPos + 1]);
+ // Loop-invariant code for looking up the `d`-level coordinates/indices.
+ const std::vector<I> &indices_d = src.indices[d];
+ assert(pstop - 1 < indices_d.size() && "Index position is out of bounds");
+ uint64_t &cursor_reord_d = this->cursor[this->reord[d]];
+ for (uint64_t pos = pstart; pos < pstop; pos++) {
+ cursor_reord_d = static_cast<uint64_t>(indices_d[pos]);
+ forallElements(yield, pos, d + 1);
+ }
+ } else { // Dense dimension.
+ const uint64_t sz = src.getDimSizes()[d];
+ const uint64_t pstart = parentPos * sz;
+ uint64_t &cursor_reord_d = this->cursor[this->reord[d]];
+ for (uint64_t i = 0; i < sz; i++) {
+ cursor_reord_d = i;
+ forallElements(yield, pstart + i, d + 1);
+ }
+ }
+ }
+};
+
/// Helper to convert string to lower case.
static char *toLower(char *token) {
for (char *c = token; *c; c++)
More information about the Mlir-commits
mailing list