[Mlir-commits] [mlir] [mlir][sparse] schedule sparse kernels in a separate pass from sparsification. (PR #72423)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Nov 15 10:23:09 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-sparse

Author: Peiming Liu (PeimingLiu)

<details>
<summary>Changes</summary>



---

Patch is 98.60 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/72423.diff


53 Files Affected:

- (modified) mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h (+3) 
- (modified) mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp (+10) 
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt (+1) 
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp (+15-22) 
- (added) mlir/lib/Dialect/SparseTensor/Transforms/LoopScheduler.cpp (+273) 
- (added) mlir/lib/Dialect/SparseTensor/Transforms/LoopScheduler.h (+73) 
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp (+179-16) 
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp (+7-580) 
- (modified) mlir/test/Dialect/SparseTensor/GPU/gpu_combi.mlir (+1-1) 
- (modified) mlir/test/Dialect/SparseTensor/GPU/gpu_matmul.mlir (+1) 
- (modified) mlir/test/Dialect/SparseTensor/GPU/gpu_matvec.mlir (+1) 
- (modified) mlir/test/Dialect/SparseTensor/constant_index_map.mlir (+1-1) 
- (modified) mlir/test/Dialect/SparseTensor/dense.mlir (+1-1) 
- (modified) mlir/test/Dialect/SparseTensor/one_trip.mlir (+1-1) 
- (modified) mlir/test/Dialect/SparseTensor/semi_ring.mlir (+1-1) 
- (modified) mlir/test/Dialect/SparseTensor/sorted_coo.mlir (+1-1) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_1d.mlir (+1-1) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_2d.mlir (+1-1) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_3d.mlir (+1-1) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_affine.mlir (+1-1) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_broadcast.mlir (+1-1) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir (+1-1) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_expand.mlir (+4-1) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir (+1-1) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_fp_ops.mlir (+1-1) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_index.mlir (+1-1) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_int_ops.mlir (+1-2) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_kernels.mlir (+28-26) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_lower.mlir (+3-3) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_lower_col.mlir (+7-6) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_lower_inplace.mlir (+3-3) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir (+1-1) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_nd.mlir (+1-1) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_out.mlir (+1-1) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_outbuf.mlir (+1-2) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_parallel.mlir (+5-5) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_parallel_reduce.mlir (+1-1) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_perm.mlir (+8-6) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir (+7-6) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_scalars.mlir (+1-1) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir (+1-1) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_sddmm_org.mlir (+1-1) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_storage.mlir (+1-1) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_transpose.mlir (+7-7) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_vector.mlir (+4-4) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir (+1-1) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_vector_index.mlir (+1-1) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_vector_ops.mlir (+1-2) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_vector_peeled.mlir (+1-1) 
- (modified) mlir/test/Dialect/SparseTensor/spy_sddmm.mlir (+1-1) 
- (modified) mlir/test/Dialect/SparseTensor/unsparsifiable_dense_op.mlir (+1-1) 
- (modified) mlir/test/Dialect/SparseTensor/unused-tensor.mlir (+1-1) 
- (modified) mlir/test/Dialect/SparseTensor/vectorize_reduction.mlir (+2-2) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
index 2fb61746d048999..bb8df626086bb29 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
@@ -127,6 +127,9 @@ inline bool hasAnySparseOperandOrResult(Operation *op) {
   return hasAnySparseOperand(op) || hasAnySparseResult(op);
 }
 
+/// Whether the operation has any sparse tensor with non-identity dim2lvl maps.
+bool hasAnyNonIdentityOperandsOrResults(Operation *op);
+
 //
 // Inference.
 //
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 92bf8ec6468e532..53a0c2428842f59 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -875,6 +875,16 @@ bool mlir::sparse_tensor::isUniqueCOOType(Type tp) {
   return isCOOType(getSparseTensorEncoding(tp), 0, /*isUnique=*/true);
 }
 
+bool mlir::sparse_tensor::hasAnyNonIdentityOperandsOrResults(Operation *op) {
+  auto hasNonIdentityMap = [](Value v) {
+    auto stt = tryGetSparseTensorType(v);
+    return stt && !stt->isIdentity();
+  };
+
+  return llvm::any_of(op->getOperands(), hasNonIdentityMap) ||
+         llvm::any_of(op->getResults(), hasNonIdentityMap);
+}
+
 Level mlir::sparse_tensor::getCOOStart(SparseTensorEncodingAttr enc) {
   // We only consider COO region with at least two levels for the purpose
   // of AOS storage optimization.
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
index b8a2ff26b6794f6..06d0ec1d7eb7cf3 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms
   CodegenEnv.cpp
   CodegenUtils.cpp
   LoopEmitter.cpp
+  LoopScheduler.cpp
   SparseBufferRewriting.cpp
   SparseGPUCodegen.cpp
   SparseReinterpretMap.cpp
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
index 5c7cc93737b7fd7..3df9bcc77717994 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
@@ -57,7 +57,11 @@ CodegenEnv::CodegenEnv(linalg::GenericOp linop, SparsificationOptions opts,
       loopEmitter(), topSort(), sparseOut(nullptr), outerParNest(-1u),
       insChain(), expValues(), expFilled(), expAdded(), expCount(), redVal(),
       redExp(detail::kInvalidId), redCustom(detail::kInvalidId),
-      redValidLexInsert() {}
+      redValidLexInsert() {
+  // TODO: remove topSort, loops should be already sorted by previous pass.
+  for (unsigned l = 0; l < latticeMerger.getNumLoops(); l++)
+    topSort.push_back(l);
+}
 
 LogicalResult CodegenEnv::initTensorExp() {
   // Builds the tensor expression for the Linalg operation in SSA form.
@@ -181,36 +185,25 @@ bool CodegenEnv::isAdmissibleTensorExp(ExprId exp) {
   // Accept "truly dynamic" if the output tensor materializes uninitialized
   // into the computation and insertions occur in lexicographic index order.
   sparseOut = lhs;
-  return isMaterializing(lhs->get());
-}
 
-bool CodegenEnv::isAdmissibleTopoOrder() {
-  if (!hasSparseOutput())
-    return true;
-
-  OpOperand *lhs = linalgOp.getDpsInitOperand(0);
-  // Accept "truly dynamic" if the output tensor materializes uninitialized
-  // into the computation and insertions occur in lexicographic index order.
-  LoopOrd nest = 0;
+  // Find the outermost parallel nest to determine whether compress/expand is
+  // needed.
+  outerParNest = 0;
   const auto iteratorTypes = linalgOp.getIteratorTypesArray();
   assert(topSortSize() == latticeMerger.getNumLoops());
   for (const LoopId i : topSort) {
     if (!latticeMerger.isFilterLoop(i)) {
-      // We only count non-filter loops as filter loops should be considered
-      // a special type of parallel loops.
       if (linalg::isReductionIterator(iteratorTypes[i]))
         break; // terminate at first reduction
-      nest++;
+      outerParNest++;
     }
   }
-  // Determine admissible dynamic insertion situations:
-  // (1) fully injective, since there are no reductions,
-  // (2) admissible 1-d expansion in innermost dimension.
-  if (static_cast<int64_t>(nest) >= linalgOp.getRank(lhs) - 1) {
-    outerParNest = nest;
-    return true;
-  }
-  return false;
+
+  // Inadmissible kernel should have already been rejected by the previous
+  // path during loop scheduling.
+  assert(static_cast<int64_t>(outerParNest) >=
+         linalgOp.getRank(linalgOp.getDpsInitOperand(0)) - 1);
+  return isMaterializing(lhs->get());
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopScheduler.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopScheduler.cpp
new file mode 100644
index 000000000000000..32e0a6f5c03d8a2
--- /dev/null
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopScheduler.cpp
@@ -0,0 +1,273 @@
+//===- LoopScheduler.cpp -------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "LoopScheduler.h"
+
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
+#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/IR/AffineExprVisitor.h"
+#include "mlir/IR/BuiltinTypes.h"
+
+using namespace mlir;
+using namespace mlir::sparse_tensor;
+
+namespace {
+
+/// A helper class that visits an affine expression and tries to find an
+/// AffineDimExpr to which the corresponding iterator from a GenericOp matches
+/// the desired iterator type.
+/// If there is no matched iterator type, returns the first DimExpr in the
+/// expression.
+class AffineDimFinder : public AffineExprVisitor<AffineDimFinder> {
+public:
+  explicit AffineDimFinder(ArrayRef<utils::IteratorType> itTypes)
+      : iterTypes(itTypes) {}
+
+  // Overrides method from AffineExprVisitor.
+  void visitDimExpr(AffineDimExpr expr) {
+    if (pickedDim == nullptr || pickIterType == iterTypes[expr.getPosition()])
+      pickedDim = expr;
+  }
+
+  /// Set the desired iterator type that we want to pick.
+  void setPickedIterType(utils::IteratorType iterType) {
+    pickIterType = iterType;
+  }
+
+  /// Get the desired AffineDimExpr.
+  AffineDimExpr getDimExpr() const {
+    return llvm::cast<AffineDimExpr>(pickedDim);
+  }
+
+  void walkPostOrder(AffineExpr expr) {
+    pickedDim = nullptr;
+    AffineExprVisitor<AffineDimFinder>::walkPostOrder(expr);
+  }
+
+private:
+  /// The picked AffineDimExpr after visit.
+  AffineExpr pickedDim;
+  /// The iterator type that we want.
+  utils::IteratorType pickIterType;
+  /// The mapping between dim=>iterator type.
+  ArrayRef<utils::IteratorType> iterTypes;
+};
+
+// Flattens an affine expression into a list of AffineDimExprs.
+struct AffineDimCollector : public AffineExprVisitor<AffineDimCollector> {
+  // Overrides method from AffineExprVisitor.
+  void visitDimExpr(AffineDimExpr expr) { dims.push_back(expr); }
+  SmallVector<AffineDimExpr> dims;
+};
+
+} // namespace
+
+inline static bool includesAny(SortMask mask1, SortMask mask2) {
+  return static_cast<unsigned>(mask1) & static_cast<unsigned>(mask2);
+}
+
+inline static bool includesDenseInput(SortMask mask) {
+  return includesAny(mask, SortMask::kIncludeDenseInput);
+}
+
+inline static bool includesDenseOutput(SortMask mask) {
+  return includesAny(mask, SortMask::kIncludeDenseOutput);
+}
+
+/// A helper to compute a topological sort. O(n^2) time complexity
+/// as we use adj matrix for the graph.
+/// The sorted result will put the first Reduction iterator to the
+/// latest possible position.
+AffineMap LoopScheduler::topoSort() {
+  std::vector<unsigned> redIt; // reduce iterator with 0 degree
+  std::vector<unsigned> parIt; // parallel iterator with 0 degree
+  const unsigned numLoops = getNumLoops();
+  for (unsigned i = 0; i < numLoops; i++) {
+    if (inDegree[i] == 0) {
+      if (iterTypes[i] == utils::IteratorType::reduction)
+        redIt.push_back(i);
+      else
+        parIt.push_back(i);
+    }
+  }
+
+  SmallVector<unsigned> loopOrder;
+  while (!redIt.empty() || !parIt.empty()) {
+    // We always prefer parallel loop over reduction loop because putting
+    // reduction loop early might make the loop sequence inadmissible.
+    auto &it = !parIt.empty() ? parIt : redIt;
+    auto src = it.back();
+    loopOrder.push_back(src);
+    it.pop_back();
+    // Update in-degree, and push 0-degree node into worklist.
+    for (unsigned dst = 0; dst < numLoops; dst++) {
+      if (itGraph[src][dst] && --inDegree[dst] == 0) {
+        if (iterTypes[dst] == utils::IteratorType::reduction)
+          redIt.push_back(dst);
+        else
+          parIt.push_back(dst);
+      }
+    }
+  }
+
+  if (loopOrder.size() == numLoops)
+    return AffineMap::getPermutationMap(loopOrder, out.getContext());
+
+  // Cycle detected.
+  return AffineMap();
+}
+
+LoopScheduler LoopScheduler::fromGenericOp(linalg::GenericOp genericOp) {
+  // Must be a demapped sparse kernel.
+  assert(!hasAnyNonIdentityOperandsOrResults(genericOp) &&
+         hasAnySparseOperandOrResult(genericOp) &&
+         genericOp.getNumDpsInits() == 1);
+
+  SmallVector<AffineMap> loopMap = genericOp.getIndexingMapsArray();
+  SmallVector<Value> ins = genericOp.getDpsInputs();
+
+  AffineMap outMap = loopMap.back();
+  loopMap.pop_back();
+
+  Value out = genericOp.getDpsInitOperand(0)->get();
+  SmallVector<utils::IteratorType> iterTypes =
+      genericOp.getIteratorTypesArray();
+
+  return LoopScheduler(std::move(ins), std::move(loopMap), out, outMap,
+                       std::move(iterTypes));
+}
+
+LoopScheduler::LoopScheduler(SmallVector<Value> &&ins,
+                             SmallVector<AffineMap> &&loop2InsLvl, Value out,
+                             AffineMap loop2OutLvl,
+                             SmallVector<utils::IteratorType> &&iterTypes)
+    : ins(std::move(ins)), loop2InsLvl(std::move(loop2InsLvl)), out(out),
+      loop2OutLvl(loop2OutLvl), iterTypes(std::move(iterTypes)) {
+  // One map per tensor.
+  assert(loop2InsLvl.size() == ins.size());
+  // All the affine maps have the same number of dimensions (loops).
+  assert(llvm::all_equal(llvm::map_range(
+      loop2InsLvl, [](AffineMap m) { return m.getNumDims(); })));
+  // The number of results of the map should match the rank of the tensor.
+  assert(llvm::all_of(llvm::zip(loop2InsLvl, ins), [](auto mvPair) {
+    auto [m, v] = mvPair;
+    return m.getNumResults() ==
+           v.getType().template cast<ShapedType>().getRank();
+  }));
+
+  itGraph.resize(getNumLoops(), std::vector<bool>(getNumLoops(), false));
+  inDegree.resize(getNumLoops());
+}
+
+AffineMap LoopScheduler::schedule(SortMask mask, Value ignored) {
+  // Reset the interation graph.
+  for (auto &row : itGraph)
+    std::fill(row.begin(), row.end(), false);
+  // Reset cached in-degree.
+  std::fill(inDegree.begin(), inDegree.end(), 0);
+
+  for (auto [in, map] : llvm::zip(ins, loop2InsLvl)) {
+    // Get map and encoding.
+    const auto enc = getSparseTensorEncoding(in.getType());
+    // Skips dense inputs when not requested.
+    if ((!enc && !includesDenseInput(mask)) || in == ignored)
+      continue;
+
+    addConstraints(in, map);
+  }
+
+  // Get map and encoding.
+  const auto enc = getSparseTensorEncoding(out.getType());
+  if ((enc || includesDenseOutput(mask)) && out != ignored)
+    addConstraints(out, loop2OutLvl);
+
+  return topoSort();
+}
+
+void LoopScheduler::addConstraints(Value t, AffineMap loop2LvlMap) {
+  auto addIterOrdering = [this](unsigned f, unsigned t) {
+    if (!itGraph[f][t] && f != t) {
+      itGraph[f][t] = true;
+      inDegree[t]++;
+    }
+  };
+
+  AffineDimFinder finder(iterTypes);
+  finder.setPickedIterType(utils::IteratorType::reduction);
+
+  // To compute iteration graph for tensor[d0 + d1 + d3, d4 + d5 + d6],
+  // we requires there exist d_x \in {d0, d1, d3} and d_y \in {d4, d5, d6},
+  // and d_x > d_y && {d0, d1, d3} - d_x > {d4, d5, d6} - d_y
+  const Level lvlRank = loop2LvlMap.getNumResults();
+  for (Level lvl = 1; lvl < lvlRank; lvl++) {
+    const AffineExpr fa = loop2LvlMap.getResult(lvl - 1);
+    const AffineExpr ta = loop2LvlMap.getResult(lvl);
+
+    if (llvm::isa<AffineDimExpr>(fa) || llvm::isa<AffineDimExpr>(ta)) {
+      // Special case when at least one loop2LvlExp is an simple AffineDimExpr
+      // (say, d0) and we require d0 > {d1, d2, ...} or {d1, d2, ...} > d0
+      AffineDimCollector fCollector;
+      fCollector.walkPostOrder(fa);
+      AffineDimCollector tCollector;
+      tCollector.walkPostOrder(ta);
+
+      for (auto fd : fCollector.dims) {
+        for (auto td : tCollector.dims) {
+          const unsigned f = fd.getPosition();
+          const unsigned t = td.getPosition();
+          addIterOrdering(f, t);
+        }
+      }
+      continue;
+    }
+
+    // When both loop2LvlExpr is compound, we pick an abitrary reduction loop
+    // from lhs and rhs and use them as d_x and d_y.
+    finder.walkPostOrder(fa);
+    const AffineDimExpr fexp = finder.getDimExpr();
+    const unsigned fldx = fexp.getPosition();
+
+    finder.walkPostOrder(ta);
+    const AffineDimExpr texp = finder.getDimExpr();
+    const unsigned tldx = texp.getPosition();
+
+    // d_x > d_y
+    addIterOrdering(fldx, tldx);
+
+    AffineDimCollector fCollector;
+    fCollector.walkPostOrder(fa);
+    AffineDimCollector tCollector;
+    tCollector.walkPostOrder(ta);
+
+    // make sure dx and dy is the last;
+    for (auto fd : fCollector.dims) {
+      const unsigned f = fd.getPosition();
+      addIterOrdering(f, fldx);
+    }
+    for (auto td : tCollector.dims) {
+      const unsigned t = td.getPosition();
+      addIterOrdering(t, tldx);
+    }
+    // {d0, d1, d3} - d_x > {d4, d5, d6} - d_y
+    // This is to ensure that the affine expressions are reduced in sparse
+    // tensor level ordering.
+    for (auto fd : fCollector.dims) {
+      const unsigned f = fd.getPosition();
+      if (f == fldx) // skip d_x
+        continue;
+      for (auto td : tCollector.dims) {
+        const unsigned t = td.getPosition();
+        if (t == tldx) // skip d_y
+          continue;
+        addIterOrdering(f, t);
+      }
+    }
+  }
+}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopScheduler.h b/mlir/lib/Dialect/SparseTensor/Transforms/LoopScheduler.h
new file mode 100644
index 000000000000000..1d5cb1bc6f60e4e
--- /dev/null
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopScheduler.h
@@ -0,0 +1,73 @@
+//===- LoopScheduler.h -----------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/AffineMap.h"
+
+namespace mlir {
+
+// Forward declarations.
+class Value;
+namespace utils {
+enum class IteratorType : uint32_t;
+} // namespace utils
+namespace linalg {
+class GenericOp;
+} // namespace linalg
+
+namespace sparse_tensor {
+
+/// Iteration graph sorting.
+enum class SortMask : unsigned {
+  // The individual mask bits.
+  kIncludeDenseOutput = 0x1, // b001
+  kIncludeDenseInput = 0x2,  // b010
+  kIncludeUndef = 0x4,       // b100
+  // The subsets of mask bits.
+  kIncludeAll = 0x7,   // b111
+  kIncludeDense = 0x3, // b011
+  kSparseOnly = 0x0,   // b000
+};
+
+class LoopScheduler {
+public:
+  // Constructs a scheduler from linalg.generic
+  // Maybe reuses the class to schedule foreach as well (to address
+  // non-permutation, e.g, traverse CSR in BSR order).
+  static LoopScheduler fromGenericOp(linalg::GenericOp genericOp);
+
+  // Returns a permutation that represents the scheduled loop order.
+  // Note that the returned AffineMap could be null if the kernel can not be
+  // schedule due to cycles in the iteration graph.
+  [[nodiscard]] AffineMap schedule(SortMask mask, Value ignored = nullptr);
+  unsigned getNumLoops() const { return loop2OutLvl.getNumDims(); }
+
+private:
+  LoopScheduler(SmallVector<Value> &&ins, SmallVector<AffineMap> &&loop2InsLvl,
+                Value out, AffineMap loop2OutLvl,
+                SmallVector<utils::IteratorType> &&iterTypes);
+
+  void addConstraints(Value t, AffineMap loop2LvlMap);
+  AffineMap topoSort();
+
+  // Input tensors and associated loop to level maps.
+  SmallVector<Value> ins;
+  SmallVector<AffineMap> loop2InsLvl;
+  // Output tensor and associated loop to level map.
+  Value out;
+  AffineMap loop2OutLvl;
+  // Loop type;
+  SmallVector<utils::IteratorType> iterTypes;
+
+  // Adjacent matrix that represents the iteration graph.
+  std::vector<std::vector<bool>> itGraph;
+  // InDegree used for topo sort.
+  std::vector<unsigned> inDegree;
+};
+
+} // namespace sparse_tensor
+} // namespace mlir
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
index 83e86d137335021..a433b2edc5500e7 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
@@ -1,4 +1,4 @@
-//===- SparseReinterpretMap.cpp - reinterpret sparse tensor maps ----------===//
+//===- SparseReinterpretMap.cpp - reinterpret sparse tensor maps ----------===/
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "CodegenUtils.h"
+#include "LoopScheduler.h"
 
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
@@ -351,17 +352,6 @@ static SmallVector<Value> remapValueRange(OpBuilder &rewriter, TypeRange types,
   return ret;
 }
 
-/// Whether the operation has any sparse tensor with non-identity dim2lvl maps.
-static bool hasNonIdentityOperandsOrResults(Operation *op) {
-  auto hasNonIdentityMap = [](Value v) {
-    auto stt = tryGetSparseTensorType(v);
-    return stt && !stt->isIdentity();
-  };
-
-  return llvm::any_of(op->getOperands(), hasNonIdentityMap) ||
-         llvm::any_of(op->getResults(), hasNonIdentityMap);
-}
-
 namespace {
 
 //===----------------------------------------------------------------------===//
@@ -379,7 +369,7 @@ struct GenericOpReinterpretMap
     // semantics.
     if (linalgOp.getNumDpsInits() != 1 || !linalgOp.hasTensorSemantics() ||
         !hasAnySparseOperandOrResult(linalgOp) ||
-        !hasNonIdentityOperandsOrResults(linalgOp))
+        !hasAnyNonIdentityOperandsOrResults(linalgOp))
       return failure();
 
     // Try translating the index map.
@@ -411,6 +401,178 @@ struct GenericOpReinterpretMap
   }
 };
 
+struct GenericOpScheduler : public OpRewritePattern<linalg::GenericOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(linalg::GenericOp linalgOp,
+                                PatternRewriter &rewriter) const override {
+    if (linalgOp.getNumDpsInits() != 1 || !linalgOp.hasTensorSemantics() ||
+        hasAnyNonIdentityOperandsOrResults(linalgOp) || // need demap first
+        !hasAnySparseOperandOrResult(linalgOp)) {
+      return failure();
+    }
+
+    const StringRef sorted = "sorted";
+    if (linalgOp->hasAttr(sorted))
+      return failure();
+
+    auto scheduler = LoopScheduler::fromGenericOp(linalgOp);
+    bool isAdmissible = false;
+    AffineMap order;
+    // A const list of all masks that we used for iteration graph
+    // computation. Must be ordered from more strict to less strict.
+    // Ideally (though might not be guaranteed), the earlier a constraint mask...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/72423


More information about the Mlir-commits mailing list