[Mlir-commits] [mlir] [mlir][sparse] use a common util function to query the tensor level s… (PR #76764)
Peiming Liu
llvmlistbot at llvm.org
Tue Jan 2 15:13:15 PST 2024
https://github.com/PeimingLiu created https://github.com/llvm/llvm-project/pull/76764
…et in a lattice point.
>From 00bad41f7857ae52def62169f5d5c4413fd623bc Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Tue, 2 Jan 2024 23:11:19 +0000
Subject: [PATCH] [mlir][sparse] use a common util function to query the tensor
level set in a lattice point.
---
.../Transforms/Sparsification.cpp | 180 +++++++++---------
1 file changed, 86 insertions(+), 94 deletions(-)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 934e1e559f44d6..7be2f30d26d8ba 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -949,94 +949,9 @@ static void endIf(CodegenEnv &env, OpBuilder &builder, scf::IfOp ifOp,
// Sparsifier synthesis methods (loop sequence).
//===----------------------------------------------------------------------===//
-/// Starts a loop sequence at given level. Returns true if
-/// the universal loop index must be maintained at this level.
-static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp,
- LoopId curr, LatSetId lts) {
- assert(!env.getLoopVar(curr));
- // Emit invariants at this loop sequence level.
- genInvariants(env, builder, exp, curr, /*isStart=*/true);
- // Emit access pattern expansion for sparse tensor output.
- genExpand(env, builder, curr, /*isStart=*/true);
- // Emit further intitialization at this loop sequence level.
- const LatPointId l0 = env.set(lts)[0];
- bool needsUniv = false;
-
- SmallVector<TensorLevel> tidLvls;
- env.merger().foreachTensorLoopId(l0, [&](TensorLoopId b, TensorId tid,
- std::optional<Level> lvl,
- LevelType lt, bool isIdxReduc) {
- assert(env.merger().loop(b) == curr);
- if (isDenseLT(lt) || isUndefLT(lt)) {
- if (tid == env.merger().getSynTensorID()) {
- // Needs loop emitter to set up loop bounds for synthetic tensor too if
- // there is a loop condition imposed on the synthetic tensor.
- tidLvls.push_back(env.makeTensorLevel(tid, env.getCurrentDepth()));
- }
- needsUniv = true;
- }
- if (isCompressedLT(lt) || isSingletonLT(lt) || isLooseCompressedLT(lt) ||
- is2OutOf4LT(lt) || isIdxReduc) {
- // Only when this is a index reduction loop, can the lt be undefined.
- assert(!isUndefLT(lt) || isIdxReduc);
- // sparse/singleton levels, or a dense/sparse index reduction loop.
- tidLvls.push_back(env.makeTensorLevel(tid, *lvl));
- }
- });
-
- env.emitter().enterNewLoopSeq(builder, env.op().getLoc(), tidLvls);
-
- // Maintain the universal index only if it is actually
- // consumed by a subsequent lattice point.
- if (needsUniv) {
- for (const LatPointId li : env.set(lts).drop_front())
- if (!env.merger().hasAnySparse(env.lat(li).simple))
- return true;
- }
- return false;
-}
-
-// Generates dense affine address for encoding.
-static void genConstantDenseAddressFromLevel(CodegenEnv &env,
- OpBuilder &builder, TensorId tid,
- Level startLvl) {
- // TODO: Handle affine expression on output tensor.
- linalg::GenericOp op = env.op();
- assert(tid < op.getNumDpsInputs());
- OpOperand *input = op.getDpsInputOperands()[tid];
- const auto lvlExprs = op.getMatchingIndexingMap(input).getResults();
- const auto enc = getSparseTensorEncoding(input->get().getType());
- if (enc) {
- const Location loc = op.getLoc();
- const TensorId tid = env.makeTensorId(input->getOperandNumber());
- const Level lvlRank = enc.getLvlRank();
- assert(lvlExprs.size() == static_cast<size_t>(lvlRank));
- for (Level l = startLvl; l < lvlRank; l++) {
- AffineExpr lvlExpr = lvlExprs[l];
- if (enc.isDenseLvl(l) && isa<AffineConstantExpr>(lvlExpr))
- env.emitter().genDenseAffineAddress(
- builder, loc, env.makeTensorLevel(tid, l), lvlExpr);
- else
- return; // break on first non-dense non-constant level
- }
- }
-}
-
-// We can generate address for constant affine expression before any loops
-// starting from the first level as they do not depend on any thing.
-// E.g., [Dense, Dense, Sparse] -> (1, 2, d0), the addresses for the first two
-// levels can be determined before loops.
-static void genInitConstantDenseAddress(CodegenEnv &env,
- RewriterBase &rewriter) {
- for (TensorId tid = 0, e = env.op().getNumDpsInputs(); tid < e; tid++)
- genConstantDenseAddressFromLevel(env, rewriter, tid, 0);
-}
-
-/// Return true if the lattices bit can be iterated by a for loop.
-static bool translateBitsToTidLvlPairs(
+static bool getAllTidLvlsInLatPoints(
CodegenEnv &env, LatPointId li, LoopId curr,
- SmallVectorImpl<TensorLevel> &tidLvls,
- SmallVectorImpl<std::pair<TensorLevel, AffineExpr>> &affineTidLvls) {
+ llvm::function_ref<void(TensorLevel, AffineExpr)> callback) {
const BitVector &simple = env.lat(li).simple;
const TensorId outTid = env.merger().getOutTensorID();
const std::optional<Level> outLvl = env.merger().getLvl(outTid, curr);
@@ -1048,7 +963,7 @@ static bool translateBitsToTidLvlPairs(
LevelType lt, bool isIdxReduc) {
if (simple[b]) {
if (isIdxReduc) {
- tidLvls.push_back(env.makeTensorLevel(tid, *lvl));
+ callback(env.makeTensorLevel(tid, *lvl), nullptr);
numloopCond++;
return;
}
@@ -1072,10 +987,10 @@ static bool translateBitsToTidLvlPairs(
}
}
hasNonUnique = !isUniqueLT(lt) || hasNonUnique;
- tidLvls.push_back(env.makeTensorLevel(tid, *lvl));
+ callback(env.makeTensorLevel(tid, *lvl), nullptr);
numloopCond++;
} else if (isDenseLT(lt) || isIdxReduc) {
- tidLvls.push_back(env.makeTensorLevel(tid, *lvl));
+ callback(env.makeTensorLevel(tid, *lvl), nullptr);
} else {
assert(isUndefLT(lt));
linalg::GenericOp op = env.op();
@@ -1109,7 +1024,7 @@ static bool translateBitsToTidLvlPairs(
// level. We need to generate the address according to the
// affine expression. This is also the best place we can do it
// to avoid putting it inside inner loops.
- affineTidLvls.emplace_back(env.makeTensorLevel(tid, l), exp);
+ callback(env.makeTensorLevel(tid, l), exp);
}
}
}
@@ -1120,15 +1035,14 @@ static bool translateBitsToTidLvlPairs(
// Note that we generate dense indices of the output tensor
// unconditionally, since they may not appear in the lattice, but may be
// needed for linearized env.
- tidLvls.push_back(env.makeTensorLevel(outTid, *outLvl));
+ callback(env.makeTensorLevel(outTid, *outLvl), nullptr);
}
if (numloopCond == 0) {
// Corner cases where the loop bound is defined by a *unused* operand, in
// this case, we just generate a dense "fake" loop by iterating over the
// synthetic tensor.
- tidLvls.push_back(env.makeTensorLevel(env.merger().getSynTensorID(),
- env.getCurrentDepth()));
+ callback(env.makeTensorLevel(env.merger().getSynTensorID(), curr), nullptr);
numloopCond++;
}
// If we just need to one loop conditions and the conditions is not imposed on
@@ -1136,6 +1050,84 @@ static bool translateBitsToTidLvlPairs(
return numloopCond == 1 && !hasNonUnique;
}
+/// Starts a loop sequence at given level. Returns true if
+/// the universal loop index must be maintained at this level.
+static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp,
+ LoopId curr, LatSetId lts) {
+ assert(!env.getLoopVar(curr));
+ // Emit invariants at this loop sequence level.
+ genInvariants(env, builder, exp, curr, /*isStart=*/true);
+ // Emit access pattern expansion for sparse tensor output.
+ genExpand(env, builder, curr, /*isStart=*/true);
+ // Emit further initialization at this loop sequence level.
+ const LatPointId l0 = env.set(lts)[0];
+
+ SmallVector<TensorLevel> tidLvls;
+ getAllTidLvlsInLatPoints(env, l0, curr, [&](TensorLevel tl, AffineExpr) {
+ tidLvls.emplace_back(tl);
+ });
+
+ env.emitter().enterNewLoopSeq(builder, env.op().getLoc(), tidLvls);
+
+ // Maintain the universal index only if it is actually
+ // consumed by a subsequent lattice point.
+ for (const LatPointId li : env.set(lts).drop_front())
+ if (!env.merger().hasAnySparse(env.lat(li).simple))
+ return true;
+
+ return false;
+}
+
+// Generates dense affine address for encoding.
+static void genConstantDenseAddressFromLevel(CodegenEnv &env,
+ OpBuilder &builder, TensorId tid,
+ Level startLvl) {
+ // TODO: Handle affine expression on output tensor.
+ linalg::GenericOp op = env.op();
+ assert(tid < op.getNumDpsInputs());
+ OpOperand *input = op.getDpsInputOperands()[tid];
+ const auto lvlExprs = op.getMatchingIndexingMap(input).getResults();
+ const auto enc = getSparseTensorEncoding(input->get().getType());
+ if (enc) {
+ const Location loc = op.getLoc();
+ const TensorId tid = env.makeTensorId(input->getOperandNumber());
+ const Level lvlRank = enc.getLvlRank();
+ assert(lvlExprs.size() == static_cast<size_t>(lvlRank));
+ for (Level l = startLvl; l < lvlRank; l++) {
+ AffineExpr lvlExpr = lvlExprs[l];
+ if (enc.isDenseLvl(l) && isa<AffineConstantExpr>(lvlExpr))
+ env.emitter().genDenseAffineAddress(
+ builder, loc, env.makeTensorLevel(tid, l), lvlExpr);
+ else
+ return; // break on first non-dense non-constant level
+ }
+ }
+}
+
+// We can generate address for constant affine expression before any loops
+// starting from the first level as they do not depend on any thing.
+// E.g., [Dense, Dense, Sparse] -> (1, 2, d0), the addresses for the first two
+// levels can be determined before loops.
+static void genInitConstantDenseAddress(CodegenEnv &env,
+ RewriterBase &rewriter) {
+ for (TensorId tid = 0, e = env.op().getNumDpsInputs(); tid < e; tid++)
+ genConstantDenseAddressFromLevel(env, rewriter, tid, 0);
+}
+
+/// Return true if the lattices bit can be iterated by a for loop.
+static bool translateBitsToTidLvlPairs(
+ CodegenEnv &env, LatPointId li, LoopId curr,
+ SmallVectorImpl<TensorLevel> &tidLvls,
+ SmallVectorImpl<std::pair<TensorLevel, AffineExpr>> &affineTidLvls) {
+ return getAllTidLvlsInLatPoints(env, li, curr,
+ [&](TensorLevel tl, AffineExpr exp) {
+ if (exp)
+ affineTidLvls.emplace_back(tl, exp);
+ else
+ tidLvls.emplace_back(tl);
+ });
+}
+
/// Starts a single loop in current sequence.
static std::pair<Operation *, bool> startLoop(CodegenEnv &env,
OpBuilder &builder, LoopId curr,
More information about the Mlir-commits
mailing list