[Mlir-commits] [mlir] 372d88b - [mlir][sparse] code cleanup.
Peiming Liu
llvmlistbot at llvm.org
Wed Aug 2 16:20:01 PDT 2023
Author: Peiming Liu
Date: 2023-08-02T23:19:55Z
New Revision: 372d88b051515f54264bf6568395bcb7f3db4de1
URL: https://github.com/llvm/llvm-project/commit/372d88b051515f54264bf6568395bcb7f3db4de1
DIFF: https://github.com/llvm/llvm-project/commit/372d88b051515f54264bf6568395bcb7f3db4de1.diff
LOG: [mlir][sparse] code cleanup.
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D156941
Added:
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
index 315cc1d05e9266..30f4c1db3c4cd0 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
@@ -184,24 +184,32 @@ class LoopEmitter {
void exitCurrentLoop(RewriterBase &rewriter, Location loc,
MutableArrayRef<Value> reduc = {});
+ /// Get the range of values for all induction variables.
+ auto getLoopIVsRange() const {
+ return llvm::map_range(loopStack, [](const LoopInfo &li) { return li.iv; });
+ }
+
/// Fills the out-parameter with the loop induction variables for all
/// loops in the current loop-stack. The variables are given in the
/// same order as the loop-stack, hence `ivs` should be indexed into
/// by `LoopOrd` (not `LoopId`).
- void getLoopIVs(SmallVectorImpl<Value> &ivs) const {
- ivs.clear();
- ivs.reserve(getCurrentDepth());
- for (auto &l : loopStack)
- ivs.push_back(l.iv);
+ SmallVector<Value> getLoopIVs() const {
+ return llvm::to_vector(getLoopIVsRange());
}
/// Gets the current depth of the loop-stack. The result is given
/// the type `LoopOrd` for the same reason as one-past-the-end iterators.
- LoopOrd getCurrentDepth() const { return loopStack.size(); }
+ LoopOrd getCurrentDepth() const {
+ return llvm::range_size(getLoopIVsRange());
+ }
/// Gets loop induction variable for the given `LoopOrd`.
Value getLoopIV(LoopOrd n) const {
- return n < getCurrentDepth() ? loopStack[n].iv : Value();
+ if (n >= getCurrentDepth())
+ return Value();
+ auto it = getLoopIVsRange().begin();
+ std::advance(it, n);
+ return *it;
}
/// Gets the total number of manifest tensors (excluding the synthetic
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index ebbe88ee902948..2a290f202c70a5 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -1318,10 +1318,7 @@ struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
reduc);
}
- SmallVector<Value> lcvs;
- lcvs.reserve(lvlRank);
- loopEmitter.getLoopIVs(lcvs);
-
+ SmallVector<Value> lcvs = loopEmitter.getLoopIVs();
if (op.getOrder()) {
// FIXME: There is some dim/lvl confusion here since `dimRank != lvlRank`
SmallVector<Value> dcvs = lcvs; // keep a copy
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index b75cdba8449a1a..2450fd6c7d03f6 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -977,17 +977,10 @@ static void genInsertionStore(CodegenEnv &env, OpBuilder &builder, OpOperand *t,
// Direct insertion in lexicographic coordinate order.
if (!env.isExpand()) {
const LoopOrd numLoops = op.getRank(t);
- // TODO: rewrite this to use `env.emitter().getLoopIVs(ivs)`
- // instead. We just need to either assert that `numLoops ==
- // env.emitter().getCurrentDepth()`, or else update the `getLoopIVs`
- // method to take an optional parameter to restrict to a smaller depth.
- SmallVector<Value> ivs;
- ivs.reserve(numLoops);
- for (LoopOrd n = 0; n < numLoops; n++) {
- const auto iv = env.emitter().getLoopIV(n);
- assert(iv);
- ivs.push_back(iv);
- }
+ // Retrieves the first `numLoop` induction variables.
+ SmallVector<Value> ivs = llvm::to_vector(
+ llvm::drop_end(env.emitter().getLoopIVsRange(),
+ env.emitter().getCurrentDepth() - numLoops));
Value chain = env.getInsertionChain();
if (!env.getValidLexInsert()) {
env.updateInsertionChain(builder.create<InsertOp>(loc, rhs, chain, ivs));
@@ -1438,7 +1431,7 @@ static Operation *genLoop(CodegenEnv &env, OpBuilder &builder, LoopOrd at,
/// Generates the induction structure for a while-loop.
static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder, LoopId idx,
- bool needsUniv, scf::WhileOp whileOp) {
+ bool needsUniv) {
Location loc = env.op().getLoc();
// Finalize each else branch of all if statements.
if (env.isReduc() || env.isExpand() || env.getInsertionChain()) {
@@ -1472,7 +1465,8 @@ static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder, LoopId idx,
builder.setInsertionPointAfter(ifOp);
}
}
- builder.setInsertionPointToEnd(&whileOp.getAfter().front());
+ // No need to set the insertion point here as LoopEmitter keeps track of the
+ // basic block where scf::Yield should be inserted.
}
/// Generates a single if-statement within a while-loop.
@@ -1525,8 +1519,8 @@ static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId ldx,
/// Generates end of true branch of if-statement within a while-loop.
static void endIf(CodegenEnv &env, OpBuilder &builder, scf::IfOp ifOp,
- Operation *loop, Value redInput, Value cntInput,
- Value insInput, Value validIns) {
+ Value redInput, Value cntInput, Value insInput,
+ Value validIns) {
SmallVector<Value> operands;
if (env.isReduc()) {
operands.push_back(env.getReduc());
@@ -1800,7 +1794,7 @@ static bool endLoop(CodegenEnv &env, RewriterBase &rewriter, Operation *loop,
env.setValidLexInsert(constantI1(rewriter, env.op().getLoc(), true));
} else if (auto whileOp = dyn_cast<scf::WhileOp>(loop)) {
// End a while-loop.
- finalizeWhileOp(env, rewriter, idx, needsUniv, whileOp);
+ finalizeWhileOp(env, rewriter, idx, needsUniv);
} else {
needsUniv = false;
}
@@ -1875,8 +1869,7 @@ static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp,
if (!isSingleCond) {
scf::IfOp ifOp = genIf(env, rewriter, idx, lj);
genStmt(env, rewriter, ej, at + 1);
- endIf(env, rewriter, ifOp, loop, redInput, cntInput, insInput,
- validIns);
+ endIf(env, rewriter, ifOp, redInput, cntInput, insInput, validIns);
} else {
genStmt(env, rewriter, ej, at + 1);
}
More information about the Mlir-commits
mailing list