[Mlir-commits] [mlir] 7c7c10a - [mlir][sparse] Updating the `Merger::{exp, lat, set}` methods to return const

wren romano llvmlistbot at llvm.org
Fri Mar 24 14:48:41 PDT 2023


Author: wren romano
Date: 2023-03-24T14:48:33-07:00
New Revision: 7c7c10a0233fc8060aab4082094a189803cbe5ac

URL: https://github.com/llvm/llvm-project/commit/7c7c10a0233fc8060aab4082094a189803cbe5ac
DIFF: https://github.com/llvm/llvm-project/commit/7c7c10a0233fc8060aab4082094a189803cbe5ac.diff

LOG: [mlir][sparse] Updating the `Merger::{exp,lat,set}` methods to return const

This helps the `Merger` maintain invariants, as well as clarifying the immutability of the underlying objects (with the one exception of `TensorExp::val`).

Depends On: D146559

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D146083

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
    mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
    mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
index 8b1e91ae8df56..7e83dfb6bce65 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
@@ -498,12 +498,60 @@ class Merger {
   }
 
   /// Convenience getters to immediately access the stored nodes.
-  /// Typically it is inadvisible to keep the reference around, as in
-  /// `TensorExpr &te = merger.exp(e)`, since insertions into the merger
-  /// may cause data movement and invalidate the underlying memory address.
-  TensorExp &exp(ExprId e) { return tensorExps[e]; }
-  LatPoint &lat(LatPointId p) { return latPoints[p]; }
-  SmallVector<LatPointId> &set(LatSetId s) { return latSets[s]; }
+  /// These methods return `const&` because the underlying objects must
+  /// not be mutated by client code.  The only exception is for mutating
+  /// the value associated with an expression, for which there are
+  /// dedicated methods below.
+  ///
+  /// NOTE: It is inadvisable to keep the reference alive for a long
+  /// time (e.g., as in `TensorExpr &te = merger.exp(e)`), since insertions
+  /// into the merger can cause data movement which will invalidate the
+  /// underlying memory address.  This isn't just a problem with the `&`
+  /// references, but also applies to the `ArrayRef`.  In particular,
+  /// using `for (LatPointId p : merger.set(s))` will run into the same
+  /// dangling-reference problems if the loop body inserts new sets.
+  const TensorExp &exp(ExprId e) const { return tensorExps[e]; }
+  const LatPoint &lat(LatPointId p) const { return latPoints[p]; }
+  ArrayRef<LatPointId> set(LatSetId s) const { return latSets[s]; }
+
+  /// Checks whether the given expression has an associated value.
+  bool hasExprValue(ExprId e) const {
+    return static_cast<bool>(tensorExps[e].val);
+  }
+
+  /// Sets the expression to have the associated value.  Asserts that
+  /// the new value is defined, and that the expression does not already
+  /// have a value.  If you want to overwrite a previous associated value,
+  /// use `updateExprValue` instead.
+  void setExprValue(ExprId e, Value v) {
+    assert(v && "Got an undefined value");
+    auto &val = tensorExps[e].val;
+    assert(!val && "Expression already has an associated value");
+    val = v;
+  }
+
+  /// Clears the value associated with the expression.  Asserts that the
+  /// expression does indeed have an associated value before clearing it.
+  /// If you don't want to check for a previous associated value first,
+  /// then use `updateExprValue` instead.
+  void clearExprValue(ExprId e) {
+    auto &val = tensorExps[e].val;
+    assert(val && "Expression does not have an associated value to clear");
+    val = Value();
+  }
+
+  /// Unilaterally updates the expression to have the associated value.
+  /// That is, unlike `setExprValue` and `clearExprValue`, this method
+  /// does not perform any checks on whether the expression had a
+  /// previously associated value nor whether the new value is defined.
+  //
+  // TODO: The unilateral update semantics are required by the
+  // current implementation of `CodegenEnv::genLoopBoundary`; however,
+  // that implementation seems a bit dubious.  We would much rather have
+  // the semantics `{ clearExprValue(e); setExprValue(e, v); }` or
+  // `{ clearExprValue(e); if (v) setExprValue(e, v); }` since those
+  // provide better invariants.
+  void updateExprValue(ExprId e, Value v) { tensorExps[e].val = v; }
 
 #ifndef NDEBUG
   /// Print methods (for debugging).

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
index 974c86d1fab5a..5d9c347b62327 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
@@ -130,6 +130,9 @@ std::optional<Operation *> CodegenEnv::genLoopBoundary(
   auto r = callback(params); // may update parameters
   unsigned i = 0;
   if (isReduc()) {
+    // FIXME: This requires `updateExprValue` to perform updates without
+    // checking for a previous value; but it's not clear whether that's
+    // by design or might be a potential source for bugs.
     updateReduc(params[i++]);
     if (redValidLexInsert)
       setValidLexInsert(params[i++]);
@@ -281,12 +284,18 @@ void CodegenEnv::startReduc(ExprId exp, Value val) {
 
 void CodegenEnv::updateReduc(Value val) {
   assert(isReduc());
-  redVal = exp(redExp).val = val;
+  redVal = val;
+  // NOTE: `genLoopBoundary` requires that this performs a unilateral
+  // update without checking for a previous value first.  (It's not
+  // clear whether any other callsites also require that.)
+  latticeMerger.updateExprValue(redExp, val);
 }
 
 Value CodegenEnv::endReduc() {
+  assert(isReduc());
   Value val = redVal;
-  updateReduc(Value());
+  redVal = val;
+  latticeMerger.clearExprValue(redExp);
   redExp = kInvalidId;
   return val;
 }

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
index 0041ad0a272cb..e11e2428d86c9 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
@@ -66,9 +66,9 @@ class CodegenEnv {
   // Merger delegates.
   //
 
-  TensorExp &exp(ExprId e) { return latticeMerger.exp(e); }
-  LatPoint &lat(LatPointId l) { return latticeMerger.lat(l); }
-  SmallVector<LatPointId> &set(LatSetId s) { return latticeMerger.set(s); }
+  const TensorExp &exp(ExprId e) const { return latticeMerger.exp(e); }
+  const LatPoint &lat(LatPointId l) const { return latticeMerger.lat(l); }
+  ArrayRef<LatPointId> set(LatSetId s) const { return latticeMerger.set(s); }
   DimLevelType dlt(TensorId t, LoopId i) const {
     return latticeMerger.getDimLevelType(t, i);
   }

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index f760244d59d8b..3343a5103671e 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -1057,7 +1057,7 @@ static void genTensorStore(CodegenEnv &env, OpBuilder &builder, ExprId exp,
       assert(env.exp(exp).val);
       Value v0 = env.exp(exp).val;
       genInsertionStore(env, builder, t, v0);
-      env.exp(exp).val = Value();
+      env.merger().clearExprValue(exp);
       // Yield modified insertion chain along true branch.
       Value mchain = env.getInsertionChain();
       builder.create<scf::YieldOp>(op.getLoc(), mchain);
@@ -1137,10 +1137,8 @@ static Value genExp(CodegenEnv &env, RewriterBase &rewriter, ExprId e,
   if (kind == TensorExp::Kind::kReduce)
     env.endCustomReduc(); // exit custom
 
-  if (kind == TensorExp::Kind::kSelect) {
-    assert(!exp.val);
-    env.exp(e).val = v0; // Preserve value for later use.
-  }
+  if (kind == TensorExp::Kind::kSelect)
+    env.merger().setExprValue(e, v0); // Preserve value for later use.
 
   return ee;
 }
@@ -1192,7 +1190,10 @@ static void genInvariants(CodegenEnv &env, OpBuilder &builder, ExprId exp,
       }
     } else {
       // Start or end loop invariant hoisting of a tensor load.
-      env.exp(exp).val = atStart ? genTensorLoad(env, builder, exp) : Value();
+      if (atStart)
+        env.merger().setExprValue(exp, genTensorLoad(env, builder, exp));
+      else
+        env.merger().clearExprValue(exp);
     }
   } else if (env.exp(exp).kind != TensorExp::Kind::kInvariant &&
              env.exp(exp).kind != TensorExp::Kind::kLoopVar) {
@@ -1346,8 +1347,7 @@ static Operation *genLoop(CodegenEnv &env, OpBuilder &builder, LoopOrd at,
 
 /// Generates the induction structure for a while-loop.
 static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder, LoopId idx,
-                            bool needsUniv, BitVector &induction,
-                            scf::WhileOp whileOp) {
+                            bool needsUniv, scf::WhileOp whileOp) {
   Location loc = env.op().getLoc();
   // Finalize each else branch of all if statements.
   if (env.isReduc() || env.isExpand() || env.getInsertionChain()) {
@@ -1386,7 +1386,7 @@ static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder, LoopId idx,
 
 /// Generates a single if-statement within a while-loop.
 static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId ldx,
-                       BitVector &conditions) {
+                       const BitVector &conditions) {
   Location loc = env.op().getLoc();
   SmallVector<Type> types;
   Value cond;
@@ -1486,13 +1486,10 @@ static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp,
   // Maintain the universal index only if it is actually
   // consumed by a subsequent lattice point.
   if (needsUniv) {
-    unsigned lsize = env.set(lts).size();
-    for (unsigned i = 1; i < lsize; i++) {
-      const LatPointId li = env.set(lts)[i];
+    for (const LatPointId li : env.set(lts).drop_front())
       if (!env.merger().hasAnySparse(env.lat(li).simple) &&
           !env.merger().hasSparseIdxReduction(env.lat(li).simple))
         return true;
-    }
   }
   return false;
 }
@@ -1675,7 +1672,7 @@ static bool endLoop(CodegenEnv &env, RewriterBase &rewriter, Operation *loop,
                     LoopId idx, LatPointId li, bool needsUniv) {
   // End a while-loop.
   if (auto whileOp = dyn_cast<scf::WhileOp>(loop)) {
-    finalizeWhileOp(env, rewriter, idx, needsUniv, env.lat(li).bits, whileOp);
+    finalizeWhileOp(env, rewriter, idx, needsUniv, whileOp);
   } else if (auto forOp = dyn_cast<scf::ForOp>(loop)) {
     // Any iteration of a reduction for-loop creates a valid lex insert.
     if (env.isReduc() && env.getValidLexInsert())
@@ -1726,10 +1723,14 @@ static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp,
   bool needsUniv = startLoopSeq(env, rewriter, exp, at, idx, ldx, lts);
 
   // Emit a loop for every lattice point L0 >= Li in this loop sequence.
-  unsigned lsize = env.set(lts).size();
+  //
+  // NOTE: We cannot change this to `for (const LatPointId li : env.set(lts))`
+  // because the loop body causes data-movement which invalidates
+  // the iterator.
+  const unsigned lsize = env.set(lts).size();
   for (unsigned i = 0; i < lsize; i++) {
-    // Start a loop.
     const LatPointId li = env.set(lts)[i];
+    // Start a loop.
     auto [loop, isSingleCond] = startLoop(env, rewriter, at, li, needsUniv);
 
     // Visit all lattices points with Li >= Lj to generate the
@@ -1737,6 +1738,9 @@ static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp,
     Value redInput = env.getReduc();
     Value cntInput = env.getExpandCount();
     Value insInput = env.getInsertionChain();
+    // NOTE: We cannot change this to `for (const LatPointId lj : env.set(lts))`
+    // because the loop body causes data-movement which invalidates the
+    // iterator.
     for (unsigned j = 0; j < lsize; j++) {
       const LatPointId lj = env.set(lts)[j];
       const ExprId ej = env.lat(lj).exp;


        


More information about the Mlir-commits mailing list