[Mlir-commits] [mlir] 7c7c10a - [mlir][sparse] Updating the `Merger::{exp, lat, set}` methods to return const
wren romano
llvmlistbot at llvm.org
Fri Mar 24 14:48:41 PDT 2023
Author: wren romano
Date: 2023-03-24T14:48:33-07:00
New Revision: 7c7c10a0233fc8060aab4082094a189803cbe5ac
URL: https://github.com/llvm/llvm-project/commit/7c7c10a0233fc8060aab4082094a189803cbe5ac
DIFF: https://github.com/llvm/llvm-project/commit/7c7c10a0233fc8060aab4082094a189803cbe5ac.diff
LOG: [mlir][sparse] Updating the `Merger::{exp,lat,set}` methods to return const
This helps the `Merger` maintain invariants, as well as clarifying the immutability of the underlying objects (with the one exception of `TensorExp::val`).
Depends On: D146559
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D146083
Added:
Modified:
mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
index 8b1e91ae8df56..7e83dfb6bce65 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
@@ -498,12 +498,60 @@ class Merger {
}
/// Convenience getters to immediately access the stored nodes.
- /// Typically it is inadvisible to keep the reference around, as in
- /// `TensorExpr &te = merger.exp(e)`, since insertions into the merger
- /// may cause data movement and invalidate the underlying memory address.
- TensorExp &exp(ExprId e) { return tensorExps[e]; }
- LatPoint &lat(LatPointId p) { return latPoints[p]; }
- SmallVector<LatPointId> &set(LatSetId s) { return latSets[s]; }
+ /// These methods return `const&` because the underlying objects must
+ /// not be mutated by client code. The only exception is for mutating
+ /// the value associated with an expression, for which there are
+ /// dedicated methods below.
+ ///
+ /// NOTE: It is inadvisable to keep the reference alive for a long
+ /// time (e.g., as in `TensorExpr &te = merger.exp(e)`), since insertions
+ /// into the merger can cause data movement which will invalidate the
+ /// underlying memory address. This isn't just a problem with the `&`
+ /// references, but also applies to the `ArrayRef`. In particular,
+ /// using `for (LatPointId p : merger.set(s))` will run into the same
+ /// dangling-reference problems if the loop body inserts new sets.
+ const TensorExp &exp(ExprId e) const { return tensorExps[e]; }
+ const LatPoint &lat(LatPointId p) const { return latPoints[p]; }
+ ArrayRef<LatPointId> set(LatSetId s) const { return latSets[s]; }
+
+ /// Checks whether the given expression has an associated value.
+ bool hasExprValue(ExprId e) const {
+ return static_cast<bool>(tensorExps[e].val);
+ }
+
+ /// Sets the expression to have the associated value. Asserts that
+ /// the new value is defined, and that the expression does not already
+ /// have a value. If you want to overwrite a previous associated value,
+ /// use `updateExprValue` instead.
+ void setExprValue(ExprId e, Value v) {
+ assert(v && "Got an undefined value");
+ auto &val = tensorExps[e].val;
+ assert(!val && "Expression already has an associated value");
+ val = v;
+ }
+
+ /// Clears the value associated with the expression. Asserts that the
+ /// expression does indeed have an associated value before clearing it.
+ /// If you don't want to check for a previous associated value first,
+ /// then use `updateExprValue` instead.
+ void clearExprValue(ExprId e) {
+ auto &val = tensorExps[e].val;
+ assert(val && "Expression does not have an associated value to clear");
+ val = Value();
+ }
+
+ /// Unilaterally updates the expression to have the associated value.
+ /// That is, unlike `setExprValue` and `clearExprValue`, this method
+ /// does not perform any checks on whether the expression had a
+ /// previously associated value nor whether the new value is defined.
+ //
+ // TODO: The unilateral update semantics are required by the
+ // current implementation of `CodegenEnv::genLoopBoundary`; however,
+ // that implementation seems a bit dubious. We would much rather have
+ // the semantics `{ clearExprValue(e); setExprValue(e, v); }` or
+ // `{ clearExprValue(e); if (v) setExprValue(e, v); }` since those
+ // provide better invariants.
+ void updateExprValue(ExprId e, Value v) { tensorExps[e].val = v; }
#ifndef NDEBUG
/// Print methods (for debugging).
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
index 974c86d1fab5a..5d9c347b62327 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
@@ -130,6 +130,9 @@ std::optional<Operation *> CodegenEnv::genLoopBoundary(
auto r = callback(params); // may update parameters
unsigned i = 0;
if (isReduc()) {
+ // FIXME: This requires `updateExprValue` to perform updates without
+ // checking for a previous value; but it's not clear whether that's
+ // by design or might be a potential source for bugs.
updateReduc(params[i++]);
if (redValidLexInsert)
setValidLexInsert(params[i++]);
@@ -281,12 +284,18 @@ void CodegenEnv::startReduc(ExprId exp, Value val) {
void CodegenEnv::updateReduc(Value val) {
assert(isReduc());
- redVal = exp(redExp).val = val;
+ redVal = val;
+ // NOTE: `genLoopBoundary` requires that this performs a unilateral
+ // update without checking for a previous value first. (It's not
+ // clear whether any other callsites also require that.)
+ latticeMerger.updateExprValue(redExp, val);
}
Value CodegenEnv::endReduc() {
+ assert(isReduc());
Value val = redVal;
- updateReduc(Value());
+ redVal = val;
+ latticeMerger.clearExprValue(redExp);
redExp = kInvalidId;
return val;
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
index 0041ad0a272cb..e11e2428d86c9 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
@@ -66,9 +66,9 @@ class CodegenEnv {
// Merger delegates.
//
- TensorExp &exp(ExprId e) { return latticeMerger.exp(e); }
- LatPoint &lat(LatPointId l) { return latticeMerger.lat(l); }
- SmallVector<LatPointId> &set(LatSetId s) { return latticeMerger.set(s); }
+ const TensorExp &exp(ExprId e) const { return latticeMerger.exp(e); }
+ const LatPoint &lat(LatPointId l) const { return latticeMerger.lat(l); }
+ ArrayRef<LatPointId> set(LatSetId s) const { return latticeMerger.set(s); }
DimLevelType dlt(TensorId t, LoopId i) const {
return latticeMerger.getDimLevelType(t, i);
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index f760244d59d8b..3343a5103671e 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -1057,7 +1057,7 @@ static void genTensorStore(CodegenEnv &env, OpBuilder &builder, ExprId exp,
assert(env.exp(exp).val);
Value v0 = env.exp(exp).val;
genInsertionStore(env, builder, t, v0);
- env.exp(exp).val = Value();
+ env.merger().clearExprValue(exp);
// Yield modified insertion chain along true branch.
Value mchain = env.getInsertionChain();
builder.create<scf::YieldOp>(op.getLoc(), mchain);
@@ -1137,10 +1137,8 @@ static Value genExp(CodegenEnv &env, RewriterBase &rewriter, ExprId e,
if (kind == TensorExp::Kind::kReduce)
env.endCustomReduc(); // exit custom
- if (kind == TensorExp::Kind::kSelect) {
- assert(!exp.val);
- env.exp(e).val = v0; // Preserve value for later use.
- }
+ if (kind == TensorExp::Kind::kSelect)
+ env.merger().setExprValue(e, v0); // Preserve value for later use.
return ee;
}
@@ -1192,7 +1190,10 @@ static void genInvariants(CodegenEnv &env, OpBuilder &builder, ExprId exp,
}
} else {
// Start or end loop invariant hoisting of a tensor load.
- env.exp(exp).val = atStart ? genTensorLoad(env, builder, exp) : Value();
+ if (atStart)
+ env.merger().setExprValue(exp, genTensorLoad(env, builder, exp));
+ else
+ env.merger().clearExprValue(exp);
}
} else if (env.exp(exp).kind != TensorExp::Kind::kInvariant &&
env.exp(exp).kind != TensorExp::Kind::kLoopVar) {
@@ -1346,8 +1347,7 @@ static Operation *genLoop(CodegenEnv &env, OpBuilder &builder, LoopOrd at,
/// Generates the induction structure for a while-loop.
static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder, LoopId idx,
- bool needsUniv, BitVector &induction,
- scf::WhileOp whileOp) {
+ bool needsUniv, scf::WhileOp whileOp) {
Location loc = env.op().getLoc();
// Finalize each else branch of all if statements.
if (env.isReduc() || env.isExpand() || env.getInsertionChain()) {
@@ -1386,7 +1386,7 @@ static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder, LoopId idx,
/// Generates a single if-statement within a while-loop.
static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId ldx,
- BitVector &conditions) {
+ const BitVector &conditions) {
Location loc = env.op().getLoc();
SmallVector<Type> types;
Value cond;
@@ -1486,13 +1486,10 @@ static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp,
// Maintain the universal index only if it is actually
// consumed by a subsequent lattice point.
if (needsUniv) {
- unsigned lsize = env.set(lts).size();
- for (unsigned i = 1; i < lsize; i++) {
- const LatPointId li = env.set(lts)[i];
+ for (const LatPointId li : env.set(lts).drop_front())
if (!env.merger().hasAnySparse(env.lat(li).simple) &&
!env.merger().hasSparseIdxReduction(env.lat(li).simple))
return true;
- }
}
return false;
}
@@ -1675,7 +1672,7 @@ static bool endLoop(CodegenEnv &env, RewriterBase &rewriter, Operation *loop,
LoopId idx, LatPointId li, bool needsUniv) {
// End a while-loop.
if (auto whileOp = dyn_cast<scf::WhileOp>(loop)) {
- finalizeWhileOp(env, rewriter, idx, needsUniv, env.lat(li).bits, whileOp);
+ finalizeWhileOp(env, rewriter, idx, needsUniv, whileOp);
} else if (auto forOp = dyn_cast<scf::ForOp>(loop)) {
// Any iteration of a reduction for-loop creates a valid lex insert.
if (env.isReduc() && env.getValidLexInsert())
@@ -1726,10 +1723,14 @@ static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp,
bool needsUniv = startLoopSeq(env, rewriter, exp, at, idx, ldx, lts);
// Emit a loop for every lattice point L0 >= Li in this loop sequence.
- unsigned lsize = env.set(lts).size();
+ //
+ // NOTE: We cannot change this to `for (const LatPointId li : env.set(lts))`
+ // because the loop body causes data-movement which invalidates
+ // the iterator.
+ const unsigned lsize = env.set(lts).size();
for (unsigned i = 0; i < lsize; i++) {
- // Start a loop.
const LatPointId li = env.set(lts)[i];
+ // Start a loop.
auto [loop, isSingleCond] = startLoop(env, rewriter, at, li, needsUniv);
// Visit all lattices points with Li >= Lj to generate the
@@ -1737,6 +1738,9 @@ static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp,
Value redInput = env.getReduc();
Value cntInput = env.getExpandCount();
Value insInput = env.getInsertionChain();
+ // NOTE: We cannot change this to `for (const LatPointId lj : env.set(lts))`
+ // because the loop body causes data-movement which invalidates the
+ // iterator.
for (unsigned j = 0; j < lsize; j++) {
const LatPointId lj = env.set(lts)[j];
const ExprId ej = env.lat(lj).exp;
More information about the Mlir-commits
mailing list