[Mlir-commits] [mlir] fd68d36 - [mlir][sparse] unifying enterLoopOverTensorAtLvl and enterCoIterationOverTensorsAtLvls

Peiming Liu llvmlistbot at llvm.org
Wed Jun 14 13:03:17 PDT 2023


Author: Peiming Liu
Date: 2023-06-14T20:03:10Z
New Revision: fd68d36109c6fcebb6d758046b88b0664acccf51

URL: https://github.com/llvm/llvm-project/commit/fd68d36109c6fcebb6d758046b88b0664acccf51
DIFF: https://github.com/llvm/llvm-project/commit/fd68d36109c6fcebb6d758046b88b0664acccf51.diff

LOG: [mlir][sparse] unifying enterLoopOverTensorAtLvl and enterCoIterationOverTensorsAtLvls

The tensor levels are now explicitly categorized into different `LoopCondKind` to instruct LoopEmitter generate different code for different kinds of condition (e.g., `SparseCond`, `SparseSliceCond`, `SparseAffineIdxCond`, etc)

The process of generating a while loop is now dissembled into three steps and they are dispatched to different LoopCondKind handler.
1. Generate LoopCondition (e.g., `pos <= posHi` for `SparseCond`, `slice.isNonEmpty` for `SparseAffineIdxCond`)
2. Generate LoopBody (e.g., compute the coordinates)
3. Generate ExtraChecks (e.g., `if (onSlice(crd))` for `SparseSliceCond`)

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D152464

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
    mlir/test/Dialect/SparseTensor/sorted_coo.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
index f466cce68a34b..d0884ca482de2 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
@@ -101,6 +101,28 @@ static std::pair<Value, Value> fromSliceCrd(OpBuilder &builder, Location loc,
   return std::make_pair(crd, rem);
 }
 
+// Generates a bool value for while loop condition that tries to iterate over a
+// fully reduced level with affine index expression.
+static Value genSparseReducedAffineCond(OpBuilder &builder, Location loc,
+                                        Value crdBuf, Value crdHi, Value posit,
+                                        Value posHi) {
+  Value inBound = CMPI(ult, posit, posHi);
+  auto ifOp =
+      builder.create<scf::IfOp>(loc, builder.getI1Type(), inBound, true);
+  // if (inbound)
+  //   yield coord < crdHi
+  builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+  Value crd = genIndexLoad(builder, loc, crdBuf, posit);
+  YIELD(CMPI(ult, crd, crdHi));
+  // else
+  //   yield false
+  builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
+  YIELD(constantI1(builder, loc, false));
+
+  builder.setInsertionPointAfter(ifOp);
+  return ifOp.getResult(0);
+}
+
 std::pair<Value, Value>
 LoopEmitter::genSliceLegitPredicate(OpBuilder &builder, Location loc, Value crd,
                                     TensorId tid, Level lvl) {
@@ -470,6 +492,41 @@ void LoopEmitter::initializeLoopEmit(OpBuilder &builder, Location loc,
   localInsertPos = builder.getInsertionPoint()->getPrevNode();
 }
 
+void LoopEmitter::categorizeLoopCondition(
+    ArrayRef<TensorLevel> tidLvls, SmallVectorImpl<TensorLvlCond> &dnConds,
+    SmallVectorImpl<TensorLvlCond> &spConds) {
+  // Finds out the tensor level that we should use to generate loops. Amongs all
+  // the tensor levels, there is at most one sparse tensor level.
+  for (auto [t, l] : unpackTensorLevelRange(tidLvls)) {
+    assert(lvlTypes[t].size() > l); // Must be a valid tid, dim pair
+    auto lvlType = lvlTypes[t][l];
+    // Must be a recognizable DLT.
+    assert(isDenseDLT(lvlType) || isCompressedDLT(lvlType) ||
+           isCompressedWithHiDLT(lvlType) || isSingletonDLT(lvlType));
+
+    bool isSparse = !isDenseDLT(lvlType);
+    bool isSlice = isSparseSlices[t];
+    bool isAffine = !dependentLvlMap[t][l].empty();
+    bool isUnRedu = false;
+    // TODO: Supports affine index expression on sparse tensor slices.
+    assert(!isSlice || !isAffine);
+
+    // Whether the affine index expression has been fully reduced or not.
+    if (!dependentLvlMap[t][l].empty())
+      isUnRedu = !depFullyReduced(t, l);
+
+    auto &dstVec = isSparse ? spConds : dnConds;
+    dstVec.emplace_back(
+        makeTensorLevel(t, l),
+        makeLoopCondKind(isSparse, isSlice, isAffine, isUnRedu));
+  }
+
+  std::sort(spConds.begin(), spConds.end(), [](auto lhs, auto rhs) {
+    // AffineUnRed > Affine > Slice > Trivial
+    return static_cast<uint8_t>(lhs.second) > static_cast<uint8_t>(rhs.second);
+  });
+}
+
 void LoopEmitter::enterNewLoopSeq(OpBuilder &builder, Location loc,
                                   ArrayRef<TensorLevel> tidLvls) {
   // TODO: sort
@@ -561,7 +618,7 @@ Value LoopEmitter::genAffine(OpBuilder &builder, Location loc, AffineExpr a) {
   }
 }
 
-Operation *LoopEmitter::emitForLoopOverTensorAtLvl(
+std::pair<Operation *, Value> LoopEmitter::emitForLoopOverTensorAtLvl(
     OpBuilder &builder, Location loc, TensorId tid, Level dstLvl, Value lo,
     Value hi, MutableArrayRef<Value> reduc, bool isParallel) {
   bool isSparseCond = isCompressedDLT(lvlTypes[tid][dstLvl]) ||
@@ -651,166 +708,433 @@ Operation *LoopEmitter::emitForLoopOverTensorAtLvl(
 
   assert(crd);
   coords[tid][dstLvl] = crd;
-  return loop;
+  return {loop, crd};
 }
 
-Operation *LoopEmitter::emitWhileLoopOverSliceAtSparseLvl(
-    OpBuilder &builder, Location loc, Value pLo, Value pHi, Value offset,
-    Value sliceSize, TensorId tid, Level lvl, MutableArrayRef<Value> reduc) {
-  // TODO: we should generalize the method to support iteration over for
-  // normal slices as well to allow early break.
-  Operation *insertPoint = nullptr;
-  Operation *loop =
-      genSliceLvlTraverseLoop(
-          builder, loc, pLo, pHi, offset, sliceSize, tid, lvl, reduc,
-          /*genYield=*/false, // unaware of the yield values from user yet
-          [this, tid, lvl, reduc, offset,
-           &insertPoint](OpBuilder &builder, Location loc, Value iv,
-                         MutableArrayRef<Value> innerReduc) {
-            assert(innerReduc.size() == reduc.size());
-            // Updates users' reduction variable inplace
-            for (unsigned i = 0, e = reduc.size(); i < e; i++)
-              reduc[i] = innerReduc[i];
-            // Loads the coordinates.
-            Value absC =
-                genIndexLoad(builder, loc, coordinatesBuffers[tid][lvl], iv);
-
-            // We need to substract the offset to get relative coordinates.
-            // TODO: how to assert relC >=0 during runtime?
-            insertPoint = builder.create<arith::SubIOp>(loc, absC, offset);
-            posits[tid][lvl] = iv;
-            coords[tid][lvl] = insertPoint->getResult(0);
-          })
-          .first;
-  // Sets the insertionn pointer inside loop body.
-  builder.setInsertionPointAfter(insertPoint);
-  return loop;
+Value LoopEmitter::genWhileLoopConditions(OpBuilder &builder, Location loc,
+                                          ValueRange ivs, TensorLvlCond cond) {
+  auto [tid, lvl] = unpackTensorLevel(cond.first);
+
+  switch (cond.second) {
+  case LoopCondKind::SparseCond: {
+    const auto reassoc = getCollapseReassociation(tid, lvl);
+    assert(reassoc.size() == ivs.size());
+    assert(reassoc.size() == 1 || isUniqueCOOType(tensors[tid].getType()));
+    // We used the first level bound as the bound the collapsed set of levels.
+    return CMPI(ult, ivs.back(), highs[tid][reassoc.front()]);
+  }
+  case LoopCondKind::SparseSliceCond: {
+    assert(ivs.size() == 1);
+    return CMPI(ult, ivs.back(), highs[tid][lvl]);
+  }
+  case LoopCondKind::SparseAffineCond: {
+    assert(ivs.size() == 1);
+    Value crdHi; // loop upper bound
+    {
+      OpBuilder::InsertionGuard guard(builder);
+      Operation *loop = builder.getInsertionBlock()->getParentOp();
+      // crdHi is a loop invariant, hosit the computation outside the loop.
+      if (llvm::isa_and_nonnull<scf::WhileOp>(loop))
+        builder.setInsertionPoint(loop);
+      crdHi = ADDI(getMostRecentSliceOnLvl(tid, lvl).offset,
+                   sliceSizes[tid][lvl].back());
+    }
+    assert(crdHi);
+    return genSparseReducedAffineCond(builder, loc,
+                                      coordinatesBuffers[tid][lvl], crdHi,
+                                      ivs[0], highs[tid][lvl]);
+  }
+  case LoopCondKind::SparseAffineUnRedCond: {
+    assert(ivs.size() == 3);
+    return ivs.front(); // isNonEmpty
+  }
+  default:
+    llvm_unreachable("Unhandled LoopCondKind");
+  }
+  llvm_unreachable("Unhandled LoopCondKind");
 }
 
-Operation *LoopEmitter::enterLoopOverTensorAtLvl(OpBuilder &builder,
-                                                 Location loc,
-                                                 ArrayRef<TensorLevel> tidLvls,
-                                                 MutableArrayRef<Value> reduc,
-                                                 bool isParallel) {
-  // TODO: support multiple return on parallel for?
-  assert(!isParallel || reduc.size() <= 1);
-  bool isSparseCond = false, isSparseSliceCond = false;
-  auto [tid, lvl] = unpackTensorLevel(tidLvls.front());
+std::optional<Value> LoopEmitter::genWhileLoopBody(OpBuilder &builder,
+                                                   Location loc, ValueRange ivs,
+                                                   TensorLvlCond cond) {
+  auto [tid, lvl] = unpackTensorLevel(cond.first);
 
-  // Finds out the tensor level that we should use to generate loops. Amongs all
-  // the tensor levels, there is at most one sparse tensor level.
-  for (auto [t, l] : unpackTensorLevelRange(tidLvls)) {
-    assert(lvlTypes[t].size() > l);         // Must be a valid tid, dim pair
-    assert(!coords[t][l] ||                 // We cannot re-enter the same level
-           !dependentLvlMap[t][l].empty()); // unless it is a slice-driver loop
-    auto lvlType = lvlTypes[t][l];
-    // Must be a recognizable DLT.
-    assert(isDenseDLT(lvlType) || isCompressedDLT(lvlType) ||
-           isCompressedWithHiDLT(lvlType) || isSingletonDLT(lvlType));
+  switch (cond.second) {
+  case LoopCondKind::SparseCond: {
+    const auto reassoc = getCollapseReassociation(tid, lvl);
+    assert(reassoc.size() == 1 || isUniqueCOOType(tensors[tid].getType()));
+    // Links the SSA chain for segHi.
+    for (unsigned i = 0, e = reassoc.size() - 1; i < e; i++)
+      if (!isUniqueDLT(lvlTypes[tid][reassoc[i]]))
+        segHi[tid][reassoc[i]] = ivs[i];
+
+    // Updates position. For collapsed COO, the position is the same across
+    // consecutive levels.
+    for (auto srcLvl : reassoc)
+      posits[tid][srcLvl] = ivs.back();
+
+    // Update coordinates.
+    coords[tid][lvl] = genSparseCrd(builder, loc, tid, lvl);
+    return std::nullopt;
+  }
+  case LoopCondKind::SparseSliceCond: {
+    assert(ivs.size() == 1);
+    posits[tid][lvl] = ivs.front();
+    Value sCrd = genSparseCrd(builder, loc, tid, lvl);
+    // Converts the coordinate loaded from the actual sparse tensor to the
+    // coordinates in the sparse slice.
+    auto [dCrd, pred] = genSliceLegitPredicate(builder, loc, sCrd, tid, lvl);
+    coords[tid][lvl] = dCrd;
+    return pred;
+  }
+  case LoopCondKind::SparseAffineCond: {
+    assert(ivs.size() == 1);
+    // Coord is the relative offset related to its parents.
+    assert(sliceStack[tid].back().depth == 1 && "TODO: not yet implement");
+    // Update c = absOffset[lvl][depth] - absOffset[lvl][depth - 1]
+    Value posit = ivs[0];
+    Value crdBuf = coordinatesBuffers[tid][lvl];
+    // We need to substract the offset to get relative coordinates.
+    // TODO: Maybe assert relC >=0 during runtime in debug build?
+    Value absC = genIndexLoad(builder, loc, crdBuf, posit);
+    auto relC = SUBI(absC, getFinalSliceOnLvl(tid, lvl).offset);
+    posits[tid][lvl] = posit;
+    coords[tid][lvl] = relC;
+    return std::nullopt;
+  }
+  case LoopCondKind::SparseAffineUnRedCond: {
+    assert(ivs.size() == 3);
+    // Coord is the relative offset related to its parents.
+    // Update c = absOffset[lvl][depth] - absOffset[lvl][depth - 1]
+    assert(sliceStack[tid].back().depth == 1 && "TODO: not yet implement");
+    // Updates the current slice info
+    SliceInfo &sliceInfo = sliceStack[tid].back();
+    sliceInfo.isNonEmpty = ivs[0];
+    sliceInfo.minCrd = ivs[1];
+    sliceInfo.offset = ivs[2];
+    coords[tid][lvl] = sliceInfo.offset;
+    // No extra check is needed before accessing the tensor level.
+    return std::nullopt;
+  }
+  default:
+    llvm_unreachable("Unhandled LoopCondKind");
+  }
+  llvm_unreachable("Unhandled LoopCondKind");
+}
 
-    // This is a slice-driven loop on sparse level.
-    if (!dependentLvlMap[t][l].empty() && !isDenseDLT(lvlType)) {
-      assert(!isSparseSliceCond && !isSparseCond);
-      isSparseSliceCond = true;
-      tid = t;
-      lvl = l;
-      continue;
+ValueRange LoopEmitter::genCheckedValue(OpBuilder &builder, Location loc,
+                                        Value pred, ValueRange curArgs,
+                                        TensorLvlCond cond) {
+  // Currently only sparse slice condition need extra check.
+  assert(isSliceCond(cond.second) && isSparseCond(cond.second));
+  assert(curArgs.size() == 1);
+  Value nextPos = ADDI(curArgs.front(), C_IDX(1));
+  return SELECT(pred, curArgs.front(), nextPos)->getResults();
+}
+
+std::pair<Operation *, Value> LoopEmitter::emitWhileLoopOverTensorsAtLvls(
+    OpBuilder &builder, Location loc, ArrayRef<TensorLvlCond> spConds,
+    MutableArrayRef<Value> reduc, bool needsUniv) {
+  // NOTE: the slice driven tensor-related reduction variable must
+  // appear before normal tensors.
+  assert(!spConds.empty());
+
+  // The set of induction variables for the while loop.
+  SmallVector<Value> ivs;
+  // Segement sizes for induction variables used for 
diff erent kinds of loop
+  // conditions.
+  SmallVector<unsigned> opSegSize;
+
+  // Construct the while-loop with a parameter for each coordinate.
+  for (auto [tl, cKind] : spConds) {
+    auto [tid, lvl] = unpackTensorLevel(tl);
+    const auto lvlTp = lvlTypes[tid][lvl];
+    // Dense level are handled by the shared univeral index.
+    assert(!isDenseCond(cKind));
+    // Must be a recognizable sparse level.
+    assert(isCompressedDLT(lvlTp) || isCompressedWithHiDLT(lvlTp) ||
+           isSingletonDLT(lvlTp));
+
+    unsigned prevSz = ivs.size();
+    const auto reassoc = getCollapseReassociation(tid, lvl);
+    if (isAffineIdxCond(cKind)) {
+      // TODO: Support view-based reshape on sparse levels with affine index
+      // expressions.
+      assert(reassoc.size() == 1);
+      if (isAffineIdxUnRedCond(cKind)) {
+        SliceInfo &sliceInfo = sliceStack[tid].back();
+        // The order matters!
+        ivs.push_back(sliceInfo.isNonEmpty);
+        ivs.push_back(sliceInfo.minCrd);
+        ivs.push_back(sliceInfo.offset);
+      } else {
+        ivs.push_back(posits[tid][lvl]); // loop lower bound (pos low).
+      }
+      // We reduced one more dependency after entering the loop.
+      levelReducedDep[tid][lvl]++;
+    } else {
+      assert(dependentLvlMap[tid][lvl].empty());
+      for (unsigned i = 0, e = reassoc.size() - 1; i < e; i++) {
+        // This is the segment high for each non-unique levels.
+        if (!isUniqueDLT(lvlTypes[tid][reassoc[i]]))
+          ivs.push_back(C_IDX(0));
+      }
+      const Value pos = posits[tid][reassoc.front()];
+      ivs.push_back(pos);
     }
+    opSegSize.push_back(ivs.size() - prevSz);
+  }
+
+  // The position where user-supplied reduction variable starts.
+  ivs.append(reduc.begin(), reduc.end());
+  // Update universal index.
+  if (needsUniv)
+    ivs.push_back(loopSeqStack.back().first);
+
+  // Ensures all operands are valid.
+  assert(llvm::all_of(ivs, [](Value v) { return v != nullptr; }));
+  TypeRange types = ValueRange(ivs).getTypes();
+  auto whileOp = builder.create<scf::WhileOp>(loc, types, ivs);
+
+  SmallVector<Location> locs(types.size(), loc);
+  Block *before = builder.createBlock(&whileOp.getBefore(), {}, types, locs);
+  Block *after = builder.createBlock(&whileOp.getAfter(), {}, types, locs);
 
-    bool isSparse = isCompressedDLT(lvlType) || isSingletonDLT(lvlType) ||
-                    isCompressedWithHiDLT(lvlType);
-    // We can at most have one sparse input, otherwise, a while loop is
-    // required to co-iterate multiple sparse tensors.
-    assert(!isSparseCond || !isSparse);
-    assert(!isSparseSliceCond || !isSparseCond);
-    if (isSparse) {
-      tid = t;
-      lvl = l;
+  // Generates loop conditions.
+  builder.setInsertionPointToStart(before);
+  ValueRange bArgs = before->getArguments();
+  Value whileCond = nullptr; // bool values for loop condition.
+  for (auto [c, segSz] : llvm::zip_equal(spConds, opSegSize)) {
+    Value cv = genWhileLoopConditions(builder, loc, bArgs.take_front(segSz), c);
+    bArgs = bArgs.drop_front(segSz);
+    whileCond = !whileCond ? cv : ANDI(whileCond, cv);
+  }
+  // The remaining block arguments are user-provided reduction values and an
+  // optional universal index. Make sure their sizes match.
+  assert(bArgs.size() == reduc.size() + needsUniv ? 1 : 0);
+  builder.create<scf::ConditionOp>(loc, whileCond, before->getArguments());
+
+  // Generates loop body.
+  builder.setInsertionPointToStart(after);
+  ValueRange aArgs = after->getArguments();
+  // Since some LoopCondKind might need extra checks to filter out invalid
+  // iterations, we maintains another array to hold the iteration arguments to
+  // yield if the checks fails.
+  SmallVector<Value> nextArgs(aArgs.begin(), aArgs.end());
+  // A mutable alias for convenient slicing.
+  MutableArrayRef<Value> nextArgsRef = nextArgs;
+  Value extraPred = nullptr;
+  for (auto [c, segSz] : llvm::zip_equal(spConds, opSegSize)) {
+    ValueRange condArgs = aArgs.take_front(segSz);
+    auto pred = genWhileLoopBody(builder, loc, condArgs, c);
+    assert(pred.has_value() == isCondWithExtraCheck(c.second));
+    if (pred.has_value()) {
+      // We need all extra checks to pass.
+      extraPred = extraPred == nullptr ? *pred : ANDI(*pred, extraPred);
+      ValueRange nxArgs = genCheckedValue(builder, loc, *pred, condArgs, c);
+      assert(nxArgs.size() == segSz);
+      // Update the value for cases when some check fails.
+      for (unsigned i = 0; i < segSz; i++) {
+        nextArgsRef[i] = nxArgs[i];
+      }
     }
-    isSparseCond = isSparseCond || isSparse;
+    aArgs = aArgs.drop_front(segSz);
+    nextArgsRef = nextArgsRef.drop_front(segSz);
   }
 
-  DimLevelType lvlType = lvlTypes[tid][lvl];
-  // TODO: Dense slice driven loop can be generated using for loop as well.
-  assert(!isSparseSliceCond || !isDenseDLT(lvlType));
-  bool isDenseSliceCond =
-      isDenseDLT(lvlType) && !dependentLvlMap[tid][lvl].empty();
-  // if the slice is fully reduced, we can now use TACO-based algorithm to
-  // iterate it.
+  if (extraPred) {
+    auto ifOp = builder.create<scf::IfOp>(loc, types, extraPred, /*else*/ true);
+    // Marks this special IfOp so that Sparsification does not finalizing it.
+    ifOp->setAttr(getLoopEmitterLoopAttrName(),
+                  StringAttr::get(builder.getContext(), "slice"));
+    // Links the SSA chain outside the if statement.
+    YIELD(ifOp->getResults());
 
-  Operation *l = nullptr;
+    // If not all slices are legit, yield the updated value.
+    builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
+    YIELD(nextArgs);
+
+    // If all slices are legit, start the user generated code.
+    builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+  }
+
+  for (auto [tid, dstLvl] : unpackTensorLevelFromCondRange(spConds)) {
+    const auto reassoc = getCollapseReassociation(tid, dstLvl);
+    assert(reassoc.size() == 1 || isUniqueCOOType(tensors[tid].getType()));
+    // TODO: Refactors this into smaller functions.
+    // NOTE: For all the collapsed level (except for the last one, that is why
+    // the loop ends with `reassoc.size() - 1`), as each iteration is advanced
+    // by the segment size of the last level, which does not always invalidate
+    // the segment size for the previous levels, thus we need to propagate the
+    // segment sizes across loop iterations and only forward if needed.
+    //
+    // E.g., for a COO tensor with the following coordinates array.
+    // (0, 0, 1),
+    // (0, 0, 2),
+    // (1, 1, 1),
+    // segHi[lvl=0] = segHi[lvl=1] = 2
+    // segHi[lvl=2] = 1,
+    // the first iteration does not invalidate segHi[0] and segHi[1]
+    for (unsigned i = 0, e = reassoc.size() - 1; i < e; i++) {
+      const Level srcLvl = reassoc[i];
+      if (!isUniqueDLT(lvlTypes[tid][srcLvl])) {
+        const Value pos = posits[tid][srcLvl];
+        const auto oldSegHi = segHi[tid][srcLvl];
+        assert(oldSegHi);
+        Value newSegHi = builder.create<arith::CmpIOp>(
+            loc, arith::CmpIPredicate::uge, pos, oldSegHi);
+        auto ifNewSegHi = builder.create<scf::IfOp>(loc, builder.getIndexType(),
+                                                    newSegHi, true);
+        {
+          OpBuilder::InsertionGuard guard(builder);
+          builder.setInsertionPointToStart(ifNewSegHi.thenBlock());
+          YIELD(genSegmentHigh(builder, loc, tid, srcLvl, pos,
+                               highs[tid][srcLvl]));
+          // Else, resues the same segment high.
+          builder.setInsertionPointToStart(ifNewSegHi.elseBlock());
+          YIELD(oldSegHi);
+        }
+        highs[tid][srcLvl + 1] = segHi[tid][srcLvl] = ifNewSegHi.getResult(0);
+      }
+    };
+    const auto srcLvl = reassoc.back();
+    if (!isUniqueDLT(lvlTypes[tid][srcLvl])) {
+      segHi[tid][srcLvl] = genSegmentHigh(
+          builder, loc, tid, srcLvl, posits[tid][srcLvl], highs[tid][srcLvl]);
+    }
+  }
 
-  // At most one tensor used as condition in for loop;
-  SmallVector<TensorLevel, 1> condTidLvl;
-  // There might be multiple dense slice driven tensor.
+  // In-place update on reduction variable.
+  assert(aArgs.size() == reduc.size() + needsUniv ? 1 : 0);
+  for (unsigned i = 0, e = reduc.size(); i < e; i++)
+    reduc[i] = aArgs[i];
+
+  Value min;
+  // Finds the minimum coordinate
+  if (!needsUniv) {
+    for (auto [tid, lvl] : unpackTensorLevelFromCondRange(spConds)) {
+      const auto lvlTp = lvlTypes[tid][lvl];
+      if (isCompressedDLT(lvlTp) || isSingletonDLT(lvlTp) ||
+          isCompressedWithHiDLT(lvlTp)) {
+        const auto crd = coords[tid][lvl];
+        if (min) {
+          Value cmp = CMPI(ult, coords[tid][lvl], min);
+          min = SELECT(cmp, coords[tid][lvl], min);
+        } else {
+          min = crd;
+        }
+      }
+    }
+  } else {
+    assert(!min);
+    // Otherwise, universal index is the minimal pos.
+    min = whileOp.getAfterArguments().back();
+  }
+
+  return {whileOp, min};
+}
+
+bool LoopEmitter::shouldIteratedByForLoop(ArrayRef<TensorLvlCond> sparseConds,
+                                          bool genDedup) {
+  assert(llvm::all_of(sparseConds,
+                      [](TensorLvlCond c) { return isSparseCond(c.second); }));
+
+  // If we need to co-iterate over two sparse tensors, we need a while loop
+  if (sparseConds.size() > 1)
+    return false;
+
+  // We also need a while loop for levels with affine index expression for
+  // non-unique levels when deduplication is required.
+  if (sparseConds.size() == 1) {
+    auto [tid, lvl] = unpackTensorLevel(sparseConds.back().first);
+    auto reassoc = getCollapseReassociation(tid, lvl);
+    return !isAffineIdxCond(sparseConds.back().second) &&
+           !(genDedup && !isUniqueDLT(lvlTypes[tid][reassoc.back()]));
+  }
+
+  return true;
+}
+
+Operation *LoopEmitter::enterCoIterationOverTensorsAtLvls(
+    OpBuilder &builder, Location loc, ArrayRef<TensorLevel> tidLvls,
+    MutableArrayRef<Value> reduc, bool tryParallel, bool genDedup,
+    bool needsUniv) {
+  // Sanity checks.
+  assert(!tidLvls.empty());
+  for (auto [t, l] : unpackTensorLevelRange(tidLvls)) {
+    assert(!coords[t][l] ||                 // We cannot re-enter the same level
+           !dependentLvlMap[t][l].empty()); // unless it is a slice-driver loop
+  }
+  // TODO: support multiple return on parallel for?
+  tryParallel = tryParallel && reduc.size() <= 1;
+
+  SmallVector<TensorLvlCond> spConds;
+  SmallVector<TensorLvlCond> dnConds;
+  categorizeLoopCondition(tidLvls, dnConds, spConds);
+
+  // Only when there is at least one sparse conditions, do we really need the
+  // universal index.
+  // TODO: Maybe we should instead requires merger to pass in a valid value at
+  // the first place instead of adjusting it in LoopEmitter?
+  needsUniv = !spConds.empty() && needsUniv;
+  // The TensorLevel used for loop conditions.
+  // If there is any sparse level, we need to use the sparse condition.
+  // If all levels are dense, we can pick arbitary one (dense slice-driven loop
+  // can be generated using a simple ForOp as well).
+  Operation *l = nullptr;
+  Value iv = nullptr;
   SmallVector<SliceLoopInfo> sliceDrivenInfo;
+  SmallVector<TensorLevel> trivialLvls;
 
   // Generates loops 
diff erently depending on whether we need a slice-driven
   // loop or a simple level traversal loop.
-  if (isSparseSliceCond) {
-    bool fullyReduced = depFullyReduced(tid, lvl);
-    if (!fullyReduced) {
-      l = emitSliceDrivenLoopOverTensorAtLvl(builder, loc, tid, lvl, reduc);
-    } else {
-      // If the slice is fully reduced, we can now use TACO-based algorithm to
-      // iterate it.
-      l = emitWhileLoopOverSliceAtSparseLvl(
-          builder, loc, posits[tid][lvl], highs[tid][lvl],
-          getFinalSliceOnLvl(tid, lvl).offset, sliceSizes[tid][lvl].back(), tid,
-          lvl, reduc);
-    }
-    levelReducedDep[tid][lvl]++;
-    sliceDrivenInfo.emplace_back(tid, lvl, fullyReduced);
-  } else {
-    Value lo = isSparseCond ? posits[tid][lvl]           // current offset
-                            : loopSeqStack.back().first; // universal index
+  if (shouldIteratedByForLoop(spConds, genDedup) && !needsUniv) {
+    assert(spConds.size() <= 1);
+    TensorLvlCond tlCond = spConds.empty() ? dnConds.front() : spConds.front();
+    auto loopCondKind = tlCond.second;
+    auto [tid, lvl] = unpackTensorLevel(tlCond.first);
+    Value lo = isSparseCond(loopCondKind)
+                   ? posits[tid][lvl]           // current offset
+                   : loopSeqStack.back().first; // universal index
     Value hi = highs[tid][lvl];
-    if (isDenseSliceCond) {
-      bool fullyReduced = depFullyReduced(tid, lvl);
-      Value sliceSz = sliceSizes[tid][lvl][sliceStack[tid].back().depth - 1];
-      // Adjust for loop hi for dense slice-driven loop.
-      if (fullyReduced) {
-        hi = sliceSz;
-        condTidLvl.push_back(makeTensorLevel(tid, lvl));
-      } else {
-        hi = SUBI(lvlSizes[tid][lvl], sliceSz);
+    if (isDenseCond(loopCondKind) && isAffineIdxCond(loopCondKind)) {
+      bool unReduc = isAffineIdxUnRedCond(loopCondKind);
+      assert(unReduc == !depFullyReduced(tid, lvl));
+      hi = sliceSizes[tid][lvl][sliceStack[tid].back().depth - 1];
+      if (unReduc) {
+        // Adjust for loop hi for dense slice-driven loop.
+        hi = SUBI(lvlSizes[tid][lvl], hi);
         hi = ADDI(hi, C_IDX(1));
       }
-    } else {
-      condTidLvl.push_back(makeTensorLevel(tid, lvl));
     }
-    l = emitForLoopOverTensorAtLvl(builder, loc, tid, lvl, lo, hi, reduc,
-                                   isParallel);
-  }
-  Value iv = coords[tid][lvl];
-  for (auto [t, l] : unpackTensorLevelRange(tidLvls)) {
-    // We only need to handle slice-driven loops on dense level here.
-    // If it is a slice-driven loop on sparse level, it needs a while loop to
-    // insert break statements, and it must have been handled correctly in L692.
-    if (!dependentLvlMap[t][l].empty() && isDenseDLT(lvlTypes[t][l])) {
-      // Pushes sliced levels to build correct LoopInfo.
-      bool fullyReduc = depFullyReduced(t, l);
-      SliceInfo &info = sliceStack[t].back();
-      if (fullyReduc) {
-        posits[t][l] = genAddress(builder, loc, t, l, ADDI(info.offset, iv));
+    std::tie(l, iv) = emitForLoopOverTensorAtLvl(builder, loc, tid, lvl, lo, hi,
+                                                 reduc, tryParallel);
+    // For loop condition must be a trivial condition (levels without affine
+    // index expression).
+    trivialLvls.push_back(tlCond.first);
+  } else {
+    for (auto [tl, cKind] : spConds) {
+      if (isAffineIdxCond(cKind)) {
+        auto [tid, lvl] = unpackTensorLevel(tl);
+        bool unReduc = isAffineIdxUnRedCond(cKind);
+        assert(unReduc == !depFullyReduced(tid, lvl));
+        sliceDrivenInfo.emplace_back(tid, lvl, /*fullyReduced=*/!unReduc);
       } else {
-        // Puts sliced dense loop into LoopInfo so that LoopEmitter knows how to
-        // exit it.
-        sliceDrivenInfo.emplace_back(t, l, fullyReduc);
-        // Update the slice information as we enter the new loop.
-        assert(*info.slicedOnLvl == l);
-        info.minCrd = info.offset = iv;
-        info.isNonEmpty = constantI1(builder, loc, true);
-        levelReducedDep[t][l]++;
+        trivialLvls.push_back(tl);
       }
     }
+
+    std::tie(l, iv) =
+        emitWhileLoopOverTensorsAtLvls(builder, loc, spConds, reduc, needsUniv);
   }
+
+  // Enter dense tensor levels.
+  enterTensorsAtDenseLvls(builder, loc, dnConds, iv, sliceDrivenInfo);
   // NOTE: we can also prepare for next dim here in advance
+
   // Pushes the loop into stack.
-  loopStack.emplace_back(condTidLvl, sliceDrivenInfo, l,
+  loopStack.emplace_back(trivialLvls, sliceDrivenInfo, l,
                          builder.getInsertionBlock(), iv, loopTag);
-  // Emit extra locals.
-  emitExtraLocalsForTensorsAtDenseLvls(builder, loc, tidLvls);
   return l;
 }
 
@@ -886,229 +1210,11 @@ void LoopEmitter::genDenseAffineAddress(OpBuilder &builder, Location loc,
                                         AffineExpr lvlExpr) {
   auto [tid, lvl] = unpackTensorLevel(tidLvl);
   assert(isDenseDLT(lvlTypes[tid][lvl]));
-  // For dense levels, the level-coordinate also serves as the position.
+  // For dense levels, the vel-coordinate also serves as the position.
   Value lvlCrd = genAffine(builder, loc, lvlExpr);
   posits[tid][lvl] = genAddress(builder, loc, tid, lvl, lvlCrd);
 }
 
-Operation *LoopEmitter::enterCoIterationOverTensorsAtLvls(
-    OpBuilder &builder, Location loc, ArrayRef<TensorLevel> tidLvls,
-    bool needsUniv, MutableArrayRef<Value> reduc) {
-  // NOTE: the slice driven tensor-related reduction variable must
-  // appear before normal tensors.
-  SmallVector<Type> types;
-  SmallVector<Value> operands;
-  // Construct the while-loop with a parameter for each coordinate.
-  const Type indexType = builder.getIndexType();
-  for (auto [tid, lvl] : unpackTensorLevelRange(tidLvls)) {
-    // TODO: support coiteration with slice driven tensors.
-    const auto lvlTp = lvlTypes[tid][lvl];
-    assert(dependentLvlMap[tid][lvl].empty() && "TODO: not yet implemented");
-    if (isCompressedDLT(lvlTp) || isSingletonDLT(lvlTp) ||
-        isCompressedWithHiDLT(lvlTp)) {
-      const auto reassoc = getCollapseReassociation(tid, lvl);
-      for (unsigned i = 0, e = reassoc.size() - 1; i < e; i++) {
-        if (!isUniqueDLT(lvlTypes[tid][reassoc[i]])) {
-          // This is the segment high for each non-unique levels.
-          types.push_back(indexType);
-          operands.push_back(C_IDX(0));
-        }
-      }
-      const auto pos = posits[tid][reassoc.front()];
-      assert(pos);
-      types.push_back(indexType);
-      operands.push_back(pos);
-    }
-  }
-  // The position where user-supplied reduction variable starts.
-  for (Value rec : reduc) {
-    types.push_back(rec.getType());
-    operands.push_back(rec);
-  }
-  if (needsUniv) {
-    types.push_back(indexType);
-    // Update universal index.
-    operands.push_back(loopSeqStack.back().first);
-  }
-  assert(types.size() == operands.size());
-  scf::WhileOp whileOp = builder.create<scf::WhileOp>(loc, types, operands);
-
-  SmallVector<Location> locs(types.size(), loc);
-  Block *before = builder.createBlock(&whileOp.getBefore(), {}, types, locs);
-  Block *after = builder.createBlock(&whileOp.getAfter(), {}, types, locs);
-
-  // Build the "before" region, which effectively consists
-  // of a conjunction of "i < upper" tests on all induction.
-  builder.setInsertionPointToStart(&whileOp.getBefore().front());
-  Value cond;
-  unsigned o = 0;
-  for (auto [t, lvl] : unpackTensorLevelRange(tidLvls)) {
-    const TensorId tid = t; // Why `t` can not be captured by lambda?
-    const auto lvlTp = lvlTypes[tid][lvl];
-    if (isCompressedDLT(lvlTp) || isSingletonDLT(lvlTp) ||
-        isCompressedWithHiDLT(lvlTp)) {
-      const auto reassoc = getCollapseReassociation(tid, lvl);
-      assert(reassoc.size() == 1 || isUniqueCOOType(tensors[tid].getType()));
-      for (unsigned i = 0, e = reassoc.size() - 1; i < e; i++) {
-        if (!isUniqueDLT(lvlTypes[tid][reassoc[i]])) {
-          // Links the SSA chain for segHi.
-          segHi[tid][reassoc[i]] = after->getArgument(o++);
-        }
-      }
-      Value op1 = before->getArgument(o);
-      // We used the first level bound as the bound the collapsed set of levels.
-      Value op2 = highs[tid][reassoc.front()];
-      Value opc = CMPI(ult, op1, op2);
-      cond = cond ? ANDI(cond, opc) : opc;
-      // Update positions
-      Value pos = after->getArgument(o++);
-      // For COO, the position is the same across consecutive levels.
-      /// FIXME: See the [CLARIFY_POSITS_LVL] note in the header.
-      llvm::for_each(reassoc, [this, tid, pos](Level srcLvl) {
-        posits[tid][srcLvl] = pos;
-      });
-    }
-  }
-  builder.create<scf::ConditionOp>(loc, cond, before->getArguments());
-
-  // Generates while body.
-  builder.setInsertionPointToStart(&whileOp.getAfter().front());
-
-  SmallVector<std::pair<Value, unsigned>> slicesPreds;
-  unsigned i = 0;
-  for (auto [tid, lvl] : unpackTensorLevelRange(tidLvls)) {
-    // Prepares for next level.
-    const auto lvlTp = lvlTypes[tid][lvl];
-    if (isCompressedDLT(lvlTp) || isSingletonDLT(lvlTp) ||
-        isCompressedWithHiDLT(lvlTp)) {
-      coords[tid][lvl] = genSparseCrd(builder, loc, tid, lvl);
-      if (isSparseSlices[tid]) {
-        auto [trans, pred] =
-            genSliceLegitPredicate(builder, loc, coords[tid][lvl], tid, lvl);
-        slicesPreds.emplace_back(pred, i);
-        // Updates to the relative coordinate to the slice.
-        coords[tid][lvl] = trans;
-      }
-      i++;
-    }
-  }
-
-  if (!slicesPreds.empty()) {
-    // Skips invalid loop iteration when slice coordinate is inapplicable.
-    SmallVector<Value> yields(after->getArguments());
-    // Generates a list of if statments
-    //  pos = in_slice ? pos : pos + 1
-    // TODO: instead of always picking pos + 1, we should set pos = high to
-    // break to loop if the coordinates are larger than the slice size.
-    //
-    // This "idx" is the index into `llvm::zip(tids, lvls)`
-    for (auto [pred, idx] : slicesPreds) {
-      Value nextPos = ADDI(yields[idx], C_IDX(1));
-      yields[idx] = SELECT(pred, yields[idx], nextPos);
-    }
-
-    Value pred = slicesPreds.front().first;
-    for (int i = 1, e = slicesPreds.size(); i < e; i++) {
-      pred = ANDI(pred, slicesPreds[i].first);
-    }
-    auto ifOp = builder.create<scf::IfOp>(loc, types, pred, /*else*/ true);
-    ifOp->setAttr(getLoopEmitterLoopAttrName(),
-                  StringAttr::get(builder.getContext(), "slice"));
-    YIELD(ifOp->getResults());
-    assert(types.size() == yields.size());
-    // If not all slices are legit
-    builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
-    YIELD(yields);
-
-    // If all slices are legit, start the user generated code.
-    builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
-  }
-
-  Value min;
-  // Finds the minimum coordinate
-  if (!needsUniv) {
-    for (auto [tid, lvl] : unpackTensorLevelRange(tidLvls)) {
-      const auto lvlTp = lvlTypes[tid][lvl];
-      if (isCompressedDLT(lvlTp) || isSingletonDLT(lvlTp) ||
-          isCompressedWithHiDLT(lvlTp)) {
-        const auto crd = coords[tid][lvl];
-        if (min) {
-          Value cmp = CMPI(ult, coords[tid][lvl], min);
-          min = SELECT(cmp, coords[tid][lvl], min);
-        } else {
-          min = crd;
-        }
-      }
-    }
-  } else {
-    assert(!min);
-    // Otherwise, universal index is the minimal pos.
-    min = after->getArguments().back();
-  }
-
-  // Sets up the loop stack.
-  loopStack.emplace_back(tidLvls, ArrayRef<SliceLoopInfo>(), whileOp,
-                         builder.getInsertionBlock(), min, loopTag);
-  assert(loopStack.size() == loopSeqStack.size());
-
-  for (auto [tid, dstLvl] : unpackTensorLevelRange(tidLvls)) {
-    const auto reassoc = getCollapseReassociation(tid, dstLvl);
-    assert(reassoc.size() == 1 || isUniqueCOOType(tensors[tid].getType()));
-    // TODO: Refactors this into smaller functions.
-    // NOTE: For all the collapsed level (except for the last one, that is why
-    // the loop ends with `reassoc.size() - 1`), as each iteration is advanced
-    // by the segment size of the last level, which does not always invalidate
-    // the segment size for the previous levels, thus we need to propagate the
-    // segment sizes across loop iterations and only forward if needed.
-    //
-    // E.g., for a COO tensor with the following coordinates array.
-    // (0, 0, 1),
-    // (0, 0, 2),
-    // (1, 1, 1),
-    // segHi[lvl=0] = segHi[lvl=1] = 2
-    // segHi[lvl=2] = 1,
-    // the first iteration does not invalidate segHi[0] and segHi[1]
-    for (unsigned i = 0, e = reassoc.size() - 1; i < e; i++) {
-      const Level srcLvl = reassoc[i];
-      if (!isUniqueDLT(lvlTypes[tid][srcLvl])) {
-        const Value pos = posits[tid][srcLvl];
-        const auto oldSegHi = segHi[tid][srcLvl];
-        assert(oldSegHi);
-        Value newSegHi = builder.create<arith::CmpIOp>(
-            loc, arith::CmpIPredicate::uge, pos, oldSegHi);
-        auto ifNewSegHi = builder.create<scf::IfOp>(loc, builder.getIndexType(),
-                                                    newSegHi, true);
-        {
-          OpBuilder::InsertionGuard guard(builder);
-          builder.setInsertionPointToStart(ifNewSegHi.thenBlock());
-          YIELD(genSegmentHigh(builder, loc, tid, srcLvl, pos,
-                               highs[tid][srcLvl]));
-          // Else, resues the same segment high.
-          builder.setInsertionPointToStart(ifNewSegHi.elseBlock());
-          YIELD(oldSegHi);
-        }
-        highs[tid][srcLvl + 1] = segHi[tid][srcLvl] = ifNewSegHi.getResult(0);
-      }
-    };
-    const auto srcLvl = reassoc.back();
-    if (!isUniqueDLT(lvlTypes[tid][srcLvl])) {
-      segHi[tid][srcLvl] = genSegmentHigh(
-          builder, loc, tid, srcLvl, posits[tid][srcLvl], highs[tid][srcLvl]);
-    }
-  }
-
-  // Emits extra locals
-  emitExtraLocalsForTensorsAtDenseLvls(builder, loc, tidLvls);
-
-  // Updates reduction variables
-  assert(after->getNumArguments() == o + reduc.size() + (needsUniv ? 1 : 0));
-  // In-place update on reduction variable.
-  for (unsigned i = 0, e = reduc.size(); i < e; i++)
-    reduc[i] = after->getArgument(o + i);
-
-  return whileOp;
-}
-
 void LoopEmitter::prepareLoopOverTensorAtLvl(OpBuilder &builder, Location loc,
                                              TensorId tid, Level dstLvl) {
   assert(isValidLevel(tid, dstLvl));
@@ -1159,20 +1265,35 @@ void LoopEmitter::prepareLoopOverTensorAtLvl(OpBuilder &builder, Location loc,
   llvm_unreachable("Unrecognized level-type!");
 }
 
-void LoopEmitter::emitExtraLocalsForTensorsAtDenseLvls(
-    OpBuilder &builder, Location loc, ArrayRef<TensorLevel> tidLvls) {
-  // Initialize dense positions. Note that we generate dense coordinates of the
-  // output tensor unconditionally, since they may not appear in the lattice,
-  // but may be needed for linearized codegen.
-  for (auto [tid, lvl] : unpackTensorLevelRange(tidLvls)) {
-    if (isSynTensor(tid))
-      continue;
+void LoopEmitter::enterTensorsAtDenseLvls(
+    OpBuilder &builder, Location loc, ArrayRef<TensorLvlCond> dnConds, Value iv,
+    SmallVectorImpl<SliceLoopInfo> &sliceInfo) {
+  for (auto [dnTidLvl, denseLoopCond] : dnConds) {
+    auto [tid, lvl] = unpackTensorLevel(dnTidLvl);
+    assert(isDenseDLT(lvlTypes[tid][lvl]));
 
-    if (isDenseDLT(lvlTypes[tid][lvl])) {
-      // Slice-driven dense level should have be handled already.
-      if (!dependentLvlMap[tid][lvl].empty())
+    if (isAffineIdxCond(denseLoopCond)) {
+      // Pushes sliced levels to build correct LoopInfo.
+      bool unReduc = isAffineIdxUnRedCond(denseLoopCond);
+      SliceInfo &info = sliceStack[tid].back();
+      if (unReduc) {
+        // Pushes sliced dense loop info to tell LoopEmitter how to exit it.
+        sliceInfo.emplace_back(tid, lvl, /*fullyReduced=*/false);
+        // Update the slice information as we enter the new loop.
+        assert(*info.slicedOnLvl == lvl);
+        info.minCrd = info.offset = iv;
+        info.isNonEmpty = constantI1(builder, loc, true);
+        levelReducedDep[tid][lvl]++;
+      } else {
+        posits[tid][lvl] =
+            genAddress(builder, loc, tid, lvl, ADDI(info.offset, iv));
+      }
+    } else {
+      // Skips the synthetic tensor
+      if (isSynTensor(tid))
         continue;
-
+      // A dense level with trivial index expression.
+      assert(dependentLvlMap[tid][lvl].empty());
       auto enc = getSparseTensorEncoding(tensors[tid].getType());
       if (enc && !isSparseOutput(tid)) {
         bool validPos = lvl == 0 || posits[tid][lvl - 1];
@@ -1182,8 +1303,7 @@ void LoopEmitter::emitExtraLocalsForTensorsAtDenseLvls(
           assert(isOutputTensor(tid));
           continue;
         }
-        posits[tid][lvl] =
-            genAddress(builder, loc, tid, lvl, loopStack.back().iv);
+        posits[tid][lvl] = genAddress(builder, loc, tid, lvl, iv);
         // NOTE: we can also prepare for next lvl here in advance
       }
     }
@@ -1270,7 +1390,7 @@ void LoopEmitter::exitForLoop(RewriterBase &rewriter, Location loc,
   // Finished iterating a tensor, clean up
   // We only do the clean up on for loop as while loops do not necessarily
   // finish the iteration on a sparse tensor
-  for (auto [tid, lvl] : unpackTensorLevelRange(loopInfo.tidLvls)) {
+  for (auto [tid, lvl] : unpackTensorLevelRange(loopInfo.trivialTidLvls)) {
     // Reset to null.
     coords[tid][lvl] = Value();
     posits[tid][lvl] = Value();
@@ -1285,6 +1405,7 @@ void LoopEmitter::exitWhileLoop(OpBuilder &builder, Location loc,
   const LoopInfo &loopInfo = loopStack.back();
   auto whileOp = llvm::cast<scf::WhileOp>(loopInfo.loop);
   Value iv = loopInfo.iv;
+  Value one = C_IDX(1);
 
   // Finalize the induction. Note that the induction could be performed
   // in the individual if-branches to avoid re-evaluating the conditions.
@@ -1299,31 +1420,32 @@ void LoopEmitter::exitWhileLoop(OpBuilder &builder, Location loc,
     assert(isCompressedDLT(lvlTypes[tid][lvl]));
     levelReducedDep[tid][lvl]--;
     if (!resolved) {
+      // TODO: support coiterating multiple slices
+      assert(loopInfo.trivialTidLvls.empty() &&
+             loopInfo.sliceDrivenInfo.size() == 1);
       genSliceNextInduction(builder, loc, whileOp, tid, lvl, operands, o);
       continue;
     }
-    // TODO: We need to distinguish coiterate loop with slice-driven loop and
-    // fully reduced while op for iterating one slices.
-    // FIXME: since we didn't implement coiteration, this must be iteration
-    // just on fully resolved slice.
-    assert(loopInfo.sliceDrivenInfo.size() == 1 && loopInfo.tidLvls.empty());
-    // The if guard to filter out out-range coordinates.
-    assert(llvm::isa<scf::IfOp>(builder.getInsertionBlock()->getParentOp()));
+
+    if (loopInfo.trivialTidLvls.empty() &&
+        loopInfo.sliceDrivenInfo.size() == 1) {
+      // Forwards the position iterator.
+      operands.push_back(ADDI(posits[tid][lvl], one));
+    } else {
+      const Value pos = posits[tid][lvl];
+      const Value nxPos = ADDI(posits[tid][lvl], one);
+      Value cmp = CMPI(eq, coords[tid][lvl], iv);
+      operands.push_back(SELECT(cmp, nxPos, pos));
+    }
+
+    // The coordinate is invalid now.
+    coords[tid][lvl] = nullptr;
+
+    // Update the position iterator as we exit the while loop.
     posits[tid][lvl] = whileOp->getResult(o++);
-    // FIXME: we are not using continue here since we do not support
-    // coiteration on slices. But it need to be treated similarly as the
-    // universal index.
-    o++; // skip continue flag.
-    // Since we did not push two results from whileOp. The size of the
-    // operands vector is smaller than the actual number of return values from
-    // the whileOp.
-    // It is because we are actually generating yield in the IfOp inside the
-    // whileOp to only iterates over inbound coordinates within the slices.
-    delta += 2;
   };
 
-  Value one = C_IDX(1);
-  for (auto [tid, dstLvl] : unpackTensorLevelRange(loopInfo.tidLvls)) {
+  for (auto [tid, dstLvl] : unpackTensorLevelRange(loopInfo.trivialTidLvls)) {
     const auto lvlTp = lvlTypes[tid][dstLvl];
     if (isCompressedDLT(lvlTp) || isSingletonDLT(lvlTp) ||
         isCompressedWithHiDLT(lvlTp)) {
@@ -1357,6 +1479,7 @@ void LoopEmitter::exitWhileLoop(OpBuilder &builder, Location loc,
       llvm::for_each(reassoc, [this, newTid, newPos](Level srcLvl) {
         posits[newTid][srcLvl] = newPos;
       });
+
       // The coordinate is invalid now.
       coords[tid][dstLvl] = nullptr;
       // The segment high is invalid now.
@@ -1439,25 +1562,6 @@ const LoopEmitter::SliceInfo &LoopEmitter::getMostRecentSliceOnLvl(TensorId tid,
   llvm_unreachable("Failed to find sliceInfo");
 }
 
-static Value genSparseReducedAffineCond(OpBuilder &builder, Location loc,
-                                        Value crdBuf, Value crdHi, Value posit,
-                                        Value posHi, Value cont) {
-  Value inBound = CMPI(ult, posit, posHi);
-  auto ifOp = builder.create<scf::IfOp>(loc, cont.getType(), inBound, true);
-  // if (inbound)
-  //   yield coord < crdHi
-  builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
-  Value crd = genIndexLoad(builder, loc, crdBuf, posit);
-  YIELD(CMPI(ult, crd, crdHi));
-  // else
-  //   yield false
-  builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
-  YIELD(constantI1(builder, loc, false));
-
-  builder.setInsertionPointAfter(ifOp);
-  return ifOp.getResult(0);
-}
-
 // Generates a while loop to iterate over a slice sparse level as follows.
 //
 // while(coords[loopLo] < offset + size) {
@@ -1466,15 +1570,13 @@ static Value genSparseReducedAffineCond(OpBuilder &builder, Location loc,
 // }
 std::pair<Operation *, ValueRange> LoopEmitter::genSliceLvlTraverseLoop(
     OpBuilder &builder, Location loc, Value posLo, Value posHi, Value offset,
-    Value size, TensorId tid, Level lvl, ValueRange userReduc, bool genYield,
+    Value size, TensorId tid, Level lvl, ValueRange userReduc,
     LoopBodyBuilder bodyBuilder) {
   Value c1 = C_IDX(1);
   Value sliceHi = ADDI(offset, sliceSizes[tid][lvl].back());
+  SmallVector<Value> reduc{posLo}; // loop lower bounds
+  const unsigned numMetaReduc = reduc.size();
 
-  SmallVector<Value> reduc = {
-      posLo,                          // loop lower bounds
-      constantI1(builder, loc, true), // continue
-  };
   // Append user required reduction value.
   reduc.append(userReduc.begin(), userReduc.end());
   scf::WhileOp whileOp = builder.create<scf::WhileOp>(
@@ -1482,28 +1584,28 @@ std::pair<Operation *, ValueRange> LoopEmitter::genSliceLvlTraverseLoop(
       /*beforeBuilder=*/
       [this, posHi, sliceHi, tid, lvl](OpBuilder &builder, Location loc,
                                        ValueRange args) {
-        Value cond = genSparseReducedAffineCond(
-            builder, loc, coordinatesBuffers[tid][lvl], sliceHi, args[0], posHi,
-            args[1]);
+        Value cond = genSparseReducedAffineCond(builder, loc,
+                                                coordinatesBuffers[tid][lvl],
+                                                sliceHi, args[0], posHi);
         // continue if not yet break nor out of bound.
         builder.create<scf::ConditionOp>(loc, cond, args);
       },
       /*afterBuilder=*/
-      [c1, genYield, bodyBuilder](OpBuilder &builder, Location loc,
-                                  ValueRange args) {
+      [c1, numMetaReduc, bodyBuilder](OpBuilder &builder, Location loc,
+                                      ValueRange args) {
         Value iv = args[0];
-        TypeRange types = args.drop_front(2).getTypes();
-        // The coordinate must be in bound as guaranteed by the loop condition.
-        // We generate a fake if operation here only to hide the two extra loop
-        // induction variable maintained by us from user, and relies on later
-        // optimization pass to remove it.
-        Value cont = constantI1(builder, loc, true);
-        auto ifOp = builder.create<scf::IfOp>(loc, types, cont,
+        TypeRange types = args.drop_front(numMetaReduc).getTypes();
+        // The coordinate must be in bound as guaranteed by the loop
+        // condition. We generate a fake if operation here only to hide the
+        // extra loop induction variables maintained by us from users, which
+        // will be removed by later optimization pass.
+        auto ifOp = builder.create<scf::IfOp>(loc, types,
+                                              constantI1(builder, loc, true),
                                               /*withElseBlock=*/!types.empty());
         {
           // 2 reduction variable maintained by us.
-          SmallVector<Value> ifRet = args.drop_front(2);
-          assert(ifRet.size() == args.size() - 2);
+          SmallVector<Value> ifRet = args.drop_front(numMetaReduc);
+          assert(ifRet.size() == args.size() - 1);
 
           OpBuilder::InsertionGuard guard(builder);
           // If coord >= sliceHi.
@@ -1516,10 +1618,6 @@ std::pair<Operation *, ValueRange> LoopEmitter::genSliceLvlTraverseLoop(
           builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
           // Delegates to users' callback.
           bodyBuilder(builder, loc, iv, ifRet);
-          if (genYield) {
-            builder.setInsertionPointToEnd(&ifOp.getThenRegion().front());
-            YIELD(ifRet);
-          }
         }
         // Marks this speical ifOp to avoid sparisification finalizing it.
         ifOp->setAttr(getLoopEmitterLoopAttrName(),
@@ -1528,13 +1626,12 @@ std::pair<Operation *, ValueRange> LoopEmitter::genSliceLvlTraverseLoop(
         SmallVector<Value> yields;
         // Increase induction variable.
         yields.push_back(ADDI(iv, c1));
-        yields.push_back(cont);
         yields.append(ifOp.getResults().begin(), ifOp.getResults().end());
         YIELD(yields);
       });
 
   builder.setInsertionPointAfter(whileOp);
-  return std::make_pair(whileOp, whileOp.getResults().drop_front(2));
+  return std::make_pair(whileOp, whileOp.getResults().drop_front(numMetaReduc));
 }
 
 // Generates a loop nest that traverse all the unresolved levels in between.
@@ -1590,7 +1687,6 @@ ValueRange LoopEmitter::genUnResolvedSliceTreeTraverse(
                   genSliceLvlTraverseLoop(
                       builder, loc, loopLo, loopHi, offset,
                       sliceSizes[tid][firstLvl].back(), tid, firstLvl, iterArgs,
-                      false,
                       [&](OpBuilder &builder, Location, Value iv,
                           MutableArrayRef<Value> reduc) {
                         ip = builder.saveInsertionPoint();
@@ -1710,7 +1806,8 @@ void LoopEmitter::genResolvedSliceBegin(OpBuilder &builder, Location loc,
 
   // FIXME: We need the relative offset related to the base slice.
   Value absOffset = offsetFromMinCoord(builder, loc, minCrd, size, isNonEmpty);
-  sliceStack[tid].emplace_back(minCrd, absOffset, isNonEmpty, lvl, /*depth=*/1);
+  sliceStack[tid].emplace_back(minCrd, absOffset, isNonEmpty, lvl,
+                               /*depth=*/1);
 }
 
 // Fills in the slicePosBuffer before slice-driven loop begin.
@@ -1796,8 +1893,8 @@ void LoopEmitter::genUnResolvedSliceBegin(OpBuilder &builder, Location loc,
         Value sPHi =
             genIndexLoad(builder, loc, positionsBuffers[tid][lvl], pHi);
 
-        // isNonEmpty = isNonEmpty || lvlNonEmpty, i.e., as long as there is one
-        // non-empty lvl, the slice is non-empty.
+        // isNonEmpty = isNonEmpty || lvlNonEmpty, i.e., as long as there is
+        // one non-empty lvl, the slice is non-empty.
         Value lvlNonEmpty = CMPI(ult, sPLo, sPHi);
         nonEmpty = builder.create<arith::OrIOp>(loc, lvlNonEmpty, nonEmpty);
 
@@ -1884,8 +1981,8 @@ bool LoopEmitter::genSliceBegin(OpBuilder &builder, Location loc, TensorId tid,
   // We do not need cache for dense levels.
   if (slicePosBuffer[tid][lvl][0] == nullptr && !isDenseDLT(lvlType)) {
     OpBuilder::InsertionGuard guard(builder);
-    // The buffer can be reused, and the size is loop invariant: it only depends
-    // on the iteration graph's toposort.
+    // The buffer can be reused, and the size is loop invariant: it only
+    // depends on the iteration graph's toposort.
     builder.setInsertionPointAfter(localInsertPos);
     Value bufSize = C_IDX(1);
     Value c2 = C_IDX(2);
@@ -1904,9 +2001,9 @@ bool LoopEmitter::genSliceBegin(OpBuilder &builder, Location loc, TensorId tid,
       Value sz = *(sliceSizes[tid][lvl].rbegin() + depth - 1);
       bufSize = MULI(bufSize, sz);
     }
-    // For a pair of [pLo, pHi]. Note that we can not compress pHi because slice
-    // creates segments in the index buffer so that the pHi for the current
-    // level is no longer the pLo for the next level.
+    // For a pair of [pLo, pHi]. Note that we can not compress pHi because
+    // slice creates segments in the index buffer so that the pHi for the
+    // current level is no longer the pLo for the next level.
     bufSize = MULI(bufSize, c2);
     // Additional two metadata {memSize, idx} at head.
     bufSize = ADDI(bufSize, c2);
@@ -2117,59 +2214,6 @@ void LoopEmitter::genSliceNextInduction(OpBuilder &builder, Location loc,
   info.offset = whileOp.getResult(retIdx++);
 }
 
-Operation *LoopEmitter::emitSliceDrivenLoopOverTensorAtLvl(
-    OpBuilder &builder, Location loc, TensorId tid, Level lvl,
-    MutableArrayRef<Value> reduc) {
-  assert(!depFullyReduced(tid, lvl));
-  SliceInfo &sliceInfo = sliceStack[tid].back();
-  assert(sliceInfo.slicedOnLvl == lvl);
-
-  // The order matters!
-  SmallVector<Value, 3> operands{sliceInfo.isNonEmpty, sliceInfo.minCrd,
-                                 sliceInfo.offset};
-  // number of reduction maintained by us.
-  size_t numMetaReduc = operands.size();
-
-  // Append user-required reduction values.
-  operands.append(reduc.begin(), reduc.end());
-  assert(operands.size() == numMetaReduc + reduc.size());
-
-  // while (slice.nonEmpty()) {
-  //   bodyBuilder();
-  //   SliceNext();
-  // }
-  auto whileOp = builder.create<scf::WhileOp>(
-      loc, ValueRange(operands).getTypes(), operands,
-      /*beforeBuilder=*/
-      [](OpBuilder &builder, Location loc, ValueRange args) {
-        builder.create<scf::ConditionOp>(loc, /*isNonEmpty*/ args[0], args);
-      },
-      /*afterBuilder=*/
-      [this, tid, lvl, reduc, numMetaReduc,
-       &sliceInfo](OpBuilder &builder, Location loc, ValueRange args) {
-        assert(args.size() == reduc.size() + numMetaReduc);
-        sliceInfo.isNonEmpty = args[0];
-        sliceInfo.minCrd = args[1];
-        sliceInfo.offset = args[2];
-        // The slice offset is used to coiterate with other tensors'
-        // coordinates.
-        Value c = sliceInfo.offset;
-        if (sliceInfo.depth > 1) {
-          // Coord is the relative offset related to its parents.
-          // Update c = absOffset[lvl][depth] - absOffset[lvl][depth - 1]
-          llvm_unreachable("TODO: not yet implement");
-        }
-        coords[tid][lvl] = c;
-
-        for (unsigned i = 0, e = reduc.size(); i < e; i++)
-          reduc[i] = args[i + numMetaReduc];
-      });
-
-  // Set the insertion point to while loop body.
-  builder.setInsertionPointToEnd(&whileOp.getAfter().front());
-  return whileOp;
-}
-
 #undef CMPI
 #undef C_IDX
 #undef YIELD

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
index 8fa79128889e2..f178366a738a4 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
@@ -141,33 +141,39 @@ class LoopEmitter {
   /// Exits the current loop sequence, this will reset universal index to 0.
   void exitCurrentLoopSeq(OpBuilder &builder, Location loc);
 
-  // TODO: Get rid of `lvls` in the argument list? Track the level we
-  // are currently at internally. Then it would be enterNextLvlForTensor.
-  // Still need a way to specify the lvl for non-annotated tensors though,
-  // as those can be accessed out of order.
-  //
-  /// Emits loop over tensor_tid_lvl, it assumes that loops between
-  /// tensor_tid_[0, lvl - 1] have already been generated.
-  /// The function will also perform in-place update on the `reduc` vector to
-  /// return the reduction variable used inside the generated loop.
-  Operation *enterLoopOverTensorAtLvl(OpBuilder &builder, Location loc,
-                                      ArrayRef<TensorLevel> tidLvls,
-                                      MutableArrayRef<Value> reduc = {},
-                                      bool isParallel = false);
-
+  /// Enters a loop that tries to locate a coordinates in a sparse level based
+  /// on the value evaluated by the provided affine expression.
+  /// DEPRECATED: affine index expression should be handled by index reduction
+  /// loop, filter loop-based solution is slow.
   Operation *enterFilterLoopOverTensorAtLvl(OpBuilder &builder, Location loc,
                                             TensorId tid, Level lvl,
                                             AffineExpr affine,
                                             MutableArrayRef<Value> reduc = {});
 
+  /// Emits the address for a dense level based on the value evaluated by the
+  /// provided affine expression.
+  /// DEPRECATED: affine index expression should be handled by index reduction
+  /// loop, filter loop-based solution is slow.
   void genDenseAffineAddress(OpBuilder &builder, Location loc,
                              TensorLevel tidLvl, AffineExpr lvlExpr);
 
+  // TODO: Get rid of `lvls` in the argument list? Track the level we
+  // are currently at internally. Then it would be enterNextLvlForTensor.
+  // Still need a way to specify the lvl for non-annotated tensors though,
+  // as those can be accessed out of order.
+  //
   /// Emits a co-iteration loop over a set of tensors.
+  /// Emits loop over tensor_tid_lvl, it assumes that loops between
+  /// tensor_tid_[0, lvl - 1] have already been generated.
+  /// The function will also perform in-place update on the `reduc` vector to
+  /// return the reduction variable used inside the generated loop.
   Operation *enterCoIterationOverTensorsAtLvls(
       OpBuilder &builder, Location loc, ArrayRef<TensorLevel> tidLvls,
-      bool needsUniv, MutableArrayRef<Value> reduc = {});
+      MutableArrayRef<Value> reduc = {}, bool isParallel = false,
+      bool genDedup = false, bool needsUniv = false);
 
+  /// Generates code to exit the current loop (e.g., generates yields, forwards
+  /// loop induction variables, etc).
   void exitCurrentLoop(RewriterBase &rewriter, Location loc,
                        MutableArrayRef<Value> reduc = {});
 
@@ -232,6 +238,15 @@ class LoopEmitter {
     });
   }
 
+  template <class ContainerTy>
+  auto unpackTensorLevelFromCondRange(ContainerTy &&c) const {
+    using EltTy = decltype(*c.begin());
+    static_assert(std::is_same_v<llvm::remove_cvref_t<EltTy>, TensorLvlCond>,
+                  "Must be unpacking a TensorLvlCond range");
+    return unpackTensorLevelRange(
+        llvm::make_first_range(std::forward<ContainerTy>(c)));
+  }
+
   ///
   /// Getters.
   ///
@@ -251,6 +266,10 @@ class LoopEmitter {
   }
 
 private:
+  ///
+  /// Structure definitions that hold 
diff erent kinds of loops information.
+  ///
+
   // A tuple that stored the slice-driven loop information.
   struct SliceLoopInfo final {
     SliceLoopInfo(TensorId tid, Level lvl, bool reduced)
@@ -262,18 +281,22 @@ class LoopEmitter {
   // LoopInfo stores information of a loop generated by LoopEmitter. E.g.,
   // the set of tensors levels that the loop is iterating over.
   struct LoopInfo final {
-    LoopInfo(ArrayRef<TensorLevel> tidLvls,
+    LoopInfo(ArrayRef<TensorLevel> trivialTidLvls,
              ArrayRef<SliceLoopInfo> sliceDrivenInfo, Operation *loop,
              Block *userBlock, Value iv, StringAttr loopTag)
-        : tidLvls(tidLvls), sliceDrivenInfo(sliceDrivenInfo), loop(loop),
-          userCodeBlock(userBlock), iv(iv) {
+        : trivialTidLvls(trivialTidLvls), sliceDrivenInfo(sliceDrivenInfo),
+          loop(loop), userCodeBlock(userBlock), iv(iv) {
       // Attached a special tag to loop emitter generated loop.
       if (loopTag)
         loop->setAttr(LoopEmitter::getLoopEmitterLoopAttrName(), loopTag);
     }
-    // The set of <tensor, lvl> that the loop is operating on
-    const llvm::SmallVector<TensorLevel> tidLvls;
-    // Slice-driven loop conditions.
+    // The set of <tensor, lvl>, with *only* trivial index expressions, that are
+    // used as the condition for the generated loop. Extra information is
+    // required for levels with non-tivial index expressions, which is
+    // maintained by the sliceDrivenInfo array below.
+    const llvm::SmallVector<TensorLevel> trivialTidLvls;
+    // The set of <tensor, lvl>, with *only* non-trivial index expressions, that
+    // are used as the condition for the generated loop.
     const llvm::SmallVector<SliceLoopInfo> sliceDrivenInfo;
     const Operation *loop;      // the loop operation
     Block *const userCodeBlock; // the block holding users' generated code.
@@ -304,9 +327,100 @@ class LoopEmitter {
     unsigned depth; // the depth (relative to dependentDimMap[tid][lvl]).
   };
 
+  ///
+  /// Enums for 
diff erent kinds of loop conditions.
+  ///
+
+  // The bit indicating whether the loop conditions is sparse.
+  static constexpr uint8_t kSparseCond = 1 << 3;
+  // The bit indicating whether the loop iterates over sparse tensor slices
+  // (i.e., with non-empty SliceDimAttr).
+  static constexpr uint8_t kSliceCond = 1 << 2;
+  // The bit indicating whether the loop iterates over tensor levels with
+  // non-trivial affine index reduction.
+  static constexpr uint8_t kAffineIdxCond = 1 << 1;
+  // The bit indicating whether the loop iterates over tensor levels with
+  // non-trivial affine index reduction, and it is not fully reduced.
+  static constexpr uint8_t kAffineIdxCondUnRed = 1 << 0;
+
+  enum class LoopCondKind : uint8_t {
+    // Dense conditions.
+    DenseCond = 0,
+    DenseSliceCond = kSliceCond,
+    DenseAffineCond = kAffineIdxCond,
+    DenseAffineUnRedCond = kAffineIdxCond | kAffineIdxCondUnRed,
+    // Sparse Conditions.
+    SparseCond = kSparseCond,
+    SparseSliceCond = kSparseCond | kSliceCond,
+    SparseAffineCond = kSparseCond | kAffineIdxCond,
+    SparseAffineUnRedCond = kSparseCond | kAffineIdxCond | kAffineIdxCondUnRed,
+  };
+  using TensorLvlCond = std::pair<TensorLevel, LoopCondKind>;
+
+  /// Sparse or dense loop condition.
+  static bool isSparseCond(LoopCondKind k) {
+    return static_cast<uint8_t>(k) & kSparseCond;
+  }
+  static bool isDenseCond(LoopCondKind k) { return !isSparseCond(k); }
+
+  /// Whether loops over sparse tensor slices or sparse tensors.
+  static bool isSliceCond(LoopCondKind k) {
+    return static_cast<uint8_t>(k) & kSliceCond;
+  }
+
+  /// Affine or trivial index expression loop condition.
+  static bool isAffineIdxCond(LoopCondKind k) {
+    return static_cast<uint8_t>(k) & kAffineIdxCond;
+  }
+  static bool isTrivalIdxCond(LoopCondKind k) { return !isAffineIdxCond(k); }
+
+  /// Whether the affine index expression is not fully reduced.
+  static bool isAffineIdxUnRedCond(LoopCondKind k) {
+    return isAffineIdxCond(k) && static_cast<uint8_t>(k) & kAffineIdxCondUnRed;
+  }
+  static bool isAffineIdxRedCond(LoopCondKind k) {
+    return isAffineIdxCond(k) && !isAffineIdxUnRedCond(k);
+  }
+
+  // Whether the loop condition kind requires extra check inside the loop body.
+  // E.g., to iterate over sparse tensor slice, we need to check whether the
+  // current cooridnate is on the slice (e.g., due to stride) or not.
+  static bool isCondWithExtraCheck(LoopCondKind k) {
+    return isSparseCond(k) && isSliceCond(k);
+  }
+
+  static LoopCondKind makeLoopCondKind(bool isSparse, bool isSlice,
+                                       bool isAffine, bool isUnRedu) {
+    assert(!isUnRedu || isAffine);
+    uint8_t bits = 0;
+    bits = isSparse ? bits | kSparseCond : bits;
+    bits = isSlice ? bits | kSliceCond : bits;
+    bits = isAffine ? bits | kAffineIdxCond : bits;
+    bits = isUnRedu ? bits | kAffineIdxCondUnRed : bits;
+    LoopCondKind kind = static_cast<LoopCondKind>(bits);
+
+    // Sanity checks.
+    assert(isSparse == isSparseCond(kind));
+    assert(isSlice == isSliceCond(kind));
+    assert(isAffine == isAffineIdxCond(kind));
+    assert(isUnRedu == isAffineIdxUnRedCond(kind));
+    return kind;
+  }
+
+  void categorizeLoopCondition(ArrayRef<TensorLevel> tidLvls,
+                               SmallVectorImpl<TensorLvlCond> &dnConds,
+                               SmallVectorImpl<TensorLvlCond> &spConds);
+
+  ///
+  /// LoopEmitter internal helper functions.
+  ///
+
   using LoopBodyBuilder = llvm::function_ref<void(OpBuilder &, Location, Value,
                                                   MutableArrayRef<Value>)>;
 
+  /// Whether the list of the sparse condition should be iterated by for loop.
+  bool shouldIteratedByForLoop(ArrayRef<TensorLvlCond> spConds, bool genDedup);
+
   /// Linearizes address for dense dimension (i.e., p = (i * d0) + j).
   Value genAddress(OpBuilder &builder, Location loc, TensorId tid, Level lvl,
                    Value iv);
@@ -354,31 +468,51 @@ class LoopEmitter {
   void prepareLoopOverTensorAtLvl(OpBuilder &builder, Location loc,
                                   TensorId tid, Level lvl);
 
-  /// Emits extra locals, since the locals might not be in simplified lattices
-  /// point used to generate the loops, but are still required to generate
-  /// expressions.
-  void emitExtraLocalsForTensorsAtDenseLvls(OpBuilder &builder, Location loc,
-                                            ArrayRef<TensorLevel> tidLvls);
-
-  /// Emits a for loop to iterate over a tensor level with the provided lower
-  /// bound `lo` and upper bound `hi`.
-  /// Apart from iterating just single tensor level, for loops can be used for
-  /// slice-driven loop on dense level too.
-  Operation *emitForLoopOverTensorAtLvl(OpBuilder &builder, Location loc,
-                                        TensorId tid, Level lvl, Value lo,
-                                        Value hi, MutableArrayRef<Value> reduc,
-                                        bool isParallel);
-
-  /// Emits a while loop to iterate over a sparse level that has been sliced.
-  /// Inserts break statement when the coordinate exceeds the sliceSize;
-  /// The method sets the insertion point inside the generated while loop body
-  /// after the break statement before return (so that callers need to handle
-  /// only in-bound coordinates).
-  Operation *emitWhileLoopOverSliceAtSparseLvl(OpBuilder &builder, Location loc,
-                                               Value pLo, Value pHi,
-                                               Value offset, Value sliceSize,
-                                               TensorId tid, Level lvl,
-                                               MutableArrayRef<Value> reduc);
+  /// Enter dense tensor levels. Since the dense tensor condition could be
+  /// optimized from the loop condition, we need to compute the
+  /// positions/coordinates inside the loop body.
+  void enterTensorsAtDenseLvls(OpBuilder &builder, Location loc,
+                               ArrayRef<TensorLvlCond> dnConds, Value iv,
+                               SmallVectorImpl<SliceLoopInfo> &sliceInfo);
+
+  /// Emits a for loop to iterate over a tensor level with the provided
+  /// lower bound `lo` and upper bound `hi`. Apart from iterating just
+  /// single tensor level, for loops can be used for slice-driven loop on
+  /// dense level too.
+  /// Returns a pair: the loop generated and the value for the induction
+  /// variable.
+  std::pair<Operation *, Value>
+  emitForLoopOverTensorAtLvl(OpBuilder &builder, Location loc, TensorId tid,
+                             Level lvl, Value lo, Value hi,
+                             MutableArrayRef<Value> reduc, bool isParallel);
+
+  /// Emits a while loop to co-iterate over a list of sparse condition, or
+  /// (complex) single sparse condition that can not be handled by for loop
+  /// (e.g., index reduction loop).
+  /// Returns a pair: the loop generated and the value for the induction
+  /// variable (which is the minimum coordinate of all the tensor that being
+  /// iterated).
+  std::pair<Operation *, Value>
+  emitWhileLoopOverTensorsAtLvls(OpBuilder &builder, Location loc,
+                                 ArrayRef<TensorLvlCond> spConds,
+                                 MutableArrayRef<Value> reduc, bool needsUniv);
+
+  /// Generates the while loop condition for the given tensor level condition.
+  Value genWhileLoopConditions(OpBuilder &builder, Location loc, ValueRange ivs,
+                               TensorLvlCond cond);
+
+  /// Generates the while loop body for the given tensor level condition.
+  std::optional<Value> genWhileLoopBody(OpBuilder &builder, Location loc,
+                                        ValueRange ivs, TensorLvlCond cond);
+
+  /// Generates the values (to forward the loop) if the extra check failes.
+  /// E.g., to iterate over a sparse tensor slice, we need:
+  ///
+  /// pos = onSlice(curCrd) ? pos : pos + 1
+  ///
+  /// to skip invalid coordinate that is included in the slice.
+  ValueRange genCheckedValue(OpBuilder &builder, Location loc, Value pred,
+                             ValueRange curArg, TensorLvlCond cond);
 
   /// Exits a for loop, returns the reduction results, e.g.,
   /// For sequential for loops:
@@ -488,7 +622,7 @@ class LoopEmitter {
   std::pair<Operation *, ValueRange>
   genSliceLvlTraverseLoop(OpBuilder &builder, Location loc, Value pLo,
                           Value pHi, Value offset, Value size, TensorId tid,
-                          Level lvl, ValueRange userReduc, bool genYield,
+                          Level lvl, ValueRange userReduc,
                           LoopBodyBuilder bodyBuilder);
 
   /// Generates a nested loop that iterates over tid on all the coordinates on
@@ -530,19 +664,6 @@ class LoopEmitter {
                              SmallVectorImpl<Value> &operands,
                              unsigned &retIdx);
 
-  /// Generates a slice-driven while loop as follows.
-  ///
-  /// curSlice = getFirstNonEmptySlice(tensor).
-  ///
-  /// while(isNonEmpty) {
-  ///   ..user code..
-  ///   isNonEmpty, curSlice = getNextNonEmptySlice(curSlice)
-  /// }
-  Operation *emitSliceDrivenLoopOverTensorAtLvl(OpBuilder &builder,
-                                                Location loc, TensorId tid,
-                                                Level lvl,
-                                                MutableArrayRef<Value> reduc);
-
   /// A optional string attribute that should be attached to the loop
   /// generated by loop emitter, it might help following passes to identify
   /// loops that operates on sparse tensors more easily.

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 2c5289aa8ae16..bb9c15ea463da 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -1180,7 +1180,8 @@ struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
       loopEmitter.enterNewLoopSeq(rewriter, loc, tidLvls);
       // Note that reduc will be taken care of by loop emitter and get updated
       // in place.
-      loopEmitter.enterLoopOverTensorAtLvl(rewriter, loc, tidLvls, reduc);
+      loopEmitter.enterCoIterationOverTensorsAtLvls(rewriter, loc, tidLvls,
+                                                    reduc);
     }
 
     SmallVector<Value> lcvs;

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 891befdcb51ea..637c16f92a293 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -1306,12 +1306,11 @@ static bool isParallelFor(CodegenEnv &env, bool isOuter, bool isSparse) {
   llvm_unreachable("unexpected parallelization strategy");
 }
 
-/// Generates a for-loop on a single index.
-static Operation *genFor(CodegenEnv &env, OpBuilder &builder, bool isOuter,
-                         bool isInner, LoopId ldx,
-                         ArrayRef<TensorLevel> tidLvls) {
+/// Whether or not the current loop being generated should be parallized (if
+/// possible) according to the configuration.
+static bool shouldTryParallize(CodegenEnv &env, LoopId ldx, bool isOuter,
+                               ArrayRef<TensorLevel> tidLvls) {
   linalg::GenericOp op = env.op();
-  Location loc = op.getLoc();
   auto iteratorTypes = op.getIteratorTypesArray();
   bool isSparse = llvm::any_of(tidLvls, [ldx, &env](TensorLevel tidLvl) {
     // Queries the DLT based on the tensor id and loop idx, as requested by
@@ -1321,38 +1320,44 @@ static Operation *genFor(CodegenEnv &env, OpBuilder &builder, bool isOuter,
     return isCompressedDLT(dlt) || isSingletonDLT(dlt);
   });
 
-  bool isParallel = isParallelFor(env, isOuter, isSparse);
+  return isParallelFor(env, isOuter, isSparse);
+}
 
+/// Generates a "filter loop" on the given tid level to locate a coordinate that
+/// is of the same value as evaluated by the affine expression in its matching
+/// indexing map.
+static Operation *genFilterLoop(CodegenEnv &env, OpBuilder &builder, LoopId ldx,
+                                TensorLevel tidLvl) {
+  linalg::GenericOp op = env.op();
+  Location loc = op.getLoc();
   Operation *loop = *env.genLoopBoundary([&](MutableArrayRef<Value> reduc) {
-    if (env.merger().isFilterLoop(ldx)) {
-      const auto [tid, lvl] = env.unpackTensorLevel(tidLvls.front());
-      // tids/lvls must only have one value because filter loops only
-      // corresponding to the one and only sparse tensor level.
-      assert(isSparse && tidLvls.size() == 1);
-      OpOperand *t = &op->getOpOperand(tid);
-      auto enc = getSparseTensorEncoding(t->get().getType());
-      // Retrieves the affine expression for the filter loop.
-      // FIXME: `toOrigDim` is deprecated.
-      AffineExpr a =
-          op.getMatchingIndexingMap(t).getResult(toOrigDim(enc, lvl));
-      return env.emitter().enterFilterLoopOverTensorAtLvl(builder, loc, tid,
-                                                          lvl, a, reduc);
-    }
-    return env.emitter().enterLoopOverTensorAtLvl(builder, loc, tidLvls, reduc,
-                                                  isParallel);
+    assert(env.merger().isFilterLoop(ldx));
+    const auto [tid, lvl] = env.unpackTensorLevel(tidLvl);
+    // tids/lvls must only have one value because filter loops only
+    // corresponding to the one and only sparse tensor level.
+    OpOperand *t = &op->getOpOperand(tid);
+    auto enc = getSparseTensorEncoding(t->get().getType());
+    // Retrieves the affine expression for the filter loop.
+    // FIXME: `toOrigDim` is deprecated.
+    AffineExpr a = op.getMatchingIndexingMap(t).getResult(toOrigDim(enc, lvl));
+    return env.emitter().enterFilterLoopOverTensorAtLvl(builder, loc, tid, lvl,
+                                                        a, reduc);
   });
-  assert(loop);
   return loop;
 }
 
-/// Emit a while-loop for co-iteration over multiple indices.
-static Operation *genWhile(CodegenEnv &env, OpBuilder &builder, LoopId idx,
-                           bool needsUniv, ArrayRef<TensorLevel> tidLvls) {
+/// Emit a loop to coiterate over the list of tensor levels. The generated loop
+/// can either be a for loop or while loop depending on whether there is at most
+/// one sparse level in the list.
+static Operation *genCoIteration(CodegenEnv &env, OpBuilder &builder,
+                                 LoopId idx, ArrayRef<TensorLevel> tidLvls,
+                                 bool tryParallel, bool needsUniv) {
   Operation *loop = *env.genLoopBoundary([&](MutableArrayRef<Value> reduc) {
     // Construct the while-loop with a parameter for each
     // index.
     return env.emitter().enterCoIterationOverTensorsAtLvls(
-        builder, env.op().getLoc(), tidLvls, needsUniv, reduc);
+        builder, env.op().getLoc(), tidLvls, reduc, tryParallel,
+        /*genDedup=*/true, needsUniv);
   });
   assert(loop);
   return loop;
@@ -1361,15 +1366,15 @@ static Operation *genWhile(CodegenEnv &env, OpBuilder &builder, LoopId idx,
 /// Generates a for-loop or a while-loop, depending on whether it implements
 /// singleton iteration or co-iteration over the given conjunction.
 static Operation *genLoop(CodegenEnv &env, OpBuilder &builder, LoopOrd at,
-                          bool needsUniv, ArrayRef<TensorLevel> tidLvls,
-                          bool isFor) {
-  const LoopId idx = env.topSortAt(at);
-  if (isFor) {
-    bool isOuter = at == 0;
-    bool isInner = at == env.topSortSize() - 1;
-    return genFor(env, builder, isOuter, isInner, idx, tidLvls);
+                          bool needsUniv, ArrayRef<TensorLevel> tidLvls) {
+  const LoopId ldx = env.topSortAt(at);
+  if (env.merger().isFilterLoop(ldx)) {
+    assert(tidLvls.size() == 1);
+    return genFilterLoop(env, builder, ldx, tidLvls.front());
   }
-  return genWhile(env, builder, idx, needsUniv, tidLvls);
+
+  bool tryParallel = shouldTryParallize(env, ldx, at == 0, tidLvls);
+  return genCoIteration(env, builder, ldx, tidLvls, tryParallel, needsUniv);
 }
 
 /// Generates the induction structure for a while-loop.
@@ -1684,7 +1689,7 @@ static std::pair<Operation *, bool> startLoop(CodegenEnv &env,
                                                  tidLvls, affineTidLvls);
 
   // Emit the for/while-loop control.
-  Operation *loop = genLoop(env, builder, at, needsUniv, tidLvls, isSingleCond);
+  Operation *loop = genLoop(env, builder, at, needsUniv, tidLvls);
   Location loc = env.op().getLoc();
   for (auto [tidLvl, exp] : affineTidLvls) {
     env.emitter().genDenseAffineAddress(builder, loc, tidLvl, exp);

diff  --git a/mlir/test/Dialect/SparseTensor/sorted_coo.mlir b/mlir/test/Dialect/SparseTensor/sorted_coo.mlir
index e4e65ef4b4e71..01189328eeb61 100644
--- a/mlir/test/Dialect/SparseTensor/sorted_coo.mlir
+++ b/mlir/test/Dialect/SparseTensor/sorted_coo.mlir
@@ -185,8 +185,6 @@ func.func @matvec(%arga: tensor<32x64xf64, #SortedCOO>,
 // CHECK:           ^bb0(%[[VAL_26:.*]]: index, %[[VAL_27:.*]]: index):
 // CHECK:             %[[VAL_28:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_26]]] : memref<?xindex, strided<[?], offset: ?>>
 // CHECK:             %[[VAL_29:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_27]]] : memref<?xindex, strided<[?], offset: ?>>
-// CHECK:             %[[VAL_30:.*]] = arith.cmpi ult, %[[VAL_29]], %[[VAL_28]] : index
-// CHECK:             %[[VAL_31:.*]] = arith.select %[[VAL_30]], %[[VAL_29]], %[[VAL_28]] : index
 // CHECK:             %[[VAL_32:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_26]]] : memref<?xindex, strided<[?], offset: ?>>
 // CHECK:             %[[VAL_33:.*]] = scf.while (%[[VAL_34:.*]] = %[[VAL_26]]) : (index) -> index {
 // CHECK:               %[[VAL_35:.*]] = arith.cmpi ult, %[[VAL_34]], %[[VAL_17]] : index
@@ -219,6 +217,8 @@ func.func @matvec(%arga: tensor<32x64xf64, #SortedCOO>,
 // CHECK:               %[[VAL_51:.*]] = arith.addi %[[VAL_50]], %[[VAL_6]] : index
 // CHECK:               scf.yield %[[VAL_51]] : index
 // CHECK:             }
+// CHECK:             %[[VAL_30:.*]] = arith.cmpi ult, %[[VAL_29]], %[[VAL_28]] : index
+// CHECK:             %[[VAL_31:.*]] = arith.select %[[VAL_30]], %[[VAL_29]], %[[VAL_28]] : index
 // CHECK:             %[[VAL_52:.*]] = arith.cmpi eq, %[[VAL_28]], %[[VAL_31]] : index
 // CHECK:             %[[VAL_53:.*]] = arith.cmpi eq, %[[VAL_29]], %[[VAL_31]] : index
 // CHECK:             %[[VAL_54:.*]] = arith.andi %[[VAL_52]], %[[VAL_53]] : i1


        


More information about the Mlir-commits mailing list