[Mlir-commits] [mlir] [mlir][sparse] use ValueRange instead of std::pair for iterator position. (PR #90243)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Apr 26 11:11:03 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Peiming Liu (PeimingLiu)

<details>
<summary>Changes</summary>

`ValueRange` is more easy to be extended (e.g., for padded iterator).



---
Full diff: https://github.com/llvm/llvm-project/pull/90243.diff


3 Files Affected:

- (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h (+1-1) 
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp (+34-31) 
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h (+12-15) 


``````````diff
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
index 59c3e49264dbe1..34312df912997b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
@@ -222,7 +222,7 @@ class LoopEmitter {
   ///
   SmallVector<Value> getValPosits(TensorId tid) const {
     SmallVector<Value> batchCrds = iters[tid].back().back()->getBatchCrds();
-    Value lastLvlPos = iters[tid].back().back()->getCurPosition().first;
+    Value lastLvlPos = iters[tid].back().back()->getCurPosition().front();
     batchCrds.push_back(lastLvlPos);
     return batchCrds;
   };
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
index 60dca3c55dec3d..745c081247dee8 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
@@ -94,8 +94,10 @@ class DenseLevel : public SparseTensorLevel {
 
   ValueRange getLvlBuffers() const override { return {}; }
 
-  ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange, Value p,
-                        Value max) const override {
+  ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
+                        ValueRange parentPos) const override {
+    assert(parentPos.size() == 1 && "Dense level can not be non-unique.");
+    Value p = parentPos.front();
     Value posLo = MULI(p, lvlSize);
     return {posLo, lvlSize};
   }
@@ -112,9 +114,9 @@ class BatchLevel : public SparseTensorLevel {
 
   ValueRange getLvlBuffers() const override { return {}; }
 
-  ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange, Value p,
-                        Value max) const override {
-    assert(max == nullptr && "Dense level can not be non-unique.");
+  ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange,
+                        ValueRange parentPos) const override {
+    assert(parentPos.size() == 1 && "Dense level can not be non-unique.");
     // No need to linearize the position for non-annotated tensors.
     return {C_IDX(0), lvlSize};
   }
@@ -127,9 +129,11 @@ class CompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
       : SparseLevel(tid, lvl, lt, lvlSize, {posBuffer, crdBuffer}) {}
 
   ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
-                        Value p, Value max) const override {
-    assert(max == nullptr &&
+                        ValueRange parentPos) const override {
+
+    assert(parentPos.size() == 1 &&
            "compressed level must be the first non-unique level.");
+    Value p = parentPos.front();
 
     SmallVector<Value> memCrd(batchPrefix);
     memCrd.push_back(p);
@@ -147,11 +151,11 @@ class LooseCompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
       : SparseLevel(tid, lvl, lt, lvlSize, {posBuffer, crdBuffer}) {}
 
   ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
-                        Value p, Value max) const override {
-    assert(max == nullptr &&
+                        ValueRange parentPos) const override {
+    assert(parentPos.size() == 1 &&
            "loose-compressed level must be the first non-unique level.");
     SmallVector<Value> memCrd(batchPrefix);
-
+    Value p = parentPos.front();
     p = MULI(p, C_IDX(2));
     memCrd.push_back(p);
     Value pLo = genIndexLoad(b, l, getPosBuf(), memCrd);
@@ -168,10 +172,13 @@ class SingletonLevel : public SparseLevel</*hasPosBuf=*/false> {
       : SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer}) {}
 
   ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
-                        Value p, Value segHi) const override {
+                        ValueRange parentPos) const override {
+    assert(parentPos.size() == 1 || parentPos.size() == 2);
+    Value p = parentPos.front();
+    Value segHi = parentPos.size() == 2 ? parentPos.back() : nullptr;
+
     if (segHi == nullptr)
       return {p, ADDI(p, C_IDX(1))};
-
     // Use the segHi as the loop upper bound.
     return {p, segHi};
   }
@@ -184,11 +191,12 @@ class NOutOfMLevel : public SparseLevel</*hasPosBuf=*/false> {
       : SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer}) {}
 
   ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
-                        Value p, Value max) const override {
-    assert(max == nullptr && isUnique() && "n:m level can not be non-unique.");
+                        ValueRange parentPos) const override {
+    assert(parentPos.size() == 1 && isUnique() &&
+           "n:m level can not be non-unique.");
     // Each n:m blk has exactly n specified elements.
     auto n = getN(lt);
-    Value posLo = MULI(p, C_IDX(n));
+    Value posLo = MULI(parentPos.front(), C_IDX(n));
     return {posLo, ADDI(posLo, C_IDX(n))};
   }
 };
@@ -316,23 +324,21 @@ class TrivialIterator : public ConcreteIterator {
       posHi = vs.back();
   };
 
-  ValuePair getCurPosition() const override { return {getItPos(), nullptr}; }
-
   void genInitImpl(OpBuilder &b, Location l,
                    const SparseIterator *parent) override {
 
     if (isBatchIterator() && batchCrds.size() <= stl.lvl)
       batchCrds.resize(stl.lvl + 1, nullptr);
 
-    Value pos = C_IDX(0);
-    Value hi = nullptr;
+    Value c0 = C_IDX(0);
+    ValueRange pPos = c0;
     // If the parent iterator is a batch iterator, we also start from 0 (but
     // on a different batch).
     if (parent && !parent->isBatchIterator())
-      std::tie(pos, hi) = parent->getCurPosition();
+      pPos = parent->getCurPosition();
 
     ValueRange batchPrefix = parent ? parent->getBatchCrds() : ValueRange{};
-    std::tie(posLo, posHi) = stl.peekRangeAt(b, l, batchPrefix, pos, hi);
+    std::tie(posLo, posHi) = stl.peekRangeAt(b, l, batchPrefix, pPos);
     // Seek to the lowest position.
     seek(posLo);
   }
@@ -406,21 +412,19 @@ class DedupIterator : public ConcreteIterator {
     return {b.getIndexType(), b.getIndexType()};
   }
 
-  ValuePair getCurPosition() const override { return {getPos(), getSegHi()}; }
-
   void genInitImpl(OpBuilder &b, Location l,
                    const SparseIterator *parent) override {
+    Value c0 = C_IDX(0);
+    ValueRange pPos = c0;
 
-    Value pos = C_IDX(0);
-    Value hi = nullptr;
     // If the parent iterator is a batch iterator, we also start from 0 (but
     // on a different batch).
     if (parent && !parent->isBatchIterator())
-      std::tie(pos, hi) = parent->getCurPosition();
+      pPos = parent->getCurPosition();
 
     Value posLo;
     ValueRange batchPrefix = parent ? parent->getBatchCrds() : ValueRange{};
-    std::tie(posLo, posHi) = stl.peekRangeAt(b, l, batchPrefix, pos, hi);
+    std::tie(posLo, posHi) = stl.peekRangeAt(b, l, batchPrefix, pPos);
 
     seek({posLo, genSegmentHigh(b, l, posLo)});
   }
@@ -505,7 +509,7 @@ class FilterIterator : public SparseIterator {
 
   SmallVector<Value> serialize() const override { return wrap->serialize(); };
   void deserialize(ValueRange vs) override { wrap->deserialize(vs); };
-  ValuePair getCurPosition() const override { return wrap->getCurPosition(); }
+  ValueRange getCurPosition() const override { return wrap->getCurPosition(); }
 
   void genInitImpl(OpBuilder &b, Location l,
                    const SparseIterator *parent) override {
@@ -756,9 +760,8 @@ class SubSectIterator : public SparseIterator {
   Value upperBound(OpBuilder &b, Location l) const override {
     return subSect.subSectSz;
   }
-  std::pair<Value, Value> getCurPosition() const override {
-    return wrap->getCurPosition();
-  };
+
+  ValueRange getCurPosition() const override { return wrap->getCurPosition(); };
 
   Value getNxLvlTupleId(OpBuilder &b, Location l) const {
     if (randomAccessible()) {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
index 46b923250dd893..b692848ec67bd8 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
@@ -36,8 +36,9 @@ class SparseTensorLevel {
                           Value iv) const = 0;
 
   /// Peeks the lower and upper bound to *fully* traverse the level with
-  /// the given position `p` that the immediate parent level is current at.
-  /// Returns a pair of values for *posLo* and *loopHi* respectively.
+  /// the given position `parentPos`, see SparseTensorIterator::getCurPostion(),
+  /// that the immediate parent level is current at. Returns a pair of values
+  /// for *posLo* and *loopHi* respectively.
   ///
   /// For a dense level, the *posLo* is the linearized position at beginning,
   /// while *loopHi* is the largest *coordinate*, it also implies that the
@@ -45,12 +46,9 @@ class SparseTensorLevel {
   ///
   /// For a sparse level, [posLo, loopHi) specifies the range of index pointer
   /// to load coordinate from the coordinate buffer.
-  ///
-  /// `bound` is only used when the level is `non-unique` and deduplication is
-  /// required. It specifies the max upper bound of the non-unique segment.
   virtual std::pair<Value, Value> peekRangeAt(OpBuilder &b, Location l,
-                                              ValueRange batchPrefix, Value p,
-                                              Value segHi = Value()) const = 0;
+                                              ValueRange batchPrefix,
+                                              ValueRange parentPos) const = 0;
 
   Level getLevel() const { return lvl; }
   LevelType getLT() const { return lt; }
@@ -199,18 +197,17 @@ class SparseIterator {
   }
   virtual Value genNotEndImpl(OpBuilder &b, Location l) = 0;
   virtual Value derefImpl(OpBuilder &b, Location l) = 0;
-  // Gets the current position and the optional *position high* (for
-  // non-unique iterators), the value is essentially the number of sparse
-  // coordinate that the iterator is current visiting. It should be able to
-  // uniquely identify the sparse range for the next level. See
-  // SparseTensorLevel::peekRangeAt();
+  // Gets the ValueRange that together specifies the current position of the
+  // iterator. For a unique level, the position can be a single index points to
+  // the current coordinate being visited. For a non-unique level, an extra
+  // index for the `segment high` is needed to to specifies the range of
+  // duplicated coordinates. The ValueRange should be able to uniquely identify
+  // the sparse range for the next level. See SparseTensorLevel::peekRangeAt();
   //
   // Not every type of iterator supports the operation, e.g., non-empty
   // subsection iterator does not because it represent a range of coordinates
   // instead of just one.
-  virtual std::pair<Value, Value> getCurPosition() const {
-    llvm_unreachable("unsupported");
-  };
+  virtual ValueRange getCurPosition() const { return getCursor(); };
 
   // Returns a pair of values for *upper*, *lower* bound respectively.
   virtual std::pair<Value, Value> genForCond(OpBuilder &b, Location l) {

``````````

</details>


https://github.com/llvm/llvm-project/pull/90243


More information about the Mlir-commits mailing list