[Mlir-commits] [mlir] [mlir][sparse] schedule sparse kernels in a separate pass from sparsification. (PR #72423)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Nov 15 10:23:09 PST 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-sparse
Author: Peiming Liu (PeimingLiu)
<details>
<summary>Changes</summary>
---
Patch is 98.60 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/72423.diff
53 Files Affected:
- (modified) mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h (+3)
- (modified) mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp (+10)
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt (+1)
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp (+15-22)
- (added) mlir/lib/Dialect/SparseTensor/Transforms/LoopScheduler.cpp (+273)
- (added) mlir/lib/Dialect/SparseTensor/Transforms/LoopScheduler.h (+73)
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp (+179-16)
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp (+7-580)
- (modified) mlir/test/Dialect/SparseTensor/GPU/gpu_combi.mlir (+1-1)
- (modified) mlir/test/Dialect/SparseTensor/GPU/gpu_matmul.mlir (+1)
- (modified) mlir/test/Dialect/SparseTensor/GPU/gpu_matvec.mlir (+1)
- (modified) mlir/test/Dialect/SparseTensor/constant_index_map.mlir (+1-1)
- (modified) mlir/test/Dialect/SparseTensor/dense.mlir (+1-1)
- (modified) mlir/test/Dialect/SparseTensor/one_trip.mlir (+1-1)
- (modified) mlir/test/Dialect/SparseTensor/semi_ring.mlir (+1-1)
- (modified) mlir/test/Dialect/SparseTensor/sorted_coo.mlir (+1-1)
- (modified) mlir/test/Dialect/SparseTensor/sparse_1d.mlir (+1-1)
- (modified) mlir/test/Dialect/SparseTensor/sparse_2d.mlir (+1-1)
- (modified) mlir/test/Dialect/SparseTensor/sparse_3d.mlir (+1-1)
- (modified) mlir/test/Dialect/SparseTensor/sparse_affine.mlir (+1-1)
- (modified) mlir/test/Dialect/SparseTensor/sparse_broadcast.mlir (+1-1)
- (modified) mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir (+1-1)
- (modified) mlir/test/Dialect/SparseTensor/sparse_expand.mlir (+4-1)
- (modified) mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir (+1-1)
- (modified) mlir/test/Dialect/SparseTensor/sparse_fp_ops.mlir (+1-1)
- (modified) mlir/test/Dialect/SparseTensor/sparse_index.mlir (+1-1)
- (modified) mlir/test/Dialect/SparseTensor/sparse_int_ops.mlir (+1-2)
- (modified) mlir/test/Dialect/SparseTensor/sparse_kernels.mlir (+28-26)
- (modified) mlir/test/Dialect/SparseTensor/sparse_lower.mlir (+3-3)
- (modified) mlir/test/Dialect/SparseTensor/sparse_lower_col.mlir (+7-6)
- (modified) mlir/test/Dialect/SparseTensor/sparse_lower_inplace.mlir (+3-3)
- (modified) mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir (+1-1)
- (modified) mlir/test/Dialect/SparseTensor/sparse_nd.mlir (+1-1)
- (modified) mlir/test/Dialect/SparseTensor/sparse_out.mlir (+1-1)
- (modified) mlir/test/Dialect/SparseTensor/sparse_outbuf.mlir (+1-2)
- (modified) mlir/test/Dialect/SparseTensor/sparse_parallel.mlir (+5-5)
- (modified) mlir/test/Dialect/SparseTensor/sparse_parallel_reduce.mlir (+1-1)
- (modified) mlir/test/Dialect/SparseTensor/sparse_perm.mlir (+8-6)
- (modified) mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir (+7-6)
- (modified) mlir/test/Dialect/SparseTensor/sparse_scalars.mlir (+1-1)
- (modified) mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir (+1-1)
- (modified) mlir/test/Dialect/SparseTensor/sparse_sddmm_org.mlir (+1-1)
- (modified) mlir/test/Dialect/SparseTensor/sparse_storage.mlir (+1-1)
- (modified) mlir/test/Dialect/SparseTensor/sparse_transpose.mlir (+7-7)
- (modified) mlir/test/Dialect/SparseTensor/sparse_vector.mlir (+4-4)
- (modified) mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir (+1-1)
- (modified) mlir/test/Dialect/SparseTensor/sparse_vector_index.mlir (+1-1)
- (modified) mlir/test/Dialect/SparseTensor/sparse_vector_ops.mlir (+1-2)
- (modified) mlir/test/Dialect/SparseTensor/sparse_vector_peeled.mlir (+1-1)
- (modified) mlir/test/Dialect/SparseTensor/spy_sddmm.mlir (+1-1)
- (modified) mlir/test/Dialect/SparseTensor/unsparsifiable_dense_op.mlir (+1-1)
- (modified) mlir/test/Dialect/SparseTensor/unused-tensor.mlir (+1-1)
- (modified) mlir/test/Dialect/SparseTensor/vectorize_reduction.mlir (+2-2)
``````````diff
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
index 2fb61746d048999..bb8df626086bb29 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
@@ -127,6 +127,9 @@ inline bool hasAnySparseOperandOrResult(Operation *op) {
return hasAnySparseOperand(op) || hasAnySparseResult(op);
}
+/// Whether the operation has any sparse tensor with non-identity dim2lvl maps.
+bool hasAnyNonIdentityOperandsOrResults(Operation *op);
+
//
// Inference.
//
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 92bf8ec6468e532..53a0c2428842f59 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -875,6 +875,16 @@ bool mlir::sparse_tensor::isUniqueCOOType(Type tp) {
return isCOOType(getSparseTensorEncoding(tp), 0, /*isUnique=*/true);
}
+bool mlir::sparse_tensor::hasAnyNonIdentityOperandsOrResults(Operation *op) {
+ auto hasNonIdentityMap = [](Value v) {
+ auto stt = tryGetSparseTensorType(v);
+ return stt && !stt->isIdentity();
+ };
+
+ return llvm::any_of(op->getOperands(), hasNonIdentityMap) ||
+ llvm::any_of(op->getResults(), hasNonIdentityMap);
+}
+
Level mlir::sparse_tensor::getCOOStart(SparseTensorEncodingAttr enc) {
// We only consider COO region with at least two levels for the purpose
// of AOS storage optimization.
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
index b8a2ff26b6794f6..06d0ec1d7eb7cf3 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms
CodegenEnv.cpp
CodegenUtils.cpp
LoopEmitter.cpp
+ LoopScheduler.cpp
SparseBufferRewriting.cpp
SparseGPUCodegen.cpp
SparseReinterpretMap.cpp
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
index 5c7cc93737b7fd7..3df9bcc77717994 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
@@ -57,7 +57,11 @@ CodegenEnv::CodegenEnv(linalg::GenericOp linop, SparsificationOptions opts,
loopEmitter(), topSort(), sparseOut(nullptr), outerParNest(-1u),
insChain(), expValues(), expFilled(), expAdded(), expCount(), redVal(),
redExp(detail::kInvalidId), redCustom(detail::kInvalidId),
- redValidLexInsert() {}
+ redValidLexInsert() {
+ // TODO: remove topSort, loops should be already sorted by previous pass.
+ for (unsigned l = 0; l < latticeMerger.getNumLoops(); l++)
+ topSort.push_back(l);
+}
LogicalResult CodegenEnv::initTensorExp() {
// Builds the tensor expression for the Linalg operation in SSA form.
@@ -181,36 +185,25 @@ bool CodegenEnv::isAdmissibleTensorExp(ExprId exp) {
// Accept "truly dynamic" if the output tensor materializes uninitialized
// into the computation and insertions occur in lexicographic index order.
sparseOut = lhs;
- return isMaterializing(lhs->get());
-}
-bool CodegenEnv::isAdmissibleTopoOrder() {
- if (!hasSparseOutput())
- return true;
-
- OpOperand *lhs = linalgOp.getDpsInitOperand(0);
- // Accept "truly dynamic" if the output tensor materializes uninitialized
- // into the computation and insertions occur in lexicographic index order.
- LoopOrd nest = 0;
+ // Find the outermost parallel nest to determine whether compress/expand is
+ // needed.
+ outerParNest = 0;
const auto iteratorTypes = linalgOp.getIteratorTypesArray();
assert(topSortSize() == latticeMerger.getNumLoops());
for (const LoopId i : topSort) {
if (!latticeMerger.isFilterLoop(i)) {
- // We only count non-filter loops as filter loops should be considered
- // a special type of parallel loops.
if (linalg::isReductionIterator(iteratorTypes[i]))
break; // terminate at first reduction
- nest++;
+ outerParNest++;
}
}
- // Determine admissible dynamic insertion situations:
- // (1) fully injective, since there are no reductions,
- // (2) admissible 1-d expansion in innermost dimension.
- if (static_cast<int64_t>(nest) >= linalgOp.getRank(lhs) - 1) {
- outerParNest = nest;
- return true;
- }
- return false;
+
+ // Inadmissible kernel should have already been rejected by the previous
+ // path during loop scheduling.
+ assert(static_cast<int64_t>(outerParNest) >=
+ linalgOp.getRank(linalgOp.getDpsInitOperand(0)) - 1);
+ return isMaterializing(lhs->get());
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopScheduler.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopScheduler.cpp
new file mode 100644
index 000000000000000..32e0a6f5c03d8a2
--- /dev/null
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopScheduler.cpp
@@ -0,0 +1,273 @@
+//===- LoopScheduler.cpp -------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "LoopScheduler.h"
+
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
+#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/IR/AffineExprVisitor.h"
+#include "mlir/IR/BuiltinTypes.h"
+
+using namespace mlir;
+using namespace mlir::sparse_tensor;
+
+namespace {
+
+/// A helper class that visits an affine expression and tries to find an
+/// AffineDimExpr to which the corresponding iterator from a GenericOp matches
+/// the desired iterator type.
+/// If there is no matched iterator type, returns the first DimExpr in the
+/// expression.
+class AffineDimFinder : public AffineExprVisitor<AffineDimFinder> {
+public:
+ explicit AffineDimFinder(ArrayRef<utils::IteratorType> itTypes)
+ : iterTypes(itTypes) {}
+
+ // Overrides method from AffineExprVisitor.
+ void visitDimExpr(AffineDimExpr expr) {
+ if (pickedDim == nullptr || pickIterType == iterTypes[expr.getPosition()])
+ pickedDim = expr;
+ }
+
+ /// Set the desired iterator type that we want to pick.
+ void setPickedIterType(utils::IteratorType iterType) {
+ pickIterType = iterType;
+ }
+
+ /// Get the desired AffineDimExpr.
+ AffineDimExpr getDimExpr() const {
+ return llvm::cast<AffineDimExpr>(pickedDim);
+ }
+
+ void walkPostOrder(AffineExpr expr) {
+ pickedDim = nullptr;
+ AffineExprVisitor<AffineDimFinder>::walkPostOrder(expr);
+ }
+
+private:
+ /// The picked AffineDimExpr after visit.
+ AffineExpr pickedDim;
+ /// The iterator type that we want.
+ utils::IteratorType pickIterType;
+ /// The mapping between dim=>iterator type.
+ ArrayRef<utils::IteratorType> iterTypes;
+};
+
+// Flattens an affine expression into a list of AffineDimExprs.
+struct AffineDimCollector : public AffineExprVisitor<AffineDimCollector> {
+ // Overrides method from AffineExprVisitor.
+ void visitDimExpr(AffineDimExpr expr) { dims.push_back(expr); }
+ SmallVector<AffineDimExpr> dims;
+};
+
+} // namespace
+
+inline static bool includesAny(SortMask mask1, SortMask mask2) {
+ return static_cast<unsigned>(mask1) & static_cast<unsigned>(mask2);
+}
+
+inline static bool includesDenseInput(SortMask mask) {
+ return includesAny(mask, SortMask::kIncludeDenseInput);
+}
+
+inline static bool includesDenseOutput(SortMask mask) {
+ return includesAny(mask, SortMask::kIncludeDenseOutput);
+}
+
+/// A helper to compute a topological sort. O(n^2) time complexity
+/// as we use adj matrix for the graph.
+/// The sorted result will put the first Reduction iterator to the
+/// latest possible position.
+AffineMap LoopScheduler::topoSort() {
+ std::vector<unsigned> redIt; // reduce iterator with 0 degree
+ std::vector<unsigned> parIt; // parallel iterator with 0 degree
+ const unsigned numLoops = getNumLoops();
+ for (unsigned i = 0; i < numLoops; i++) {
+ if (inDegree[i] == 0) {
+ if (iterTypes[i] == utils::IteratorType::reduction)
+ redIt.push_back(i);
+ else
+ parIt.push_back(i);
+ }
+ }
+
+ SmallVector<unsigned> loopOrder;
+ while (!redIt.empty() || !parIt.empty()) {
+ // We always prefer parallel loop over reduction loop because putting
+ // reduction loop early might make the loop sequence inadmissible.
+ auto &it = !parIt.empty() ? parIt : redIt;
+ auto src = it.back();
+ loopOrder.push_back(src);
+ it.pop_back();
+ // Update in-degree, and push 0-degree node into worklist.
+ for (unsigned dst = 0; dst < numLoops; dst++) {
+ if (itGraph[src][dst] && --inDegree[dst] == 0) {
+ if (iterTypes[dst] == utils::IteratorType::reduction)
+ redIt.push_back(dst);
+ else
+ parIt.push_back(dst);
+ }
+ }
+ }
+
+ if (loopOrder.size() == numLoops)
+ return AffineMap::getPermutationMap(loopOrder, out.getContext());
+
+ // Cycle detected.
+ return AffineMap();
+}
+
+LoopScheduler LoopScheduler::fromGenericOp(linalg::GenericOp genericOp) {
+ // Must be a demapped sparse kernel.
+ assert(!hasAnyNonIdentityOperandsOrResults(genericOp) &&
+ hasAnySparseOperandOrResult(genericOp) &&
+ genericOp.getNumDpsInits() == 1);
+
+ SmallVector<AffineMap> loopMap = genericOp.getIndexingMapsArray();
+ SmallVector<Value> ins = genericOp.getDpsInputs();
+
+ AffineMap outMap = loopMap.back();
+ loopMap.pop_back();
+
+ Value out = genericOp.getDpsInitOperand(0)->get();
+ SmallVector<utils::IteratorType> iterTypes =
+ genericOp.getIteratorTypesArray();
+
+ return LoopScheduler(std::move(ins), std::move(loopMap), out, outMap,
+ std::move(iterTypes));
+}
+
+LoopScheduler::LoopScheduler(SmallVector<Value> &&ins,
+ SmallVector<AffineMap> &&loop2InsLvl, Value out,
+ AffineMap loop2OutLvl,
+ SmallVector<utils::IteratorType> &&iterTypes)
+ : ins(std::move(ins)), loop2InsLvl(std::move(loop2InsLvl)), out(out),
+ loop2OutLvl(loop2OutLvl), iterTypes(std::move(iterTypes)) {
+ // One map per tensor.
+ assert(loop2InsLvl.size() == ins.size());
+ // All the affine maps have the same number of dimensions (loops).
+ assert(llvm::all_equal(llvm::map_range(
+ loop2InsLvl, [](AffineMap m) { return m.getNumDims(); })));
+ // The number of results of the map should match the rank of the tensor.
+ assert(llvm::all_of(llvm::zip(loop2InsLvl, ins), [](auto mvPair) {
+ auto [m, v] = mvPair;
+ return m.getNumResults() ==
+ v.getType().template cast<ShapedType>().getRank();
+ }));
+
+ itGraph.resize(getNumLoops(), std::vector<bool>(getNumLoops(), false));
+ inDegree.resize(getNumLoops());
+}
+
+AffineMap LoopScheduler::schedule(SortMask mask, Value ignored) {
+ // Reset the interation graph.
+ for (auto &row : itGraph)
+ std::fill(row.begin(), row.end(), false);
+ // Reset cached in-degree.
+ std::fill(inDegree.begin(), inDegree.end(), 0);
+
+ for (auto [in, map] : llvm::zip(ins, loop2InsLvl)) {
+ // Get map and encoding.
+ const auto enc = getSparseTensorEncoding(in.getType());
+ // Skips dense inputs when not requested.
+ if ((!enc && !includesDenseInput(mask)) || in == ignored)
+ continue;
+
+ addConstraints(in, map);
+ }
+
+ // Get map and encoding.
+ const auto enc = getSparseTensorEncoding(out.getType());
+ if ((enc || includesDenseOutput(mask)) && out != ignored)
+ addConstraints(out, loop2OutLvl);
+
+ return topoSort();
+}
+
+void LoopScheduler::addConstraints(Value t, AffineMap loop2LvlMap) {
+ auto addIterOrdering = [this](unsigned f, unsigned t) {
+ if (!itGraph[f][t] && f != t) {
+ itGraph[f][t] = true;
+ inDegree[t]++;
+ }
+ };
+
+ AffineDimFinder finder(iterTypes);
+ finder.setPickedIterType(utils::IteratorType::reduction);
+
+ // To compute iteration graph for tensor[d0 + d1 + d3, d4 + d5 + d6],
+ // we requires there exist d_x \in {d0, d1, d3} and d_y \in {d4, d5, d6},
+ // and d_x > d_y && {d0, d1, d3} - d_x > {d4, d5, d6} - d_y
+ const Level lvlRank = loop2LvlMap.getNumResults();
+ for (Level lvl = 1; lvl < lvlRank; lvl++) {
+ const AffineExpr fa = loop2LvlMap.getResult(lvl - 1);
+ const AffineExpr ta = loop2LvlMap.getResult(lvl);
+
+ if (llvm::isa<AffineDimExpr>(fa) || llvm::isa<AffineDimExpr>(ta)) {
+ // Special case when at least one loop2LvlExp is an simple AffineDimExpr
+ // (say, d0) and we require d0 > {d1, d2, ...} or {d1, d2, ...} > d0
+ AffineDimCollector fCollector;
+ fCollector.walkPostOrder(fa);
+ AffineDimCollector tCollector;
+ tCollector.walkPostOrder(ta);
+
+ for (auto fd : fCollector.dims) {
+ for (auto td : tCollector.dims) {
+ const unsigned f = fd.getPosition();
+ const unsigned t = td.getPosition();
+ addIterOrdering(f, t);
+ }
+ }
+ continue;
+ }
+
+ // When both loop2LvlExpr is compound, we pick an abitrary reduction loop
+ // from lhs and rhs and use them as d_x and d_y.
+ finder.walkPostOrder(fa);
+ const AffineDimExpr fexp = finder.getDimExpr();
+ const unsigned fldx = fexp.getPosition();
+
+ finder.walkPostOrder(ta);
+ const AffineDimExpr texp = finder.getDimExpr();
+ const unsigned tldx = texp.getPosition();
+
+ // d_x > d_y
+ addIterOrdering(fldx, tldx);
+
+ AffineDimCollector fCollector;
+ fCollector.walkPostOrder(fa);
+ AffineDimCollector tCollector;
+ tCollector.walkPostOrder(ta);
+
+ // make sure dx and dy is the last;
+ for (auto fd : fCollector.dims) {
+ const unsigned f = fd.getPosition();
+ addIterOrdering(f, fldx);
+ }
+ for (auto td : tCollector.dims) {
+ const unsigned t = td.getPosition();
+ addIterOrdering(t, tldx);
+ }
+ // {d0, d1, d3} - d_x > {d4, d5, d6} - d_y
+ // This is to ensure that the affine expressions are reduced in sparse
+ // tensor level ordering.
+ for (auto fd : fCollector.dims) {
+ const unsigned f = fd.getPosition();
+ if (f == fldx) // skip d_x
+ continue;
+ for (auto td : tCollector.dims) {
+ const unsigned t = td.getPosition();
+ if (t == tldx) // skip d_y
+ continue;
+ addIterOrdering(f, t);
+ }
+ }
+ }
+}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopScheduler.h b/mlir/lib/Dialect/SparseTensor/Transforms/LoopScheduler.h
new file mode 100644
index 000000000000000..1d5cb1bc6f60e4e
--- /dev/null
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopScheduler.h
@@ -0,0 +1,73 @@
+//===- LoopScheduler.h -----------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/AffineMap.h"
+
+namespace mlir {
+
+// Forward declarations.
+class Value;
+namespace utils {
+enum class IteratorType : uint32_t;
+} // namespace utils
+namespace linalg {
+class GenericOp;
+} // namespace linalg
+
+namespace sparse_tensor {
+
+/// Iteration graph sorting.
+enum class SortMask : unsigned {
+ // The individual mask bits.
+ kIncludeDenseOutput = 0x1, // b001
+ kIncludeDenseInput = 0x2, // b010
+ kIncludeUndef = 0x4, // b100
+ // The subsets of mask bits.
+ kIncludeAll = 0x7, // b111
+ kIncludeDense = 0x3, // b011
+ kSparseOnly = 0x0, // b000
+};
+
+class LoopScheduler {
+public:
+ // Constructs a scheduler from linalg.generic
+ // Maybe reuses the class to schedule foreach as well (to address
+ // non-permutation, e.g, traverse CSR in BSR order).
+ static LoopScheduler fromGenericOp(linalg::GenericOp genericOp);
+
+ // Returns a permutation that represents the scheduled loop order.
+ // Note that the returned AffineMap could be null if the kernel can not be
+ // schedule due to cycles in the iteration graph.
+ [[nodiscard]] AffineMap schedule(SortMask mask, Value ignored = nullptr);
+ unsigned getNumLoops() const { return loop2OutLvl.getNumDims(); }
+
+private:
+ LoopScheduler(SmallVector<Value> &&ins, SmallVector<AffineMap> &&loop2InsLvl,
+ Value out, AffineMap loop2OutLvl,
+ SmallVector<utils::IteratorType> &&iterTypes);
+
+ void addConstraints(Value t, AffineMap loop2LvlMap);
+ AffineMap topoSort();
+
+ // Input tensors and associated loop to level maps.
+ SmallVector<Value> ins;
+ SmallVector<AffineMap> loop2InsLvl;
+ // Output tensor and associated loop to level map.
+ Value out;
+ AffineMap loop2OutLvl;
+ // Loop type;
+ SmallVector<utils::IteratorType> iterTypes;
+
+ // Adjacent matrix that represents the iteration graph.
+ std::vector<std::vector<bool>> itGraph;
+ // InDegree used for topo sort.
+ std::vector<unsigned> inDegree;
+};
+
+} // namespace sparse_tensor
+} // namespace mlir
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
index 83e86d137335021..a433b2edc5500e7 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
@@ -1,4 +1,4 @@
-//===- SparseReinterpretMap.cpp - reinterpret sparse tensor maps ----------===//
+//===- SparseReinterpretMap.cpp - reinterpret sparse tensor maps ----------===/
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "CodegenUtils.h"
+#include "LoopScheduler.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
@@ -351,17 +352,6 @@ static SmallVector<Value> remapValueRange(OpBuilder &rewriter, TypeRange types,
return ret;
}
-/// Whether the operation has any sparse tensor with non-identity dim2lvl maps.
-static bool hasNonIdentityOperandsOrResults(Operation *op) {
- auto hasNonIdentityMap = [](Value v) {
- auto stt = tryGetSparseTensorType(v);
- return stt && !stt->isIdentity();
- };
-
- return llvm::any_of(op->getOperands(), hasNonIdentityMap) ||
- llvm::any_of(op->getResults(), hasNonIdentityMap);
-}
-
namespace {
//===----------------------------------------------------------------------===//
@@ -379,7 +369,7 @@ struct GenericOpReinterpretMap
// semantics.
if (linalgOp.getNumDpsInits() != 1 || !linalgOp.hasTensorSemantics() ||
!hasAnySparseOperandOrResult(linalgOp) ||
- !hasNonIdentityOperandsOrResults(linalgOp))
+ !hasAnyNonIdentityOperandsOrResults(linalgOp))
return failure();
// Try translating the index map.
@@ -411,6 +401,178 @@ struct GenericOpReinterpretMap
}
};
+struct GenericOpScheduler : public OpRewritePattern<linalg::GenericOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(linalg::GenericOp linalgOp,
+ PatternRewriter &rewriter) const override {
+ if (linalgOp.getNumDpsInits() != 1 || !linalgOp.hasTensorSemantics() ||
+ hasAnyNonIdentityOperandsOrResults(linalgOp) || // need demap first
+ !hasAnySparseOperandOrResult(linalgOp)) {
+ return failure();
+ }
+
+ const StringRef sorted = "sorted";
+ if (linalgOp->hasAttr(sorted))
+ return failure();
+
+ auto scheduler = LoopScheduler::fromGenericOp(linalgOp);
+ bool isAdmissible = false;
+ AffineMap order;
+ // A const list of all masks that we used for iteration graph
+ // computation. Must be ordered from more strict to less strict.
+ // Ideally (though might not be guaranteed), the earlier a constraint mask...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/72423
More information about the Mlir-commits
mailing list