[Mlir-commits] [mlir] d03805f - [mlir][sparse] add merger/topo sort support for slice-based affine sparse index codegen

Peiming Liu llvmlistbot at llvm.org
Mon Mar 20 14:24:15 PDT 2023


Author: Peiming Liu
Date: 2023-03-20T21:24:10Z
New Revision: d03805f2ee0bdaa2513fbc3efb9e404e128bdbb3

URL: https://github.com/llvm/llvm-project/commit/d03805f2ee0bdaa2513fbc3efb9e404e128bdbb3
DIFF: https://github.com/llvm/llvm-project/commit/d03805f2ee0bdaa2513fbc3efb9e404e128bdbb3.diff

LOG: [mlir][sparse] add merger/topo sort support for slice-based affine sparse index codegen

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D142928

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
    mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
    mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
    mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
index 59c5b78fda7b8..0e6c2f1553f1c 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
@@ -435,19 +435,58 @@ class Merger {
 
   /// Iterates over a set of `TensorLoopId`s, invoking the callback
   /// for each `TensorLoopId` and passing it the corresponding tensor
-  /// identifier, level, and level-type.
-  void
-  foreachTensorLoopId(LatPointId p,
-                      function_ref<void(TensorLoopId, TensorId,
-                                        std::optional<Level>, DimLevelType)>
-                          callback) const {
-    for (const TensorLoopId b : latPoints[p].bits.set_bits())
-      callback(b, tensor(b), getLvl(b), getDimLevelType(b));
+  /// identifier, level, and level-type, following with a boolean value
+  /// indicating whether it is a dependent index reduction loop condition.
+  void foreachTensorLoopId(
+      LatPointId p, function_ref<void(TensorLoopId, TensorId,
+                                      std::optional<Level>, DimLevelType, bool)>
+                        callback) {
+    for (const TensorLoopId b : latPoints[p].bits.set_bits()) {
+      TensorId t = tensor(b);
+      if (isLvlWithNonTrivialIdxExp(b)) {
+        // This must be an undefined level.
+        assert(!getLvl(b).has_value());
+        // Slice the tid along the dependent level to iterate current loop.
+        callback(b, t, loopToDependencies[loop(b)][t], getDimLevelType(b),
+                 /*isIdxReduc=*/true);
+      } else {
+        callback(b, t, getLvl(b), getDimLevelType(b), /*isIdxReduc=*/false);
+      }
+    }
   }
 
   /// Sets whether the output tensor is sparse or not.
   void setHasSparseOut(bool s) { hasSparseOut = s; }
 
+  /// Establishes the two-way map that i <-> <t, lvl>.
+  void setLoopDependentTensorLevel(LoopId i, TensorId t, Level lvl) {
+    assert(lvl < numLoops);
+    loopToDependencies[i][t] = lvl;
+    levelToDependentIdx[t][lvl].push_back(i);
+  }
+
+  /// Whether the loop has dependent slice.
+  bool hasDependentLvl(LoopId i, TensorId tid) {
+    return loopToDependencies[i][tid].has_value();
+  }
+
+  /// Returns the list of loop indices which appear in the non-trivial index
+  /// expression on t_l, e.g., A[i+j] => {i, j}
+  std::vector<LoopId> &getDependentLoops(TensorId t, Level lvl) {
+    return levelToDependentIdx[t][lvl];
+  }
+
+  /// Returns the defining [tid, lvl] for the loop.
+  std::pair<TensorId, Level> getLoopDefiningLvl(LoopId i) const {
+    return loopBounds[i];
+  }
+
+  /// Checks whether the TensorLoopId represents a tensor level with
+  /// non-trivial index expression on it.
+  bool isLvlWithNonTrivialIdxExp(TensorLoopId b) const {
+    return loopToDependencies[loop(b)][tensor(b)].has_value();
+  }
+
   /// Convenience getters to immediately access the stored nodes.
   /// Typically it is inadvisible to keep the reference around, as in
   /// `TensorExpr &te = merger.exp(e)`, since insertions into the merger
@@ -511,6 +550,20 @@ class Merger {
   // Map that converts pair<TensorId, Level> to the corresponding LoopId.
   std::vector<std::vector<std::optional<LoopId>>> lvlToLoop;
 
+  // Map from a loop to its dependencies if any.
+  // The dependencies of a loop is a set of (tensor, level) pairs.
+  // It is currently only set for non-trivial index expressions.
+  // E.g., A[i+j] => i and j will have dependencies {A0} to indicate that
+  // i and j are used in the non-trivial index expression on A0.
+  std::vector<std::vector<std::optional<Level>>> loopToDependencies;
+  // The inverse map of ldxToDependencies from tensor level -> dependent loop
+  // E.g., A[i+j], we have A0 => {i, j}, to indicate that A0 uses both {i, j}
+  // to compute its indices.
+  std::vector<std::vector<std::vector<LoopId>>> levelToDependentIdx;
+
+  // Map from a loop to the [tid, lvl] pair that defines the loop boundary.
+  std::vector<std::pair<TensorId, Level>> loopBounds;
+
   llvm::SmallVector<TensorExp> tensorExps;
   llvm::SmallVector<LatPoint> latPoints;
   llvm::SmallVector<SmallVector<LatPointId>> latSets;

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
index 8c6a7bd6433db..776d7f7f47ece 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
@@ -99,6 +99,7 @@ class CodegenEnv {
     topSort.reserve(capacity);
   }
 
+  ArrayRef<LoopId> getTopSort() const { return topSort; };
   ArrayRef<LoopId> getTopSortSlice(LoopOrd n, LoopOrd m) const;
   ArrayRef<LoopId> getLoopStackUpTo(LoopOrd n) const;
   ArrayRef<LoopId> getCurrentLoopStack() const;

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index f119ac3ba7ae5..d7ce2b7f63f5c 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -109,6 +109,12 @@ class AffineDimFinder : public AffineExprVisitor<AffineDimFinder> {
   SmallVector<utils::IteratorType> iterTypes;
 };
 
+// Flattens an affine expression into a list of AffineDimExprs.
+struct AffineDimCollector : public AffineExprVisitor<AffineDimCollector> {
+  void visitDimExpr(AffineDimExpr expr) { dims.push_back(expr); }
+  SmallVector<AffineDimExpr> dims;
+};
+
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -254,6 +260,69 @@ static bool findAffine(Merger &merger, TensorId tid, Level lvl, AffineExpr a,
   }
 }
 
+/// Helper method to inspect affine expressions for index variable reduction
+/// based codegen. It finds the dependent index set for all tensor levels in the
+/// current expression we are generating.
+///
+/// For example, when handling A[i+j][j+k], we build the two way mapping in
+/// merger between (tensor, level) pairs and their dependent index variable set:
+/// A_0 <=> [i, j] and A_1 <=> [j, k]
+///
+/// It rejects cases (returns false)
+/// 1st, when the same index is used more than once, e.g., A[i+j][i]
+/// 2nd, when multiplication is used in the non-trivial index expression.
+/// 3rd, when a constant operand is used in the non-trivial index expression.
+///
+/// TODO: constant should be easy to handle.
+static bool findDepIdxSet(Merger &merger, TensorId tensor, Level lvl,
+                          AffineExpr a, DimLevelType dlt,
+                          bool isSubExp = false) {
+  switch (a.getKind()) {
+  case AffineExprKind::DimId: {
+    LoopId ldx = a.cast<AffineDimExpr>().getPosition();
+    if (!isUndefDLT(merger.getDimLevelType(tensor, ldx)))
+      return false; // used more than once, e.g., A[i][i]
+
+    // TODO: Generalizes the following two cases. A[i] (with trivial index
+    // expression) can be treated as a special affine index expression. We do
+    // not necessarily need to 
diff erentiate them.
+    if (!isSubExp)
+      merger.setLevelAndType(tensor, ldx, lvl, dlt);
+
+    if (isSubExp) {
+      // The current loops appears in more than one affine expressions on the
+      // same tensor. We can not handle this case. e.g., A[i+j][i+k], `i` is
+      // used twice.
+      if (merger.hasDependentLvl(ldx, tensor)) {
+        // TODO: This can be supported by coiterate slices if the loop idx is
+        // appeared on affine index for 
diff erent tensor, or take slice on
+        // mulitple dimensions when it is on the same tensor.
+        // E.g.,
+        // `d0 + d1` for indexing t0[lvl0] and `d0 + d2` for indexing t1[lvl0]
+        // d0_1 = getNextSliceOffset t0 along lvl0
+        // d0_2 = getNextSliceOffset t1 along lvl0
+        // if d0_1 == d0_2 then d0 = d0_1 = d0_1
+        // else increase min(d0_1, d0_2).
+        return false;
+      }
+      merger.setLoopDependentTensorLevel(ldx, tensor, lvl);
+    }
+    return true;
+  }
+  case AffineExprKind::Constant:
+  case AffineExprKind::Mul:
+    // TODO: Support Mul and Constant AffineExp for slice-based codegen
+    return false;
+  case AffineExprKind::Add: {
+    auto binOp = a.cast<AffineBinaryOpExpr>();
+    return findDepIdxSet(merger, tensor, lvl, binOp.getLHS(), dlt, true) &&
+           findDepIdxSet(merger, tensor, lvl, binOp.getRHS(), dlt, true);
+  }
+  default:
+    return false;
+  }
+}
+
 /// Get the total number of compound affine expressions in the
 /// `getMatchingIndexingMap` for the given tensor.  For the following inputs:
 ///
@@ -262,7 +331,8 @@ static bool findAffine(Merger &merger, TensorId tid, Level lvl, AffineExpr a,
 ///
 /// Returns 1 (because the first level is compressed and its corresponding
 /// indexing-expression is `d0 + d1`)
-static unsigned getNumCompoundAffineOnSparseLvls(AffineMap map, Value tensor) {
+static unsigned getNumNonTrivialIdxExpOnSparseLvls(AffineMap map,
+                                                   Value tensor) {
   // The `tensor` is not guaranted to have `RankedTensorType`, therefore
   // we can't use `getRankedTensorType`/`getSparseTensorType` here.
   // However, we don't need to handle `StorageSpecifierType`, so we
@@ -305,20 +375,20 @@ static unsigned getNumCompoundAffineOnSparseLvls(AffineMap map, Value tensor) {
 
 /// Get the total number of sparse levels with compound affine
 /// expressions, summed over all operands of the `GenericOp`.
-static unsigned getNumCompoundAffineOnSparseLvls(linalg::GenericOp op) {
+static unsigned getNumNonTrivialIdxExpOnSparseLvls(linalg::GenericOp op) {
   unsigned num = 0;
   for (OpOperand &t : op->getOpOperands())
-    num += getNumCompoundAffineOnSparseLvls(op.getMatchingIndexingMap(&t),
-                                            t.get());
+    num += getNumNonTrivialIdxExpOnSparseLvls(op.getMatchingIndexingMap(&t),
+                                              t.get());
   return num;
 }
 
-static bool hasCompoundAffineOnSparseOut(linalg::GenericOp op) {
+static bool hasNonTrivialAffineOnSparseOut(linalg::GenericOp op) {
   OpOperand *out = op.getDpsInitOperand(0);
   if (getSparseTensorType(out->get()).isAllDense())
     return false;
-  return getNumCompoundAffineOnSparseLvls(op.getMatchingIndexingMap(out),
-                                          out->get());
+  return getNumNonTrivialIdxExpOnSparseLvls(op.getMatchingIndexingMap(out),
+                                            out->get());
 }
 
 /// Helper method to inspect sparse encodings in the tensor types.
@@ -326,7 +396,14 @@ static bool hasCompoundAffineOnSparseOut(linalg::GenericOp op) {
 /// Returns true if the sparse annotations and affine subscript
 /// expressions of all tensors are admissible. Returns false if
 /// no annotations are found or inadmissible constructs occur.
-static bool findSparseAnnotations(CodegenEnv &env) {
+/// We currently support two 
diff erent ways to handle non-trivial index
+/// expression on sparse tensors, and they accept 
diff erent affine expressions.
+/// When using filter-loop-based approach, it accept (almost) arbitrary affine
+/// index expression on sparse tensor but it is much less efficient, and will be
+/// gradually removed from the codebase.
+/// When using dependent index reducton-based approach, it currently only
+/// supports affine addition index expression.
+static bool findSparseAnnotations(CodegenEnv &env, bool idxReducBased) {
   bool annotated = false;
   // `filterLdx` may be mutated by `findAffine`.
   LoopId filterLdx = env.merger().getStartingFilterLoopId();
@@ -335,17 +412,30 @@ static bool findSparseAnnotations(CodegenEnv &env) {
     const auto enc = getSparseTensorEncoding(t.get().getType());
     if (enc)
       annotated = true;
+
     const Level lvlRank = map.getNumResults();
     assert(!enc || lvlRank == enc.getLvlRank());
     assert(static_cast<Level>(env.op().getRank(&t)) == lvlRank);
+
+    // We only need to do index reduction if there is at least one non-trivial
+    // index expression on sparse levels.
+    // If all non-trivial index expression is on dense levels, we can
+    // efficiently rely on the random access to locate the element.
+    bool needIdxReduc =
+        enc && getNumNonTrivialIdxExpOnSparseLvls(map, t.get()) != 0;
+    // If then current tensor being inspected requires affine index, it need
+    // to be sliced.
     for (Level l = 0; l < lvlRank; l++) {
       const TensorId tid = t.getOperandNumber();
-      // FIXME: `toOrigDim` is deprecated.
-      // FIXME: above we asserted that there are `lvlRank` many results,
-      // but this is assuming there are in fact `dimRank` many results instead.
-      const AffineExpr a = map.getResult(toOrigDim(enc, l));
-      if (!findAffine(env.merger(), tid, l, a, enc.getLvlType(l), filterLdx))
-        return false; // inadmissible affine expression
+      AffineExpr a = map.getResult(toOrigDim(enc, l));
+      DimLevelType dlt = enc.getLvlType(l);
+      if (idxReducBased && needIdxReduc) {
+        if (!findDepIdxSet(env.merger(), tid, l, a, dlt))
+          return false; // inadmissible affine expression
+      } else {
+        if (!findAffine(env.merger(), tid, l, a, dlt, filterLdx))
+          return false; // inadmissible affine expression
+      }
     }
   }
   assert(filterLdx == env.merger().getNumLoops());
@@ -469,11 +559,11 @@ static void addAffineOrderings(std::vector<std::vector<bool>> &adjM,
   }
 }
 
-static void tryLoosenAffineDenseConstraints(linalg::GenericOp op,
-                                            std::optional<LoopId> &fldx,
-                                            AffineExpr &fa,
-                                            std::optional<LoopId> &tldx,
-                                            AffineExpr &ta) {
+static void tryRelaxAffineConstraints(linalg::GenericOp op,
+                                      std::optional<LoopId> &fldx,
+                                      AffineExpr &fa,
+                                      std::optional<LoopId> &tldx,
+                                      AffineExpr &ta) {
   // We use a heuristic here to only pick one dim expression from each
   // compound affine expression to establish the order between two dense
   // dimensions.
@@ -494,7 +584,7 @@ static void tryLoosenAffineDenseConstraints(linalg::GenericOp op,
     }
     if (!ta.isa<AffineConstantExpr>()) {
       // Heuristic: we prefer reduction loop for rhs to reduce the chance
-      // addint reduce < parallel ordering.
+      // adding reduce < parallel ordering.
       finder.setPickedIterType(utils::IteratorType::reduction);
       finder.walkPostOrder(ta);
       ta = finder.getDimExpr();
@@ -503,14 +593,183 @@ static void tryLoosenAffineDenseConstraints(linalg::GenericOp op,
   }
 }
 
+/// Makes target array's elements appear in the same order as the `order` array.
+static void sortArrayBasedOnOrder(std::vector<LoopId> &target,
+                                  ArrayRef<LoopId> order) {
+  std::sort(target.begin(), target.end(), [&order](LoopId l, LoopId r) {
+    assert(l != r);
+    int idxL = -1, idxR = -1;
+    for (int i = 0, e = order.size(); i < e; i++) {
+      if (order[i] == l)
+        idxL = i;
+      if (order[i] == r)
+        idxR = i;
+    }
+    assert(idxL >= 0 && idxR >= 0);
+    return idxL < idxR;
+  });
+}
+
+static void addFilterLoopBasedConstraints(CodegenEnv &env, OpOperand &t,
+                                          OpOperand *skip, SortMask mask,
+                                          std::vector<std::vector<bool>> &adjM,
+                                          std::vector<unsigned> &inDegree) {
+  // Get map and encoding.
+  auto map = env.op().getMatchingIndexingMap(&t);
+  auto enc = getSparseTensorEncoding(t.get().getType());
+
+  // Each tensor expression and optional dimension ordering (row-major
+  // by default) puts an ordering constraint on the loop indices. For
+  // example, the tensor expresion A_ijk forces the ordering i < j < k
+  // on the loop indices if no explicit dimension ordering is given.
+  for (Level l = 0, rank = map.getNumResults(); l < rank; l++) {
+    AffineExpr ta = map.getResult(toOrigDim(enc, l));
+    std::optional<LoopId> tldx =
+        env.merger().getLoopId(t.getOperandNumber(), l);
+    // Filter loops should be constructed after all the dependent loops,
+    // i.e., d0 + d1 < filter_loop(d0 + d1)
+    if (tldx && env.merger().isFilterLoop(*tldx)) {
+      assert(!ta.isa<AffineDimExpr>() && !isDenseDLT(enc.getDimLevelType()[l]));
+      addAffineOrderings(adjM, inDegree, ta, AffineExpr(), std::nullopt, tldx);
+      // Now that the ordering of affine expression is captured by filter
+      // loop idx, we only need to ensure the affine ordering against filter
+      // loop. Thus, we reset the affine express to nil here to mark it as
+      // resolved.
+      ta = AffineExpr();
+    }
+
+    // Skip tensor during cycle resolution, though order between filter loop
+    // and dependent loops need to be guaranteed unconditionally.
+    if (&t == skip)
+      continue;
+
+    if (l > 0) {
+      AffineExpr fa = map.getResult(toOrigDim(enc, l - 1));
+      std::optional<LoopId> fldx =
+          env.merger().getLoopId(t.getOperandNumber(), l - 1);
+
+      // Applying order constraints on every pair of dimExpr between two
+      // compound affine expressions can sometime too strict:
+      // E.g, for [dense, dense] -> (d0 + d1, d2 + d3).
+      // It is totally fine to have loop sequence d0->d2->d1->d3 instead of
+      // requiring d0 < d2, d1 < d2, d0 < d3, d1 < d3.
+      // We also relax the affine constraint when use slice-based algorithm
+      // as there is no filter loop for affine index on sparse dimension.
+      // TODO: do we really need the condition?
+      if (!includesDense(mask))
+        tryRelaxAffineConstraints(env.op(), fldx, fa, tldx, ta);
+
+      // (d0 + d1) < (d2 + d3), or
+      // filter_loop_d-1 < (d2 + d3), or
+      // (d0 + d1) < filter_loop_d, or
+      // filter_loop_d-1 < filter_loop_d depending on whether fa/ta is reset
+      // above.
+      addAffineOrderings(adjM, inDegree, fa, ta, fldx, tldx);
+    }
+  }
+}
+
+static void addSliceBasedConstraints(CodegenEnv &env, OpOperand &t,
+                                     OpOperand *skip, SortMask mask,
+                                     std::vector<std::vector<bool>> &adjM,
+                                     std::vector<unsigned> &inDegree) {
+  // Get map and encoding.
+  auto map = env.op().getMatchingIndexingMap(&t);
+  auto enc = getSparseTensorEncoding(t.get().getType());
+
+  // No special treatment for simple indices.
+  if (getNumNonTrivialIdxExpOnSparseLvls(map, t.get()) == 0)
+    return addFilterLoopBasedConstraints(env, t, skip, mask, adjM, inDegree);
+
+  // Skip tensor during cycle resolution, though order between filter loop
+  // and dependent loops need to be guaranteed unconditionally.
+  if (&t == skip)
+    return;
+
+  AffineDimFinder finder(env.op());
+  finder.setPickedIterType(utils::IteratorType::reduction);
+  // To compute iteration graph for tensor[d0 + d1 + d3, d4 + d5 + d6],
+  // we requires there exist d_x \in {d0, d1, d3} and d_y \in {d4, d5, d6},
+  // and d_x > d_y && {d0, d1, d3} - d_x > {d4, d5, d6} - d_y
+  for (Level lvl = 1, rank = map.getNumResults(); lvl < rank; lvl++) {
+    AffineExpr fa = map.getResult(toOrigDim(enc, lvl - 1));
+    AffineExpr ta = map.getResult(toOrigDim(enc, lvl));
+
+    // This is a heuristic, we pick an abitrary reduction loop from lhs and
+    // rhs and use them as d_x and d_y.
+    finder.walkPostOrder(fa);
+    AffineDimExpr fexp = finder.getDimExpr();
+    LoopId fldx = fexp.getPosition();
+
+    finder.walkPostOrder(ta);
+    AffineDimExpr texp = finder.getDimExpr();
+    LoopId tldx = texp.getPosition();
+
+    // d_x > d_y
+    if (!adjM[fldx][tldx]) {
+      adjM[fldx][tldx] = true;
+      inDegree[tldx]++;
+    }
+
+    AffineDimCollector fCollector;
+    fCollector.walkPostOrder(fa);
+    AffineDimCollector tCollector;
+    tCollector.walkPostOrder(ta);
+
+    // make sure dx and dy is the last;
+    for (auto fd : fCollector.dims) {
+      LoopId f = fd.getPosition();
+      if (f == fldx)
+        continue;
+      if (!adjM[f][fldx]) {
+        adjM[f][fldx] = true;
+        inDegree[fldx]++;
+      }
+    }
+    for (auto td : tCollector.dims) {
+      LoopId t = td.getPosition();
+      if (t == tldx)
+        continue;
+      if (!adjM[t][tldx]) {
+        adjM[t][tldx] = true;
+        inDegree[tldx]++;
+      }
+    }
+    // Since we only support affine addition, the order between two dim
+    // expression does not really matters.
+    // {d0, d1, d3} - d_x > {d4, d5, d6} - d_y
+    // This is to ensure that the affine expressions are reduced in sparse
+    // tensor level ordering.
+    // TODO: this ordering could probably be loosen if we support out-of-order
+    // reduction.
+    // TODO: the evaluation order need to be ensure to
+    // support affine multiplication.
+    for (auto fd : fCollector.dims) {
+      LoopId f = fd.getPosition();
+      if (f == fldx) // skip d_x
+        continue;
+
+      for (auto td : tCollector.dims) {
+        LoopId t = td.getPosition();
+        if (t == tldx) // skip d_y
+          continue;
+        if (!adjM[f][t]) {
+          adjM[f][t] = true;
+          inDegree[t]++;
+        }
+      }
+    }
+  }
+}
+
 /// Computes a topologically sorted iteration graph for the linalg operation.
-/// Ensures all tensors are visited in natural coordinate order.  This is
+/// Ensures all tensors are visited in natural index order. This is
 /// essential for sparse storage formats since these only support access
-/// along fixed levels.  Even for dense storage formats, however, the natural
-/// coordinate order yields innermost unit-stride access with better spatial
+/// along fixed dimensions. Even for dense storage formats, however, the natural
+/// index order yields innermost unit-stride access with better spatial
 /// locality.
 static bool computeIterationGraph(CodegenEnv &env, SortMask mask,
-                                  OpOperand *skip = nullptr) {
+                                  OpOperand *skip, bool idxReducBased = false) {
   // Set up an n x n from/to adjacency matrix of the iteration graph
   // for the implicit loop indices i_0 .. i_n-1.
   const LoopId n = env.merger().getNumLoops();
@@ -522,7 +781,8 @@ static bool computeIterationGraph(CodegenEnv &env, SortMask mask,
     // Get map and encoding.
     const auto map = env.op().getMatchingIndexingMap(&t);
     const auto enc = getSparseTensorEncoding(t.get().getType());
-    assert(map.getNumDims() + getNumCompoundAffineOnSparseLvls(env.op()) == n);
+    assert(map.getNumDims() + getNumNonTrivialIdxExpOnSparseLvls(env.op()) ==
+           n);
 
     // Skips dense inputs/outputs when not requested.
     const bool isDenseInput = !enc && env.op().isDpsInput(&t);
@@ -549,63 +809,12 @@ static bool computeIterationGraph(CodegenEnv &env, SortMask mask,
         }
       }
     }
-
-    // Each tensor expression and optional dimension ordering (row-major
-    // by default) puts an ordering constraint on the loop indices. For
-    // example, the tensor expresion A_ijk forces the ordering i < j < k
-    // on the loop indices if no explicit dimension ordering is given.
-    const Level lvlRank = map.getNumResults();
-    assert(!enc || lvlRank == enc.getLvlRank());
-    for (Level l = 0; l < lvlRank; l++) {
-      // FIXME: `toOrigDim` is deprecated.
-      // FIXME: above we asserted that there are `lvlRank` many results,
-      // but this is assuming there are in fact `dimRank` many results instead.
-      AffineExpr ta = map.getResult(toOrigDim(enc, l));
-      std::optional<LoopId> tldx =
-          env.merger().getLoopId(t.getOperandNumber(), l);
-
-      // Filter loops should be constructed after all the dependent loops,
-      // i.e., d0 + d1 < filter_loop(d0 + d1)
-      if (tldx && env.merger().isFilterLoop(*tldx)) {
-        assert(!ta.isa<AffineDimExpr>() && !isDenseDLT(enc.getLvlType(l)));
-        addAffineOrderings(adjM, inDegree, ta, AffineExpr(), std::nullopt,
-                           tldx);
-        // Now that the ordering of affine expression is captured by filter
-        // loop idx, we only need to ensure the affine ordering against filter
-        // loop. Thus, we reset the affine express to nil here to mark it as
-        // resolved.
-        ta = AffineExpr();
-      }
-
-      // Skip tensor during cycle resolution, though order between filter loop
-      // and dependent loops need to be guaranteed unconditionally.
-      if (&t == skip)
-        continue;
-
-      if (l > 0) {
-        // FIXME: `toOrigDim` is deprecated.
-        // FIXME: above we asserted that there are `lvlRank` many results,
-        // but this is assuming there are in fact `dimRank` many results.
-        AffineExpr fa = map.getResult(toOrigDim(enc, l - 1));
-        std::optional<LoopId> fldx =
-            env.merger().getLoopId(t.getOperandNumber(), l - 1);
-
-        // Applying order constraints on every pair of dimExpr between two
-        // compound affine expressions can sometime too strict:
-        // E.g, for [dense, dense] -> (d0 + d1, d2 + d3).
-        // It is totally fine to have loop sequence d0->d2->d1->d3 instead of
-        // requiring d0 < d2, d1 < d2, d0 < d3, d1 < d3.
-        if (!includesDense(mask))
-          tryLoosenAffineDenseConstraints(env.op(), fldx, fa, tldx, ta);
-
-        // (d0 + d1) < (d2 + d3), or
-        // filter_loop_d-1 < (d2 + d3), or
-        // (d0 + d1) < filter_loop_d, or
-        // filter_loop_d-1 < filter_loop_d depending on whether fa/ta is reset
-        // above.
-        addAffineOrderings(adjM, inDegree, fa, ta, fldx, tldx);
-      }
-    }
+    // Push unrelated loops into sparse iteration space, so these
+    // will be skipped more often.
+    if (idxReducBased)
+      addSliceBasedConstraints(env, t, skip, mask, adjM, inDegree);
+    else
+      addFilterLoopBasedConstraints(env, t, skip, mask, adjM, inDegree);
   }
   // Topologically sort the iteration graph to determine loop order.
   // Report failure for a cyclic iteration graph.
@@ -1275,7 +1484,7 @@ static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp,
   SmallVector<Level> lvls;
   env.merger().foreachTensorLoopId(l0, [&](TensorLoopId b, TensorId tid,
                                            std::optional<Level> lvl,
-                                           DimLevelType dlt) {
+                                           DimLevelType dlt, bool) {
     assert(env.merger().loop(b) == idx);
     if (isDenseDLT(dlt) || isUndefDLT(dlt)) {
       needsUniv = true;
@@ -1350,7 +1559,7 @@ static bool translateBitsToTidLvlPairs(
   bool hasNonUnique = false;
   env.merger().foreachTensorLoopId(li, [&, ldx](TensorLoopId b, TensorId tid,
                                                 std::optional<Level> lvl,
-                                                DimLevelType dlt) {
+                                                DimLevelType dlt, bool) {
     if (simple.test(b)) {
       if (isUndefDLT(dlt)) {
         // An undefined dlt in the lattices, we probably mean to
@@ -1596,21 +1805,25 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
                                 PatternRewriter &rewriter) const override {
     // Only accept single output operations without affine index on sparse
     // output.
-    if (op.getNumDpsInits() != 1 || hasCompoundAffineOnSparseOut(op))
+    if (op.getNumDpsInits() != 1 || hasNonTrivialAffineOnSparseOut(op))
       return failure();
 
-    if (options.enableIndexReduction)
-      llvm_unreachable("not yet implemented");
-
     // Sets up a code generation environment.
     const unsigned numTensors = op->getNumOperands();
     const unsigned numLoops = op.getNumLoops();
-    const unsigned numFilterLoops = getNumCompoundAffineOnSparseLvls(op);
-    CodegenEnv env(op, options, numTensors, numLoops, numFilterLoops);
+    const unsigned numFilterLoops = getNumNonTrivialIdxExpOnSparseLvls(op);
+    // TODO: we should probably always use slice-based codegen whenever
+    // possible, we can even intermix slice-based and filter-loop based codegen.
+    bool idxReducBased = options.enableIndexReduction && numFilterLoops != 0;
+
+    // If we uses slice based algorithm for affine index, we do not need filter
+    // loop.
+    CodegenEnv env(op, options, numTensors, numLoops,
+                   /*numFilterLoops=*/idxReducBased ? 0 : numFilterLoops);
 
     // Detects sparse annotations and translates the per-level sparsity
     // information for all tensors to loop indices in the kernel.
-    if (!findSparseAnnotations(env))
+    if (!findSparseAnnotations(env, idxReducBased))
       return failure();
 
     // Constructs the tensor expressions tree from `op`, returns failure if the
@@ -1635,7 +1848,7 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
         SortMask::kIncludeDenseInput, SortMask::kIncludeDenseOutput,
         SortMask::kIncludeUndef,      SortMask::kSparseOnly};
     for (const SortMask mask : allMasks) {
-      if (computeIterationGraph(env, mask)) {
+      if (computeIterationGraph(env, mask, nullptr, idxReducBased)) {
         hasCycle = false;
         if (env.isAdmissibleTopoOrder()) {
           isAdmissible = true;
@@ -1644,11 +1857,24 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
         // else try a set of less strict constraints.
       }
     }
-    if (hasCycle)
-      return resolveCycle(env, rewriter); // one last shot
+    if (hasCycle) {
+      return idxReducBased
+                 ? failure() // TODO: should cycle be resolved 
diff erently?
+                 : resolveCycle(env, rewriter); // one last shot
+    }
+
     if (!isAdmissible)
       return failure(); // inadmissible expression, reject
 
+    for (OpOperand &t : env.op()->getOpOperands()) {
+      Level rank = env.op().getMatchingIndexingMap(&t).getNumResults();
+      for (Level lvl = 0; lvl < rank; lvl++) {
+        sortArrayBasedOnOrder(
+            env.merger().getDependentLoops(t.getOperandNumber(), lvl),
+            env.getTopSort());
+      }
+    }
+
     // Recursively generates code if admissible.
     env.startEmit();
     genBuffers(env, rewriter);

diff  --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index 029ce3f3f91ec..7f4400188cf14 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -220,7 +220,12 @@ Merger::Merger(unsigned numInputOutputTensors, unsigned numNativeLoops,
       loopToLvl(numTensors,
                 std::vector<std::optional<Level>>(numLoops, std::nullopt)),
       lvlToLoop(numTensors,
-                std::vector<std::optional<LoopId>>(numLoops, std::nullopt)) {}
+                std::vector<std::optional<LoopId>>(numLoops, std::nullopt)),
+      loopToDependencies(numLoops, std::vector<std::optional<Level>>(
+                                       numTensors, std::nullopt)),
+      levelToDependentIdx(numTensors, std::vector<std::vector<LoopId>>(
+                                          numLoops, std::vector<LoopId>())),
+      loopBounds(numLoops, std::make_pair(numTensors, numLoops)) {}
 
 //===----------------------------------------------------------------------===//
 // Lattice methods.
@@ -762,7 +767,10 @@ void Merger::dumpBits(const BitVector &bits) const {
       const TensorId t = tensor(b);
       const LoopId i = loop(b);
       const auto dlt = lvlTypes[t][i];
-      llvm::dbgs() << " i_" << t << "_" << i << "_" << toMLIRString(dlt);
+      if (isLvlWithNonTrivialIdxExp(b))
+        llvm::dbgs() << " DEP_" << t << "_" << i;
+      else
+        llvm::dbgs() << " i_" << t << "_" << i << "_" << toMLIRString(dlt);
     }
   }
 }


        


More information about the Mlir-commits mailing list