[Mlir-commits] [mlir] [mlir][sparse] use a common util function to query the tensor level s… (PR #76764)

Peiming Liu llvmlistbot at llvm.org
Tue Jan 2 15:25:37 PST 2024


https://github.com/PeimingLiu updated https://github.com/llvm/llvm-project/pull/76764

>From 5bc57435fc5b29aa20136615629c8629b2c56627 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Tue, 2 Jan 2024 23:11:19 +0000
Subject: [PATCH] [mlir][sparse] use a common util function to query the tensor
 level set in a lattice point.

---
 .../Transforms/Sparsification.cpp             | 180 +++++++++---------
 1 file changed, 86 insertions(+), 94 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 934e1e559f44d6..35eb4b4f6e47f8 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -949,94 +949,9 @@ static void endIf(CodegenEnv &env, OpBuilder &builder, scf::IfOp ifOp,
 // Sparsifier synthesis methods (loop sequence).
 //===----------------------------------------------------------------------===//
 
-/// Starts a loop sequence at given level. Returns true if
-/// the universal loop index must be maintained at this level.
-static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp,
-                         LoopId curr, LatSetId lts) {
-  assert(!env.getLoopVar(curr));
-  // Emit invariants at this loop sequence level.
-  genInvariants(env, builder, exp, curr, /*isStart=*/true);
-  // Emit access pattern expansion for sparse tensor output.
-  genExpand(env, builder, curr, /*isStart=*/true);
-  // Emit further intitialization at this loop sequence level.
-  const LatPointId l0 = env.set(lts)[0];
-  bool needsUniv = false;
-
-  SmallVector<TensorLevel> tidLvls;
-  env.merger().foreachTensorLoopId(l0, [&](TensorLoopId b, TensorId tid,
-                                           std::optional<Level> lvl,
-                                           LevelType lt, bool isIdxReduc) {
-    assert(env.merger().loop(b) == curr);
-    if (isDenseLT(lt) || isUndefLT(lt)) {
-      if (tid == env.merger().getSynTensorID()) {
-        // Needs loop emitter to set up loop bounds for synthetic tensor too if
-        // there is a loop condition imposed on the synthetic tensor.
-        tidLvls.push_back(env.makeTensorLevel(tid, env.getCurrentDepth()));
-      }
-      needsUniv = true;
-    }
-    if (isCompressedLT(lt) || isSingletonLT(lt) || isLooseCompressedLT(lt) ||
-        is2OutOf4LT(lt) || isIdxReduc) {
-      // Only when this is a index reduction loop, can the lt be undefined.
-      assert(!isUndefLT(lt) || isIdxReduc);
-      // sparse/singleton levels, or a dense/sparse index reduction loop.
-      tidLvls.push_back(env.makeTensorLevel(tid, *lvl));
-    }
-  });
-
-  env.emitter().enterNewLoopSeq(builder, env.op().getLoc(), tidLvls);
-
-  // Maintain the universal index only if it is actually
-  // consumed by a subsequent lattice point.
-  if (needsUniv) {
-    for (const LatPointId li : env.set(lts).drop_front())
-      if (!env.merger().hasAnySparse(env.lat(li).simple))
-        return true;
-  }
-  return false;
-}
-
-// Generates dense affine address for encoding.
-static void genConstantDenseAddressFromLevel(CodegenEnv &env,
-                                             OpBuilder &builder, TensorId tid,
-                                             Level startLvl) {
-  // TODO: Handle affine expression on output tensor.
-  linalg::GenericOp op = env.op();
-  assert(tid < op.getNumDpsInputs());
-  OpOperand *input = op.getDpsInputOperands()[tid];
-  const auto lvlExprs = op.getMatchingIndexingMap(input).getResults();
-  const auto enc = getSparseTensorEncoding(input->get().getType());
-  if (enc) {
-    const Location loc = op.getLoc();
-    const TensorId tid = env.makeTensorId(input->getOperandNumber());
-    const Level lvlRank = enc.getLvlRank();
-    assert(lvlExprs.size() == static_cast<size_t>(lvlRank));
-    for (Level l = startLvl; l < lvlRank; l++) {
-      AffineExpr lvlExpr = lvlExprs[l];
-      if (enc.isDenseLvl(l) && isa<AffineConstantExpr>(lvlExpr))
-        env.emitter().genDenseAffineAddress(
-            builder, loc, env.makeTensorLevel(tid, l), lvlExpr);
-      else
-        return; // break on first non-dense non-constant level
-    }
-  }
-}
-
-// We can generate address for constant affine expression before any loops
-// starting from the first level as they do not depend on any thing.
-// E.g., [Dense, Dense, Sparse] -> (1, 2, d0), the addresses for the first two
-// levels can be determined before loops.
-static void genInitConstantDenseAddress(CodegenEnv &env,
-                                        RewriterBase &rewriter) {
-  for (TensorId tid = 0, e = env.op().getNumDpsInputs(); tid < e; tid++)
-    genConstantDenseAddressFromLevel(env, rewriter, tid, 0);
-}
-
-/// Return true if the lattices bit can be iterated by a for loop.
-static bool translateBitsToTidLvlPairs(
+static bool getAllTidLvlsInLatPoints(
     CodegenEnv &env, LatPointId li, LoopId curr,
-    SmallVectorImpl<TensorLevel> &tidLvls,
-    SmallVectorImpl<std::pair<TensorLevel, AffineExpr>> &affineTidLvls) {
+    llvm::function_ref<void(TensorLevel, AffineExpr)> callback) {
   const BitVector &simple = env.lat(li).simple;
   const TensorId outTid = env.merger().getOutTensorID();
   const std::optional<Level> outLvl = env.merger().getLvl(outTid, curr);
@@ -1048,7 +963,7 @@ static bool translateBitsToTidLvlPairs(
                     LevelType lt, bool isIdxReduc) {
         if (simple[b]) {
           if (isIdxReduc) {
-            tidLvls.push_back(env.makeTensorLevel(tid, *lvl));
+            callback(env.makeTensorLevel(tid, *lvl), nullptr);
             numloopCond++;
             return;
           }
@@ -1072,10 +987,10 @@ static bool translateBitsToTidLvlPairs(
             }
           }
           hasNonUnique = !isUniqueLT(lt) || hasNonUnique;
-          tidLvls.push_back(env.makeTensorLevel(tid, *lvl));
+          callback(env.makeTensorLevel(tid, *lvl), nullptr);
           numloopCond++;
         } else if (isDenseLT(lt) || isIdxReduc) {
-          tidLvls.push_back(env.makeTensorLevel(tid, *lvl));
+          callback(env.makeTensorLevel(tid, *lvl), nullptr);
         } else {
           assert(isUndefLT(lt));
           linalg::GenericOp op = env.op();
@@ -1109,7 +1024,7 @@ static bool translateBitsToTidLvlPairs(
                 // level. We need to generate the address according to the
                 // affine expression. This is also the best place we can do it
                 // to avoid putting it inside inner loops.
-                affineTidLvls.emplace_back(env.makeTensorLevel(tid, l), exp);
+                callback(env.makeTensorLevel(tid, l), exp);
               }
             }
           }
@@ -1120,15 +1035,14 @@ static bool translateBitsToTidLvlPairs(
     // Note that we generate dense indices of the output tensor
     // unconditionally, since they may not appear in the lattice, but may be
     // needed for linearized env.
-    tidLvls.push_back(env.makeTensorLevel(outTid, *outLvl));
+    callback(env.makeTensorLevel(outTid, *outLvl), nullptr);
   }
 
   if (numloopCond == 0) {
     // Corner cases where the loop bound is defined by a *unused* operand, in
     // this case, we just generate a dense "fake" loop by iterating over the
     // synthetic tensor.
-    tidLvls.push_back(env.makeTensorLevel(env.merger().getSynTensorID(),
-                                          env.getCurrentDepth()));
+    callback(env.makeTensorLevel(env.merger().getSynTensorID(), curr), nullptr);
     numloopCond++;
   }
   // If we just need to one loop conditions and the conditions is not imposed on
@@ -1136,6 +1050,84 @@ static bool translateBitsToTidLvlPairs(
   return numloopCond == 1 && !hasNonUnique;
 }
 
+/// Starts a loop sequence at given level. Returns true if
+/// the universal loop index must be maintained at this level.
+static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp,
+                         LoopId curr, LatSetId lts) {
+  assert(!env.getLoopVar(curr));
+  // Emit invariants at this loop sequence level.
+  genInvariants(env, builder, exp, curr, /*isStart=*/true);
+  // Emit access pattern expansion for sparse tensor output.
+  genExpand(env, builder, curr, /*isStart=*/true);
+  // Emit further initialization at this loop sequence level.
+  const LatPointId l0 = env.set(lts)[0];
+
+  SmallVector<TensorLevel> tidLvls;
+  getAllTidLvlsInLatPoints(env, l0, curr, [&](TensorLevel tl, AffineExpr) {
+    tidLvls.emplace_back(tl);
+  });
+
+  env.emitter().enterNewLoopSeq(builder, env.op().getLoc(), tidLvls);
+
+  // Maintain the universal index only if it is actually
+  // consumed by a subsequent lattice point.
+  for (const LatPointId li : env.set(lts).drop_front())
+    if (!env.merger().hasAnySparse(env.lat(li).simple))
+      return true;
+
+  return false;
+}
+
+// Generates dense affine address for encoding.
+static void genConstantDenseAddressFromLevel(CodegenEnv &env,
+                                             OpBuilder &builder, TensorId tid,
+                                             Level startLvl) {
+  // TODO: Handle affine expression on output tensor.
+  linalg::GenericOp op = env.op();
+  assert(tid < op.getNumDpsInputs());
+  OpOperand *input = op.getDpsInputOperands()[tid];
+  const auto lvlExprs = op.getMatchingIndexingMap(input).getResults();
+  const auto enc = getSparseTensorEncoding(input->get().getType());
+  if (enc) {
+    const Location loc = op.getLoc();
+    const TensorId tid = env.makeTensorId(input->getOperandNumber());
+    const Level lvlRank = enc.getLvlRank();
+    assert(lvlExprs.size() == static_cast<size_t>(lvlRank));
+    for (Level l = startLvl; l < lvlRank; l++) {
+      AffineExpr lvlExpr = lvlExprs[l];
+      if (enc.isDenseLvl(l) && isa<AffineConstantExpr>(lvlExpr))
+        env.emitter().genDenseAffineAddress(
+            builder, loc, env.makeTensorLevel(tid, l), lvlExpr);
+      else
+        return; // break on first non-dense non-constant level
+    }
+  }
+}
+
+// We can generate address for constant affine expression before any loops
+// starting from the first level as they do not depend on anything.
+// E.g., [Dense, Dense, Sparse] -> (1, 2, d0), the addresses for the first two
+// levels can be determined before loops.
+static void genInitConstantDenseAddress(CodegenEnv &env,
+                                        RewriterBase &rewriter) {
+  for (TensorId tid = 0, e = env.op().getNumDpsInputs(); tid < e; tid++)
+    genConstantDenseAddressFromLevel(env, rewriter, tid, 0);
+}
+
+/// Returns true if the lattice bit can be iterated by a for loop.
+static bool translateBitsToTidLvlPairs(
+    CodegenEnv &env, LatPointId li, LoopId curr,
+    SmallVectorImpl<TensorLevel> &tidLvls,
+    SmallVectorImpl<std::pair<TensorLevel, AffineExpr>> &affineTidLvls) {
+  return getAllTidLvlsInLatPoints(env, li, curr,
+                                  [&](TensorLevel tl, AffineExpr exp) {
+                                    if (exp)
+                                      affineTidLvls.emplace_back(tl, exp);
+                                    else
+                                      tidLvls.emplace_back(tl);
+                                  });
+}
+
 /// Starts a single loop in current sequence.
 static std::pair<Operation *, bool> startLoop(CodegenEnv &env,
                                               OpBuilder &builder, LoopId curr,



More information about the Mlir-commits mailing list