[Mlir-commits] [mlir] fd68d36 - [mlir][sparse] unifying enterLoopOverTensorAtLvl and enterCoIterationOverTensorsAtLvls
Peiming Liu
llvmlistbot at llvm.org
Wed Jun 14 13:03:17 PDT 2023
Author: Peiming Liu
Date: 2023-06-14T20:03:10Z
New Revision: fd68d36109c6fcebb6d758046b88b0664acccf51
URL: https://github.com/llvm/llvm-project/commit/fd68d36109c6fcebb6d758046b88b0664acccf51
DIFF: https://github.com/llvm/llvm-project/commit/fd68d36109c6fcebb6d758046b88b0664acccf51.diff
LOG: [mlir][sparse] unifying enterLoopOverTensorAtLvl and enterCoIterationOverTensorsAtLvls
The tensor levels are now explicitly categorized into different `LoopCondKind` to instruct LoopEmitter generate different code for different kinds of condition (e.g., `SparseCond`, `SparseSliceCond`, `SparseAffineIdxCond`, etc)
The process of generating a while loop is now dissembled into three steps and they are dispatched to different LoopCondKind handler.
1. Generate LoopCondition (e.g., `pos <= posHi` for `SparseCond`, `slice.isNonEmpty` for `SparseAffineIdxCond`)
2. Generate LoopBody (e.g., compute the coordinates)
3. Generate ExtraChecks (e.g., `if (onSlice(crd))` for `SparseSliceCond`)
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D152464
Added:
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
mlir/test/Dialect/SparseTensor/sorted_coo.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
index f466cce68a34b..d0884ca482de2 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
@@ -101,6 +101,28 @@ static std::pair<Value, Value> fromSliceCrd(OpBuilder &builder, Location loc,
return std::make_pair(crd, rem);
}
+// Generates a bool value for while loop condition that tries to iterate over a
+// fully reduced level with affine index expression.
+static Value genSparseReducedAffineCond(OpBuilder &builder, Location loc,
+ Value crdBuf, Value crdHi, Value posit,
+ Value posHi) {
+ Value inBound = CMPI(ult, posit, posHi);
+ auto ifOp =
+ builder.create<scf::IfOp>(loc, builder.getI1Type(), inBound, true);
+ // if (inbound)
+ // yield coord < crdHi
+ builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+ Value crd = genIndexLoad(builder, loc, crdBuf, posit);
+ YIELD(CMPI(ult, crd, crdHi));
+ // else
+ // yield false
+ builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
+ YIELD(constantI1(builder, loc, false));
+
+ builder.setInsertionPointAfter(ifOp);
+ return ifOp.getResult(0);
+}
+
std::pair<Value, Value>
LoopEmitter::genSliceLegitPredicate(OpBuilder &builder, Location loc, Value crd,
TensorId tid, Level lvl) {
@@ -470,6 +492,41 @@ void LoopEmitter::initializeLoopEmit(OpBuilder &builder, Location loc,
localInsertPos = builder.getInsertionPoint()->getPrevNode();
}
+void LoopEmitter::categorizeLoopCondition(
+ ArrayRef<TensorLevel> tidLvls, SmallVectorImpl<TensorLvlCond> &dnConds,
+ SmallVectorImpl<TensorLvlCond> &spConds) {
+ // Finds out the tensor level that we should use to generate loops. Amongs all
+ // the tensor levels, there is at most one sparse tensor level.
+ for (auto [t, l] : unpackTensorLevelRange(tidLvls)) {
+ assert(lvlTypes[t].size() > l); // Must be a valid tid, dim pair
+ auto lvlType = lvlTypes[t][l];
+ // Must be a recognizable DLT.
+ assert(isDenseDLT(lvlType) || isCompressedDLT(lvlType) ||
+ isCompressedWithHiDLT(lvlType) || isSingletonDLT(lvlType));
+
+ bool isSparse = !isDenseDLT(lvlType);
+ bool isSlice = isSparseSlices[t];
+ bool isAffine = !dependentLvlMap[t][l].empty();
+ bool isUnRedu = false;
+ // TODO: Supports affine index expression on sparse tensor slices.
+ assert(!isSlice || !isAffine);
+
+ // Whether the affine index expression has been fully reduced or not.
+ if (!dependentLvlMap[t][l].empty())
+ isUnRedu = !depFullyReduced(t, l);
+
+ auto &dstVec = isSparse ? spConds : dnConds;
+ dstVec.emplace_back(
+ makeTensorLevel(t, l),
+ makeLoopCondKind(isSparse, isSlice, isAffine, isUnRedu));
+ }
+
+ std::sort(spConds.begin(), spConds.end(), [](auto lhs, auto rhs) {
+ // AffineUnRed > Affine > Slice > Trivial
+ return static_cast<uint8_t>(lhs.second) > static_cast<uint8_t>(rhs.second);
+ });
+}
+
void LoopEmitter::enterNewLoopSeq(OpBuilder &builder, Location loc,
ArrayRef<TensorLevel> tidLvls) {
// TODO: sort
@@ -561,7 +618,7 @@ Value LoopEmitter::genAffine(OpBuilder &builder, Location loc, AffineExpr a) {
}
}
-Operation *LoopEmitter::emitForLoopOverTensorAtLvl(
+std::pair<Operation *, Value> LoopEmitter::emitForLoopOverTensorAtLvl(
OpBuilder &builder, Location loc, TensorId tid, Level dstLvl, Value lo,
Value hi, MutableArrayRef<Value> reduc, bool isParallel) {
bool isSparseCond = isCompressedDLT(lvlTypes[tid][dstLvl]) ||
@@ -651,166 +708,433 @@ Operation *LoopEmitter::emitForLoopOverTensorAtLvl(
assert(crd);
coords[tid][dstLvl] = crd;
- return loop;
+ return {loop, crd};
}
-Operation *LoopEmitter::emitWhileLoopOverSliceAtSparseLvl(
- OpBuilder &builder, Location loc, Value pLo, Value pHi, Value offset,
- Value sliceSize, TensorId tid, Level lvl, MutableArrayRef<Value> reduc) {
- // TODO: we should generalize the method to support iteration over for
- // normal slices as well to allow early break.
- Operation *insertPoint = nullptr;
- Operation *loop =
- genSliceLvlTraverseLoop(
- builder, loc, pLo, pHi, offset, sliceSize, tid, lvl, reduc,
- /*genYield=*/false, // unaware of the yield values from user yet
- [this, tid, lvl, reduc, offset,
- &insertPoint](OpBuilder &builder, Location loc, Value iv,
- MutableArrayRef<Value> innerReduc) {
- assert(innerReduc.size() == reduc.size());
- // Updates users' reduction variable inplace
- for (unsigned i = 0, e = reduc.size(); i < e; i++)
- reduc[i] = innerReduc[i];
- // Loads the coordinates.
- Value absC =
- genIndexLoad(builder, loc, coordinatesBuffers[tid][lvl], iv);
-
- // We need to substract the offset to get relative coordinates.
- // TODO: how to assert relC >=0 during runtime?
- insertPoint = builder.create<arith::SubIOp>(loc, absC, offset);
- posits[tid][lvl] = iv;
- coords[tid][lvl] = insertPoint->getResult(0);
- })
- .first;
- // Sets the insertionn pointer inside loop body.
- builder.setInsertionPointAfter(insertPoint);
- return loop;
+Value LoopEmitter::genWhileLoopConditions(OpBuilder &builder, Location loc,
+ ValueRange ivs, TensorLvlCond cond) {
+ auto [tid, lvl] = unpackTensorLevel(cond.first);
+
+ switch (cond.second) {
+ case LoopCondKind::SparseCond: {
+ const auto reassoc = getCollapseReassociation(tid, lvl);
+ assert(reassoc.size() == ivs.size());
+ assert(reassoc.size() == 1 || isUniqueCOOType(tensors[tid].getType()));
+ // We used the first level bound as the bound the collapsed set of levels.
+ return CMPI(ult, ivs.back(), highs[tid][reassoc.front()]);
+ }
+ case LoopCondKind::SparseSliceCond: {
+ assert(ivs.size() == 1);
+ return CMPI(ult, ivs.back(), highs[tid][lvl]);
+ }
+ case LoopCondKind::SparseAffineCond: {
+ assert(ivs.size() == 1);
+ Value crdHi; // loop upper bound
+ {
+ OpBuilder::InsertionGuard guard(builder);
+ Operation *loop = builder.getInsertionBlock()->getParentOp();
+ // crdHi is a loop invariant, hosit the computation outside the loop.
+ if (llvm::isa_and_nonnull<scf::WhileOp>(loop))
+ builder.setInsertionPoint(loop);
+ crdHi = ADDI(getMostRecentSliceOnLvl(tid, lvl).offset,
+ sliceSizes[tid][lvl].back());
+ }
+ assert(crdHi);
+ return genSparseReducedAffineCond(builder, loc,
+ coordinatesBuffers[tid][lvl], crdHi,
+ ivs[0], highs[tid][lvl]);
+ }
+ case LoopCondKind::SparseAffineUnRedCond: {
+ assert(ivs.size() == 3);
+ return ivs.front(); // isNonEmpty
+ }
+ default:
+ llvm_unreachable("Unhandled LoopCondKind");
+ }
+ llvm_unreachable("Unhandled LoopCondKind");
}
-Operation *LoopEmitter::enterLoopOverTensorAtLvl(OpBuilder &builder,
- Location loc,
- ArrayRef<TensorLevel> tidLvls,
- MutableArrayRef<Value> reduc,
- bool isParallel) {
- // TODO: support multiple return on parallel for?
- assert(!isParallel || reduc.size() <= 1);
- bool isSparseCond = false, isSparseSliceCond = false;
- auto [tid, lvl] = unpackTensorLevel(tidLvls.front());
+std::optional<Value> LoopEmitter::genWhileLoopBody(OpBuilder &builder,
+ Location loc, ValueRange ivs,
+ TensorLvlCond cond) {
+ auto [tid, lvl] = unpackTensorLevel(cond.first);
- // Finds out the tensor level that we should use to generate loops. Amongs all
- // the tensor levels, there is at most one sparse tensor level.
- for (auto [t, l] : unpackTensorLevelRange(tidLvls)) {
- assert(lvlTypes[t].size() > l); // Must be a valid tid, dim pair
- assert(!coords[t][l] || // We cannot re-enter the same level
- !dependentLvlMap[t][l].empty()); // unless it is a slice-driver loop
- auto lvlType = lvlTypes[t][l];
- // Must be a recognizable DLT.
- assert(isDenseDLT(lvlType) || isCompressedDLT(lvlType) ||
- isCompressedWithHiDLT(lvlType) || isSingletonDLT(lvlType));
+ switch (cond.second) {
+ case LoopCondKind::SparseCond: {
+ const auto reassoc = getCollapseReassociation(tid, lvl);
+ assert(reassoc.size() == 1 || isUniqueCOOType(tensors[tid].getType()));
+ // Links the SSA chain for segHi.
+ for (unsigned i = 0, e = reassoc.size() - 1; i < e; i++)
+ if (!isUniqueDLT(lvlTypes[tid][reassoc[i]]))
+ segHi[tid][reassoc[i]] = ivs[i];
+
+ // Updates position. For collapsed COO, the position is the same across
+ // consecutive levels.
+ for (auto srcLvl : reassoc)
+ posits[tid][srcLvl] = ivs.back();
+
+ // Update coordinates.
+ coords[tid][lvl] = genSparseCrd(builder, loc, tid, lvl);
+ return std::nullopt;
+ }
+ case LoopCondKind::SparseSliceCond: {
+ assert(ivs.size() == 1);
+ posits[tid][lvl] = ivs.front();
+ Value sCrd = genSparseCrd(builder, loc, tid, lvl);
+ // Converts the coordinate loaded from the actual sparse tensor to the
+ // coordinates in the sparse slice.
+ auto [dCrd, pred] = genSliceLegitPredicate(builder, loc, sCrd, tid, lvl);
+ coords[tid][lvl] = dCrd;
+ return pred;
+ }
+ case LoopCondKind::SparseAffineCond: {
+ assert(ivs.size() == 1);
+ // Coord is the relative offset related to its parents.
+ assert(sliceStack[tid].back().depth == 1 && "TODO: not yet implement");
+ // Update c = absOffset[lvl][depth] - absOffset[lvl][depth - 1]
+ Value posit = ivs[0];
+ Value crdBuf = coordinatesBuffers[tid][lvl];
+ // We need to substract the offset to get relative coordinates.
+ // TODO: Maybe assert relC >=0 during runtime in debug build?
+ Value absC = genIndexLoad(builder, loc, crdBuf, posit);
+ auto relC = SUBI(absC, getFinalSliceOnLvl(tid, lvl).offset);
+ posits[tid][lvl] = posit;
+ coords[tid][lvl] = relC;
+ return std::nullopt;
+ }
+ case LoopCondKind::SparseAffineUnRedCond: {
+ assert(ivs.size() == 3);
+ // Coord is the relative offset related to its parents.
+ // Update c = absOffset[lvl][depth] - absOffset[lvl][depth - 1]
+ assert(sliceStack[tid].back().depth == 1 && "TODO: not yet implement");
+ // Updates the current slice info
+ SliceInfo &sliceInfo = sliceStack[tid].back();
+ sliceInfo.isNonEmpty = ivs[0];
+ sliceInfo.minCrd = ivs[1];
+ sliceInfo.offset = ivs[2];
+ coords[tid][lvl] = sliceInfo.offset;
+ // No extra check is needed before accessing the tensor level.
+ return std::nullopt;
+ }
+ default:
+ llvm_unreachable("Unhandled LoopCondKind");
+ }
+ llvm_unreachable("Unhandled LoopCondKind");
+}
- // This is a slice-driven loop on sparse level.
- if (!dependentLvlMap[t][l].empty() && !isDenseDLT(lvlType)) {
- assert(!isSparseSliceCond && !isSparseCond);
- isSparseSliceCond = true;
- tid = t;
- lvl = l;
- continue;
+ValueRange LoopEmitter::genCheckedValue(OpBuilder &builder, Location loc,
+ Value pred, ValueRange curArgs,
+ TensorLvlCond cond) {
+ // Currently only sparse slice condition need extra check.
+ assert(isSliceCond(cond.second) && isSparseCond(cond.second));
+ assert(curArgs.size() == 1);
+ Value nextPos = ADDI(curArgs.front(), C_IDX(1));
+ return SELECT(pred, curArgs.front(), nextPos)->getResults();
+}
+
+std::pair<Operation *, Value> LoopEmitter::emitWhileLoopOverTensorsAtLvls(
+ OpBuilder &builder, Location loc, ArrayRef<TensorLvlCond> spConds,
+ MutableArrayRef<Value> reduc, bool needsUniv) {
+ // NOTE: the slice driven tensor-related reduction variable must
+ // appear before normal tensors.
+ assert(!spConds.empty());
+
+ // The set of induction variables for the while loop.
+ SmallVector<Value> ivs;
+ // Segement sizes for induction variables used for
diff erent kinds of loop
+ // conditions.
+ SmallVector<unsigned> opSegSize;
+
+ // Construct the while-loop with a parameter for each coordinate.
+ for (auto [tl, cKind] : spConds) {
+ auto [tid, lvl] = unpackTensorLevel(tl);
+ const auto lvlTp = lvlTypes[tid][lvl];
+ // Dense level are handled by the shared univeral index.
+ assert(!isDenseCond(cKind));
+ // Must be a recognizable sparse level.
+ assert(isCompressedDLT(lvlTp) || isCompressedWithHiDLT(lvlTp) ||
+ isSingletonDLT(lvlTp));
+
+ unsigned prevSz = ivs.size();
+ const auto reassoc = getCollapseReassociation(tid, lvl);
+ if (isAffineIdxCond(cKind)) {
+ // TODO: Support view-based reshape on sparse levels with affine index
+ // expressions.
+ assert(reassoc.size() == 1);
+ if (isAffineIdxUnRedCond(cKind)) {
+ SliceInfo &sliceInfo = sliceStack[tid].back();
+ // The order matters!
+ ivs.push_back(sliceInfo.isNonEmpty);
+ ivs.push_back(sliceInfo.minCrd);
+ ivs.push_back(sliceInfo.offset);
+ } else {
+ ivs.push_back(posits[tid][lvl]); // loop lower bound (pos low).
+ }
+ // We reduced one more dependency after entering the loop.
+ levelReducedDep[tid][lvl]++;
+ } else {
+ assert(dependentLvlMap[tid][lvl].empty());
+ for (unsigned i = 0, e = reassoc.size() - 1; i < e; i++) {
+ // This is the segment high for each non-unique levels.
+ if (!isUniqueDLT(lvlTypes[tid][reassoc[i]]))
+ ivs.push_back(C_IDX(0));
+ }
+ const Value pos = posits[tid][reassoc.front()];
+ ivs.push_back(pos);
}
+ opSegSize.push_back(ivs.size() - prevSz);
+ }
+
+ // The position where user-supplied reduction variable starts.
+ ivs.append(reduc.begin(), reduc.end());
+ // Update universal index.
+ if (needsUniv)
+ ivs.push_back(loopSeqStack.back().first);
+
+ // Ensures all operands are valid.
+ assert(llvm::all_of(ivs, [](Value v) { return v != nullptr; }));
+ TypeRange types = ValueRange(ivs).getTypes();
+ auto whileOp = builder.create<scf::WhileOp>(loc, types, ivs);
+
+ SmallVector<Location> locs(types.size(), loc);
+ Block *before = builder.createBlock(&whileOp.getBefore(), {}, types, locs);
+ Block *after = builder.createBlock(&whileOp.getAfter(), {}, types, locs);
- bool isSparse = isCompressedDLT(lvlType) || isSingletonDLT(lvlType) ||
- isCompressedWithHiDLT(lvlType);
- // We can at most have one sparse input, otherwise, a while loop is
- // required to co-iterate multiple sparse tensors.
- assert(!isSparseCond || !isSparse);
- assert(!isSparseSliceCond || !isSparseCond);
- if (isSparse) {
- tid = t;
- lvl = l;
+ // Generates loop conditions.
+ builder.setInsertionPointToStart(before);
+ ValueRange bArgs = before->getArguments();
+ Value whileCond = nullptr; // bool values for loop condition.
+ for (auto [c, segSz] : llvm::zip_equal(spConds, opSegSize)) {
+ Value cv = genWhileLoopConditions(builder, loc, bArgs.take_front(segSz), c);
+ bArgs = bArgs.drop_front(segSz);
+ whileCond = !whileCond ? cv : ANDI(whileCond, cv);
+ }
+ // The remaining block arguments are user-provided reduction values and an
+ // optional universal index. Make sure their sizes match.
+ assert(bArgs.size() == reduc.size() + needsUniv ? 1 : 0);
+ builder.create<scf::ConditionOp>(loc, whileCond, before->getArguments());
+
+ // Generates loop body.
+ builder.setInsertionPointToStart(after);
+ ValueRange aArgs = after->getArguments();
+ // Since some LoopCondKind might need extra checks to filter out invalid
+ // iterations, we maintains another array to hold the iteration arguments to
+ // yield if the checks fails.
+ SmallVector<Value> nextArgs(aArgs.begin(), aArgs.end());
+ // A mutable alias for convenient slicing.
+ MutableArrayRef<Value> nextArgsRef = nextArgs;
+ Value extraPred = nullptr;
+ for (auto [c, segSz] : llvm::zip_equal(spConds, opSegSize)) {
+ ValueRange condArgs = aArgs.take_front(segSz);
+ auto pred = genWhileLoopBody(builder, loc, condArgs, c);
+ assert(pred.has_value() == isCondWithExtraCheck(c.second));
+ if (pred.has_value()) {
+ // We need all extra checks to pass.
+ extraPred = extraPred == nullptr ? *pred : ANDI(*pred, extraPred);
+ ValueRange nxArgs = genCheckedValue(builder, loc, *pred, condArgs, c);
+ assert(nxArgs.size() == segSz);
+ // Update the value for cases when some check fails.
+ for (unsigned i = 0; i < segSz; i++) {
+ nextArgsRef[i] = nxArgs[i];
+ }
}
- isSparseCond = isSparseCond || isSparse;
+ aArgs = aArgs.drop_front(segSz);
+ nextArgsRef = nextArgsRef.drop_front(segSz);
}
- DimLevelType lvlType = lvlTypes[tid][lvl];
- // TODO: Dense slice driven loop can be generated using for loop as well.
- assert(!isSparseSliceCond || !isDenseDLT(lvlType));
- bool isDenseSliceCond =
- isDenseDLT(lvlType) && !dependentLvlMap[tid][lvl].empty();
- // if the slice is fully reduced, we can now use TACO-based algorithm to
- // iterate it.
+ if (extraPred) {
+ auto ifOp = builder.create<scf::IfOp>(loc, types, extraPred, /*else*/ true);
+ // Marks this special IfOp so that Sparsification does not finalizing it.
+ ifOp->setAttr(getLoopEmitterLoopAttrName(),
+ StringAttr::get(builder.getContext(), "slice"));
+ // Links the SSA chain outside the if statement.
+ YIELD(ifOp->getResults());
- Operation *l = nullptr;
+ // If not all slices are legit, yield the updated value.
+ builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
+ YIELD(nextArgs);
+
+ // If all slices are legit, start the user generated code.
+ builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+ }
+
+ for (auto [tid, dstLvl] : unpackTensorLevelFromCondRange(spConds)) {
+ const auto reassoc = getCollapseReassociation(tid, dstLvl);
+ assert(reassoc.size() == 1 || isUniqueCOOType(tensors[tid].getType()));
+ // TODO: Refactors this into smaller functions.
+ // NOTE: For all the collapsed level (except for the last one, that is why
+ // the loop ends with `reassoc.size() - 1`), as each iteration is advanced
+ // by the segment size of the last level, which does not always invalidate
+ // the segment size for the previous levels, thus we need to propagate the
+ // segment sizes across loop iterations and only forward if needed.
+ //
+ // E.g., for a COO tensor with the following coordinates array.
+ // (0, 0, 1),
+ // (0, 0, 2),
+ // (1, 1, 1),
+ // segHi[lvl=0] = segHi[lvl=1] = 2
+ // segHi[lvl=2] = 1,
+ // the first iteration does not invalidate segHi[0] and segHi[1]
+ for (unsigned i = 0, e = reassoc.size() - 1; i < e; i++) {
+ const Level srcLvl = reassoc[i];
+ if (!isUniqueDLT(lvlTypes[tid][srcLvl])) {
+ const Value pos = posits[tid][srcLvl];
+ const auto oldSegHi = segHi[tid][srcLvl];
+ assert(oldSegHi);
+ Value newSegHi = builder.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::uge, pos, oldSegHi);
+ auto ifNewSegHi = builder.create<scf::IfOp>(loc, builder.getIndexType(),
+ newSegHi, true);
+ {
+ OpBuilder::InsertionGuard guard(builder);
+ builder.setInsertionPointToStart(ifNewSegHi.thenBlock());
+ YIELD(genSegmentHigh(builder, loc, tid, srcLvl, pos,
+ highs[tid][srcLvl]));
+ // Else, resues the same segment high.
+ builder.setInsertionPointToStart(ifNewSegHi.elseBlock());
+ YIELD(oldSegHi);
+ }
+ highs[tid][srcLvl + 1] = segHi[tid][srcLvl] = ifNewSegHi.getResult(0);
+ }
+ };
+ const auto srcLvl = reassoc.back();
+ if (!isUniqueDLT(lvlTypes[tid][srcLvl])) {
+ segHi[tid][srcLvl] = genSegmentHigh(
+ builder, loc, tid, srcLvl, posits[tid][srcLvl], highs[tid][srcLvl]);
+ }
+ }
- // At most one tensor used as condition in for loop;
- SmallVector<TensorLevel, 1> condTidLvl;
- // There might be multiple dense slice driven tensor.
+ // In-place update on reduction variable.
+ assert(aArgs.size() == reduc.size() + needsUniv ? 1 : 0);
+ for (unsigned i = 0, e = reduc.size(); i < e; i++)
+ reduc[i] = aArgs[i];
+
+ Value min;
+ // Finds the minimum coordinate
+ if (!needsUniv) {
+ for (auto [tid, lvl] : unpackTensorLevelFromCondRange(spConds)) {
+ const auto lvlTp = lvlTypes[tid][lvl];
+ if (isCompressedDLT(lvlTp) || isSingletonDLT(lvlTp) ||
+ isCompressedWithHiDLT(lvlTp)) {
+ const auto crd = coords[tid][lvl];
+ if (min) {
+ Value cmp = CMPI(ult, coords[tid][lvl], min);
+ min = SELECT(cmp, coords[tid][lvl], min);
+ } else {
+ min = crd;
+ }
+ }
+ }
+ } else {
+ assert(!min);
+ // Otherwise, universal index is the minimal pos.
+ min = whileOp.getAfterArguments().back();
+ }
+
+ return {whileOp, min};
+}
+
+bool LoopEmitter::shouldIteratedByForLoop(ArrayRef<TensorLvlCond> sparseConds,
+ bool genDedup) {
+ assert(llvm::all_of(sparseConds,
+ [](TensorLvlCond c) { return isSparseCond(c.second); }));
+
+ // If we need to co-iterate over two sparse tensors, we need a while loop
+ if (sparseConds.size() > 1)
+ return false;
+
+ // We also need a while loop for levels with affine index expression for
+ // non-unique levels when deduplication is required.
+ if (sparseConds.size() == 1) {
+ auto [tid, lvl] = unpackTensorLevel(sparseConds.back().first);
+ auto reassoc = getCollapseReassociation(tid, lvl);
+ return !isAffineIdxCond(sparseConds.back().second) &&
+ !(genDedup && !isUniqueDLT(lvlTypes[tid][reassoc.back()]));
+ }
+
+ return true;
+}
+
+Operation *LoopEmitter::enterCoIterationOverTensorsAtLvls(
+ OpBuilder &builder, Location loc, ArrayRef<TensorLevel> tidLvls,
+ MutableArrayRef<Value> reduc, bool tryParallel, bool genDedup,
+ bool needsUniv) {
+ // Sanity checks.
+ assert(!tidLvls.empty());
+ for (auto [t, l] : unpackTensorLevelRange(tidLvls)) {
+ assert(!coords[t][l] || // We cannot re-enter the same level
+ !dependentLvlMap[t][l].empty()); // unless it is a slice-driver loop
+ }
+ // TODO: support multiple return on parallel for?
+ tryParallel = tryParallel && reduc.size() <= 1;
+
+ SmallVector<TensorLvlCond> spConds;
+ SmallVector<TensorLvlCond> dnConds;
+ categorizeLoopCondition(tidLvls, dnConds, spConds);
+
+ // Only when there is at least one sparse conditions, do we really need the
+ // universal index.
+ // TODO: Maybe we should instead requires merger to pass in a valid value at
+ // the first place instead of adjusting it in LoopEmitter?
+ needsUniv = !spConds.empty() && needsUniv;
+ // The TensorLevel used for loop conditions.
+ // If there is any sparse level, we need to use the sparse condition.
+ // If all levels are dense, we can pick arbitary one (dense slice-driven loop
+ // can be generated using a simple ForOp as well).
+ Operation *l = nullptr;
+ Value iv = nullptr;
SmallVector<SliceLoopInfo> sliceDrivenInfo;
+ SmallVector<TensorLevel> trivialLvls;
// Generates loops
diff erently depending on whether we need a slice-driven
// loop or a simple level traversal loop.
- if (isSparseSliceCond) {
- bool fullyReduced = depFullyReduced(tid, lvl);
- if (!fullyReduced) {
- l = emitSliceDrivenLoopOverTensorAtLvl(builder, loc, tid, lvl, reduc);
- } else {
- // If the slice is fully reduced, we can now use TACO-based algorithm to
- // iterate it.
- l = emitWhileLoopOverSliceAtSparseLvl(
- builder, loc, posits[tid][lvl], highs[tid][lvl],
- getFinalSliceOnLvl(tid, lvl).offset, sliceSizes[tid][lvl].back(), tid,
- lvl, reduc);
- }
- levelReducedDep[tid][lvl]++;
- sliceDrivenInfo.emplace_back(tid, lvl, fullyReduced);
- } else {
- Value lo = isSparseCond ? posits[tid][lvl] // current offset
- : loopSeqStack.back().first; // universal index
+ if (shouldIteratedByForLoop(spConds, genDedup) && !needsUniv) {
+ assert(spConds.size() <= 1);
+ TensorLvlCond tlCond = spConds.empty() ? dnConds.front() : spConds.front();
+ auto loopCondKind = tlCond.second;
+ auto [tid, lvl] = unpackTensorLevel(tlCond.first);
+ Value lo = isSparseCond(loopCondKind)
+ ? posits[tid][lvl] // current offset
+ : loopSeqStack.back().first; // universal index
Value hi = highs[tid][lvl];
- if (isDenseSliceCond) {
- bool fullyReduced = depFullyReduced(tid, lvl);
- Value sliceSz = sliceSizes[tid][lvl][sliceStack[tid].back().depth - 1];
- // Adjust for loop hi for dense slice-driven loop.
- if (fullyReduced) {
- hi = sliceSz;
- condTidLvl.push_back(makeTensorLevel(tid, lvl));
- } else {
- hi = SUBI(lvlSizes[tid][lvl], sliceSz);
+ if (isDenseCond(loopCondKind) && isAffineIdxCond(loopCondKind)) {
+ bool unReduc = isAffineIdxUnRedCond(loopCondKind);
+ assert(unReduc == !depFullyReduced(tid, lvl));
+ hi = sliceSizes[tid][lvl][sliceStack[tid].back().depth - 1];
+ if (unReduc) {
+ // Adjust for loop hi for dense slice-driven loop.
+ hi = SUBI(lvlSizes[tid][lvl], hi);
hi = ADDI(hi, C_IDX(1));
}
- } else {
- condTidLvl.push_back(makeTensorLevel(tid, lvl));
}
- l = emitForLoopOverTensorAtLvl(builder, loc, tid, lvl, lo, hi, reduc,
- isParallel);
- }
- Value iv = coords[tid][lvl];
- for (auto [t, l] : unpackTensorLevelRange(tidLvls)) {
- // We only need to handle slice-driven loops on dense level here.
- // If it is a slice-driven loop on sparse level, it needs a while loop to
- // insert break statements, and it must have been handled correctly in L692.
- if (!dependentLvlMap[t][l].empty() && isDenseDLT(lvlTypes[t][l])) {
- // Pushes sliced levels to build correct LoopInfo.
- bool fullyReduc = depFullyReduced(t, l);
- SliceInfo &info = sliceStack[t].back();
- if (fullyReduc) {
- posits[t][l] = genAddress(builder, loc, t, l, ADDI(info.offset, iv));
+ std::tie(l, iv) = emitForLoopOverTensorAtLvl(builder, loc, tid, lvl, lo, hi,
+ reduc, tryParallel);
+ // For loop condition must be a trivial condition (levels without affine
+ // index expression).
+ trivialLvls.push_back(tlCond.first);
+ } else {
+ for (auto [tl, cKind] : spConds) {
+ if (isAffineIdxCond(cKind)) {
+ auto [tid, lvl] = unpackTensorLevel(tl);
+ bool unReduc = isAffineIdxUnRedCond(cKind);
+ assert(unReduc == !depFullyReduced(tid, lvl));
+ sliceDrivenInfo.emplace_back(tid, lvl, /*fullyReduced=*/!unReduc);
} else {
- // Puts sliced dense loop into LoopInfo so that LoopEmitter knows how to
- // exit it.
- sliceDrivenInfo.emplace_back(t, l, fullyReduc);
- // Update the slice information as we enter the new loop.
- assert(*info.slicedOnLvl == l);
- info.minCrd = info.offset = iv;
- info.isNonEmpty = constantI1(builder, loc, true);
- levelReducedDep[t][l]++;
+ trivialLvls.push_back(tl);
}
}
+
+ std::tie(l, iv) =
+ emitWhileLoopOverTensorsAtLvls(builder, loc, spConds, reduc, needsUniv);
}
+
+ // Enter dense tensor levels.
+ enterTensorsAtDenseLvls(builder, loc, dnConds, iv, sliceDrivenInfo);
// NOTE: we can also prepare for next dim here in advance
+
// Pushes the loop into stack.
- loopStack.emplace_back(condTidLvl, sliceDrivenInfo, l,
+ loopStack.emplace_back(trivialLvls, sliceDrivenInfo, l,
builder.getInsertionBlock(), iv, loopTag);
- // Emit extra locals.
- emitExtraLocalsForTensorsAtDenseLvls(builder, loc, tidLvls);
return l;
}
@@ -886,229 +1210,11 @@ void LoopEmitter::genDenseAffineAddress(OpBuilder &builder, Location loc,
AffineExpr lvlExpr) {
auto [tid, lvl] = unpackTensorLevel(tidLvl);
assert(isDenseDLT(lvlTypes[tid][lvl]));
- // For dense levels, the level-coordinate also serves as the position.
+ // For dense levels, the vel-coordinate also serves as the position.
Value lvlCrd = genAffine(builder, loc, lvlExpr);
posits[tid][lvl] = genAddress(builder, loc, tid, lvl, lvlCrd);
}
-Operation *LoopEmitter::enterCoIterationOverTensorsAtLvls(
- OpBuilder &builder, Location loc, ArrayRef<TensorLevel> tidLvls,
- bool needsUniv, MutableArrayRef<Value> reduc) {
- // NOTE: the slice driven tensor-related reduction variable must
- // appear before normal tensors.
- SmallVector<Type> types;
- SmallVector<Value> operands;
- // Construct the while-loop with a parameter for each coordinate.
- const Type indexType = builder.getIndexType();
- for (auto [tid, lvl] : unpackTensorLevelRange(tidLvls)) {
- // TODO: support coiteration with slice driven tensors.
- const auto lvlTp = lvlTypes[tid][lvl];
- assert(dependentLvlMap[tid][lvl].empty() && "TODO: not yet implemented");
- if (isCompressedDLT(lvlTp) || isSingletonDLT(lvlTp) ||
- isCompressedWithHiDLT(lvlTp)) {
- const auto reassoc = getCollapseReassociation(tid, lvl);
- for (unsigned i = 0, e = reassoc.size() - 1; i < e; i++) {
- if (!isUniqueDLT(lvlTypes[tid][reassoc[i]])) {
- // This is the segment high for each non-unique levels.
- types.push_back(indexType);
- operands.push_back(C_IDX(0));
- }
- }
- const auto pos = posits[tid][reassoc.front()];
- assert(pos);
- types.push_back(indexType);
- operands.push_back(pos);
- }
- }
- // The position where user-supplied reduction variable starts.
- for (Value rec : reduc) {
- types.push_back(rec.getType());
- operands.push_back(rec);
- }
- if (needsUniv) {
- types.push_back(indexType);
- // Update universal index.
- operands.push_back(loopSeqStack.back().first);
- }
- assert(types.size() == operands.size());
- scf::WhileOp whileOp = builder.create<scf::WhileOp>(loc, types, operands);
-
- SmallVector<Location> locs(types.size(), loc);
- Block *before = builder.createBlock(&whileOp.getBefore(), {}, types, locs);
- Block *after = builder.createBlock(&whileOp.getAfter(), {}, types, locs);
-
- // Build the "before" region, which effectively consists
- // of a conjunction of "i < upper" tests on all induction.
- builder.setInsertionPointToStart(&whileOp.getBefore().front());
- Value cond;
- unsigned o = 0;
- for (auto [t, lvl] : unpackTensorLevelRange(tidLvls)) {
- const TensorId tid = t; // Why `t` can not be captured by lambda?
- const auto lvlTp = lvlTypes[tid][lvl];
- if (isCompressedDLT(lvlTp) || isSingletonDLT(lvlTp) ||
- isCompressedWithHiDLT(lvlTp)) {
- const auto reassoc = getCollapseReassociation(tid, lvl);
- assert(reassoc.size() == 1 || isUniqueCOOType(tensors[tid].getType()));
- for (unsigned i = 0, e = reassoc.size() - 1; i < e; i++) {
- if (!isUniqueDLT(lvlTypes[tid][reassoc[i]])) {
- // Links the SSA chain for segHi.
- segHi[tid][reassoc[i]] = after->getArgument(o++);
- }
- }
- Value op1 = before->getArgument(o);
- // We used the first level bound as the bound the collapsed set of levels.
- Value op2 = highs[tid][reassoc.front()];
- Value opc = CMPI(ult, op1, op2);
- cond = cond ? ANDI(cond, opc) : opc;
- // Update positions
- Value pos = after->getArgument(o++);
- // For COO, the position is the same across consecutive levels.
- /// FIXME: See the [CLARIFY_POSITS_LVL] note in the header.
- llvm::for_each(reassoc, [this, tid, pos](Level srcLvl) {
- posits[tid][srcLvl] = pos;
- });
- }
- }
- builder.create<scf::ConditionOp>(loc, cond, before->getArguments());
-
- // Generates while body.
- builder.setInsertionPointToStart(&whileOp.getAfter().front());
-
- SmallVector<std::pair<Value, unsigned>> slicesPreds;
- unsigned i = 0;
- for (auto [tid, lvl] : unpackTensorLevelRange(tidLvls)) {
- // Prepares for next level.
- const auto lvlTp = lvlTypes[tid][lvl];
- if (isCompressedDLT(lvlTp) || isSingletonDLT(lvlTp) ||
- isCompressedWithHiDLT(lvlTp)) {
- coords[tid][lvl] = genSparseCrd(builder, loc, tid, lvl);
- if (isSparseSlices[tid]) {
- auto [trans, pred] =
- genSliceLegitPredicate(builder, loc, coords[tid][lvl], tid, lvl);
- slicesPreds.emplace_back(pred, i);
- // Updates to the relative coordinate to the slice.
- coords[tid][lvl] = trans;
- }
- i++;
- }
- }
-
- if (!slicesPreds.empty()) {
- // Skips invalid loop iteration when slice coordinate is inapplicable.
- SmallVector<Value> yields(after->getArguments());
- // Generates a list of if statments
- // pos = in_slice ? pos : pos + 1
- // TODO: instead of always picking pos + 1, we should set pos = high to
- // break to loop if the coordinates are larger than the slice size.
- //
- // This "idx" is the index into `llvm::zip(tids, lvls)`
- for (auto [pred, idx] : slicesPreds) {
- Value nextPos = ADDI(yields[idx], C_IDX(1));
- yields[idx] = SELECT(pred, yields[idx], nextPos);
- }
-
- Value pred = slicesPreds.front().first;
- for (int i = 1, e = slicesPreds.size(); i < e; i++) {
- pred = ANDI(pred, slicesPreds[i].first);
- }
- auto ifOp = builder.create<scf::IfOp>(loc, types, pred, /*else*/ true);
- ifOp->setAttr(getLoopEmitterLoopAttrName(),
- StringAttr::get(builder.getContext(), "slice"));
- YIELD(ifOp->getResults());
- assert(types.size() == yields.size());
- // If not all slices are legit
- builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
- YIELD(yields);
-
- // If all slices are legit, start the user generated code.
- builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
- }
-
- Value min;
- // Finds the minimum coordinate
- if (!needsUniv) {
- for (auto [tid, lvl] : unpackTensorLevelRange(tidLvls)) {
- const auto lvlTp = lvlTypes[tid][lvl];
- if (isCompressedDLT(lvlTp) || isSingletonDLT(lvlTp) ||
- isCompressedWithHiDLT(lvlTp)) {
- const auto crd = coords[tid][lvl];
- if (min) {
- Value cmp = CMPI(ult, coords[tid][lvl], min);
- min = SELECT(cmp, coords[tid][lvl], min);
- } else {
- min = crd;
- }
- }
- }
- } else {
- assert(!min);
- // Otherwise, universal index is the minimal pos.
- min = after->getArguments().back();
- }
-
- // Sets up the loop stack.
- loopStack.emplace_back(tidLvls, ArrayRef<SliceLoopInfo>(), whileOp,
- builder.getInsertionBlock(), min, loopTag);
- assert(loopStack.size() == loopSeqStack.size());
-
- for (auto [tid, dstLvl] : unpackTensorLevelRange(tidLvls)) {
- const auto reassoc = getCollapseReassociation(tid, dstLvl);
- assert(reassoc.size() == 1 || isUniqueCOOType(tensors[tid].getType()));
- // TODO: Refactors this into smaller functions.
- // NOTE: For all the collapsed level (except for the last one, that is why
- // the loop ends with `reassoc.size() - 1`), as each iteration is advanced
- // by the segment size of the last level, which does not always invalidate
- // the segment size for the previous levels, thus we need to propagate the
- // segment sizes across loop iterations and only forward if needed.
- //
- // E.g., for a COO tensor with the following coordinates array.
- // (0, 0, 1),
- // (0, 0, 2),
- // (1, 1, 1),
- // segHi[lvl=0] = segHi[lvl=1] = 2
- // segHi[lvl=2] = 1,
- // the first iteration does not invalidate segHi[0] and segHi[1]
- for (unsigned i = 0, e = reassoc.size() - 1; i < e; i++) {
- const Level srcLvl = reassoc[i];
- if (!isUniqueDLT(lvlTypes[tid][srcLvl])) {
- const Value pos = posits[tid][srcLvl];
- const auto oldSegHi = segHi[tid][srcLvl];
- assert(oldSegHi);
- Value newSegHi = builder.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::uge, pos, oldSegHi);
- auto ifNewSegHi = builder.create<scf::IfOp>(loc, builder.getIndexType(),
- newSegHi, true);
- {
- OpBuilder::InsertionGuard guard(builder);
- builder.setInsertionPointToStart(ifNewSegHi.thenBlock());
- YIELD(genSegmentHigh(builder, loc, tid, srcLvl, pos,
- highs[tid][srcLvl]));
- // Else, resues the same segment high.
- builder.setInsertionPointToStart(ifNewSegHi.elseBlock());
- YIELD(oldSegHi);
- }
- highs[tid][srcLvl + 1] = segHi[tid][srcLvl] = ifNewSegHi.getResult(0);
- }
- };
- const auto srcLvl = reassoc.back();
- if (!isUniqueDLT(lvlTypes[tid][srcLvl])) {
- segHi[tid][srcLvl] = genSegmentHigh(
- builder, loc, tid, srcLvl, posits[tid][srcLvl], highs[tid][srcLvl]);
- }
- }
-
- // Emits extra locals
- emitExtraLocalsForTensorsAtDenseLvls(builder, loc, tidLvls);
-
- // Updates reduction variables
- assert(after->getNumArguments() == o + reduc.size() + (needsUniv ? 1 : 0));
- // In-place update on reduction variable.
- for (unsigned i = 0, e = reduc.size(); i < e; i++)
- reduc[i] = after->getArgument(o + i);
-
- return whileOp;
-}
-
void LoopEmitter::prepareLoopOverTensorAtLvl(OpBuilder &builder, Location loc,
TensorId tid, Level dstLvl) {
assert(isValidLevel(tid, dstLvl));
@@ -1159,20 +1265,35 @@ void LoopEmitter::prepareLoopOverTensorAtLvl(OpBuilder &builder, Location loc,
llvm_unreachable("Unrecognized level-type!");
}
-void LoopEmitter::emitExtraLocalsForTensorsAtDenseLvls(
- OpBuilder &builder, Location loc, ArrayRef<TensorLevel> tidLvls) {
- // Initialize dense positions. Note that we generate dense coordinates of the
- // output tensor unconditionally, since they may not appear in the lattice,
- // but may be needed for linearized codegen.
- for (auto [tid, lvl] : unpackTensorLevelRange(tidLvls)) {
- if (isSynTensor(tid))
- continue;
+void LoopEmitter::enterTensorsAtDenseLvls(
+ OpBuilder &builder, Location loc, ArrayRef<TensorLvlCond> dnConds, Value iv,
+ SmallVectorImpl<SliceLoopInfo> &sliceInfo) {
+ for (auto [dnTidLvl, denseLoopCond] : dnConds) {
+ auto [tid, lvl] = unpackTensorLevel(dnTidLvl);
+ assert(isDenseDLT(lvlTypes[tid][lvl]));
- if (isDenseDLT(lvlTypes[tid][lvl])) {
- // Slice-driven dense level should have be handled already.
- if (!dependentLvlMap[tid][lvl].empty())
+ if (isAffineIdxCond(denseLoopCond)) {
+ // Pushes sliced levels to build correct LoopInfo.
+ bool unReduc = isAffineIdxUnRedCond(denseLoopCond);
+ SliceInfo &info = sliceStack[tid].back();
+ if (unReduc) {
+ // Pushes sliced dense loop info to tell LoopEmitter how to exit it.
+ sliceInfo.emplace_back(tid, lvl, /*fullyReduced=*/false);
+ // Update the slice information as we enter the new loop.
+ assert(*info.slicedOnLvl == lvl);
+ info.minCrd = info.offset = iv;
+ info.isNonEmpty = constantI1(builder, loc, true);
+ levelReducedDep[tid][lvl]++;
+ } else {
+ posits[tid][lvl] =
+ genAddress(builder, loc, tid, lvl, ADDI(info.offset, iv));
+ }
+ } else {
+ // Skips the synthetic tensor
+ if (isSynTensor(tid))
continue;
-
+ // A dense level with trivial index expression.
+ assert(dependentLvlMap[tid][lvl].empty());
auto enc = getSparseTensorEncoding(tensors[tid].getType());
if (enc && !isSparseOutput(tid)) {
bool validPos = lvl == 0 || posits[tid][lvl - 1];
@@ -1182,8 +1303,7 @@ void LoopEmitter::emitExtraLocalsForTensorsAtDenseLvls(
assert(isOutputTensor(tid));
continue;
}
- posits[tid][lvl] =
- genAddress(builder, loc, tid, lvl, loopStack.back().iv);
+ posits[tid][lvl] = genAddress(builder, loc, tid, lvl, iv);
// NOTE: we can also prepare for next lvl here in advance
}
}
@@ -1270,7 +1390,7 @@ void LoopEmitter::exitForLoop(RewriterBase &rewriter, Location loc,
// Finished iterating a tensor, clean up
// We only do the clean up on for loop as while loops do not necessarily
// finish the iteration on a sparse tensor
- for (auto [tid, lvl] : unpackTensorLevelRange(loopInfo.tidLvls)) {
+ for (auto [tid, lvl] : unpackTensorLevelRange(loopInfo.trivialTidLvls)) {
// Reset to null.
coords[tid][lvl] = Value();
posits[tid][lvl] = Value();
@@ -1285,6 +1405,7 @@ void LoopEmitter::exitWhileLoop(OpBuilder &builder, Location loc,
const LoopInfo &loopInfo = loopStack.back();
auto whileOp = llvm::cast<scf::WhileOp>(loopInfo.loop);
Value iv = loopInfo.iv;
+ Value one = C_IDX(1);
// Finalize the induction. Note that the induction could be performed
// in the individual if-branches to avoid re-evaluating the conditions.
@@ -1299,31 +1420,32 @@ void LoopEmitter::exitWhileLoop(OpBuilder &builder, Location loc,
assert(isCompressedDLT(lvlTypes[tid][lvl]));
levelReducedDep[tid][lvl]--;
if (!resolved) {
+ // TODO: support coiterating multiple slices
+ assert(loopInfo.trivialTidLvls.empty() &&
+ loopInfo.sliceDrivenInfo.size() == 1);
genSliceNextInduction(builder, loc, whileOp, tid, lvl, operands, o);
continue;
}
- // TODO: We need to distinguish coiterate loop with slice-driven loop and
- // fully reduced while op for iterating one slices.
- // FIXME: since we didn't implement coiteration, this must be iteration
- // just on fully resolved slice.
- assert(loopInfo.sliceDrivenInfo.size() == 1 && loopInfo.tidLvls.empty());
- // The if guard to filter out out-range coordinates.
- assert(llvm::isa<scf::IfOp>(builder.getInsertionBlock()->getParentOp()));
+
+ if (loopInfo.trivialTidLvls.empty() &&
+ loopInfo.sliceDrivenInfo.size() == 1) {
+ // Forwards the position iterator.
+ operands.push_back(ADDI(posits[tid][lvl], one));
+ } else {
+ const Value pos = posits[tid][lvl];
+ const Value nxPos = ADDI(posits[tid][lvl], one);
+ Value cmp = CMPI(eq, coords[tid][lvl], iv);
+ operands.push_back(SELECT(cmp, nxPos, pos));
+ }
+
+ // The coordinate is invalid now.
+ coords[tid][lvl] = nullptr;
+
+ // Update the position iterator as we exit the while loop.
posits[tid][lvl] = whileOp->getResult(o++);
- // FIXME: we are not using continue here since we do not support
- // coiteration on slices. But it need to be treated similarly as the
- // universal index.
- o++; // skip continue flag.
- // Since we did not push two results from whileOp. The size of the
- // operands vector is smaller than the actual number of return values from
- // the whileOp.
- // It is because we are actually generating yield in the IfOp inside the
- // whileOp to only iterates over inbound coordinates within the slices.
- delta += 2;
};
- Value one = C_IDX(1);
- for (auto [tid, dstLvl] : unpackTensorLevelRange(loopInfo.tidLvls)) {
+ for (auto [tid, dstLvl] : unpackTensorLevelRange(loopInfo.trivialTidLvls)) {
const auto lvlTp = lvlTypes[tid][dstLvl];
if (isCompressedDLT(lvlTp) || isSingletonDLT(lvlTp) ||
isCompressedWithHiDLT(lvlTp)) {
@@ -1357,6 +1479,7 @@ void LoopEmitter::exitWhileLoop(OpBuilder &builder, Location loc,
llvm::for_each(reassoc, [this, newTid, newPos](Level srcLvl) {
posits[newTid][srcLvl] = newPos;
});
+
// The coordinate is invalid now.
coords[tid][dstLvl] = nullptr;
// The segment high is invalid now.
@@ -1439,25 +1562,6 @@ const LoopEmitter::SliceInfo &LoopEmitter::getMostRecentSliceOnLvl(TensorId tid,
llvm_unreachable("Failed to find sliceInfo");
}
-static Value genSparseReducedAffineCond(OpBuilder &builder, Location loc,
- Value crdBuf, Value crdHi, Value posit,
- Value posHi, Value cont) {
- Value inBound = CMPI(ult, posit, posHi);
- auto ifOp = builder.create<scf::IfOp>(loc, cont.getType(), inBound, true);
- // if (inbound)
- // yield coord < crdHi
- builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
- Value crd = genIndexLoad(builder, loc, crdBuf, posit);
- YIELD(CMPI(ult, crd, crdHi));
- // else
- // yield false
- builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
- YIELD(constantI1(builder, loc, false));
-
- builder.setInsertionPointAfter(ifOp);
- return ifOp.getResult(0);
-}
-
// Generates a while loop to iterate over a slice sparse level as follows.
//
// while(coords[loopLo] < offset + size) {
@@ -1466,15 +1570,13 @@ static Value genSparseReducedAffineCond(OpBuilder &builder, Location loc,
// }
std::pair<Operation *, ValueRange> LoopEmitter::genSliceLvlTraverseLoop(
OpBuilder &builder, Location loc, Value posLo, Value posHi, Value offset,
- Value size, TensorId tid, Level lvl, ValueRange userReduc, bool genYield,
+ Value size, TensorId tid, Level lvl, ValueRange userReduc,
LoopBodyBuilder bodyBuilder) {
Value c1 = C_IDX(1);
Value sliceHi = ADDI(offset, sliceSizes[tid][lvl].back());
+ SmallVector<Value> reduc{posLo}; // loop lower bounds
+ const unsigned numMetaReduc = reduc.size();
- SmallVector<Value> reduc = {
- posLo, // loop lower bounds
- constantI1(builder, loc, true), // continue
- };
// Append user required reduction value.
reduc.append(userReduc.begin(), userReduc.end());
scf::WhileOp whileOp = builder.create<scf::WhileOp>(
@@ -1482,28 +1584,28 @@ std::pair<Operation *, ValueRange> LoopEmitter::genSliceLvlTraverseLoop(
/*beforeBuilder=*/
[this, posHi, sliceHi, tid, lvl](OpBuilder &builder, Location loc,
ValueRange args) {
- Value cond = genSparseReducedAffineCond(
- builder, loc, coordinatesBuffers[tid][lvl], sliceHi, args[0], posHi,
- args[1]);
+ Value cond = genSparseReducedAffineCond(builder, loc,
+ coordinatesBuffers[tid][lvl],
+ sliceHi, args[0], posHi);
// continue if not yet break nor out of bound.
builder.create<scf::ConditionOp>(loc, cond, args);
},
/*afterBuilder=*/
- [c1, genYield, bodyBuilder](OpBuilder &builder, Location loc,
- ValueRange args) {
+ [c1, numMetaReduc, bodyBuilder](OpBuilder &builder, Location loc,
+ ValueRange args) {
Value iv = args[0];
- TypeRange types = args.drop_front(2).getTypes();
- // The coordinate must be in bound as guaranteed by the loop condition.
- // We generate a fake if operation here only to hide the two extra loop
- // induction variable maintained by us from user, and relies on later
- // optimization pass to remove it.
- Value cont = constantI1(builder, loc, true);
- auto ifOp = builder.create<scf::IfOp>(loc, types, cont,
+ TypeRange types = args.drop_front(numMetaReduc).getTypes();
+ // The coordinate must be in bound as guaranteed by the loop
+ // condition. We generate a fake if operation here only to hide the
+ // extra loop induction variables maintained by us from users, which
+ // will be removed by later optimization pass.
+ auto ifOp = builder.create<scf::IfOp>(loc, types,
+ constantI1(builder, loc, true),
/*withElseBlock=*/!types.empty());
{
// 2 reduction variable maintained by us.
- SmallVector<Value> ifRet = args.drop_front(2);
- assert(ifRet.size() == args.size() - 2);
+ SmallVector<Value> ifRet = args.drop_front(numMetaReduc);
+ assert(ifRet.size() == args.size() - 1);
OpBuilder::InsertionGuard guard(builder);
// If coord >= sliceHi.
@@ -1516,10 +1618,6 @@ std::pair<Operation *, ValueRange> LoopEmitter::genSliceLvlTraverseLoop(
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
// Delegates to users' callback.
bodyBuilder(builder, loc, iv, ifRet);
- if (genYield) {
- builder.setInsertionPointToEnd(&ifOp.getThenRegion().front());
- YIELD(ifRet);
- }
}
// Marks this speical ifOp to avoid sparisification finalizing it.
ifOp->setAttr(getLoopEmitterLoopAttrName(),
@@ -1528,13 +1626,12 @@ std::pair<Operation *, ValueRange> LoopEmitter::genSliceLvlTraverseLoop(
SmallVector<Value> yields;
// Increase induction variable.
yields.push_back(ADDI(iv, c1));
- yields.push_back(cont);
yields.append(ifOp.getResults().begin(), ifOp.getResults().end());
YIELD(yields);
});
builder.setInsertionPointAfter(whileOp);
- return std::make_pair(whileOp, whileOp.getResults().drop_front(2));
+ return std::make_pair(whileOp, whileOp.getResults().drop_front(numMetaReduc));
}
// Generates a loop nest that traverse all the unresolved levels in between.
@@ -1590,7 +1687,6 @@ ValueRange LoopEmitter::genUnResolvedSliceTreeTraverse(
genSliceLvlTraverseLoop(
builder, loc, loopLo, loopHi, offset,
sliceSizes[tid][firstLvl].back(), tid, firstLvl, iterArgs,
- false,
[&](OpBuilder &builder, Location, Value iv,
MutableArrayRef<Value> reduc) {
ip = builder.saveInsertionPoint();
@@ -1710,7 +1806,8 @@ void LoopEmitter::genResolvedSliceBegin(OpBuilder &builder, Location loc,
// FIXME: We need the relative offset related to the base slice.
Value absOffset = offsetFromMinCoord(builder, loc, minCrd, size, isNonEmpty);
- sliceStack[tid].emplace_back(minCrd, absOffset, isNonEmpty, lvl, /*depth=*/1);
+ sliceStack[tid].emplace_back(minCrd, absOffset, isNonEmpty, lvl,
+ /*depth=*/1);
}
// Fills in the slicePosBuffer before slice-driven loop begin.
@@ -1796,8 +1893,8 @@ void LoopEmitter::genUnResolvedSliceBegin(OpBuilder &builder, Location loc,
Value sPHi =
genIndexLoad(builder, loc, positionsBuffers[tid][lvl], pHi);
- // isNonEmpty = isNonEmpty || lvlNonEmpty, i.e., as long as there is one
- // non-empty lvl, the slice is non-empty.
+ // isNonEmpty = isNonEmpty || lvlNonEmpty, i.e., as long as there is
+ // one non-empty lvl, the slice is non-empty.
Value lvlNonEmpty = CMPI(ult, sPLo, sPHi);
nonEmpty = builder.create<arith::OrIOp>(loc, lvlNonEmpty, nonEmpty);
@@ -1884,8 +1981,8 @@ bool LoopEmitter::genSliceBegin(OpBuilder &builder, Location loc, TensorId tid,
// We do not need cache for dense levels.
if (slicePosBuffer[tid][lvl][0] == nullptr && !isDenseDLT(lvlType)) {
OpBuilder::InsertionGuard guard(builder);
- // The buffer can be reused, and the size is loop invariant: it only depends
- // on the iteration graph's toposort.
+ // The buffer can be reused, and the size is loop invariant: it only
+ // depends on the iteration graph's toposort.
builder.setInsertionPointAfter(localInsertPos);
Value bufSize = C_IDX(1);
Value c2 = C_IDX(2);
@@ -1904,9 +2001,9 @@ bool LoopEmitter::genSliceBegin(OpBuilder &builder, Location loc, TensorId tid,
Value sz = *(sliceSizes[tid][lvl].rbegin() + depth - 1);
bufSize = MULI(bufSize, sz);
}
- // For a pair of [pLo, pHi]. Note that we can not compress pHi because slice
- // creates segments in the index buffer so that the pHi for the current
- // level is no longer the pLo for the next level.
+ // For a pair of [pLo, pHi]. Note that we can not compress pHi because
+ // slice creates segments in the index buffer so that the pHi for the
+ // current level is no longer the pLo for the next level.
bufSize = MULI(bufSize, c2);
// Additional two metadata {memSize, idx} at head.
bufSize = ADDI(bufSize, c2);
@@ -2117,59 +2214,6 @@ void LoopEmitter::genSliceNextInduction(OpBuilder &builder, Location loc,
info.offset = whileOp.getResult(retIdx++);
}
-Operation *LoopEmitter::emitSliceDrivenLoopOverTensorAtLvl(
- OpBuilder &builder, Location loc, TensorId tid, Level lvl,
- MutableArrayRef<Value> reduc) {
- assert(!depFullyReduced(tid, lvl));
- SliceInfo &sliceInfo = sliceStack[tid].back();
- assert(sliceInfo.slicedOnLvl == lvl);
-
- // The order matters!
- SmallVector<Value, 3> operands{sliceInfo.isNonEmpty, sliceInfo.minCrd,
- sliceInfo.offset};
- // number of reduction maintained by us.
- size_t numMetaReduc = operands.size();
-
- // Append user-required reduction values.
- operands.append(reduc.begin(), reduc.end());
- assert(operands.size() == numMetaReduc + reduc.size());
-
- // while (slice.nonEmpty()) {
- // bodyBuilder();
- // SliceNext();
- // }
- auto whileOp = builder.create<scf::WhileOp>(
- loc, ValueRange(operands).getTypes(), operands,
- /*beforeBuilder=*/
- [](OpBuilder &builder, Location loc, ValueRange args) {
- builder.create<scf::ConditionOp>(loc, /*isNonEmpty*/ args[0], args);
- },
- /*afterBuilder=*/
- [this, tid, lvl, reduc, numMetaReduc,
- &sliceInfo](OpBuilder &builder, Location loc, ValueRange args) {
- assert(args.size() == reduc.size() + numMetaReduc);
- sliceInfo.isNonEmpty = args[0];
- sliceInfo.minCrd = args[1];
- sliceInfo.offset = args[2];
- // The slice offset is used to coiterate with other tensors'
- // coordinates.
- Value c = sliceInfo.offset;
- if (sliceInfo.depth > 1) {
- // Coord is the relative offset related to its parents.
- // Update c = absOffset[lvl][depth] - absOffset[lvl][depth - 1]
- llvm_unreachable("TODO: not yet implement");
- }
- coords[tid][lvl] = c;
-
- for (unsigned i = 0, e = reduc.size(); i < e; i++)
- reduc[i] = args[i + numMetaReduc];
- });
-
- // Set the insertion point to while loop body.
- builder.setInsertionPointToEnd(&whileOp.getAfter().front());
- return whileOp;
-}
-
#undef CMPI
#undef C_IDX
#undef YIELD
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
index 8fa79128889e2..f178366a738a4 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
@@ -141,33 +141,39 @@ class LoopEmitter {
/// Exits the current loop sequence, this will reset universal index to 0.
void exitCurrentLoopSeq(OpBuilder &builder, Location loc);
- // TODO: Get rid of `lvls` in the argument list? Track the level we
- // are currently at internally. Then it would be enterNextLvlForTensor.
- // Still need a way to specify the lvl for non-annotated tensors though,
- // as those can be accessed out of order.
- //
- /// Emits loop over tensor_tid_lvl, it assumes that loops between
- /// tensor_tid_[0, lvl - 1] have already been generated.
- /// The function will also perform in-place update on the `reduc` vector to
- /// return the reduction variable used inside the generated loop.
- Operation *enterLoopOverTensorAtLvl(OpBuilder &builder, Location loc,
- ArrayRef<TensorLevel> tidLvls,
- MutableArrayRef<Value> reduc = {},
- bool isParallel = false);
-
+ /// Enters a loop that tries to locate a coordinates in a sparse level based
+ /// on the value evaluated by the provided affine expression.
+ /// DEPRECATED: affine index expression should be handled by index reduction
+ /// loop, filter loop-based solution is slow.
Operation *enterFilterLoopOverTensorAtLvl(OpBuilder &builder, Location loc,
TensorId tid, Level lvl,
AffineExpr affine,
MutableArrayRef<Value> reduc = {});
+ /// Emits the address for a dense level based on the value evaluated by the
+ /// provided affine expression.
+ /// DEPRECATED: affine index expression should be handled by index reduction
+ /// loop, filter loop-based solution is slow.
void genDenseAffineAddress(OpBuilder &builder, Location loc,
TensorLevel tidLvl, AffineExpr lvlExpr);
+ // TODO: Get rid of `lvls` in the argument list? Track the level we
+ // are currently at internally. Then it would be enterNextLvlForTensor.
+ // Still need a way to specify the lvl for non-annotated tensors though,
+ // as those can be accessed out of order.
+ //
/// Emits a co-iteration loop over a set of tensors.
+ /// Emits loop over tensor_tid_lvl, it assumes that loops between
+ /// tensor_tid_[0, lvl - 1] have already been generated.
+ /// The function will also perform in-place update on the `reduc` vector to
+ /// return the reduction variable used inside the generated loop.
Operation *enterCoIterationOverTensorsAtLvls(
OpBuilder &builder, Location loc, ArrayRef<TensorLevel> tidLvls,
- bool needsUniv, MutableArrayRef<Value> reduc = {});
+ MutableArrayRef<Value> reduc = {}, bool isParallel = false,
+ bool genDedup = false, bool needsUniv = false);
+ /// Generates code to exit the current loop (e.g., generates yields, forwards
+ /// loop induction variables, etc).
void exitCurrentLoop(RewriterBase &rewriter, Location loc,
MutableArrayRef<Value> reduc = {});
@@ -232,6 +238,15 @@ class LoopEmitter {
});
}
+ template <class ContainerTy>
+ auto unpackTensorLevelFromCondRange(ContainerTy &&c) const {
+ using EltTy = decltype(*c.begin());
+ static_assert(std::is_same_v<llvm::remove_cvref_t<EltTy>, TensorLvlCond>,
+ "Must be unpacking a TensorLvlCond range");
+ return unpackTensorLevelRange(
+ llvm::make_first_range(std::forward<ContainerTy>(c)));
+ }
+
///
/// Getters.
///
@@ -251,6 +266,10 @@ class LoopEmitter {
}
private:
+ ///
+ /// Structure definitions that hold
diff erent kinds of loops information.
+ ///
+
// A tuple that stored the slice-driven loop information.
struct SliceLoopInfo final {
SliceLoopInfo(TensorId tid, Level lvl, bool reduced)
@@ -262,18 +281,22 @@ class LoopEmitter {
// LoopInfo stores information of a loop generated by LoopEmitter. E.g.,
// the set of tensors levels that the loop is iterating over.
struct LoopInfo final {
- LoopInfo(ArrayRef<TensorLevel> tidLvls,
+ LoopInfo(ArrayRef<TensorLevel> trivialTidLvls,
ArrayRef<SliceLoopInfo> sliceDrivenInfo, Operation *loop,
Block *userBlock, Value iv, StringAttr loopTag)
- : tidLvls(tidLvls), sliceDrivenInfo(sliceDrivenInfo), loop(loop),
- userCodeBlock(userBlock), iv(iv) {
+ : trivialTidLvls(trivialTidLvls), sliceDrivenInfo(sliceDrivenInfo),
+ loop(loop), userCodeBlock(userBlock), iv(iv) {
// Attached a special tag to loop emitter generated loop.
if (loopTag)
loop->setAttr(LoopEmitter::getLoopEmitterLoopAttrName(), loopTag);
}
- // The set of <tensor, lvl> that the loop is operating on
- const llvm::SmallVector<TensorLevel> tidLvls;
- // Slice-driven loop conditions.
+ // The set of <tensor, lvl>, with *only* trivial index expressions, that are
+ // used as the condition for the generated loop. Extra information is
+ // required for levels with non-tivial index expressions, which is
+ // maintained by the sliceDrivenInfo array below.
+ const llvm::SmallVector<TensorLevel> trivialTidLvls;
+ // The set of <tensor, lvl>, with *only* non-trivial index expressions, that
+ // are used as the condition for the generated loop.
const llvm::SmallVector<SliceLoopInfo> sliceDrivenInfo;
const Operation *loop; // the loop operation
Block *const userCodeBlock; // the block holding users' generated code.
@@ -304,9 +327,100 @@ class LoopEmitter {
unsigned depth; // the depth (relative to dependentDimMap[tid][lvl]).
};
+ ///
+ /// Enums for
diff erent kinds of loop conditions.
+ ///
+
+ // The bit indicating whether the loop conditions is sparse.
+ static constexpr uint8_t kSparseCond = 1 << 3;
+ // The bit indicating whether the loop iterates over sparse tensor slices
+ // (i.e., with non-empty SliceDimAttr).
+ static constexpr uint8_t kSliceCond = 1 << 2;
+ // The bit indicating whether the loop iterates over tensor levels with
+ // non-trivial affine index reduction.
+ static constexpr uint8_t kAffineIdxCond = 1 << 1;
+ // The bit indicating whether the loop iterates over tensor levels with
+ // non-trivial affine index reduction, and it is not fully reduced.
+ static constexpr uint8_t kAffineIdxCondUnRed = 1 << 0;
+
+ enum class LoopCondKind : uint8_t {
+ // Dense conditions.
+ DenseCond = 0,
+ DenseSliceCond = kSliceCond,
+ DenseAffineCond = kAffineIdxCond,
+ DenseAffineUnRedCond = kAffineIdxCond | kAffineIdxCondUnRed,
+ // Sparse Conditions.
+ SparseCond = kSparseCond,
+ SparseSliceCond = kSparseCond | kSliceCond,
+ SparseAffineCond = kSparseCond | kAffineIdxCond,
+ SparseAffineUnRedCond = kSparseCond | kAffineIdxCond | kAffineIdxCondUnRed,
+ };
+ using TensorLvlCond = std::pair<TensorLevel, LoopCondKind>;
+
+ /// Sparse or dense loop condition.
+ static bool isSparseCond(LoopCondKind k) {
+ return static_cast<uint8_t>(k) & kSparseCond;
+ }
+ static bool isDenseCond(LoopCondKind k) { return !isSparseCond(k); }
+
+ /// Whether loops over sparse tensor slices or sparse tensors.
+ static bool isSliceCond(LoopCondKind k) {
+ return static_cast<uint8_t>(k) & kSliceCond;
+ }
+
+ /// Affine or trivial index expression loop condition.
+ static bool isAffineIdxCond(LoopCondKind k) {
+ return static_cast<uint8_t>(k) & kAffineIdxCond;
+ }
+ static bool isTrivalIdxCond(LoopCondKind k) { return !isAffineIdxCond(k); }
+
+ /// Whether the affine index expression is not fully reduced.
+ static bool isAffineIdxUnRedCond(LoopCondKind k) {
+ return isAffineIdxCond(k) && static_cast<uint8_t>(k) & kAffineIdxCondUnRed;
+ }
+ static bool isAffineIdxRedCond(LoopCondKind k) {
+ return isAffineIdxCond(k) && !isAffineIdxUnRedCond(k);
+ }
+
+ // Whether the loop condition kind requires extra check inside the loop body.
+ // E.g., to iterate over sparse tensor slice, we need to check whether the
+ // current cooridnate is on the slice (e.g., due to stride) or not.
+ static bool isCondWithExtraCheck(LoopCondKind k) {
+ return isSparseCond(k) && isSliceCond(k);
+ }
+
+ static LoopCondKind makeLoopCondKind(bool isSparse, bool isSlice,
+ bool isAffine, bool isUnRedu) {
+ assert(!isUnRedu || isAffine);
+ uint8_t bits = 0;
+ bits = isSparse ? bits | kSparseCond : bits;
+ bits = isSlice ? bits | kSliceCond : bits;
+ bits = isAffine ? bits | kAffineIdxCond : bits;
+ bits = isUnRedu ? bits | kAffineIdxCondUnRed : bits;
+ LoopCondKind kind = static_cast<LoopCondKind>(bits);
+
+ // Sanity checks.
+ assert(isSparse == isSparseCond(kind));
+ assert(isSlice == isSliceCond(kind));
+ assert(isAffine == isAffineIdxCond(kind));
+ assert(isUnRedu == isAffineIdxUnRedCond(kind));
+ return kind;
+ }
+
+ void categorizeLoopCondition(ArrayRef<TensorLevel> tidLvls,
+ SmallVectorImpl<TensorLvlCond> &dnConds,
+ SmallVectorImpl<TensorLvlCond> &spConds);
+
+ ///
+ /// LoopEmitter internal helper functions.
+ ///
+
using LoopBodyBuilder = llvm::function_ref<void(OpBuilder &, Location, Value,
MutableArrayRef<Value>)>;
+ /// Whether the list of the sparse condition should be iterated by for loop.
+ bool shouldIteratedByForLoop(ArrayRef<TensorLvlCond> spConds, bool genDedup);
+
/// Linearizes address for dense dimension (i.e., p = (i * d0) + j).
Value genAddress(OpBuilder &builder, Location loc, TensorId tid, Level lvl,
Value iv);
@@ -354,31 +468,51 @@ class LoopEmitter {
void prepareLoopOverTensorAtLvl(OpBuilder &builder, Location loc,
TensorId tid, Level lvl);
- /// Emits extra locals, since the locals might not be in simplified lattices
- /// point used to generate the loops, but are still required to generate
- /// expressions.
- void emitExtraLocalsForTensorsAtDenseLvls(OpBuilder &builder, Location loc,
- ArrayRef<TensorLevel> tidLvls);
-
- /// Emits a for loop to iterate over a tensor level with the provided lower
- /// bound `lo` and upper bound `hi`.
- /// Apart from iterating just single tensor level, for loops can be used for
- /// slice-driven loop on dense level too.
- Operation *emitForLoopOverTensorAtLvl(OpBuilder &builder, Location loc,
- TensorId tid, Level lvl, Value lo,
- Value hi, MutableArrayRef<Value> reduc,
- bool isParallel);
-
- /// Emits a while loop to iterate over a sparse level that has been sliced.
- /// Inserts break statement when the coordinate exceeds the sliceSize;
- /// The method sets the insertion point inside the generated while loop body
- /// after the break statement before return (so that callers need to handle
- /// only in-bound coordinates).
- Operation *emitWhileLoopOverSliceAtSparseLvl(OpBuilder &builder, Location loc,
- Value pLo, Value pHi,
- Value offset, Value sliceSize,
- TensorId tid, Level lvl,
- MutableArrayRef<Value> reduc);
+ /// Enter dense tensor levels. Since the dense tensor condition could be
+ /// optimized from the loop condition, we need to compute the
+ /// positions/coordinates inside the loop body.
+ void enterTensorsAtDenseLvls(OpBuilder &builder, Location loc,
+ ArrayRef<TensorLvlCond> dnConds, Value iv,
+ SmallVectorImpl<SliceLoopInfo> &sliceInfo);
+
+ /// Emits a for loop to iterate over a tensor level with the provided
+ /// lower bound `lo` and upper bound `hi`. Apart from iterating just
+ /// single tensor level, for loops can be used for slice-driven loop on
+ /// dense level too.
+ /// Returns a pair: the loop generated and the value for the induction
+ /// variable.
+ std::pair<Operation *, Value>
+ emitForLoopOverTensorAtLvl(OpBuilder &builder, Location loc, TensorId tid,
+ Level lvl, Value lo, Value hi,
+ MutableArrayRef<Value> reduc, bool isParallel);
+
+ /// Emits a while loop to co-iterate over a list of sparse condition, or
+ /// (complex) single sparse condition that can not be handled by for loop
+ /// (e.g., index reduction loop).
+ /// Returns a pair: the loop generated and the value for the induction
+ /// variable (which is the minimum coordinate of all the tensor that being
+ /// iterated).
+ std::pair<Operation *, Value>
+ emitWhileLoopOverTensorsAtLvls(OpBuilder &builder, Location loc,
+ ArrayRef<TensorLvlCond> spConds,
+ MutableArrayRef<Value> reduc, bool needsUniv);
+
+ /// Generates the while loop condition for the given tensor level condition.
+ Value genWhileLoopConditions(OpBuilder &builder, Location loc, ValueRange ivs,
+ TensorLvlCond cond);
+
+ /// Generates the while loop body for the given tensor level condition.
+ std::optional<Value> genWhileLoopBody(OpBuilder &builder, Location loc,
+ ValueRange ivs, TensorLvlCond cond);
+
+ /// Generates the values (to forward the loop) if the extra check failes.
+ /// E.g., to iterate over a sparse tensor slice, we need:
+ ///
+ /// pos = onSlice(curCrd) ? pos : pos + 1
+ ///
+ /// to skip invalid coordinate that is included in the slice.
+ ValueRange genCheckedValue(OpBuilder &builder, Location loc, Value pred,
+ ValueRange curArg, TensorLvlCond cond);
/// Exits a for loop, returns the reduction results, e.g.,
/// For sequential for loops:
@@ -488,7 +622,7 @@ class LoopEmitter {
std::pair<Operation *, ValueRange>
genSliceLvlTraverseLoop(OpBuilder &builder, Location loc, Value pLo,
Value pHi, Value offset, Value size, TensorId tid,
- Level lvl, ValueRange userReduc, bool genYield,
+ Level lvl, ValueRange userReduc,
LoopBodyBuilder bodyBuilder);
/// Generates a nested loop that iterates over tid on all the coordinates on
@@ -530,19 +664,6 @@ class LoopEmitter {
SmallVectorImpl<Value> &operands,
unsigned &retIdx);
- /// Generates a slice-driven while loop as follows.
- ///
- /// curSlice = getFirstNonEmptySlice(tensor).
- ///
- /// while(isNonEmpty) {
- /// ..user code..
- /// isNonEmpty, curSlice = getNextNonEmptySlice(curSlice)
- /// }
- Operation *emitSliceDrivenLoopOverTensorAtLvl(OpBuilder &builder,
- Location loc, TensorId tid,
- Level lvl,
- MutableArrayRef<Value> reduc);
-
/// A optional string attribute that should be attached to the loop
/// generated by loop emitter, it might help following passes to identify
/// loops that operates on sparse tensors more easily.
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 2c5289aa8ae16..bb9c15ea463da 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -1180,7 +1180,8 @@ struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
loopEmitter.enterNewLoopSeq(rewriter, loc, tidLvls);
// Note that reduc will be taken care of by loop emitter and get updated
// in place.
- loopEmitter.enterLoopOverTensorAtLvl(rewriter, loc, tidLvls, reduc);
+ loopEmitter.enterCoIterationOverTensorsAtLvls(rewriter, loc, tidLvls,
+ reduc);
}
SmallVector<Value> lcvs;
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 891befdcb51ea..637c16f92a293 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -1306,12 +1306,11 @@ static bool isParallelFor(CodegenEnv &env, bool isOuter, bool isSparse) {
llvm_unreachable("unexpected parallelization strategy");
}
-/// Generates a for-loop on a single index.
-static Operation *genFor(CodegenEnv &env, OpBuilder &builder, bool isOuter,
- bool isInner, LoopId ldx,
- ArrayRef<TensorLevel> tidLvls) {
+/// Whether or not the current loop being generated should be parallized (if
+/// possible) according to the configuration.
+static bool shouldTryParallize(CodegenEnv &env, LoopId ldx, bool isOuter,
+ ArrayRef<TensorLevel> tidLvls) {
linalg::GenericOp op = env.op();
- Location loc = op.getLoc();
auto iteratorTypes = op.getIteratorTypesArray();
bool isSparse = llvm::any_of(tidLvls, [ldx, &env](TensorLevel tidLvl) {
// Queries the DLT based on the tensor id and loop idx, as requested by
@@ -1321,38 +1320,44 @@ static Operation *genFor(CodegenEnv &env, OpBuilder &builder, bool isOuter,
return isCompressedDLT(dlt) || isSingletonDLT(dlt);
});
- bool isParallel = isParallelFor(env, isOuter, isSparse);
+ return isParallelFor(env, isOuter, isSparse);
+}
+/// Generates a "filter loop" on the given tid level to locate a coordinate that
+/// is of the same value as evaluated by the affine expression in its matching
+/// indexing map.
+static Operation *genFilterLoop(CodegenEnv &env, OpBuilder &builder, LoopId ldx,
+ TensorLevel tidLvl) {
+ linalg::GenericOp op = env.op();
+ Location loc = op.getLoc();
Operation *loop = *env.genLoopBoundary([&](MutableArrayRef<Value> reduc) {
- if (env.merger().isFilterLoop(ldx)) {
- const auto [tid, lvl] = env.unpackTensorLevel(tidLvls.front());
- // tids/lvls must only have one value because filter loops only
- // corresponding to the one and only sparse tensor level.
- assert(isSparse && tidLvls.size() == 1);
- OpOperand *t = &op->getOpOperand(tid);
- auto enc = getSparseTensorEncoding(t->get().getType());
- // Retrieves the affine expression for the filter loop.
- // FIXME: `toOrigDim` is deprecated.
- AffineExpr a =
- op.getMatchingIndexingMap(t).getResult(toOrigDim(enc, lvl));
- return env.emitter().enterFilterLoopOverTensorAtLvl(builder, loc, tid,
- lvl, a, reduc);
- }
- return env.emitter().enterLoopOverTensorAtLvl(builder, loc, tidLvls, reduc,
- isParallel);
+ assert(env.merger().isFilterLoop(ldx));
+ const auto [tid, lvl] = env.unpackTensorLevel(tidLvl);
+ // tids/lvls must only have one value because filter loops only
+ // corresponding to the one and only sparse tensor level.
+ OpOperand *t = &op->getOpOperand(tid);
+ auto enc = getSparseTensorEncoding(t->get().getType());
+ // Retrieves the affine expression for the filter loop.
+ // FIXME: `toOrigDim` is deprecated.
+ AffineExpr a = op.getMatchingIndexingMap(t).getResult(toOrigDim(enc, lvl));
+ return env.emitter().enterFilterLoopOverTensorAtLvl(builder, loc, tid, lvl,
+ a, reduc);
});
- assert(loop);
return loop;
}
-/// Emit a while-loop for co-iteration over multiple indices.
-static Operation *genWhile(CodegenEnv &env, OpBuilder &builder, LoopId idx,
- bool needsUniv, ArrayRef<TensorLevel> tidLvls) {
+/// Emit a loop to coiterate over the list of tensor levels. The generated loop
+/// can either be a for loop or while loop depending on whether there is at most
+/// one sparse level in the list.
+static Operation *genCoIteration(CodegenEnv &env, OpBuilder &builder,
+ LoopId idx, ArrayRef<TensorLevel> tidLvls,
+ bool tryParallel, bool needsUniv) {
Operation *loop = *env.genLoopBoundary([&](MutableArrayRef<Value> reduc) {
// Construct the while-loop with a parameter for each
// index.
return env.emitter().enterCoIterationOverTensorsAtLvls(
- builder, env.op().getLoc(), tidLvls, needsUniv, reduc);
+ builder, env.op().getLoc(), tidLvls, reduc, tryParallel,
+ /*genDedup=*/true, needsUniv);
});
assert(loop);
return loop;
@@ -1361,15 +1366,15 @@ static Operation *genWhile(CodegenEnv &env, OpBuilder &builder, LoopId idx,
/// Generates a for-loop or a while-loop, depending on whether it implements
/// singleton iteration or co-iteration over the given conjunction.
static Operation *genLoop(CodegenEnv &env, OpBuilder &builder, LoopOrd at,
- bool needsUniv, ArrayRef<TensorLevel> tidLvls,
- bool isFor) {
- const LoopId idx = env.topSortAt(at);
- if (isFor) {
- bool isOuter = at == 0;
- bool isInner = at == env.topSortSize() - 1;
- return genFor(env, builder, isOuter, isInner, idx, tidLvls);
+ bool needsUniv, ArrayRef<TensorLevel> tidLvls) {
+ const LoopId ldx = env.topSortAt(at);
+ if (env.merger().isFilterLoop(ldx)) {
+ assert(tidLvls.size() == 1);
+ return genFilterLoop(env, builder, ldx, tidLvls.front());
}
- return genWhile(env, builder, idx, needsUniv, tidLvls);
+
+ bool tryParallel = shouldTryParallize(env, ldx, at == 0, tidLvls);
+ return genCoIteration(env, builder, ldx, tidLvls, tryParallel, needsUniv);
}
/// Generates the induction structure for a while-loop.
@@ -1684,7 +1689,7 @@ static std::pair<Operation *, bool> startLoop(CodegenEnv &env,
tidLvls, affineTidLvls);
// Emit the for/while-loop control.
- Operation *loop = genLoop(env, builder, at, needsUniv, tidLvls, isSingleCond);
+ Operation *loop = genLoop(env, builder, at, needsUniv, tidLvls);
Location loc = env.op().getLoc();
for (auto [tidLvl, exp] : affineTidLvls) {
env.emitter().genDenseAffineAddress(builder, loc, tidLvl, exp);
diff --git a/mlir/test/Dialect/SparseTensor/sorted_coo.mlir b/mlir/test/Dialect/SparseTensor/sorted_coo.mlir
index e4e65ef4b4e71..01189328eeb61 100644
--- a/mlir/test/Dialect/SparseTensor/sorted_coo.mlir
+++ b/mlir/test/Dialect/SparseTensor/sorted_coo.mlir
@@ -185,8 +185,6 @@ func.func @matvec(%arga: tensor<32x64xf64, #SortedCOO>,
// CHECK: ^bb0(%[[VAL_26:.*]]: index, %[[VAL_27:.*]]: index):
// CHECK: %[[VAL_28:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_26]]] : memref<?xindex, strided<[?], offset: ?>>
// CHECK: %[[VAL_29:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_27]]] : memref<?xindex, strided<[?], offset: ?>>
-// CHECK: %[[VAL_30:.*]] = arith.cmpi ult, %[[VAL_29]], %[[VAL_28]] : index
-// CHECK: %[[VAL_31:.*]] = arith.select %[[VAL_30]], %[[VAL_29]], %[[VAL_28]] : index
// CHECK: %[[VAL_32:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_26]]] : memref<?xindex, strided<[?], offset: ?>>
// CHECK: %[[VAL_33:.*]] = scf.while (%[[VAL_34:.*]] = %[[VAL_26]]) : (index) -> index {
// CHECK: %[[VAL_35:.*]] = arith.cmpi ult, %[[VAL_34]], %[[VAL_17]] : index
@@ -219,6 +217,8 @@ func.func @matvec(%arga: tensor<32x64xf64, #SortedCOO>,
// CHECK: %[[VAL_51:.*]] = arith.addi %[[VAL_50]], %[[VAL_6]] : index
// CHECK: scf.yield %[[VAL_51]] : index
// CHECK: }
+// CHECK: %[[VAL_30:.*]] = arith.cmpi ult, %[[VAL_29]], %[[VAL_28]] : index
+// CHECK: %[[VAL_31:.*]] = arith.select %[[VAL_30]], %[[VAL_29]], %[[VAL_28]] : index
// CHECK: %[[VAL_52:.*]] = arith.cmpi eq, %[[VAL_28]], %[[VAL_31]] : index
// CHECK: %[[VAL_53:.*]] = arith.cmpi eq, %[[VAL_29]], %[[VAL_31]] : index
// CHECK: %[[VAL_54:.*]] = arith.andi %[[VAL_52]], %[[VAL_53]] : i1
More information about the Mlir-commits
mailing list