[Mlir-commits] [mlir] [mlir][sparse] use shared value storage between wrapped iterator and the wrapper. (PR #80046)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jan 30 11:15:31 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Peiming Liu (PeimingLiu)
<details>
<summary>Changes</summary>
---
Patch is 21.58 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/80046.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp (+80-96)
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h (+59-26)
``````````diff
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
index e43896942d7f..11f3fea335d5 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
@@ -205,34 +205,48 @@ static Value offsetFromMinCrd(OpBuilder &b, Location l, Value minCrd,
namespace {
-class TrivialIterator : public SparseIterator {
- Value getLoopLo(OpBuilder &b, Location l) const {
- // Dense loop are traversed by coordinate, delinearize the position to get
- // the coordinate.
- if (randomAccessible())
- return SUBI(itPos, posLo);
- return itPos;
+// The iterator that that traverse a concrete sparse tensor levels. High-level
+// abstract iterators wrap it to achieve more complex goals (such as collapsing
+// several levels). It also holds the common storage to hold the mlir::Values
+// for itself as well as for wrappers.
+class ConcreteIterator : public SparseIterator {
+protected:
+ ConcreteIterator(const SparseTensorLevel &stl, IterKind kind,
+ unsigned itValCnt)
+ : SparseIterator(kind, stl.tid, stl.lvl, itValCnt, itValsStorage),
+ stl(stl) {
+ // Allocate enough storage for iterator values.
+ itValsStorage.resize(itValCnt);
}
public:
- TrivialIterator(const SparseTensorLevel &stl,
- const IterKind kind = IterKind::kTrivial)
- : SparseIterator(kind, stl.tid, stl.lvl, itPos), stl(stl) {}
-
// For LLVM-style RTTI.
static bool classof(const SparseIterator *from) {
return from->kind == IterKind::kTrivial;
}
bool randomAccessible() const override { return isDenseLT(stl.getLT()); };
- bool iteratableByFor() const override { return true; };
+ bool iteratableByFor() const override { return kind != IterKind::kDedup; };
Value upperBound(OpBuilder &b, Location l) const override {
return stl.size();
};
+protected:
+ // Owner of the storage, all wrappers build on top of a concrete iterator
+ // shares the same storage such that the iterator values are always
+ // synchronized.
+ SmallVector<Value> itValsStorage;
+ const SparseTensorLevel &stl;
+};
+
+class TrivialIterator : public ConcreteIterator {
+public:
+ TrivialIterator(const SparseTensorLevel &stl)
+ : ConcreteIterator(stl, IterKind::kTrivial, /*itValCnt=*/1) {}
+
SmallVector<Value> serialize() const override {
SmallVector<Value> ret;
- ret.push_back(itPos);
+ ret.push_back(getItPos());
if (randomAccessible()) {
// Loop high is implicit (defined by `upperBound()`) for random-access
// iterator, but we need to memorize posLo for linearization.
@@ -252,10 +266,10 @@ class TrivialIterator : public SparseIterator {
posHi = vs.back();
};
- ValuePair getCurPosition() const override { return {itPos, nullptr}; }
+ ValuePair getCurPosition() const override { return {getItPos(), nullptr}; }
- void genInit(OpBuilder &b, Location l,
- const SparseIterator *parent) override {
+ void genInitImpl(OpBuilder &b, Location l,
+ const SparseIterator *parent) override {
Value pos = C_IDX(0);
Value hi = nullptr;
if (parent)
@@ -269,25 +283,25 @@ class TrivialIterator : public SparseIterator {
ValuePair genForCond(OpBuilder &b, Location l) override {
if (randomAccessible())
return {deref(b, l), upperBound(b, l)};
- return std::make_pair(getLoopLo(b, l), posHi);
+ return std::make_pair(getItPos(), posHi);
}
Value genNotEnd(OpBuilder &b, Location l) override {
// We used the first level bound as the bound the collapsed set of levels.
- return CMPI(ult, itPos, posHi);
+ return CMPI(ult, getItPos(), posHi);
}
Value deref(OpBuilder &b, Location l) override {
if (randomAccessible()) {
- updateCrd(SUBI(itPos, posLo));
+ updateCrd(SUBI(getItPos(), posLo));
} else {
- updateCrd(stl.peekCrdAt(b, l, itPos));
+ updateCrd(stl.peekCrdAt(b, l, getItPos()));
}
return getCrd();
};
- ValueRange forward(OpBuilder &b, Location l) override {
- seek(ADDI(itPos, C_IDX(1)));
+ ValueRange forwardImpl(OpBuilder &b, Location l) override {
+ seek(ADDI(getItPos(), C_IDX(1)));
return getItVals();
}
@@ -305,20 +319,17 @@ class TrivialIterator : public SparseIterator {
updateCrd(crd);
}
- Value itPos; // the position that represent the iterator
-
+ Value getItPos() const { return getItVals().front(); }
Value posLo, posHi;
- const SparseTensorLevel &stl;
};
-class DedupIterator : public SparseIterator {
+class DedupIterator : public ConcreteIterator {
private:
Value genSegmentHigh(OpBuilder &b, Location l, Value pos);
public:
DedupIterator(const SparseTensorLevel &stl)
- : SparseIterator(IterKind::kDedup, stl.tid, stl.lvl, posAndSegHi),
- stl(stl) {
+ : ConcreteIterator(stl, IterKind::kDedup, /*itValCnt=*/2) {
assert(!stl.isUnique());
}
// For LLVM-style RTTI.
@@ -326,16 +337,10 @@ class DedupIterator : public SparseIterator {
return from->kind == IterKind::kDedup;
}
- bool randomAccessible() const override { return false; };
- bool iteratableByFor() const override { return false; };
- Value upperBound(OpBuilder &b, Location l) const override {
- return stl.size();
- };
-
ValuePair getCurPosition() const override { return {getPos(), getSegHi()}; }
- void genInit(OpBuilder &b, Location l,
- const SparseIterator *parent) override {
+ void genInitImpl(OpBuilder &b, Location l,
+ const SparseIterator *parent) override {
Value pos = C_IDX(0);
Value hi = nullptr;
@@ -369,18 +374,16 @@ class DedupIterator : public SparseIterator {
return getCrd();
};
- ValueRange forward(OpBuilder &b, Location l) override {
+ ValueRange forwardImpl(OpBuilder &b, Location l) override {
Value nxPos = getSegHi(); // forward the position to the next segment.
seek({nxPos, genSegmentHigh(b, l, nxPos)});
return getItVals();
}
- Value getPos() const { return posAndSegHi[0]; }
- Value getSegHi() const { return posAndSegHi[1]; }
+ Value getPos() const { return getItVals()[0]; }
+ Value getSegHi() const { return getItVals()[1]; }
Value posHi;
- Value posAndSegHi[2]; // position and segment high
- const SparseTensorLevel &stl;
};
//
@@ -424,8 +427,8 @@ class FilterIterator : public SparseIterator {
void deserialize(ValueRange vs) override { wrap->deserialize(vs); };
ValuePair getCurPosition() const override { return wrap->getCurPosition(); }
- void genInit(OpBuilder &b, Location l,
- const SparseIterator *parent) override {
+ void genInitImpl(OpBuilder &b, Location l,
+ const SparseIterator *parent) override {
wrap->genInit(b, l, parent);
if (!randomAccessible()) {
// TODO: we can skip this when stride == 1 and offset == 0, we can also
@@ -451,9 +454,9 @@ class FilterIterator : public SparseIterator {
updateCrd(crd);
}
- ValueRange forward(OpBuilder &b, Location l) override;
+ ValueRange forwardImpl(OpBuilder &b, Location l) override;
- const Value offset, stride, size;
+ Value offset, stride, size;
std::unique_ptr<SparseIterator> wrap;
};
@@ -467,7 +470,7 @@ class NonEmptySubSectIterator : public SparseIterator {
std::unique_ptr<SparseIterator> &&delegate,
Value subSectSz)
: SparseIterator(IterKind::kNonEmptySubSect, delegate->tid, delegate->lvl,
- /*itVals=*/subSectMeta),
+ 3, /*itVals=*/subSectMeta),
parent(parent), delegate(std::move(delegate)),
tupleSz(this->delegate->serialize().size()), subSectSz(subSectSz) {
auto *p = dyn_cast_or_null<NonEmptySubSectIterator>(parent);
@@ -555,7 +558,7 @@ class NonEmptySubSectIterator : public SparseIterator {
return ADDI(SUBI(parentUB, subSectSz), C_IDX(1));
};
- void genInit(OpBuilder &b, Location l, const SparseIterator *) override;
+ void genInitImpl(OpBuilder &b, Location l, const SparseIterator *) override;
void locate(OpBuilder &b, Location l, Value crd) override {
Value absOff = crd;
@@ -587,7 +590,7 @@ class NonEmptySubSectIterator : public SparseIterator {
return crd;
};
- ValueRange forward(OpBuilder &b, Location l) override;
+ ValueRange forwardImpl(OpBuilder &b, Location l) override;
Value getMinCrd() const { return subSectMeta[0]; }
Value getAbsOff() const { return subSectMeta[1]; }
@@ -605,7 +608,8 @@ class NonEmptySubSectIterator : public SparseIterator {
const Value subSectSz;
- Value subSectMeta[3]; // minCrd, absolute offset, notEnd
+ // minCrd, absolute offset, notEnd
+ SmallVector<Value, 3> subSectMeta{nullptr, nullptr, nullptr};
};
class SubSectIterator;
@@ -628,41 +632,18 @@ struct SubSectIterHelper {
};
class SubSectIterator : public SparseIterator {
- // RAII to sync iterator values between the wrap the iterator and the
- // SubSectIterator.
- struct WrapItValSyncer {
- explicit WrapItValSyncer(SubSectIterator &it) : it(it) {
- if (!it.randomAccessible())
- it.wrap->seek(it.getItVals().drop_back());
- }
- ~WrapItValSyncer() {
- if (!it.randomAccessible()) {
- ValueRange wrapItVals = it.wrap->getItVals();
- std::copy(wrapItVals.begin(), wrapItVals.end(), it.itVals.begin());
- }
- }
- SubSectIterator ⁢
- };
-
public:
SubSectIterator(const NonEmptySubSectIterator &subSect,
const SparseIterator &parent,
std::unique_ptr<SparseIterator> &&wrap, Value size,
unsigned stride)
- : SparseIterator(IterKind::kSubSect, *wrap), itVals(), subSect(subSect),
- wrap(std::move(wrap)), parent(parent), size(size), stride(stride),
- helper(*this) {
+ : SparseIterator(IterKind::kSubSect, *wrap,
+ /*extraVal=*/wrap->randomAccessible() ? 0 : 1),
+ subSect(subSect), wrap(std::move(wrap)), parent(parent), size(size),
+ stride(stride), helper(*this) {
assert(stride == 1 && "Not implemented.");
assert(subSect.tid == tid && subSect.lvl == lvl);
assert(parent.kind != IterKind::kSubSect || parent.lvl + 1 == lvl);
-
- if (!randomAccessible()) {
- // We maintain a extra counter to count the actually sparse coordinate
- // included in the subsection.
- unsigned itValSz = this->wrap->getItVals().size() + 1;
- itVals.resize(itValSz, nullptr);
- relinkItVals(itVals);
- }
};
// For LLVM-style RTTI.
@@ -681,11 +662,10 @@ class SubSectIterator : public SparseIterator {
if (randomAccessible()) {
return ADDI(getCrd(), nxLvlTupleStart);
};
- return ADDI(itVals.back(), nxLvlTupleStart);
+ return ADDI(getItVals().back(), nxLvlTupleStart);
}
- void genInit(OpBuilder &b, Location l, const SparseIterator *) override {
- WrapItValSyncer syncer(*this);
+ void genInitImpl(OpBuilder &b, Location l, const SparseIterator *) override {
if (randomAccessible()) {
if (auto *p = llvm::dyn_cast<SubSectIterator>(&parent)) {
assert(p->lvl + 1 == lvl);
@@ -700,10 +680,10 @@ class SubSectIterator : public SparseIterator {
return;
}
assert(!randomAccessible());
- assert(itVals.size() == wrap->getItVals().size() + 1);
+ assert(getItVals().size() == wrap->getItVals().size() + 1);
// Extra counter that counts the number of actually visited coordinates in
// the sparse subsection.
- itVals.back() = C_IDX(0);
+ getMutItVals().back() = C_IDX(0);
Value tupleId;
if (auto *p = llvm::dyn_cast<SubSectIterator>(&parent)) {
assert(p->lvl + 1 == lvl);
@@ -717,35 +697,28 @@ class SubSectIterator : public SparseIterator {
}
void locate(OpBuilder &b, Location l, Value crd) override {
- WrapItValSyncer syncer(*this);
helper.locate(b, l, crd);
updateCrd(crd);
}
Value genNotEnd(OpBuilder &b, Location l) override {
- WrapItValSyncer syncer(*this);
return helper.genNotEnd(b, l);
}
Value deref(OpBuilder &b, Location l) override {
- WrapItValSyncer syncer(*this);
Value crd = helper.deref(b, l);
updateCrd(crd);
return crd;
};
- ValueRange forward(OpBuilder &b, Location l) override {
- {
- WrapItValSyncer syncer(*this);
- helper.forward(b, l);
- }
+ ValueRange forwardImpl(OpBuilder &b, Location l) override {
+ helper.forward(b, l);
assert(!randomAccessible());
- assert(itVals.size() == wrap->getItVals().size() + 1);
- itVals.back() = ADDI(itVals.back(), C_IDX(1));
+ assert(getItVals().size() == wrap->getItVals().size() + 1);
+ getMutItVals().back() = ADDI(getItVals().back(), C_IDX(1));
return getItVals();
};
- SmallVector<Value> itVals;
Value nxLvlTupleStart;
const NonEmptySubSectIterator &subSect;
@@ -764,6 +737,17 @@ class SubSectIterator : public SparseIterator {
// SparseIterator derived classes implementation.
//===----------------------------------------------------------------------===//
+void SparseIterator::genInit(OpBuilder &b, Location l,
+ const SparseIterator *p) {
+ // TODO: support lowering to function call.
+ return genInitImpl(b, l, p);
+}
+
+ValueRange SparseIterator::forward(OpBuilder &b, Location l) {
+ // TODO: support lowering to function call.
+ return forwardImpl(b, l);
+}
+
ValueRange SparseIterator::forwardIf(OpBuilder &b, Location l, Value cond) {
auto ifOp = b.create<scf::IfOp>(l, getItVals().getTypes(), cond, true);
// Generate else branch first, otherwise iterator values will be updated by
@@ -846,7 +830,7 @@ Value FilterIterator::genNotEnd(OpBuilder &b, Location l) {
return r.front();
}
-ValueRange FilterIterator::forward(OpBuilder &b, Location l) {
+ValueRange FilterIterator::forwardImpl(OpBuilder &b, Location l) {
assert(!randomAccessible());
// Generates
//
@@ -1013,8 +997,8 @@ ValueRange NonEmptySubSectIterator::inflateSubSectTree(
return p->inflateSubSectTree(b, l, reduc, visitDenseSubSect);
}
-void NonEmptySubSectIterator::genInit(OpBuilder &b, Location l,
- const SparseIterator *) {
+void NonEmptySubSectIterator::genInitImpl(OpBuilder &b, Location l,
+ const SparseIterator *) {
Value c0 = C_IDX(0);
if (!isSubSectRoot()) {
assert(parent->lvl + 1 == lvl);
@@ -1096,7 +1080,7 @@ void NonEmptySubSectIterator::genInit(OpBuilder &b, Location l,
seek(meta);
}
-ValueRange NonEmptySubSectIterator::forward(OpBuilder &b, Location l) {
+ValueRange NonEmptySubSectIterator::forwardImpl(OpBuilder &b, Location l) {
assert(!randomAccessible());
Value c0 = C_IDX(0), c1 = C_IDX(1);
// Forward to the next non empty slice by generating
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
index d2b3396b7283..1771a058f00c 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
@@ -80,24 +80,37 @@ class SparseIterator {
SparseIterator &operator=(const SparseIterator &) = delete;
protected:
- SparseIterator(IterKind kind, unsigned tid, unsigned lvl,
- MutableArrayRef<Value> itVals)
- : kind(kind), tid(tid), lvl(lvl), crd(nullptr), itVals(itVals){};
+ SparseIterator(IterKind kind, unsigned tid, unsigned lvl, unsigned itValsCnt,
+ SmallVectorImpl<Value> &storage)
+ : kind(kind), tid(tid), lvl(lvl), crd(nullptr), itValsCnt(itValsCnt),
+ itValsStorageRef(storage){};
SparseIterator(IterKind kind, const SparseIterator &wrap)
: kind(kind), tid(wrap.tid), lvl(wrap.lvl), crd(nullptr),
- itVals(wrap.itVals){};
+ itValsCnt(wrap.itValsCnt), itValsStorageRef(wrap.itValsStorageRef) {
+ assert(wrap.itValsCnt == itValsStorage.size());
+ };
+
+ SparseIterator(IterKind kind, const SparseIterator &wrap, unsigned extraVal)
+ : kind(kind), tid(wrap.tid), lvl(wrap.lvl), crd(nullptr),
+ itValsCnt(wrap.itValsCnt + extraVal),
+ itValsStorageRef(wrap.itValsStorageRef) {
+ itValsStorageRef.append(extraVal, nullptr);
+ assert(itValsCnt == itValsStorage.size());
+ };
public:
virtual ~SparseIterator() = default;
Value getCrd() const { return crd; }
- ValueRange getItVals() const { return itVals; };
+ ValueRange getItVals() const {
+ return ValueRange(itValsStorageRef).take_front(itValsCnt);
+ };
// Sets the iterate to the specified position.
void seek(ValueRange vals) {
- assert(vals.size() == itVals.size());
- std::copy(vals.begin(), vals.end(), itVals.begin());
+ assert(vals.size() == itValsCnt);
+ std::copy(vals.begin(), vals.end(), itValsStorageRef.begin());
// Now that the iterator is re-positioned, the coordinate becomes invalid.
crd = nullptr;
}
@@ -119,8 +132,8 @@ class SparseIterator {
virtual Value upperBound(OpBuilder &b, Location l) const = 0;
// Serializes and deserializes the current status to/from a set of values. The
- // ValueRange should contain values that specifies the current postion and
- // loop bound.
+ // ValueRange should contain values that are suffcient to recover the current
+ // iterating postion (i.e., itVals) as well as loop bound.
//
// Not every type of iterator supports the operations, e.g., non-empty
// subsection iterator does not because the the number of non-empty
@@ -132,10 +145,33 @@ class SparseIterator {
};
virtual void deserialize(ValueRange vs) { llvm_unreachable("unsupported"); };
+ // virtual std::string getFuncNamePostfix() const = 0;
+ // virtual SmallVector<Value> toFuncArgs() const = 0;
+ // virtual void linkFuncScope(ValueRange ret) = 0;
+
//
// Core functions.
//
+ // Initializes the iterator according to the parent iterator's state.
+ void genInit(OpBuilder &b, Location l, const SparseIterator *p);
+
+ // Forwards the iterator to the next element.
+ ValueRange forward(OpBuilder &b, Location l);
+
+ // Actual Implementation provided by derived class.
+ virtual void genInitImpl(OpBuilder &, Location, const SparseIterator *) = 0;
+ virtual ValueRange forwardImpl(OpBuilder &b, Location l) = 0;
+
+ // Returns a boolean value that equals `!it.end()`
+ virtual Value genNotEnd(OpBuilder &b, Location l) = 0;
+
+ // Dereference the iterator, loads the coordinate at the current position.
+ //
+ // The method assumes that the iterator is not currently exhausted (i.e.,
+ // it != it.end()).
+ virtual Value deref(OpBuilder &b, Location l) = 0;
+
// Gets the current position and the optional *position high* (for non-unique
// iterators), the value is essentially the number of sparse coordinate that
// the iterator is current visiting. It should be able to uniquely identify
@@ -148,9 +184,6 @@ class SparseIterator {
llvm_unreachable("unsupported");
};
- // Initializes the iterator according to the parent iterator's state.
- virtual void genInit(OpBuilder &, Location, const SparseIterator *) = 0;
-
// Returns a pair of values for *upper*, *lower* bound respectively.
virtual std::pair<Value, Value> genForCond(OpBuilder &b, Location l) {
assert(randomAccessible());
@@ -158,22 +191,13 @@ class SparseIterator {
return {getCrd(), upperBound(b, l)};
}
- // Returns a boolean value that equals `!it.end()`
- virtual Value genNotEnd(OpBuilder &b, Location l) = 0;
+ // Generates a bool value for scf::ConditionOp.
std::pair<Value, ValueRange> genWhileCond(OpBuilder &b, Location l,
ValueRange vs) {
ValueRange rem = linkNewScope(vs);
return std::make_pair(genNotEnd(b, l), rem);
}
- // Dereference the iterator, loads the coordinate at the current position.
- //
- // The method assumes that the iterator is not currently exhausted (i.e.,
- // it != it.end()).
- virtual Value deref(OpBuilder &b, Location l) = 0;
-
- virtual ValueRange forward(OpBuilder &b, Location l) = 0;
-
// Generate a conditional it.next() in the following form
//
// if (cond)
@@ -198,13 +222,13 @@ ...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/80046
More information about the Mlir-commits
mailing list