[Mlir-commits] [mlir] [mlir][sparse] merger cleanup (PR #70371)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Oct 26 12:36:52 PDT 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-sparse
Author: Aart Bik (aartbik)
<details>
<summary>Changes</summary>
Implemented some TODOs and removed unlikely ones.
Comment cleanup
---
Full diff: https://github.com/llvm/llvm-project/pull/70371.diff
2 Files Affected:
- (modified) mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h (+18-60)
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp (+4-8)
``````````diff
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
index 5e7538006757287..215920f8b4607b2 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
@@ -122,19 +122,10 @@ struct TensorExp final {
///
/// The `kLoopVar` leaf kind is for representing `linalg::IndexOp`.
/// That is, its argument is a `LoopId` identifying the loop-variable
-/// in question, and its value will be the current iteration's value
-/// of that loop-variable. See the `LoopId` documentation for more details.
-///
-/// The `kSynZero` leaf kind is for representing a synthetic zero value, which
-/// can be introduced when sparsifying operations like `arith::cmp` to generate
-/// `arith::cmp %lhs, %syn_zero` when the rhs operand is absent.
-//
-// TODO: Modify this definition so that the numeric values already encode
-// the `ExpArity` (while extending the notion of "arity" to include not
-// just the number of `ExprId` children the node has, but also whether the
-// node has a `Value` and/or `Operation*`). Doing this will avoid needing
-// to enumerate all the kinds in `getExpArity` and in the `TensorExp` ctor,
-// and should help clean up a few other places as well.
+/// in question, and its value will be the current iteration's value.
+/// The `kSynZero` leaf kind is for representing a synthetic zero value,
+/// which can be introduced when sparsifying operations like `arith::cmp`
+/// to generate `arith::cmp %lhs, %syn_zero` when the rhs operand is absent.
enum class TensorExp::Kind {
// Leaf.
kTensor = 0,
@@ -253,15 +244,6 @@ class Merger {
///
/// The maxLvlRank specifies the max level rank of all inputs/output tensors.
/// It is used to pre-allocate sufficient memory for internal storage.
- //
- // TODO: we want to make the filter loop more efficient in the future,
- // e.g., by avoiding scanning the full list of stored coordinates (keeping
- // the last position in ordered list) or even apply binary search to find
- // the coordinate.
- //
- // TODO: would be cleaner to understand/document if the first argument
- // gave the number of input tensors, instead of the current number of
- // input+output tensors.
Merger(unsigned numInputOutputTensors, unsigned numNativeLoops,
unsigned numFilterLoops, unsigned maxLvlRank);
@@ -383,12 +365,15 @@ class Merger {
/// Gets the total number of loops (native loops + filter loops).
constexpr unsigned getNumLoops() const { return numLoops; }
+
/// Gets the number of native loops.
constexpr unsigned getNumNativeLoops() const { return numNativeLoops; }
+
/// Gets the number of filter loops.
constexpr unsigned getNumFilterLoops() const {
return numLoops - numNativeLoops;
}
+
/// Gets the identifier of the first filter-loop.
constexpr LoopId getStartingFilterLoopId() const {
return getNumNativeLoops();
@@ -473,8 +458,7 @@ class Merger {
lvlTypes[t][i] = dlt;
loopToLvl[t][i] = lvl;
lvlToLoop[t][lvl] = i;
- // TODO: Maybe we should favor a constant loop bound when there are multiple
- // choices.
+ // TODO: favor a constant loop bound when there are multiple choices.
loopBounds[i] = std::make_pair(t, lvl);
}
@@ -600,43 +584,19 @@ class Merger {
/// Checks whether the given expression has an associated value.
bool hasExprValue(ExprId e) const { return static_cast<bool>(exp(e).val); }
- /// Sets the expression to have the associated value. Asserts that
- /// the new value is defined, and that the expression does not already
- /// have a value. If you want to overwrite a previous associated value,
- /// use `updateExprValue` instead.
+ /// Sets the expression to have the associated value. Asserts that the new
+ /// value is defined, and that the expression does not already have a value.
void setExprValue(ExprId e, Value v) {
- assert(isValidExprId(e));
- assert(v && "Got an undefined value");
- auto &val = tensorExps[e].val;
- assert(!val && "Expression already has an associated value");
- val = v;
+ assert(!exp(e).val && "Expression already has an associated value");
+ assert(v && "Trying to assign an undefined value");
+ tensorExps[e].val = v;
}
- /// Clears the value associated with the expression. Asserts that the
+ /// Clears the value associated with the expression. Asserts that the
/// expression does indeed have an associated value before clearing it.
- /// If you don't want to check for a previous associated value first,
- /// then use `updateExprValue` instead.
void clearExprValue(ExprId e) {
- assert(isValidExprId(e));
- auto &val = tensorExps[e].val;
- assert(val && "Expression does not have an associated value to clear");
- val = Value();
- }
-
- /// Unilaterally updates the expression to have the associated value.
- /// That is, unlike `setExprValue` and `clearExprValue`, this method
- /// does not perform any checks on whether the expression had a
- /// previously associated value nor whether the new value is defined.
- //
- // TODO: The unilateral update semantics are required by the
- // current implementation of `CodegenEnv::genLoopBoundary`; however,
- // that implementation seems a bit dubious. We would much rather have
- // the semantics `{ clearExprValue(e); setExprValue(e, v); }` or
- // `{ clearExprValue(e); if (v) setExprValue(e, v); }` since those
- // provide better invariants.
- void updateExprValue(ExprId e, Value v) {
- assert(isValidExprId(e));
- tensorExps[e].val = v;
+ assert(exp(e).val && "Expression does not have an associated value");
+ tensorExps[e].val = Value();
}
#ifndef NDEBUG
@@ -706,12 +666,10 @@ class Merger {
// `operator[]`: `SmallVector` performs OOB checks, whereas `std::vector`
// does not.
- /// Map that converts pair<TensorId, LoopId> to the corresponding
- /// level-type.
+ /// Map that converts pair<TensorId, LoopId> to the corresponding lvl-type.
std::vector<std::vector<DimLevelType>> lvlTypes;
- /// Map that converts pair<TensorId, LoopId> to the corresponding
- /// level.
+ /// Map that converts pair<TensorId, LoopId> to the corresponding lvl.
std::vector<std::vector<std::optional<Level>>> loopToLvl;
/// Map that converts pair<TensorId, Level> to the corresponding LoopId.
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
index 924b0a0dac8113e..5c7cc93737b7fd7 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
@@ -137,9 +137,6 @@ std::optional<Operation *> CodegenEnv::genLoopBoundary(
auto r = callback(params); // may update parameters
unsigned i = 0;
if (isReduc()) {
- // FIXME: This requires `updateExprValue` to perform updates without
- // checking for a previous value; but it's not clear whether that's
- // by design or might be a potential source for bugs.
updateReduc(params[i++]);
if (redValidLexInsert)
setValidLexInsert(params[i++]);
@@ -283,16 +280,15 @@ void CodegenEnv::endExpand() {
void CodegenEnv::startReduc(ExprId exp, Value val) {
assert(!isReduc() && exp != detail::kInvalidId);
redExp = exp;
- updateReduc(val);
+ redVal = val;
+ latticeMerger.setExprValue(exp, val);
}
void CodegenEnv::updateReduc(Value val) {
assert(isReduc());
redVal = val;
- // NOTE: `genLoopBoundary` requires that this performs a unilateral
- // update without checking for a previous value first. (It's not
- // clear whether any other callsites also require that.)
- latticeMerger.updateExprValue(redExp, val);
+ latticeMerger.clearExprValue(redExp);
+ latticeMerger.setExprValue(redExp, val);
}
Value CodegenEnv::endReduc() {
``````````
</details>
https://github.com/llvm/llvm-project/pull/70371
More information about the Mlir-commits
mailing list