[Mlir-commits] [mlir] 98ce2de - [mlir][sparse] cleanup ldx/idx/depth/at usage (#74654)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Dec 6 13:23:54 PST 2023


Author: Aart Bik
Date: 2023-12-06T13:23:50-08:00
New Revision: 98ce2debc6ff3f6d31d7b63eb54e10e88a84ee78

URL: https://github.com/llvm/llvm-project/commit/98ce2debc6ff3f6d31d7b63eb54e10e88a84ee78
DIFF: https://github.com/llvm/llvm-project/commit/98ce2debc6ff3f6d31d7b63eb54e10e88a84ee78.diff

LOG: [mlir][sparse] cleanup ldx/idx/depth/at usage (#74654)

This adds a consistent usage with `at` for everything that refers to the
current loop nesting. This cleans up some redundant legacy code from
when we were still using topSort inside sparsifier code.

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index d03e9615d340e..6637a26d0e5af 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -44,23 +44,23 @@ using namespace mlir::sparse_tensor;
 // Sparsifier analysis methods.
 //===----------------------------------------------------------------------===//
 
-/// Determines if affine expression is invariant.
-static bool isInvariantAffine(AffineExpr a, unsigned loopDepth, LoopId ldx,
-                              bool &isAtLoop) {
+/// Returns true iff affine expression is invariant. Sets the
+/// parameter `isAtLoop` when expression just became invariant.
+static bool isInvariantAffine(AffineExpr a, LoopId at, bool &isAtLoop) {
   switch (a.getKind()) {
   case AffineExprKind::DimId: {
     const LoopId i = cast<AffineDimExpr>(a).getPosition();
-    if (i == ldx) {
+    if (i + 1 == at) {
       isAtLoop = true;
-      return true; // invariant at given loop
+      return true; // becomes invariant at current loop
     }
-    return i < loopDepth; // invariant when already generated
+    return i < at; // invariant when already generated
   }
   case AffineExprKind::Add:
   case AffineExprKind::Mul: {
     auto binOp = cast<AffineBinaryOpExpr>(a);
-    return isInvariantAffine(binOp.getLHS(), loopDepth, ldx, isAtLoop) &&
-           isInvariantAffine(binOp.getRHS(), loopDepth, ldx, isAtLoop);
+    return isInvariantAffine(binOp.getLHS(), at, isAtLoop) &&
+           isInvariantAffine(binOp.getRHS(), at, isAtLoop);
   }
   default: {
     assert(isa<AffineConstantExpr>(a));
@@ -126,8 +126,8 @@ static bool findDepIdxSet(Merger &merger, TensorId tensor, Level lvl,
     if (coefficient <= 0)
       return false;
 
-    const LoopId ldx = merger.makeLoopId(cast<AffineDimExpr>(a).getPosition());
-    if (!isUndefLT(merger.getLvlType(tensor, ldx)))
+    const LoopId idx = merger.makeLoopId(cast<AffineDimExpr>(a).getPosition());
+    if (!isUndefLT(merger.getLvlType(tensor, idx)))
       return false; // used more than once, e.g., A[i][i]
 
     // TODO: Generalizes the following two cases. A[i] (with trivial index
@@ -135,14 +135,14 @@ static bool findDepIdxSet(Merger &merger, TensorId tensor, Level lvl,
     // not necessarily need to 
diff erentiate them.
     if (!isSubExp) {
       assert(coefficient == 1);
-      merger.setLevelAndType(tensor, ldx, lvl, lt);
+      merger.setLevelAndType(tensor, idx, lvl, lt);
     }
 
     if (isSubExp) {
       // The current loops appears in more than one affine expressions on the
       // same tensor. We can not handle this case. e.g., A[i+j][i+k], `i` is
       // used twice.
-      if (merger.hasDependentLvl(ldx, tensor)) {
+      if (merger.hasDependentLvl(idx, tensor)) {
         // TODO: This can be supported by coiterate slices if the loop idx is
         // appeared on affine index for 
diff erent tensor, or take slice on
         // multiple dimensions when it is on the same tensor.
@@ -154,7 +154,7 @@ static bool findDepIdxSet(Merger &merger, TensorId tensor, Level lvl,
         // else increase min(d0_1, d0_2).
         return false;
       }
-      merger.setLoopDependentTensorLevel(ldx, tensor, lvl, lt, coefficient);
+      merger.setLoopDependentTensorLevel(idx, tensor, lvl, lt, coefficient);
     }
     return true;
   }
@@ -613,9 +613,9 @@ static Value genExp(CodegenEnv &env, RewriterBase &rewriter, ExprId e) {
   if (kind == TensorExp::Kind::kReduce)
     env.startCustomReduc(e); // enter custom
 
-  Value v0, v1;
   // If either lhs/rhs is a synthetic zero, we infer the type for the zero value
   // based on the type of the other operand.
+  Value v0, v1;
   if (exp.children.e0 != ::mlir::sparse_tensor::detail::kInvalidId &&
       env.exp(exp.children.e0).kind == TensorExp::Kind::kSynZero) {
     v1 = genExp(env, rewriter, exp.children.e1);
@@ -655,21 +655,21 @@ static Value genExp(CodegenEnv &env, RewriterBase &rewriter, ExprId e) {
 
 /// Hoists loop invariant tensor loads for which indices have been exhausted.
 static void genInvariants(CodegenEnv &env, OpBuilder &builder, ExprId exp,
-                          LoopId ldx, bool atStart) {
+                          LoopId at, bool atStart) {
   if (exp == ::mlir::sparse_tensor::detail::kInvalidId)
     return;
   if (env.exp(exp).kind == TensorExp::Kind::kTensor) {
     // Inspect tensor indices.
-    bool isAtLoop = ldx == ::mlir::sparse_tensor::detail::kInvalidId;
     linalg::GenericOp op = env.op();
     OpOperand &t = op->getOpOperand(env.exp(exp).tensor);
     const auto map = op.getMatchingIndexingMap(&t);
     const auto stt = getSparseTensorType(t.get());
     const Level lvlRank = stt.getLvlRank();
     assert(static_cast<Level>(map.getNumResults()) == lvlRank);
+    bool isAtLoop = at == 0; // for scalar tensors
     for (Level l = 0; l < lvlRank; l++) {
       const AffineExpr a = map.getResult(l);
-      if (!isInvariantAffine(a, env.getLoopDepth(), ldx, isAtLoop))
+      if (!isInvariantAffine(a, at, /*out*/ isAtLoop))
         return; // still in play
     }
     // All exhausted at this level (isAtLoop denotes exactly at this LoopId).
@@ -705,8 +705,8 @@ static void genInvariants(CodegenEnv &env, OpBuilder &builder, ExprId exp,
       env.startCustomReduc(exp); // enter custom
     const ExprId e0 = env.exp(exp).children.e0;
     const ExprId e1 = env.exp(exp).children.e1;
-    genInvariants(env, builder, e0, ldx, atStart);
-    genInvariants(env, builder, e1, ldx, atStart);
+    genInvariants(env, builder, e0, at, atStart);
+    genInvariants(env, builder, e1, at, atStart);
     if (env.exp(exp).kind == TensorExp::Kind::kReduce)
       env.endCustomReduc(); // exit custom
   }
@@ -782,29 +782,28 @@ static bool isParallelFor(CodegenEnv &env, bool isOuter, bool isSparse) {
 
 /// Whether or not the current loop being generated should be parallized (if
 /// possible) according to the configuration.
-static bool shouldTryParallize(CodegenEnv &env, LoopId ldx, bool isOuter,
+static bool shouldTryParallize(CodegenEnv &env, LoopId at,
                                ArrayRef<TensorLevel> tidLvls) {
   linalg::GenericOp op = env.op();
   auto iteratorTypes = op.getIteratorTypesArray();
-  bool isSparse = llvm::any_of(tidLvls, [ldx, &env](TensorLevel tidLvl) {
-    // Queries the LT based on the tensor id and loop idx, as requested by
-    // `CodegenEnv::lt(TensorId, LoopIdx)`. The returned LT from CodegenEnv
+  bool isSparse = llvm::any_of(tidLvls, [at, &env](TensorLevel tidLvl) {
+    // Queries the LT based on the tensor and loop id, as requested by
+    // `CodegenEnv::lt(TensorId, LoopId)`. The returned LT from CodegenEnv
     // should be consistent with the LT indexed by <TensorId, Level>.
-    const auto lt = env.lt(env.unpackTensorLevel(tidLvl).first, ldx);
+    const auto lt = env.lt(env.unpackTensorLevel(tidLvl).first, at);
     return isCompressedLT(lt) || isSingletonLT(lt);
   });
-  return isParallelFor(env, isOuter, isSparse);
+  return isParallelFor(env, /*isOuter=*/at == 0, isSparse);
 }
 
 /// Emit a loop to coiterate over the list of tensor levels. The generated loop
 /// can either be a for loop or while loop depending on whether there is at most
 /// one sparse level in the list.
 static Operation *genCoIteration(CodegenEnv &env, OpBuilder &builder,
-                                 LoopId idx, ArrayRef<TensorLevel> tidLvls,
+                                 ArrayRef<TensorLevel> tidLvls,
                                  bool tryParallel, bool needsUniv) {
   Operation *loop = *env.genLoopBoundary([&](MutableArrayRef<Value> reduc) {
-    // Construct the while-loop with a parameter for each
-    // index.
+    // Construct while-loop with a parameter for each index.
     return env.emitter().enterCoIterationOverTensorsAtLvls(
         builder, env.op().getLoc(), tidLvls, reduc, tryParallel,
         /*genDedup=*/true, needsUniv);
@@ -817,12 +816,12 @@ static Operation *genCoIteration(CodegenEnv &env, OpBuilder &builder,
 /// singleton iteration or co-iteration over the given conjunction.
 static Operation *genLoop(CodegenEnv &env, OpBuilder &builder, LoopId at,
                           bool needsUniv, ArrayRef<TensorLevel> tidLvls) {
-  bool tryParallel = shouldTryParallize(env, at, at == 0, tidLvls);
-  return genCoIteration(env, builder, at, tidLvls, tryParallel, needsUniv);
+  bool tryParallel = shouldTryParallize(env, at, tidLvls);
+  return genCoIteration(env, builder, tidLvls, tryParallel, needsUniv);
 }
 
 /// Generates the induction structure for a while-loop.
-static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder, LoopId idx,
+static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder,
                             bool needsUniv) {
   Location loc = env.op().getLoc();
   // Finalize each else branch of all if statements.
@@ -862,7 +861,7 @@ static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder, LoopId idx,
 }
 
 /// Generates a single if-statement within a while-loop.
-static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId ldx,
+static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId at,
                        LatPointId p) {
   Location loc = env.op().getLoc();
   SmallVector<Type> types;
@@ -880,13 +879,13 @@ static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId ldx,
           auto stt = getSparseTensorType(env.op().getInputs()[tid]);
           lt = stt.getLvlType(*lvl);
         }
-        assert(ldx == env.merger().loop(b));
+        assert(at == env.merger().loop(b));
         Value clause;
         if (isCompressedLT(lt) || isSingletonLT(lt) ||
             isLooseCompressedLT(lt) || is2OutOf4LT(lt)) {
           assert(lvl.has_value());
           const Value crd = env.emitter().getCoords()[tid][*lvl];
-          const Value lvar = env.getLoopVar(ldx);
+          const Value lvar = env.getLoopVar(at);
           clause = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
                                                  crd, lvar);
         } else {
@@ -943,12 +942,12 @@ static void endIf(CodegenEnv &env, OpBuilder &builder, scf::IfOp ifOp,
 /// Starts a loop sequence at given level. Returns true if
 /// the universal loop index must be maintained at this level.
 static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp,
-                         LoopId idx, LoopId ldx, LatSetId lts) {
-  assert(!env.getLoopVar(idx));
+                         LoopId at, LatSetId lts) {
+  assert(!env.getLoopVar(at));
   // Emit invariants at this loop sequence level.
-  genInvariants(env, builder, exp, ldx, /*atStart=*/true);
+  genInvariants(env, builder, exp, at, /*atStart=*/true);
   // Emit access pattern expansion for sparse tensor output.
-  genExpand(env, builder, idx, /*atStart=*/true);
+  genExpand(env, builder, at, /*atStart=*/true);
   // Emit further intitialization at this loop sequence level.
   const LatPointId l0 = env.set(lts)[0];
   bool needsUniv = false;
@@ -957,7 +956,7 @@ static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp,
   env.merger().foreachTensorLoopId(l0, [&](TensorLoopId b, TensorId tid,
                                            std::optional<Level> lvl,
                                            LevelType lt, bool isIdxReduc) {
-    assert(env.merger().loop(b) == idx);
+    assert(env.merger().loop(b) == at);
     if (isDenseLT(lt) || isUndefLT(lt)) {
       if (tid == env.merger().getSynTensorID()) {
         // Needs loop emitter to set up loop bounds for synthetic tensor too if
@@ -988,6 +987,7 @@ static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp,
   return false;
 }
 
+// Generates dense affine address for encoding.
 static void genConstantDenseAddressFromLevel(CodegenEnv &env,
                                              OpBuilder &builder, TensorId tid,
                                              Level startLvl) {
@@ -1013,30 +1013,30 @@ static void genConstantDenseAddressFromLevel(CodegenEnv &env,
   }
 }
 
+// We can generate address for constant affine expression before any loops
+// starting from the first level as they do not depend on any thing.
+// E.g., [Dense, Dense, Sparse] -> (1, 2, d0), the addresses for the first two
+// levels can be determined before loops.
 static void genInitConstantDenseAddress(CodegenEnv &env,
                                         RewriterBase &rewriter) {
-  // We can generate address for constant affine expression before any loops
-  // starting from the first level as they do not depend on any thing.
-  // E.g., [Dense, Dense, Sparse] -> (1, 2, d0), the addresses for the first two
-  // levels can be determined before loops.
   for (TensorId tid = 0, e = env.op().getNumDpsInputs(); tid < e; tid++)
     genConstantDenseAddressFromLevel(env, rewriter, tid, 0);
 }
 
 /// Return true if the lattices bit can be iterated by a for loop.
 static bool translateBitsToTidLvlPairs(
-    CodegenEnv &env, LatPointId li, LoopId ldx,
+    CodegenEnv &env, LatPointId li, LoopId at,
     SmallVectorImpl<TensorLevel> &tidLvls,
     SmallVectorImpl<std::pair<TensorLevel, AffineExpr>> &affineTidLvls) {
   const BitVector &simple = env.lat(li).simple;
   const TensorId outTid = env.merger().getOutTensorID();
-  const std::optional<Level> outLvl = env.merger().getLvl(outTid, ldx);
+  const std::optional<Level> outLvl = env.merger().getLvl(outTid, at);
 
   unsigned numloopCond = 0;
   bool hasNonUnique = false;
-  env.merger().foreachTensorLoopId(li, [&, ldx](TensorLoopId b, TensorId tid,
-                                                std::optional<Level> lvl,
-                                                LevelType lt, bool isIdxReduc) {
+  env.merger().foreachTensorLoopId(li, [&, at](TensorLoopId b, TensorId tid,
+                                               std::optional<Level> lvl,
+                                               LevelType lt, bool isIdxReduc) {
     if (simple[b]) {
       if (isIdxReduc) {
         tidLvls.push_back(env.makeTensorLevel(tid, *lvl));
@@ -1089,11 +1089,11 @@ static bool translateBitsToTidLvlPairs(
         if (isa<AffineDimExpr>(exp) || !stt.isDenseLvl(l))
           continue;
 
-        // Constant affine expression are handled in genLoop
+        // Constant affine expression are handled in genLoop.
         if (!isa<AffineConstantExpr>(exp)) {
           bool isAtLoop = false;
-          if (isInvariantAffine(exp, env.getLoopDepth(), ldx, isAtLoop) &&
-              isAtLoop) {
+          assert(at == env.getLoopDepth());
+          if (isInvariantAffine(exp, at + 1, /*out*/ isAtLoop) && isAtLoop) {
             // If the compound affine is invariant and we are right at the
             // level. We need to generate the address according to the
             // affine expression. This is also the best place we can do it
@@ -1105,7 +1105,7 @@ static bool translateBitsToTidLvlPairs(
     }
   });
 
-  if (isDenseLT(env.lt(outTid, ldx))) {
+  if (isDenseLT(env.lt(outTid, at))) {
     // Note that we generate dense indices of the output tensor
     // unconditionally, since they may not appear in the lattice, but may be
     // needed for linearized env.
@@ -1131,9 +1131,9 @@ static std::pair<Operation *, bool> startLoop(CodegenEnv &env,
                                               LatPointId li, bool needsUniv) {
   // The set of tensors + lvls to generate loops on
   SmallVector<TensorLevel> tidLvls;
+
   // The set of dense tensors with non-trivial affine expression that just
-  // becomes invariant and the address shall now be generated at the current
-  // level.
+  // becomes invariant and the address are generated at the current level.
   SmallVector<std::pair<TensorLevel, AffineExpr>> affineTidLvls;
   bool isSingleCond =
       translateBitsToTidLvlPairs(env, li, at, tidLvls, affineTidLvls);
@@ -1161,38 +1161,34 @@ static std::pair<Operation *, bool> startLoop(CodegenEnv &env,
 
 /// Ends a single loop in current sequence. Returns new values for needsUniv.
 static bool endLoop(CodegenEnv &env, RewriterBase &rewriter, Operation *loop,
-                    LoopId idx, LatPointId li, bool needsUniv,
-                    bool isSingleCond) {
-
+                    LatPointId li, bool needsUniv, bool isSingleCond) {
+  // Either a for-loop or a while-loop that iterates over a slice.
   if (isSingleCond) {
-    // Either a for-loop or a while-loop that iterates over a slice.
     // Any iteration creates a valid lex insert.
     if (env.isReduc() && env.getValidLexInsert())
       env.setValidLexInsert(constantI1(rewriter, env.op().getLoc(), true));
   } else if (auto whileOp = dyn_cast<scf::WhileOp>(loop)) {
     // End a while-loop.
-    finalizeWhileOp(env, rewriter, idx, needsUniv);
+    finalizeWhileOp(env, rewriter, needsUniv);
   } else {
     needsUniv = false;
   }
-
   env.genLoopBoundary([&](MutableArrayRef<Value> reduc) {
     env.emitter().exitCurrentLoop(rewriter, env.op().getLoc(), reduc);
     return std::nullopt;
   });
-
   return needsUniv;
 }
 
 /// Ends a loop sequence at given level.
 static void endLoopSeq(CodegenEnv &env, OpBuilder &builder, unsigned exp,
-                       unsigned idx, unsigned ldx) {
-  assert(!env.getLoopVar(idx));
+                       unsigned at) {
+  assert(!env.getLoopVar(at));
   env.emitter().exitCurrentLoopSeq(builder, env.op().getLoc());
   // Unmark bookkeeping of invariants and loop index.
-  genInvariants(env, builder, exp, ldx, /*atStart=*/false);
+  genInvariants(env, builder, exp, at, /*atStart=*/false);
   // Finalize access pattern expansion for sparse tensor output.
-  genExpand(env, builder, idx, /*atStart=*/false);
+  genExpand(env, builder, at, /*atStart=*/false);
 }
 
 /// Recursively generates code while computing iteration lattices in order
@@ -1200,6 +1196,8 @@ static void endLoopSeq(CodegenEnv &env, OpBuilder &builder, unsigned exp,
 /// and intersections of sparse iterations spaces.
 static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp,
                     LoopId at) {
+  assert(at == env.getLoopDepth());
+
   // At each leaf, assign remaining tensor (sub)expression to output tensor.
   if (at == env.getLoopNum()) {
     Value rhs = genExp(env, rewriter, exp);
@@ -1207,13 +1205,12 @@ static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp,
     return;
   }
 
-  // Construct iteration lattices for current loop index, with L0 at top.
-  const LoopId ldx = at == 0 ? sparse_tensor::detail::kInvalidId : at - 1;
+  // Construct iteration lattices for current loop index.
   const LatSetId lts =
       env.merger().optimizeSet(env.merger().buildLattices(exp, at));
 
   // Start a loop sequence.
-  bool needsUniv = startLoopSeq(env, rewriter, exp, at, ldx, lts);
+  bool needsUniv = startLoopSeq(env, rewriter, exp, at, lts);
 
   // Emit a loop for every lattice point L0 >= Li in this loop sequence.
   // We cannot change this to `for (const LatPointId li : env.set(lts))`
@@ -1250,11 +1247,12 @@ static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp,
     }
 
     // End a loop.
-    needsUniv = endLoop(env, rewriter, loop, at, li, needsUniv, isSingleCond);
+    needsUniv = endLoop(env, rewriter, loop, at, needsUniv, isSingleCond);
   }
 
   // End a loop sequence.
-  endLoopSeq(env, rewriter, exp, at, ldx);
+  endLoopSeq(env, rewriter, exp, at);
+  assert(at == env.getLoopDepth());
 }
 
 /// Converts the result computed by the sparse kernel into the required form.
@@ -1309,6 +1307,7 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
           op, "Loops not yet scheduled, try run --sparse-reinterpret-map "
               "before sparsification.");
     }
+
     // Must have been demapped as well if the generic op is sorted.
     assert(!hasAnyNonIdentityOperandsOrResults(op));
 


        


More information about the Mlir-commits mailing list