[Mlir-commits] [mlir] 36c95ee - [mlir][sparse] group tensor id and levels into pairs in loop emitter

Peiming Liu llvmlistbot at llvm.org
Thu May 4 09:15:47 PDT 2023


Author: Peiming Liu
Date: 2023-05-04T16:15:42Z
New Revision: 36c95ee739c0d7e49dc69e0d8d86d30667c14d49

URL: https://github.com/llvm/llvm-project/commit/36c95ee739c0d7e49dc69e0d8d86d30667c14d49
DIFF: https://github.com/llvm/llvm-project/commit/36c95ee739c0d7e49dc69e0d8d86d30667c14d49.diff

LOG: [mlir][sparse] group tensor id and levels into pairs in loop emitter

This addressed some unresolved comments in https://reviews.llvm.org/D142930

Reviewed By: aartbik, wrengr

Differential Revision: https://reviews.llvm.org/D148565

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
    mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
index 1ee5c19b284fe..bf8bd722a50ad 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
@@ -85,6 +85,25 @@ class CodegenEnv {
     return latticeMerger.getDimLevelType(b);
   }
 
+  //
+  // LoopEmitter delegates.
+  //
+
+  constexpr TensorLevel makeTensorLevel(TensorId t, Level l) const {
+    // Make sure LoopEmitter, GenericOp, and Merger agree on the number of
+    // tensors. Merger has one more synthetic tensor for loop invariants.
+    assert(loopEmitter.getNumTensors() == linalgOp->getNumOperands() &&
+           loopEmitter.getNumTensors() == latticeMerger.getNumTensors() - 1);
+    return loopEmitter.makeTensorLevel(t, l);
+  }
+  std::pair<TensorId, Level> unpackTensorLevel(TensorLevel tl) const {
+    return loopEmitter.unpackTensorLevel(tl);
+  }
+  template <class ContainerTy>
+  auto unpackTensorLevelRange(ContainerTy &&c) const {
+    return loopEmitter.unpackTensorLevelRange(std::forward<ContainerTy>(c));
+  }
+
   //
   // Code generation environment verify functions.
   //

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
index ba6b4641408a5..731a1a9e460e5 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
@@ -456,13 +456,12 @@ void LoopEmitter::initializeLoopEmit(OpBuilder &builder, Location loc,
 }
 
 void LoopEmitter::enterNewLoopSeq(OpBuilder &builder, Location loc,
-                                  ArrayRef<TensorId> tids,
-                                  ArrayRef<Level> lvls) {
+                                  ArrayRef<TensorLevel> tidLvls) {
   // TODO: sort
   assert(loopSeqStack.size() == loopStack.size());
   // Prepares for all the tensors used in the current loop sequence.
   std::vector<std::tuple<TensorId, Level, bool>> slicedTids;
-  for (auto [tid, lvl] : llvm::zip(tids, lvls)) {
+  for (auto [tid, lvl] : unpackTensorLevelRange(tidLvls)) {
     if (!dependentLvlMap[tid][lvl].empty()) {
       bool fullyRed = genSliceBegin(builder, loc, tid, lvl);
       slicedTids.emplace_back(tid, lvl, fullyRed);
@@ -660,17 +659,19 @@ Operation *LoopEmitter::emitWhileLoopOverSliceAtSparseLvl(
   return loop;
 }
 
-Operation *LoopEmitter::enterLoopOverTensorAtLvl(
-    OpBuilder &builder, Location loc, ArrayRef<TensorId> tids,
-    ArrayRef<Level> lvls, MutableArrayRef<Value> reduc, bool isParallel) {
+Operation *LoopEmitter::enterLoopOverTensorAtLvl(OpBuilder &builder,
+                                                 Location loc,
+                                                 ArrayRef<TensorLevel> tidLvls,
+                                                 MutableArrayRef<Value> reduc,
+                                                 bool isParallel) {
   // TODO: support multiple return on parallel for?
   assert(!isParallel || reduc.size() <= 1);
   bool isSparseCond = false, isSparseSliceCond = false;
-  size_t tid = tids.front(), lvl = lvls.front();
+  auto [tid, lvl] = unpackTensorLevel(tidLvls.front());
 
   // Finds out the tensor level that we should use to generate loops. Amongs all
   // the tensor levels, there is at most one sparse tensor level.
-  for (auto [t, l] : llvm::zip(tids, lvls)) {
+  for (auto [t, l] : unpackTensorLevelRange(tidLvls)) {
     assert(lvlTypes[t].size() > l);         // Must be a valid tid, dim pair
     assert(!coords[t][l] ||                 // We cannot re-enter the same level
            !dependentLvlMap[t][l].empty()); // unless it is a slice-driver loop
@@ -712,12 +713,9 @@ Operation *LoopEmitter::enterLoopOverTensorAtLvl(
   Operation *l = nullptr;
 
   // At most one tensor used as condition in for loop;
-  SmallVector<TensorId, 1> condTid;
-  SmallVector<Level, 1> condLvl;
-  // There Might be multiple dense slice driven tensor.
-  SmallVector<TensorId> sliceTids;
-  SmallVector<Level> sliceLvls;
-  SmallVector<bool> sliceReduc;
+  SmallVector<TensorLevel, 1> condTidLvl;
+  // There might be multiple dense slice driven tensor.
+  SmallVector<SliceLoopInfo> sliceDrivenInfo;
 
   // Generates loops 
diff erently depending on whether we need a slice-driven
   // loop or a simple level traversal loop.
@@ -734,9 +732,7 @@ Operation *LoopEmitter::enterLoopOverTensorAtLvl(
           lvl, reduc);
     }
     levelReducedDep[tid][lvl]++;
-    sliceTids.push_back(tid);
-    sliceLvls.push_back(lvl);
-    sliceReduc.push_back(fullyReduced);
+    sliceDrivenInfo.emplace_back(tid, lvl, fullyReduced);
   } else {
     Value lo = isSparseCond ? posits[tid][lvl]           // current offset
                             : loopSeqStack.back().first; // universal index
@@ -747,21 +743,19 @@ Operation *LoopEmitter::enterLoopOverTensorAtLvl(
       // Adjust for loop hi for dense slice-driven loop.
       if (fullyReduced) {
         hi = sliceSz;
-        condTid.push_back(tid);
-        condLvl.push_back(lvl);
+        condTidLvl.push_back(makeTensorLevel(tid, lvl));
       } else {
         hi = SUBI(lvlSizes[tid][lvl], sliceSz);
         hi = ADDI(hi, C_IDX(1));
       }
     } else {
-      condTid.push_back(tid);
-      condLvl.push_back(lvl);
+      condTidLvl.push_back(makeTensorLevel(tid, lvl));
     }
     l = emitForLoopOverTensorAtLvl(builder, loc, tid, lvl, lo, hi, reduc,
                                    isParallel);
   }
   Value iv = coords[tid][lvl];
-  for (auto [t, l] : llvm::zip(tids, lvls)) {
+  for (auto [t, l] : unpackTensorLevelRange(tidLvls)) {
     // We only need to handle slice-driven loops on dense level here.
     // If it is a slice-driven loop on sparse level, it needs a while loop to
     // insert break statements, and it must have been handled correctly in L692.
@@ -774,9 +768,7 @@ Operation *LoopEmitter::enterLoopOverTensorAtLvl(
       } else {
         // Puts sliced dense loop into LoopInfo so that LoopEmitter knows how to
         // exit it.
-        sliceTids.push_back(t);
-        sliceLvls.push_back(l);
-        sliceReduc.push_back(fullyReduc);
+        sliceDrivenInfo.emplace_back(t, l, fullyReduc);
         // Update the slice information as we enter the new loop.
         assert(*info.slicedOnLvl == l);
         info.minCrd = info.offset = iv;
@@ -787,10 +779,10 @@ Operation *LoopEmitter::enterLoopOverTensorAtLvl(
   }
   // NOTE: we can also prepare for next dim here in advance
   // Pushes the loop into stack.
-  loopStack.emplace_back(condTid, condLvl, sliceTids, sliceLvls, sliceReduc, l,
+  loopStack.emplace_back(condTidLvl, sliceDrivenInfo, l,
                          builder.getInsertionBlock(), iv, loopTag);
   // Emit extra locals.
-  emitExtraLocalsForTensorsAtDenseLvls(builder, loc, tids, lvls);
+  emitExtraLocalsForTensorsAtDenseLvls(builder, loc, tidLvls);
   return l;
 }
 
@@ -854,16 +846,17 @@ Operation *LoopEmitter::enterFilterLoopOverTensorAtLvl(
 
   // NOTE: we can also prepare for next lvl here in advance
   // Push the loop into stack
-  loopStack.emplace_back(ArrayRef<TensorId>(tid), ArrayRef<Level>(lvl),
-                         ArrayRef<TensorId>(), ArrayRef<Level>(),
-                         ArrayRef<bool>(), forOp, builder.getInsertionBlock(),
-                         coords[tid][lvl], nullptr);
+  loopStack.emplace_back(ArrayRef<TensorLevel>(makeTensorLevel(tid, lvl)),
+                         ArrayRef<SliceLoopInfo>(), forOp,
+                         builder.getInsertionBlock(), coords[tid][lvl],
+                         nullptr);
   return forOp;
 }
 
 void LoopEmitter::genDenseAffineAddress(OpBuilder &builder, Location loc,
-                                        TensorId tid, Level lvl,
+                                        TensorLevel tidLvl,
                                         AffineExpr lvlExpr) {
+  auto [tid, lvl] = unpackTensorLevel(tidLvl);
   assert(isDenseDLT(lvlTypes[tid][lvl]));
   // For dense levels, the level-coordinate also serves as the position.
   Value lvlCrd = genAffine(builder, loc, lvlExpr);
@@ -871,16 +864,15 @@ void LoopEmitter::genDenseAffineAddress(OpBuilder &builder, Location loc,
 }
 
 Operation *LoopEmitter::enterCoIterationOverTensorsAtLvls(
-    OpBuilder &builder, Location loc, ArrayRef<TensorId> tids,
-    ArrayRef<Level> lvls, bool needsUniv, MutableArrayRef<Value> reduc) {
+    OpBuilder &builder, Location loc, ArrayRef<TensorLevel> tidLvls,
+    bool needsUniv, MutableArrayRef<Value> reduc) {
   // NOTE: the slice driven tensor-related reduction variable must
   // appear before normal tensors.
-  assert(tids.size() == lvls.size());
   SmallVector<Type> types;
   SmallVector<Value> operands;
   // Construct the while-loop with a parameter for each coordinate.
   const Type indexType = builder.getIndexType();
-  for (auto [tid, lvl] : llvm::zip(tids, lvls)) {
+  for (auto [tid, lvl] : unpackTensorLevelRange(tidLvls)) {
     // TODO: support coiteration with slice driven tensors.
     const auto lvlTp = lvlTypes[tid][lvl];
     assert(dependentLvlMap[tid][lvl].empty() && "TODO: not yet implemented");
@@ -922,7 +914,7 @@ Operation *LoopEmitter::enterCoIterationOverTensorsAtLvls(
   builder.setInsertionPointToStart(&whileOp.getBefore().front());
   Value cond;
   unsigned o = 0;
-  for (auto [t, lvl] : llvm::zip(tids, lvls)) {
+  for (auto [t, lvl] : unpackTensorLevelRange(tidLvls)) {
     const TensorId tid = t; // Why `t` can not be captured by lambda?
     const auto lvlTp = lvlTypes[tid][lvl];
     if (isCompressedDLT(lvlTp) || isSingletonDLT(lvlTp) ||
@@ -956,7 +948,7 @@ Operation *LoopEmitter::enterCoIterationOverTensorsAtLvls(
 
   SmallVector<std::pair<Value, unsigned>> slicesPreds;
   unsigned i = 0;
-  for (auto [tid, lvl] : llvm::zip(tids, lvls)) {
+  for (auto [tid, lvl] : unpackTensorLevelRange(tidLvls)) {
     // Prepares for next level.
     const auto lvlTp = lvlTypes[tid][lvl];
     if (isCompressedDLT(lvlTp) || isSingletonDLT(lvlTp) ||
@@ -1007,7 +999,7 @@ Operation *LoopEmitter::enterCoIterationOverTensorsAtLvls(
   Value min;
   // Finds the minimum coordinate
   if (!needsUniv) {
-    for (auto [tid, lvl] : llvm::zip(tids, lvls)) {
+    for (auto [tid, lvl] : unpackTensorLevelRange(tidLvls)) {
       const auto lvlTp = lvlTypes[tid][lvl];
       if (isCompressedDLT(lvlTp) || isSingletonDLT(lvlTp) ||
           isCompressedWithHiDLT(lvlTp)) {
@@ -1027,12 +1019,11 @@ Operation *LoopEmitter::enterCoIterationOverTensorsAtLvls(
   }
 
   // Sets up the loop stack.
-  loopStack.emplace_back(tids, lvls, ArrayRef<TensorId>(), ArrayRef<Level>(),
-                         ArrayRef<bool>(), whileOp, builder.getInsertionBlock(),
-                         min, loopTag);
+  loopStack.emplace_back(tidLvls, ArrayRef<SliceLoopInfo>(), whileOp,
+                         builder.getInsertionBlock(), min, loopTag);
   assert(loopStack.size() == loopSeqStack.size());
 
-  for (auto [tid, dstLvl] : llvm::zip(tids, lvls)) {
+  for (auto [tid, dstLvl] : unpackTensorLevelRange(tidLvls)) {
     const auto reassoc = getCollapseReassociation(tid, dstLvl);
     assert(reassoc.size() == 1 || isUniqueCOOType(tensors[tid].getType()));
     // TODO: Refactors this into smaller functions.
@@ -1079,7 +1070,7 @@ Operation *LoopEmitter::enterCoIterationOverTensorsAtLvls(
   }
 
   // Emits extra locals
-  emitExtraLocalsForTensorsAtDenseLvls(builder, loc, tids, lvls);
+  emitExtraLocalsForTensorsAtDenseLvls(builder, loc, tidLvls);
 
   // Updates reduction variables
   assert(after->getNumArguments() == o + reduc.size() + (needsUniv ? 1 : 0));
@@ -1140,15 +1131,12 @@ void LoopEmitter::prepareLoopOverTensorAtLvl(OpBuilder &builder, Location loc,
   llvm_unreachable("Unrecognized level-type!");
 }
 
-void LoopEmitter::emitExtraLocalsForTensorsAtDenseLvls(OpBuilder &builder,
-                                                       Location loc,
-                                                       ArrayRef<TensorId> tids,
-                                                       ArrayRef<Level> lvls) {
+void LoopEmitter::emitExtraLocalsForTensorsAtDenseLvls(
+    OpBuilder &builder, Location loc, ArrayRef<TensorLevel> tidLvls) {
   // Initialize dense positions. Note that we generate dense coordinates of the
   // output tensor unconditionally, since they may not appear in the lattice,
   // but may be needed for linearized codegen.
-  assert(tids.size() == lvls.size());
-  for (auto [tid, lvl] : llvm::zip(tids, lvls)) {
+  for (auto [tid, lvl] : unpackTensorLevelRange(tidLvls)) {
     if (isDenseDLT(lvlTypes[tid][lvl])) {
       // Slice-driven dense level should have be handled already.
       if (!dependentLvlMap[tid][lvl].empty())
@@ -1175,8 +1163,7 @@ void LoopEmitter::exitForLoop(RewriterBase &rewriter, Location loc,
                               MutableArrayRef<Value> reduc) {
   const LoopInfo &loopInfo = loopStack.back();
   rewriter.setInsertionPointToEnd(loopInfo.userCodeBlock);
-  for (auto [tid, lvl, reduced] : llvm::zip(
-           loopInfo.slicedTids, loopInfo.slicedLvls, loopInfo.sliceReduced)) {
+  for (auto [tid, lvl, reduced] : loopInfo.sliceDrivenInfo) {
     SliceInfo &info = sliceStack[tid].back();
     assert(isDenseDLT(lvlTypes[tid][lvl]));
     assert(*info.slicedOnLvl == lvl && !reduced);
@@ -1253,7 +1240,7 @@ void LoopEmitter::exitForLoop(RewriterBase &rewriter, Location loc,
   // Finished iterating a tensor, clean up
   // We only do the clean up on for loop as while loops do not necessarily
   // finish the iteration on a sparse tensor
-  for (auto [tid, lvl] : llvm::zip(loopInfo.tids, loopInfo.lvls)) {
+  for (auto [tid, lvl] : unpackTensorLevelRange(loopInfo.tidLvls)) {
     // Reset to null.
     coords[tid][lvl] = Value();
     posits[tid][lvl] = Value();
@@ -1278,8 +1265,7 @@ void LoopEmitter::exitWhileLoop(OpBuilder &builder, Location loc,
   unsigned o = 0;
   SmallVector<Value> operands;
   unsigned delta = 0;
-  for (auto [tid, lvl, resolved] : llvm::zip(
-           loopInfo.slicedTids, loopInfo.slicedLvls, loopInfo.sliceReduced)) {
+  for (auto [tid, lvl, resolved] : loopInfo.sliceDrivenInfo) {
     // TODO: handle dense.
     assert(isCompressedDLT(lvlTypes[tid][lvl]));
     levelReducedDep[tid][lvl]--;
@@ -1291,7 +1277,7 @@ void LoopEmitter::exitWhileLoop(OpBuilder &builder, Location loc,
     // fully reduced while op for iterating one slices.
     // FIXME: since we didn't implement coiteration, this must be iteration
     // just on fully resolved slice.
-    assert(loopInfo.slicedTids.size() == 1 && loopInfo.tids.empty());
+    assert(loopInfo.sliceDrivenInfo.size() == 1 && loopInfo.tidLvls.empty());
     // The if guard to filter out out-range coordinates.
     assert(llvm::isa<scf::IfOp>(builder.getInsertionBlock()->getParentOp()));
     posits[tid][lvl] = whileOp->getResult(o++);
@@ -1308,7 +1294,7 @@ void LoopEmitter::exitWhileLoop(OpBuilder &builder, Location loc,
   };
 
   Value one = C_IDX(1);
-  for (auto [tid, dstLvl] : llvm::zip(loopInfo.tids, loopInfo.lvls)) {
+  for (auto [tid, dstLvl] : unpackTensorLevelRange(loopInfo.tidLvls)) {
     const auto lvlTp = lvlTypes[tid][dstLvl];
     if (isCompressedDLT(lvlTp) || isSingletonDLT(lvlTp) ||
         isCompressedWithHiDLT(lvlTp)) {
@@ -1376,7 +1362,6 @@ void LoopEmitter::exitCurrentLoop(RewriterBase &rewriter, Location loc,
   // Clean up the values, it would help use to discover potential bug at a
   // earlier stage (instead of silently using a wrong value).
   const LoopInfo &loopInfo = loopStack.back();
-  assert(loopInfo.tids.size() == loopInfo.lvls.size());
   SmallVector<Value> red;
   if (llvm::isa<scf::WhileOp>(loopInfo.loop)) {
     exitWhileLoop(rewriter, loc, reduc);

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
index d069633836e6e..fe4354de3bb45 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
@@ -42,6 +42,8 @@ namespace sparse_tensor {
 // typecheck this to avoid mixups in the code.
 using LoopOrd = unsigned;
 
+// A compressed <tensor id, level> pair.
+using TensorLevel = unsigned;
 //===----------------------------------------------------------------------===//
 // SparseTensorLoopEmiter class, manages sparse tensors and helps to
 // generate loop structure to (co)-iterate sparse tensors.
@@ -134,7 +136,7 @@ class LoopEmitter {
   ///   // loop sequence end.
   /// }
   void enterNewLoopSeq(OpBuilder &builder, Location loc,
-                       ArrayRef<TensorId> tids, ArrayRef<Level> lvls);
+                       ArrayRef<TensorLevel> tidLvls);
 
   /// Exits the current loop sequence, this will reset universal index to 0.
   void exitCurrentLoopSeq(OpBuilder &builder, Location loc);
@@ -149,8 +151,7 @@ class LoopEmitter {
   /// The function will also perform in-place update on the `reduc` vector to
   /// return the reduction variable used inside the generated loop.
   Operation *enterLoopOverTensorAtLvl(OpBuilder &builder, Location loc,
-                                      ArrayRef<TensorId> tids,
-                                      ArrayRef<Level> lvls,
+                                      ArrayRef<TensorLevel> tidLvls,
                                       MutableArrayRef<Value> reduc = {},
                                       bool isParallel = false);
 
@@ -159,13 +160,13 @@ class LoopEmitter {
                                             AffineExpr affine,
                                             MutableArrayRef<Value> reduc = {});
 
-  void genDenseAffineAddress(OpBuilder &builder, Location loc, TensorId tid,
-                             Level lvl, AffineExpr lvlExpr);
+  void genDenseAffineAddress(OpBuilder &builder, Location loc,
+                             TensorLevel tidLvl, AffineExpr lvlExpr);
 
   /// Emits a co-iteration loop over a set of tensors.
   Operation *enterCoIterationOverTensorsAtLvls(
-      OpBuilder &builder, Location loc, ArrayRef<TensorId> tids,
-      ArrayRef<Level> lvls, bool needsUniv, MutableArrayRef<Value> reduc = {});
+      OpBuilder &builder, Location loc, ArrayRef<TensorLevel> tidLvls,
+      bool needsUniv, MutableArrayRef<Value> reduc = {});
 
   void exitCurrentLoop(RewriterBase &rewriter, Location loc,
                        MutableArrayRef<Value> reduc = {});
@@ -190,6 +191,31 @@ class LoopEmitter {
     return n < getCurrentDepth() ? loopStack[n].iv : Value();
   }
 
+  /// Gets the total number of tensors that loopEmitter is operating on.
+  unsigned getNumTensors() const { return tensors.size(); }
+
+  /// Compresses a TensorId and Level into a TensorLevel.
+  constexpr TensorLevel makeTensorLevel(TensorId t, Level l) const {
+    return l * getNumTensors() + t;
+  }
+
+  /// De-compresses a TensorLevel back to a pair of TensorId and Level.
+  std::pair<TensorId, Level> unpackTensorLevel(TensorLevel tidLvl) const {
+    unsigned nt = getNumTensors();
+    return std::make_pair(tidLvl % nt, tidLvl / nt);
+  }
+
+  /// Converts a range of TensorLevel to a range of std::pair<TensorId, Level>
+  template <class ContainerTy>
+  auto unpackTensorLevelRange(ContainerTy &&c) const {
+    using EltTy = decltype(*c.begin());
+    static_assert(std::is_same_v<llvm::remove_cvref_t<EltTy>, TensorLevel>,
+                  "Must be unpacking a TensorLevel range");
+    return llvm::map_range(std::forward<ContainerTy>(c), [this](EltTy tl) {
+      return this->unpackTensorLevel(tl);
+    });
+  }
+
   ///
   /// Getters.
   ///
@@ -209,32 +235,30 @@ class LoopEmitter {
   }
 
 private:
+  // A tuple that stored the slice-driven loop information.
+  struct SliceLoopInfo final {
+    SliceLoopInfo(TensorId tid, Level lvl, bool reduced)
+        : tid(tid), lvl(lvl), reduced(reduced) {}
+    TensorId tid;
+    Level lvl;
+    bool reduced;
+  };
   // LoopInfo stores information of a loop generated by LoopEmitter. E.g.,
   // the set of tensors levels that the loop is iterating over.
   struct LoopInfo final {
-    LoopInfo(ArrayRef<TensorId> tids, ArrayRef<Level> lvls,
-             ArrayRef<TensorId> slicedTids, ArrayRef<Level> slicedLvls,
-             ArrayRef<bool> sliceReduced, Operation *loop, Block *userBlock,
-             Value iv, StringAttr loopTag)
-        : tids(tids), lvls(lvls), slicedTids(slicedTids),
-          slicedLvls(slicedLvls), sliceReduced(sliceReduced), loop(loop),
+    LoopInfo(ArrayRef<TensorLevel> tidLvls,
+             ArrayRef<SliceLoopInfo> sliceDrivenInfo, Operation *loop,
+             Block *userBlock, Value iv, StringAttr loopTag)
+        : tidLvls(tidLvls), sliceDrivenInfo(sliceDrivenInfo), loop(loop),
           userCodeBlock(userBlock), iv(iv) {
       // Attached a special tag to loop emitter generated loop.
       if (loopTag)
         loop->setAttr(LoopEmitter::getLoopEmitterLoopAttrName(), loopTag);
     }
-    // TODO: maybe use a vector<pair> for tid and lvl?
-    //       (Or compress them together with a `TensorLoopId`.)
-    // The set of tensors that the loop is operating on
-    const llvm::SmallVector<TensorId> tids;
-    // The corresponding levels for the tensors
-    const llvm::SmallVector<Level> lvls;
-    // The set of tensors for slice-driven loop conditions.
-    const llvm::SmallVector<TensorId> slicedTids;
-    // The corresponding level for slice-driven tensors.
-    const llvm::SmallVector<Level> slicedLvls;
-    // Whether the tensor is fully reduced (e.g., i + j => j).
-    const llvm::SmallVector<bool> sliceReduced;
+    // The set of <tensor, lvl> that the loop is operating on
+    const llvm::SmallVector<TensorLevel> tidLvls;
+    // Slice-driven loop conditions.
+    const llvm::SmallVector<SliceLoopInfo> sliceDrivenInfo;
     const Operation *loop;      // the loop operation
     Block *const userCodeBlock; // the block holding users' generated code.
     const Value iv;             // the induction variable for the loop
@@ -295,8 +319,6 @@ class LoopEmitter {
                                                  Location loc, Value crd,
                                                  TensorId tid, Level lvl);
 
-  unsigned getNumTensors() const { return tensors.size(); }
-
   bool isOutputTensor(TensorId tid) const {
     return hasOutput && tid == getNumTensors() - 1;
   }
@@ -318,8 +340,7 @@ class LoopEmitter {
   /// point used to generate the loops, but are still required to generate
   /// expressions.
   void emitExtraLocalsForTensorsAtDenseLvls(OpBuilder &builder, Location loc,
-                                            ArrayRef<TensorId> tids,
-                                            ArrayRef<Level> lvls);
+                                            ArrayRef<TensorLevel> tidLvls);
 
   /// Emits a for loop to iterate over a tensor level with the provided lower
   /// bound `lo` and upper bound `hi`.

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 73cbe611aa376..4aba829d7dbdd 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -990,14 +990,12 @@ struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
     for (Level l = 0; l < lvlRank; l++) {
       // TODO: provide utility function for loop sequences that only contains
       // one for loop?
-      // FIXME(wrengr): what is this "ld" supposed to be really?
-      const Level ld = op.getOrder() ? op.getOrder()->getDimPosition(l) : l;
-      const SmallVector<TensorId, 1> tids{0};
-      loopEmitter.enterNewLoopSeq(rewriter, loc, tids, ld);
+      const SmallVector<TensorLevel, 1> tidLvls{
+          loopEmitter.makeTensorLevel(0, l)};
+      loopEmitter.enterNewLoopSeq(rewriter, loc, tidLvls);
       // Note that reduc will be taken care of by loop emitter and get updated
       // in place.
-
-      loopEmitter.enterLoopOverTensorAtLvl(rewriter, loc, tids, l, reduc);
+      loopEmitter.enterLoopOverTensorAtLvl(rewriter, loc, tidLvls, reduc);
     }
 
     SmallVector<Value> lcvs;

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 4ed879121e98c..afeabb33fcd78 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -1296,13 +1296,16 @@ static bool isParallelFor(CodegenEnv &env, bool isOuter, bool isSparse) {
 
 /// Generates a for-loop on a single index.
 static Operation *genFor(CodegenEnv &env, OpBuilder &builder, bool isOuter,
-                         bool isInner, LoopId ldx, ArrayRef<TensorId> tids,
-                         ArrayRef<Level> lvls) {
+                         bool isInner, LoopId ldx,
+                         ArrayRef<TensorLevel> tidLvls) {
   linalg::GenericOp op = env.op();
   Location loc = op.getLoc();
   auto iteratorTypes = op.getIteratorTypesArray();
-  bool isSparse = llvm::any_of(tids, [ldx, &env](TensorId tid) {
-    const auto dlt = env.dlt(tid, ldx);
+  bool isSparse = llvm::any_of(tidLvls, [ldx, &env](TensorLevel tidLvl) {
+    // Queries the DLT based on the tensor id and loop idx, as requested by
+    // `CodegenEnv::dlt(TensorId, LoopIdx)`. The returned DLT from CodegenEnv
+    // should be consistent with the DLT indexed by <TensorId, Level>.
+    const auto dlt = env.dlt(env.unpackTensorLevel(tidLvl).first, ldx);
     return isCompressedDLT(dlt) || isSingletonDLT(dlt);
   });
 
@@ -1310,11 +1313,10 @@ static Operation *genFor(CodegenEnv &env, OpBuilder &builder, bool isOuter,
 
   Operation *loop = *env.genLoopBoundary([&](MutableArrayRef<Value> reduc) {
     if (env.merger().isFilterLoop(ldx)) {
-      const TensorId tid = tids.front();
-      const Level lvl = lvls.front();
+      const auto [tid, lvl] = env.unpackTensorLevel(tidLvls.front());
       // tids/lvls must only have one value because filter loops only
       // corresponding to the one and only sparse tensor level.
-      assert(isSparse && tids.size() == 1 && lvls.size() == 1);
+      assert(isSparse && tidLvls.size() == 1);
       OpOperand *t = &op->getOpOperand(tid);
       auto enc = getSparseTensorEncoding(t->get().getType());
       // Retrieves the affine expression for the filter loop.
@@ -1324,8 +1326,8 @@ static Operation *genFor(CodegenEnv &env, OpBuilder &builder, bool isOuter,
       return env.emitter().enterFilterLoopOverTensorAtLvl(builder, loc, tid,
                                                           lvl, a, reduc);
     }
-    return env.emitter().enterLoopOverTensorAtLvl(builder, loc, tids, lvls,
-                                                  reduc, isParallel);
+    return env.emitter().enterLoopOverTensorAtLvl(builder, loc, tidLvls, reduc,
+                                                  isParallel);
   });
   assert(loop);
   return loop;
@@ -1333,13 +1335,12 @@ static Operation *genFor(CodegenEnv &env, OpBuilder &builder, bool isOuter,
 
 /// Emit a while-loop for co-iteration over multiple indices.
 static Operation *genWhile(CodegenEnv &env, OpBuilder &builder, LoopId idx,
-                           bool needsUniv, ArrayRef<TensorId> tids,
-                           ArrayRef<Level> lvls) {
+                           bool needsUniv, ArrayRef<TensorLevel> tidLvls) {
   Operation *loop = *env.genLoopBoundary([&](MutableArrayRef<Value> reduc) {
     // Construct the while-loop with a parameter for each
     // index.
     return env.emitter().enterCoIterationOverTensorsAtLvls(
-        builder, env.op().getLoc(), tids, lvls, needsUniv, reduc);
+        builder, env.op().getLoc(), tidLvls, needsUniv, reduc);
   });
   assert(loop);
   return loop;
@@ -1348,16 +1349,15 @@ static Operation *genWhile(CodegenEnv &env, OpBuilder &builder, LoopId idx,
 /// Generates a for-loop or a while-loop, depending on whether it implements
 /// singleton iteration or co-iteration over the given conjunction.
 static Operation *genLoop(CodegenEnv &env, OpBuilder &builder, LoopOrd at,
-                          bool needsUniv, ArrayRef<TensorId> tids,
-                          ArrayRef<Level> lvls, bool isFor) {
-  assert(tids.size() == lvls.size());
+                          bool needsUniv, ArrayRef<TensorLevel> tidLvls,
+                          bool isFor) {
   const LoopId idx = env.topSortAt(at);
   if (isFor) {
     bool isOuter = at == 0;
     bool isInner = at == env.topSortSize() - 1;
-    return genFor(env, builder, isOuter, isInner, idx, tids, lvls);
+    return genFor(env, builder, isOuter, isInner, idx, tidLvls);
   }
-  return genWhile(env, builder, idx, needsUniv, tids, lvls);
+  return genWhile(env, builder, idx, needsUniv, tidLvls);
 }
 
 /// Generates the induction structure for a while-loop.
@@ -1480,8 +1480,7 @@ static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp,
   const LatPointId l0 = env.set(lts)[0];
   bool needsUniv = false;
 
-  SmallVector<TensorId> tids;
-  SmallVector<Level> lvls;
+  SmallVector<TensorLevel> tidLvls;
   env.merger().foreachTensorLoopId(l0, [&](TensorLoopId b, TensorId tid,
                                            std::optional<Level> lvl,
                                            DimLevelType dlt, bool isIdxReduc) {
@@ -1493,12 +1492,11 @@ static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp,
       // Only when this is a index reduction loop, can the dlt be undefined.
       assert(!isUndefDLT(dlt) || isIdxReduc);
       // sparse/singleton levels, or a dense/sparse index reduction loop.
-      tids.push_back(tid);
-      lvls.push_back(*lvl);
+      tidLvls.push_back(env.makeTensorLevel(tid, *lvl));
     }
   });
 
-  env.emitter().enterNewLoopSeq(builder, env.op().getLoc(), tids, lvls);
+  env.emitter().enterNewLoopSeq(builder, env.op().getLoc(), tidLvls);
 
   // Maintain the universal index only if it is actually
   // consumed by a subsequent lattice point.
@@ -1529,7 +1527,8 @@ static void genConstantDenseAddressFromLevel(CodegenEnv &env,
       // FIXME: `toOrigDim` is deprecated.
       AffineExpr lvlExpr = lvlExprs[toOrigDim(enc, l)];
       if (enc.isDenseLvl(l) && lvlExpr.isa<AffineConstantExpr>())
-        env.emitter().genDenseAffineAddress(builder, loc, tid, l, lvlExpr);
+        env.emitter().genDenseAffineAddress(
+            builder, loc, env.makeTensorLevel(tid, l), lvlExpr);
       else
         return; // break on first non-dense non-constant level
     }
@@ -1548,23 +1547,21 @@ static void genInitConstantDenseAddress(CodegenEnv &env,
 
 /// Return true if the lattices bit can be iterated by a for loop.
 static bool translateBitsToTidLvlPairs(
-    CodegenEnv &env, LatPointId li, LoopId ldx, SmallVectorImpl<TensorId> &tids,
-    SmallVectorImpl<Level> &lvls, SmallVectorImpl<TensorId> &affineTids,
-    SmallVectorImpl<Level> &affineLvls, SmallVectorImpl<AffineExpr> &exps) {
+    CodegenEnv &env, LatPointId li, LoopId ldx,
+    SmallVectorImpl<TensorLevel> &tidLvls,
+    SmallVectorImpl<std::pair<TensorLevel, AffineExpr>> &affineTidLvls) {
   const BitVector &simple = env.lat(li).simple;
   const TensorId outTid = env.merger().getOutTensorID();
   const std::optional<Level> outLvl = env.merger().getLvl(outTid, ldx);
 
   unsigned numloopCond = 0;
   bool hasNonUnique = false;
-
   env.merger().foreachTensorLoopId(
       li, [&, ldx](TensorLoopId b, TensorId tid, std::optional<Level> lvl,
                    DimLevelType dlt, bool isIdxReduc) {
         if (simple[b]) {
           if (isIdxReduc) {
-            tids.push_back(tid);
-            lvls.push_back(*lvl);
+            tidLvls.push_back(env.makeTensorLevel(tid, *lvl));
             numloopCond++;
             return;
           }
@@ -1582,12 +1579,10 @@ static bool translateBitsToTidLvlPairs(
               return;
           }
           hasNonUnique = !isUniqueDLT(dlt) || hasNonUnique;
-          tids.push_back(tid);
-          lvls.push_back(*lvl);
+          tidLvls.push_back(env.makeTensorLevel(tid, *lvl));
           numloopCond++;
         } else if (isDenseDLT(dlt) || isIdxReduc) {
-          tids.push_back(tid);
-          lvls.push_back(*lvl);
+          tidLvls.push_back(env.makeTensorLevel(tid, *lvl));
         } else {
           assert(isUndefDLT(dlt));
           linalg::GenericOp op = env.op();
@@ -1625,9 +1620,7 @@ static bool translateBitsToTidLvlPairs(
                 // computeIterationGraph), another more admissible approach
                 // might be accepting out-of-order access between consecutive
                 // dense levels.
-                affineTids.push_back(tid);
-                affineLvls.push_back(l);
-                exps.push_back(exp);
+                affineTidLvls.emplace_back(env.makeTensorLevel(tid, l), exp);
               }
             }
           }
@@ -1638,8 +1631,7 @@ static bool translateBitsToTidLvlPairs(
     // Note that we generate dense indices of the output tensor
     // unconditionally, since they may not appear in the lattice, but may be
     // needed for linearized env.
-    tids.push_back(outTid);
-    lvls.push_back(*outLvl);
+    tidLvls.push_back(env.makeTensorLevel(outTid, *outLvl));
   }
 
   assert(numloopCond > 0);
@@ -1653,29 +1645,27 @@ static std::pair<Operation *, bool> startLoop(CodegenEnv &env,
                                               OpBuilder &builder, LoopOrd at,
                                               LatPointId li, bool needsUniv) {
   // The set of tensors + lvls to generate loops on
-  SmallVector<TensorId> tids, affineTids;
-  SmallVector<Level> lvls, affineLvls;
+  SmallVector<TensorLevel> tidLvls;
   // The set of dense tensors with non-trivial affine expression that just
   // becomes invariant and the address shall now be generated at the current
   // level.
-  SmallVector<AffineExpr> affines;
-  bool isSingleCond = translateBitsToTidLvlPairs(
-      env, li, env.topSortAt(at), tids, lvls, affineTids, affineLvls, affines);
+  SmallVector<std::pair<TensorLevel, AffineExpr>> affineTidLvls;
+  bool isSingleCond = translateBitsToTidLvlPairs(env, li, env.topSortAt(at),
+                                                 tidLvls, affineTidLvls);
 
   // Emit the for/while-loop control.
-  Operation *loop =
-      genLoop(env, builder, at, needsUniv, tids, lvls, isSingleCond);
+  Operation *loop = genLoop(env, builder, at, needsUniv, tidLvls, isSingleCond);
   Location loc = env.op().getLoc();
-  for (auto [tid, lvl, exp] : llvm::zip(affineTids, affineLvls, affines)) {
-    env.emitter().genDenseAffineAddress(builder, loc, tid, lvl, exp);
+  for (auto [tidLvl, exp] : affineTidLvls) {
+    env.emitter().genDenseAffineAddress(builder, loc, tidLvl, exp);
   }
 
   // Until now, we have entered every <tid, lvl> pair in {cond, extra,
   // affine}Tids/Lvls. The addresses of the upcoming levels which are dependent
   // on constant affines expression may now be determined.
-  auto allTids = llvm::concat<TensorId>(tids, affineTids);
-  auto allLvls = llvm::concat<Level>(lvls, affineLvls);
-  for (auto [tid, lvl] : llvm::zip(allTids, allLvls)) {
+  auto allTidLvls =
+      llvm::concat<TensorLevel>(tidLvls, llvm::make_first_range(affineTidLvls));
+  for (auto [tid, lvl] : env.unpackTensorLevelRange(allTidLvls)) {
     if (tid != env.merger().getOutTensorID())
       genConstantDenseAddressFromLevel(env, builder, tid, lvl + 1);
   }


        


More information about the Mlir-commits mailing list