[Mlir-commits] [mlir] [mlir][sparse] use a common util function to query the tensor level s… (PR #76764)
Yinying Li
llvmlistbot at llvm.org
Tue Jan 2 15:22:15 PST 2024
================
@@ -1120,22 +1035,99 @@ static bool translateBitsToTidLvlPairs(
// Note that we generate dense indices of the output tensor
// unconditionally, since they may not appear in the lattice, but may be
// needed for linearized env.
- tidLvls.push_back(env.makeTensorLevel(outTid, *outLvl));
+ callback(env.makeTensorLevel(outTid, *outLvl), nullptr);
}
if (numloopCond == 0) {
// Corner cases where the loop bound is defined by a *unused* operand, in
// this case, we just generate a dense "fake" loop by iterating over the
// synthetic tensor.
- tidLvls.push_back(env.makeTensorLevel(env.merger().getSynTensorID(),
- env.getCurrentDepth()));
+ callback(env.makeTensorLevel(env.merger().getSynTensorID(), curr), nullptr);
numloopCond++;
}
// If we just need to one loop conditions and the conditions is not imposed on
// non-unique level, the loop can be generated by a for loop.
return numloopCond == 1 && !hasNonUnique;
}
+/// Starts a loop sequence at given level. Returns true if
+/// the universal loop index must be maintained at this level.
+static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp,
+ LoopId curr, LatSetId lts) {
+ assert(!env.getLoopVar(curr));
+ // Emit invariants at this loop sequence level.
+ genInvariants(env, builder, exp, curr, /*isStart=*/true);
+ // Emit access pattern expansion for sparse tensor output.
+ genExpand(env, builder, curr, /*isStart=*/true);
+ // Emit further initialization at this loop sequence level.
+ const LatPointId l0 = env.set(lts)[0];
+
+ SmallVector<TensorLevel> tidLvls;
+ getAllTidLvlsInLatPoints(env, l0, curr, [&](TensorLevel tl, AffineExpr) {
+ tidLvls.emplace_back(tl);
+ });
+
+ env.emitter().enterNewLoopSeq(builder, env.op().getLoc(), tidLvls);
+
+ // Maintain the universal index only if it is actually
+ // consumed by a subsequent lattice point.
+ for (const LatPointId li : env.set(lts).drop_front())
+ if (!env.merger().hasAnySparse(env.lat(li).simple))
+ return true;
+
+ return false;
+}
+
+// Generates dense affine address for encoding.
+static void genConstantDenseAddressFromLevel(CodegenEnv &env,
+ OpBuilder &builder, TensorId tid,
+ Level startLvl) {
+ // TODO: Handle affine expression on output tensor.
+ linalg::GenericOp op = env.op();
+ assert(tid < op.getNumDpsInputs());
+ OpOperand *input = op.getDpsInputOperands()[tid];
+ const auto lvlExprs = op.getMatchingIndexingMap(input).getResults();
+ const auto enc = getSparseTensorEncoding(input->get().getType());
+ if (enc) {
+ const Location loc = op.getLoc();
+ const TensorId tid = env.makeTensorId(input->getOperandNumber());
+ const Level lvlRank = enc.getLvlRank();
+ assert(lvlExprs.size() == static_cast<size_t>(lvlRank));
+ for (Level l = startLvl; l < lvlRank; l++) {
+ AffineExpr lvlExpr = lvlExprs[l];
+ if (enc.isDenseLvl(l) && isa<AffineConstantExpr>(lvlExpr))
+ env.emitter().genDenseAffineAddress(
+ builder, loc, env.makeTensorLevel(tid, l), lvlExpr);
+ else
+ return; // break on first non-dense non-constant level
+ }
+ }
+}
+
+// We can generate address for constant affine expression before any loops
+// starting from the first level as they do not depend on any thing.
+// E.g., [Dense, Dense, Sparse] -> (1, 2, d0), the addresses for the first two
+// levels can be determined before loops.
+static void genInitConstantDenseAddress(CodegenEnv &env,
+ RewriterBase &rewriter) {
+ for (TensorId tid = 0, e = env.op().getNumDpsInputs(); tid < e; tid++)
+ genConstantDenseAddressFromLevel(env, rewriter, tid, 0);
+}
+
+/// Return true if the lattices bit can be iterated by a for loop.
----------------
yinying-lisa-li wrote:
nit: Returns. (consistent with all the other outmost comments)
https://github.com/llvm/llvm-project/pull/76764
More information about the Mlir-commits
mailing list