[Mlir-commits] [mlir] [mlir][sparse] use a common util function to query the tensor level s… (PR #76764)

Yinying Li llvmlistbot at llvm.org
Tue Jan 2 15:22:15 PST 2024


================
@@ -1120,22 +1035,99 @@ static bool translateBitsToTidLvlPairs(
     // Note that we generate dense indices of the output tensor
     // unconditionally, since they may not appear in the lattice, but may be
     // needed for linearized env.
-    tidLvls.push_back(env.makeTensorLevel(outTid, *outLvl));
+    callback(env.makeTensorLevel(outTid, *outLvl), nullptr);
   }
 
   if (numloopCond == 0) {
     // Corner cases where the loop bound is defined by a *unused* operand, in
     // this case, we just generate a dense "fake" loop by iterating over the
     // synthetic tensor.
-    tidLvls.push_back(env.makeTensorLevel(env.merger().getSynTensorID(),
-                                          env.getCurrentDepth()));
+    callback(env.makeTensorLevel(env.merger().getSynTensorID(), curr), nullptr);
     numloopCond++;
   }
   // If we just need to one loop conditions and the conditions is not imposed on
   // non-unique level, the loop can be generated by a for loop.
   return numloopCond == 1 && !hasNonUnique;
 }
 
+/// Starts a loop sequence at given level. Returns true if
+/// the universal loop index must be maintained at this level.
+static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp,
+                         LoopId curr, LatSetId lts) {
+  assert(!env.getLoopVar(curr));
+  // Emit invariants at this loop sequence level.
+  genInvariants(env, builder, exp, curr, /*isStart=*/true);
+  // Emit access pattern expansion for sparse tensor output.
+  genExpand(env, builder, curr, /*isStart=*/true);
+  // Emit further initialization at this loop sequence level.
+  const LatPointId l0 = env.set(lts)[0];
+
+  SmallVector<TensorLevel> tidLvls;
+  getAllTidLvlsInLatPoints(env, l0, curr, [&](TensorLevel tl, AffineExpr) {
+    tidLvls.emplace_back(tl);
+  });
+
+  env.emitter().enterNewLoopSeq(builder, env.op().getLoc(), tidLvls);
+
+  // Maintain the universal index only if it is actually
+  // consumed by a subsequent lattice point.
+  for (const LatPointId li : env.set(lts).drop_front())
+    if (!env.merger().hasAnySparse(env.lat(li).simple))
+      return true;
+
+  return false;
+}
+
+// Generates dense affine address for encoding.
+static void genConstantDenseAddressFromLevel(CodegenEnv &env,
+                                             OpBuilder &builder, TensorId tid,
+                                             Level startLvl) {
+  // TODO: Handle affine expression on output tensor.
+  linalg::GenericOp op = env.op();
+  assert(tid < op.getNumDpsInputs());
+  OpOperand *input = op.getDpsInputOperands()[tid];
+  const auto lvlExprs = op.getMatchingIndexingMap(input).getResults();
+  const auto enc = getSparseTensorEncoding(input->get().getType());
+  if (enc) {
+    const Location loc = op.getLoc();
+    const TensorId tid = env.makeTensorId(input->getOperandNumber());
+    const Level lvlRank = enc.getLvlRank();
+    assert(lvlExprs.size() == static_cast<size_t>(lvlRank));
+    for (Level l = startLvl; l < lvlRank; l++) {
+      AffineExpr lvlExpr = lvlExprs[l];
+      if (enc.isDenseLvl(l) && isa<AffineConstantExpr>(lvlExpr))
+        env.emitter().genDenseAffineAddress(
+            builder, loc, env.makeTensorLevel(tid, l), lvlExpr);
+      else
+        return; // break on first non-dense non-constant level
+    }
+  }
+}
+
+// We can generate address for constant affine expression before any loops
+// starting from the first level as they do not depend on any thing.
+// E.g., [Dense, Dense, Sparse] -> (1, 2, d0), the addresses for the first two
+// levels can be determined before loops.
+static void genInitConstantDenseAddress(CodegenEnv &env,
+                                        RewriterBase &rewriter) {
+  for (TensorId tid = 0, e = env.op().getNumDpsInputs(); tid < e; tid++)
+    genConstantDenseAddressFromLevel(env, rewriter, tid, 0);
+}
+
+/// Return true if the lattices bit can be iterated by a for loop.
----------------
yinying-lisa-li wrote:

nit: Returns. (consistent with all the other outmost comments)

https://github.com/llvm/llvm-project/pull/76764


More information about the Mlir-commits mailing list