[Mlir-commits] [mlir] [mlir][sparse] complete migration to dim2lvl/lvl2dim in library (PR #69268)
Peiming Liu
llvmlistbot at llvm.org
Tue Oct 17 09:32:54 PDT 2023
================
@@ -669,184 +599,48 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
return -1u;
}
- // Allow `SparseTensorEnumerator` to access the data-members (to avoid
- // the cost of virtual-function dispatch in inner loops), without
- // making them public to other client code.
- friend class SparseTensorEnumerator<P, C, V>;
-
- std::vector<std::vector<P>> positions;
- std::vector<std::vector<C>> coordinates;
- std::vector<V> values;
- std::vector<uint64_t> lvlCursor; // cursor for lexicographic insertion.
- SparseTensorCOO<V> *lvlCOO; // COO used during forwarding
-};
-
-//===----------------------------------------------------------------------===//
-//
-// SparseTensorEnumerator
-//
-//===----------------------------------------------------------------------===//
-
-/// A (higher-order) function object for enumerating the elements of some
-/// `SparseTensorStorage` under a permutation. That is, the `forallElements`
-/// method encapsulates the loop-nest for enumerating the elements of
-/// the source tensor (in whatever order is best for the source tensor),
-/// and applies a permutation to the coordinates before handing
-/// each element to the callback. A single enumerator object can be
-/// freely reused for several calls to `forallElements`, just so long
-/// as each call is sequential with respect to one another.
-///
-/// N.B., this class stores a reference to the `SparseTensorStorageBase`
-/// passed to the constructor; thus, objects of this class must not
-/// outlive the sparse tensor they depend on.
-///
-/// Design Note: The reason we define this class instead of simply using
-/// `SparseTensorEnumerator<P,C,V>` is because we need to hide/generalize
-/// the `<P,C>` template parameters from MLIR client code (to simplify the
-/// type parameters used for direct sparse-to-sparse conversion). And the
-/// reason we define the `SparseTensorEnumerator<P,C,V>` subclasses rather
-/// than simply using this class, is to avoid the cost of virtual-method
-/// dispatch within the loop-nest.
-template <typename V>
-class SparseTensorEnumeratorBase {
-public:
- /// Constructs an enumerator which automatically applies the given
- /// mapping from the source tensor's dimensions to the desired
- /// target tensor dimensions.
- ///
- /// Preconditions:
- /// * the `src` must have the same `V` value type.
- /// * `trgSizes` must be valid for `trgRank`.
- /// * `src2trg` must be valid for `srcRank`, and must map coordinates
- /// valid for `src.getDimSizes()` to coordinates valid for `trgSizes`.
- ///
- /// Asserts:
- /// * `trgSizes` must be nonnull and must contain only nonzero sizes.
- /// * `srcRank == src.getDimRank()`.
- /// * `src2trg` must be nonnull.
- SparseTensorEnumeratorBase(const SparseTensorStorageBase &src,
- uint64_t trgRank, const uint64_t *trgSizes,
- uint64_t srcRank, const uint64_t *src2trg)
- : src(src), trgSizes(trgSizes, trgSizes + trgRank),
- lvl2trg(src.getLvlRank()), trgCursor(trgRank) {
- assert(trgSizes && "Received nullptr for target-sizes");
- assert(src2trg && "Received nullptr for source-to-target mapping");
- assert(srcRank == src.getDimRank() && "Source-rank mismatch");
- for (uint64_t t = 0; t < trgRank; ++t)
- assert(trgSizes[t] > 0 && "Target-size zero has trivial storage");
- const auto &lvl2src = src.getLvl2Dim();
- for (uint64_t lvlRank = src.getLvlRank(), l = 0; l < lvlRank; ++l)
- lvl2trg[l] = src2trg[lvl2src[l]];
- }
-
- virtual ~SparseTensorEnumeratorBase() = default;
-
- // We disallow copying to help avoid leaking the `src` reference.
- // (In addition to avoiding the problem of slicing.)
- SparseTensorEnumeratorBase(const SparseTensorEnumeratorBase &) = delete;
- SparseTensorEnumeratorBase &
- operator=(const SparseTensorEnumeratorBase &) = delete;
-
- /// Gets the source's dimension-rank.
- uint64_t getSrcDimRank() const { return src.getDimRank(); }
-
- /// Gets the target's dimension-/level-rank. (This is usually
- /// "dimension-rank", though that may coincide with "level-rank"
- /// depending on usage.)
- uint64_t getTrgRank() const { return trgSizes.size(); }
-
- /// Gets the target's dimension-/level-sizes. (These are usually
- /// "dimensions", though that may coincide with "level-rank" depending
- /// on usage.)
- const std::vector<uint64_t> &getTrgSizes() const { return trgSizes; }
-
- /// Enumerates all elements of the source tensor, permutes their
- /// coordinates, and passes the permuted element to the callback.
- /// The callback must not store the cursor reference directly,
- /// since this function reuses the storage. Instead, the callback
- /// must copy it if they want to keep it.
- virtual void forallElements(ElementConsumer<V> yield) = 0;
-
-protected:
- const SparseTensorStorageBase &src;
- std::vector<uint64_t> trgSizes; // in target order.
- std::vector<uint64_t> lvl2trg; // source-levels -> target-dims/lvls.
- std::vector<uint64_t> trgCursor; // in target order.
-};
-
-template <typename P, typename C, typename V>
-class SparseTensorEnumerator final : public SparseTensorEnumeratorBase<V> {
- using Base = SparseTensorEnumeratorBase<V>;
- using StorageImpl = SparseTensorStorage<P, C, V>;
-
-public:
- /// Constructs an enumerator which automatically applies the given
- /// mapping from the source tensor's dimensions to the desired
- /// target tensor dimensions.
- ///
- /// Preconditions/assertions are as per the `SparseTensorEnumeratorBase` ctor.
- SparseTensorEnumerator(const StorageImpl &src, uint64_t trgRank,
- const uint64_t *trgSizes, uint64_t srcRank,
- const uint64_t *src2trg)
- : Base(src, trgRank, trgSizes, srcRank, src2trg) {}
-
- ~SparseTensorEnumerator() final = default;
-
- void forallElements(ElementConsumer<V> yield) final {
- forallElements(yield, 0, 0);
- }
-
-private:
- // TODO: Once we functionalize the mappings, then we'll no longer
- // be able to use the current approach of constructing `lvl2trg` in the
- // ctor and using it to incrementally fill the `trgCursor` cursor as we
- // recurse through `forallElements`. Instead we'll want to incrementally
- // fill a `lvlCursor` as we recurse, and then use `src.getLvl2Dim()`
- // and `src2trg` to convert it just before yielding to the callback.
- // It's probably most efficient to just store the `srcCursor` and
- // `trgCursor` buffers in this object, but we may want to benchmark
- // that against using `std::calloc` to stack-allocate them instead.
- //
- /// The recursive component of the public `forallElements`.
- void forallElements(ElementConsumer<V> yield, uint64_t parentPos,
- uint64_t l) {
- // Recover the `<P,C,V>` type parameters of `src`.
- const auto &src = static_cast<const StorageImpl &>(this->src);
- if (l == src.getLvlRank()) {
- assert(parentPos < src.values.size());
- // TODO: <https://github.com/llvm/llvm-project/issues/54179>
- yield(this->trgCursor, src.values[parentPos]);
+ // Performs forall on level entries and inserts into dim COO.
+ void toCOO(uint64_t parentPos, uint64_t l, std::vector<uint64_t> &dimCoords) {
+ if (l == getLvlRank()) {
+ map.pushbackward(lvlCursor.data(), dimCoords.data());
+ assert(coo);
+ assert(parentPos < values.size());
+ coo->add(dimCoords, values[parentPos]);
return;
}
- uint64_t &cursorL = this->trgCursor[this->lvl2trg[l]];
- const auto dlt = src.getLvlType(l); // Avoid redundant bounds checking.
- if (isCompressedDLT(dlt)) {
+ if (isCompressedLvl(l)) {
// Look up the bounds of the `l`-level segment determined by the
// `(l - 1)`-level position `parentPos`.
- const std::vector<P> &positionsL = src.positions[l];
+ const std::vector<P> &positionsL = positions[l];
assert(parentPos + 1 < positionsL.size());
const uint64_t pstart = static_cast<uint64_t>(positionsL[parentPos]);
const uint64_t pstop = static_cast<uint64_t>(positionsL[parentPos + 1]);
// Loop-invariant code for looking up the `l`-level coordinates.
- const std::vector<C> &coordinatesL = src.coordinates[l];
+ const std::vector<C> &coordinatesL = coordinates[l];
assert(pstop <= coordinatesL.size());
for (uint64_t pos = pstart; pos < pstop; ++pos) {
- cursorL = static_cast<uint64_t>(coordinatesL[pos]);
- forallElements(yield, pos, l + 1);
+ lvlCursor[l] = static_cast<uint64_t>(coordinatesL[pos]);
+ toCOO(pos, l + 1, dimCoords);
}
- } else if (isSingletonDLT(dlt)) {
- cursorL = src.getCrd(l, parentPos);
- forallElements(yield, parentPos, l + 1);
+ } else if (isSingletonLvl(l)) {
+ lvlCursor[l] = getCrd(l, parentPos);
+ toCOO(parentPos, l + 1, dimCoords);
} else { // Dense level.
- assert(isDenseDLT(dlt));
- const uint64_t sz = src.getLvlSizes()[l];
+ assert(isDenseLvl(l));
+ const uint64_t sz = getLvlSizes()[l];
const uint64_t pstart = parentPos * sz;
for (uint64_t c = 0; c < sz; ++c) {
- cursorL = c;
- forallElements(yield, pstart + c, l + 1);
+ lvlCursor[l] = c;
+ toCOO(pstart + c, l + 1, dimCoords);
}
}
}
+
+ std::vector<std::vector<P>> positions;
+ std::vector<std::vector<C>> coordinates;
+ std::vector<V> values;
+ std::vector<uint64_t> lvlCursor;
+ SparseTensorCOO<V> *coo;
----------------
PeimingLiu wrote:
How do you guarantee that it wouldn't leak? Do we have some assumptions here on how user call the APIs?
https://github.com/llvm/llvm-project/pull/69268
More information about the Mlir-commits
mailing list