[Mlir-commits] [mlir] [mlir][sparse] schedule sparse kernels in a separate pass from sparsification. (PR #72423)
Peiming Liu
llvmlistbot at llvm.org
Wed Nov 15 11:01:30 PST 2023
https://github.com/PeimingLiu updated https://github.com/llvm/llvm-project/pull/72423
>From 3ab8768197445e75b6bc4067ace2c91abf21ca38 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Tue, 14 Nov 2023 20:55:50 +0000
Subject: [PATCH 1/4] [mlir][sparse] schedule sparse loops
---
.../Dialect/SparseTensor/IR/SparseTensor.h | 3 +
.../SparseTensor/IR/SparseTensorDialect.cpp | 10 +
.../SparseTensor/Transforms/CMakeLists.txt | 1 +
.../SparseTensor/Transforms/CodegenEnv.cpp | 37 +-
.../SparseTensor/Transforms/LoopScheduler.cpp | 273 ++++++++
.../SparseTensor/Transforms/LoopScheduler.h | 73 +++
.../Transforms/SparseReinterpretMap.cpp | 195 +++++-
.../Transforms/Sparsification.cpp | 587 +-----------------
.../Dialect/SparseTensor/GPU/gpu_combi.mlir | 2 +-
.../Dialect/SparseTensor/GPU/gpu_matmul.mlir | 1 +
.../Dialect/SparseTensor/GPU/gpu_matvec.mlir | 1 +
.../SparseTensor/constant_index_map.mlir | 2 +-
mlir/test/Dialect/SparseTensor/dense.mlir | 2 +-
mlir/test/Dialect/SparseTensor/one_trip.mlir | 2 +-
mlir/test/Dialect/SparseTensor/semi_ring.mlir | 2 +-
.../test/Dialect/SparseTensor/sorted_coo.mlir | 2 +-
mlir/test/Dialect/SparseTensor/sparse_1d.mlir | 2 +-
mlir/test/Dialect/SparseTensor/sparse_2d.mlir | 2 +-
mlir/test/Dialect/SparseTensor/sparse_3d.mlir | 2 +-
.../Dialect/SparseTensor/sparse_affine.mlir | 2 +-
.../SparseTensor/sparse_broadcast.mlir | 2 +-
.../sparse_conv_2d_slice_based.mlir | 2 +-
.../Dialect/SparseTensor/sparse_expand.mlir | 5 +-
.../SparseTensor/sparse_fill_zero.mlir | 2 +-
.../Dialect/SparseTensor/sparse_fp_ops.mlir | 2 +-
.../Dialect/SparseTensor/sparse_index.mlir | 2 +-
.../Dialect/SparseTensor/sparse_int_ops.mlir | 3 +-
.../Dialect/SparseTensor/sparse_kernels.mlir | 54 +-
.../Dialect/SparseTensor/sparse_lower.mlir | 6 +-
.../SparseTensor/sparse_lower_col.mlir | 13 +-
.../SparseTensor/sparse_lower_inplace.mlir | 6 +-
.../SparseTensor/sparse_matmul_codegen.mlir | 2 +-
mlir/test/Dialect/SparseTensor/sparse_nd.mlir | 2 +-
.../test/Dialect/SparseTensor/sparse_out.mlir | 2 +-
.../Dialect/SparseTensor/sparse_outbuf.mlir | 3 +-
.../Dialect/SparseTensor/sparse_parallel.mlir | 10 +-
.../SparseTensor/sparse_parallel_reduce.mlir | 2 +-
.../Dialect/SparseTensor/sparse_perm.mlir | 14 +-
.../SparseTensor/sparse_perm_lower.mlir | 13 +-
.../Dialect/SparseTensor/sparse_scalars.mlir | 2 +-
.../Dialect/SparseTensor/sparse_sddmm.mlir | 2 +-
.../SparseTensor/sparse_sddmm_org.mlir | 2 +-
.../Dialect/SparseTensor/sparse_storage.mlir | 2 +-
.../SparseTensor/sparse_transpose.mlir | 14 +-
.../Dialect/SparseTensor/sparse_vector.mlir | 8 +-
.../SparseTensor/sparse_vector_chain.mlir | 2 +-
.../SparseTensor/sparse_vector_index.mlir | 2 +-
.../SparseTensor/sparse_vector_ops.mlir | 3 +-
.../SparseTensor/sparse_vector_peeled.mlir | 2 +-
mlir/test/Dialect/SparseTensor/spy_sddmm.mlir | 2 +-
.../SparseTensor/unsparsifiable_dense_op.mlir | 2 +-
.../Dialect/SparseTensor/unused-tensor.mlir | 2 +-
.../SparseTensor/vectorize_reduction.mlir | 4 +-
53 files changed, 673 insertions(+), 722 deletions(-)
create mode 100644 mlir/lib/Dialect/SparseTensor/Transforms/LoopScheduler.cpp
create mode 100644 mlir/lib/Dialect/SparseTensor/Transforms/LoopScheduler.h
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
index 2fb61746d048999..bb8df626086bb29 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
@@ -127,6 +127,9 @@ inline bool hasAnySparseOperandOrResult(Operation *op) {
return hasAnySparseOperand(op) || hasAnySparseResult(op);
}
+/// Whether the operation has any sparse tensor with non-identity dim2lvl maps.
+bool hasAnyNonIdentityOperandsOrResults(Operation *op);
+
//
// Inference.
//
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 92bf8ec6468e532..53a0c2428842f59 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -875,6 +875,16 @@ bool mlir::sparse_tensor::isUniqueCOOType(Type tp) {
return isCOOType(getSparseTensorEncoding(tp), 0, /*isUnique=*/true);
}
+bool mlir::sparse_tensor::hasAnyNonIdentityOperandsOrResults(Operation *op) {
+ auto hasNonIdentityMap = [](Value v) {
+ auto stt = tryGetSparseTensorType(v);
+ return stt && !stt->isIdentity();
+ };
+
+ return llvm::any_of(op->getOperands(), hasNonIdentityMap) ||
+ llvm::any_of(op->getResults(), hasNonIdentityMap);
+}
+
Level mlir::sparse_tensor::getCOOStart(SparseTensorEncodingAttr enc) {
// We only consider COO region with at least two levels for the purpose
// of AOS storage optimization.
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
index b8a2ff26b6794f6..06d0ec1d7eb7cf3 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms
CodegenEnv.cpp
CodegenUtils.cpp
LoopEmitter.cpp
+ LoopScheduler.cpp
SparseBufferRewriting.cpp
SparseGPUCodegen.cpp
SparseReinterpretMap.cpp
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
index 5c7cc93737b7fd7..3df9bcc77717994 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
@@ -57,7 +57,11 @@ CodegenEnv::CodegenEnv(linalg::GenericOp linop, SparsificationOptions opts,
loopEmitter(), topSort(), sparseOut(nullptr), outerParNest(-1u),
insChain(), expValues(), expFilled(), expAdded(), expCount(), redVal(),
redExp(detail::kInvalidId), redCustom(detail::kInvalidId),
- redValidLexInsert() {}
+ redValidLexInsert() {
+ // TODO: remove topSort, loops should be already sorted by previous pass.
+ for (unsigned l = 0; l < latticeMerger.getNumLoops(); l++)
+ topSort.push_back(l);
+}
LogicalResult CodegenEnv::initTensorExp() {
// Builds the tensor expression for the Linalg operation in SSA form.
@@ -181,36 +185,25 @@ bool CodegenEnv::isAdmissibleTensorExp(ExprId exp) {
// Accept "truly dynamic" if the output tensor materializes uninitialized
// into the computation and insertions occur in lexicographic index order.
sparseOut = lhs;
- return isMaterializing(lhs->get());
-}
-bool CodegenEnv::isAdmissibleTopoOrder() {
- if (!hasSparseOutput())
- return true;
-
- OpOperand *lhs = linalgOp.getDpsInitOperand(0);
- // Accept "truly dynamic" if the output tensor materializes uninitialized
- // into the computation and insertions occur in lexicographic index order.
- LoopOrd nest = 0;
+ // Find the outermost parallel nest to determine whether compress/expand is
+ // needed.
+ outerParNest = 0;
const auto iteratorTypes = linalgOp.getIteratorTypesArray();
assert(topSortSize() == latticeMerger.getNumLoops());
for (const LoopId i : topSort) {
if (!latticeMerger.isFilterLoop(i)) {
- // We only count non-filter loops as filter loops should be considered
- // a special type of parallel loops.
if (linalg::isReductionIterator(iteratorTypes[i]))
break; // terminate at first reduction
- nest++;
+ outerParNest++;
}
}
- // Determine admissible dynamic insertion situations:
- // (1) fully injective, since there are no reductions,
- // (2) admissible 1-d expansion in innermost dimension.
- if (static_cast<int64_t>(nest) >= linalgOp.getRank(lhs) - 1) {
- outerParNest = nest;
- return true;
- }
- return false;
+
+ // Inadmissible kernel should have already been rejected by the previous
+ // path during loop scheduling.
+ assert(static_cast<int64_t>(outerParNest) >=
+ linalgOp.getRank(linalgOp.getDpsInitOperand(0)) - 1);
+ return isMaterializing(lhs->get());
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopScheduler.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopScheduler.cpp
new file mode 100644
index 000000000000000..32e0a6f5c03d8a2
--- /dev/null
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopScheduler.cpp
@@ -0,0 +1,273 @@
+//===- LoopScheduler.cpp -------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "LoopScheduler.h"
+
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
+#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/IR/AffineExprVisitor.h"
+#include "mlir/IR/BuiltinTypes.h"
+
+using namespace mlir;
+using namespace mlir::sparse_tensor;
+
+namespace {
+
+/// A helper class that visits an affine expression and tries to find an
+/// AffineDimExpr to which the corresponding iterator from a GenericOp matches
+/// the desired iterator type.
+/// If there is no matched iterator type, returns the first DimExpr in the
+/// expression.
+class AffineDimFinder : public AffineExprVisitor<AffineDimFinder> {
+public:
+ explicit AffineDimFinder(ArrayRef<utils::IteratorType> itTypes)
+ : iterTypes(itTypes) {}
+
+ // Overrides method from AffineExprVisitor.
+ void visitDimExpr(AffineDimExpr expr) {
+ if (pickedDim == nullptr || pickIterType == iterTypes[expr.getPosition()])
+ pickedDim = expr;
+ }
+
+ /// Set the desired iterator type that we want to pick.
+ void setPickedIterType(utils::IteratorType iterType) {
+ pickIterType = iterType;
+ }
+
+ /// Get the desired AffineDimExpr.
+ AffineDimExpr getDimExpr() const {
+ return llvm::cast<AffineDimExpr>(pickedDim);
+ }
+
+ void walkPostOrder(AffineExpr expr) {
+ pickedDim = nullptr;
+ AffineExprVisitor<AffineDimFinder>::walkPostOrder(expr);
+ }
+
+private:
+ /// The picked AffineDimExpr after visit.
+ AffineExpr pickedDim;
+ /// The iterator type that we want.
+ utils::IteratorType pickIterType;
+ /// The mapping between dim=>iterator type.
+ ArrayRef<utils::IteratorType> iterTypes;
+};
+
+// Flattens an affine expression into a list of AffineDimExprs.
+struct AffineDimCollector : public AffineExprVisitor<AffineDimCollector> {
+ // Overrides method from AffineExprVisitor.
+ void visitDimExpr(AffineDimExpr expr) { dims.push_back(expr); }
+ SmallVector<AffineDimExpr> dims;
+};
+
+} // namespace
+
+inline static bool includesAny(SortMask mask1, SortMask mask2) {
+ return static_cast<unsigned>(mask1) & static_cast<unsigned>(mask2);
+}
+
+inline static bool includesDenseInput(SortMask mask) {
+ return includesAny(mask, SortMask::kIncludeDenseInput);
+}
+
+inline static bool includesDenseOutput(SortMask mask) {
+ return includesAny(mask, SortMask::kIncludeDenseOutput);
+}
+
+/// A helper to compute a topological sort. O(n^2) time complexity
+/// as we use adj matrix for the graph.
+/// The sorted result will put the first Reduction iterator to the
+/// latest possible position.
+AffineMap LoopScheduler::topoSort() {
+ std::vector<unsigned> redIt; // reduce iterator with 0 degree
+ std::vector<unsigned> parIt; // parallel iterator with 0 degree
+ const unsigned numLoops = getNumLoops();
+ for (unsigned i = 0; i < numLoops; i++) {
+ if (inDegree[i] == 0) {
+ if (iterTypes[i] == utils::IteratorType::reduction)
+ redIt.push_back(i);
+ else
+ parIt.push_back(i);
+ }
+ }
+
+ SmallVector<unsigned> loopOrder;
+ while (!redIt.empty() || !parIt.empty()) {
+ // We always prefer parallel loop over reduction loop because putting
+ // reduction loop early might make the loop sequence inadmissible.
+ auto &it = !parIt.empty() ? parIt : redIt;
+ auto src = it.back();
+ loopOrder.push_back(src);
+ it.pop_back();
+ // Update in-degree, and push 0-degree node into worklist.
+ for (unsigned dst = 0; dst < numLoops; dst++) {
+ if (itGraph[src][dst] && --inDegree[dst] == 0) {
+ if (iterTypes[dst] == utils::IteratorType::reduction)
+ redIt.push_back(dst);
+ else
+ parIt.push_back(dst);
+ }
+ }
+ }
+
+ if (loopOrder.size() == numLoops)
+ return AffineMap::getPermutationMap(loopOrder, out.getContext());
+
+ // Cycle detected.
+ return AffineMap();
+}
+
+LoopScheduler LoopScheduler::fromGenericOp(linalg::GenericOp genericOp) {
+ // Must be a demapped sparse kernel.
+ assert(!hasAnyNonIdentityOperandsOrResults(genericOp) &&
+ hasAnySparseOperandOrResult(genericOp) &&
+ genericOp.getNumDpsInits() == 1);
+
+ SmallVector<AffineMap> loopMap = genericOp.getIndexingMapsArray();
+ SmallVector<Value> ins = genericOp.getDpsInputs();
+
+ AffineMap outMap = loopMap.back();
+ loopMap.pop_back();
+
+ Value out = genericOp.getDpsInitOperand(0)->get();
+ SmallVector<utils::IteratorType> iterTypes =
+ genericOp.getIteratorTypesArray();
+
+ return LoopScheduler(std::move(ins), std::move(loopMap), out, outMap,
+ std::move(iterTypes));
+}
+
+LoopScheduler::LoopScheduler(SmallVector<Value> &&ins,
+ SmallVector<AffineMap> &&loop2InsLvl, Value out,
+ AffineMap loop2OutLvl,
+ SmallVector<utils::IteratorType> &&iterTypes)
+ : ins(std::move(ins)), loop2InsLvl(std::move(loop2InsLvl)), out(out),
+ loop2OutLvl(loop2OutLvl), iterTypes(std::move(iterTypes)) {
+ // One map per tensor.
+ assert(loop2InsLvl.size() == ins.size());
+ // All the affine maps have the same number of dimensions (loops).
+ assert(llvm::all_equal(llvm::map_range(
+ loop2InsLvl, [](AffineMap m) { return m.getNumDims(); })));
+ // The number of results of the map should match the rank of the tensor.
+ assert(llvm::all_of(llvm::zip(loop2InsLvl, ins), [](auto mvPair) {
+ auto [m, v] = mvPair;
+ return m.getNumResults() ==
+ v.getType().template cast<ShapedType>().getRank();
+ }));
+
+ itGraph.resize(getNumLoops(), std::vector<bool>(getNumLoops(), false));
+ inDegree.resize(getNumLoops());
+}
+
+AffineMap LoopScheduler::schedule(SortMask mask, Value ignored) {
+ // Reset the interation graph.
+ for (auto &row : itGraph)
+ std::fill(row.begin(), row.end(), false);
+ // Reset cached in-degree.
+ std::fill(inDegree.begin(), inDegree.end(), 0);
+
+ for (auto [in, map] : llvm::zip(ins, loop2InsLvl)) {
+ // Get map and encoding.
+ const auto enc = getSparseTensorEncoding(in.getType());
+ // Skips dense inputs when not requested.
+ if ((!enc && !includesDenseInput(mask)) || in == ignored)
+ continue;
+
+ addConstraints(in, map);
+ }
+
+ // Get map and encoding.
+ const auto enc = getSparseTensorEncoding(out.getType());
+ if ((enc || includesDenseOutput(mask)) && out != ignored)
+ addConstraints(out, loop2OutLvl);
+
+ return topoSort();
+}
+
+void LoopScheduler::addConstraints(Value t, AffineMap loop2LvlMap) {
+ auto addIterOrdering = [this](unsigned f, unsigned t) {
+ if (!itGraph[f][t] && f != t) {
+ itGraph[f][t] = true;
+ inDegree[t]++;
+ }
+ };
+
+ AffineDimFinder finder(iterTypes);
+ finder.setPickedIterType(utils::IteratorType::reduction);
+
+ // To compute iteration graph for tensor[d0 + d1 + d3, d4 + d5 + d6],
+ // we requires there exist d_x \in {d0, d1, d3} and d_y \in {d4, d5, d6},
+ // and d_x > d_y && {d0, d1, d3} - d_x > {d4, d5, d6} - d_y
+ const Level lvlRank = loop2LvlMap.getNumResults();
+ for (Level lvl = 1; lvl < lvlRank; lvl++) {
+ const AffineExpr fa = loop2LvlMap.getResult(lvl - 1);
+ const AffineExpr ta = loop2LvlMap.getResult(lvl);
+
+ if (llvm::isa<AffineDimExpr>(fa) || llvm::isa<AffineDimExpr>(ta)) {
+ // Special case when at least one loop2LvlExp is an simple AffineDimExpr
+ // (say, d0) and we require d0 > {d1, d2, ...} or {d1, d2, ...} > d0
+ AffineDimCollector fCollector;
+ fCollector.walkPostOrder(fa);
+ AffineDimCollector tCollector;
+ tCollector.walkPostOrder(ta);
+
+ for (auto fd : fCollector.dims) {
+ for (auto td : tCollector.dims) {
+ const unsigned f = fd.getPosition();
+ const unsigned t = td.getPosition();
+ addIterOrdering(f, t);
+ }
+ }
+ continue;
+ }
+
+ // When both loop2LvlExpr is compound, we pick an abitrary reduction loop
+ // from lhs and rhs and use them as d_x and d_y.
+ finder.walkPostOrder(fa);
+ const AffineDimExpr fexp = finder.getDimExpr();
+ const unsigned fldx = fexp.getPosition();
+
+ finder.walkPostOrder(ta);
+ const AffineDimExpr texp = finder.getDimExpr();
+ const unsigned tldx = texp.getPosition();
+
+ // d_x > d_y
+ addIterOrdering(fldx, tldx);
+
+ AffineDimCollector fCollector;
+ fCollector.walkPostOrder(fa);
+ AffineDimCollector tCollector;
+ tCollector.walkPostOrder(ta);
+
+ // make sure dx and dy is the last;
+ for (auto fd : fCollector.dims) {
+ const unsigned f = fd.getPosition();
+ addIterOrdering(f, fldx);
+ }
+ for (auto td : tCollector.dims) {
+ const unsigned t = td.getPosition();
+ addIterOrdering(t, tldx);
+ }
+ // {d0, d1, d3} - d_x > {d4, d5, d6} - d_y
+ // This is to ensure that the affine expressions are reduced in sparse
+ // tensor level ordering.
+ for (auto fd : fCollector.dims) {
+ const unsigned f = fd.getPosition();
+ if (f == fldx) // skip d_x
+ continue;
+ for (auto td : tCollector.dims) {
+ const unsigned t = td.getPosition();
+ if (t == tldx) // skip d_y
+ continue;
+ addIterOrdering(f, t);
+ }
+ }
+ }
+}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopScheduler.h b/mlir/lib/Dialect/SparseTensor/Transforms/LoopScheduler.h
new file mode 100644
index 000000000000000..1d5cb1bc6f60e4e
--- /dev/null
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopScheduler.h
@@ -0,0 +1,73 @@
+//===- LoopScheduler.h -----------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/AffineMap.h"
+
+namespace mlir {
+
+// Forward declarations.
+class Value;
+namespace utils {
+enum class IteratorType : uint32_t;
+} // namespace utils
+namespace linalg {
+class GenericOp;
+} // namespace linalg
+
+namespace sparse_tensor {
+
+/// Iteration graph sorting.
+enum class SortMask : unsigned {
+ // The individual mask bits.
+ kIncludeDenseOutput = 0x1, // b001
+ kIncludeDenseInput = 0x2, // b010
+ kIncludeUndef = 0x4, // b100
+ // The subsets of mask bits.
+ kIncludeAll = 0x7, // b111
+ kIncludeDense = 0x3, // b011
+ kSparseOnly = 0x0, // b000
+};
+
+class LoopScheduler {
+public:
+ // Constructs a scheduler from linalg.generic
+ // Maybe reuses the class to schedule foreach as well (to address
+ // non-permutation, e.g, traverse CSR in BSR order).
+ static LoopScheduler fromGenericOp(linalg::GenericOp genericOp);
+
+ // Returns a permutation that represents the scheduled loop order.
+ // Note that the returned AffineMap could be null if the kernel can not be
+ // schedule due to cycles in the iteration graph.
+ [[nodiscard]] AffineMap schedule(SortMask mask, Value ignored = nullptr);
+ unsigned getNumLoops() const { return loop2OutLvl.getNumDims(); }
+
+private:
+ LoopScheduler(SmallVector<Value> &&ins, SmallVector<AffineMap> &&loop2InsLvl,
+ Value out, AffineMap loop2OutLvl,
+ SmallVector<utils::IteratorType> &&iterTypes);
+
+ void addConstraints(Value t, AffineMap loop2LvlMap);
+ AffineMap topoSort();
+
+ // Input tensors and associated loop to level maps.
+ SmallVector<Value> ins;
+ SmallVector<AffineMap> loop2InsLvl;
+ // Output tensor and associated loop to level map.
+ Value out;
+ AffineMap loop2OutLvl;
+ // Loop type;
+ SmallVector<utils::IteratorType> iterTypes;
+
+ // Adjacent matrix that represents the iteration graph.
+ std::vector<std::vector<bool>> itGraph;
+ // InDegree used for topo sort.
+ std::vector<unsigned> inDegree;
+};
+
+} // namespace sparse_tensor
+} // namespace mlir
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
index 83e86d137335021..a433b2edc5500e7 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
@@ -1,4 +1,4 @@
-//===- SparseReinterpretMap.cpp - reinterpret sparse tensor maps ----------===//
+//===- SparseReinterpretMap.cpp - reinterpret sparse tensor maps ----------===/
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "CodegenUtils.h"
+#include "LoopScheduler.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
@@ -351,17 +352,6 @@ static SmallVector<Value> remapValueRange(OpBuilder &rewriter, TypeRange types,
return ret;
}
-/// Whether the operation has any sparse tensor with non-identity dim2lvl maps.
-static bool hasNonIdentityOperandsOrResults(Operation *op) {
- auto hasNonIdentityMap = [](Value v) {
- auto stt = tryGetSparseTensorType(v);
- return stt && !stt->isIdentity();
- };
-
- return llvm::any_of(op->getOperands(), hasNonIdentityMap) ||
- llvm::any_of(op->getResults(), hasNonIdentityMap);
-}
-
namespace {
//===----------------------------------------------------------------------===//
@@ -379,7 +369,7 @@ struct GenericOpReinterpretMap
// semantics.
if (linalgOp.getNumDpsInits() != 1 || !linalgOp.hasTensorSemantics() ||
!hasAnySparseOperandOrResult(linalgOp) ||
- !hasNonIdentityOperandsOrResults(linalgOp))
+ !hasAnyNonIdentityOperandsOrResults(linalgOp))
return failure();
// Try translating the index map.
@@ -411,6 +401,178 @@ struct GenericOpReinterpretMap
}
};
+struct GenericOpScheduler : public OpRewritePattern<linalg::GenericOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(linalg::GenericOp linalgOp,
+ PatternRewriter &rewriter) const override {
+ if (linalgOp.getNumDpsInits() != 1 || !linalgOp.hasTensorSemantics() ||
+ hasAnyNonIdentityOperandsOrResults(linalgOp) || // need demap first
+ !hasAnySparseOperandOrResult(linalgOp)) {
+ return failure();
+ }
+
+ const StringRef sorted = "sorted";
+ if (linalgOp->hasAttr(sorted))
+ return failure();
+
+ auto scheduler = LoopScheduler::fromGenericOp(linalgOp);
+ bool isAdmissible = false;
+ AffineMap order;
+ // A const list of all masks that we used for iteration graph
+ // computation. Must be ordered from more strict to less strict.
+ // Ideally (though might not be guaranteed), the earlier a constraint mask
+ // can be satisfied, the faster the generated kernel will be.
+ const auto allMasks = {
+ SortMask::kIncludeAll, SortMask::kIncludeDense,
+ SortMask::kIncludeDenseInput, SortMask::kIncludeDenseOutput,
+ SortMask::kIncludeUndef, SortMask::kSparseOnly};
+ for (const SortMask mask : allMasks) {
+ order = scheduler.schedule(mask);
+ if (order) {
+ if (isAdmissibleOrder(linalgOp, order)) {
+ isAdmissible = true;
+ break;
+ }
+ // else try a set of less strict constraints.
+ }
+ }
+
+ if (!order) {
+ // Cycles detected.
+ if (failed(resolveCycle(scheduler, linalgOp, rewriter))) {
+ return rewriter.notifyMatchFailure(
+ linalgOp, "the sparse kernel can not be scheduled: loop detected.");
+ }
+ return success();
+ }
+
+ if (!isAdmissible) {
+ return rewriter.notifyMatchFailure(
+ linalgOp, "the sparse kernel can not be scheduled.");
+ }
+
+ // Marks the GenericOp to avoid recursive matching.
+ linalgOp->setAttr(sorted, rewriter.getBoolAttr(true));
+
+ // Already sorted.
+ if (order.isIdentity())
+ return failure();
+
+ assert(order.isPermutation());
+ // `order` is orignial loop -> sorted loop map
+ ArrayAttr preItTypes = linalgOp.getIteratorTypesAttr();
+ SmallVector<Attribute> curItTypes;
+ curItTypes.reserve(preItTypes.size());
+ for (AffineExpr expr : order.getResults()) {
+ unsigned loopID = llvm::cast<AffineDimExpr>(expr).getPosition();
+ curItTypes.push_back(preItTypes[loopID]);
+ }
+
+ // Inverse `order` to get sorted loop -> original loop map
+ order = inversePermutation(order);
+ SmallVector<AffineMap> idxMaps = linalgOp.getIndexingMapsArray();
+ for (AffineMap &idxMap : idxMaps)
+ idxMap = idxMap.compose(order); // sorted loop -> lvl map
+
+ rewriter.startRootUpdate(linalgOp);
+ linalgOp.setIndexingMapsAttr(rewriter.getAffineMapArrayAttr(idxMaps));
+ linalgOp.setIteratorTypesAttr(rewriter.getArrayAttr(curItTypes));
+ rewriter.finalizeRootUpdate(linalgOp);
+
+ return success();
+ }
+
+private:
+ /// Whether the loop order is admissible by sparsification.
+ static bool isAdmissibleOrder(linalg::GenericOp linalgOp, AffineMap order) {
+ if (!hasAnySparseResult(linalgOp))
+ return true;
+
+ OpOperand *lhs = linalgOp.getDpsInitOperand(0);
+ unsigned nest = 0;
+ const auto iteratorTypes = linalgOp.getIteratorTypesArray();
+ for (const AffineExpr l : order.getResults()) {
+ unsigned loopId = llvm::cast<AffineDimExpr>(l).getPosition();
+ auto itTp =
+ linalgOp.getIteratorTypes()[loopId].cast<linalg::IteratorTypeAttr>();
+ if (linalg::isReductionIterator(itTp.getValue()))
+ break; // terminate at first reduction
+ nest++;
+ }
+ // Determine admissible dynamic insertion situations:
+ // (1) fully injective, since there are no reductions,
+ // (2) admissible 1-d expansion in innermost dimension.
+ return static_cast<int64_t>(nest) >= linalgOp.getRank(lhs) - 1;
+ };
+
+ // Last resort cycle resolution.
+ static LogicalResult resolveCycle(LoopScheduler &scheduler,
+ linalg::LinalgOp linalgOp,
+ PatternRewriter &rewriter) {
+ // Compute topological sort while leaving out every sparse input tensor in
+ // succession until an acylic iteration graph results.
+ for (OpOperand *t : linalgOp.getDpsInputOperands()) {
+ Value tval = t->get();
+ auto srcEnc = getSparseTensorEncoding(tval.getType());
+ // The constraints introduced by compound index expression are
+ // complicated. Skips them.
+ AffineMap idxMap = linalgOp.getMatchingIndexingMap(t);
+ bool hasCompExpr = llvm::any_of(idxMap.getResults(), [](AffineExpr exp) {
+ return !llvm::isa<AffineDimExpr>(exp);
+ });
+ if (!srcEnc || hasCompExpr)
+ continue;
+
+ // Try scheduling loop without constraints from `tval`.
+ AffineMap order = scheduler.schedule(SortMask::kSparseOnly, tval);
+ if (!order) // still cyclic
+ continue;
+
+ // Found an input tensor that resolves the cycle by inserting a
+ // conversion into a sparse tensor that adheres to the iteration
+ // graph order.
+ auto stt = getSparseTensorType(tval);
+ assert(stt.isIdentity());
+ order = inversePermutation(order);
+ // sorted loop -> lvl map.
+ idxMap = idxMap.compose(order);
+
+ // Found a permutation such that the results in `idxMap` is sorted.
+ // For example,
+ // (d0, d1, d2, d3) -> (d2, d1, d0)
+ // loops are scheduled in order of d0->d1->d2->d3, to resolve the cycle,
+ // we find a permutation, perm(d2, d1, d0) -> (d0, d1, d2), such that the
+ // transposed tensor's levels are visited in the same order as the loop
+ // scheduling order.
+ SmallVector<std::pair<unsigned, unsigned>> lvlSeq;
+ for (AffineExpr expr : idxMap.getResults()) {
+ unsigned lvl = llvm::cast<AffineDimExpr>(expr).getPosition();
+ lvlSeq.push_back(std::make_pair(lvl, lvlSeq.size()));
+ }
+ std::sort(lvlSeq.begin(), lvlSeq.end(), [](auto &lhs, auto &rhs) -> bool {
+ return lhs.first < rhs.first;
+ });
+ SmallVector<unsigned> perm =
+ llvm::to_vector(llvm::make_second_range(lvlSeq));
+ auto dimToLvl = AffineMap::getPermutationMap(perm, linalgOp.getContext());
+ // The result of the idxMap must be unsorted.
+ assert(!dimToLvl.isIdentity());
+
+ // Inserting the transpose
+ rewriter.setInsertionPoint(linalgOp);
+ RankedTensorType dstTp = stt.withDimToLvl(dimToLvl).getRankedTensorType();
+ Value dst = rewriter.create<ConvertOp>(tval.getLoc(), dstTp, tval);
+ rewriter.updateRootInPlace(linalgOp, [&]() {
+ linalgOp->setOperand(t->getOperandNumber(), dst);
+ });
+ return success();
+ }
+ // Cannot be resolved with a single conversion.
+ // TODO: convert more than one?
+ return failure();
+ }
+};
+
//===----------------------------------------------------------------------===//
// Reinterpret Map Rewriters for operations other than linalg.generics
//===----------------------------------------------------------------------===//
@@ -420,7 +582,7 @@ struct TensorAllocDemapper : public OpRewritePattern<AllocOp> {
using OpRewritePattern<AllocOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AllocOp op,
PatternRewriter &rewriter) const override {
- if (!hasNonIdentityOperandsOrResults(op))
+ if (!hasAnyNonIdentityOperandsOrResults(op))
return failure();
Location loc = op.getLoc();
@@ -493,7 +655,7 @@ struct ForeachOpDemapper
PatternRewriter &rewriter) const {
// Only handle operations with sparse input/output with non-identity dim2lvl
// maps.
- if (!hasNonIdentityOperandsOrResults(op))
+ if (!hasAnyNonIdentityOperandsOrResults(op))
return failure();
// TODO: demap constant as well.
@@ -582,7 +744,8 @@ void mlir::populateSparseReinterpretMap(RewritePatternSet &patterns,
ReinterpretMapScope scope) {
if (scope == ReinterpretMapScope::kAll ||
scope == ReinterpretMapScope::kGenericOnly) {
- patterns.add<GenericOpReinterpretMap>(patterns.getContext());
+ patterns.add<GenericOpReinterpretMap, GenericOpScheduler>(
+ patterns.getContext());
}
if (scope == ReinterpretMapScope::kAll ||
scope == ReinterpretMapScope::kExceptGeneric) {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 093bda9ca28efff..cd6f689a04acb53 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -39,90 +39,6 @@
using namespace mlir;
using namespace mlir::sparse_tensor;
-//===----------------------------------------------------------------------===//
-// Declarations
-//===----------------------------------------------------------------------===//
-
-namespace {
-
-/// Iteration graph sorting.
-enum class SortMask : unsigned {
- // The individual mask bits.
- kIncludeDenseOutput = 0x1, // b001
- kIncludeDenseInput = 0x2, // b010
- kIncludeUndef = 0x4, // b100
- // The subsets of mask bits.
- kIncludeAll = 0x7, // b111
- kIncludeDense = 0x3, // b011
- kSparseOnly = 0x0, // b000
-};
-
-inline static bool includesAny(SortMask mask1, SortMask mask2) {
- return static_cast<unsigned>(mask1) & static_cast<unsigned>(mask2);
-}
-
-inline static bool includesDenseInput(SortMask mask) {
- return includesAny(mask, SortMask::kIncludeDenseInput);
-}
-
-inline static bool includesDenseOutput(SortMask mask) {
- return includesAny(mask, SortMask::kIncludeDenseOutput);
-}
-
-inline static bool includesDense(SortMask mask) {
- return includesAny(mask, SortMask::kIncludeDense);
-}
-
-inline static bool includesUndef(SortMask mask) {
- return includesAny(mask, SortMask::kIncludeUndef);
-}
-
-/// A helper class that visits an affine expression and tries to find an
-/// AffineDimExpr to which the corresponding iterator from a GenericOp matches
-/// the desired iterator type.
-class AffineDimFinder : public AffineExprVisitor<AffineDimFinder> {
-public:
- explicit AffineDimFinder(linalg::GenericOp op)
- : iterTypes(op.getIteratorTypes()) {}
-
- // Overrides method from AffineExprVisitor.
- void visitDimExpr(AffineDimExpr expr) {
- if (pickedDim == nullptr ||
- pickIterType ==
- cast<linalg::IteratorTypeAttr>(iterTypes[expr.getPosition()])
- .getValue()) {
- pickedDim = expr;
- }
- }
-
- /// Set the desired iterator type that we want to pick.
- void setPickedIterType(utils::IteratorType iterType) {
- pickIterType = iterType;
- }
-
- /// Get the desired AffineDimExpr.
- AffineDimExpr getDimExpr() const { return cast<AffineDimExpr>(pickedDim); }
-
-private:
- /// The picked AffineDimExpr after visit. This must be stored as
- /// `AffineExpr` rather than `AffineDimExpr`, because the latter
- /// doesn't have a default ctor.
- AffineExpr pickedDim;
- /// The iterator type that we want.
- utils::IteratorType pickIterType;
- /// The mapping between dim=>iterator type.
- ArrayAttr iterTypes;
-};
-
-// Flattens an affine expression into a list of AffineDimExprs.
-struct AffineDimCollector : public AffineExprVisitor<AffineDimCollector> {
- // Overrides method from AffineExprVisitor.
- void visitDimExpr(AffineDimExpr expr) { dims.push_back(expr); }
- SmallVector<AffineDimExpr> dims;
-};
-
-} // namespace
-
//===----------------------------------------------------------------------===//
// Sparsifier analysis methods.
//===----------------------------------------------------------------------===//
@@ -170,56 +86,6 @@ static bool isInvariantAffine(CodegenEnv &env, AffineExpr a, LoopId ldx,
return isInvariantAffine(a, env.getCurrentLoopStack(), ldx, isAtLoop);
}
-/// Helper method to construct a permuted dimension ordering
-/// that adheres to the given topological sort.
-//
-// FIXME: does the above actually mean "dimensions", or should it say
-// "level ordering"? The same dim/lvl confusion applies to all the code
-// and comments in the definition below.
-static AffineMap permute(CodegenEnv &env, AffineMap m) {
- assert(m.getNumDims() + env.merger().getNumFilterLoops() ==
- env.topSortSize() &&
- "size mismatch");
- // Construct the inverse of `m`; to avoid the asymptotic complexity
- // of calling `m.getPermutedPosition` repeatedly.
- //
- // The variable `perm` must use `unsigned` rather than `Dimension`/`Level`,
- // because that's what `AffineMap::getPermutationMap` requires.
- // TODO: however, `perm` should be renamed to make clear what exactly
- // it's storing a permutation of.
- SmallVector<unsigned> perm;
- const unsigned numResults = m.getNumResults();
- BitVector worklist(numResults, true);
- LoopOrd loopDepth = 1;
-
- // Construct the permutation.
- while (worklist.any() && loopDepth <= env.topSortSize()) {
- const unsigned preSize = perm.size();
- for (unsigned dim : worklist.set_bits()) {
- bool isAtLoop = false;
- if (isa<AffineConstantExpr>(m.getResult(dim)) ||
- (isInvariantAffine(m.getResult(dim), env.getLoopStackUpTo(loopDepth),
- env.topSortAt(loopDepth - 1), isAtLoop) &&
- isAtLoop)) {
- // If the matching affine is constant expression or just become
- // invariant. We can visit the dimension now without breaking the
- // topSort constraint.
- perm.push_back(dim);
- }
- }
-
- // Removes resolved dimension.
- for (unsigned i = preSize, e = perm.size(); i < e; i++)
- worklist.reset(perm[i]);
-
- // Try entering the next loop in the stack.
- loopDepth++;
- }
-
- assert(perm.size() == numResults);
- return AffineMap::getPermutationMap(perm, env.op().getContext());
-}
-
/// Helper method to inspect affine expressions. Rejects cases where the
/// same index is used more than once. Also rejects compound affine
/// expressions in sparse dimensions.
@@ -471,378 +337,6 @@ static bool findSparseAnnotations(CodegenEnv &env, bool idxReducBased) {
return annotated;
}
-/// A helper to compute a topological sort. O(n^2) time complexity
-/// as we use adj matrix for the graph.
-/// The sorted result will put the first Reduction iterator to the
-/// latest possible `LoopOrd`.
-///
-/// The `inDegree` is indexed by `LoopId`, and the `adjM` is indexed by
-/// `(LoopId,LoopId)`.
-static bool topSortOptimal(CodegenEnv &env,
- ArrayRef<utils::IteratorType> iteratorTypes,
- std::vector<unsigned> &inDegree,
- std::vector<std::vector<bool>> &adjM) {
- std::vector<LoopId> redIt; // reduce iterator with 0 degree
- std::vector<LoopId> parIt; // parallel iterator with 0 degree
- std::vector<LoopId> filterIt; // filter loop with 0 degree
- const LoopId numLoops = env.merger().getNumLoops();
- for (LoopId i = 0; i < numLoops; i++) {
- if (inDegree[i] == 0) {
- if (env.merger().isFilterLoop(i))
- filterIt.push_back(i);
- else if (linalg::isReductionIterator(iteratorTypes[i]))
- redIt.push_back(i);
- else
- parIt.push_back(i);
- }
- }
-
- while (!redIt.empty() || !parIt.empty() || !filterIt.empty()) {
- // We always choose in order of filter loop -> parallel loop -> reduction
- // loop because
- // 1. Putting reduction loop early might make the loop sequence
- // inadmissible.
- // 2. Filter loops should be put as early as possible for better
- // performance, since only one (if any) iteration will carry the
- // computation. E.g., for (1 to N)
- // for (1 to M)
- // for (1 to K)
- // if (xxx)
- // O(X) computation => O(NMK+NMX) time complexity
- //
- // By putting the filter loop one level up, we got
- //
- // for (1 to N)
- // for (1 to K)
- // if (xxx)
- // for (1 to M)
- // O(X) computation => O(NK+NMX) time complexity
- auto &it = !filterIt.empty() ? filterIt : (!parIt.empty() ? parIt : redIt);
- auto src = it.back();
- env.topSortPushBack(src);
- it.pop_back();
- // Update in-degree, and push 0-degree node into worklist.
- for (LoopId dst = 0; dst < numLoops; dst++) {
- if (adjM[src][dst] && --inDegree[dst] == 0) {
- if (env.merger().isFilterLoop(dst))
- filterIt.push_back(dst);
- else if (linalg::isReductionIterator(iteratorTypes[dst]))
- redIt.push_back(dst);
- else
- parIt.push_back(dst);
- }
- }
- }
- return env.topSortSize() == numLoops;
-}
-
-static void addIterOrdering(LoopId f, LoopId t,
- std::vector<std::vector<bool>> &adjM,
- std::vector<unsigned> &inDegree) {
- if (!adjM[f][t] && f != t) {
- adjM[f][t] = true;
- inDegree[t]++;
- }
-}
-
-/// Helper method to add all constraints from the indices in one affine
-/// expression before all indices in the other affine expression. For
-/// example i0+i1 < i2+i3+1 yields i0<i2, i0<i3, i1<i2, and i1<i3.
-/// The affine expression `a` is empty iff `fidx` have a value, leading to
-/// b = (i0 + i1) < fidx => i0 < fidx, i1 < fidx.
-/// The affine expression `b` is empty iff `tidx` have a value, leading to
-/// tidx < a = (i0 + i1) => tidx < i0, tidx < i1.
-///
-/// The `inDegree` is indexed by `LoopId`, and the `adjM` is indexed by
-/// `(LoopId,LoopId)`.
-static void addAffineOrderings(std::vector<std::vector<bool>> &adjM,
- std::vector<unsigned> &inDegree, AffineExpr a,
- AffineExpr b, std::optional<LoopId> fidx,
- std::optional<LoopId> tidx) {
- if (!a && !b) {
- // Recursion leaf.
- assert(fidx && tidx);
- const LoopId f = *fidx, t = *tidx;
- addIterOrdering(f, t, adjM, inDegree);
- return;
- }
- // Picks an affine expression and expand (recurse into) it.
- const auto toExpand = a ? a : b;
- switch (toExpand.getKind()) {
- case AffineExprKind::DimId: {
- const std::optional<LoopId> idx{
- cast<AffineDimExpr>(toExpand).getPosition()};
- if (toExpand == a)
- addAffineOrderings(adjM, inDegree, AffineExpr(), b, idx, tidx);
- else // toExpand == b
- addAffineOrderings(adjM, inDegree, a, AffineExpr(), fidx, idx);
- break;
- }
- case AffineExprKind::Add:
- case AffineExprKind::Mul: {
- auto binOp = cast<AffineBinaryOpExpr>(toExpand);
- if (toExpand == a) {
- addAffineOrderings(adjM, inDegree, binOp.getLHS(), b, fidx, tidx);
- addAffineOrderings(adjM, inDegree, binOp.getRHS(), b, fidx, tidx);
- } else {
- addAffineOrderings(adjM, inDegree, a, binOp.getLHS(), fidx, tidx);
- addAffineOrderings(adjM, inDegree, a, binOp.getRHS(), fidx, tidx);
- }
- break;
- }
- default:
- break;
- }
-}
-
-static void tryRelaxAffineConstraints(linalg::GenericOp op,
- std::optional<LoopId> &fldx,
- AffineExpr &fa,
- std::optional<LoopId> &tldx,
- AffineExpr &ta) {
- // We use a heuristic here to only pick one dim expression from each
- // compound affine expression to establish the order between two dense
- // dimensions.
- if (!tldx) {
- AffineDimFinder finder(op);
- // NOTE: The ordering can only be loosen when the destination level is
- // dense (when !tldx), for [dense, sparse] -> (d0 + d1, d2), we still
- // require both d0 < d2 and d1 < d2 to ensure correct ordering (i.e.,
- // no ordering like d0->d2->d1).
- // TODO: this is obviously a sub optimal solution.
- if (!fldx && !isa<AffineConstantExpr>(fa)) {
- // Heuristic: we prefer parallel loop for lhs to reduce the chance
- // we add reduce < parallel ordering.
- finder.setPickedIterType(utils::IteratorType::parallel);
- finder.walkPostOrder(fa);
- fa = finder.getDimExpr();
- fldx = finder.getDimExpr().getPosition();
- }
- if (!isa<AffineConstantExpr>(ta)) {
- // Heuristic: we prefer reduction loop for rhs to reduce the chance
- // adding reduce < parallel ordering.
- finder.setPickedIterType(utils::IteratorType::reduction);
- finder.walkPostOrder(ta);
- ta = finder.getDimExpr();
- tldx = finder.getDimExpr().getPosition();
- }
- }
-}
-
-static void addFilterLoopBasedConstraints(CodegenEnv &env, OpOperand &t,
- OpOperand *skip, SortMask mask,
- std::vector<std::vector<bool>> &adjM,
- std::vector<unsigned> &inDegree) {
- // Get map, encoding, and tensor-identifier.
- const auto map = env.op().getMatchingIndexingMap(&t);
- const auto enc = getSparseTensorEncoding(t.get().getType());
- const TensorId tid = env.makeTensorId(t.getOperandNumber());
-
- // Each tensor expression and optional dimension ordering (row-major
- // by default) puts an ordering constraint on the loop indices. For
- // example, the tensor expression A_ijk forces the ordering i < j < k
- // on the loop indices if no explicit dimension ordering is given.
- const Level lvlRank = map.getNumResults();
- assert(!enc || lvlRank == enc.getLvlRank());
- for (Level lvl = 0; lvl < lvlRank; lvl++) {
- // FIXME: `toOrigDim` is deprecated.
- AffineExpr ta = map.getResult(toOrigDim(enc, lvl));
- std::optional<LoopId> tldx = env.merger().getLoopId(tid, lvl);
- // Filter loops should be constructed after all the dependent loops,
- // i.e., d0 + d1 < filter_loop(d0 + d1)
- if (tldx && env.merger().isFilterLoop(*tldx)) {
- assert(!isa<AffineDimExpr>(ta) && !isDenseDLT(enc.getLvlTypes()[lvl]));
- addAffineOrderings(adjM, inDegree, ta, AffineExpr(), std::nullopt, tldx);
- // Now that the ordering of affine expression is captured by filter
- // loop idx, we only need to ensure the affine ordering against filter
- // loop. Thus, we reset the affine express to nil here to mark it as
- // resolved.
- ta = AffineExpr();
- }
-
- // Skip tensor during cycle resolution, though order between filter loop
- // and dependent loops need to be guaranteed unconditionally.
- if (&t == skip)
- continue;
-
- if (lvl > 0) {
- // FIXME: `toOrigDim` is deprecated.
- AffineExpr fa = map.getResult(toOrigDim(enc, lvl - 1));
- std::optional<LoopId> fldx = env.merger().getLoopId(tid, lvl - 1);
-
- // Applying order constraints on every pair of dimExpr between two
- // compound affine expressions can sometime too strict:
- // E.g., for [dense, dense] -> (d0 + d1, d2 + d3).
- // It is totally fine to have loop sequence d0->d2->d1->d3 instead of
- // requiring d0 < d2, d1 < d2, d0 < d3, d1 < d3.
- // We also relax the affine constraint when use slice-based algorithm
- // as there is no filter loop for affine index on sparse dimension.
- // TODO: do we really need the condition?
- if (!includesDense(mask))
- tryRelaxAffineConstraints(env.op(), fldx, fa, tldx, ta);
-
- // (d0 + d1) < (d2 + d3), or
- // filter_loop_d-1 < (d2 + d3), or
- // (d0 + d1) < filter_loop_d, or
- // filter_loop_d-1 < filter_loop_d depending on whether fa/ta is reset
- // above.
- addAffineOrderings(adjM, inDegree, fa, ta, fldx, tldx);
- }
- }
-}
-
-static void addSliceBasedConstraints(CodegenEnv &env, OpOperand &t,
- OpOperand *skip, SortMask mask,
- std::vector<std::vector<bool>> &adjM,
- std::vector<unsigned> &inDegree) {
- // Get map and encoding.
- const auto map = env.op().getMatchingIndexingMap(&t);
- const auto enc = getSparseTensorEncoding(t.get().getType());
-
- // No special treatment for simple indices.
- if (getNumNonTrivialIdxExpOnSparseLvls(map, t.get()) == 0)
- return addFilterLoopBasedConstraints(env, t, skip, mask, adjM, inDegree);
-
- // Skip tensor during cycle resolution, though order between filter loop
- // and dependent loops need to be guaranteed unconditionally.
- if (&t == skip)
- return;
-
- AffineDimFinder finder(env.op());
- finder.setPickedIterType(utils::IteratorType::reduction);
- // To compute iteration graph for tensor[d0 + d1 + d3, d4 + d5 + d6],
- // we requires there exist d_x \in {d0, d1, d3} and d_y \in {d4, d5, d6},
- // and d_x > d_y && {d0, d1, d3} - d_x > {d4, d5, d6} - d_y
- const Level lvlRank = map.getNumResults();
- assert(!enc || lvlRank == enc.getLvlRank());
- for (Level lvl = 1; lvl < lvlRank; lvl++) {
- // FIXME: `toOrigDim` is deprecated.
- const AffineExpr fa = map.getResult(toOrigDim(enc, lvl - 1));
- const AffineExpr ta = map.getResult(toOrigDim(enc, lvl));
-
- if (isa<AffineDimExpr>(fa) || isa<AffineDimExpr>(ta)) {
- AffineDimCollector fCollector;
- fCollector.walkPostOrder(fa);
-
- AffineDimCollector tCollector;
- tCollector.walkPostOrder(ta);
- for (auto fd : fCollector.dims) {
- for (auto td : tCollector.dims) {
- const LoopId f = env.makeLoopId(fd.getPosition());
- const LoopId t = env.makeLoopId(td.getPosition());
- addIterOrdering(f, t, adjM, inDegree);
- }
- }
- continue;
- }
-
- // This is a heuristic, we pick an abitrary reduction loop from lhs and
- // rhs and use them as d_x and d_y.
- finder.walkPostOrder(fa);
- const AffineDimExpr fexp = finder.getDimExpr();
- const LoopId fldx = env.makeLoopId(fexp.getPosition());
-
- finder.walkPostOrder(ta);
- const AffineDimExpr texp = finder.getDimExpr();
- const LoopId tldx = env.makeLoopId(texp.getPosition());
-
- // d_x > d_y
- addIterOrdering(fldx, tldx, adjM, inDegree);
-
- AffineDimCollector fCollector;
- fCollector.walkPostOrder(fa);
- AffineDimCollector tCollector;
- tCollector.walkPostOrder(ta);
-
- // make sure dx and dy is the last;
- for (auto fd : fCollector.dims) {
- const LoopId f = env.makeLoopId(fd.getPosition());
- addIterOrdering(f, fldx, adjM, inDegree);
- }
- for (auto td : tCollector.dims) {
- const LoopId t = env.makeLoopId(td.getPosition());
- addIterOrdering(t, tldx, adjM, inDegree);
- }
- // Since we only support affine addition, the order between two dim
- // expression does not really matters.
- // {d0, d1, d3} - d_x > {d4, d5, d6} - d_y
- // This is to ensure that the affine expressions are reduced in sparse
- // tensor level ordering.
- // TODO: this ordering could probably be loosen if we support out-of-order
- // reduction.
- // TODO: the evaluation order need to be ensure to
- // support affine multiplication.
- for (auto fd : fCollector.dims) {
- const LoopId f = env.makeLoopId(fd.getPosition());
- if (f == fldx) // skip d_x
- continue;
- for (auto td : tCollector.dims) {
- const LoopId t = env.makeLoopId(td.getPosition());
- if (t == tldx) // skip d_y
- continue;
- addIterOrdering(f, t, adjM, inDegree);
- }
- }
- }
-}
-
-/// Computes a topologically sorted iteration graph for the linalg operation.
-/// Ensures all tensors are visited in natural index order. This is
-/// essential for sparse storage formats since these only support access
-/// along fixed dimensions. Even for dense storage formats, however, the natural
-/// index order yields innermost unit-stride access with better spatial
-/// locality.
-static bool computeIterationGraph(CodegenEnv &env, SortMask mask,
- OpOperand *skip, bool idxReducBased = false) {
- // Set up an n x n from/to adjacency matrix of the iteration graph
- // for the implicit loop indices i_0 .. i_n-1.
- const unsigned numLoops = env.merger().getNumLoops();
- std::vector<std::vector<bool>> adjM(numLoops,
- std::vector<bool>(numLoops, false));
- std::vector<unsigned> inDegree(numLoops, 0); // in-degree of each node.
- const auto iteratorTypes = env.op().getIteratorTypesArray();
- // Iterate over the indexing maps of every tensor in the tensor expression.
- for (OpOperand &t : env.op()->getOpOperands()) {
- // Get map and encoding.
- const auto enc = getSparseTensorEncoding(t.get().getType());
- // Skips dense inputs/outputs when not requested.
- const bool isDenseInput = !enc && env.op().isDpsInput(&t);
- const bool isDenseOutput = !enc && !isDenseInput;
- if ((isDenseInput && !includesDenseInput(mask)) ||
- (isDenseOutput && !includesDenseOutput(mask)))
- continue;
-
- // Push unrelated loops into sparse iteration space, so these
- // will be skipped more often.
- // TODO: Do we really need this?
- if (includesUndef(mask)) {
- const TensorId tid = env.makeTensorId(t.getOperandNumber());
- for (LoopId i = 0; i < numLoops; i++) {
- const auto dltI = env.dlt(tid, i);
- if (isCompressedDLT(dltI) || isLooseCompressedDLT(dltI) ||
- isSingletonDLT(dltI) || is2OutOf4DLT(dltI)) {
- for (LoopId j = 0; j < numLoops; j++)
- if (isUndefDLT(env.dlt(tid, j))) {
- addIterOrdering(i, j, adjM, inDegree);
- }
- } else {
- assert(isDenseDLT(dltI) || isUndefDLT(dltI));
- }
- }
- }
- // Push unrelated loops into sparse iteration space, so these
- // will be skipped more often.
- if (idxReducBased)
- addSliceBasedConstraints(env, t, skip, mask, adjM, inDegree);
- else
- addFilterLoopBasedConstraints(env, t, skip, mask, adjM, inDegree);
- }
- // Topologically sort the iteration graph to determine loop order.
- // Report failure for a cyclic iteration graph.
- env.topSortClear(numLoops);
- return topSortOptimal(env, iteratorTypes, inDegree, adjM);
-}
-
//===----------------------------------------------------------------------===//
// Sparsifier synthesis methods (statements and expressions).
//===----------------------------------------------------------------------===//
@@ -1951,6 +1445,12 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
if (hasNonTrivialAffineOnSparseOut(op))
return failure();
+ if (!op->hasAttr("sorted")) {
+ return rewriter.notifyMatchFailure(
+ op, "Loops not yet scheduled, try run --sparse-reinterpret-map "
+ "before sparsification.");
+ }
+
// Sets up a code generation environment.
const unsigned numTensors = op->getNumOperands();
const unsigned numLoops = op.getNumLoops();
@@ -1970,6 +1470,7 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
maxLvlRank = std::max(maxLvlRank, SparseTensorType(rtp).getLvlRank());
}
}
+
// A slice based algorithm for affine indices does not need filter loops.
CodegenEnv env(op, options, numTensors, numLoops,
/*numFilterLoops=*/idxReducBased ? 0 : numFilterLoops,
@@ -2003,39 +1504,6 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
if (failed(env.initTensorExp()))
return failure();
- // Computes a topologically sorted iteration graph to ensure tensors
- // are visited in natural index order. Gradually relaxes the considered
- // constraints until an acyclic iteration graph results, such that sparse
- // code generation can proceed. As a last resort, an attempt is made
- // to resolve cycles by inserting a conversion.
- bool isAdmissible = false;
- bool hasCycle = true;
- // A const list of all masks that we used for iteration graph
- // computation. Must be ordered from more strict to less strict.
- // Ideally (though might not be guaranteed), the earlier a constraint mask
- // can be satisfied, the faster the generated kernel will be.
- const auto allMasks = {
- SortMask::kIncludeAll, SortMask::kIncludeDense,
- SortMask::kIncludeDenseInput, SortMask::kIncludeDenseOutput,
- SortMask::kIncludeUndef, SortMask::kSparseOnly};
- for (const SortMask mask : allMasks) {
- if (computeIterationGraph(env, mask, nullptr, idxReducBased)) {
- hasCycle = false;
- if (env.isAdmissibleTopoOrder()) {
- isAdmissible = true;
- break;
- }
- // else try a set of less strict constraints.
- }
- }
- if (hasCycle) {
- return idxReducBased
- ? failure() // TODO: should cycle be resolved differently?
- : resolveCycle(env, rewriter); // one last shot
- }
- if (!isAdmissible)
- return failure(); // inadmissible expression, reject
-
// Recursively generates code if admissible.
env.startEmit();
genBuffers(env, rewriter);
@@ -2049,47 +1517,6 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
}
private:
- // Last resort cycle resolution.
- LogicalResult resolveCycle(CodegenEnv &env, PatternRewriter &rewriter) const {
- // Compute topological sort while leaving out every
- // sparse input tensor in succession until an acylic
- // iteration graph results.
- for (OpOperand *t : env.op().getDpsInputOperands()) {
- const TensorId tid = env.makeTensorId(t->getOperandNumber());
- Value tval = t->get();
- auto srcEnc = getSparseTensorEncoding(tval.getType());
- if (!srcEnc || !computeIterationGraph(env, SortMask::kSparseOnly, t))
- continue;
- // Found an input tensor that resolves the cycle by inserting a
- // conversion into a sparse tensor that adheres to the iteration
- // graph order. Also releases the temporary sparse tensor.
- //
- // TODO: investigate fusing the conversion with computation,
- // especially if it is a direct yield!
- //
- auto srcTp = getRankedTensorType(tval);
- // TODO: This assertion is to match the behavior from prior to
- // merging dimOrdering and higherOrdering into dimToLvl. However,
- // since `permute` returns a permutation, we can remove this
- // restriction by instead composing the result of `permute`
- // with `srcEnc.getDimToLvl`.
- assert(srcEnc.isPermutation());
- auto dstEnc =
- srcEnc.withDimToLvl(permute(env, env.op().getMatchingIndexingMap(t)));
- auto dstTp = RankedTensorType::get(srcTp.getShape(),
- srcTp.getElementType(), dstEnc);
- auto convert = rewriter.create<ConvertOp>(tval.getLoc(), dstTp, tval);
- rewriter.updateRootInPlace(env.op(),
- [&]() { env.op()->setOperand(tid, convert); });
- rewriter.setInsertionPointAfter(env.op());
- rewriter.create<bufferization::DeallocTensorOp>(tval.getLoc(), convert);
- return success();
- }
- // Cannot be resolved with a single conversion.
- // TODO: convert more than one?
- return failure();
- }
-
/// Options to control sparse code generation.
SparsificationOptions options;
};
diff --git a/mlir/test/Dialect/SparseTensor/GPU/gpu_combi.mlir b/mlir/test/Dialect/SparseTensor/GPU/gpu_combi.mlir
index 0979884cbd502a5..b12bad685b49b97 100644
--- a/mlir/test/Dialect/SparseTensor/GPU/gpu_combi.mlir
+++ b/mlir/test/Dialect/SparseTensor/GPU/gpu_combi.mlir
@@ -1,5 +1,6 @@
// RUN: mlir-opt %s --linalg-generalize-named-ops \
// RUN: --pre-sparsification-rewrite \
+// RUN: --sparse-reinterpret-map \
// RUN: --sparsification="parallelization-strategy=dense-outer-loop" \
// RUN: --sparse-gpu-codegen | FileCheck %s
@@ -60,4 +61,3 @@ func.func @matmuls(%A: tensor<1024x8xf64>,
outs(%Z: tensor<1024x1024xf64>) -> tensor<1024x1024xf64>
return %D : tensor<1024x1024xf64>
}
-
diff --git a/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul.mlir b/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul.mlir
index 8dc6afa320af7d5..e477dcc0169fc3d 100644
--- a/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul.mlir
+++ b/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul.mlir
@@ -1,5 +1,6 @@
// RUN: mlir-opt %s --linalg-generalize-named-ops \
// RUN: --pre-sparsification-rewrite \
+// RUN: --sparse-reinterpret-map \
// RUN: --sparsification="parallelization-strategy=dense-outer-loop" \
// RUN: --sparse-gpu-codegen | FileCheck %s
diff --git a/mlir/test/Dialect/SparseTensor/GPU/gpu_matvec.mlir b/mlir/test/Dialect/SparseTensor/GPU/gpu_matvec.mlir
index ab267062e41693a..0c5ff55dd863c08 100644
--- a/mlir/test/Dialect/SparseTensor/GPU/gpu_matvec.mlir
+++ b/mlir/test/Dialect/SparseTensor/GPU/gpu_matvec.mlir
@@ -1,5 +1,6 @@
// RUN: mlir-opt %s --linalg-generalize-named-ops \
// RUN: --pre-sparsification-rewrite \
+// RUN: --sparse-reinterpret-map \
// RUN: --sparsification="parallelization-strategy=dense-outer-loop" \
// RUN: --sparse-gpu-codegen | FileCheck %s
diff --git a/mlir/test/Dialect/SparseTensor/constant_index_map.mlir b/mlir/test/Dialect/SparseTensor/constant_index_map.mlir
index 9eb535385790146..eaef6a315852920 100644
--- a/mlir/test/Dialect/SparseTensor/constant_index_map.mlir
+++ b/mlir/test/Dialect/SparseTensor/constant_index_map.mlir
@@ -1,6 +1,6 @@
// Reported by https://github.com/llvm/llvm-project/issues/61530
-// RUN: mlir-opt %s -sparsification | FileCheck %s
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification | FileCheck %s
#map1 = affine_map<(d0) -> (0, d0)>
#map2 = affine_map<(d0) -> (d0)>
diff --git a/mlir/test/Dialect/SparseTensor/dense.mlir b/mlir/test/Dialect/SparseTensor/dense.mlir
index 485a5cbb178af94..52db814572a15fb 100644
--- a/mlir/test/Dialect/SparseTensor/dense.mlir
+++ b/mlir/test/Dialect/SparseTensor/dense.mlir
@@ -1,5 +1,5 @@
// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
-// RUN: mlir-opt %s -sparsification | FileCheck %s
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification | FileCheck %s
// Test to demonstrate the difference between non-annotated dense tensors
// and all-dense-annotated "sparse" tensors. The former class remains as
diff --git a/mlir/test/Dialect/SparseTensor/one_trip.mlir b/mlir/test/Dialect/SparseTensor/one_trip.mlir
index 35d7878633b58a5..b8ab177357492d8 100644
--- a/mlir/test/Dialect/SparseTensor/one_trip.mlir
+++ b/mlir/test/Dialect/SparseTensor/one_trip.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -sparsification -cse | FileCheck %s
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification -cse | FileCheck %s
#Dense = #sparse_tensor.encoding<{
map = (d0, d1) -> (d0 : dense, d1 : dense)
diff --git a/mlir/test/Dialect/SparseTensor/semi_ring.mlir b/mlir/test/Dialect/SparseTensor/semi_ring.mlir
index c69efcae3b08ec2..5a936bfa1e597a7 100644
--- a/mlir/test/Dialect/SparseTensor/semi_ring.mlir
+++ b/mlir/test/Dialect/SparseTensor/semi_ring.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -sparsification | FileCheck %s
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification | FileCheck %s
#SM = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
diff --git a/mlir/test/Dialect/SparseTensor/sorted_coo.mlir b/mlir/test/Dialect/SparseTensor/sorted_coo.mlir
index 3ce5dd1dcfdffc8..1a8af35358bf62c 100644
--- a/mlir/test/Dialect/SparseTensor/sorted_coo.mlir
+++ b/mlir/test/Dialect/SparseTensor/sorted_coo.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -sparsification --canonicalize | FileCheck %s
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification --canonicalize | FileCheck %s
#SortedCOO = #sparse_tensor.encoding<{
map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton)
diff --git a/mlir/test/Dialect/SparseTensor/sparse_1d.mlir b/mlir/test/Dialect/SparseTensor/sparse_1d.mlir
index 7c5e1a372d44fca..d29758e4fa97121 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_1d.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_1d.mlir
@@ -1,5 +1,5 @@
// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
-// RUN: mlir-opt %s -sparsification | FileCheck %s
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification | FileCheck %s
#DV = #sparse_tensor.encoding<{ map = (d0) -> (d0 : dense) }>
#SV = #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>
diff --git a/mlir/test/Dialect/SparseTensor/sparse_2d.mlir b/mlir/test/Dialect/SparseTensor/sparse_2d.mlir
index 076e9201a1c053e..22681daa94148ef 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_2d.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_2d.mlir
@@ -1,5 +1,5 @@
// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
-// RUN: mlir-opt %s -sparsification | FileCheck %s
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification | FileCheck %s
#Tdd = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : dense) }>
#Tds = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
diff --git a/mlir/test/Dialect/SparseTensor/sparse_3d.mlir b/mlir/test/Dialect/SparseTensor/sparse_3d.mlir
index c245e612be37f12..f2daa77a0c24c8a 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_3d.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_3d.mlir
@@ -1,5 +1,5 @@
// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
-// RUN: mlir-opt %s -sparsification | FileCheck %s
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification | FileCheck %s
#Td = #sparse_tensor.encoding<{ map = (d0) -> (d0 : dense) }>
diff --git a/mlir/test/Dialect/SparseTensor/sparse_affine.mlir b/mlir/test/Dialect/SparseTensor/sparse_affine.mlir
index d6c3075a8732574..ca55d2eb0177819 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_affine.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_affine.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -sparsification | FileCheck %s
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification | FileCheck %s
#SpVec = #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>
#CSR = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
diff --git a/mlir/test/Dialect/SparseTensor/sparse_broadcast.mlir b/mlir/test/Dialect/SparseTensor/sparse_broadcast.mlir
index 1a5f79d23cba29a..278450fabd74ec2 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_broadcast.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_broadcast.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --sparsification --canonicalize --cse | FileCheck %s
+// RUN: mlir-opt %s --sparse-reinterpret-map --sparsification --canonicalize --cse | FileCheck %s
#DCSR = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed, d1 : compressed) }>
#SparseTensor = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : compressed, d1 : compressed, d2 : compressed) }>
diff --git a/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir b/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir
index 7576a95c0fd0aac..de0582e5c688575 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --sparsification --canonicalize --cse | FileCheck %s
+// RUN: mlir-opt %s --sparse-reinterpret-map --sparsification --canonicalize --cse | FileCheck %s
#map = affine_map<(d0, d1, d2, d3) -> (d0 + d2, d1 + d3)>
#map1 = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
diff --git a/mlir/test/Dialect/SparseTensor/sparse_expand.mlir b/mlir/test/Dialect/SparseTensor/sparse_expand.mlir
index 4817771a6044ce2..bdbbf52d86286ab 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_expand.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_expand.mlir
@@ -1,9 +1,11 @@
// RUN: mlir-opt %s --linalg-generalize-named-ops \
// RUN: --linalg-fuse-elementwise-ops \
+// RUN: --sparse-reinterpret-map \
// RUN: --sparsification | \
// RUN: FileCheck %s --check-prefix=CHECK-SPARSE
// RUN: mlir-opt %s --linalg-generalize-named-ops \
// RUN: --linalg-fuse-elementwise-ops \
+// RUN: --sparse-reinterpret-map \
// RUN: --sparsification --lower-sparse-ops-to-foreach \
// RUN: --lower-sparse-foreach-to-scf \
// RUN: --sparse-tensor-conversion --cse | \
@@ -142,7 +144,8 @@ func.func @matmul1(%A: tensor<8x2xf64, #CSR>,
// CHECK-SPARSE: }
// CHECK-SPARSE: sparse_tensor.compress %[[A]], %[[B]], %[[C]], %[[COUNT]]
// CHECK-SPARSE: }
-// CHECK-SPARSE: %[[RET:.*]] = sparse_tensor.load %[[T]] hasInserts
+// CHECK-SPARSE: %[[DEMAP:.*]] = sparse_tensor.load %[[T]] hasInserts
+// CHECK-SPARSE: %[[RET:.*]] = sparse_tensor.reinterpret_map %[[DEMAP]]
// CHECK-SPARSE: return %[[RET]]
//
// CHECK-CONVERT-LABEL: func @matmul2(
diff --git a/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir b/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir
index 988ab7f85be4134..40367f12f85a472 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --linalg-generalize-named-ops --pre-sparsification-rewrite --sparsification --sparse-tensor-conversion --canonicalize --cse | FileCheck %s
+// RUN: mlir-opt %s --linalg-generalize-named-ops --pre-sparsification-rewrite --sparse-reinterpret-map --sparsification --sparse-tensor-conversion --canonicalize --cse | FileCheck %s
#DCSR = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed, d1 : compressed) }>
diff --git a/mlir/test/Dialect/SparseTensor/sparse_fp_ops.mlir b/mlir/test/Dialect/SparseTensor/sparse_fp_ops.mlir
index dac34da30b49fda..f25a56469d7e742 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_fp_ops.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_fp_ops.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -sparsification | FileCheck %s
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification | FileCheck %s
#SV = #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>
diff --git a/mlir/test/Dialect/SparseTensor/sparse_index.mlir b/mlir/test/Dialect/SparseTensor/sparse_index.mlir
index 34e04c03529036f..e49c89e01790c08 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_index.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_index.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -sparsification | FileCheck %s
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification | FileCheck %s
#DenseMatrix = #sparse_tensor.encoding<{
map = (d0, d1) -> (d0 : dense, d1 : dense)
diff --git a/mlir/test/Dialect/SparseTensor/sparse_int_ops.mlir b/mlir/test/Dialect/SparseTensor/sparse_int_ops.mlir
index 0bc396278357656..da7d45bd63e17d4 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_int_ops.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_int_ops.mlir
@@ -1,5 +1,5 @@
// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
-// RUN: mlir-opt %s -sparsification | FileCheck %s
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification | FileCheck %s
#SV = #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>
@@ -506,4 +506,3 @@ func.func @lslbyc(%arga: tensor<32xi64, #SV>,
} -> tensor<32xi64>
return %0 : tensor<32xi64>
}
-
diff --git a/mlir/test/Dialect/SparseTensor/sparse_kernels.mlir b/mlir/test/Dialect/SparseTensor/sparse_kernels.mlir
index fa0617b9cb6f58e..1cb23cfd5b9eb2f 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_kernels.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_kernels.mlir
@@ -1,6 +1,6 @@
// RUN: mlir-opt %s \
// RUN: --linalg-generalize-named-ops --linalg-fuse-elementwise-ops \
-// RUN: --sparsification | FileCheck %s
+// RUN: --sparse-reinterpret-map --sparsification | FileCheck %s
#SparseVector = #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>
@@ -195,10 +195,11 @@ func.func @matmul2(%A: tensor<4x8xf64, #DCSR>,
return %D: tensor<4x4xf64, #DCSR>
}
+
// CHECK-LABEL: func.func @conv2d(
-// CHECK-SAME: %[[VAL_0:.*]]: tensor<8x8xi32>,
-// CHECK-SAME: %[[VAL_1:.*]]: tensor<3x3xi32, #sparse_tensor.encoding<{{{.*}}}>>,
-// CHECK-SAME: %[[VAL_2:.*]]: tensor<6x6xi32>) -> tensor<6x6xi32> {
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<8x8xi32>,
+// CHECK-SAME: %[[VAL_1:.*]]: tensor<3x3xi32, #sparse_tensor.encoding<{{{.*}}}>>,
+// CHECK-SAME: %[[VAL_2:.*]]: tensor<6x6xi32>) -> tensor<6x6xi32> {
// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 6 : index
// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 1 : index
@@ -208,33 +209,34 @@ func.func @matmul2(%A: tensor<4x8xf64, #DCSR>,
// CHECK-DAG: %[[VAL_9:.*]] = sparse_tensor.positions %[[VAL_1]] {level = 1 : index} : tensor<3x3xi32, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
// CHECK-DAG: %[[VAL_10:.*]] = sparse_tensor.coordinates %[[VAL_1]] {level = 1 : index} : tensor<3x3xi32, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
// CHECK-DAG: %[[VAL_11:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<3x3xi32, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xi32>
-// CHECK: %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_2]] : memref<6x6xi32>
+// CHECK-DAG: %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_2]] : memref<6x6xi32>
// CHECK: scf.for %[[VAL_13:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] {
-// CHECK: %[[VAL_14:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_4]]] : memref<?xindex>
-// CHECK: %[[VAL_15:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_5]]] : memref<?xindex>
-// CHECK: scf.for %[[VAL_16:.*]] = %[[VAL_14]] to %[[VAL_15]] step %[[VAL_5]] {
-// CHECK: %[[VAL_17:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_16]]] : memref<?xindex>
-// CHECK: scf.for %[[VAL_18:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] {
-// CHECK: %[[VAL_19:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_13]], %[[VAL_18]]] : memref<6x6xi32>
-// CHECK: %[[VAL_20:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_16]]] : memref<?xindex>
-// CHECK: %[[VAL_21:.*]] = arith.addi %[[VAL_16]], %[[VAL_5]] : index
-// CHECK: %[[VAL_22:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_21]]] : memref<?xindex>
-// CHECK: %[[VAL_23:.*]] = scf.for %[[VAL_24:.*]] = %[[VAL_20]] to %[[VAL_22]] step %[[VAL_5]] iter_args(%[[VAL_25:.*]] = %[[VAL_19]]) -> (i32) {
-// CHECK: %[[VAL_26:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_24]]] : memref<?xindex>
-// CHECK: %[[VAL_27:.*]] = arith.addi %[[VAL_13]], %[[VAL_17]] : index
-// CHECK: %[[VAL_28:.*]] = arith.addi %[[VAL_18]], %[[VAL_26]] : index
-// CHECK: %[[VAL_29:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_27]], %[[VAL_28]]] : memref<8x8xi32>
-// CHECK: %[[VAL_30:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_24]]] : memref<?xi32>
-// CHECK: %[[VAL_31:.*]] = arith.muli %[[VAL_29]], %[[VAL_30]] : i32
-// CHECK: %[[VAL_32:.*]] = arith.addi %[[VAL_25]], %[[VAL_31]] : i32
-// CHECK: scf.yield %[[VAL_32]] : i32
+// CHECK: scf.for %[[VAL_14:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] {
+// CHECK: %[[VAL_15:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_13]], %[[VAL_14]]] : memref<6x6xi32>
+// CHECK: %[[VAL_16:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_4]]] : memref<?xindex>
+// CHECK: %[[VAL_17:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_5]]] : memref<?xindex>
+// CHECK: %[[VAL_18:.*]] = scf.for %[[VAL_19:.*]] = %[[VAL_16]] to %[[VAL_17]] step %[[VAL_5]] iter_args(%[[VAL_20:.*]] = %[[VAL_15]]) -> (i32) {
+// CHECK: %[[VAL_21:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_19]]] : memref<?xindex>
+// CHECK: %[[VAL_22:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_19]]] : memref<?xindex>
+// CHECK: %[[VAL_23:.*]] = arith.addi %[[VAL_19]], %[[VAL_5]] : index
+// CHECK: %[[VAL_24:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_23]]] : memref<?xindex>
+// CHECK: %[[VAL_25:.*]] = scf.for %[[VAL_26:.*]] = %[[VAL_22]] to %[[VAL_24]] step %[[VAL_5]] iter_args(%[[VAL_27:.*]] = %[[VAL_20]]) -> (i32) {
+// CHECK: %[[VAL_28:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_26]]] : memref<?xindex>
+// CHECK: %[[VAL_29:.*]] = arith.addi %[[VAL_13]], %[[VAL_21]] : index
+// CHECK: %[[VAL_30:.*]] = arith.addi %[[VAL_14]], %[[VAL_28]] : index
+// CHECK: %[[VAL_31:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_29]], %[[VAL_30]]] : memref<8x8xi32>
+// CHECK: %[[VAL_32:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_26]]] : memref<?xi32>
+// CHECK: %[[VAL_33:.*]] = arith.muli %[[VAL_31]], %[[VAL_32]] : i32
+// CHECK: %[[VAL_34:.*]] = arith.addi %[[VAL_27]], %[[VAL_33]] : i32
+// CHECK: scf.yield %[[VAL_34]] : i32
// CHECK: } {"Emitted from" = "linalg.generic"}
-// CHECK: memref.store %[[VAL_33:.*]], %[[VAL_12]]{{\[}}%[[VAL_13]], %[[VAL_18]]] : memref<6x6xi32>
+// CHECK: scf.yield %[[VAL_25]] : i32
// CHECK: } {"Emitted from" = "linalg.generic"}
+// CHECK: memref.store %[[VAL_18]], %[[VAL_12]]{{\[}}%[[VAL_13]], %[[VAL_14]]] : memref<6x6xi32>
// CHECK: } {"Emitted from" = "linalg.generic"}
// CHECK: } {"Emitted from" = "linalg.generic"}
-// CHECK: %[[VAL_34:.*]] = bufferization.to_tensor %[[VAL_12]] : memref<6x6xi32>
-// CHECK: return %[[VAL_34]] : tensor<6x6xi32>
+// CHECK: %[[VAL_35:.*]] = bufferization.to_tensor %[[VAL_12]] : memref<6x6xi32>
+// CHECK: return %[[VAL_35]] : tensor<6x6xi32>
// CHECK: }
func.func @conv2d(%input: tensor<8x8xi32>,
%filter: tensor<3x3xi32, #DCSR>,
diff --git a/mlir/test/Dialect/SparseTensor/sparse_lower.mlir b/mlir/test/Dialect/SparseTensor/sparse_lower.mlir
index 13245f427a7970f..788dd9c26e6cd38 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_lower.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_lower.mlir
@@ -1,9 +1,9 @@
-// RUN: mlir-opt %s -sparsification | FileCheck %s --check-prefix=CHECK-HIR
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification | FileCheck %s --check-prefix=CHECK-HIR
//
-// RUN: mlir-opt %s -sparsification --sparse-tensor-conversion --cse | \
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification --sparse-tensor-conversion --cse | \
// RUN: FileCheck %s --check-prefix=CHECK-MIR
//
-// RUN: mlir-opt %s -sparsification --sparse-tensor-conversion --cse \
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification --sparse-tensor-conversion --cse \
// RUN: --func-bufferize --arith-bufferize \
// RUN: --tensor-bufferize --finalizing-bufferize | \
// RUN: FileCheck %s --check-prefix=CHECK-LIR
diff --git a/mlir/test/Dialect/SparseTensor/sparse_lower_col.mlir b/mlir/test/Dialect/SparseTensor/sparse_lower_col.mlir
index a987d59f6773139..da08f53bd089bd5 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_lower_col.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_lower_col.mlir
@@ -1,9 +1,9 @@
-// RUN: mlir-opt %s -sparsification | FileCheck %s --check-prefix=CHECK-HIR
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification | FileCheck %s --check-prefix=CHECK-HIR
//
-// RUN: mlir-opt %s -sparsification --sparse-tensor-conversion --cse | \
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification --sparse-tensor-conversion --cse | \
// RUN: FileCheck %s --check-prefix=CHECK-MIR
//
-// RUN: mlir-opt %s -sparsification --sparse-tensor-conversion --cse \
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification --sparse-tensor-conversion --cse \
// RUN: --func-bufferize --arith-bufferize \
// RUN: --tensor-bufferize --finalizing-bufferize | \
// RUN: FileCheck %s --check-prefix=CHECK-LIR
@@ -29,9 +29,10 @@
// CHECK-HIR-DAG: %[[VAL_3:.*]] = arith.constant 64 : index
// CHECK-HIR-DAG: %[[VAL_4:.*]] = arith.constant 0 : index
// CHECK-HIR-DAG: %[[VAL_5:.*]] = arith.constant 1 : index
-// CHECK-HIR-DAG: %[[VAL_6:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<32x64xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
-// CHECK-HIR-DAG: %[[VAL_7:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<32x64xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
-// CHECK-HIR-DAG: %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x64xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xf64>
+// CHECK-HIR: %[[DEMAP:.*]] = sparse_tensor.reinterpret_map %[[VAL_0]]
+// CHECK-HIR-DAG: %[[VAL_6:.*]] = sparse_tensor.positions %[[DEMAP]] {level = 1 : index}
+// CHECK-HIR-DAG: %[[VAL_7:.*]] = sparse_tensor.coordinates %[[DEMAP]] {level = 1 : index}
+// CHECK-HIR-DAG: %[[VAL_8:.*]] = sparse_tensor.values %[[DEMAP]]
// CHECK-HIR-DAG: %[[VAL_9:.*]] = bufferization.to_memref %[[VAL_1]] : memref<64xf64>
// CHECK-HIR-DAG: %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32xf64>
// CHECK-HIR: scf.for %[[VAL_12:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] {
diff --git a/mlir/test/Dialect/SparseTensor/sparse_lower_inplace.mlir b/mlir/test/Dialect/SparseTensor/sparse_lower_inplace.mlir
index 0e09f658fa15caa..eb6920304ed8479 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_lower_inplace.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_lower_inplace.mlir
@@ -1,9 +1,9 @@
-// RUN: mlir-opt %s -sparsification | FileCheck %s --check-prefix=CHECK-HIR
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification | FileCheck %s --check-prefix=CHECK-HIR
//
-// RUN: mlir-opt %s -sparsification --sparse-tensor-conversion --cse | \
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification --sparse-tensor-conversion --cse | \
// RUN: FileCheck %s --check-prefix=CHECK-MIR
//
-// RUN: mlir-opt %s -sparsification --sparse-tensor-conversion --cse \
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification --sparse-tensor-conversion --cse \
// RUN: --func-bufferize --arith-bufferize \
// RUN: --tensor-bufferize --finalizing-bufferize | \
// RUN: FileCheck %s --check-prefix=CHECK-LIR
diff --git a/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir b/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir
index 12d443ca679207e..5145d6c1dcfc322 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir
@@ -1,7 +1,7 @@
// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
// RUN: mlir-opt %s --linalg-generalize-named-ops \
-// RUN: --sparsification --sparse-tensor-codegen \
+// RUN: --sparse-reinterpret-map --sparsification --sparse-tensor-codegen \
// RUN: --canonicalize --cse | FileCheck %s
#CSR = #sparse_tensor.encoding<{
diff --git a/mlir/test/Dialect/SparseTensor/sparse_nd.mlir b/mlir/test/Dialect/SparseTensor/sparse_nd.mlir
index d5dc16bed20c7af..27716a82c164fa0 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_nd.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_nd.mlir
@@ -1,5 +1,5 @@
// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
-// RUN: mlir-opt %s -sparsification | FileCheck %s
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification | FileCheck %s
// Example with cyclic iteration graph with sparse and dense constraints,
// but an acyclic iteration graph using sparse constraints only.
diff --git a/mlir/test/Dialect/SparseTensor/sparse_out.mlir b/mlir/test/Dialect/SparseTensor/sparse_out.mlir
index 794eb39cb2fcb58..8bf0625425ead11 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_out.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_out.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -sparsification | FileCheck %s
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification | FileCheck %s
#CSR = #sparse_tensor.encoding<{
map = (d0, d1) -> (d0 : dense, d1 : compressed),
diff --git a/mlir/test/Dialect/SparseTensor/sparse_outbuf.mlir b/mlir/test/Dialect/SparseTensor/sparse_outbuf.mlir
index 7dbb1c7d9490877..1028b58be37df30 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_outbuf.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_outbuf.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -sparsification | FileCheck %s
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification | FileCheck %s
#SV = #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>
@@ -110,4 +110,3 @@ func.func @update_inplace(%arga: tensor<10xf32, #SV>,
} -> tensor<10xf32>
return %0 : tensor<10xf32>
}
-
diff --git a/mlir/test/Dialect/SparseTensor/sparse_parallel.mlir b/mlir/test/Dialect/SparseTensor/sparse_parallel.mlir
index dab70e6a3e6f139..1db7cf85ea5fb78 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_parallel.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_parallel.mlir
@@ -1,12 +1,12 @@
-// RUN: mlir-opt %s -sparsification="parallelization-strategy=none" | \
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification="parallelization-strategy=none" | \
// RUN: FileCheck %s --check-prefix=CHECK-PAR0
-// RUN: mlir-opt %s -sparsification="parallelization-strategy=dense-outer-loop" | \
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification="parallelization-strategy=dense-outer-loop" | \
// RUN: FileCheck %s --check-prefix=CHECK-PAR1
-// RUN: mlir-opt %s -sparsification="parallelization-strategy=any-storage-outer-loop" | \
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification="parallelization-strategy=any-storage-outer-loop" | \
// RUN: FileCheck %s --check-prefix=CHECK-PAR2
-// RUN: mlir-opt %s -sparsification="parallelization-strategy=dense-any-loop" | \
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification="parallelization-strategy=dense-any-loop" | \
// RUN: FileCheck %s --check-prefix=CHECK-PAR3
-// RUN: mlir-opt %s -sparsification="parallelization-strategy=any-storage-any-loop" | \
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification="parallelization-strategy=any-storage-any-loop" | \
// RUN: FileCheck %s --check-prefix=CHECK-PAR4
#DenseMatrix = #sparse_tensor.encoding<{
diff --git a/mlir/test/Dialect/SparseTensor/sparse_parallel_reduce.mlir b/mlir/test/Dialect/SparseTensor/sparse_parallel_reduce.mlir
index 26f5a05b64697ac..8118eeceb6a62ee 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_parallel_reduce.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_parallel_reduce.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -sparsification="parallelization-strategy=any-storage-any-loop" | \
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification="parallelization-strategy=any-storage-any-loop" | \
// RUN: FileCheck %s
#CSR = #sparse_tensor.encoding<{
diff --git a/mlir/test/Dialect/SparseTensor/sparse_perm.mlir b/mlir/test/Dialect/SparseTensor/sparse_perm.mlir
index 186030aa2ca0873..83cbfa771f1d626 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_perm.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_perm.mlir
@@ -1,5 +1,5 @@
// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
-// RUN: mlir-opt %s -sparsification | FileCheck %s
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification | FileCheck %s
#X = #sparse_tensor.encoding<{
map = (d0, d1, d2) -> (d2 : dense, d0 : dense, d1 : dense)
@@ -22,7 +22,8 @@
// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 10 : index
// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<10x20x30xf32, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK: %[[DEMAP:.*]] = sparse_tensor.reinterpret_map %[[VAL_0]]
+// CHECK-DAG: %[[VAL_7:.*]] = sparse_tensor.values %[[DEMAP]] : tensor<30x10x20xf32, #sparse_tensor.encoding<{{{.*}}}>>
// CHECK-DAG: %[[VAL_9:.*]] = bufferization.to_memref %[[VAL_1]] : memref<20x30x10xf32>
// CHECK: linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_9]] : memref<20x30x10xf32>)
// CHECK: scf.for %[[VAL_10:.*]] = %[[VAL_5]] to %[[VAL_3]] step %[[VAL_6]] {
@@ -58,10 +59,11 @@ func.func @sparse_static_dims(%arga: tensor<10x20x30xf32, #X>,
// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 2 : index
// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[VAL_5:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<?x?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
-// CHECK-DAG: %[[VAL_6:.*]] = sparse_tensor.lvl %[[VAL_0]], %[[VAL_2]] : tensor<?x?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
-// CHECK-DAG: %[[VAL_7:.*]] = sparse_tensor.lvl %[[VAL_0]], %[[VAL_3]] : tensor<?x?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
-// CHECK-DAG: %[[VAL_8:.*]] = sparse_tensor.lvl %[[VAL_0]], %[[VAL_4]] : tensor<?x?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK: %[[DEMAP:.*]] = sparse_tensor.reinterpret_map %[[VAL_0]]
+// CHECK-DAG: %[[VAL_5:.*]] = sparse_tensor.values %[[DEMAP]] : tensor<?x?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK-DAG: %[[VAL_6:.*]] = sparse_tensor.lvl %[[DEMAP]], %[[VAL_2]] : tensor<?x?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK-DAG: %[[VAL_7:.*]] = sparse_tensor.lvl %[[DEMAP]], %[[VAL_3]] : tensor<?x?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK-DAG: %[[VAL_8:.*]] = sparse_tensor.lvl %[[DEMAP]], %[[VAL_4]] : tensor<?x?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
// CHECK-DAG: %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_1]] : memref<?x?x?xf32>
// CHECK: linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_10]] : memref<?x?x?xf32>)
// CHECK: scf.for %[[VAL_11:.*]] = %[[VAL_3]] to %[[VAL_7]] step %[[VAL_4]] {
diff --git a/mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir b/mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir
index 02738a9e4544ac3..1078e5c282df11a 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir
@@ -1,6 +1,6 @@
-// RUN: mlir-opt %s -sparsification --canonicalize | FileCheck %s --check-prefix=CHECK-HIR
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification --canonicalize | FileCheck %s --check-prefix=CHECK-HIR
//
-// RUN: mlir-opt %s -sparsification --sparse-tensor-conversion --canonicalize | \
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification --sparse-tensor-conversion --canonicalize | \
// RUN: FileCheck %s --check-prefix=CHECK-MIR
#X = #sparse_tensor.encoding<{
@@ -21,10 +21,11 @@
// CHECK-HIR-DAG: %[[VAL_2:.*]] = arith.constant 1 : index
// CHECK-HIR-DAG: %[[VAL_3:.*]] = arith.constant 0 : index
// CHECK-HIR-DAG: %[[VAL_4:.*]] = arith.constant 2 : index
-// CHECK-HIR-DAG: %[[VAL_5:.*]] = sparse_tensor.lvl %[[VAL_0]], %[[VAL_3]] : tensor<?x?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
-// CHECK-HIR-DAG: %[[VAL_6:.*]] = sparse_tensor.lvl %[[VAL_0]], %[[VAL_2]] : tensor<?x?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
-// CHECK-HIR-DAG: %[[VAL_7:.*]] = sparse_tensor.lvl %[[VAL_0]], %[[VAL_4]] : tensor<?x?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
-// CHECK-HIR-DAG: %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<?x?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK-HIR: %[[DEMAP:. *]] = sparse_tensor.reinterpret_map %[[VAL_0]]
+// CHECK-HIR-DAG: %[[VAL_5:.*]] = sparse_tensor.lvl %[[DEMAP]], %[[VAL_3]] : tensor<?x?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK-HIR-DAG: %[[VAL_6:.*]] = sparse_tensor.lvl %[[DEMAP]], %[[VAL_2]] : tensor<?x?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK-HIR-DAG: %[[VAL_7:.*]] = sparse_tensor.lvl %[[DEMAP]], %[[VAL_4]] : tensor<?x?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK-HIR-DAG: %[[VAL_8:.*]] = sparse_tensor.values %[[DEMAP]] : tensor<?x?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
// CHECK-HIR-DAG: %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_1]] : memref<f32>
// CHECK-HIR: %[[VAL_11:.*]] = tensor.extract %[[VAL_1]][] : tensor<f32>
// CHECK-HIR: %[[VAL_12:.*]] = scf.for %[[VAL_13:.*]] = %[[VAL_3]] to %[[VAL_5]] step %[[VAL_2]] iter_args(%[[VAL_14:.*]] = %[[VAL_11]]) -> (f32) {
diff --git a/mlir/test/Dialect/SparseTensor/sparse_scalars.mlir b/mlir/test/Dialect/SparseTensor/sparse_scalars.mlir
index 85ab2a654098ad6..4ca577b9f729031 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_scalars.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_scalars.mlir
@@ -1,5 +1,5 @@
// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
-// RUN: mlir-opt %s -sparsification | FileCheck %s
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification | FileCheck %s
#SparseMatrix = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed, d1 : compressed) }>
diff --git a/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir b/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir
index 55deb91c5a1bb4f..1d95fe8d0569a23 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --test-tensor-copy-insertion --pre-sparsification-rewrite --sparsification --cse | FileCheck %s
+// RUN: mlir-opt %s --test-tensor-copy-insertion --pre-sparsification-rewrite --sparse-reinterpret-map --sparsification --cse | FileCheck %s
#SM = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed, d1 : compressed) }>
diff --git a/mlir/test/Dialect/SparseTensor/sparse_sddmm_org.mlir b/mlir/test/Dialect/SparseTensor/sparse_sddmm_org.mlir
index ed5b72d0d872cbc..79bedcf5a49e16c 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_sddmm_org.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_sddmm_org.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --pre-sparsification-rewrite --sparsification --cse | FileCheck %s
+// RUN: mlir-opt %s --pre-sparsification-rewrite --sparse-reinterpret-map --sparsification --cse | FileCheck %s
#SM = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed, d1 : compressed) }>
diff --git a/mlir/test/Dialect/SparseTensor/sparse_storage.mlir b/mlir/test/Dialect/SparseTensor/sparse_storage.mlir
index 65ea2d131277980..a8cc056139e49e6 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_storage.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_storage.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -sparsification= | FileCheck %s
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification= | FileCheck %s
#SparseVector64 = #sparse_tensor.encoding<{
map = (d0) -> (d0 : compressed),
diff --git a/mlir/test/Dialect/SparseTensor/sparse_transpose.mlir b/mlir/test/Dialect/SparseTensor/sparse_transpose.mlir
index b61136553d2679a..fd41fedf91c2cbb 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_transpose.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_transpose.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -sparsification | FileCheck %s
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification | FileCheck %s
#DCSR = #sparse_tensor.encoding<{
map = (d0, d1) -> (d0 : compressed, d1 : compressed)
@@ -21,11 +21,12 @@
// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[VAL_3:.*]] = tensor.empty() : tensor<4x3xf64, #sparse_tensor.encoding<{{{.*}}}>>
// CHECK-DAG: %[[VAL_4:.*]] = sparse_tensor.convert %[[VAL_0]] : tensor<3x4xf64, #sparse_tensor.encoding<{{{.*}}}>> to tensor<3x4xf64, #sparse_tensor.encoding<{{{.*}}}>>
-// CHECK-DAG: %[[VAL_5:.*]] = sparse_tensor.positions %[[VAL_4]] {level = 0 : index} : tensor<3x4xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
-// CHECK-DAG: %[[VAL_6:.*]] = sparse_tensor.coordinates %[[VAL_4]] {level = 0 : index} : tensor<3x4xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
-// CHECK-DAG: %[[VAL_7:.*]] = sparse_tensor.positions %[[VAL_4]] {level = 1 : index} : tensor<3x4xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
-// CHECK-DAG: %[[VAL_8:.*]] = sparse_tensor.coordinates %[[VAL_4]] {level = 1 : index} : tensor<3x4xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
-// CHECK-DAG: %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_4]] : tensor<3x4xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xf64>
+// CHECK: %[[DEMAP:.*]] = sparse_tensor.reinterpret_map %[[VAL_4]] : tensor<3x4xf64, #sparse_tensor.encoding<{{{.*}}}>> to tensor<4x3xf64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK-DAG: %[[VAL_5:.*]] = sparse_tensor.positions %[[DEMAP]] {level = 0 : index} : tensor<4x3xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
+// CHECK-DAG: %[[VAL_6:.*]] = sparse_tensor.coordinates %[[DEMAP]] {level = 0 : index} : tensor<4x3xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
+// CHECK-DAG: %[[VAL_7:.*]] = sparse_tensor.positions %[[DEMAP]] {level = 1 : index} : tensor<4x3xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
+// CHECK-DAG: %[[VAL_8:.*]] = sparse_tensor.coordinates %[[DEMAP]] {level = 1 : index} : tensor<4x3xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
+// CHECK-DAG: %[[VAL_9:.*]] = sparse_tensor.values %[[DEMAP]] : tensor<4x3xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xf64>
// CHECK: %[[VAL_10:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_1]]] : memref<?xindex>
// CHECK: %[[VAL_11:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_2]]] : memref<?xindex>
// CHECK: %[[VAL_12:.*]] = scf.for %[[VAL_13:.*]] = %[[VAL_10]] to %[[VAL_11]] step %[[VAL_2]] iter_args(%[[VAL_14:.*]] = %[[VAL_3]]) -> (tensor<4x3xf64, #sparse_tensor.encoding<{{{.*}}}>>) {
@@ -42,7 +43,6 @@
// CHECK: scf.yield %[[VAL_25:.*]] : tensor<4x3xf64, #sparse_tensor.encoding<{{{.*}}}>>
// CHECK: }
// CHECK: %[[VAL_26:.*]] = sparse_tensor.load %[[VAL_27:.*]] hasInserts : tensor<4x3xf64, #sparse_tensor.encoding<{{{.*}}}>>
-// CHECK: bufferization.dealloc_tensor %[[VAL_4]] : tensor<3x4xf64, #sparse_tensor.encoding<{{{.*}}}>>
// CHECK: return %[[VAL_26]] : tensor<4x3xf64, #sparse_tensor.encoding<{{{.*}}}>>
// CHECK: }
func.func @sparse_transpose_auto(%arga: tensor<3x4xf64, #DCSR>)
diff --git a/mlir/test/Dialect/SparseTensor/sparse_vector.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector.mlir
index 48ba9119c4f44ea..364ba6e71ff3bb9 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_vector.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_vector.mlir
@@ -1,10 +1,10 @@
-// RUN: mlir-opt %s -sparsification -cse -split-input-file | \
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification -cse -split-input-file | \
// RUN: FileCheck %s --check-prefix=CHECK-SCALAR
-// RUN: mlir-opt %s -sparsification -cse -sparse-vectorization="vl=16" -cse -split-input-file | \
+// RUN: mlir-opt %s --sparse-reinterpret-map --sparse-reinterpret-map -sparsification -cse -sparse-vectorization="vl=16" -cse -split-input-file | \
// RUN: FileCheck %s --check-prefix=CHECK-VEC16
-// RUN: mlir-opt %s -sparsification -cse -sparse-vectorization="vl=16 enable-simd-index32=true" -cse -split-input-file | \
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification -cse -sparse-vectorization="vl=16 enable-simd-index32=true" -cse -split-input-file | \
// RUN: FileCheck %s --check-prefix=CHECK-VEC16-IDX32
-// RUN: mlir-opt %s -sparsification -cse -sparse-vectorization="vl=4 enable-vla-vectorization=true" -cse -split-input-file | \
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification -cse -sparse-vectorization="vl=4 enable-vla-vectorization=true" -cse -split-input-file | \
// RUN: FileCheck %s --check-prefix=CHECK-VEC4-SVE
#DenseVector = #sparse_tensor.encoding<{ map = (d0) -> (d0 : dense) }>
diff --git a/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir
index b8a4563bd63d87f..269a4926c9ec375 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -sparsification -cse -sparse-vectorization="vl=8" -cse | \
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification -cse -sparse-vectorization="vl=8" -cse | \
// RUN: FileCheck %s
#SparseMatrix = #sparse_tensor.encoding<{map = (d0, d1) -> (d0 : dense, d1 : compressed)}>
diff --git a/mlir/test/Dialect/SparseTensor/sparse_vector_index.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector_index.mlir
index 31a78ceb0e2b36b..9bc24fd02b827d0 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_vector_index.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_vector_index.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -sparsification -cse -sparse-vectorization="vl=8" -cse | \
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification -cse -sparse-vectorization="vl=8" -cse | \
// RUN: FileCheck %s
// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
diff --git a/mlir/test/Dialect/SparseTensor/sparse_vector_ops.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector_ops.mlir
index 67841beaa6933f6..fc6741c1f69a6cd 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_vector_ops.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_vector_ops.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -sparsification -cse -sparse-vectorization="vl=8" -cse | \
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification -cse -sparse-vectorization="vl=8" -cse | \
// RUN: FileCheck %s
#DenseVector = #sparse_tensor.encoding<{ map = (d0) -> (d0 : dense) }>
@@ -84,4 +84,3 @@ func.func @vops(%arga: tensor<1024xf32, #DenseVector>,
} -> tensor<1024xf32>
return %0 : tensor<1024xf32>
}
-
diff --git a/mlir/test/Dialect/SparseTensor/sparse_vector_peeled.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector_peeled.mlir
index db830f29d7ca655..99d6a3dc390e086 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_vector_peeled.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_vector_peeled.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --sparsification -cse -sparse-vectorization="vl=16" -scf-for-loop-peeling -canonicalize -cse | \
+// RUN: mlir-opt %s --sparse-reinterpret-map --sparsification -cse -sparse-vectorization="vl=16" -scf-for-loop-peeling -canonicalize -cse | \
// RUN: FileCheck %s
#SparseVector = #sparse_tensor.encoding<{
diff --git a/mlir/test/Dialect/SparseTensor/spy_sddmm.mlir b/mlir/test/Dialect/SparseTensor/spy_sddmm.mlir
index eaa15d7f83bc47f..a4211991f260833 100644
--- a/mlir/test/Dialect/SparseTensor/spy_sddmm.mlir
+++ b/mlir/test/Dialect/SparseTensor/spy_sddmm.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -sparsification | FileCheck %s
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification | FileCheck %s
//
// A SDDMM implementation with "spy" function and
diff --git a/mlir/test/Dialect/SparseTensor/unsparsifiable_dense_op.mlir b/mlir/test/Dialect/SparseTensor/unsparsifiable_dense_op.mlir
index 7112d7dc8ca900b..ea29f1d677eff7b 100644
--- a/mlir/test/Dialect/SparseTensor/unsparsifiable_dense_op.mlir
+++ b/mlir/test/Dialect/SparseTensor/unsparsifiable_dense_op.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -sparsification | FileCheck %s
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification | FileCheck %s
#trait = {
indexing_maps = [
diff --git a/mlir/test/Dialect/SparseTensor/unused-tensor.mlir b/mlir/test/Dialect/SparseTensor/unused-tensor.mlir
index 5f169dd989bdc48..330bb9fa57483f5 100644
--- a/mlir/test/Dialect/SparseTensor/unused-tensor.mlir
+++ b/mlir/test/Dialect/SparseTensor/unused-tensor.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -sparsification | FileCheck %s
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification | FileCheck %s
//
// A contrived example where the sparse tensor B is only
diff --git a/mlir/test/Dialect/SparseTensor/vectorize_reduction.mlir b/mlir/test/Dialect/SparseTensor/vectorize_reduction.mlir
index a4f58a4fa913f0c..ef3c0f88d4222f2 100644
--- a/mlir/test/Dialect/SparseTensor/vectorize_reduction.mlir
+++ b/mlir/test/Dialect/SparseTensor/vectorize_reduction.mlir
@@ -1,6 +1,6 @@
-// RUN: mlir-opt %s -sparsification -cse -sparse-vectorization="vl=8" -cse -split-input-file | \
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification -cse -sparse-vectorization="vl=8" -cse -split-input-file | \
// RUN: FileCheck %s --check-prefix=CHECK-ON
-// RUN: mlir-opt %s -sparsification -cse -split-input-file | \
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification -cse -split-input-file | \
// RUN: FileCheck %s --check-prefix=CHECK-OFF
// -----
>From b74de428b0dad7bcc8880b8019de241e58f73fe0 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Wed, 15 Nov 2023 18:45:49 +0000
Subject: [PATCH 2/4] address comments.
---
.../Dialect/SparseTensor/IR/SparseTensor.h | 3 ++-
.../SparseTensor/Transforms/CMakeLists.txt | 2 +-
...Scheduler.cpp => IterationGraphSorter.cpp} | 22 +++++++++----------
...LoopScheduler.h => IterationGraphSorter.h} | 13 ++++++-----
.../Transforms/SparseReinterpretMap.cpp | 10 ++++-----
.../Dialect/SparseTensor/GPU/gpu_matmul.mlir | 2 +-
.../Dialect/SparseTensor/sparse_kernels.mlir | 6 ++---
7 files changed, 30 insertions(+), 28 deletions(-)
rename mlir/lib/Dialect/SparseTensor/Transforms/{LoopScheduler.cpp => IterationGraphSorter.cpp} (93%)
rename mlir/lib/Dialect/SparseTensor/Transforms/{LoopScheduler.h => IterationGraphSorter.h} (82%)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
index bb8df626086bb29..eb7c50ae2efdf8c 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
@@ -127,7 +127,8 @@ inline bool hasAnySparseOperandOrResult(Operation *op) {
return hasAnySparseOperand(op) || hasAnySparseResult(op);
}
-/// Whether the operation has any sparse tensor with non-identity dim2lvl maps.
+/// Returns true iff MLIR operation has any sparse tensor with non-identity
+/// dim2lvl maps.
bool hasAnyNonIdentityOperandsOrResults(Operation *op);
//
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
index 06d0ec1d7eb7cf3..8459e46e5814f06 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
@@ -2,8 +2,8 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms
BufferizableOpInterfaceImpl.cpp
CodegenEnv.cpp
CodegenUtils.cpp
+ IterationGraphSorter.cpp
LoopEmitter.cpp
- LoopScheduler.cpp
SparseBufferRewriting.cpp
SparseGPUCodegen.cpp
SparseReinterpretMap.cpp
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopScheduler.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/IterationGraphSorter.cpp
similarity index 93%
rename from mlir/lib/Dialect/SparseTensor/Transforms/LoopScheduler.cpp
rename to mlir/lib/Dialect/SparseTensor/Transforms/IterationGraphSorter.cpp
index 32e0a6f5c03d8a2..7b1e22dcba5493c 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopScheduler.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/IterationGraphSorter.cpp
@@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//
-#include "LoopScheduler.h"
+#include "IterationGraphSorter.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
@@ -85,7 +85,7 @@ inline static bool includesDenseOutput(SortMask mask) {
/// as we use adj matrix for the graph.
/// The sorted result will put the first Reduction iterator to the
/// latest possible position.
-AffineMap LoopScheduler::topoSort() {
+AffineMap IterationGraphSorter::topoSort() {
std::vector<unsigned> redIt; // reduce iterator with 0 degree
std::vector<unsigned> parIt; // parallel iterator with 0 degree
const unsigned numLoops = getNumLoops();
@@ -124,7 +124,8 @@ AffineMap LoopScheduler::topoSort() {
return AffineMap();
}
-LoopScheduler LoopScheduler::fromGenericOp(linalg::GenericOp genericOp) {
+IterationGraphSorter
+IterationGraphSorter::fromGenericOp(linalg::GenericOp genericOp) {
// Must be a demapped sparse kernel.
assert(!hasAnyNonIdentityOperandsOrResults(genericOp) &&
hasAnySparseOperandOrResult(genericOp) &&
@@ -140,14 +141,13 @@ LoopScheduler LoopScheduler::fromGenericOp(linalg::GenericOp genericOp) {
SmallVector<utils::IteratorType> iterTypes =
genericOp.getIteratorTypesArray();
- return LoopScheduler(std::move(ins), std::move(loopMap), out, outMap,
- std::move(iterTypes));
+ return IterationGraphSorter(std::move(ins), std::move(loopMap), out, outMap,
+ std::move(iterTypes));
}
-LoopScheduler::LoopScheduler(SmallVector<Value> &&ins,
- SmallVector<AffineMap> &&loop2InsLvl, Value out,
- AffineMap loop2OutLvl,
- SmallVector<utils::IteratorType> &&iterTypes)
+IterationGraphSorter::IterationGraphSorter(
+ SmallVector<Value> &&ins, SmallVector<AffineMap> &&loop2InsLvl, Value out,
+ AffineMap loop2OutLvl, SmallVector<utils::IteratorType> &&iterTypes)
: ins(std::move(ins)), loop2InsLvl(std::move(loop2InsLvl)), out(out),
loop2OutLvl(loop2OutLvl), iterTypes(std::move(iterTypes)) {
// One map per tensor.
@@ -166,7 +166,7 @@ LoopScheduler::LoopScheduler(SmallVector<Value> &&ins,
inDegree.resize(getNumLoops());
}
-AffineMap LoopScheduler::schedule(SortMask mask, Value ignored) {
+AffineMap IterationGraphSorter::sort(SortMask mask, Value ignored) {
// Reset the interation graph.
for (auto &row : itGraph)
std::fill(row.begin(), row.end(), false);
@@ -191,7 +191,7 @@ AffineMap LoopScheduler::schedule(SortMask mask, Value ignored) {
return topoSort();
}
-void LoopScheduler::addConstraints(Value t, AffineMap loop2LvlMap) {
+void IterationGraphSorter::addConstraints(Value t, AffineMap loop2LvlMap) {
auto addIterOrdering = [this](unsigned f, unsigned t) {
if (!itGraph[f][t] && f != t) {
itGraph[f][t] = true;
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopScheduler.h b/mlir/lib/Dialect/SparseTensor/Transforms/IterationGraphSorter.h
similarity index 82%
rename from mlir/lib/Dialect/SparseTensor/Transforms/LoopScheduler.h
rename to mlir/lib/Dialect/SparseTensor/Transforms/IterationGraphSorter.h
index 1d5cb1bc6f60e4e..613a8609ac0973a 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopScheduler.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/IterationGraphSorter.h
@@ -33,23 +33,24 @@ enum class SortMask : unsigned {
kSparseOnly = 0x0, // b000
};
-class LoopScheduler {
+class IterationGraphSorter {
public:
// Constructs a scheduler from linalg.generic
// Maybe reuses the class to schedule foreach as well (to address
// non-permutation, e.g, traverse CSR in BSR order).
- static LoopScheduler fromGenericOp(linalg::GenericOp genericOp);
+ static IterationGraphSorter fromGenericOp(linalg::GenericOp genericOp);
// Returns a permutation that represents the scheduled loop order.
// Note that the returned AffineMap could be null if the kernel can not be
// schedule due to cycles in the iteration graph.
- [[nodiscard]] AffineMap schedule(SortMask mask, Value ignored = nullptr);
+ [[nodiscard]] AffineMap sort(SortMask mask, Value ignored = nullptr);
unsigned getNumLoops() const { return loop2OutLvl.getNumDims(); }
private:
- LoopScheduler(SmallVector<Value> &&ins, SmallVector<AffineMap> &&loop2InsLvl,
- Value out, AffineMap loop2OutLvl,
- SmallVector<utils::IteratorType> &&iterTypes);
+ IterationGraphSorter(SmallVector<Value> &&ins,
+ SmallVector<AffineMap> &&loop2InsLvl, Value out,
+ AffineMap loop2OutLvl,
+ SmallVector<utils::IteratorType> &&iterTypes);
void addConstraints(Value t, AffineMap loop2LvlMap);
AffineMap topoSort();
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
index a433b2edc5500e7..5789f9aa12e1b80 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
@@ -7,7 +7,7 @@
//===----------------------------------------------------------------------===//
#include "CodegenUtils.h"
-#include "LoopScheduler.h"
+#include "IterationGraphSorter.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
@@ -415,7 +415,7 @@ struct GenericOpScheduler : public OpRewritePattern<linalg::GenericOp> {
if (linalgOp->hasAttr(sorted))
return failure();
- auto scheduler = LoopScheduler::fromGenericOp(linalgOp);
+ auto scheduler = IterationGraphSorter::fromGenericOp(linalgOp);
bool isAdmissible = false;
AffineMap order;
// A const list of all masks that we used for iteration graph
@@ -427,7 +427,7 @@ struct GenericOpScheduler : public OpRewritePattern<linalg::GenericOp> {
SortMask::kIncludeDenseInput, SortMask::kIncludeDenseOutput,
SortMask::kIncludeUndef, SortMask::kSparseOnly};
for (const SortMask mask : allMasks) {
- order = scheduler.schedule(mask);
+ order = scheduler.sort(mask);
if (order) {
if (isAdmissibleOrder(linalgOp, order)) {
isAdmissible = true;
@@ -506,7 +506,7 @@ struct GenericOpScheduler : public OpRewritePattern<linalg::GenericOp> {
};
// Last resort cycle resolution.
- static LogicalResult resolveCycle(LoopScheduler &scheduler,
+ static LogicalResult resolveCycle(IterationGraphSorter &scheduler,
linalg::LinalgOp linalgOp,
PatternRewriter &rewriter) {
// Compute topological sort while leaving out every sparse input tensor in
@@ -524,7 +524,7 @@ struct GenericOpScheduler : public OpRewritePattern<linalg::GenericOp> {
continue;
// Try scheduling loop without constraints from `tval`.
- AffineMap order = scheduler.schedule(SortMask::kSparseOnly, tval);
+ AffineMap order = scheduler.sort(SortMask::kSparseOnly, tval);
if (!order) // still cyclic
continue;
diff --git a/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul.mlir b/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul.mlir
index e477dcc0169fc3d..a7d2565cff747ea 100644
--- a/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul.mlir
+++ b/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul.mlir
@@ -1,6 +1,6 @@
// RUN: mlir-opt %s --linalg-generalize-named-ops \
// RUN: --pre-sparsification-rewrite \
-// RUN: --sparse-reinterpret-map \
+// RUN: --sparse-reinterpret-map \
// RUN: --sparsification="parallelization-strategy=dense-outer-loop" \
// RUN: --sparse-gpu-codegen | FileCheck %s
diff --git a/mlir/test/Dialect/SparseTensor/sparse_kernels.mlir b/mlir/test/Dialect/SparseTensor/sparse_kernels.mlir
index 1cb23cfd5b9eb2f..276a864327daa9c 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_kernels.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_kernels.mlir
@@ -197,9 +197,9 @@ func.func @matmul2(%A: tensor<4x8xf64, #DCSR>,
// CHECK-LABEL: func.func @conv2d(
-// CHECK-SAME: %[[VAL_0:.*]]: tensor<8x8xi32>,
-// CHECK-SAME: %[[VAL_1:.*]]: tensor<3x3xi32, #sparse_tensor.encoding<{{{.*}}}>>,
-// CHECK-SAME: %[[VAL_2:.*]]: tensor<6x6xi32>) -> tensor<6x6xi32> {
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<8x8xi32>,
+// CHECK-SAME: %[[VAL_1:.*]]: tensor<3x3xi32, #sparse_tensor.encoding<{{{.*}}}>>,
+// CHECK-SAME: %[[VAL_2:.*]]: tensor<6x6xi32>) -> tensor<6x6xi32> {
// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 6 : index
// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 1 : index
>From 0d8dd6335fb838e474ddfc9e9155331e4422a561 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Wed, 15 Nov 2023 18:49:59 +0000
Subject: [PATCH 3/4] cleanup
---
mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp | 8 +++-----
1 file changed, 3 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
index 3df9bcc77717994..3a02d5634586070 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
@@ -192,11 +192,9 @@ bool CodegenEnv::isAdmissibleTensorExp(ExprId exp) {
const auto iteratorTypes = linalgOp.getIteratorTypesArray();
assert(topSortSize() == latticeMerger.getNumLoops());
for (const LoopId i : topSort) {
- if (!latticeMerger.isFilterLoop(i)) {
- if (linalg::isReductionIterator(iteratorTypes[i]))
- break; // terminate at first reduction
- outerParNest++;
- }
+ if (linalg::isReductionIterator(iteratorTypes[i]))
+ break; // terminate at first reduction
+ outerParNest++;
}
// Inadmissible kernel should have already been rejected by the previous
>From 2a4682e11b50a7db1a423ff5cc927b6650cf1f9b Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Wed, 15 Nov 2023 19:01:00 +0000
Subject: [PATCH 4/4] address comments.
---
.../SparseTensor/Transforms/IterationGraphSorter.cpp | 10 +++++-----
.../SparseTensor/Transforms/SparseReinterpretMap.cpp | 2 +-
2 files changed, 6 insertions(+), 6 deletions(-)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/IterationGraphSorter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/IterationGraphSorter.cpp
index 7b1e22dcba5493c..b6011727f4127cf 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/IterationGraphSorter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/IterationGraphSorter.cpp
@@ -30,7 +30,7 @@ class AffineDimFinder : public AffineExprVisitor<AffineDimFinder> {
explicit AffineDimFinder(ArrayRef<utils::IteratorType> itTypes)
: iterTypes(itTypes) {}
- // Overrides method from AffineExprVisitor.
+ // Override method from AffineExprVisitor.
void visitDimExpr(AffineDimExpr expr) {
if (pickedDim == nullptr || pickIterType == iterTypes[expr.getPosition()])
pickedDim = expr;
@@ -176,7 +176,7 @@ AffineMap IterationGraphSorter::sort(SortMask mask, Value ignored) {
for (auto [in, map] : llvm::zip(ins, loop2InsLvl)) {
// Get map and encoding.
const auto enc = getSparseTensorEncoding(in.getType());
- // Skips dense inputs when not requested.
+ // Skip dense inputs when not requested.
if ((!enc && !includesDenseInput(mask)) || in == ignored)
continue;
@@ -203,7 +203,7 @@ void IterationGraphSorter::addConstraints(Value t, AffineMap loop2LvlMap) {
finder.setPickedIterType(utils::IteratorType::reduction);
// To compute iteration graph for tensor[d0 + d1 + d3, d4 + d5 + d6],
- // we requires there exist d_x \in {d0, d1, d3} and d_y \in {d4, d5, d6},
+ // we require there exist d_x \in {d0, d1, d3} and d_y \in {d4, d5, d6},
// and d_x > d_y && {d0, d1, d3} - d_x > {d4, d5, d6} - d_y
const Level lvlRank = loop2LvlMap.getNumResults();
for (Level lvl = 1; lvl < lvlRank; lvl++) {
@@ -211,7 +211,7 @@ void IterationGraphSorter::addConstraints(Value t, AffineMap loop2LvlMap) {
const AffineExpr ta = loop2LvlMap.getResult(lvl);
if (llvm::isa<AffineDimExpr>(fa) || llvm::isa<AffineDimExpr>(ta)) {
- // Special case when at least one loop2LvlExp is an simple AffineDimExpr
+ // Special case when at least one loop2LvlExp is a simple AffineDimExpr
// (say, d0) and we require d0 > {d1, d2, ...} or {d1, d2, ...} > d0
AffineDimCollector fCollector;
fCollector.walkPostOrder(fa);
@@ -246,7 +246,7 @@ void IterationGraphSorter::addConstraints(Value t, AffineMap loop2LvlMap) {
AffineDimCollector tCollector;
tCollector.walkPostOrder(ta);
- // make sure dx and dy is the last;
+ // Make sure dx and dy is the last.
for (auto fd : fCollector.dims) {
const unsigned f = fd.getPosition();
addIterOrdering(f, fldx);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
index 5789f9aa12e1b80..268bd8fbe27387f 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
@@ -515,7 +515,7 @@ struct GenericOpScheduler : public OpRewritePattern<linalg::GenericOp> {
Value tval = t->get();
auto srcEnc = getSparseTensorEncoding(tval.getType());
// The constraints introduced by compound index expression are
- // complicated. Skips them.
+ // complicated. Skip them.
AffineMap idxMap = linalgOp.getMatchingIndexingMap(t);
bool hasCompExpr = llvm::any_of(idxMap.getResults(), [](AffineExpr exp) {
return !llvm::isa<AffineDimExpr>(exp);
More information about the Mlir-commits
mailing list