[Mlir-commits] [mlir] [mlir][sparse] merger cleanup (PR #70371)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Oct 26 12:36:52 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-sparse

Author: Aart Bik (aartbik)

<details>
<summary>Changes</summary>

Implemented some TODOs and removed unlikely ones.
Comment cleanup

---
Full diff: https://github.com/llvm/llvm-project/pull/70371.diff


2 Files Affected:

- (modified) mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h (+18-60) 
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp (+4-8) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
index 5e7538006757287..215920f8b4607b2 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
@@ -122,19 +122,10 @@ struct TensorExp final {
 ///
 /// The `kLoopVar` leaf kind is for representing `linalg::IndexOp`.
 /// That is, its argument is a `LoopId` identifying the loop-variable
-/// in question, and its value will be the current iteration's value
-/// of that loop-variable.  See the `LoopId` documentation for more details.
-///
-/// The `kSynZero` leaf kind is for representing a synthetic zero value, which
-/// can be introduced when sparsifying operations like `arith::cmp` to generate
-/// `arith::cmp %lhs, %syn_zero` when the rhs operand is absent.
-//
-// TODO: Modify this definition so that the numeric values already encode
-// the `ExpArity` (while extending the notion of "arity" to include not
-// just the number of `ExprId` children the node has, but also whether the
-// node has a `Value` and/or `Operation*`).  Doing this will avoid needing
-// to enumerate all the kinds in `getExpArity` and in the `TensorExp` ctor,
-// and should help clean up a few other places as well.
+/// in question, and its value will be the current iteration's value.
+/// The `kSynZero` leaf kind is for representing a synthetic zero value,
+/// which can be introduced when sparsifying operations like `arith::cmp`
+/// to generate `arith::cmp %lhs, %syn_zero` when the rhs operand is absent.
 enum class TensorExp::Kind {
   // Leaf.
   kTensor = 0,
@@ -253,15 +244,6 @@ class Merger {
   ///
   /// The maxLvlRank specifies the max level rank of all inputs/output tensors.
   /// It is used to pre-allocate sufficient memory for internal storage.
-  //
-  // TODO: we want to make the filter loop more efficient in the future,
-  // e.g., by avoiding scanning the full list of stored coordinates (keeping
-  // the last position in ordered list) or even apply binary search to find
-  // the coordinate.
-  //
-  // TODO: would be cleaner to understand/document if the first argument
-  // gave the number of input tensors, instead of the current number of
-  // input+output tensors.
   Merger(unsigned numInputOutputTensors, unsigned numNativeLoops,
          unsigned numFilterLoops, unsigned maxLvlRank);
 
@@ -383,12 +365,15 @@ class Merger {
 
   /// Gets the total number of loops (native loops + filter loops).
   constexpr unsigned getNumLoops() const { return numLoops; }
+
   /// Gets the number of native loops.
   constexpr unsigned getNumNativeLoops() const { return numNativeLoops; }
+
   /// Gets the number of filter loops.
   constexpr unsigned getNumFilterLoops() const {
     return numLoops - numNativeLoops;
   }
+
   /// Gets the identifier of the first filter-loop.
   constexpr LoopId getStartingFilterLoopId() const {
     return getNumNativeLoops();
@@ -473,8 +458,7 @@ class Merger {
     lvlTypes[t][i] = dlt;
     loopToLvl[t][i] = lvl;
     lvlToLoop[t][lvl] = i;
-    // TODO: Maybe we should favor a constant loop bound when there are multiple
-    // choices.
+    // TODO: favor a constant loop bound when there are multiple choices.
     loopBounds[i] = std::make_pair(t, lvl);
   }
 
@@ -600,43 +584,19 @@ class Merger {
   /// Checks whether the given expression has an associated value.
   bool hasExprValue(ExprId e) const { return static_cast<bool>(exp(e).val); }
 
-  /// Sets the expression to have the associated value.  Asserts that
-  /// the new value is defined, and that the expression does not already
-  /// have a value.  If you want to overwrite a previous associated value,
-  /// use `updateExprValue` instead.
+  /// Sets the expression to have the associated value. Asserts that the new
+  /// value is defined, and that the expression does not already have a value.
   void setExprValue(ExprId e, Value v) {
-    assert(isValidExprId(e));
-    assert(v && "Got an undefined value");
-    auto &val = tensorExps[e].val;
-    assert(!val && "Expression already has an associated value");
-    val = v;
+    assert(!exp(e).val && "Expression already has an associated value");
+    assert(v && "Trying to assign an undefined value");
+    tensorExps[e].val = v;
   }
 
-  /// Clears the value associated with the expression.  Asserts that the
+  /// Clears the value associated with the expression. Asserts that the
   /// expression does indeed have an associated value before clearing it.
-  /// If you don't want to check for a previous associated value first,
-  /// then use `updateExprValue` instead.
   void clearExprValue(ExprId e) {
-    assert(isValidExprId(e));
-    auto &val = tensorExps[e].val;
-    assert(val && "Expression does not have an associated value to clear");
-    val = Value();
-  }
-
-  /// Unilaterally updates the expression to have the associated value.
-  /// That is, unlike `setExprValue` and `clearExprValue`, this method
-  /// does not perform any checks on whether the expression had a
-  /// previously associated value nor whether the new value is defined.
-  //
-  // TODO: The unilateral update semantics are required by the
-  // current implementation of `CodegenEnv::genLoopBoundary`; however,
-  // that implementation seems a bit dubious.  We would much rather have
-  // the semantics `{ clearExprValue(e); setExprValue(e, v); }` or
-  // `{ clearExprValue(e); if (v) setExprValue(e, v); }` since those
-  // provide better invariants.
-  void updateExprValue(ExprId e, Value v) {
-    assert(isValidExprId(e));
-    tensorExps[e].val = v;
+    assert(exp(e).val && "Expression does not have an associated value");
+    tensorExps[e].val = Value();
   }
 
 #ifndef NDEBUG
@@ -706,12 +666,10 @@ class Merger {
   // `operator[]`: `SmallVector` performs OOB checks, whereas `std::vector`
   // does not.
 
-  /// Map that converts pair<TensorId, LoopId> to the corresponding
-  /// level-type.
+  /// Map that converts pair<TensorId, LoopId> to the corresponding lvl-type.
   std::vector<std::vector<DimLevelType>> lvlTypes;
 
-  /// Map that converts pair<TensorId, LoopId> to the corresponding
-  /// level.
+  /// Map that converts pair<TensorId, LoopId> to the corresponding lvl.
   std::vector<std::vector<std::optional<Level>>> loopToLvl;
 
   /// Map that converts pair<TensorId, Level> to the corresponding LoopId.
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
index 924b0a0dac8113e..5c7cc93737b7fd7 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
@@ -137,9 +137,6 @@ std::optional<Operation *> CodegenEnv::genLoopBoundary(
   auto r = callback(params); // may update parameters
   unsigned i = 0;
   if (isReduc()) {
-    // FIXME: This requires `updateExprValue` to perform updates without
-    // checking for a previous value; but it's not clear whether that's
-    // by design or might be a potential source for bugs.
     updateReduc(params[i++]);
     if (redValidLexInsert)
       setValidLexInsert(params[i++]);
@@ -283,16 +280,15 @@ void CodegenEnv::endExpand() {
 void CodegenEnv::startReduc(ExprId exp, Value val) {
   assert(!isReduc() && exp != detail::kInvalidId);
   redExp = exp;
-  updateReduc(val);
+  redVal = val;
+  latticeMerger.setExprValue(exp, val);
 }
 
 void CodegenEnv::updateReduc(Value val) {
   assert(isReduc());
   redVal = val;
-  // NOTE: `genLoopBoundary` requires that this performs a unilateral
-  // update without checking for a previous value first.  (It's not
-  // clear whether any other callsites also require that.)
-  latticeMerger.updateExprValue(redExp, val);
+  latticeMerger.clearExprValue(redExp);
+  latticeMerger.setExprValue(redExp, val);
 }
 
 Value CodegenEnv::endReduc() {

``````````

</details>


https://github.com/llvm/llvm-project/pull/70371


More information about the Mlir-commits mailing list