[Mlir-commits] [mlir] [mlir][sparse] use shared value storage between wrapped iterator and the wrapper. (PR #80046)

Peiming Liu llvmlistbot at llvm.org
Tue Jan 30 11:39:08 PST 2024


https://github.com/PeimingLiu updated https://github.com/llvm/llvm-project/pull/80046

>From 23191c9eca794e52640778acc07f7831a96dee79 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Thu, 25 Jan 2024 17:38:27 +0000
Subject: [PATCH 1/2] [mlir][sparse] use shared value storage for wrapped
 iterator and wrapper

---
 .../Transforms/Utils/SparseTensorLevel.cpp    | 176 ++++++++----------
 .../Transforms/Utils/SparseTensorLevel.h      |  84 ++++++---
 2 files changed, 138 insertions(+), 122 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
index e43896942d7fe..98323c2195461 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
@@ -205,34 +205,48 @@ static Value offsetFromMinCrd(OpBuilder &b, Location l, Value minCrd,
 
 namespace {
 
-class TrivialIterator : public SparseIterator {
-  Value getLoopLo(OpBuilder &b, Location l) const {
-    // Dense loop are traversed by coordinate, delinearize the position to get
-    // the coordinate.
-    if (randomAccessible())
-      return SUBI(itPos, posLo);
-    return itPos;
+// The iterator that traverses a concrete sparse tensor levels. High-level
+// abstract iterators wrap it to achieve more complex goals (such as collapsing
+// several levels). It also holds the common storage to hold the mlir::Values
+// for itself as well as for wrappers.
+class ConcreteIterator : public SparseIterator {
+protected:
+  ConcreteIterator(const SparseTensorLevel &stl, IterKind kind,
+                   unsigned itValCnt)
+      : SparseIterator(kind, stl.tid, stl.lvl, itValCnt, itValsStorage),
+        stl(stl) {
+    // Allocate enough storage for iterator values.
+    itValsStorage.resize(itValCnt);
   }
 
 public:
-  TrivialIterator(const SparseTensorLevel &stl,
-                  const IterKind kind = IterKind::kTrivial)
-      : SparseIterator(kind, stl.tid, stl.lvl, itPos), stl(stl) {}
-
   // For LLVM-style RTTI.
   static bool classof(const SparseIterator *from) {
     return from->kind == IterKind::kTrivial;
   }
 
   bool randomAccessible() const override { return isDenseLT(stl.getLT()); };
-  bool iteratableByFor() const override { return true; };
+  bool iteratableByFor() const override { return kind != IterKind::kDedup; };
   Value upperBound(OpBuilder &b, Location l) const override {
     return stl.size();
   };
 
+protected:
+  // Owner of the storage, all wrappers build on top of a concrete iterator
+  // share the same storage such that the iterator values are always
+  // synchronized.
+  SmallVector<Value> itValsStorage;
+  const SparseTensorLevel &stl;
+};
+
+class TrivialIterator : public ConcreteIterator {
+public:
+  TrivialIterator(const SparseTensorLevel &stl)
+      : ConcreteIterator(stl, IterKind::kTrivial, /*itValCnt=*/1) {}
+
   SmallVector<Value> serialize() const override {
     SmallVector<Value> ret;
-    ret.push_back(itPos);
+    ret.push_back(getItPos());
     if (randomAccessible()) {
       // Loop high is implicit (defined by `upperBound()`) for random-access
       // iterator, but we need to memorize posLo for linearization.
@@ -252,10 +266,10 @@ class TrivialIterator : public SparseIterator {
       posHi = vs.back();
   };
 
-  ValuePair getCurPosition() const override { return {itPos, nullptr}; }
+  ValuePair getCurPosition() const override { return {getItPos(), nullptr}; }
 
-  void genInit(OpBuilder &b, Location l,
-               const SparseIterator *parent) override {
+  void genInitImpl(OpBuilder &b, Location l,
+                   const SparseIterator *parent) override {
     Value pos = C_IDX(0);
     Value hi = nullptr;
     if (parent)
@@ -269,25 +283,25 @@ class TrivialIterator : public SparseIterator {
   ValuePair genForCond(OpBuilder &b, Location l) override {
     if (randomAccessible())
       return {deref(b, l), upperBound(b, l)};
-    return std::make_pair(getLoopLo(b, l), posHi);
+    return std::make_pair(getItPos(), posHi);
   }
 
   Value genNotEnd(OpBuilder &b, Location l) override {
     // We used the first level bound as the bound the collapsed set of levels.
-    return CMPI(ult, itPos, posHi);
+    return CMPI(ult, getItPos(), posHi);
   }
 
   Value deref(OpBuilder &b, Location l) override {
     if (randomAccessible()) {
-      updateCrd(SUBI(itPos, posLo));
+      updateCrd(SUBI(getItPos(), posLo));
     } else {
-      updateCrd(stl.peekCrdAt(b, l, itPos));
+      updateCrd(stl.peekCrdAt(b, l, getItPos()));
     }
     return getCrd();
   };
 
-  ValueRange forward(OpBuilder &b, Location l) override {
-    seek(ADDI(itPos, C_IDX(1)));
+  ValueRange forwardImpl(OpBuilder &b, Location l) override {
+    seek(ADDI(getItPos(), C_IDX(1)));
     return getItVals();
   }
 
@@ -305,20 +319,17 @@ class TrivialIterator : public SparseIterator {
     updateCrd(crd);
   }
 
-  Value itPos; // the position that represent the iterator
-
+  Value getItPos() const { return getItVals().front(); }
   Value posLo, posHi;
-  const SparseTensorLevel &stl;
 };
 
-class DedupIterator : public SparseIterator {
+class DedupIterator : public ConcreteIterator {
 private:
   Value genSegmentHigh(OpBuilder &b, Location l, Value pos);
 
 public:
   DedupIterator(const SparseTensorLevel &stl)
-      : SparseIterator(IterKind::kDedup, stl.tid, stl.lvl, posAndSegHi),
-        stl(stl) {
+      : ConcreteIterator(stl, IterKind::kDedup, /*itValCnt=*/2) {
     assert(!stl.isUnique());
   }
   // For LLVM-style RTTI.
@@ -326,16 +337,10 @@ class DedupIterator : public SparseIterator {
     return from->kind == IterKind::kDedup;
   }
 
-  bool randomAccessible() const override { return false; };
-  bool iteratableByFor() const override { return false; };
-  Value upperBound(OpBuilder &b, Location l) const override {
-    return stl.size();
-  };
-
   ValuePair getCurPosition() const override { return {getPos(), getSegHi()}; }
 
-  void genInit(OpBuilder &b, Location l,
-               const SparseIterator *parent) override {
+  void genInitImpl(OpBuilder &b, Location l,
+                   const SparseIterator *parent) override {
 
     Value pos = C_IDX(0);
     Value hi = nullptr;
@@ -369,18 +374,16 @@ class DedupIterator : public SparseIterator {
     return getCrd();
   };
 
-  ValueRange forward(OpBuilder &b, Location l) override {
+  ValueRange forwardImpl(OpBuilder &b, Location l) override {
     Value nxPos = getSegHi(); // forward the position to the next segment.
     seek({nxPos, genSegmentHigh(b, l, nxPos)});
     return getItVals();
   }
 
-  Value getPos() const { return posAndSegHi[0]; }
-  Value getSegHi() const { return posAndSegHi[1]; }
+  Value getPos() const { return getItVals()[0]; }
+  Value getSegHi() const { return getItVals()[1]; }
 
   Value posHi;
-  Value posAndSegHi[2]; // position and segment high
-  const SparseTensorLevel &stl;
 };
 
 //
@@ -424,8 +427,8 @@ class FilterIterator : public SparseIterator {
   void deserialize(ValueRange vs) override { wrap->deserialize(vs); };
   ValuePair getCurPosition() const override { return wrap->getCurPosition(); }
 
-  void genInit(OpBuilder &b, Location l,
-               const SparseIterator *parent) override {
+  void genInitImpl(OpBuilder &b, Location l,
+                   const SparseIterator *parent) override {
     wrap->genInit(b, l, parent);
     if (!randomAccessible()) {
       // TODO: we can skip this when stride == 1 and offset == 0, we can also
@@ -451,9 +454,9 @@ class FilterIterator : public SparseIterator {
     updateCrd(crd);
   }
 
-  ValueRange forward(OpBuilder &b, Location l) override;
+  ValueRange forwardImpl(OpBuilder &b, Location l) override;
 
-  const Value offset, stride, size;
+  Value offset, stride, size;
   std::unique_ptr<SparseIterator> wrap;
 };
 
@@ -467,7 +470,7 @@ class NonEmptySubSectIterator : public SparseIterator {
                           std::unique_ptr<SparseIterator> &&delegate,
                           Value subSectSz)
       : SparseIterator(IterKind::kNonEmptySubSect, delegate->tid, delegate->lvl,
-                       /*itVals=*/subSectMeta),
+                       3, /*itVals=*/subSectMeta),
         parent(parent), delegate(std::move(delegate)),
         tupleSz(this->delegate->serialize().size()), subSectSz(subSectSz) {
     auto *p = dyn_cast_or_null<NonEmptySubSectIterator>(parent);
@@ -555,7 +558,7 @@ class NonEmptySubSectIterator : public SparseIterator {
     return ADDI(SUBI(parentUB, subSectSz), C_IDX(1));
   };
 
-  void genInit(OpBuilder &b, Location l, const SparseIterator *) override;
+  void genInitImpl(OpBuilder &b, Location l, const SparseIterator *) override;
 
   void locate(OpBuilder &b, Location l, Value crd) override {
     Value absOff = crd;
@@ -587,7 +590,7 @@ class NonEmptySubSectIterator : public SparseIterator {
     return crd;
   };
 
-  ValueRange forward(OpBuilder &b, Location l) override;
+  ValueRange forwardImpl(OpBuilder &b, Location l) override;
 
   Value getMinCrd() const { return subSectMeta[0]; }
   Value getAbsOff() const { return subSectMeta[1]; }
@@ -605,7 +608,8 @@ class NonEmptySubSectIterator : public SparseIterator {
 
   const Value subSectSz;
 
-  Value subSectMeta[3]; // minCrd, absolute offset, notEnd
+  // minCrd, absolute offset, notEnd
+  SmallVector<Value, 3> subSectMeta{nullptr, nullptr, nullptr};
 };
 
 class SubSectIterator;
@@ -628,41 +632,18 @@ struct SubSectIterHelper {
 };
 
 class SubSectIterator : public SparseIterator {
-  // RAII to sync iterator values between the wrap the iterator and the
-  // SubSectIterator.
-  struct WrapItValSyncer {
-    explicit WrapItValSyncer(SubSectIterator &it) : it(it) {
-      if (!it.randomAccessible())
-        it.wrap->seek(it.getItVals().drop_back());
-    }
-    ~WrapItValSyncer() {
-      if (!it.randomAccessible()) {
-        ValueRange wrapItVals = it.wrap->getItVals();
-        std::copy(wrapItVals.begin(), wrapItVals.end(), it.itVals.begin());
-      }
-    }
-    SubSectIterator ⁢
-  };
-
 public:
   SubSectIterator(const NonEmptySubSectIterator &subSect,
                   const SparseIterator &parent,
                   std::unique_ptr<SparseIterator> &&wrap, Value size,
                   unsigned stride)
-      : SparseIterator(IterKind::kSubSect, *wrap), itVals(), subSect(subSect),
-        wrap(std::move(wrap)), parent(parent), size(size), stride(stride),
-        helper(*this) {
+      : SparseIterator(IterKind::kSubSect, *wrap,
+                       /*extraVal=*/wrap->randomAccessible() ? 0 : 1),
+        subSect(subSect), wrap(std::move(wrap)), parent(parent), size(size),
+        stride(stride), helper(*this) {
     assert(stride == 1 && "Not implemented.");
     assert(subSect.tid == tid && subSect.lvl == lvl);
     assert(parent.kind != IterKind::kSubSect || parent.lvl + 1 == lvl);
-
-    if (!randomAccessible()) {
-      // We maintain a extra counter to count the actually sparse coordinate
-      // included in the subsection.
-      unsigned itValSz = this->wrap->getItVals().size() + 1;
-      itVals.resize(itValSz, nullptr);
-      relinkItVals(itVals);
-    }
   };
 
   // For LLVM-style RTTI.
@@ -681,11 +662,10 @@ class SubSectIterator : public SparseIterator {
     if (randomAccessible()) {
       return ADDI(getCrd(), nxLvlTupleStart);
     };
-    return ADDI(itVals.back(), nxLvlTupleStart);
+    return ADDI(getItVals().back(), nxLvlTupleStart);
   }
 
-  void genInit(OpBuilder &b, Location l, const SparseIterator *) override {
-    WrapItValSyncer syncer(*this);
+  void genInitImpl(OpBuilder &b, Location l, const SparseIterator *) override {
     if (randomAccessible()) {
       if (auto *p = llvm::dyn_cast<SubSectIterator>(&parent)) {
         assert(p->lvl + 1 == lvl);
@@ -700,10 +680,10 @@ class SubSectIterator : public SparseIterator {
       return;
     }
     assert(!randomAccessible());
-    assert(itVals.size() == wrap->getItVals().size() + 1);
+    assert(getItVals().size() == wrap->getItVals().size() + 1);
     // Extra counter that counts the number of actually visited coordinates in
     // the sparse subsection.
-    itVals.back() = C_IDX(0);
+    getMutItVals().back() = C_IDX(0);
     Value tupleId;
     if (auto *p = llvm::dyn_cast<SubSectIterator>(&parent)) {
       assert(p->lvl + 1 == lvl);
@@ -717,35 +697,28 @@ class SubSectIterator : public SparseIterator {
   }
 
   void locate(OpBuilder &b, Location l, Value crd) override {
-    WrapItValSyncer syncer(*this);
     helper.locate(b, l, crd);
     updateCrd(crd);
   }
 
   Value genNotEnd(OpBuilder &b, Location l) override {
-    WrapItValSyncer syncer(*this);
     return helper.genNotEnd(b, l);
   }
 
   Value deref(OpBuilder &b, Location l) override {
-    WrapItValSyncer syncer(*this);
     Value crd = helper.deref(b, l);
     updateCrd(crd);
     return crd;
   };
 
-  ValueRange forward(OpBuilder &b, Location l) override {
-    {
-      WrapItValSyncer syncer(*this);
-      helper.forward(b, l);
-    }
+  ValueRange forwardImpl(OpBuilder &b, Location l) override {
+    helper.forward(b, l);
     assert(!randomAccessible());
-    assert(itVals.size() == wrap->getItVals().size() + 1);
-    itVals.back() = ADDI(itVals.back(), C_IDX(1));
+    assert(getItVals().size() == wrap->getItVals().size() + 1);
+    getMutItVals().back() = ADDI(getItVals().back(), C_IDX(1));
     return getItVals();
   };
 
-  SmallVector<Value> itVals;
   Value nxLvlTupleStart;
 
   const NonEmptySubSectIterator &subSect;
@@ -764,6 +737,17 @@ class SubSectIterator : public SparseIterator {
 // SparseIterator derived classes implementation.
 //===----------------------------------------------------------------------===//
 
+void SparseIterator::genInit(OpBuilder &b, Location l,
+                             const SparseIterator *p) {
+  // TODO: support lowering to function call.
+  return genInitImpl(b, l, p);
+}
+
+ValueRange SparseIterator::forward(OpBuilder &b, Location l) {
+  // TODO: support lowering to function call.
+  return forwardImpl(b, l);
+}
+
 ValueRange SparseIterator::forwardIf(OpBuilder &b, Location l, Value cond) {
   auto ifOp = b.create<scf::IfOp>(l, getItVals().getTypes(), cond, true);
   // Generate else branch first, otherwise iterator values will be updated by
@@ -846,7 +830,7 @@ Value FilterIterator::genNotEnd(OpBuilder &b, Location l) {
   return r.front();
 }
 
-ValueRange FilterIterator::forward(OpBuilder &b, Location l) {
+ValueRange FilterIterator::forwardImpl(OpBuilder &b, Location l) {
   assert(!randomAccessible());
   // Generates
   //
@@ -1013,8 +997,8 @@ ValueRange NonEmptySubSectIterator::inflateSubSectTree(
   return p->inflateSubSectTree(b, l, reduc, visitDenseSubSect);
 }
 
-void NonEmptySubSectIterator::genInit(OpBuilder &b, Location l,
-                                      const SparseIterator *) {
+void NonEmptySubSectIterator::genInitImpl(OpBuilder &b, Location l,
+                                          const SparseIterator *) {
   Value c0 = C_IDX(0);
   if (!isSubSectRoot()) {
     assert(parent->lvl + 1 == lvl);
@@ -1096,7 +1080,7 @@ void NonEmptySubSectIterator::genInit(OpBuilder &b, Location l,
   seek(meta);
 }
 
-ValueRange NonEmptySubSectIterator::forward(OpBuilder &b, Location l) {
+ValueRange NonEmptySubSectIterator::forwardImpl(OpBuilder &b, Location l) {
   assert(!randomAccessible());
   Value c0 = C_IDX(0), c1 = C_IDX(1);
   // Forward to the next non empty slice by generating
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
index d2b3396b72836..ee5b4ee39003b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
@@ -80,24 +80,37 @@ class SparseIterator {
   SparseIterator &operator=(const SparseIterator &) = delete;
 
 protected:
-  SparseIterator(IterKind kind, unsigned tid, unsigned lvl,
-                 MutableArrayRef<Value> itVals)
-      : kind(kind), tid(tid), lvl(lvl), crd(nullptr), itVals(itVals){};
+  SparseIterator(IterKind kind, unsigned tid, unsigned lvl, unsigned itValsCnt,
+                 SmallVectorImpl<Value> &storage)
+      : kind(kind), tid(tid), lvl(lvl), crd(nullptr), itValsCnt(itValsCnt),
+        itValsStorageRef(storage){};
 
   SparseIterator(IterKind kind, const SparseIterator &wrap)
       : kind(kind), tid(wrap.tid), lvl(wrap.lvl), crd(nullptr),
-        itVals(wrap.itVals){};
+        itValsCnt(wrap.itValsCnt), itValsStorageRef(wrap.itValsStorageRef) {
+    assert(wrap.itValsCnt == itValsStorage.size());
+  };
+
+  SparseIterator(IterKind kind, const SparseIterator &wrap, unsigned extraVal)
+      : kind(kind), tid(wrap.tid), lvl(wrap.lvl), crd(nullptr),
+        itValsCnt(wrap.itValsCnt + extraVal),
+        itValsStorageRef(wrap.itValsStorageRef) {
+    itValsStorageRef.append(extraVal, nullptr);
+    assert(itValsCnt == itValsStorage.size());
+  };
 
 public:
   virtual ~SparseIterator() = default;
 
   Value getCrd() const { return crd; }
-  ValueRange getItVals() const { return itVals; };
+  ValueRange getItVals() const {
+    return ValueRange(itValsStorageRef).take_front(itValsCnt);
+  };
 
   // Sets the iterate to the specified position.
   void seek(ValueRange vals) {
-    assert(vals.size() == itVals.size());
-    std::copy(vals.begin(), vals.end(), itVals.begin());
+    assert(vals.size() == itValsCnt);
+    std::copy(vals.begin(), vals.end(), itValsStorageRef.begin());
     // Now that the iterator is re-positioned, the coordinate becomes invalid.
     crd = nullptr;
   }
@@ -119,8 +132,8 @@ class SparseIterator {
   virtual Value upperBound(OpBuilder &b, Location l) const = 0;
 
   // Serializes and deserializes the current status to/from a set of values. The
-  // ValueRange should contain values that specifies the current postion and
-  // loop bound.
+  // ValueRange should contain values that are sufficient to recover the current
+  // iterating postion (i.e., itVals) as well as loop bound.
   //
   // Not every type of iterator supports the operations, e.g., non-empty
   // subsection iterator does not because the the number of non-empty
@@ -136,6 +149,25 @@ class SparseIterator {
   // Core functions.
   //
 
+  // Initializes the iterator according to the parent iterator's state.
+  void genInit(OpBuilder &b, Location l, const SparseIterator *p);
+
+  // Forwards the iterator to the next element.
+  ValueRange forward(OpBuilder &b, Location l);
+
+  // Actual Implementation provided by derived class.
+  virtual void genInitImpl(OpBuilder &, Location, const SparseIterator *) = 0;
+  virtual ValueRange forwardImpl(OpBuilder &b, Location l) = 0;
+
+  // Returns a boolean value that equals `!it.end()`
+  virtual Value genNotEnd(OpBuilder &b, Location l) = 0;
+
+  // Dereferences the iterator, loads the coordinate at the current position.
+  //
+  // The method assumes that the iterator is not currently exhausted (i.e.,
+  // it != it.end()).
+  virtual Value deref(OpBuilder &b, Location l) = 0;
+
   // Gets the current position and the optional *position high* (for non-unique
   // iterators), the value is essentially the number of sparse coordinate that
   // the iterator is current visiting. It should be able to uniquely identify
@@ -148,9 +180,6 @@ class SparseIterator {
     llvm_unreachable("unsupported");
   };
 
-  // Initializes the iterator according to the parent iterator's state.
-  virtual void genInit(OpBuilder &, Location, const SparseIterator *) = 0;
-
   // Returns a pair of values for *upper*, *lower* bound respectively.
   virtual std::pair<Value, Value> genForCond(OpBuilder &b, Location l) {
     assert(randomAccessible());
@@ -158,22 +187,13 @@ class SparseIterator {
     return {getCrd(), upperBound(b, l)};
   }
 
-  // Returns a boolean value that equals `!it.end()`
-  virtual Value genNotEnd(OpBuilder &b, Location l) = 0;
+  // Generates a bool value for scf::ConditionOp.
   std::pair<Value, ValueRange> genWhileCond(OpBuilder &b, Location l,
                                             ValueRange vs) {
     ValueRange rem = linkNewScope(vs);
     return std::make_pair(genNotEnd(b, l), rem);
   }
 
-  // Dereference the iterator, loads the coordinate at the current position.
-  //
-  // The method assumes that the iterator is not currently exhausted (i.e.,
-  // it != it.end()).
-  virtual Value deref(OpBuilder &b, Location l) = 0;
-
-  virtual ValueRange forward(OpBuilder &b, Location l) = 0;
-
   // Generate a conditional it.next() in the following form
   //
   // if (cond)
@@ -198,13 +218,16 @@ class SparseIterator {
   ValueRange linkNewScope(ValueRange pos) {
     assert(!randomAccessible() && "random accessible iterators are traversed "
                                   "by coordinate, call locate() instead.");
-    seek(pos.take_front(itVals.size()));
-    return pos.drop_front(itVals.size());
+    seek(pos.take_front(itValsCnt));
+    return pos.drop_front(itValsCnt);
   };
 
 protected:
   void updateCrd(Value crd) { this->crd = crd; }
-  void relinkItVals(MutableArrayRef<Value> itVals) { this->itVals = itVals; }
+  MutableArrayRef<Value> getMutItVals() {
+    MutableArrayRef<Value> ref = itValsStorageRef;
+    return ref.take_front(itValsCnt);
+  }
 
 public:
   const IterKind kind;     // For LLVM-style RTTI.
@@ -219,7 +242,16 @@ class SparseIterator {
   // For trivial iterators, it is the position; for dedup iterators, it consists
   // of the positon and the segment high, for non-empty subsection iterator, it
   // is the metadata that specifies the subsection.
-  MutableArrayRef<Value> itVals;
+  // Note that the wrapped iterator shares the same storage to maintain itVals
+  // with it wrapper, which means the wrapped iterator might only owns a subset
+  // of all the values stored in itValStorage.
+  const unsigned itValsCnt;
+  SmallVectorImpl<Value> &itValsStorageRef;
+  // All other (loop invariant) values used by the iterator. Although these
+  // values are not updated between loop iterations, they still need to be
+  // passed as function parameters to reconstruct the iterator in a new function
+  // scope.
+  // SmallVectorImpl<Value> &metaValsStorageRef;
 };
 
 /// Helper function to create a TensorLevel object from given `tensor`.

>From f50da8b5174cf362c85e7530a8259a676c696476 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Tue, 30 Jan 2024 19:38:54 +0000
Subject: [PATCH 2/2] address comments

---
 .../SparseTensor/Transforms/Utils/SparseTensorLevel.h | 11 +++--------
 1 file changed, 3 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
index ee5b4ee39003b..bf115712bdfc1 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
@@ -88,7 +88,7 @@ class SparseIterator {
   SparseIterator(IterKind kind, const SparseIterator &wrap)
       : kind(kind), tid(wrap.tid), lvl(wrap.lvl), crd(nullptr),
         itValsCnt(wrap.itValsCnt), itValsStorageRef(wrap.itValsStorageRef) {
-    assert(wrap.itValsCnt == itValsStorage.size());
+    assert(wrap.itValsCnt == itValsStorageRef.size());
   };
 
   SparseIterator(IterKind kind, const SparseIterator &wrap, unsigned extraVal)
@@ -96,7 +96,7 @@ class SparseIterator {
         itValsCnt(wrap.itValsCnt + extraVal),
         itValsStorageRef(wrap.itValsStorageRef) {
     itValsStorageRef.append(extraVal, nullptr);
-    assert(itValsCnt == itValsStorage.size());
+    assert(itValsCnt == itValsStorageRef.size());
   };
 
 public:
@@ -243,15 +243,10 @@ class SparseIterator {
   // of the positon and the segment high, for non-empty subsection iterator, it
   // is the metadata that specifies the subsection.
   // Note that the wrapped iterator shares the same storage to maintain itVals
-  // with it wrapper, which means the wrapped iterator might only owns a subset
+  // with it wrapper, which means the wrapped iterator might only own a subset
   // of all the values stored in itValStorage.
   const unsigned itValsCnt;
   SmallVectorImpl<Value> &itValsStorageRef;
-  // All other (loop invariant) values used by the iterator. Although these
-  // values are not updated between loop iterations, they still need to be
-  // passed as function parameters to reconstruct the iterator in a new function
-  // scope.
-  // SmallVectorImpl<Value> &metaValsStorageRef;
 };
 
 /// Helper function to create a TensorLevel object from given `tensor`.



More information about the Mlir-commits mailing list