[Mlir-commits] [mlir] d03805f - [mlir][sparse] add merger/topo sort support for slice-based affine sparse index codegen
Peiming Liu
llvmlistbot at llvm.org
Mon Mar 20 14:24:15 PDT 2023
Author: Peiming Liu
Date: 2023-03-20T21:24:10Z
New Revision: d03805f2ee0bdaa2513fbc3efb9e404e128bdbb3
URL: https://github.com/llvm/llvm-project/commit/d03805f2ee0bdaa2513fbc3efb9e404e128bdbb3
DIFF: https://github.com/llvm/llvm-project/commit/d03805f2ee0bdaa2513fbc3efb9e404e128bdbb3.diff
LOG: [mlir][sparse] add merger/topo sort support for slice-based affine sparse index codegen
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D142928
Added:
Modified:
mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
index 59c5b78fda7b8..0e6c2f1553f1c 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
@@ -435,19 +435,58 @@ class Merger {
/// Iterates over a set of `TensorLoopId`s, invoking the callback
/// for each `TensorLoopId` and passing it the corresponding tensor
- /// identifier, level, and level-type.
- void
- foreachTensorLoopId(LatPointId p,
- function_ref<void(TensorLoopId, TensorId,
- std::optional<Level>, DimLevelType)>
- callback) const {
- for (const TensorLoopId b : latPoints[p].bits.set_bits())
- callback(b, tensor(b), getLvl(b), getDimLevelType(b));
+ /// identifier, level, and level-type, following with a boolean value
+ /// indicating whether it is a dependent index reduction loop condition.
+ void foreachTensorLoopId(
+ LatPointId p, function_ref<void(TensorLoopId, TensorId,
+ std::optional<Level>, DimLevelType, bool)>
+ callback) {
+ for (const TensorLoopId b : latPoints[p].bits.set_bits()) {
+ TensorId t = tensor(b);
+ if (isLvlWithNonTrivialIdxExp(b)) {
+ // This must be an undefined level.
+ assert(!getLvl(b).has_value());
+ // Slice the tid along the dependent level to iterate current loop.
+ callback(b, t, loopToDependencies[loop(b)][t], getDimLevelType(b),
+ /*isIdxReduc=*/true);
+ } else {
+ callback(b, t, getLvl(b), getDimLevelType(b), /*isIdxReduc=*/false);
+ }
+ }
}
/// Sets whether the output tensor is sparse or not.
void setHasSparseOut(bool s) { hasSparseOut = s; }
+ /// Establishes the two-way map that i <-> <t, lvl>.
+ void setLoopDependentTensorLevel(LoopId i, TensorId t, Level lvl) {
+ assert(lvl < numLoops);
+ loopToDependencies[i][t] = lvl;
+ levelToDependentIdx[t][lvl].push_back(i);
+ }
+
+ /// Whether the loop has dependent slice.
+ bool hasDependentLvl(LoopId i, TensorId tid) {
+ return loopToDependencies[i][tid].has_value();
+ }
+
+ /// Returns the list of loop indices which appear in the non-trivial index
+ /// expression on t_l, e.g., A[i+j] => {i, j}
+ std::vector<LoopId> &getDependentLoops(TensorId t, Level lvl) {
+ return levelToDependentIdx[t][lvl];
+ }
+
+ /// Returns the defining [tid, lvl] for the loop.
+ std::pair<TensorId, Level> getLoopDefiningLvl(LoopId i) const {
+ return loopBounds[i];
+ }
+
+ /// Checks whether the TensorLoopId represents a tensor level with
+ /// non-trivial index expression on it.
+ bool isLvlWithNonTrivialIdxExp(TensorLoopId b) const {
+ return loopToDependencies[loop(b)][tensor(b)].has_value();
+ }
+
/// Convenience getters to immediately access the stored nodes.
/// Typically it is inadvisible to keep the reference around, as in
/// `TensorExpr &te = merger.exp(e)`, since insertions into the merger
@@ -511,6 +550,20 @@ class Merger {
// Map that converts pair<TensorId, Level> to the corresponding LoopId.
std::vector<std::vector<std::optional<LoopId>>> lvlToLoop;
+ // Map from a loop to its dependencies if any.
+ // The dependencies of a loop is a set of (tensor, level) pairs.
+ // It is currently only set for non-trivial index expressions.
+ // E.g., A[i+j] => i and j will have dependencies {A0} to indicate that
+ // i and j are used in the non-trivial index expression on A0.
+ std::vector<std::vector<std::optional<Level>>> loopToDependencies;
+ // The inverse map of ldxToDependencies from tensor level -> dependent loop
+ // E.g., A[i+j], we have A0 => {i, j}, to indicate that A0 uses both {i, j}
+ // to compute its indices.
+ std::vector<std::vector<std::vector<LoopId>>> levelToDependentIdx;
+
+ // Map from a loop to the [tid, lvl] pair that defines the loop boundary.
+ std::vector<std::pair<TensorId, Level>> loopBounds;
+
llvm::SmallVector<TensorExp> tensorExps;
llvm::SmallVector<LatPoint> latPoints;
llvm::SmallVector<SmallVector<LatPointId>> latSets;
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
index 8c6a7bd6433db..776d7f7f47ece 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
@@ -99,6 +99,7 @@ class CodegenEnv {
topSort.reserve(capacity);
}
+ ArrayRef<LoopId> getTopSort() const { return topSort; };
ArrayRef<LoopId> getTopSortSlice(LoopOrd n, LoopOrd m) const;
ArrayRef<LoopId> getLoopStackUpTo(LoopOrd n) const;
ArrayRef<LoopId> getCurrentLoopStack() const;
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index f119ac3ba7ae5..d7ce2b7f63f5c 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -109,6 +109,12 @@ class AffineDimFinder : public AffineExprVisitor<AffineDimFinder> {
SmallVector<utils::IteratorType> iterTypes;
};
+// Flattens an affine expression into a list of AffineDimExprs.
+struct AffineDimCollector : public AffineExprVisitor<AffineDimCollector> {
+ void visitDimExpr(AffineDimExpr expr) { dims.push_back(expr); }
+ SmallVector<AffineDimExpr> dims;
+};
+
} // namespace
//===----------------------------------------------------------------------===//
@@ -254,6 +260,69 @@ static bool findAffine(Merger &merger, TensorId tid, Level lvl, AffineExpr a,
}
}
+/// Helper method to inspect affine expressions for index variable reduction
+/// based codegen. It finds the dependent index set for all tensor levels in the
+/// current expression we are generating.
+///
+/// For example, when handling A[i+j][j+k], we build the two way mapping in
+/// merger between (tensor, level) pairs and their dependent index variable set:
+/// A_0 <=> [i, j] and A_1 <=> [j, k]
+///
+/// It rejects cases (returns false)
+/// 1st, when the same index is used more than once, e.g., A[i+j][i]
+/// 2nd, when multiplication is used in the non-trivial index expression.
+/// 3rd, when a constant operand is used in the non-trivial index expression.
+///
+/// TODO: constant should be easy to handle.
+static bool findDepIdxSet(Merger &merger, TensorId tensor, Level lvl,
+ AffineExpr a, DimLevelType dlt,
+ bool isSubExp = false) {
+ switch (a.getKind()) {
+ case AffineExprKind::DimId: {
+ LoopId ldx = a.cast<AffineDimExpr>().getPosition();
+ if (!isUndefDLT(merger.getDimLevelType(tensor, ldx)))
+ return false; // used more than once, e.g., A[i][i]
+
+ // TODO: Generalizes the following two cases. A[i] (with trivial index
+ // expression) can be treated as a special affine index expression. We do
+ // not necessarily need to
diff erentiate them.
+ if (!isSubExp)
+ merger.setLevelAndType(tensor, ldx, lvl, dlt);
+
+ if (isSubExp) {
+ // The current loops appears in more than one affine expressions on the
+ // same tensor. We can not handle this case. e.g., A[i+j][i+k], `i` is
+ // used twice.
+ if (merger.hasDependentLvl(ldx, tensor)) {
+ // TODO: This can be supported by coiterate slices if the loop idx is
+ // appeared on affine index for
diff erent tensor, or take slice on
+ // mulitple dimensions when it is on the same tensor.
+ // E.g.,
+ // `d0 + d1` for indexing t0[lvl0] and `d0 + d2` for indexing t1[lvl0]
+ // d0_1 = getNextSliceOffset t0 along lvl0
+ // d0_2 = getNextSliceOffset t1 along lvl0
+ // if d0_1 == d0_2 then d0 = d0_1 = d0_1
+ // else increase min(d0_1, d0_2).
+ return false;
+ }
+ merger.setLoopDependentTensorLevel(ldx, tensor, lvl);
+ }
+ return true;
+ }
+ case AffineExprKind::Constant:
+ case AffineExprKind::Mul:
+ // TODO: Support Mul and Constant AffineExp for slice-based codegen
+ return false;
+ case AffineExprKind::Add: {
+ auto binOp = a.cast<AffineBinaryOpExpr>();
+ return findDepIdxSet(merger, tensor, lvl, binOp.getLHS(), dlt, true) &&
+ findDepIdxSet(merger, tensor, lvl, binOp.getRHS(), dlt, true);
+ }
+ default:
+ return false;
+ }
+}
+
/// Get the total number of compound affine expressions in the
/// `getMatchingIndexingMap` for the given tensor. For the following inputs:
///
@@ -262,7 +331,8 @@ static bool findAffine(Merger &merger, TensorId tid, Level lvl, AffineExpr a,
///
/// Returns 1 (because the first level is compressed and its corresponding
/// indexing-expression is `d0 + d1`)
-static unsigned getNumCompoundAffineOnSparseLvls(AffineMap map, Value tensor) {
+static unsigned getNumNonTrivialIdxExpOnSparseLvls(AffineMap map,
+ Value tensor) {
// The `tensor` is not guaranted to have `RankedTensorType`, therefore
// we can't use `getRankedTensorType`/`getSparseTensorType` here.
// However, we don't need to handle `StorageSpecifierType`, so we
@@ -305,20 +375,20 @@ static unsigned getNumCompoundAffineOnSparseLvls(AffineMap map, Value tensor) {
/// Get the total number of sparse levels with compound affine
/// expressions, summed over all operands of the `GenericOp`.
-static unsigned getNumCompoundAffineOnSparseLvls(linalg::GenericOp op) {
+static unsigned getNumNonTrivialIdxExpOnSparseLvls(linalg::GenericOp op) {
unsigned num = 0;
for (OpOperand &t : op->getOpOperands())
- num += getNumCompoundAffineOnSparseLvls(op.getMatchingIndexingMap(&t),
- t.get());
+ num += getNumNonTrivialIdxExpOnSparseLvls(op.getMatchingIndexingMap(&t),
+ t.get());
return num;
}
-static bool hasCompoundAffineOnSparseOut(linalg::GenericOp op) {
+static bool hasNonTrivialAffineOnSparseOut(linalg::GenericOp op) {
OpOperand *out = op.getDpsInitOperand(0);
if (getSparseTensorType(out->get()).isAllDense())
return false;
- return getNumCompoundAffineOnSparseLvls(op.getMatchingIndexingMap(out),
- out->get());
+ return getNumNonTrivialIdxExpOnSparseLvls(op.getMatchingIndexingMap(out),
+ out->get());
}
/// Helper method to inspect sparse encodings in the tensor types.
@@ -326,7 +396,14 @@ static bool hasCompoundAffineOnSparseOut(linalg::GenericOp op) {
/// Returns true if the sparse annotations and affine subscript
/// expressions of all tensors are admissible. Returns false if
/// no annotations are found or inadmissible constructs occur.
-static bool findSparseAnnotations(CodegenEnv &env) {
+/// We currently support two
diff erent ways to handle non-trivial index
+/// expression on sparse tensors, and they accept
diff erent affine expressions.
+/// When using filter-loop-based approach, it accept (almost) arbitrary affine
+/// index expression on sparse tensor but it is much less efficient, and will be
+/// gradually removed from the codebase.
+/// When using dependent index reducton-based approach, it currently only
+/// supports affine addition index expression.
+static bool findSparseAnnotations(CodegenEnv &env, bool idxReducBased) {
bool annotated = false;
// `filterLdx` may be mutated by `findAffine`.
LoopId filterLdx = env.merger().getStartingFilterLoopId();
@@ -335,17 +412,30 @@ static bool findSparseAnnotations(CodegenEnv &env) {
const auto enc = getSparseTensorEncoding(t.get().getType());
if (enc)
annotated = true;
+
const Level lvlRank = map.getNumResults();
assert(!enc || lvlRank == enc.getLvlRank());
assert(static_cast<Level>(env.op().getRank(&t)) == lvlRank);
+
+ // We only need to do index reduction if there is at least one non-trivial
+ // index expression on sparse levels.
+ // If all non-trivial index expression is on dense levels, we can
+ // efficiently rely on the random access to locate the element.
+ bool needIdxReduc =
+ enc && getNumNonTrivialIdxExpOnSparseLvls(map, t.get()) != 0;
+ // If then current tensor being inspected requires affine index, it need
+ // to be sliced.
for (Level l = 0; l < lvlRank; l++) {
const TensorId tid = t.getOperandNumber();
- // FIXME: `toOrigDim` is deprecated.
- // FIXME: above we asserted that there are `lvlRank` many results,
- // but this is assuming there are in fact `dimRank` many results instead.
- const AffineExpr a = map.getResult(toOrigDim(enc, l));
- if (!findAffine(env.merger(), tid, l, a, enc.getLvlType(l), filterLdx))
- return false; // inadmissible affine expression
+ AffineExpr a = map.getResult(toOrigDim(enc, l));
+ DimLevelType dlt = enc.getLvlType(l);
+ if (idxReducBased && needIdxReduc) {
+ if (!findDepIdxSet(env.merger(), tid, l, a, dlt))
+ return false; // inadmissible affine expression
+ } else {
+ if (!findAffine(env.merger(), tid, l, a, dlt, filterLdx))
+ return false; // inadmissible affine expression
+ }
}
}
assert(filterLdx == env.merger().getNumLoops());
@@ -469,11 +559,11 @@ static void addAffineOrderings(std::vector<std::vector<bool>> &adjM,
}
}
-static void tryLoosenAffineDenseConstraints(linalg::GenericOp op,
- std::optional<LoopId> &fldx,
- AffineExpr &fa,
- std::optional<LoopId> &tldx,
- AffineExpr &ta) {
+static void tryRelaxAffineConstraints(linalg::GenericOp op,
+ std::optional<LoopId> &fldx,
+ AffineExpr &fa,
+ std::optional<LoopId> &tldx,
+ AffineExpr &ta) {
// We use a heuristic here to only pick one dim expression from each
// compound affine expression to establish the order between two dense
// dimensions.
@@ -494,7 +584,7 @@ static void tryLoosenAffineDenseConstraints(linalg::GenericOp op,
}
if (!ta.isa<AffineConstantExpr>()) {
// Heuristic: we prefer reduction loop for rhs to reduce the chance
- // addint reduce < parallel ordering.
+ // adding reduce < parallel ordering.
finder.setPickedIterType(utils::IteratorType::reduction);
finder.walkPostOrder(ta);
ta = finder.getDimExpr();
@@ -503,14 +593,183 @@ static void tryLoosenAffineDenseConstraints(linalg::GenericOp op,
}
}
+/// Makes target array's elements appear in the same order as the `order` array.
+static void sortArrayBasedOnOrder(std::vector<LoopId> &target,
+ ArrayRef<LoopId> order) {
+ std::sort(target.begin(), target.end(), [&order](LoopId l, LoopId r) {
+ assert(l != r);
+ int idxL = -1, idxR = -1;
+ for (int i = 0, e = order.size(); i < e; i++) {
+ if (order[i] == l)
+ idxL = i;
+ if (order[i] == r)
+ idxR = i;
+ }
+ assert(idxL >= 0 && idxR >= 0);
+ return idxL < idxR;
+ });
+}
+
+static void addFilterLoopBasedConstraints(CodegenEnv &env, OpOperand &t,
+ OpOperand *skip, SortMask mask,
+ std::vector<std::vector<bool>> &adjM,
+ std::vector<unsigned> &inDegree) {
+ // Get map and encoding.
+ auto map = env.op().getMatchingIndexingMap(&t);
+ auto enc = getSparseTensorEncoding(t.get().getType());
+
+ // Each tensor expression and optional dimension ordering (row-major
+ // by default) puts an ordering constraint on the loop indices. For
+ // example, the tensor expresion A_ijk forces the ordering i < j < k
+ // on the loop indices if no explicit dimension ordering is given.
+ for (Level l = 0, rank = map.getNumResults(); l < rank; l++) {
+ AffineExpr ta = map.getResult(toOrigDim(enc, l));
+ std::optional<LoopId> tldx =
+ env.merger().getLoopId(t.getOperandNumber(), l);
+ // Filter loops should be constructed after all the dependent loops,
+ // i.e., d0 + d1 < filter_loop(d0 + d1)
+ if (tldx && env.merger().isFilterLoop(*tldx)) {
+ assert(!ta.isa<AffineDimExpr>() && !isDenseDLT(enc.getDimLevelType()[l]));
+ addAffineOrderings(adjM, inDegree, ta, AffineExpr(), std::nullopt, tldx);
+ // Now that the ordering of affine expression is captured by filter
+ // loop idx, we only need to ensure the affine ordering against filter
+ // loop. Thus, we reset the affine express to nil here to mark it as
+ // resolved.
+ ta = AffineExpr();
+ }
+
+ // Skip tensor during cycle resolution, though order between filter loop
+ // and dependent loops need to be guaranteed unconditionally.
+ if (&t == skip)
+ continue;
+
+ if (l > 0) {
+ AffineExpr fa = map.getResult(toOrigDim(enc, l - 1));
+ std::optional<LoopId> fldx =
+ env.merger().getLoopId(t.getOperandNumber(), l - 1);
+
+ // Applying order constraints on every pair of dimExpr between two
+ // compound affine expressions can sometime too strict:
+ // E.g, for [dense, dense] -> (d0 + d1, d2 + d3).
+ // It is totally fine to have loop sequence d0->d2->d1->d3 instead of
+ // requiring d0 < d2, d1 < d2, d0 < d3, d1 < d3.
+ // We also relax the affine constraint when use slice-based algorithm
+ // as there is no filter loop for affine index on sparse dimension.
+ // TODO: do we really need the condition?
+ if (!includesDense(mask))
+ tryRelaxAffineConstraints(env.op(), fldx, fa, tldx, ta);
+
+ // (d0 + d1) < (d2 + d3), or
+ // filter_loop_d-1 < (d2 + d3), or
+ // (d0 + d1) < filter_loop_d, or
+ // filter_loop_d-1 < filter_loop_d depending on whether fa/ta is reset
+ // above.
+ addAffineOrderings(adjM, inDegree, fa, ta, fldx, tldx);
+ }
+ }
+}
+
+static void addSliceBasedConstraints(CodegenEnv &env, OpOperand &t,
+ OpOperand *skip, SortMask mask,
+ std::vector<std::vector<bool>> &adjM,
+ std::vector<unsigned> &inDegree) {
+ // Get map and encoding.
+ auto map = env.op().getMatchingIndexingMap(&t);
+ auto enc = getSparseTensorEncoding(t.get().getType());
+
+ // No special treatment for simple indices.
+ if (getNumNonTrivialIdxExpOnSparseLvls(map, t.get()) == 0)
+ return addFilterLoopBasedConstraints(env, t, skip, mask, adjM, inDegree);
+
+ // Skip tensor during cycle resolution, though order between filter loop
+ // and dependent loops need to be guaranteed unconditionally.
+ if (&t == skip)
+ return;
+
+ AffineDimFinder finder(env.op());
+ finder.setPickedIterType(utils::IteratorType::reduction);
+ // To compute iteration graph for tensor[d0 + d1 + d3, d4 + d5 + d6],
+ // we requires there exist d_x \in {d0, d1, d3} and d_y \in {d4, d5, d6},
+ // and d_x > d_y && {d0, d1, d3} - d_x > {d4, d5, d6} - d_y
+ for (Level lvl = 1, rank = map.getNumResults(); lvl < rank; lvl++) {
+ AffineExpr fa = map.getResult(toOrigDim(enc, lvl - 1));
+ AffineExpr ta = map.getResult(toOrigDim(enc, lvl));
+
+ // This is a heuristic, we pick an abitrary reduction loop from lhs and
+ // rhs and use them as d_x and d_y.
+ finder.walkPostOrder(fa);
+ AffineDimExpr fexp = finder.getDimExpr();
+ LoopId fldx = fexp.getPosition();
+
+ finder.walkPostOrder(ta);
+ AffineDimExpr texp = finder.getDimExpr();
+ LoopId tldx = texp.getPosition();
+
+ // d_x > d_y
+ if (!adjM[fldx][tldx]) {
+ adjM[fldx][tldx] = true;
+ inDegree[tldx]++;
+ }
+
+ AffineDimCollector fCollector;
+ fCollector.walkPostOrder(fa);
+ AffineDimCollector tCollector;
+ tCollector.walkPostOrder(ta);
+
+ // make sure dx and dy is the last;
+ for (auto fd : fCollector.dims) {
+ LoopId f = fd.getPosition();
+ if (f == fldx)
+ continue;
+ if (!adjM[f][fldx]) {
+ adjM[f][fldx] = true;
+ inDegree[fldx]++;
+ }
+ }
+ for (auto td : tCollector.dims) {
+ LoopId t = td.getPosition();
+ if (t == tldx)
+ continue;
+ if (!adjM[t][tldx]) {
+ adjM[t][tldx] = true;
+ inDegree[tldx]++;
+ }
+ }
+ // Since we only support affine addition, the order between two dim
+ // expression does not really matters.
+ // {d0, d1, d3} - d_x > {d4, d5, d6} - d_y
+ // This is to ensure that the affine expressions are reduced in sparse
+ // tensor level ordering.
+ // TODO: this ordering could probably be loosen if we support out-of-order
+ // reduction.
+ // TODO: the evaluation order need to be ensure to
+ // support affine multiplication.
+ for (auto fd : fCollector.dims) {
+ LoopId f = fd.getPosition();
+ if (f == fldx) // skip d_x
+ continue;
+
+ for (auto td : tCollector.dims) {
+ LoopId t = td.getPosition();
+ if (t == tldx) // skip d_y
+ continue;
+ if (!adjM[f][t]) {
+ adjM[f][t] = true;
+ inDegree[t]++;
+ }
+ }
+ }
+ }
+}
+
/// Computes a topologically sorted iteration graph for the linalg operation.
-/// Ensures all tensors are visited in natural coordinate order. This is
+/// Ensures all tensors are visited in natural index order. This is
/// essential for sparse storage formats since these only support access
-/// along fixed levels. Even for dense storage formats, however, the natural
-/// coordinate order yields innermost unit-stride access with better spatial
+/// along fixed dimensions. Even for dense storage formats, however, the natural
+/// index order yields innermost unit-stride access with better spatial
/// locality.
static bool computeIterationGraph(CodegenEnv &env, SortMask mask,
- OpOperand *skip = nullptr) {
+ OpOperand *skip, bool idxReducBased = false) {
// Set up an n x n from/to adjacency matrix of the iteration graph
// for the implicit loop indices i_0 .. i_n-1.
const LoopId n = env.merger().getNumLoops();
@@ -522,7 +781,8 @@ static bool computeIterationGraph(CodegenEnv &env, SortMask mask,
// Get map and encoding.
const auto map = env.op().getMatchingIndexingMap(&t);
const auto enc = getSparseTensorEncoding(t.get().getType());
- assert(map.getNumDims() + getNumCompoundAffineOnSparseLvls(env.op()) == n);
+ assert(map.getNumDims() + getNumNonTrivialIdxExpOnSparseLvls(env.op()) ==
+ n);
// Skips dense inputs/outputs when not requested.
const bool isDenseInput = !enc && env.op().isDpsInput(&t);
@@ -549,63 +809,12 @@ static bool computeIterationGraph(CodegenEnv &env, SortMask mask,
}
}
}
-
- // Each tensor expression and optional dimension ordering (row-major
- // by default) puts an ordering constraint on the loop indices. For
- // example, the tensor expresion A_ijk forces the ordering i < j < k
- // on the loop indices if no explicit dimension ordering is given.
- const Level lvlRank = map.getNumResults();
- assert(!enc || lvlRank == enc.getLvlRank());
- for (Level l = 0; l < lvlRank; l++) {
- // FIXME: `toOrigDim` is deprecated.
- // FIXME: above we asserted that there are `lvlRank` many results,
- // but this is assuming there are in fact `dimRank` many results instead.
- AffineExpr ta = map.getResult(toOrigDim(enc, l));
- std::optional<LoopId> tldx =
- env.merger().getLoopId(t.getOperandNumber(), l);
-
- // Filter loops should be constructed after all the dependent loops,
- // i.e., d0 + d1 < filter_loop(d0 + d1)
- if (tldx && env.merger().isFilterLoop(*tldx)) {
- assert(!ta.isa<AffineDimExpr>() && !isDenseDLT(enc.getLvlType(l)));
- addAffineOrderings(adjM, inDegree, ta, AffineExpr(), std::nullopt,
- tldx);
- // Now that the ordering of affine expression is captured by filter
- // loop idx, we only need to ensure the affine ordering against filter
- // loop. Thus, we reset the affine express to nil here to mark it as
- // resolved.
- ta = AffineExpr();
- }
-
- // Skip tensor during cycle resolution, though order between filter loop
- // and dependent loops need to be guaranteed unconditionally.
- if (&t == skip)
- continue;
-
- if (l > 0) {
- // FIXME: `toOrigDim` is deprecated.
- // FIXME: above we asserted that there are `lvlRank` many results,
- // but this is assuming there are in fact `dimRank` many results.
- AffineExpr fa = map.getResult(toOrigDim(enc, l - 1));
- std::optional<LoopId> fldx =
- env.merger().getLoopId(t.getOperandNumber(), l - 1);
-
- // Applying order constraints on every pair of dimExpr between two
- // compound affine expressions can sometime too strict:
- // E.g, for [dense, dense] -> (d0 + d1, d2 + d3).
- // It is totally fine to have loop sequence d0->d2->d1->d3 instead of
- // requiring d0 < d2, d1 < d2, d0 < d3, d1 < d3.
- if (!includesDense(mask))
- tryLoosenAffineDenseConstraints(env.op(), fldx, fa, tldx, ta);
-
- // (d0 + d1) < (d2 + d3), or
- // filter_loop_d-1 < (d2 + d3), or
- // (d0 + d1) < filter_loop_d, or
- // filter_loop_d-1 < filter_loop_d depending on whether fa/ta is reset
- // above.
- addAffineOrderings(adjM, inDegree, fa, ta, fldx, tldx);
- }
- }
+ // Push unrelated loops into sparse iteration space, so these
+ // will be skipped more often.
+ if (idxReducBased)
+ addSliceBasedConstraints(env, t, skip, mask, adjM, inDegree);
+ else
+ addFilterLoopBasedConstraints(env, t, skip, mask, adjM, inDegree);
}
// Topologically sort the iteration graph to determine loop order.
// Report failure for a cyclic iteration graph.
@@ -1275,7 +1484,7 @@ static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp,
SmallVector<Level> lvls;
env.merger().foreachTensorLoopId(l0, [&](TensorLoopId b, TensorId tid,
std::optional<Level> lvl,
- DimLevelType dlt) {
+ DimLevelType dlt, bool) {
assert(env.merger().loop(b) == idx);
if (isDenseDLT(dlt) || isUndefDLT(dlt)) {
needsUniv = true;
@@ -1350,7 +1559,7 @@ static bool translateBitsToTidLvlPairs(
bool hasNonUnique = false;
env.merger().foreachTensorLoopId(li, [&, ldx](TensorLoopId b, TensorId tid,
std::optional<Level> lvl,
- DimLevelType dlt) {
+ DimLevelType dlt, bool) {
if (simple.test(b)) {
if (isUndefDLT(dlt)) {
// An undefined dlt in the lattices, we probably mean to
@@ -1596,21 +1805,25 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
PatternRewriter &rewriter) const override {
// Only accept single output operations without affine index on sparse
// output.
- if (op.getNumDpsInits() != 1 || hasCompoundAffineOnSparseOut(op))
+ if (op.getNumDpsInits() != 1 || hasNonTrivialAffineOnSparseOut(op))
return failure();
- if (options.enableIndexReduction)
- llvm_unreachable("not yet implemented");
-
// Sets up a code generation environment.
const unsigned numTensors = op->getNumOperands();
const unsigned numLoops = op.getNumLoops();
- const unsigned numFilterLoops = getNumCompoundAffineOnSparseLvls(op);
- CodegenEnv env(op, options, numTensors, numLoops, numFilterLoops);
+ const unsigned numFilterLoops = getNumNonTrivialIdxExpOnSparseLvls(op);
+ // TODO: we should probably always use slice-based codegen whenever
+ // possible, we can even intermix slice-based and filter-loop based codegen.
+ bool idxReducBased = options.enableIndexReduction && numFilterLoops != 0;
+
+ // If we uses slice based algorithm for affine index, we do not need filter
+ // loop.
+ CodegenEnv env(op, options, numTensors, numLoops,
+ /*numFilterLoops=*/idxReducBased ? 0 : numFilterLoops);
// Detects sparse annotations and translates the per-level sparsity
// information for all tensors to loop indices in the kernel.
- if (!findSparseAnnotations(env))
+ if (!findSparseAnnotations(env, idxReducBased))
return failure();
// Constructs the tensor expressions tree from `op`, returns failure if the
@@ -1635,7 +1848,7 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
SortMask::kIncludeDenseInput, SortMask::kIncludeDenseOutput,
SortMask::kIncludeUndef, SortMask::kSparseOnly};
for (const SortMask mask : allMasks) {
- if (computeIterationGraph(env, mask)) {
+ if (computeIterationGraph(env, mask, nullptr, idxReducBased)) {
hasCycle = false;
if (env.isAdmissibleTopoOrder()) {
isAdmissible = true;
@@ -1644,11 +1857,24 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
// else try a set of less strict constraints.
}
}
- if (hasCycle)
- return resolveCycle(env, rewriter); // one last shot
+ if (hasCycle) {
+ return idxReducBased
+ ? failure() // TODO: should cycle be resolved
diff erently?
+ : resolveCycle(env, rewriter); // one last shot
+ }
+
if (!isAdmissible)
return failure(); // inadmissible expression, reject
+ for (OpOperand &t : env.op()->getOpOperands()) {
+ Level rank = env.op().getMatchingIndexingMap(&t).getNumResults();
+ for (Level lvl = 0; lvl < rank; lvl++) {
+ sortArrayBasedOnOrder(
+ env.merger().getDependentLoops(t.getOperandNumber(), lvl),
+ env.getTopSort());
+ }
+ }
+
// Recursively generates code if admissible.
env.startEmit();
genBuffers(env, rewriter);
diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index 029ce3f3f91ec..7f4400188cf14 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -220,7 +220,12 @@ Merger::Merger(unsigned numInputOutputTensors, unsigned numNativeLoops,
loopToLvl(numTensors,
std::vector<std::optional<Level>>(numLoops, std::nullopt)),
lvlToLoop(numTensors,
- std::vector<std::optional<LoopId>>(numLoops, std::nullopt)) {}
+ std::vector<std::optional<LoopId>>(numLoops, std::nullopt)),
+ loopToDependencies(numLoops, std::vector<std::optional<Level>>(
+ numTensors, std::nullopt)),
+ levelToDependentIdx(numTensors, std::vector<std::vector<LoopId>>(
+ numLoops, std::vector<LoopId>())),
+ loopBounds(numLoops, std::make_pair(numTensors, numLoops)) {}
//===----------------------------------------------------------------------===//
// Lattice methods.
@@ -762,7 +767,10 @@ void Merger::dumpBits(const BitVector &bits) const {
const TensorId t = tensor(b);
const LoopId i = loop(b);
const auto dlt = lvlTypes[t][i];
- llvm::dbgs() << " i_" << t << "_" << i << "_" << toMLIRString(dlt);
+ if (isLvlWithNonTrivialIdxExp(b))
+ llvm::dbgs() << " DEP_" << t << "_" << i;
+ else
+ llvm::dbgs() << " i_" << t << "_" << i << "_" << toMLIRString(dlt);
}
}
}
More information about the Mlir-commits
mailing list