[Mlir-commits] [mlir] e015d38 - [mlir][sparse] Pass down constant coefficients of affine index expressions to LoopEmitter.
Peiming Liu
llvmlistbot at llvm.org
Wed Aug 30 11:55:59 PDT 2023
Author: Peiming Liu
Date: 2023-08-30T18:44:50Z
New Revision: e015d385c913daae3ec9654b84104caf28940c77
URL: https://github.com/llvm/llvm-project/commit/e015d385c913daae3ec9654b84104caf28940c77
DIFF: https://github.com/llvm/llvm-project/commit/e015d385c913daae3ec9654b84104caf28940c77.diff
LOG: [mlir][sparse] Pass down constant coefficients of affine index expressions to LoopEmitter.
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D158914
Added:
Modified:
mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
index 6293badb6df5c4..5e753800675728 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
@@ -56,6 +56,13 @@ using LatPointId = unsigned;
/// for the corresponding `SmallVector<LatPointId>` object.
using LatSetId = unsigned;
+/// A pair of level and its corresponding DimLevelType of a tensor.
+using LvlDLTPair = std::pair<Level, DimLevelType>;
+
+/// A pair of loop id and its coefficients. E.g., for affine expression in the
+/// affine map `2 * d0`, loop id = 0, coefficient = 2.
+using LoopCoeffPair = std::pair<LoopId, unsigned>;
+
/// Tensor expression. Represents an MLIR expression in tensor index notation.
struct TensorExp final {
enum class Kind;
@@ -509,22 +516,22 @@ class Merger {
/// Establishes the two-way map that i <-> <t, lvl, dlt>.
void setLoopDependentTensorLevel(LoopId i, TensorId t, Level lvl,
- DimLevelType dlt) {
+ DimLevelType dlt, unsigned coefficient) {
assert(isValidLoopId(i) && isValidLevel(t, lvl));
- assert(!loopToDependencies[i][t].has_value()); // must be the first def
- loopToDependencies[i][t] = std::make_pair(lvl, dlt);
- levelToDependentLoop[t][lvl].push_back(i);
+ assert(!loopToUnresolvedLvls[i][t].has_value()); // must be the first def
+ loopToUnresolvedLvls[i][t] = std::make_pair(lvl, dlt);
+ levelToDependentLoop[t][lvl].emplace_back(i, coefficient);
}
/// Whether the loop has dependent slice.
bool hasDependentLvl(LoopId i, TensorId t) {
assert(isValidTensorId(t) && isValidLoopId(i));
- return loopToDependencies[i][t].has_value();
+ return loopToUnresolvedLvls[i][t].has_value();
}
/// Returns the list of loop indices which appear in the non-trivial index
/// expression on t_l, e.g., A[i+j] => {i, j}
- std::vector<LoopId> &getDependentLoops(TensorId t, Level lvl) {
+ std::vector<LoopCoeffPair> &getDependentLoops(TensorId t, Level lvl) {
assert(isValidLevel(t, lvl));
return levelToDependentLoop[t][lvl];
}
@@ -541,7 +548,7 @@ class Merger {
const TensorId t = tensor(b);
const LoopId i = loop(b);
assert(isValidTensorId(t) && isValidLoopId(i));
- return loopToDependencies[i][t].has_value();
+ return loopToUnresolvedLvls[i][t].has_value();
}
/// Checks whether the TensorLoopId represents a sparse tensor level contains
@@ -556,12 +563,12 @@ class Merger {
Level getLoopDependentLevel(TensorLoopId b) const {
assert(isLvlWithNonTrivialIdxExp(b));
- return loopToDependencies[loop(b)][tensor(b)]->first;
+ return loopToUnresolvedLvls[loop(b)][tensor(b)]->first;
}
DimLevelType getLoopDependentLevelType(TensorLoopId b) const {
assert(isLvlWithNonTrivialIdxExp(b));
- return loopToDependencies[loop(b)][tensor(b)]->second;
+ return loopToUnresolvedLvls[loop(b)][tensor(b)]->second;
}
/// Convenience getters to immediately access the stored nodes.
@@ -715,13 +722,13 @@ class Merger {
/// It is currently only set for non-trivial index expressions.
/// E.g., A[i+j] => i and j will have dependencies {A0, dlt(A0)} to indicate
/// that i and j are used in the non-trivial index expression on A0.
- std::vector<std::vector<std::optional<std::pair<Level, DimLevelType>>>>
- loopToDependencies;
+ std::vector<std::vector<std::optional<LvlDLTPair>>> loopToUnresolvedLvls;
/// The inverse map of ldxToDependencies from tensor level -> dependent loop
- /// E.g., A[i+j], we have A0 => {i, j}, to indicate that A0 uses both {i, j}
- /// to compute its indices.
- std::vector<std::vector<std::vector<LoopId>>> levelToDependentLoop;
+ /// E.g., A[2i+j], we have A0 => {(2, i), (1, j)}, to indicate that A0 uses
+ /// both {i, j} to compute its indices and the coefficients on the loop id are
+ /// 2 and 1 respectively.
+ std::vector<std::vector<std::vector<LoopCoeffPair>>> levelToDependentLoop;
/// Map from a loop to the [tid, lvl] pair that defines the loop boundary.
std::vector<std::pair<TensorId, Level>> loopBounds;
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
index 854b6f5587073b..924b0a0dac8113 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
@@ -29,16 +29,16 @@ static bool isMaterializing(Value val) {
}
/// Makes target array's elements sorted according to the `order` array.
-static void sortArrayBasedOnOrder(std::vector<LoopId> &target,
+static void sortArrayBasedOnOrder(std::vector<LoopCoeffPair> &target,
ArrayRef<LoopId> order) {
std::sort(target.begin(), target.end(),
- [&order](const LoopId &l, const LoopId &r) {
+ [&order](const LoopCoeffPair &l, const LoopCoeffPair &r) {
assert(std::addressof(l) == std::addressof(r) || l != r);
int idxL = -1, idxR = -1;
for (int i = 0, e = order.size(); i < e; i++) {
- if (order[i] == l)
+ if (order[i] == l.first)
idxL = i;
- if (order[i] == r)
+ if (order[i] == r.first)
idxR = i;
}
assert(idxL >= 0 && idxR >= 0);
@@ -104,13 +104,17 @@ void CodegenEnv::startEmit() {
/*isSparseOut=*/sparseOut != nullptr, topSort,
// TODO: compute the map and pass it to loop emitter directly instead of
// passing in a callback.
- [this](TensorId t, Level lvl) -> std::vector<std::pair<TensorId, Level>> {
- // Translates from a list of loop index to a list of [tid, dim] pair.
- std::vector<LoopId> rLoops = this->merger().getDependentLoops(t, lvl);
- std::vector<std::pair<TensorId, Level>> ret;
+ /*dependentLvlGetter=*/
+ [this](TensorId t,
+ Level lvl) -> std::vector<std::pair<TensorLevel, unsigned>> {
+ // Translates from a list of loop indices to a list of [tid, lvl] pair.
+ std::vector<LoopCoeffPair> &rLoops = merger().getDependentLoops(t, lvl);
+ std::vector<std::pair<TensorLevel, unsigned>> ret;
ret.reserve(rLoops.size());
- for (LoopId l : rLoops)
- ret.emplace_back(this->merger().getLoopDefiningLvl(l));
+ for (auto [loop, coeff] : rLoops) {
+ TensorLevel tl = makeTensorLevel(merger().getLoopDefiningLvl(loop));
+ ret.emplace_back(tl, coeff);
+ };
return ret;
});
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
index dea9e740b8db64..c0fc505d153a45 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
@@ -96,6 +96,9 @@ class CodegenEnv {
loopEmitter.getSynTensorId() == latticeMerger.getSynTensorID());
return loopEmitter.makeTensorLevel(t, l);
}
+ TensorLevel makeTensorLevel(std::pair<TensorId, Level> tlPair) const {
+ return makeTensorLevel(tlPair.first, tlPair.second);
+ }
std::pair<TensorId, Level> unpackTensorLevel(TensorLevel tl) const {
return loopEmitter.unpackTensorLevel(tl);
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
index 06db5b0ab78e35..441f29dedcdafb 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
@@ -264,14 +264,11 @@ Value LoopEmitter::genSegmentHigh(OpBuilder &builder, Location loc,
Value LoopEmitter::genSparseCrd(OpBuilder &builder, Location loc, TensorId tid,
Level lvl) {
- Value crd = C_IDX(0);
// A load on the coordinates array yields the coordinate.
const Value mem = coordinatesBuffers[tid][lvl];
/// FIXME: See the [CLARIFY_POSITS_LVL] note in the header.
const Value pos = posits[tid][lvl];
- const Value off = genIndexLoad(builder, loc, mem, pos);
- // Linearized the coordinates within the same collapse reassociation.
- crd = ADDI(crd, off);
+ const Value crd = genIndexLoad(builder, loc, mem, pos);
return crd;
}
@@ -317,9 +314,10 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
// Index-reduction related fields.
this->dependentLvlMap.assign(
- numTensors, std::vector<std::vector<std::pair<TensorId, Level>>>());
+ numTensors, std::vector<std::vector<std::pair<TensorLevel, unsigned>>>());
this->slicePosBuffer.assign(numTensors, std::vector<std::vector<Value>>());
- this->sliceSizes.assign(numTensors, std::vector<std::vector<Value>>());
+ this->sliceMeta.assign(
+ numTensors, std::vector<std::vector<std::pair<Value, unsigned>>>());
this->sliceStack.assign(numTensors, std::vector<SliceInfo>());
this->levelReducedDep.assign(numTensors, std::vector<unsigned>());
@@ -367,10 +365,10 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
// Slice-driven loops related initialization.
levelReducedDep[tid].assign(lvlRank, 0);
- dependentLvlMap[tid].assign(lvlRank,
- std::vector<std::pair<TensorId, Level>>());
+ dependentLvlMap[tid].assign(
+ lvlRank, std::vector<std::pair<TensorLevel, unsigned>>());
slicePosBuffer[tid].assign(lvlRank, std::vector<Value>());
- sliceSizes[tid].assign(lvlRank, std::vector<Value>());
+ sliceMeta[tid].assign(lvlRank, std::vector<std::pair<Value, unsigned>>());
sliceStack[tid].emplace_back(/*minCrd=*/Value(),
/*offset=*/Value(), /*isNonEmpty*/ Value(),
std::nullopt, 0);
@@ -380,8 +378,8 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
unsigned depends = dependentLvlMap[tid][l].size();
if (depends == 0)
continue;
- // We need `depends - 1` slices to fully the affine expression.
- sliceSizes[tid][l].assign(depends - 1, nullptr);
+ sliceMeta[tid][l].assign(depends, std::make_pair(nullptr, 0));
+ // We need `depends - 1` slices to fully reduce the affine expression.
slicePosBuffer[tid][l].assign(depends - 1, nullptr);
}
}
@@ -502,15 +500,20 @@ void LoopEmitter::initializeLoopEmit(
Level lvlRank = SparseTensorType(rtp).getLvlRank();
for (Level lvl = 0; lvl < lvlRank; lvl++) {
if (!dependentLvlMap[t][lvl].empty()) {
- ArrayRef<std::pair<TensorId, Level>> depLvls = dependentLvlMap[t][lvl];
+ ArrayRef<std::pair<TensorLevel, unsigned>> depLvls =
+ dependentLvlMap[t][lvl];
// Needs at least two operands to form a non-trivial affine expression.
- assert(depLvls.size() > 1);
+ assert(depLvls.size() == sliceMeta[t][lvl].size());
Value size = c0;
- for (unsigned e = depLvls.size() - 1; e >= 1; e--) {
- auto [dt, dd] = depLvls[e];
- size = ADDI(size, lvlSizes[dt][dd]);
- sliceSizes[t][lvl][e - 1] = size;
+ for (int e = depLvls.size() - 1; e >= 0; e--) {
+ auto [dt, dl] = unpackTensorLevel(depLvls[e].first);
+ unsigned stride = depLvls[e].second;
+ Value stridedSize = lvlSizes[dt][dl];
+ if (stride != 1)
+ stridedSize = MULI(stridedSize, C_IDX(stride));
+ size = ADDI(size, stridedSize);
+ sliceMeta[t][lvl][e] = std::make_pair(size, stride);
}
}
}
@@ -729,8 +732,9 @@ Value LoopEmitter::genWhileLoopConditions(OpBuilder &builder, Location loc,
// crdHi is a loop invariant, hosit the computation outside the loop.
if (llvm::isa_and_nonnull<scf::WhileOp>(loop))
builder.setInsertionPoint(loop);
- crdHi = ADDI(getMostRecentSliceOnLvl(tid, lvl).offset,
- sliceSizes[tid][lvl].back());
+ auto [size, stride] = sliceMeta[tid][lvl].back();
+ assert(stride == 1 && "Not yet implemented");
+ crdHi = ADDI(getMostRecentSliceOnLvl(tid, lvl).offset, size);
}
assert(crdHi);
return genSparseReducedAffineCond(builder, loc,
@@ -984,7 +988,7 @@ bool LoopEmitter::shouldIteratedByForLoop(ArrayRef<TensorLvlCond> sparseConds,
if (sparseConds.size() > 1)
return false;
- // We also need a while loop for levels with affine index expression for
+ // We also need a while loop for levels with affine index expression and
// non-unique levels when deduplication is required.
if (sparseConds.size() == 1) {
auto [tid, lvl] = unpackTensorLevel(sparseConds.back().first);
@@ -1042,7 +1046,9 @@ Operation *LoopEmitter::enterCoIterationOverTensorsAtLvls(
if (isDenseCond(loopCondKind) && isAffineIdxCond(loopCondKind)) {
bool unReduc = isAffineIdxUnRedCond(loopCondKind);
assert(unReduc == !depFullyReduced(tid, lvl));
- hi = sliceSizes[tid][lvl][sliceStack[tid].back().depth - 1];
+ auto [size, stride] = sliceMeta[tid][lvl][sliceStack[tid].back().depth];
+ assert(stride == 1 && "Not yet implemented");
+ hi = size;
if (unReduc) {
// Adjust for loop hi for dense slice-driven loop.
hi = SUBI(lvlSizes[tid][lvl], hi);
@@ -1215,6 +1221,8 @@ void LoopEmitter::enterTensorsAtDenseLvls(
SliceInfo &info = sliceStack[tid].back();
// Pushes sliced dense loop info to tell LoopEmitter how to exit it.
sliceInfo.emplace_back(tid, lvl, /*fullyReduced=*/!unReduc);
+ // FIXME: The offset and position iterator need to be adjusted when the
+ // slice is strided.
if (unReduc) {
assert(*info.slicedOnLvl == lvl);
// Update the slice information as we enter the new loop.
@@ -1361,7 +1369,9 @@ void LoopEmitter::forwardsReducedSliceLevelTreeIt(OpBuilder &builder,
while (curLvl < leafLvl && isDenseDLT(lvlTypes[tid][curLvl])) {
// One step forward in parent level results in forwarding `slice.size` step
// in child dense level.
- fcnt = MULI(sliceSizes[tid][curLvl].back(), fcnt);
+ auto [size, stride] = sliceMeta[tid][curLvl].back();
+ assert(stride == 1 && "Not yet implemented");
+ fcnt = MULI(size, fcnt);
curLvl++;
}
@@ -1420,7 +1430,18 @@ void LoopEmitter::exitWhileLoop(OpBuilder &builder, Location loc,
// TODO: support coiterating multiple slices
assert(loopInfo.trivialTidLvls.empty() &&
loopInfo.sliceDrivenInfo.size() == 1);
- genSliceNextInduction(builder, loc, whileOp, tid, lvl, operands, o);
+ auto [nxNonEmpty, nxMinCrd, nxAbsOffset] =
+ genSliceNextInduction(builder, loc, tid, lvl);
+ // Update while loop induction operands.
+ operands.push_back(nxNonEmpty);
+ operands.push_back(nxMinCrd);
+ operands.push_back(nxAbsOffset);
+
+ // Update the slice stack.
+ SliceInfo &info = sliceStack[tid].back();
+ info.isNonEmpty = whileOp.getResult(o++);
+ info.minCrd = whileOp.getResult(o++);
+ info.offset = whileOp.getResult(o++);
continue;
}
@@ -1566,7 +1587,10 @@ std::pair<Operation *, ValueRange> LoopEmitter::genSliceLvlTraverseLoop(
Value size, TensorId tid, Level lvl, ValueRange userReduc,
LoopBodyBuilder bodyBuilder) {
Value c1 = C_IDX(1);
- Value sliceHi = ADDI(offset, sliceSizes[tid][lvl].back());
+ auto [sliceSz, stride] = sliceMeta[tid][lvl].back();
+ assert(stride == 1 && "Not yet implemented");
+ Value sliceHi = ADDI(offset, sliceSz);
+
SmallVector<Value> reduc{posLo}; // loop lower bounds
const unsigned numMetaReduc = reduc.size();
@@ -1663,6 +1687,8 @@ ValueRange LoopEmitter::genUnResolvedSliceTreeTraverse(
reduc.back() = ADDI(reduc.back(), C_IDX(1));
};
+ // FIXME: Need special handling when the previous unresolved slice is strided:
+ // We probably need to filter out coordinates that is not on stride.
if (firstResLvl.has_value()) {
// Overwrite position when the first level is fully resolved.
pos = posits[firstResLvl->first][firstResLvl->second];
@@ -1694,10 +1720,13 @@ ValueRange LoopEmitter::genUnResolvedSliceTreeTraverse(
// non-consecutive segments.
builder.create<memref::StoreOp>(loc, iterArgs.back(), sPtrBuf,
ADDI(iv, c2).getResult());
+
+ auto [size, stride] = sliceMeta[tid][firstLvl].back();
+ assert(stride == 1 && "Not yet implemented");
ValueRange itArgs =
genSliceLvlTraverseLoop(
- builder, loc, loopLo, loopHi, offset,
- sliceSizes[tid][firstLvl].back(), tid, firstLvl, iterArgs,
+ builder, loc, loopLo, loopHi, offset, size, tid, firstLvl,
+ iterArgs,
[&](OpBuilder &builder, Location, Value iv,
MutableArrayRef<Value> reduc) {
ip = builder.saveInsertionPoint();
@@ -1710,8 +1739,9 @@ ValueRange LoopEmitter::genUnResolvedSliceTreeTraverse(
} else if (isDenseDLT(lvlTypes[tid][firstLvl])) {
assert(firstLvl == 0); // This must be the first level.
Value lb = frontSlice.offset;
- Value sliceSz =
- sliceSizes[tid][*frontSlice.slicedOnLvl][frontSlice.depth - 1];
+ auto [sliceSz, stride] =
+ sliceMeta[tid][*frontSlice.slicedOnLvl][frontSlice.depth];
+ assert(stride == 1 && "Not yet implemented");
Value ub = ADDI(lb, sliceSz);
outerMost = builder.create<scf::ForOp>(
loc, lb, ub, c1, innerArgs,
@@ -1735,7 +1765,8 @@ ValueRange LoopEmitter::genUnResolvedSliceTreeTraverse(
Level sliceLvl = *slice->slicedOnLvl;
assert(isDenseDLT(lvlTypes[tid][sliceLvl]));
Value offset = slice->offset;
- Value sliceSz = sliceSizes[tid][sliceLvl][slice->depth - 1];
+ auto [sliceSz, stride] = sliceMeta[tid][sliceLvl][slice->depth];
+ assert(stride == 1 && "Not yet implemented");
lbs.push_back(offset);
ubs.push_back(ADDI(offset, sliceSz));
steps.push_back(c1);
@@ -1788,7 +1819,8 @@ void LoopEmitter::genResolvedSliceBegin(OpBuilder &builder, Location loc,
lvl, /*depth=*/1);
return;
}
- Value size = sliceSizes[tid][lvl][0];
+ auto [nxSz, stride] = sliceMeta[tid][lvl][1];
+ assert(stride == 1 && "Not yet implemented");
Value sPtrBuf = slicePosBuffer[tid][lvl][0];
Value pHi, pLo;
if (lvl == 0) {
@@ -1816,7 +1848,7 @@ void LoopEmitter::genResolvedSliceBegin(OpBuilder &builder, Location loc,
Value minCrd = genIndexLoad(builder, loc, coordinatesBuffers[tid][lvl], pLo);
// FIXME: We need the relative offset related to the base slice.
- Value absOffset = offsetFromMinCoord(builder, loc, minCrd, size, isNonEmpty);
+ Value absOffset = offsetFromMinCoord(builder, loc, minCrd, nxSz, isNonEmpty);
sliceStack[tid].emplace_back(minCrd, absOffset, isNonEmpty, lvl,
/*depth=*/1);
}
@@ -1845,7 +1877,12 @@ void LoopEmitter::genUnResolvedSliceBegin(OpBuilder &builder, Location loc,
TensorId tid, Level lvl) {
Value c0 = C_IDX(0), c1 = C_IDX(1), c2 = C_IDX(2);
unsigned depth = levelReducedDep[tid][lvl];
- Value size = sliceSizes[tid][lvl][depth]; // Dense slice begin is trivial
+ // TODO: handle case when the current slice stride is not one.
+ assert(sliceMeta[tid][lvl][depth].second == 1 && "Not yet implemented");
+
+ // The remaining slice size after reduction.
+ Value remSz = sliceMeta[tid][lvl][depth + 1].first;
+ // Dense slice begin is trivial
if (isDenseDLT(lvlTypes[tid][lvl])) {
sliceStack[tid].emplace_back(c0, c0, constantI1(builder, loc, false), lvl,
depth + 1);
@@ -1941,7 +1978,7 @@ void LoopEmitter::genUnResolvedSliceBegin(OpBuilder &builder, Location loc,
builder.create<memref::StoreOp>(loc, result[2], sPtrBuf, c0);
builder.create<memref::StoreOp>(loc, c0, sPtrBuf, c1);
// FIXME: we need the relative offset related to the base slice.
- Value absOffset = offsetFromMinCoord(builder, loc, minCrd, size, isNonEmpty);
+ Value absOffset = offsetFromMinCoord(builder, loc, minCrd, remSz, isNonEmpty);
sliceStack[tid].emplace_back(minCrd, absOffset, isNonEmpty, lvl, depth + 1);
}
@@ -2005,9 +2042,12 @@ bool LoopEmitter::genSliceBegin(OpBuilder &builder, Location loc, TensorId tid,
// TODO: Maybe using allocaScopeOp inside the loop to resolve the issue?
for (Level curLevel = lvl;
curLevel >= 1 && !lvlFullyResolved(tid, curLevel - 1); curLevel--) {
- auto depth = remDepOnLevel(tid, curLevel - 1);
- assert(sliceSizes[tid][lvl].size() >= depth);
- Value sz = *(sliceSizes[tid][lvl].rbegin() + depth - 1);
+ // We only handle cases when all the previously unresolved levels are
+ // fully reduced.
+ assert(depFullyReduced(tid, curLevel - 1));
+ assert(!sliceMeta[tid][curLevel - 1].empty());
+ auto [sz, stride] = sliceMeta[tid][curLevel - 1].back();
+ assert(stride == 1 && "Not yet implemented");
bufSize = MULI(bufSize, sz);
}
// For a triple of [pLo, pHi, pPtr]. Note that we can not compress pHi
@@ -2042,18 +2082,15 @@ void LoopEmitter::invalidateSliceIterIdx(OpBuilder &builder, Location loc,
}
}
-void LoopEmitter::genSliceNextInduction(OpBuilder &builder, Location loc,
- const Operation *op, TensorId tid,
- Level lvl,
- SmallVectorImpl<Value> &operands,
- unsigned &retIdx) {
+std::tuple<Value, Value, Value>
+LoopEmitter::genSliceNextInduction(OpBuilder &builder, Location loc,
+ TensorId tid, Level lvl) {
if (!isCompressedDLT(lvlTypes[tid][lvl]))
llvm_unreachable("TODO");
// else generate code to compute next non empty slice.
Value c0 = C_IDX(0), c1 = C_IDX(1), c2 = C_IDX(2);
- auto whileOp = llvm::cast<scf::WhileOp>(op);
SliceInfo &info = sliceStack[tid].back();
assert(info.slicedOnLvl == lvl);
//
@@ -2182,9 +2219,12 @@ void LoopEmitter::genSliceNextInduction(OpBuilder &builder, Location loc,
builder.setInsertionPointAfter(forOp.loops.front());
// minOffset = minCrd + 1 >= size ? minCrd + 1 - size : c0
Value tmp = ADDI(forOp.results.front(), c1);
- Value minOffset = SUBI(tmp, sliceSizes[tid][lvl][info.depth - 1]);
- Value p = CMPI(uge, tmp, sliceSizes[tid][lvl][info.depth - 1]);
+ auto [size, stride] = sliceMeta[tid][lvl][info.depth];
+ assert(stride == 1 && "Not yet implemented");
+ Value minOffset = SUBI(tmp, size);
+ Value p = CMPI(uge, tmp, size);
minOffset = SELECT(p, minOffset, c0);
+
SmallVector<Value, 3> yields;
yields.assign(forOp.results.begin(), forOp.results.end());
yields.push_back(minOffset);
@@ -2200,7 +2240,9 @@ void LoopEmitter::genSliceNextInduction(OpBuilder &builder, Location loc,
Value maxPred = CMPI(ugt, minOffset, nxOffset);
Value nextAbsOffset = SELECT(maxPred, minOffset, nxOffset);
- Value sliceUB = ADDI(nextAbsOffset, sliceSizes[tid][lvl][info.depth - 1]);
+ auto [size, stride] = sliceMeta[tid][lvl][info.depth];
+ assert(stride == 1 && "Not yet implemented");
+ Value sliceUB = ADDI(nextAbsOffset, size);
// FIXME: this only works if there is only one parent.
assert(info.depth - 1 == 0);
@@ -2211,15 +2253,7 @@ void LoopEmitter::genSliceNextInduction(OpBuilder &builder, Location loc,
assert(info.depth - 1 == 0);
Value nextRelOffset = nextAbsOffset;
nextRelOffset = SELECT(nextNonEmpty, nextRelOffset, c0);
-
- operands.push_back(nextNonEmpty);
- operands.push_back(nextMinCrd);
- operands.push_back(nextAbsOffset); // we push the absolute offset.
-
- // Update the slice stack.
- info.isNonEmpty = whileOp.getResult(retIdx++);
- info.minCrd = whileOp.getResult(retIdx++);
- info.offset = whileOp.getResult(retIdx++);
+ return std::make_tuple(nextNonEmpty, nextMinCrd, nextAbsOffset);
}
#undef CMPI
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
index 1b3acd68e587d7..d9948d3f4db73b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
@@ -84,18 +84,22 @@ class LoopEmitter {
using SynTensorBoundSetter =
function_ref<Value(OpBuilder &builder, Location loc, Level lvl)>;
- // Map from [tid, dim] to a list of dependent [tid, dim] for affine expression
- // index on sparse tensors.
- // E.g., for affine index (d0 + d1), it depends on two [tid, dim] that defines
- // d0 and d1 (for affine expression reduction).
+ // Map from [tid, lvl] to a list of dependent [tidlvl, coeffecient] for
+ // subscript expressions on sparse tensors.
+ //
+ // E.g., for affine index (2 * d0 + d1), it depends on two tidlvls that
+ // defines d0 and d1 (for affine expression reduction) and uses 2 and 1 for
+ // cofficients on d0, d1 respectively.
// If the list is empty, it means that there is no affine expression on the
- // input [tid, dim].
+ // input [tid, lvl].
+ //
// NOTE: The caller is responsible to ensure that the order of the returned
// list to be consistent with the topological order of the iteration graph,
// otherwise the loop emitter might reduce a wrong dependent index variable
// when generating slice-driven loops.
using DependentLvlGetter =
- function_ref<std::vector<std::pair<TensorId, Level>>(TensorId, Level)>;
+ function_ref<std::vector<std::pair<TensorLevel, unsigned>>(TensorId,
+ Level)>;
LoopEmitter() = default;
@@ -335,9 +339,9 @@ class LoopEmitter {
// Whether this is the tensor that has not yet been sliced.
bool isInitialTensor() const { return !slicedOnLvl.has_value(); }
- Value minCrd; // the minimum coordinate of the slice.
- Value offset; // the offset of the current slice.
- Value isNonEmpty; // whether the slice is empty.
+ Value minCrd; // the minimum coordinate of the slice.
+ Value offset; // the *absolute* offset of the current slice.
+ Value isNonEmpty; // whether the slice is empty.
std::optional<Level> slicedOnLvl; // the level on which the slice is done
unsigned depth; // the depth (relative to dependentDimMap[tid][lvl]).
};
@@ -645,10 +649,12 @@ class LoopEmitter {
bool genSliceBegin(OpBuilder &builder, Location loc, TensorId tid, Level lvl);
/// Generates code to get the next non-empty slices of tid on lvl.
- void genSliceNextInduction(OpBuilder &builder, Location loc,
- const Operation *whileOp, TensorId tid, Level lvl,
- SmallVectorImpl<Value> &operands,
- unsigned &retIdx);
+ /// Returns a tuple of values for <NonEmpty, MinCrd, AbsOffset> (see
+ /// SliceInfo) respectively.
+ std::tuple<Value, Value, Value> genSliceNextInduction(OpBuilder &builder,
+ Location loc,
+ TensorId tid,
+ Level lvl);
/// A optional string attribute that should be attached to the loop
/// generated by loop emitter, it might help following passes to identify
@@ -707,9 +713,9 @@ class LoopEmitter {
std::vector<std::vector<Value>> sliceOffsets;
std::vector<std::vector<Value>> sliceStrides;
- // Map from [tid, level] to a list of dependent [tid, level].
- // See comments for `DependentDimGetter`.
- std::vector<std::vector<std::vector<std::pair<TensorId, Level>>>>
+ // Map from [tid, level] to a list of dependent [tidlevel, coefficient].
+ // See comments for `DependentLvlGetter`.
+ std::vector<std::vector<std::vector<std::pair<TensorLevel, unsigned>>>>
dependentLvlMap;
// The cached position buffer for the slices, they serve the same purpose as
@@ -718,8 +724,9 @@ class LoopEmitter {
// to avoid iteration from the beginning.
std::vector<std::vector<std::vector<Value>>> slicePosBuffer;
- // The cached size for each slices.
- std::vector<std::vector<std::vector<Value>>> sliceSizes;
+ // The (size, stride) for each conceptual slice used for index reduction
+ // loops.
+ std::vector<std::vector<std::vector<std::pair<Value, unsigned>>>> sliceMeta;
// The number of reduced dependencies on a tensor level so far.
std::vector<std::vector<unsigned>> levelReducedDep;
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 2450fd6c7d03f6..770349d6d1db0f 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -282,10 +282,14 @@ static bool findAffine(Merger &merger, TensorId tid, Level lvl, AffineExpr a,
///
/// TODO: constant should be easy to handle.
static bool findDepIdxSet(Merger &merger, TensorId tensor, Level lvl,
- AffineExpr a, DimLevelType dlt,
- bool isSubExp = false) {
+ AffineExpr a, DimLevelType dlt, bool isSubExp = false,
+ int64_t coefficient = 1) {
switch (a.getKind()) {
case AffineExprKind::DimId: {
+ // Only allow positive coefficients on AffineDimExpr.
+ if (coefficient <= 0)
+ return false;
+
const LoopId ldx = merger.makeLoopId(a.cast<AffineDimExpr>().getPosition());
if (!isUndefDLT(merger.getLvlType(tensor, ldx)))
return false; // used more than once, e.g., A[i][i]
@@ -293,8 +297,10 @@ static bool findDepIdxSet(Merger &merger, TensorId tensor, Level lvl,
// TODO: Generalizes the following two cases. A[i] (with trivial index
// expression) can be treated as a special affine index expression. We do
// not necessarily need to
diff erentiate them.
- if (!isSubExp)
+ if (!isSubExp) {
+ assert(coefficient == 1);
merger.setLevelAndType(tensor, ldx, lvl, dlt);
+ }
if (isSubExp) {
// The current loops appears in more than one affine expressions on the
@@ -312,14 +318,26 @@ static bool findDepIdxSet(Merger &merger, TensorId tensor, Level lvl,
// else increase min(d0_1, d0_2).
return false;
}
- merger.setLoopDependentTensorLevel(ldx, tensor, lvl, dlt);
+ merger.setLoopDependentTensorLevel(ldx, tensor, lvl, dlt, coefficient);
}
return true;
}
case AffineExprKind::Constant:
- case AffineExprKind::Mul:
- // TODO: Support Mul and Constant AffineExp for slice-based codegen
- return false;
+ // TODO: Support Constant AffineExp for slice-based codegen
+ case AffineExprKind::Mul: {
+ // TODO: Support index expression like `2 * d0`, we now only support more
+ // complicated cases like `2 * d0 + d1`.
+ if (!isSubExp)
+ return false;
+ auto binOp = a.cast<AffineBinaryOpExpr>();
+ auto lhs = binOp.getLHS(), rhs = binOp.getRHS();
+ if (rhs.isa<AffineConstantExpr>())
+ std::swap(lhs, rhs);
+ // Must be in form of `constant * d`.
+ assert(lhs.isa<AffineConstantExpr>() && rhs.isa<AffineDimExpr>());
+ int64_t coefficient = lhs.cast<AffineConstantExpr>().getValue();
+ return findDepIdxSet(merger, tensor, lvl, rhs, dlt, isSubExp, coefficient);
+ }
case AffineExprKind::Add: {
auto binOp = a.cast<AffineBinaryOpExpr>();
return findDepIdxSet(merger, tensor, lvl, binOp.getLHS(), dlt, true) &&
diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index f39a2069a57dd8..4143efbd0ab28e 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -232,11 +232,11 @@ Merger::Merger(unsigned numInputOutputTensors, unsigned numNativeLoops,
std::vector<std::optional<Level>>(numLoops, std::nullopt)),
lvlToLoop(numTensors,
std::vector<std::optional<LoopId>>(maxLvlRank, std::nullopt)),
- loopToDependencies(
- numLoops, std::vector<std::optional<std::pair<Level, DimLevelType>>>(
- numTensors, std::nullopt)),
- levelToDependentLoop(numTensors, std::vector<std::vector<LoopId>>(
- maxLvlRank, std::vector<LoopId>())),
+ loopToUnresolvedLvls(numLoops, std::vector<std::optional<LvlDLTPair>>(
+ numTensors, std::nullopt)),
+ levelToDependentLoop(numTensors,
+ std::vector<std::vector<LoopCoeffPair>>(
+ maxLvlRank, std::vector<LoopCoeffPair>())),
loopBounds(numLoops, std::make_pair(numTensors, numLoops)) {}
//===----------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list