[Mlir-commits] [mlir] [mlir][sparse] use ValueRange instead of std::pair for iterator position. (PR #90243)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Apr 26 11:11:03 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Peiming Liu (PeimingLiu)
<details>
<summary>Changes</summary>
`ValueRange` is more easy to be extended (e.g., for padded iterator).
---
Full diff: https://github.com/llvm/llvm-project/pull/90243.diff
3 Files Affected:
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h (+1-1)
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp (+34-31)
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h (+12-15)
``````````diff
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
index 59c3e49264dbe1..34312df912997b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
@@ -222,7 +222,7 @@ class LoopEmitter {
///
SmallVector<Value> getValPosits(TensorId tid) const {
SmallVector<Value> batchCrds = iters[tid].back().back()->getBatchCrds();
- Value lastLvlPos = iters[tid].back().back()->getCurPosition().first;
+ Value lastLvlPos = iters[tid].back().back()->getCurPosition().front();
batchCrds.push_back(lastLvlPos);
return batchCrds;
};
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
index 60dca3c55dec3d..745c081247dee8 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
@@ -94,8 +94,10 @@ class DenseLevel : public SparseTensorLevel {
ValueRange getLvlBuffers() const override { return {}; }
- ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange, Value p,
- Value max) const override {
+ ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
+ ValueRange parentPos) const override {
+ assert(parentPos.size() == 1 && "Dense level can not be non-unique.");
+ Value p = parentPos.front();
Value posLo = MULI(p, lvlSize);
return {posLo, lvlSize};
}
@@ -112,9 +114,9 @@ class BatchLevel : public SparseTensorLevel {
ValueRange getLvlBuffers() const override { return {}; }
- ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange, Value p,
- Value max) const override {
- assert(max == nullptr && "Dense level can not be non-unique.");
+ ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange,
+ ValueRange parentPos) const override {
+ assert(parentPos.size() == 1 && "Dense level can not be non-unique.");
// No need to linearize the position for non-annotated tensors.
return {C_IDX(0), lvlSize};
}
@@ -127,9 +129,11 @@ class CompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
: SparseLevel(tid, lvl, lt, lvlSize, {posBuffer, crdBuffer}) {}
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
- Value p, Value max) const override {
- assert(max == nullptr &&
+ ValueRange parentPos) const override {
+
+ assert(parentPos.size() == 1 &&
"compressed level must be the first non-unique level.");
+ Value p = parentPos.front();
SmallVector<Value> memCrd(batchPrefix);
memCrd.push_back(p);
@@ -147,11 +151,11 @@ class LooseCompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
: SparseLevel(tid, lvl, lt, lvlSize, {posBuffer, crdBuffer}) {}
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
- Value p, Value max) const override {
- assert(max == nullptr &&
+ ValueRange parentPos) const override {
+ assert(parentPos.size() == 1 &&
"loose-compressed level must be the first non-unique level.");
SmallVector<Value> memCrd(batchPrefix);
-
+ Value p = parentPos.front();
p = MULI(p, C_IDX(2));
memCrd.push_back(p);
Value pLo = genIndexLoad(b, l, getPosBuf(), memCrd);
@@ -168,10 +172,13 @@ class SingletonLevel : public SparseLevel</*hasPosBuf=*/false> {
: SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer}) {}
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
- Value p, Value segHi) const override {
+ ValueRange parentPos) const override {
+ assert(parentPos.size() == 1 || parentPos.size() == 2);
+ Value p = parentPos.front();
+ Value segHi = parentPos.size() == 2 ? parentPos.back() : nullptr;
+
if (segHi == nullptr)
return {p, ADDI(p, C_IDX(1))};
-
// Use the segHi as the loop upper bound.
return {p, segHi};
}
@@ -184,11 +191,12 @@ class NOutOfMLevel : public SparseLevel</*hasPosBuf=*/false> {
: SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer}) {}
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
- Value p, Value max) const override {
- assert(max == nullptr && isUnique() && "n:m level can not be non-unique.");
+ ValueRange parentPos) const override {
+ assert(parentPos.size() == 1 && isUnique() &&
+ "n:m level can not be non-unique.");
// Each n:m blk has exactly n specified elements.
auto n = getN(lt);
- Value posLo = MULI(p, C_IDX(n));
+ Value posLo = MULI(parentPos.front(), C_IDX(n));
return {posLo, ADDI(posLo, C_IDX(n))};
}
};
@@ -316,23 +324,21 @@ class TrivialIterator : public ConcreteIterator {
posHi = vs.back();
};
- ValuePair getCurPosition() const override { return {getItPos(), nullptr}; }
-
void genInitImpl(OpBuilder &b, Location l,
const SparseIterator *parent) override {
if (isBatchIterator() && batchCrds.size() <= stl.lvl)
batchCrds.resize(stl.lvl + 1, nullptr);
- Value pos = C_IDX(0);
- Value hi = nullptr;
+ Value c0 = C_IDX(0);
+ ValueRange pPos = c0;
// If the parent iterator is a batch iterator, we also start from 0 (but
// on a different batch).
if (parent && !parent->isBatchIterator())
- std::tie(pos, hi) = parent->getCurPosition();
+ pPos = parent->getCurPosition();
ValueRange batchPrefix = parent ? parent->getBatchCrds() : ValueRange{};
- std::tie(posLo, posHi) = stl.peekRangeAt(b, l, batchPrefix, pos, hi);
+ std::tie(posLo, posHi) = stl.peekRangeAt(b, l, batchPrefix, pPos);
// Seek to the lowest position.
seek(posLo);
}
@@ -406,21 +412,19 @@ class DedupIterator : public ConcreteIterator {
return {b.getIndexType(), b.getIndexType()};
}
- ValuePair getCurPosition() const override { return {getPos(), getSegHi()}; }
-
void genInitImpl(OpBuilder &b, Location l,
const SparseIterator *parent) override {
+ Value c0 = C_IDX(0);
+ ValueRange pPos = c0;
- Value pos = C_IDX(0);
- Value hi = nullptr;
// If the parent iterator is a batch iterator, we also start from 0 (but
// on a different batch).
if (parent && !parent->isBatchIterator())
- std::tie(pos, hi) = parent->getCurPosition();
+ pPos = parent->getCurPosition();
Value posLo;
ValueRange batchPrefix = parent ? parent->getBatchCrds() : ValueRange{};
- std::tie(posLo, posHi) = stl.peekRangeAt(b, l, batchPrefix, pos, hi);
+ std::tie(posLo, posHi) = stl.peekRangeAt(b, l, batchPrefix, pPos);
seek({posLo, genSegmentHigh(b, l, posLo)});
}
@@ -505,7 +509,7 @@ class FilterIterator : public SparseIterator {
SmallVector<Value> serialize() const override { return wrap->serialize(); };
void deserialize(ValueRange vs) override { wrap->deserialize(vs); };
- ValuePair getCurPosition() const override { return wrap->getCurPosition(); }
+ ValueRange getCurPosition() const override { return wrap->getCurPosition(); }
void genInitImpl(OpBuilder &b, Location l,
const SparseIterator *parent) override {
@@ -756,9 +760,8 @@ class SubSectIterator : public SparseIterator {
Value upperBound(OpBuilder &b, Location l) const override {
return subSect.subSectSz;
}
- std::pair<Value, Value> getCurPosition() const override {
- return wrap->getCurPosition();
- };
+
+ ValueRange getCurPosition() const override { return wrap->getCurPosition(); };
Value getNxLvlTupleId(OpBuilder &b, Location l) const {
if (randomAccessible()) {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
index 46b923250dd893..b692848ec67bd8 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
@@ -36,8 +36,9 @@ class SparseTensorLevel {
Value iv) const = 0;
/// Peeks the lower and upper bound to *fully* traverse the level with
- /// the given position `p` that the immediate parent level is current at.
- /// Returns a pair of values for *posLo* and *loopHi* respectively.
+ /// the given position `parentPos`, see SparseTensorIterator::getCurPostion(),
+ /// that the immediate parent level is current at. Returns a pair of values
+ /// for *posLo* and *loopHi* respectively.
///
/// For a dense level, the *posLo* is the linearized position at beginning,
/// while *loopHi* is the largest *coordinate*, it also implies that the
@@ -45,12 +46,9 @@ class SparseTensorLevel {
///
/// For a sparse level, [posLo, loopHi) specifies the range of index pointer
/// to load coordinate from the coordinate buffer.
- ///
- /// `bound` is only used when the level is `non-unique` and deduplication is
- /// required. It specifies the max upper bound of the non-unique segment.
virtual std::pair<Value, Value> peekRangeAt(OpBuilder &b, Location l,
- ValueRange batchPrefix, Value p,
- Value segHi = Value()) const = 0;
+ ValueRange batchPrefix,
+ ValueRange parentPos) const = 0;
Level getLevel() const { return lvl; }
LevelType getLT() const { return lt; }
@@ -199,18 +197,17 @@ class SparseIterator {
}
virtual Value genNotEndImpl(OpBuilder &b, Location l) = 0;
virtual Value derefImpl(OpBuilder &b, Location l) = 0;
- // Gets the current position and the optional *position high* (for
- // non-unique iterators), the value is essentially the number of sparse
- // coordinate that the iterator is current visiting. It should be able to
- // uniquely identify the sparse range for the next level. See
- // SparseTensorLevel::peekRangeAt();
+ // Gets the ValueRange that together specifies the current position of the
+ // iterator. For a unique level, the position can be a single index points to
+ // the current coordinate being visited. For a non-unique level, an extra
+ // index for the `segment high` is needed to to specifies the range of
+ // duplicated coordinates. The ValueRange should be able to uniquely identify
+ // the sparse range for the next level. See SparseTensorLevel::peekRangeAt();
//
// Not every type of iterator supports the operation, e.g., non-empty
// subsection iterator does not because it represent a range of coordinates
// instead of just one.
- virtual std::pair<Value, Value> getCurPosition() const {
- llvm_unreachable("unsupported");
- };
+ virtual ValueRange getCurPosition() const { return getCursor(); };
// Returns a pair of values for *upper*, *lower* bound respectively.
virtual std::pair<Value, Value> genForCond(OpBuilder &b, Location l) {
``````````
</details>
https://github.com/llvm/llvm-project/pull/90243
More information about the Mlir-commits
mailing list