[Mlir-commits] [mlir] [mlir][sparse] cleanup ldx/idx/depth/at usage (PR #74654)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Dec 6 13:02:17 PST 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-sparse
Author: Aart Bik (aartbik)
<details>
<summary>Changes</summary>
This adds a consistent usage with `at` for everything that refers to the current loop nesting. This cleans up some redundant legacy code from when we were still using topSort inside sparsifier code.
---
Full diff: https://github.com/llvm/llvm-project/pull/74654.diff
1 Files Affected:
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp (+69-70)
``````````diff
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index d03e9615d340e..6637a26d0e5af 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -44,23 +44,23 @@ using namespace mlir::sparse_tensor;
// Sparsifier analysis methods.
//===----------------------------------------------------------------------===//
-/// Determines if affine expression is invariant.
-static bool isInvariantAffine(AffineExpr a, unsigned loopDepth, LoopId ldx,
- bool &isAtLoop) {
+/// Returns true iff affine expression is invariant. Sets the
+/// parameter `isAtLoop` when expression just became invariant.
+static bool isInvariantAffine(AffineExpr a, LoopId at, bool &isAtLoop) {
switch (a.getKind()) {
case AffineExprKind::DimId: {
const LoopId i = cast<AffineDimExpr>(a).getPosition();
- if (i == ldx) {
+ if (i + 1 == at) {
isAtLoop = true;
- return true; // invariant at given loop
+ return true; // becomes invariant at current loop
}
- return i < loopDepth; // invariant when already generated
+ return i < at; // invariant when already generated
}
case AffineExprKind::Add:
case AffineExprKind::Mul: {
auto binOp = cast<AffineBinaryOpExpr>(a);
- return isInvariantAffine(binOp.getLHS(), loopDepth, ldx, isAtLoop) &&
- isInvariantAffine(binOp.getRHS(), loopDepth, ldx, isAtLoop);
+ return isInvariantAffine(binOp.getLHS(), at, isAtLoop) &&
+ isInvariantAffine(binOp.getRHS(), at, isAtLoop);
}
default: {
assert(isa<AffineConstantExpr>(a));
@@ -126,8 +126,8 @@ static bool findDepIdxSet(Merger &merger, TensorId tensor, Level lvl,
if (coefficient <= 0)
return false;
- const LoopId ldx = merger.makeLoopId(cast<AffineDimExpr>(a).getPosition());
- if (!isUndefLT(merger.getLvlType(tensor, ldx)))
+ const LoopId idx = merger.makeLoopId(cast<AffineDimExpr>(a).getPosition());
+ if (!isUndefLT(merger.getLvlType(tensor, idx)))
return false; // used more than once, e.g., A[i][i]
// TODO: Generalizes the following two cases. A[i] (with trivial index
@@ -135,14 +135,14 @@ static bool findDepIdxSet(Merger &merger, TensorId tensor, Level lvl,
// not necessarily need to differentiate them.
if (!isSubExp) {
assert(coefficient == 1);
- merger.setLevelAndType(tensor, ldx, lvl, lt);
+ merger.setLevelAndType(tensor, idx, lvl, lt);
}
if (isSubExp) {
// The current loops appears in more than one affine expressions on the
// same tensor. We can not handle this case. e.g., A[i+j][i+k], `i` is
// used twice.
- if (merger.hasDependentLvl(ldx, tensor)) {
+ if (merger.hasDependentLvl(idx, tensor)) {
// TODO: This can be supported by coiterate slices if the loop idx is
// appeared on affine index for different tensor, or take slice on
// multiple dimensions when it is on the same tensor.
@@ -154,7 +154,7 @@ static bool findDepIdxSet(Merger &merger, TensorId tensor, Level lvl,
// else increase min(d0_1, d0_2).
return false;
}
- merger.setLoopDependentTensorLevel(ldx, tensor, lvl, lt, coefficient);
+ merger.setLoopDependentTensorLevel(idx, tensor, lvl, lt, coefficient);
}
return true;
}
@@ -613,9 +613,9 @@ static Value genExp(CodegenEnv &env, RewriterBase &rewriter, ExprId e) {
if (kind == TensorExp::Kind::kReduce)
env.startCustomReduc(e); // enter custom
- Value v0, v1;
// If either lhs/rhs is a synthetic zero, we infer the type for the zero value
// based on the type of the other operand.
+ Value v0, v1;
if (exp.children.e0 != ::mlir::sparse_tensor::detail::kInvalidId &&
env.exp(exp.children.e0).kind == TensorExp::Kind::kSynZero) {
v1 = genExp(env, rewriter, exp.children.e1);
@@ -655,21 +655,21 @@ static Value genExp(CodegenEnv &env, RewriterBase &rewriter, ExprId e) {
/// Hoists loop invariant tensor loads for which indices have been exhausted.
static void genInvariants(CodegenEnv &env, OpBuilder &builder, ExprId exp,
- LoopId ldx, bool atStart) {
+ LoopId at, bool atStart) {
if (exp == ::mlir::sparse_tensor::detail::kInvalidId)
return;
if (env.exp(exp).kind == TensorExp::Kind::kTensor) {
// Inspect tensor indices.
- bool isAtLoop = ldx == ::mlir::sparse_tensor::detail::kInvalidId;
linalg::GenericOp op = env.op();
OpOperand &t = op->getOpOperand(env.exp(exp).tensor);
const auto map = op.getMatchingIndexingMap(&t);
const auto stt = getSparseTensorType(t.get());
const Level lvlRank = stt.getLvlRank();
assert(static_cast<Level>(map.getNumResults()) == lvlRank);
+ bool isAtLoop = at == 0; // for scalar tensors
for (Level l = 0; l < lvlRank; l++) {
const AffineExpr a = map.getResult(l);
- if (!isInvariantAffine(a, env.getLoopDepth(), ldx, isAtLoop))
+ if (!isInvariantAffine(a, at, /*out*/ isAtLoop))
return; // still in play
}
// All exhausted at this level (isAtLoop denotes exactly at this LoopId).
@@ -705,8 +705,8 @@ static void genInvariants(CodegenEnv &env, OpBuilder &builder, ExprId exp,
env.startCustomReduc(exp); // enter custom
const ExprId e0 = env.exp(exp).children.e0;
const ExprId e1 = env.exp(exp).children.e1;
- genInvariants(env, builder, e0, ldx, atStart);
- genInvariants(env, builder, e1, ldx, atStart);
+ genInvariants(env, builder, e0, at, atStart);
+ genInvariants(env, builder, e1, at, atStart);
if (env.exp(exp).kind == TensorExp::Kind::kReduce)
env.endCustomReduc(); // exit custom
}
@@ -782,29 +782,28 @@ static bool isParallelFor(CodegenEnv &env, bool isOuter, bool isSparse) {
/// Whether or not the current loop being generated should be parallized (if
/// possible) according to the configuration.
-static bool shouldTryParallize(CodegenEnv &env, LoopId ldx, bool isOuter,
+static bool shouldTryParallize(CodegenEnv &env, LoopId at,
ArrayRef<TensorLevel> tidLvls) {
linalg::GenericOp op = env.op();
auto iteratorTypes = op.getIteratorTypesArray();
- bool isSparse = llvm::any_of(tidLvls, [ldx, &env](TensorLevel tidLvl) {
- // Queries the LT based on the tensor id and loop idx, as requested by
- // `CodegenEnv::lt(TensorId, LoopIdx)`. The returned LT from CodegenEnv
+ bool isSparse = llvm::any_of(tidLvls, [at, &env](TensorLevel tidLvl) {
+ // Queries the LT based on the tensor and loop id, as requested by
+ // `CodegenEnv::lt(TensorId, LoopId)`. The returned LT from CodegenEnv
// should be consistent with the LT indexed by <TensorId, Level>.
- const auto lt = env.lt(env.unpackTensorLevel(tidLvl).first, ldx);
+ const auto lt = env.lt(env.unpackTensorLevel(tidLvl).first, at);
return isCompressedLT(lt) || isSingletonLT(lt);
});
- return isParallelFor(env, isOuter, isSparse);
+ return isParallelFor(env, /*isOuter=*/at == 0, isSparse);
}
/// Emit a loop to coiterate over the list of tensor levels. The generated loop
/// can either be a for loop or while loop depending on whether there is at most
/// one sparse level in the list.
static Operation *genCoIteration(CodegenEnv &env, OpBuilder &builder,
- LoopId idx, ArrayRef<TensorLevel> tidLvls,
+ ArrayRef<TensorLevel> tidLvls,
bool tryParallel, bool needsUniv) {
Operation *loop = *env.genLoopBoundary([&](MutableArrayRef<Value> reduc) {
- // Construct the while-loop with a parameter for each
- // index.
+ // Construct while-loop with a parameter for each index.
return env.emitter().enterCoIterationOverTensorsAtLvls(
builder, env.op().getLoc(), tidLvls, reduc, tryParallel,
/*genDedup=*/true, needsUniv);
@@ -817,12 +816,12 @@ static Operation *genCoIteration(CodegenEnv &env, OpBuilder &builder,
/// singleton iteration or co-iteration over the given conjunction.
static Operation *genLoop(CodegenEnv &env, OpBuilder &builder, LoopId at,
bool needsUniv, ArrayRef<TensorLevel> tidLvls) {
- bool tryParallel = shouldTryParallize(env, at, at == 0, tidLvls);
- return genCoIteration(env, builder, at, tidLvls, tryParallel, needsUniv);
+ bool tryParallel = shouldTryParallize(env, at, tidLvls);
+ return genCoIteration(env, builder, tidLvls, tryParallel, needsUniv);
}
/// Generates the induction structure for a while-loop.
-static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder, LoopId idx,
+static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder,
bool needsUniv) {
Location loc = env.op().getLoc();
// Finalize each else branch of all if statements.
@@ -862,7 +861,7 @@ static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder, LoopId idx,
}
/// Generates a single if-statement within a while-loop.
-static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId ldx,
+static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId at,
LatPointId p) {
Location loc = env.op().getLoc();
SmallVector<Type> types;
@@ -880,13 +879,13 @@ static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId ldx,
auto stt = getSparseTensorType(env.op().getInputs()[tid]);
lt = stt.getLvlType(*lvl);
}
- assert(ldx == env.merger().loop(b));
+ assert(at == env.merger().loop(b));
Value clause;
if (isCompressedLT(lt) || isSingletonLT(lt) ||
isLooseCompressedLT(lt) || is2OutOf4LT(lt)) {
assert(lvl.has_value());
const Value crd = env.emitter().getCoords()[tid][*lvl];
- const Value lvar = env.getLoopVar(ldx);
+ const Value lvar = env.getLoopVar(at);
clause = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
crd, lvar);
} else {
@@ -943,12 +942,12 @@ static void endIf(CodegenEnv &env, OpBuilder &builder, scf::IfOp ifOp,
/// Starts a loop sequence at given level. Returns true if
/// the universal loop index must be maintained at this level.
static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp,
- LoopId idx, LoopId ldx, LatSetId lts) {
- assert(!env.getLoopVar(idx));
+ LoopId at, LatSetId lts) {
+ assert(!env.getLoopVar(at));
// Emit invariants at this loop sequence level.
- genInvariants(env, builder, exp, ldx, /*atStart=*/true);
+ genInvariants(env, builder, exp, at, /*atStart=*/true);
// Emit access pattern expansion for sparse tensor output.
- genExpand(env, builder, idx, /*atStart=*/true);
+ genExpand(env, builder, at, /*atStart=*/true);
// Emit further intitialization at this loop sequence level.
const LatPointId l0 = env.set(lts)[0];
bool needsUniv = false;
@@ -957,7 +956,7 @@ static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp,
env.merger().foreachTensorLoopId(l0, [&](TensorLoopId b, TensorId tid,
std::optional<Level> lvl,
LevelType lt, bool isIdxReduc) {
- assert(env.merger().loop(b) == idx);
+ assert(env.merger().loop(b) == at);
if (isDenseLT(lt) || isUndefLT(lt)) {
if (tid == env.merger().getSynTensorID()) {
// Needs loop emitter to set up loop bounds for synthetic tensor too if
@@ -988,6 +987,7 @@ static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp,
return false;
}
+// Generates dense affine address for encoding.
static void genConstantDenseAddressFromLevel(CodegenEnv &env,
OpBuilder &builder, TensorId tid,
Level startLvl) {
@@ -1013,30 +1013,30 @@ static void genConstantDenseAddressFromLevel(CodegenEnv &env,
}
}
+// We can generate address for constant affine expression before any loops
+// starting from the first level as they do not depend on any thing.
+// E.g., [Dense, Dense, Sparse] -> (1, 2, d0), the addresses for the first two
+// levels can be determined before loops.
static void genInitConstantDenseAddress(CodegenEnv &env,
RewriterBase &rewriter) {
- // We can generate address for constant affine expression before any loops
- // starting from the first level as they do not depend on any thing.
- // E.g., [Dense, Dense, Sparse] -> (1, 2, d0), the addresses for the first two
- // levels can be determined before loops.
for (TensorId tid = 0, e = env.op().getNumDpsInputs(); tid < e; tid++)
genConstantDenseAddressFromLevel(env, rewriter, tid, 0);
}
/// Return true if the lattices bit can be iterated by a for loop.
static bool translateBitsToTidLvlPairs(
- CodegenEnv &env, LatPointId li, LoopId ldx,
+ CodegenEnv &env, LatPointId li, LoopId at,
SmallVectorImpl<TensorLevel> &tidLvls,
SmallVectorImpl<std::pair<TensorLevel, AffineExpr>> &affineTidLvls) {
const BitVector &simple = env.lat(li).simple;
const TensorId outTid = env.merger().getOutTensorID();
- const std::optional<Level> outLvl = env.merger().getLvl(outTid, ldx);
+ const std::optional<Level> outLvl = env.merger().getLvl(outTid, at);
unsigned numloopCond = 0;
bool hasNonUnique = false;
- env.merger().foreachTensorLoopId(li, [&, ldx](TensorLoopId b, TensorId tid,
- std::optional<Level> lvl,
- LevelType lt, bool isIdxReduc) {
+ env.merger().foreachTensorLoopId(li, [&, at](TensorLoopId b, TensorId tid,
+ std::optional<Level> lvl,
+ LevelType lt, bool isIdxReduc) {
if (simple[b]) {
if (isIdxReduc) {
tidLvls.push_back(env.makeTensorLevel(tid, *lvl));
@@ -1089,11 +1089,11 @@ static bool translateBitsToTidLvlPairs(
if (isa<AffineDimExpr>(exp) || !stt.isDenseLvl(l))
continue;
- // Constant affine expression are handled in genLoop
+ // Constant affine expression are handled in genLoop.
if (!isa<AffineConstantExpr>(exp)) {
bool isAtLoop = false;
- if (isInvariantAffine(exp, env.getLoopDepth(), ldx, isAtLoop) &&
- isAtLoop) {
+ assert(at == env.getLoopDepth());
+ if (isInvariantAffine(exp, at + 1, /*out*/ isAtLoop) && isAtLoop) {
// If the compound affine is invariant and we are right at the
// level. We need to generate the address according to the
// affine expression. This is also the best place we can do it
@@ -1105,7 +1105,7 @@ static bool translateBitsToTidLvlPairs(
}
});
- if (isDenseLT(env.lt(outTid, ldx))) {
+ if (isDenseLT(env.lt(outTid, at))) {
// Note that we generate dense indices of the output tensor
// unconditionally, since they may not appear in the lattice, but may be
// needed for linearized env.
@@ -1131,9 +1131,9 @@ static std::pair<Operation *, bool> startLoop(CodegenEnv &env,
LatPointId li, bool needsUniv) {
// The set of tensors + lvls to generate loops on
SmallVector<TensorLevel> tidLvls;
+
// The set of dense tensors with non-trivial affine expression that just
- // becomes invariant and the address shall now be generated at the current
- // level.
+ // becomes invariant and the address are generated at the current level.
SmallVector<std::pair<TensorLevel, AffineExpr>> affineTidLvls;
bool isSingleCond =
translateBitsToTidLvlPairs(env, li, at, tidLvls, affineTidLvls);
@@ -1161,38 +1161,34 @@ static std::pair<Operation *, bool> startLoop(CodegenEnv &env,
/// Ends a single loop in current sequence. Returns new values for needsUniv.
static bool endLoop(CodegenEnv &env, RewriterBase &rewriter, Operation *loop,
- LoopId idx, LatPointId li, bool needsUniv,
- bool isSingleCond) {
-
+ LatPointId li, bool needsUniv, bool isSingleCond) {
+ // Either a for-loop or a while-loop that iterates over a slice.
if (isSingleCond) {
- // Either a for-loop or a while-loop that iterates over a slice.
// Any iteration creates a valid lex insert.
if (env.isReduc() && env.getValidLexInsert())
env.setValidLexInsert(constantI1(rewriter, env.op().getLoc(), true));
} else if (auto whileOp = dyn_cast<scf::WhileOp>(loop)) {
// End a while-loop.
- finalizeWhileOp(env, rewriter, idx, needsUniv);
+ finalizeWhileOp(env, rewriter, needsUniv);
} else {
needsUniv = false;
}
-
env.genLoopBoundary([&](MutableArrayRef<Value> reduc) {
env.emitter().exitCurrentLoop(rewriter, env.op().getLoc(), reduc);
return std::nullopt;
});
-
return needsUniv;
}
/// Ends a loop sequence at given level.
static void endLoopSeq(CodegenEnv &env, OpBuilder &builder, unsigned exp,
- unsigned idx, unsigned ldx) {
- assert(!env.getLoopVar(idx));
+ unsigned at) {
+ assert(!env.getLoopVar(at));
env.emitter().exitCurrentLoopSeq(builder, env.op().getLoc());
// Unmark bookkeeping of invariants and loop index.
- genInvariants(env, builder, exp, ldx, /*atStart=*/false);
+ genInvariants(env, builder, exp, at, /*atStart=*/false);
// Finalize access pattern expansion for sparse tensor output.
- genExpand(env, builder, idx, /*atStart=*/false);
+ genExpand(env, builder, at, /*atStart=*/false);
}
/// Recursively generates code while computing iteration lattices in order
@@ -1200,6 +1196,8 @@ static void endLoopSeq(CodegenEnv &env, OpBuilder &builder, unsigned exp,
/// and intersections of sparse iterations spaces.
static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp,
LoopId at) {
+ assert(at == env.getLoopDepth());
+
// At each leaf, assign remaining tensor (sub)expression to output tensor.
if (at == env.getLoopNum()) {
Value rhs = genExp(env, rewriter, exp);
@@ -1207,13 +1205,12 @@ static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp,
return;
}
- // Construct iteration lattices for current loop index, with L0 at top.
- const LoopId ldx = at == 0 ? sparse_tensor::detail::kInvalidId : at - 1;
+ // Construct iteration lattices for current loop index.
const LatSetId lts =
env.merger().optimizeSet(env.merger().buildLattices(exp, at));
// Start a loop sequence.
- bool needsUniv = startLoopSeq(env, rewriter, exp, at, ldx, lts);
+ bool needsUniv = startLoopSeq(env, rewriter, exp, at, lts);
// Emit a loop for every lattice point L0 >= Li in this loop sequence.
// We cannot change this to `for (const LatPointId li : env.set(lts))`
@@ -1250,11 +1247,12 @@ static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp,
}
// End a loop.
- needsUniv = endLoop(env, rewriter, loop, at, li, needsUniv, isSingleCond);
+ needsUniv = endLoop(env, rewriter, loop, at, needsUniv, isSingleCond);
}
// End a loop sequence.
- endLoopSeq(env, rewriter, exp, at, ldx);
+ endLoopSeq(env, rewriter, exp, at);
+ assert(at == env.getLoopDepth());
}
/// Converts the result computed by the sparse kernel into the required form.
@@ -1309,6 +1307,7 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
op, "Loops not yet scheduled, try run --sparse-reinterpret-map "
"before sparsification.");
}
+
// Must have been demapped as well if the generic op is sorted.
assert(!hasAnyNonIdentityOperandsOrResults(op));
``````````
</details>
https://github.com/llvm/llvm-project/pull/74654
More information about the Mlir-commits
mailing list