[Mlir-commits] [mlir] [mlir][sparse] schedule sparse kernels in a separate pass from sparsification. (PR #72423)

Peiming Liu llvmlistbot at llvm.org
Wed Nov 15 10:22:37 PST 2023


https://github.com/PeimingLiu created https://github.com/llvm/llvm-project/pull/72423

None

>From 3ab8768197445e75b6bc4067ace2c91abf21ca38 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Tue, 14 Nov 2023 20:55:50 +0000
Subject: [PATCH] [mlir][sparse] schedule sparse loops

---
 .../Dialect/SparseTensor/IR/SparseTensor.h    |   3 +
 .../SparseTensor/IR/SparseTensorDialect.cpp   |  10 +
 .../SparseTensor/Transforms/CMakeLists.txt    |   1 +
 .../SparseTensor/Transforms/CodegenEnv.cpp    |  37 +-
 .../SparseTensor/Transforms/LoopScheduler.cpp | 273 ++++++++
 .../SparseTensor/Transforms/LoopScheduler.h   |  73 +++
 .../Transforms/SparseReinterpretMap.cpp       | 195 +++++-
 .../Transforms/Sparsification.cpp             | 587 +-----------------
 .../Dialect/SparseTensor/GPU/gpu_combi.mlir   |   2 +-
 .../Dialect/SparseTensor/GPU/gpu_matmul.mlir  |   1 +
 .../Dialect/SparseTensor/GPU/gpu_matvec.mlir  |   1 +
 .../SparseTensor/constant_index_map.mlir      |   2 +-
 mlir/test/Dialect/SparseTensor/dense.mlir     |   2 +-
 mlir/test/Dialect/SparseTensor/one_trip.mlir  |   2 +-
 mlir/test/Dialect/SparseTensor/semi_ring.mlir |   2 +-
 .../test/Dialect/SparseTensor/sorted_coo.mlir |   2 +-
 mlir/test/Dialect/SparseTensor/sparse_1d.mlir |   2 +-
 mlir/test/Dialect/SparseTensor/sparse_2d.mlir |   2 +-
 mlir/test/Dialect/SparseTensor/sparse_3d.mlir |   2 +-
 .../Dialect/SparseTensor/sparse_affine.mlir   |   2 +-
 .../SparseTensor/sparse_broadcast.mlir        |   2 +-
 .../sparse_conv_2d_slice_based.mlir           |   2 +-
 .../Dialect/SparseTensor/sparse_expand.mlir   |   5 +-
 .../SparseTensor/sparse_fill_zero.mlir        |   2 +-
 .../Dialect/SparseTensor/sparse_fp_ops.mlir   |   2 +-
 .../Dialect/SparseTensor/sparse_index.mlir    |   2 +-
 .../Dialect/SparseTensor/sparse_int_ops.mlir  |   3 +-
 .../Dialect/SparseTensor/sparse_kernels.mlir  |  54 +-
 .../Dialect/SparseTensor/sparse_lower.mlir    |   6 +-
 .../SparseTensor/sparse_lower_col.mlir        |  13 +-
 .../SparseTensor/sparse_lower_inplace.mlir    |   6 +-
 .../SparseTensor/sparse_matmul_codegen.mlir   |   2 +-
 mlir/test/Dialect/SparseTensor/sparse_nd.mlir |   2 +-
 .../test/Dialect/SparseTensor/sparse_out.mlir |   2 +-
 .../Dialect/SparseTensor/sparse_outbuf.mlir   |   3 +-
 .../Dialect/SparseTensor/sparse_parallel.mlir |  10 +-
 .../SparseTensor/sparse_parallel_reduce.mlir  |   2 +-
 .../Dialect/SparseTensor/sparse_perm.mlir     |  14 +-
 .../SparseTensor/sparse_perm_lower.mlir       |  13 +-
 .../Dialect/SparseTensor/sparse_scalars.mlir  |   2 +-
 .../Dialect/SparseTensor/sparse_sddmm.mlir    |   2 +-
 .../SparseTensor/sparse_sddmm_org.mlir        |   2 +-
 .../Dialect/SparseTensor/sparse_storage.mlir  |   2 +-
 .../SparseTensor/sparse_transpose.mlir        |  14 +-
 .../Dialect/SparseTensor/sparse_vector.mlir   |   8 +-
 .../SparseTensor/sparse_vector_chain.mlir     |   2 +-
 .../SparseTensor/sparse_vector_index.mlir     |   2 +-
 .../SparseTensor/sparse_vector_ops.mlir       |   3 +-
 .../SparseTensor/sparse_vector_peeled.mlir    |   2 +-
 mlir/test/Dialect/SparseTensor/spy_sddmm.mlir |   2 +-
 .../SparseTensor/unsparsifiable_dense_op.mlir |   2 +-
 .../Dialect/SparseTensor/unused-tensor.mlir   |   2 +-
 .../SparseTensor/vectorize_reduction.mlir     |   4 +-
 53 files changed, 673 insertions(+), 722 deletions(-)
 create mode 100644 mlir/lib/Dialect/SparseTensor/Transforms/LoopScheduler.cpp
 create mode 100644 mlir/lib/Dialect/SparseTensor/Transforms/LoopScheduler.h

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
index 2fb61746d048999..bb8df626086bb29 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
@@ -127,6 +127,9 @@ inline bool hasAnySparseOperandOrResult(Operation *op) {
   return hasAnySparseOperand(op) || hasAnySparseResult(op);
 }
 
+/// Whether the operation has any sparse tensor with non-identity dim2lvl maps.
+bool hasAnyNonIdentityOperandsOrResults(Operation *op);
+
 //
 // Inference.
 //
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 92bf8ec6468e532..53a0c2428842f59 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -875,6 +875,16 @@ bool mlir::sparse_tensor::isUniqueCOOType(Type tp) {
   return isCOOType(getSparseTensorEncoding(tp), 0, /*isUnique=*/true);
 }
 
+bool mlir::sparse_tensor::hasAnyNonIdentityOperandsOrResults(Operation *op) {
+  auto hasNonIdentityMap = [](Value v) {
+    auto stt = tryGetSparseTensorType(v);
+    return stt && !stt->isIdentity();
+  };
+
+  return llvm::any_of(op->getOperands(), hasNonIdentityMap) ||
+         llvm::any_of(op->getResults(), hasNonIdentityMap);
+}
+
 Level mlir::sparse_tensor::getCOOStart(SparseTensorEncodingAttr enc) {
   // We only consider COO region with at least two levels for the purpose
   // of AOS storage optimization.
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
index b8a2ff26b6794f6..06d0ec1d7eb7cf3 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms
   CodegenEnv.cpp
   CodegenUtils.cpp
   LoopEmitter.cpp
+  LoopScheduler.cpp
   SparseBufferRewriting.cpp
   SparseGPUCodegen.cpp
   SparseReinterpretMap.cpp
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
index 5c7cc93737b7fd7..3df9bcc77717994 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
@@ -57,7 +57,11 @@ CodegenEnv::CodegenEnv(linalg::GenericOp linop, SparsificationOptions opts,
       loopEmitter(), topSort(), sparseOut(nullptr), outerParNest(-1u),
       insChain(), expValues(), expFilled(), expAdded(), expCount(), redVal(),
       redExp(detail::kInvalidId), redCustom(detail::kInvalidId),
-      redValidLexInsert() {}
+      redValidLexInsert() {
+  // TODO: remove topSort, loops should be already sorted by previous pass.
+  for (unsigned l = 0; l < latticeMerger.getNumLoops(); l++)
+    topSort.push_back(l);
+}
 
 LogicalResult CodegenEnv::initTensorExp() {
   // Builds the tensor expression for the Linalg operation in SSA form.
@@ -181,36 +185,25 @@ bool CodegenEnv::isAdmissibleTensorExp(ExprId exp) {
   // Accept "truly dynamic" if the output tensor materializes uninitialized
   // into the computation and insertions occur in lexicographic index order.
   sparseOut = lhs;
-  return isMaterializing(lhs->get());
-}
 
-bool CodegenEnv::isAdmissibleTopoOrder() {
-  if (!hasSparseOutput())
-    return true;
-
-  OpOperand *lhs = linalgOp.getDpsInitOperand(0);
-  // Accept "truly dynamic" if the output tensor materializes uninitialized
-  // into the computation and insertions occur in lexicographic index order.
-  LoopOrd nest = 0;
+  // Find the outermost parallel nest to determine whether compress/expand is
+  // needed.
+  outerParNest = 0;
   const auto iteratorTypes = linalgOp.getIteratorTypesArray();
   assert(topSortSize() == latticeMerger.getNumLoops());
   for (const LoopId i : topSort) {
     if (!latticeMerger.isFilterLoop(i)) {
-      // We only count non-filter loops as filter loops should be considered
-      // a special type of parallel loops.
       if (linalg::isReductionIterator(iteratorTypes[i]))
         break; // terminate at first reduction
-      nest++;
+      outerParNest++;
     }
   }
-  // Determine admissible dynamic insertion situations:
-  // (1) fully injective, since there are no reductions,
-  // (2) admissible 1-d expansion in innermost dimension.
-  if (static_cast<int64_t>(nest) >= linalgOp.getRank(lhs) - 1) {
-    outerParNest = nest;
-    return true;
-  }
-  return false;
+
+  // Inadmissible kernel should have already been rejected by the previous
+  // path during loop scheduling.
+  assert(static_cast<int64_t>(outerParNest) >=
+         linalgOp.getRank(linalgOp.getDpsInitOperand(0)) - 1);
+  return isMaterializing(lhs->get());
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopScheduler.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopScheduler.cpp
new file mode 100644
index 000000000000000..32e0a6f5c03d8a2
--- /dev/null
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopScheduler.cpp
@@ -0,0 +1,273 @@
+//===- LoopScheduler.cpp -------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "LoopScheduler.h"
+
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
+#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/IR/AffineExprVisitor.h"
+#include "mlir/IR/BuiltinTypes.h"
+
+using namespace mlir;
+using namespace mlir::sparse_tensor;
+
+namespace {
+
+/// A helper class that visits an affine expression and tries to find an
+/// AffineDimExpr to which the corresponding iterator from a GenericOp matches
+/// the desired iterator type.
+/// If there is no matched iterator type, returns the first DimExpr in the
+/// expression.
+class AffineDimFinder : public AffineExprVisitor<AffineDimFinder> {
+public:
+  explicit AffineDimFinder(ArrayRef<utils::IteratorType> itTypes)
+      : iterTypes(itTypes) {}
+
+  // Overrides method from AffineExprVisitor.
+  void visitDimExpr(AffineDimExpr expr) {
+    if (pickedDim == nullptr || pickIterType == iterTypes[expr.getPosition()])
+      pickedDim = expr;
+  }
+
+  /// Set the desired iterator type that we want to pick.
+  void setPickedIterType(utils::IteratorType iterType) {
+    pickIterType = iterType;
+  }
+
+  /// Get the desired AffineDimExpr.
+  AffineDimExpr getDimExpr() const {
+    return llvm::cast<AffineDimExpr>(pickedDim);
+  }
+
+  void walkPostOrder(AffineExpr expr) {
+    pickedDim = nullptr;
+    AffineExprVisitor<AffineDimFinder>::walkPostOrder(expr);
+  }
+
+private:
+  /// The picked AffineDimExpr after visit.
+  AffineExpr pickedDim;
+  /// The iterator type that we want.
+  utils::IteratorType pickIterType;
+  /// The mapping between dim=>iterator type.
+  ArrayRef<utils::IteratorType> iterTypes;
+};
+
+// Flattens an affine expression into a list of AffineDimExprs.
+struct AffineDimCollector : public AffineExprVisitor<AffineDimCollector> {
+  // Overrides method from AffineExprVisitor.
+  void visitDimExpr(AffineDimExpr expr) { dims.push_back(expr); }
+  SmallVector<AffineDimExpr> dims;
+};
+
+} // namespace
+
+inline static bool includesAny(SortMask mask1, SortMask mask2) {
+  return static_cast<unsigned>(mask1) & static_cast<unsigned>(mask2);
+}
+
+inline static bool includesDenseInput(SortMask mask) {
+  return includesAny(mask, SortMask::kIncludeDenseInput);
+}
+
+inline static bool includesDenseOutput(SortMask mask) {
+  return includesAny(mask, SortMask::kIncludeDenseOutput);
+}
+
+/// A helper to compute a topological sort. O(n^2) time complexity
+/// as we use adj matrix for the graph.
+/// The sorted result will put the first Reduction iterator to the
+/// latest possible position.
+AffineMap LoopScheduler::topoSort() {
+  std::vector<unsigned> redIt; // reduce iterator with 0 degree
+  std::vector<unsigned> parIt; // parallel iterator with 0 degree
+  const unsigned numLoops = getNumLoops();
+  for (unsigned i = 0; i < numLoops; i++) {
+    if (inDegree[i] == 0) {
+      if (iterTypes[i] == utils::IteratorType::reduction)
+        redIt.push_back(i);
+      else
+        parIt.push_back(i);
+    }
+  }
+
+  SmallVector<unsigned> loopOrder;
+  while (!redIt.empty() || !parIt.empty()) {
+    // We always prefer parallel loop over reduction loop because putting
+    // reduction loop early might make the loop sequence inadmissible.
+    auto &it = !parIt.empty() ? parIt : redIt;
+    auto src = it.back();
+    loopOrder.push_back(src);
+    it.pop_back();
+    // Update in-degree, and push 0-degree node into worklist.
+    for (unsigned dst = 0; dst < numLoops; dst++) {
+      if (itGraph[src][dst] && --inDegree[dst] == 0) {
+        if (iterTypes[dst] == utils::IteratorType::reduction)
+          redIt.push_back(dst);
+        else
+          parIt.push_back(dst);
+      }
+    }
+  }
+
+  if (loopOrder.size() == numLoops)
+    return AffineMap::getPermutationMap(loopOrder, out.getContext());
+
+  // Cycle detected.
+  return AffineMap();
+}
+
+LoopScheduler LoopScheduler::fromGenericOp(linalg::GenericOp genericOp) {
+  // Must be a demapped sparse kernel.
+  assert(!hasAnyNonIdentityOperandsOrResults(genericOp) &&
+         hasAnySparseOperandOrResult(genericOp) &&
+         genericOp.getNumDpsInits() == 1);
+
+  SmallVector<AffineMap> loopMap = genericOp.getIndexingMapsArray();
+  SmallVector<Value> ins = genericOp.getDpsInputs();
+
+  AffineMap outMap = loopMap.back();
+  loopMap.pop_back();
+
+  Value out = genericOp.getDpsInitOperand(0)->get();
+  SmallVector<utils::IteratorType> iterTypes =
+      genericOp.getIteratorTypesArray();
+
+  return LoopScheduler(std::move(ins), std::move(loopMap), out, outMap,
+                       std::move(iterTypes));
+}
+
+LoopScheduler::LoopScheduler(SmallVector<Value> &&ins,
+                             SmallVector<AffineMap> &&loop2InsLvl, Value out,
+                             AffineMap loop2OutLvl,
+                             SmallVector<utils::IteratorType> &&iterTypes)
+    : ins(std::move(ins)), loop2InsLvl(std::move(loop2InsLvl)), out(out),
+      loop2OutLvl(loop2OutLvl), iterTypes(std::move(iterTypes)) {
+  // One map per tensor.
+  assert(loop2InsLvl.size() == ins.size());
+  // All the affine maps have the same number of dimensions (loops).
+  assert(llvm::all_equal(llvm::map_range(
+      loop2InsLvl, [](AffineMap m) { return m.getNumDims(); })));
+  // The number of results of the map should match the rank of the tensor.
+  assert(llvm::all_of(llvm::zip(loop2InsLvl, ins), [](auto mvPair) {
+    auto [m, v] = mvPair;
+    return m.getNumResults() ==
+           v.getType().template cast<ShapedType>().getRank();
+  }));
+
+  itGraph.resize(getNumLoops(), std::vector<bool>(getNumLoops(), false));
+  inDegree.resize(getNumLoops());
+}
+
+AffineMap LoopScheduler::schedule(SortMask mask, Value ignored) {
+  // Reset the interation graph.
+  for (auto &row : itGraph)
+    std::fill(row.begin(), row.end(), false);
+  // Reset cached in-degree.
+  std::fill(inDegree.begin(), inDegree.end(), 0);
+
+  for (auto [in, map] : llvm::zip(ins, loop2InsLvl)) {
+    // Get map and encoding.
+    const auto enc = getSparseTensorEncoding(in.getType());
+    // Skips dense inputs when not requested.
+    if ((!enc && !includesDenseInput(mask)) || in == ignored)
+      continue;
+
+    addConstraints(in, map);
+  }
+
+  // Get map and encoding.
+  const auto enc = getSparseTensorEncoding(out.getType());
+  if ((enc || includesDenseOutput(mask)) && out != ignored)
+    addConstraints(out, loop2OutLvl);
+
+  return topoSort();
+}
+
+void LoopScheduler::addConstraints(Value t, AffineMap loop2LvlMap) {
+  auto addIterOrdering = [this](unsigned f, unsigned t) {
+    if (!itGraph[f][t] && f != t) {
+      itGraph[f][t] = true;
+      inDegree[t]++;
+    }
+  };
+
+  AffineDimFinder finder(iterTypes);
+  finder.setPickedIterType(utils::IteratorType::reduction);
+
+  // To compute iteration graph for tensor[d0 + d1 + d3, d4 + d5 + d6],
+  // we requires there exist d_x \in {d0, d1, d3} and d_y \in {d4, d5, d6},
+  // and d_x > d_y && {d0, d1, d3} - d_x > {d4, d5, d6} - d_y
+  const Level lvlRank = loop2LvlMap.getNumResults();
+  for (Level lvl = 1; lvl < lvlRank; lvl++) {
+    const AffineExpr fa = loop2LvlMap.getResult(lvl - 1);
+    const AffineExpr ta = loop2LvlMap.getResult(lvl);
+
+    if (llvm::isa<AffineDimExpr>(fa) || llvm::isa<AffineDimExpr>(ta)) {
+      // Special case when at least one loop2LvlExp is an simple AffineDimExpr
+      // (say, d0) and we require d0 > {d1, d2, ...} or {d1, d2, ...} > d0
+      AffineDimCollector fCollector;
+      fCollector.walkPostOrder(fa);
+      AffineDimCollector tCollector;
+      tCollector.walkPostOrder(ta);
+
+      for (auto fd : fCollector.dims) {
+        for (auto td : tCollector.dims) {
+          const unsigned f = fd.getPosition();
+          const unsigned t = td.getPosition();
+          addIterOrdering(f, t);
+        }
+      }
+      continue;
+    }
+
+    // When both loop2LvlExpr is compound, we pick an abitrary reduction loop
+    // from lhs and rhs and use them as d_x and d_y.
+    finder.walkPostOrder(fa);
+    const AffineDimExpr fexp = finder.getDimExpr();
+    const unsigned fldx = fexp.getPosition();
+
+    finder.walkPostOrder(ta);
+    const AffineDimExpr texp = finder.getDimExpr();
+    const unsigned tldx = texp.getPosition();
+
+    // d_x > d_y
+    addIterOrdering(fldx, tldx);
+
+    AffineDimCollector fCollector;
+    fCollector.walkPostOrder(fa);
+    AffineDimCollector tCollector;
+    tCollector.walkPostOrder(ta);
+
+    // make sure dx and dy is the last;
+    for (auto fd : fCollector.dims) {
+      const unsigned f = fd.getPosition();
+      addIterOrdering(f, fldx);
+    }
+    for (auto td : tCollector.dims) {
+      const unsigned t = td.getPosition();
+      addIterOrdering(t, tldx);
+    }
+    // {d0, d1, d3} - d_x > {d4, d5, d6} - d_y
+    // This is to ensure that the affine expressions are reduced in sparse
+    // tensor level ordering.
+    for (auto fd : fCollector.dims) {
+      const unsigned f = fd.getPosition();
+      if (f == fldx) // skip d_x
+        continue;
+      for (auto td : tCollector.dims) {
+        const unsigned t = td.getPosition();
+        if (t == tldx) // skip d_y
+          continue;
+        addIterOrdering(f, t);
+      }
+    }
+  }
+}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopScheduler.h b/mlir/lib/Dialect/SparseTensor/Transforms/LoopScheduler.h
new file mode 100644
index 000000000000000..1d5cb1bc6f60e4e
--- /dev/null
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopScheduler.h
@@ -0,0 +1,73 @@
+//===- LoopScheduler.h -----------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/AffineMap.h"
+
+namespace mlir {
+
+// Forward declarations.
+class Value;
+namespace utils {
+enum class IteratorType : uint32_t;
+} // namespace utils
+namespace linalg {
+class GenericOp;
+} // namespace linalg
+
+namespace sparse_tensor {
+
+/// Iteration graph sorting.
+enum class SortMask : unsigned {
+  // The individual mask bits.
+  kIncludeDenseOutput = 0x1, // b001
+  kIncludeDenseInput = 0x2,  // b010
+  kIncludeUndef = 0x4,       // b100
+  // The subsets of mask bits.
+  kIncludeAll = 0x7,   // b111
+  kIncludeDense = 0x3, // b011
+  kSparseOnly = 0x0,   // b000
+};
+
+class LoopScheduler {
+public:
+  // Constructs a scheduler from linalg.generic
+  // Maybe reuses the class to schedule foreach as well (to address
+  // non-permutation, e.g, traverse CSR in BSR order).
+  static LoopScheduler fromGenericOp(linalg::GenericOp genericOp);
+
+  // Returns a permutation that represents the scheduled loop order.
+  // Note that the returned AffineMap could be null if the kernel can not be
+  // schedule due to cycles in the iteration graph.
+  [[nodiscard]] AffineMap schedule(SortMask mask, Value ignored = nullptr);
+  unsigned getNumLoops() const { return loop2OutLvl.getNumDims(); }
+
+private:
+  LoopScheduler(SmallVector<Value> &&ins, SmallVector<AffineMap> &&loop2InsLvl,
+                Value out, AffineMap loop2OutLvl,
+                SmallVector<utils::IteratorType> &&iterTypes);
+
+  void addConstraints(Value t, AffineMap loop2LvlMap);
+  AffineMap topoSort();
+
+  // Input tensors and associated loop to level maps.
+  SmallVector<Value> ins;
+  SmallVector<AffineMap> loop2InsLvl;
+  // Output tensor and associated loop to level map.
+  Value out;
+  AffineMap loop2OutLvl;
+  // Loop type;
+  SmallVector<utils::IteratorType> iterTypes;
+
+  // Adjacent matrix that represents the iteration graph.
+  std::vector<std::vector<bool>> itGraph;
+  // InDegree used for topo sort.
+  std::vector<unsigned> inDegree;
+};
+
+} // namespace sparse_tensor
+} // namespace mlir
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
index 83e86d137335021..a433b2edc5500e7 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
@@ -1,4 +1,4 @@
-//===- SparseReinterpretMap.cpp - reinterpret sparse tensor maps ----------===//
+//===- SparseReinterpretMap.cpp - reinterpret sparse tensor maps ----------===/
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "CodegenUtils.h"
+#include "LoopScheduler.h"
 
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
@@ -351,17 +352,6 @@ static SmallVector<Value> remapValueRange(OpBuilder &rewriter, TypeRange types,
   return ret;
 }
 
-/// Whether the operation has any sparse tensor with non-identity dim2lvl maps.
-static bool hasNonIdentityOperandsOrResults(Operation *op) {
-  auto hasNonIdentityMap = [](Value v) {
-    auto stt = tryGetSparseTensorType(v);
-    return stt && !stt->isIdentity();
-  };
-
-  return llvm::any_of(op->getOperands(), hasNonIdentityMap) ||
-         llvm::any_of(op->getResults(), hasNonIdentityMap);
-}
-
 namespace {
 
 //===----------------------------------------------------------------------===//
@@ -379,7 +369,7 @@ struct GenericOpReinterpretMap
     // semantics.
     if (linalgOp.getNumDpsInits() != 1 || !linalgOp.hasTensorSemantics() ||
         !hasAnySparseOperandOrResult(linalgOp) ||
-        !hasNonIdentityOperandsOrResults(linalgOp))
+        !hasAnyNonIdentityOperandsOrResults(linalgOp))
       return failure();
 
     // Try translating the index map.
@@ -411,6 +401,178 @@ struct GenericOpReinterpretMap
   }
 };
 
+struct GenericOpScheduler : public OpRewritePattern<linalg::GenericOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(linalg::GenericOp linalgOp,
+                                PatternRewriter &rewriter) const override {
+    if (linalgOp.getNumDpsInits() != 1 || !linalgOp.hasTensorSemantics() ||
+        hasAnyNonIdentityOperandsOrResults(linalgOp) || // need demap first
+        !hasAnySparseOperandOrResult(linalgOp)) {
+      return failure();
+    }
+
+    const StringRef sorted = "sorted";
+    if (linalgOp->hasAttr(sorted))
+      return failure();
+
+    auto scheduler = LoopScheduler::fromGenericOp(linalgOp);
+    bool isAdmissible = false;
+    AffineMap order;
+    // A const list of all masks that we used for iteration graph
+    // computation. Must be ordered from more strict to less strict.
+    // Ideally (though might not be guaranteed), the earlier a constraint mask
+    // can be satisfied, the faster the generated kernel will be.
+    const auto allMasks = {
+        SortMask::kIncludeAll,        SortMask::kIncludeDense,
+        SortMask::kIncludeDenseInput, SortMask::kIncludeDenseOutput,
+        SortMask::kIncludeUndef,      SortMask::kSparseOnly};
+    for (const SortMask mask : allMasks) {
+      order = scheduler.schedule(mask);
+      if (order) {
+        if (isAdmissibleOrder(linalgOp, order)) {
+          isAdmissible = true;
+          break;
+        }
+        // else try a set of less strict constraints.
+      }
+    }
+
+    if (!order) {
+      // Cycles detected.
+      if (failed(resolveCycle(scheduler, linalgOp, rewriter))) {
+        return rewriter.notifyMatchFailure(
+            linalgOp, "the sparse kernel can not be scheduled: loop detected.");
+      }
+      return success();
+    }
+
+    if (!isAdmissible) {
+      return rewriter.notifyMatchFailure(
+          linalgOp, "the sparse kernel can not be scheduled.");
+    }
+
+    // Marks the GenericOp to avoid recursive matching.
+    linalgOp->setAttr(sorted, rewriter.getBoolAttr(true));
+
+    // Already sorted.
+    if (order.isIdentity())
+      return failure();
+
+    assert(order.isPermutation());
+    // `order` is orignial loop -> sorted loop map
+    ArrayAttr preItTypes = linalgOp.getIteratorTypesAttr();
+    SmallVector<Attribute> curItTypes;
+    curItTypes.reserve(preItTypes.size());
+    for (AffineExpr expr : order.getResults()) {
+      unsigned loopID = llvm::cast<AffineDimExpr>(expr).getPosition();
+      curItTypes.push_back(preItTypes[loopID]);
+    }
+
+    // Inverse `order` to get sorted loop -> original loop map
+    order = inversePermutation(order);
+    SmallVector<AffineMap> idxMaps = linalgOp.getIndexingMapsArray();
+    for (AffineMap &idxMap : idxMaps)
+      idxMap = idxMap.compose(order); // sorted loop -> lvl map
+
+    rewriter.startRootUpdate(linalgOp);
+    linalgOp.setIndexingMapsAttr(rewriter.getAffineMapArrayAttr(idxMaps));
+    linalgOp.setIteratorTypesAttr(rewriter.getArrayAttr(curItTypes));
+    rewriter.finalizeRootUpdate(linalgOp);
+
+    return success();
+  }
+
+private:
+  /// Whether the loop order is admissible by sparsification.
+  static bool isAdmissibleOrder(linalg::GenericOp linalgOp, AffineMap order) {
+    if (!hasAnySparseResult(linalgOp))
+      return true;
+
+    OpOperand *lhs = linalgOp.getDpsInitOperand(0);
+    unsigned nest = 0;
+    const auto iteratorTypes = linalgOp.getIteratorTypesArray();
+    for (const AffineExpr l : order.getResults()) {
+      unsigned loopId = llvm::cast<AffineDimExpr>(l).getPosition();
+      auto itTp =
+          linalgOp.getIteratorTypes()[loopId].cast<linalg::IteratorTypeAttr>();
+      if (linalg::isReductionIterator(itTp.getValue()))
+        break; // terminate at first reduction
+      nest++;
+    }
+    // Determine admissible dynamic insertion situations:
+    // (1) fully injective, since there are no reductions,
+    // (2) admissible 1-d expansion in innermost dimension.
+    return static_cast<int64_t>(nest) >= linalgOp.getRank(lhs) - 1;
+  };
+
+  // Last resort cycle resolution.
+  static LogicalResult resolveCycle(LoopScheduler &scheduler,
+                                    linalg::LinalgOp linalgOp,
+                                    PatternRewriter &rewriter) {
+    // Compute topological sort while leaving out every sparse input tensor in
+    // succession until an acylic iteration graph results.
+    for (OpOperand *t : linalgOp.getDpsInputOperands()) {
+      Value tval = t->get();
+      auto srcEnc = getSparseTensorEncoding(tval.getType());
+      // The constraints introduced by compound index expression are
+      // complicated. Skips them.
+      AffineMap idxMap = linalgOp.getMatchingIndexingMap(t);
+      bool hasCompExpr = llvm::any_of(idxMap.getResults(), [](AffineExpr exp) {
+        return !llvm::isa<AffineDimExpr>(exp);
+      });
+      if (!srcEnc || hasCompExpr)
+        continue;
+
+      // Try scheduling loop without constraints from `tval`.
+      AffineMap order = scheduler.schedule(SortMask::kSparseOnly, tval);
+      if (!order) // still cyclic
+        continue;
+
+      // Found an input tensor that resolves the cycle by inserting a
+      // conversion into a sparse tensor that adheres to the iteration
+      // graph order.
+      auto stt = getSparseTensorType(tval);
+      assert(stt.isIdentity());
+      order = inversePermutation(order);
+      // sorted loop -> lvl map.
+      idxMap = idxMap.compose(order);
+
+      // Found a permutation such that the results in `idxMap` is sorted.
+      // For example,
+      //  (d0, d1, d2, d3) -> (d2, d1, d0)
+      // loops are scheduled in order of d0->d1->d2->d3, to resolve the cycle,
+      // we find a permutation, perm(d2, d1, d0) -> (d0, d1, d2), such that the
+      // transposed tensor's levels are visited in the same order as the loop
+      // scheduling order.
+      SmallVector<std::pair<unsigned, unsigned>> lvlSeq;
+      for (AffineExpr expr : idxMap.getResults()) {
+        unsigned lvl = llvm::cast<AffineDimExpr>(expr).getPosition();
+        lvlSeq.push_back(std::make_pair(lvl, lvlSeq.size()));
+      }
+      std::sort(lvlSeq.begin(), lvlSeq.end(), [](auto &lhs, auto &rhs) -> bool {
+        return lhs.first < rhs.first;
+      });
+      SmallVector<unsigned> perm =
+          llvm::to_vector(llvm::make_second_range(lvlSeq));
+      auto dimToLvl = AffineMap::getPermutationMap(perm, linalgOp.getContext());
+      // The result of the idxMap must be unsorted.
+      assert(!dimToLvl.isIdentity());
+
+      // Inserting the transpose
+      rewriter.setInsertionPoint(linalgOp);
+      RankedTensorType dstTp = stt.withDimToLvl(dimToLvl).getRankedTensorType();
+      Value dst = rewriter.create<ConvertOp>(tval.getLoc(), dstTp, tval);
+      rewriter.updateRootInPlace(linalgOp, [&]() {
+        linalgOp->setOperand(t->getOperandNumber(), dst);
+      });
+      return success();
+    }
+    // Cannot be resolved with a single conversion.
+    // TODO: convert more than one?
+    return failure();
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // Reinterpret Map Rewriters for operations other than linalg.generics
 //===----------------------------------------------------------------------===//
@@ -420,7 +582,7 @@ struct TensorAllocDemapper : public OpRewritePattern<AllocOp> {
   using OpRewritePattern<AllocOp>::OpRewritePattern;
   LogicalResult matchAndRewrite(AllocOp op,
                                 PatternRewriter &rewriter) const override {
-    if (!hasNonIdentityOperandsOrResults(op))
+    if (!hasAnyNonIdentityOperandsOrResults(op))
       return failure();
 
     Location loc = op.getLoc();
@@ -493,7 +655,7 @@ struct ForeachOpDemapper
                           PatternRewriter &rewriter) const {
     // Only handle operations with sparse input/output with non-identity dim2lvl
     // maps.
-    if (!hasNonIdentityOperandsOrResults(op))
+    if (!hasAnyNonIdentityOperandsOrResults(op))
       return failure();
 
     // TODO: demap constant as well.
@@ -582,7 +744,8 @@ void mlir::populateSparseReinterpretMap(RewritePatternSet &patterns,
                                         ReinterpretMapScope scope) {
   if (scope == ReinterpretMapScope::kAll ||
       scope == ReinterpretMapScope::kGenericOnly) {
-    patterns.add<GenericOpReinterpretMap>(patterns.getContext());
+    patterns.add<GenericOpReinterpretMap, GenericOpScheduler>(
+        patterns.getContext());
   }
   if (scope == ReinterpretMapScope::kAll ||
       scope == ReinterpretMapScope::kExceptGeneric) {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 093bda9ca28efff..cd6f689a04acb53 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -39,90 +39,6 @@
 using namespace mlir;
 using namespace mlir::sparse_tensor;
 
-//===----------------------------------------------------------------------===//
-// Declarations
-//===----------------------------------------------------------------------===//
-
-namespace {
-
-/// Iteration graph sorting.
-enum class SortMask : unsigned {
-  // The individual mask bits.
-  kIncludeDenseOutput = 0x1, // b001
-  kIncludeDenseInput = 0x2,  // b010
-  kIncludeUndef = 0x4,       // b100
-  // The subsets of mask bits.
-  kIncludeAll = 0x7,   // b111
-  kIncludeDense = 0x3, // b011
-  kSparseOnly = 0x0,   // b000
-};
-
-inline static bool includesAny(SortMask mask1, SortMask mask2) {
-  return static_cast<unsigned>(mask1) & static_cast<unsigned>(mask2);
-}
-
-inline static bool includesDenseInput(SortMask mask) {
-  return includesAny(mask, SortMask::kIncludeDenseInput);
-}
-
-inline static bool includesDenseOutput(SortMask mask) {
-  return includesAny(mask, SortMask::kIncludeDenseOutput);
-}
-
-inline static bool includesDense(SortMask mask) {
-  return includesAny(mask, SortMask::kIncludeDense);
-}
-
-inline static bool includesUndef(SortMask mask) {
-  return includesAny(mask, SortMask::kIncludeUndef);
-}
-
-/// A helper class that visits an affine expression and tries to find an
-/// AffineDimExpr to which the corresponding iterator from a GenericOp matches
-/// the desired iterator type.
-class AffineDimFinder : public AffineExprVisitor<AffineDimFinder> {
-public:
-  explicit AffineDimFinder(linalg::GenericOp op)
-      : iterTypes(op.getIteratorTypes()) {}
-
-  // Overrides method from AffineExprVisitor.
-  void visitDimExpr(AffineDimExpr expr) {
-    if (pickedDim == nullptr ||
-        pickIterType ==
-            cast<linalg::IteratorTypeAttr>(iterTypes[expr.getPosition()])
-                .getValue()) {
-      pickedDim = expr;
-    }
-  }
-
-  /// Set the desired iterator type that we want to pick.
-  void setPickedIterType(utils::IteratorType iterType) {
-    pickIterType = iterType;
-  }
-
-  /// Get the desired AffineDimExpr.
-  AffineDimExpr getDimExpr() const { return cast<AffineDimExpr>(pickedDim); }
-
-private:
-  /// The picked AffineDimExpr after visit.  This must be stored as
-  /// `AffineExpr` rather than `AffineDimExpr`, because the latter
-  /// doesn't have a default ctor.
-  AffineExpr pickedDim;
-  /// The iterator type that we want.
-  utils::IteratorType pickIterType;
-  /// The mapping between dim=>iterator type.
-  ArrayAttr iterTypes;
-};
-
-// Flattens an affine expression into a list of AffineDimExprs.
-struct AffineDimCollector : public AffineExprVisitor<AffineDimCollector> {
-  // Overrides method from AffineExprVisitor.
-  void visitDimExpr(AffineDimExpr expr) { dims.push_back(expr); }
-  SmallVector<AffineDimExpr> dims;
-};
-
-} // namespace
-
 //===----------------------------------------------------------------------===//
 // Sparsifier analysis methods.
 //===----------------------------------------------------------------------===//
@@ -170,56 +86,6 @@ static bool isInvariantAffine(CodegenEnv &env, AffineExpr a, LoopId ldx,
   return isInvariantAffine(a, env.getCurrentLoopStack(), ldx, isAtLoop);
 }
 
-/// Helper method to construct a permuted dimension ordering
-/// that adheres to the given topological sort.
-//
-// FIXME: does the above actually mean "dimensions", or should it say
-// "level ordering"?  The same dim/lvl confusion applies to all the code
-// and comments in the definition below.
-static AffineMap permute(CodegenEnv &env, AffineMap m) {
-  assert(m.getNumDims() + env.merger().getNumFilterLoops() ==
-             env.topSortSize() &&
-         "size mismatch");
-  // Construct the inverse of `m`; to avoid the asymptotic complexity
-  // of calling `m.getPermutedPosition` repeatedly.
-  //
-  // The variable `perm` must use `unsigned` rather than `Dimension`/`Level`,
-  // because that's what `AffineMap::getPermutationMap` requires.
-  // TODO: however, `perm` should be renamed to make clear what exactly
-  // it's storing a permutation of.
-  SmallVector<unsigned> perm;
-  const unsigned numResults = m.getNumResults();
-  BitVector worklist(numResults, true);
-  LoopOrd loopDepth = 1;
-
-  // Construct the permutation.
-  while (worklist.any() && loopDepth <= env.topSortSize()) {
-    const unsigned preSize = perm.size();
-    for (unsigned dim : worklist.set_bits()) {
-      bool isAtLoop = false;
-      if (isa<AffineConstantExpr>(m.getResult(dim)) ||
-          (isInvariantAffine(m.getResult(dim), env.getLoopStackUpTo(loopDepth),
-                             env.topSortAt(loopDepth - 1), isAtLoop) &&
-           isAtLoop)) {
-        // If the matching affine is constant expression or just become
-        // invariant. We can visit the dimension now without breaking the
-        // topSort constraint.
-        perm.push_back(dim);
-      }
-    }
-
-    // Removes resolved dimension.
-    for (unsigned i = preSize, e = perm.size(); i < e; i++)
-      worklist.reset(perm[i]);
-
-    // Try entering the next loop in the stack.
-    loopDepth++;
-  }
-
-  assert(perm.size() == numResults);
-  return AffineMap::getPermutationMap(perm, env.op().getContext());
-}
-
 /// Helper method to inspect affine expressions. Rejects cases where the
 /// same index is used more than once. Also rejects compound affine
 /// expressions in sparse dimensions.
@@ -471,378 +337,6 @@ static bool findSparseAnnotations(CodegenEnv &env, bool idxReducBased) {
   return annotated;
 }
 
-/// A helper to compute a topological sort. O(n^2) time complexity
-/// as we use adj matrix for the graph.
-/// The sorted result will put the first Reduction iterator to the
-/// latest possible `LoopOrd`.
-///
-/// The `inDegree` is indexed by `LoopId`, and the `adjM` is indexed by
-/// `(LoopId,LoopId)`.
-static bool topSortOptimal(CodegenEnv &env,
-                           ArrayRef<utils::IteratorType> iteratorTypes,
-                           std::vector<unsigned> &inDegree,
-                           std::vector<std::vector<bool>> &adjM) {
-  std::vector<LoopId> redIt;    // reduce iterator with 0 degree
-  std::vector<LoopId> parIt;    // parallel iterator with 0 degree
-  std::vector<LoopId> filterIt; // filter loop with 0 degree
-  const LoopId numLoops = env.merger().getNumLoops();
-  for (LoopId i = 0; i < numLoops; i++) {
-    if (inDegree[i] == 0) {
-      if (env.merger().isFilterLoop(i))
-        filterIt.push_back(i);
-      else if (linalg::isReductionIterator(iteratorTypes[i]))
-        redIt.push_back(i);
-      else
-        parIt.push_back(i);
-    }
-  }
-
-  while (!redIt.empty() || !parIt.empty() || !filterIt.empty()) {
-    // We always choose in order of filter loop -> parallel loop -> reduction
-    // loop because
-    // 1. Putting reduction loop early might make the loop sequence
-    // inadmissible.
-    // 2. Filter loops should be put as early as possible for better
-    // performance, since only one (if any) iteration will carry the
-    // computation. E.g., for (1 to N)
-    //  for (1 to M)
-    //    for (1 to K)
-    //      if (xxx)
-    //        O(X) computation  => O(NMK+NMX) time complexity
-    //
-    // By putting the filter loop one level up, we got
-    //
-    // for (1 to N)
-    //  for (1 to K)
-    //    if (xxx)
-    //      for (1 to M)
-    //        O(X) computation  => O(NK+NMX) time complexity
-    auto &it = !filterIt.empty() ? filterIt : (!parIt.empty() ? parIt : redIt);
-    auto src = it.back();
-    env.topSortPushBack(src);
-    it.pop_back();
-    // Update in-degree, and push 0-degree node into worklist.
-    for (LoopId dst = 0; dst < numLoops; dst++) {
-      if (adjM[src][dst] && --inDegree[dst] == 0) {
-        if (env.merger().isFilterLoop(dst))
-          filterIt.push_back(dst);
-        else if (linalg::isReductionIterator(iteratorTypes[dst]))
-          redIt.push_back(dst);
-        else
-          parIt.push_back(dst);
-      }
-    }
-  }
-  return env.topSortSize() == numLoops;
-}
-
-static void addIterOrdering(LoopId f, LoopId t,
-                            std::vector<std::vector<bool>> &adjM,
-                            std::vector<unsigned> &inDegree) {
-  if (!adjM[f][t] && f != t) {
-    adjM[f][t] = true;
-    inDegree[t]++;
-  }
-}
-
-/// Helper method to add all constraints from the indices in one affine
-/// expression before all indices in the other affine expression. For
-/// example i0+i1 < i2+i3+1 yields i0<i2, i0<i3, i1<i2, and i1<i3.
-/// The affine expression `a` is empty iff `fidx` have a value, leading to
-/// b = (i0 + i1) < fidx => i0 < fidx, i1 < fidx.
-/// The affine expression `b` is empty iff `tidx` have a value, leading to
-/// tidx < a = (i0 + i1) => tidx < i0, tidx < i1.
-///
-/// The `inDegree` is indexed by `LoopId`, and the `adjM` is indexed by
-/// `(LoopId,LoopId)`.
-static void addAffineOrderings(std::vector<std::vector<bool>> &adjM,
-                               std::vector<unsigned> &inDegree, AffineExpr a,
-                               AffineExpr b, std::optional<LoopId> fidx,
-                               std::optional<LoopId> tidx) {
-  if (!a && !b) {
-    // Recursion leaf.
-    assert(fidx && tidx);
-    const LoopId f = *fidx, t = *tidx;
-    addIterOrdering(f, t, adjM, inDegree);
-    return;
-  }
-  // Picks an affine expression and expand (recurse into) it.
-  const auto toExpand = a ? a : b;
-  switch (toExpand.getKind()) {
-  case AffineExprKind::DimId: {
-    const std::optional<LoopId> idx{
-        cast<AffineDimExpr>(toExpand).getPosition()};
-    if (toExpand == a)
-      addAffineOrderings(adjM, inDegree, AffineExpr(), b, idx, tidx);
-    else // toExpand == b
-      addAffineOrderings(adjM, inDegree, a, AffineExpr(), fidx, idx);
-    break;
-  }
-  case AffineExprKind::Add:
-  case AffineExprKind::Mul: {
-    auto binOp = cast<AffineBinaryOpExpr>(toExpand);
-    if (toExpand == a) {
-      addAffineOrderings(adjM, inDegree, binOp.getLHS(), b, fidx, tidx);
-      addAffineOrderings(adjM, inDegree, binOp.getRHS(), b, fidx, tidx);
-    } else {
-      addAffineOrderings(adjM, inDegree, a, binOp.getLHS(), fidx, tidx);
-      addAffineOrderings(adjM, inDegree, a, binOp.getRHS(), fidx, tidx);
-    }
-    break;
-  }
-  default:
-    break;
-  }
-}
-
-static void tryRelaxAffineConstraints(linalg::GenericOp op,
-                                      std::optional<LoopId> &fldx,
-                                      AffineExpr &fa,
-                                      std::optional<LoopId> &tldx,
-                                      AffineExpr &ta) {
-  // We use a heuristic here to only pick one dim expression from each
-  // compound affine expression to establish the order between two dense
-  // dimensions.
-  if (!tldx) {
-    AffineDimFinder finder(op);
-    // NOTE: The ordering can only be loosen when the destination level is
-    // dense (when !tldx), for [dense, sparse] -> (d0 + d1, d2), we still
-    // require both d0 < d2 and d1 < d2 to ensure correct ordering (i.e.,
-    // no ordering like d0->d2->d1).
-    // TODO: this is obviously a sub optimal solution.
-    if (!fldx && !isa<AffineConstantExpr>(fa)) {
-      // Heuristic: we prefer parallel loop for lhs to reduce the chance
-      // we add reduce < parallel ordering.
-      finder.setPickedIterType(utils::IteratorType::parallel);
-      finder.walkPostOrder(fa);
-      fa = finder.getDimExpr();
-      fldx = finder.getDimExpr().getPosition();
-    }
-    if (!isa<AffineConstantExpr>(ta)) {
-      // Heuristic: we prefer reduction loop for rhs to reduce the chance
-      // adding reduce < parallel ordering.
-      finder.setPickedIterType(utils::IteratorType::reduction);
-      finder.walkPostOrder(ta);
-      ta = finder.getDimExpr();
-      tldx = finder.getDimExpr().getPosition();
-    }
-  }
-}
-
-static void addFilterLoopBasedConstraints(CodegenEnv &env, OpOperand &t,
-                                          OpOperand *skip, SortMask mask,
-                                          std::vector<std::vector<bool>> &adjM,
-                                          std::vector<unsigned> &inDegree) {
-  // Get map, encoding, and tensor-identifier.
-  const auto map = env.op().getMatchingIndexingMap(&t);
-  const auto enc = getSparseTensorEncoding(t.get().getType());
-  const TensorId tid = env.makeTensorId(t.getOperandNumber());
-
-  // Each tensor expression and optional dimension ordering (row-major
-  // by default) puts an ordering constraint on the loop indices. For
-  // example, the tensor expression A_ijk forces the ordering i < j < k
-  // on the loop indices if no explicit dimension ordering is given.
-  const Level lvlRank = map.getNumResults();
-  assert(!enc || lvlRank == enc.getLvlRank());
-  for (Level lvl = 0; lvl < lvlRank; lvl++) {
-    // FIXME: `toOrigDim` is deprecated.
-    AffineExpr ta = map.getResult(toOrigDim(enc, lvl));
-    std::optional<LoopId> tldx = env.merger().getLoopId(tid, lvl);
-    // Filter loops should be constructed after all the dependent loops,
-    // i.e., d0 + d1 < filter_loop(d0 + d1)
-    if (tldx && env.merger().isFilterLoop(*tldx)) {
-      assert(!isa<AffineDimExpr>(ta) && !isDenseDLT(enc.getLvlTypes()[lvl]));
-      addAffineOrderings(adjM, inDegree, ta, AffineExpr(), std::nullopt, tldx);
-      // Now that the ordering of affine expression is captured by filter
-      // loop idx, we only need to ensure the affine ordering against filter
-      // loop. Thus, we reset the affine express to nil here to mark it as
-      // resolved.
-      ta = AffineExpr();
-    }
-
-    // Skip tensor during cycle resolution, though order between filter loop
-    // and dependent loops need to be guaranteed unconditionally.
-    if (&t == skip)
-      continue;
-
-    if (lvl > 0) {
-      // FIXME: `toOrigDim` is deprecated.
-      AffineExpr fa = map.getResult(toOrigDim(enc, lvl - 1));
-      std::optional<LoopId> fldx = env.merger().getLoopId(tid, lvl - 1);
-
-      // Applying order constraints on every pair of dimExpr between two
-      // compound affine expressions can sometime too strict:
-      // E.g., for [dense, dense] -> (d0 + d1, d2 + d3).
-      // It is totally fine to have loop sequence d0->d2->d1->d3 instead of
-      // requiring d0 < d2, d1 < d2, d0 < d3, d1 < d3.
-      // We also relax the affine constraint when use slice-based algorithm
-      // as there is no filter loop for affine index on sparse dimension.
-      // TODO: do we really need the condition?
-      if (!includesDense(mask))
-        tryRelaxAffineConstraints(env.op(), fldx, fa, tldx, ta);
-
-      // (d0 + d1) < (d2 + d3), or
-      // filter_loop_d-1 < (d2 + d3), or
-      // (d0 + d1) < filter_loop_d, or
-      // filter_loop_d-1 < filter_loop_d depending on whether fa/ta is reset
-      // above.
-      addAffineOrderings(adjM, inDegree, fa, ta, fldx, tldx);
-    }
-  }
-}
-
-static void addSliceBasedConstraints(CodegenEnv &env, OpOperand &t,
-                                     OpOperand *skip, SortMask mask,
-                                     std::vector<std::vector<bool>> &adjM,
-                                     std::vector<unsigned> &inDegree) {
-  // Get map and encoding.
-  const auto map = env.op().getMatchingIndexingMap(&t);
-  const auto enc = getSparseTensorEncoding(t.get().getType());
-
-  // No special treatment for simple indices.
-  if (getNumNonTrivialIdxExpOnSparseLvls(map, t.get()) == 0)
-    return addFilterLoopBasedConstraints(env, t, skip, mask, adjM, inDegree);
-
-  // Skip tensor during cycle resolution, though order between filter loop
-  // and dependent loops need to be guaranteed unconditionally.
-  if (&t == skip)
-    return;
-
-  AffineDimFinder finder(env.op());
-  finder.setPickedIterType(utils::IteratorType::reduction);
-  // To compute iteration graph for tensor[d0 + d1 + d3, d4 + d5 + d6],
-  // we requires there exist d_x \in {d0, d1, d3} and d_y \in {d4, d5, d6},
-  // and d_x > d_y && {d0, d1, d3} - d_x > {d4, d5, d6} - d_y
-  const Level lvlRank = map.getNumResults();
-  assert(!enc || lvlRank == enc.getLvlRank());
-  for (Level lvl = 1; lvl < lvlRank; lvl++) {
-    // FIXME: `toOrigDim` is deprecated.
-    const AffineExpr fa = map.getResult(toOrigDim(enc, lvl - 1));
-    const AffineExpr ta = map.getResult(toOrigDim(enc, lvl));
-
-    if (isa<AffineDimExpr>(fa) || isa<AffineDimExpr>(ta)) {
-      AffineDimCollector fCollector;
-      fCollector.walkPostOrder(fa);
-
-      AffineDimCollector tCollector;
-      tCollector.walkPostOrder(ta);
-      for (auto fd : fCollector.dims) {
-        for (auto td : tCollector.dims) {
-          const LoopId f = env.makeLoopId(fd.getPosition());
-          const LoopId t = env.makeLoopId(td.getPosition());
-          addIterOrdering(f, t, adjM, inDegree);
-        }
-      }
-      continue;
-    }
-
-    // This is a heuristic, we pick an abitrary reduction loop from lhs and
-    // rhs and use them as d_x and d_y.
-    finder.walkPostOrder(fa);
-    const AffineDimExpr fexp = finder.getDimExpr();
-    const LoopId fldx = env.makeLoopId(fexp.getPosition());
-
-    finder.walkPostOrder(ta);
-    const AffineDimExpr texp = finder.getDimExpr();
-    const LoopId tldx = env.makeLoopId(texp.getPosition());
-
-    // d_x > d_y
-    addIterOrdering(fldx, tldx, adjM, inDegree);
-
-    AffineDimCollector fCollector;
-    fCollector.walkPostOrder(fa);
-    AffineDimCollector tCollector;
-    tCollector.walkPostOrder(ta);
-
-    // make sure dx and dy is the last;
-    for (auto fd : fCollector.dims) {
-      const LoopId f = env.makeLoopId(fd.getPosition());
-      addIterOrdering(f, fldx, adjM, inDegree);
-    }
-    for (auto td : tCollector.dims) {
-      const LoopId t = env.makeLoopId(td.getPosition());
-      addIterOrdering(t, tldx, adjM, inDegree);
-    }
-    // Since we only support affine addition, the order between two dim
-    // expression does not really matters.
-    // {d0, d1, d3} - d_x > {d4, d5, d6} - d_y
-    // This is to ensure that the affine expressions are reduced in sparse
-    // tensor level ordering.
-    // TODO: this ordering could probably be loosen if we support out-of-order
-    // reduction.
-    // TODO: the evaluation order need to be ensure to
-    // support affine multiplication.
-    for (auto fd : fCollector.dims) {
-      const LoopId f = env.makeLoopId(fd.getPosition());
-      if (f == fldx) // skip d_x
-        continue;
-      for (auto td : tCollector.dims) {
-        const LoopId t = env.makeLoopId(td.getPosition());
-        if (t == tldx) // skip d_y
-          continue;
-        addIterOrdering(f, t, adjM, inDegree);
-      }
-    }
-  }
-}
-
-/// Computes a topologically sorted iteration graph for the linalg operation.
-/// Ensures all tensors are visited in natural index order. This is
-/// essential for sparse storage formats since these only support access
-/// along fixed dimensions. Even for dense storage formats, however, the natural
-/// index order yields innermost unit-stride access with better spatial
-/// locality.
-static bool computeIterationGraph(CodegenEnv &env, SortMask mask,
-                                  OpOperand *skip, bool idxReducBased = false) {
-  // Set up an n x n from/to adjacency matrix of the iteration graph
-  // for the implicit loop indices i_0 .. i_n-1.
-  const unsigned numLoops = env.merger().getNumLoops();
-  std::vector<std::vector<bool>> adjM(numLoops,
-                                      std::vector<bool>(numLoops, false));
-  std::vector<unsigned> inDegree(numLoops, 0); // in-degree of each node.
-  const auto iteratorTypes = env.op().getIteratorTypesArray();
-  // Iterate over the indexing maps of every tensor in the tensor expression.
-  for (OpOperand &t : env.op()->getOpOperands()) {
-    // Get map and encoding.
-    const auto enc = getSparseTensorEncoding(t.get().getType());
-    // Skips dense inputs/outputs when not requested.
-    const bool isDenseInput = !enc && env.op().isDpsInput(&t);
-    const bool isDenseOutput = !enc && !isDenseInput;
-    if ((isDenseInput && !includesDenseInput(mask)) ||
-        (isDenseOutput && !includesDenseOutput(mask)))
-      continue;
-
-    // Push unrelated loops into sparse iteration space, so these
-    // will be skipped more often.
-    // TODO: Do we really need this?
-    if (includesUndef(mask)) {
-      const TensorId tid = env.makeTensorId(t.getOperandNumber());
-      for (LoopId i = 0; i < numLoops; i++) {
-        const auto dltI = env.dlt(tid, i);
-        if (isCompressedDLT(dltI) || isLooseCompressedDLT(dltI) ||
-            isSingletonDLT(dltI) || is2OutOf4DLT(dltI)) {
-          for (LoopId j = 0; j < numLoops; j++)
-            if (isUndefDLT(env.dlt(tid, j))) {
-              addIterOrdering(i, j, adjM, inDegree);
-            }
-        } else {
-          assert(isDenseDLT(dltI) || isUndefDLT(dltI));
-        }
-      }
-    }
-    // Push unrelated loops into sparse iteration space, so these
-    // will be skipped more often.
-    if (idxReducBased)
-      addSliceBasedConstraints(env, t, skip, mask, adjM, inDegree);
-    else
-      addFilterLoopBasedConstraints(env, t, skip, mask, adjM, inDegree);
-  }
-  // Topologically sort the iteration graph to determine loop order.
-  // Report failure for a cyclic iteration graph.
-  env.topSortClear(numLoops);
-  return topSortOptimal(env, iteratorTypes, inDegree, adjM);
-}
-
 //===----------------------------------------------------------------------===//
 // Sparsifier synthesis methods (statements and expressions).
 //===----------------------------------------------------------------------===//
@@ -1951,6 +1445,12 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
     if (hasNonTrivialAffineOnSparseOut(op))
       return failure();
 
+    if (!op->hasAttr("sorted")) {
+      return rewriter.notifyMatchFailure(
+          op, "Loops not yet scheduled, try run --sparse-reinterpret-map "
+              "before sparsification.");
+    }
+
     // Sets up a code generation environment.
     const unsigned numTensors = op->getNumOperands();
     const unsigned numLoops = op.getNumLoops();
@@ -1970,6 +1470,7 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
         maxLvlRank = std::max(maxLvlRank, SparseTensorType(rtp).getLvlRank());
       }
     }
+
     // A slice based algorithm for affine indices does not need filter loops.
     CodegenEnv env(op, options, numTensors, numLoops,
                    /*numFilterLoops=*/idxReducBased ? 0 : numFilterLoops,
@@ -2003,39 +1504,6 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
     if (failed(env.initTensorExp()))
       return failure();
 
-    // Computes a topologically sorted iteration graph to ensure tensors
-    // are visited in natural index order. Gradually relaxes the considered
-    // constraints until an acyclic iteration graph results, such that sparse
-    // code generation can proceed. As a last resort, an attempt is made
-    // to resolve cycles by inserting a conversion.
-    bool isAdmissible = false;
-    bool hasCycle = true;
-    // A const list of all masks that we used for iteration graph
-    // computation. Must be ordered from more strict to less strict.
-    // Ideally (though might not be guaranteed), the earlier a constraint mask
-    // can be satisfied, the faster the generated kernel will be.
-    const auto allMasks = {
-        SortMask::kIncludeAll,        SortMask::kIncludeDense,
-        SortMask::kIncludeDenseInput, SortMask::kIncludeDenseOutput,
-        SortMask::kIncludeUndef,      SortMask::kSparseOnly};
-    for (const SortMask mask : allMasks) {
-      if (computeIterationGraph(env, mask, nullptr, idxReducBased)) {
-        hasCycle = false;
-        if (env.isAdmissibleTopoOrder()) {
-          isAdmissible = true;
-          break;
-        }
-        // else try a set of less strict constraints.
-      }
-    }
-    if (hasCycle) {
-      return idxReducBased
-                 ? failure() // TODO: should cycle be resolved differently?
-                 : resolveCycle(env, rewriter); // one last shot
-    }
-    if (!isAdmissible)
-      return failure(); // inadmissible expression, reject
-
     // Recursively generates code if admissible.
     env.startEmit();
     genBuffers(env, rewriter);
@@ -2049,47 +1517,6 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
   }
 
 private:
-  // Last resort cycle resolution.
-  LogicalResult resolveCycle(CodegenEnv &env, PatternRewriter &rewriter) const {
-    // Compute topological sort while leaving out every
-    // sparse input tensor in succession until an acylic
-    // iteration graph results.
-    for (OpOperand *t : env.op().getDpsInputOperands()) {
-      const TensorId tid = env.makeTensorId(t->getOperandNumber());
-      Value tval = t->get();
-      auto srcEnc = getSparseTensorEncoding(tval.getType());
-      if (!srcEnc || !computeIterationGraph(env, SortMask::kSparseOnly, t))
-        continue;
-      // Found an input tensor that resolves the cycle by inserting a
-      // conversion into a sparse tensor that adheres to the iteration
-      // graph order. Also releases the temporary sparse tensor.
-      //
-      // TODO: investigate fusing the conversion with computation,
-      //       especially if it is a direct yield!
-      //
-      auto srcTp = getRankedTensorType(tval);
-      // TODO: This assertion is to match the behavior from prior to
-      // merging dimOrdering and higherOrdering into dimToLvl.  However,
-      // since `permute` returns a permutation, we can remove this
-      // restriction by instead composing the result of `permute`
-      // with `srcEnc.getDimToLvl`.
-      assert(srcEnc.isPermutation());
-      auto dstEnc =
-          srcEnc.withDimToLvl(permute(env, env.op().getMatchingIndexingMap(t)));
-      auto dstTp = RankedTensorType::get(srcTp.getShape(),
-                                         srcTp.getElementType(), dstEnc);
-      auto convert = rewriter.create<ConvertOp>(tval.getLoc(), dstTp, tval);
-      rewriter.updateRootInPlace(env.op(),
-                                 [&]() { env.op()->setOperand(tid, convert); });
-      rewriter.setInsertionPointAfter(env.op());
-      rewriter.create<bufferization::DeallocTensorOp>(tval.getLoc(), convert);
-      return success();
-    }
-    // Cannot be resolved with a single conversion.
-    // TODO: convert more than one?
-    return failure();
-  }
-
   /// Options to control sparse code generation.
   SparsificationOptions options;
 };
diff --git a/mlir/test/Dialect/SparseTensor/GPU/gpu_combi.mlir b/mlir/test/Dialect/SparseTensor/GPU/gpu_combi.mlir
index 0979884cbd502a5..b12bad685b49b97 100644
--- a/mlir/test/Dialect/SparseTensor/GPU/gpu_combi.mlir
+++ b/mlir/test/Dialect/SparseTensor/GPU/gpu_combi.mlir
@@ -1,5 +1,6 @@
 // RUN: mlir-opt %s --linalg-generalize-named-ops \
 // RUN:             --pre-sparsification-rewrite \
+// RUN:             --sparse-reinterpret-map \
 // RUN:             --sparsification="parallelization-strategy=dense-outer-loop" \
 // RUN:             --sparse-gpu-codegen | FileCheck %s
 
@@ -60,4 +61,3 @@ func.func @matmuls(%A: tensor<1024x8xf64>,
       outs(%Z: tensor<1024x1024xf64>) -> tensor<1024x1024xf64>
   return %D : tensor<1024x1024xf64>
 }
-
diff --git a/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul.mlir b/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul.mlir
index 8dc6afa320af7d5..e477dcc0169fc3d 100644
--- a/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul.mlir
+++ b/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul.mlir
@@ -1,5 +1,6 @@
 // RUN: mlir-opt %s --linalg-generalize-named-ops \
 // RUN:             --pre-sparsification-rewrite \
+// RUN:              --sparse-reinterpret-map \
 // RUN:             --sparsification="parallelization-strategy=dense-outer-loop" \
 // RUN:             --sparse-gpu-codegen | FileCheck %s
 
diff --git a/mlir/test/Dialect/SparseTensor/GPU/gpu_matvec.mlir b/mlir/test/Dialect/SparseTensor/GPU/gpu_matvec.mlir
index ab267062e41693a..0c5ff55dd863c08 100644
--- a/mlir/test/Dialect/SparseTensor/GPU/gpu_matvec.mlir
+++ b/mlir/test/Dialect/SparseTensor/GPU/gpu_matvec.mlir
@@ -1,5 +1,6 @@
 // RUN: mlir-opt %s --linalg-generalize-named-ops \
 // RUN:             --pre-sparsification-rewrite \
+// RUN:             --sparse-reinterpret-map \
 // RUN:             --sparsification="parallelization-strategy=dense-outer-loop" \
 // RUN:             --sparse-gpu-codegen | FileCheck %s
 
diff --git a/mlir/test/Dialect/SparseTensor/constant_index_map.mlir b/mlir/test/Dialect/SparseTensor/constant_index_map.mlir
index 9eb535385790146..eaef6a315852920 100644
--- a/mlir/test/Dialect/SparseTensor/constant_index_map.mlir
+++ b/mlir/test/Dialect/SparseTensor/constant_index_map.mlir
@@ -1,6 +1,6 @@
 // Reported by https://github.com/llvm/llvm-project/issues/61530
 
-// RUN: mlir-opt %s -sparsification | FileCheck %s
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification | FileCheck %s
 
 #map1 = affine_map<(d0) -> (0, d0)>
 #map2 = affine_map<(d0) -> (d0)>
diff --git a/mlir/test/Dialect/SparseTensor/dense.mlir b/mlir/test/Dialect/SparseTensor/dense.mlir
index 485a5cbb178af94..52db814572a15fb 100644
--- a/mlir/test/Dialect/SparseTensor/dense.mlir
+++ b/mlir/test/Dialect/SparseTensor/dense.mlir
@@ -1,5 +1,5 @@
 // NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
-// RUN: mlir-opt %s -sparsification | FileCheck %s
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification | FileCheck %s
 
 // Test to demonstrate the difference between non-annotated dense tensors
 // and all-dense-annotated "sparse" tensors. The former class remains as
diff --git a/mlir/test/Dialect/SparseTensor/one_trip.mlir b/mlir/test/Dialect/SparseTensor/one_trip.mlir
index 35d7878633b58a5..b8ab177357492d8 100644
--- a/mlir/test/Dialect/SparseTensor/one_trip.mlir
+++ b/mlir/test/Dialect/SparseTensor/one_trip.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -sparsification -cse | FileCheck %s
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification -cse | FileCheck %s
 
 #Dense = #sparse_tensor.encoding<{
   map = (d0, d1) -> (d0 : dense, d1 : dense)
diff --git a/mlir/test/Dialect/SparseTensor/semi_ring.mlir b/mlir/test/Dialect/SparseTensor/semi_ring.mlir
index c69efcae3b08ec2..5a936bfa1e597a7 100644
--- a/mlir/test/Dialect/SparseTensor/semi_ring.mlir
+++ b/mlir/test/Dialect/SparseTensor/semi_ring.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -sparsification | FileCheck %s
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification | FileCheck %s
 
 #SM = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
 
diff --git a/mlir/test/Dialect/SparseTensor/sorted_coo.mlir b/mlir/test/Dialect/SparseTensor/sorted_coo.mlir
index 3ce5dd1dcfdffc8..1a8af35358bf62c 100644
--- a/mlir/test/Dialect/SparseTensor/sorted_coo.mlir
+++ b/mlir/test/Dialect/SparseTensor/sorted_coo.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -sparsification --canonicalize | FileCheck %s
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification --canonicalize | FileCheck %s
 
 #SortedCOO = #sparse_tensor.encoding<{
   map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton)
diff --git a/mlir/test/Dialect/SparseTensor/sparse_1d.mlir b/mlir/test/Dialect/SparseTensor/sparse_1d.mlir
index 7c5e1a372d44fca..d29758e4fa97121 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_1d.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_1d.mlir
@@ -1,5 +1,5 @@
 // NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
-// RUN: mlir-opt %s -sparsification | FileCheck %s
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification | FileCheck %s
 
 #DV = #sparse_tensor.encoding<{ map = (d0) -> (d0 : dense) }>
 #SV = #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>
diff --git a/mlir/test/Dialect/SparseTensor/sparse_2d.mlir b/mlir/test/Dialect/SparseTensor/sparse_2d.mlir
index 076e9201a1c053e..22681daa94148ef 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_2d.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_2d.mlir
@@ -1,5 +1,5 @@
 // NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
-// RUN: mlir-opt %s -sparsification | FileCheck %s
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification | FileCheck %s
 
 #Tdd = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : dense) }>
 #Tds = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
diff --git a/mlir/test/Dialect/SparseTensor/sparse_3d.mlir b/mlir/test/Dialect/SparseTensor/sparse_3d.mlir
index c245e612be37f12..f2daa77a0c24c8a 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_3d.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_3d.mlir
@@ -1,5 +1,5 @@
 // NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
-// RUN: mlir-opt %s -sparsification | FileCheck %s
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification | FileCheck %s
 
 #Td = #sparse_tensor.encoding<{ map = (d0) -> (d0 : dense) }>
 
diff --git a/mlir/test/Dialect/SparseTensor/sparse_affine.mlir b/mlir/test/Dialect/SparseTensor/sparse_affine.mlir
index d6c3075a8732574..ca55d2eb0177819 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_affine.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_affine.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -sparsification | FileCheck %s
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification | FileCheck %s
 
 #SpVec = #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>
 #CSR   = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
diff --git a/mlir/test/Dialect/SparseTensor/sparse_broadcast.mlir b/mlir/test/Dialect/SparseTensor/sparse_broadcast.mlir
index 1a5f79d23cba29a..278450fabd74ec2 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_broadcast.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_broadcast.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --sparsification --canonicalize --cse | FileCheck %s
+// RUN: mlir-opt %s --sparse-reinterpret-map --sparsification --canonicalize --cse | FileCheck %s
 
 #DCSR = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed, d1 : compressed) }>
 #SparseTensor = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : compressed, d1 : compressed, d2 : compressed) }>
diff --git a/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir b/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir
index 7576a95c0fd0aac..de0582e5c688575 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --sparsification --canonicalize --cse | FileCheck %s
+// RUN: mlir-opt %s --sparse-reinterpret-map --sparsification --canonicalize --cse | FileCheck %s
 
 #map = affine_map<(d0, d1, d2, d3) -> (d0 + d2, d1 + d3)>
 #map1 = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
diff --git a/mlir/test/Dialect/SparseTensor/sparse_expand.mlir b/mlir/test/Dialect/SparseTensor/sparse_expand.mlir
index 4817771a6044ce2..bdbbf52d86286ab 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_expand.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_expand.mlir
@@ -1,9 +1,11 @@
 // RUN: mlir-opt %s --linalg-generalize-named-ops \
 // RUN:             --linalg-fuse-elementwise-ops \
+// RUN:             --sparse-reinterpret-map \
 // RUN:             --sparsification | \
 // RUN:   FileCheck %s --check-prefix=CHECK-SPARSE
 // RUN: mlir-opt %s --linalg-generalize-named-ops \
 // RUN:             --linalg-fuse-elementwise-ops \
+// RUN:             --sparse-reinterpret-map \
 // RUN:             --sparsification --lower-sparse-ops-to-foreach \
 // RUN:             --lower-sparse-foreach-to-scf \
 // RUN:             --sparse-tensor-conversion --cse | \
@@ -142,7 +144,8 @@ func.func @matmul1(%A: tensor<8x2xf64, #CSR>,
 // CHECK-SPARSE:   }
 // CHECK-SPARSE:   sparse_tensor.compress %[[A]], %[[B]], %[[C]], %[[COUNT]]
 // CHECK-SPARSE: }
-// CHECK-SPARSE: %[[RET:.*]] = sparse_tensor.load %[[T]] hasInserts
+// CHECK-SPARSE: %[[DEMAP:.*]] = sparse_tensor.load %[[T]] hasInserts
+// CHECK-SPARSE: %[[RET:.*]] = sparse_tensor.reinterpret_map %[[DEMAP]]
 // CHECK-SPARSE: return %[[RET]]
 //
 // CHECK-CONVERT-LABEL: func @matmul2(
diff --git a/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir b/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir
index 988ab7f85be4134..40367f12f85a472 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --linalg-generalize-named-ops --pre-sparsification-rewrite --sparsification --sparse-tensor-conversion --canonicalize --cse | FileCheck %s
+// RUN: mlir-opt %s --linalg-generalize-named-ops --pre-sparsification-rewrite --sparse-reinterpret-map --sparsification --sparse-tensor-conversion --canonicalize --cse | FileCheck %s
 
 #DCSR = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed, d1 : compressed) }>
 
diff --git a/mlir/test/Dialect/SparseTensor/sparse_fp_ops.mlir b/mlir/test/Dialect/SparseTensor/sparse_fp_ops.mlir
index dac34da30b49fda..f25a56469d7e742 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_fp_ops.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_fp_ops.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -sparsification | FileCheck %s
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification | FileCheck %s
 
 #SV = #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>
 
diff --git a/mlir/test/Dialect/SparseTensor/sparse_index.mlir b/mlir/test/Dialect/SparseTensor/sparse_index.mlir
index 34e04c03529036f..e49c89e01790c08 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_index.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_index.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -sparsification | FileCheck %s
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification | FileCheck %s
 
 #DenseMatrix = #sparse_tensor.encoding<{
   map = (d0, d1) -> (d0 : dense, d1 : dense)
diff --git a/mlir/test/Dialect/SparseTensor/sparse_int_ops.mlir b/mlir/test/Dialect/SparseTensor/sparse_int_ops.mlir
index 0bc396278357656..da7d45bd63e17d4 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_int_ops.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_int_ops.mlir
@@ -1,5 +1,5 @@
 // NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
-// RUN: mlir-opt %s -sparsification | FileCheck %s
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification | FileCheck %s
 
 #SV = #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>
 
@@ -506,4 +506,3 @@ func.func @lslbyc(%arga: tensor<32xi64, #SV>,
   } -> tensor<32xi64>
   return %0 : tensor<32xi64>
 }
-
diff --git a/mlir/test/Dialect/SparseTensor/sparse_kernels.mlir b/mlir/test/Dialect/SparseTensor/sparse_kernels.mlir
index fa0617b9cb6f58e..1cb23cfd5b9eb2f 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_kernels.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_kernels.mlir
@@ -1,6 +1,6 @@
 // RUN: mlir-opt %s \
 // RUN: --linalg-generalize-named-ops --linalg-fuse-elementwise-ops \
-// RUN: --sparsification | FileCheck %s
+// RUN: --sparse-reinterpret-map --sparsification | FileCheck %s
 
 #SparseVector = #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>
 
@@ -195,10 +195,11 @@ func.func @matmul2(%A: tensor<4x8xf64, #DCSR>,
   return %D: tensor<4x4xf64, #DCSR>
 }
 
+
 // CHECK-LABEL:   func.func @conv2d(
-// CHECK-SAME:      %[[VAL_0:.*]]: tensor<8x8xi32>,
-// CHECK-SAME:      %[[VAL_1:.*]]: tensor<3x3xi32, #sparse_tensor.encoding<{{{.*}}}>>,
-// CHECK-SAME:      %[[VAL_2:.*]]: tensor<6x6xi32>) -> tensor<6x6xi32> {
+// CHECK-SAME:                      %[[VAL_0:.*]]: tensor<8x8xi32>,
+// CHECK-SAME:                      %[[VAL_1:.*]]: tensor<3x3xi32, #sparse_tensor.encoding<{{{.*}}}>>,
+// CHECK-SAME:                      %[[VAL_2:.*]]: tensor<6x6xi32>) -> tensor<6x6xi32> {
 // CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 6 : index
 // CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 0 : index
 // CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 1 : index
@@ -208,33 +209,34 @@ func.func @matmul2(%A: tensor<4x8xf64, #DCSR>,
 // CHECK-DAG:       %[[VAL_9:.*]] = sparse_tensor.positions %[[VAL_1]] {level = 1 : index} : tensor<3x3xi32, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_10:.*]] = sparse_tensor.coordinates %[[VAL_1]] {level = 1 : index} : tensor<3x3xi32, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_11:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<3x3xi32, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xi32>
-// CHECK:           %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_2]] : memref<6x6xi32>
+// CHECK-DAG:       %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_2]] : memref<6x6xi32>
 // CHECK:           scf.for %[[VAL_13:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] {
-// CHECK:             %[[VAL_14:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_4]]] : memref<?xindex>
-// CHECK:             %[[VAL_15:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_5]]] : memref<?xindex>
-// CHECK:             scf.for %[[VAL_16:.*]] = %[[VAL_14]] to %[[VAL_15]] step %[[VAL_5]] {
-// CHECK:               %[[VAL_17:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_16]]] : memref<?xindex>
-// CHECK:               scf.for %[[VAL_18:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] {
-// CHECK:                 %[[VAL_19:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_13]], %[[VAL_18]]] : memref<6x6xi32>
-// CHECK:                 %[[VAL_20:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_16]]] : memref<?xindex>
-// CHECK:                 %[[VAL_21:.*]] = arith.addi %[[VAL_16]], %[[VAL_5]] : index
-// CHECK:                 %[[VAL_22:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_21]]] : memref<?xindex>
-// CHECK:                 %[[VAL_23:.*]] = scf.for %[[VAL_24:.*]] = %[[VAL_20]] to %[[VAL_22]] step %[[VAL_5]] iter_args(%[[VAL_25:.*]] = %[[VAL_19]]) -> (i32) {
-// CHECK:                   %[[VAL_26:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_24]]] : memref<?xindex>
-// CHECK:                   %[[VAL_27:.*]] = arith.addi %[[VAL_13]], %[[VAL_17]] : index
-// CHECK:                   %[[VAL_28:.*]] = arith.addi %[[VAL_18]], %[[VAL_26]] : index
-// CHECK:                   %[[VAL_29:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_27]], %[[VAL_28]]] : memref<8x8xi32>
-// CHECK:                   %[[VAL_30:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_24]]] : memref<?xi32>
-// CHECK:                   %[[VAL_31:.*]] = arith.muli %[[VAL_29]], %[[VAL_30]] : i32
-// CHECK:                   %[[VAL_32:.*]] = arith.addi %[[VAL_25]], %[[VAL_31]] : i32
-// CHECK:                   scf.yield %[[VAL_32]] : i32
+// CHECK:             scf.for %[[VAL_14:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] {
+// CHECK:               %[[VAL_15:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_13]], %[[VAL_14]]] : memref<6x6xi32>
+// CHECK:               %[[VAL_16:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_4]]] : memref<?xindex>
+// CHECK:               %[[VAL_17:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_5]]] : memref<?xindex>
+// CHECK:               %[[VAL_18:.*]] = scf.for %[[VAL_19:.*]] = %[[VAL_16]] to %[[VAL_17]] step %[[VAL_5]] iter_args(%[[VAL_20:.*]] = %[[VAL_15]]) -> (i32) {
+// CHECK:                 %[[VAL_21:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_19]]] : memref<?xindex>
+// CHECK:                 %[[VAL_22:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_19]]] : memref<?xindex>
+// CHECK:                 %[[VAL_23:.*]] = arith.addi %[[VAL_19]], %[[VAL_5]] : index
+// CHECK:                 %[[VAL_24:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_23]]] : memref<?xindex>
+// CHECK:                 %[[VAL_25:.*]] = scf.for %[[VAL_26:.*]] = %[[VAL_22]] to %[[VAL_24]] step %[[VAL_5]] iter_args(%[[VAL_27:.*]] = %[[VAL_20]]) -> (i32) {
+// CHECK:                   %[[VAL_28:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_26]]] : memref<?xindex>
+// CHECK:                   %[[VAL_29:.*]] = arith.addi %[[VAL_13]], %[[VAL_21]] : index
+// CHECK:                   %[[VAL_30:.*]] = arith.addi %[[VAL_14]], %[[VAL_28]] : index
+// CHECK:                   %[[VAL_31:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_29]], %[[VAL_30]]] : memref<8x8xi32>
+// CHECK:                   %[[VAL_32:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_26]]] : memref<?xi32>
+// CHECK:                   %[[VAL_33:.*]] = arith.muli %[[VAL_31]], %[[VAL_32]] : i32
+// CHECK:                   %[[VAL_34:.*]] = arith.addi %[[VAL_27]], %[[VAL_33]] : i32
+// CHECK:                   scf.yield %[[VAL_34]] : i32
 // CHECK:                 } {"Emitted from" = "linalg.generic"}
-// CHECK:                 memref.store %[[VAL_33:.*]], %[[VAL_12]]{{\[}}%[[VAL_13]], %[[VAL_18]]] : memref<6x6xi32>
+// CHECK:                 scf.yield %[[VAL_25]] : i32
 // CHECK:               } {"Emitted from" = "linalg.generic"}
+// CHECK:               memref.store %[[VAL_18]], %[[VAL_12]]{{\[}}%[[VAL_13]], %[[VAL_14]]] : memref<6x6xi32>
 // CHECK:             } {"Emitted from" = "linalg.generic"}
 // CHECK:           } {"Emitted from" = "linalg.generic"}
-// CHECK:           %[[VAL_34:.*]] = bufferization.to_tensor %[[VAL_12]] : memref<6x6xi32>
-// CHECK:           return %[[VAL_34]] : tensor<6x6xi32>
+// CHECK:           %[[VAL_35:.*]] = bufferization.to_tensor %[[VAL_12]] : memref<6x6xi32>
+// CHECK:           return %[[VAL_35]] : tensor<6x6xi32>
 // CHECK:         }
 func.func @conv2d(%input:  tensor<8x8xi32>,
              %filter: tensor<3x3xi32, #DCSR>,
diff --git a/mlir/test/Dialect/SparseTensor/sparse_lower.mlir b/mlir/test/Dialect/SparseTensor/sparse_lower.mlir
index 13245f427a7970f..788dd9c26e6cd38 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_lower.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_lower.mlir
@@ -1,9 +1,9 @@
-// RUN: mlir-opt %s -sparsification | FileCheck %s --check-prefix=CHECK-HIR
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification | FileCheck %s --check-prefix=CHECK-HIR
 //
-// RUN: mlir-opt %s -sparsification --sparse-tensor-conversion --cse | \
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification --sparse-tensor-conversion --cse | \
 // RUN: FileCheck %s --check-prefix=CHECK-MIR
 //
-// RUN: mlir-opt %s -sparsification --sparse-tensor-conversion --cse \
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification --sparse-tensor-conversion --cse \
 // RUN: --func-bufferize --arith-bufferize           \
 // RUN: --tensor-bufferize --finalizing-bufferize |  \
 // RUN: FileCheck %s --check-prefix=CHECK-LIR
diff --git a/mlir/test/Dialect/SparseTensor/sparse_lower_col.mlir b/mlir/test/Dialect/SparseTensor/sparse_lower_col.mlir
index a987d59f6773139..da08f53bd089bd5 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_lower_col.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_lower_col.mlir
@@ -1,9 +1,9 @@
-// RUN: mlir-opt %s -sparsification | FileCheck %s --check-prefix=CHECK-HIR
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification | FileCheck %s --check-prefix=CHECK-HIR
 //
-// RUN: mlir-opt %s -sparsification --sparse-tensor-conversion --cse | \
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification --sparse-tensor-conversion --cse | \
 // RUN: FileCheck %s --check-prefix=CHECK-MIR
 //
-// RUN: mlir-opt %s -sparsification --sparse-tensor-conversion --cse \
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification --sparse-tensor-conversion --cse \
 // RUN: --func-bufferize --arith-bufferize           \
 // RUN: --tensor-bufferize --finalizing-bufferize |  \
 // RUN: FileCheck %s --check-prefix=CHECK-LIR
@@ -29,9 +29,10 @@
 // CHECK-HIR-DAG:       %[[VAL_3:.*]] = arith.constant 64 : index
 // CHECK-HIR-DAG:       %[[VAL_4:.*]] = arith.constant 0 : index
 // CHECK-HIR-DAG:       %[[VAL_5:.*]] = arith.constant 1 : index
-// CHECK-HIR-DAG:       %[[VAL_6:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<32x64xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
-// CHECK-HIR-DAG:       %[[VAL_7:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<32x64xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
-// CHECK-HIR-DAG:       %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x64xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xf64>
+// CHECK-HIR:           %[[DEMAP:.*]] = sparse_tensor.reinterpret_map %[[VAL_0]]
+// CHECK-HIR-DAG:       %[[VAL_6:.*]] = sparse_tensor.positions %[[DEMAP]] {level = 1 : index}
+// CHECK-HIR-DAG:       %[[VAL_7:.*]] = sparse_tensor.coordinates %[[DEMAP]] {level = 1 : index}
+// CHECK-HIR-DAG:       %[[VAL_8:.*]] = sparse_tensor.values %[[DEMAP]]
 // CHECK-HIR-DAG:       %[[VAL_9:.*]] = bufferization.to_memref %[[VAL_1]] : memref<64xf64>
 // CHECK-HIR-DAG:       %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32xf64>
 // CHECK-HIR:           scf.for %[[VAL_12:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] {
diff --git a/mlir/test/Dialect/SparseTensor/sparse_lower_inplace.mlir b/mlir/test/Dialect/SparseTensor/sparse_lower_inplace.mlir
index 0e09f658fa15caa..eb6920304ed8479 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_lower_inplace.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_lower_inplace.mlir
@@ -1,9 +1,9 @@
-// RUN: mlir-opt %s -sparsification | FileCheck %s --check-prefix=CHECK-HIR
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification | FileCheck %s --check-prefix=CHECK-HIR
 //
-// RUN: mlir-opt %s -sparsification --sparse-tensor-conversion --cse | \
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification --sparse-tensor-conversion --cse | \
 // RUN: FileCheck %s --check-prefix=CHECK-MIR
 //
-// RUN: mlir-opt %s -sparsification --sparse-tensor-conversion --cse \
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification --sparse-tensor-conversion --cse \
 // RUN: --func-bufferize --arith-bufferize           \
 // RUN: --tensor-bufferize --finalizing-bufferize |  \
 // RUN: FileCheck %s --check-prefix=CHECK-LIR
diff --git a/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir b/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir
index 12d443ca679207e..5145d6c1dcfc322 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir
@@ -1,7 +1,7 @@
 // NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
 
 // RUN: mlir-opt %s --linalg-generalize-named-ops \
-// RUN:  --sparsification --sparse-tensor-codegen \
+// RUN:  --sparse-reinterpret-map --sparsification --sparse-tensor-codegen \
 // RUN:  --canonicalize --cse | FileCheck %s
 
 #CSR = #sparse_tensor.encoding<{
diff --git a/mlir/test/Dialect/SparseTensor/sparse_nd.mlir b/mlir/test/Dialect/SparseTensor/sparse_nd.mlir
index d5dc16bed20c7af..27716a82c164fa0 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_nd.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_nd.mlir
@@ -1,5 +1,5 @@
 // NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
-// RUN: mlir-opt %s -sparsification | FileCheck %s
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification | FileCheck %s
 
 // Example with cyclic iteration graph with sparse and dense constraints,
 // but an acyclic iteration graph using sparse constraints only.
diff --git a/mlir/test/Dialect/SparseTensor/sparse_out.mlir b/mlir/test/Dialect/SparseTensor/sparse_out.mlir
index 794eb39cb2fcb58..8bf0625425ead11 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_out.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_out.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -sparsification | FileCheck %s
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification | FileCheck %s
 
 #CSR = #sparse_tensor.encoding<{
   map = (d0, d1) -> (d0 : dense, d1 : compressed),
diff --git a/mlir/test/Dialect/SparseTensor/sparse_outbuf.mlir b/mlir/test/Dialect/SparseTensor/sparse_outbuf.mlir
index 7dbb1c7d9490877..1028b58be37df30 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_outbuf.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_outbuf.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -sparsification | FileCheck %s
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification | FileCheck %s
 
 #SV = #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>
 
@@ -110,4 +110,3 @@ func.func @update_inplace(%arga: tensor<10xf32, #SV>,
   } -> tensor<10xf32>
   return %0 : tensor<10xf32>
 }
-
diff --git a/mlir/test/Dialect/SparseTensor/sparse_parallel.mlir b/mlir/test/Dialect/SparseTensor/sparse_parallel.mlir
index dab70e6a3e6f139..1db7cf85ea5fb78 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_parallel.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_parallel.mlir
@@ -1,12 +1,12 @@
-// RUN: mlir-opt %s -sparsification="parallelization-strategy=none" | \
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification="parallelization-strategy=none" | \
 // RUN:   FileCheck %s --check-prefix=CHECK-PAR0
-// RUN: mlir-opt %s -sparsification="parallelization-strategy=dense-outer-loop" | \
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification="parallelization-strategy=dense-outer-loop" | \
 // RUN:   FileCheck %s --check-prefix=CHECK-PAR1
-// RUN: mlir-opt %s -sparsification="parallelization-strategy=any-storage-outer-loop" | \
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification="parallelization-strategy=any-storage-outer-loop" | \
 // RUN:   FileCheck %s --check-prefix=CHECK-PAR2
-// RUN: mlir-opt %s -sparsification="parallelization-strategy=dense-any-loop" | \
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification="parallelization-strategy=dense-any-loop" | \
 // RUN:   FileCheck %s --check-prefix=CHECK-PAR3
-// RUN: mlir-opt %s -sparsification="parallelization-strategy=any-storage-any-loop" | \
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification="parallelization-strategy=any-storage-any-loop" | \
 // RUN:   FileCheck %s --check-prefix=CHECK-PAR4
 
 #DenseMatrix = #sparse_tensor.encoding<{
diff --git a/mlir/test/Dialect/SparseTensor/sparse_parallel_reduce.mlir b/mlir/test/Dialect/SparseTensor/sparse_parallel_reduce.mlir
index 26f5a05b64697ac..8118eeceb6a62ee 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_parallel_reduce.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_parallel_reduce.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -sparsification="parallelization-strategy=any-storage-any-loop" | \
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification="parallelization-strategy=any-storage-any-loop" | \
 // RUN:   FileCheck %s
 
 #CSR = #sparse_tensor.encoding<{
diff --git a/mlir/test/Dialect/SparseTensor/sparse_perm.mlir b/mlir/test/Dialect/SparseTensor/sparse_perm.mlir
index 186030aa2ca0873..83cbfa771f1d626 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_perm.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_perm.mlir
@@ -1,5 +1,5 @@
 // NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
-// RUN: mlir-opt %s -sparsification | FileCheck %s
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification | FileCheck %s
 
 #X = #sparse_tensor.encoding<{
   map = (d0, d1, d2) -> (d2 : dense, d0 : dense, d1 : dense)
@@ -22,7 +22,8 @@
 // CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 10 : index
 // CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 0 : index
 // CHECK-DAG:       %[[VAL_6:.*]] = arith.constant 1 : index
-// CHECK-DAG:       %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<10x20x30xf32, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK:           %[[DEMAP:.*]] = sparse_tensor.reinterpret_map %[[VAL_0]]
+// CHECK-DAG:       %[[VAL_7:.*]] = sparse_tensor.values %[[DEMAP]] : tensor<30x10x20xf32, #sparse_tensor.encoding<{{{.*}}}>>
 // CHECK-DAG:       %[[VAL_9:.*]] = bufferization.to_memref %[[VAL_1]] : memref<20x30x10xf32>
 // CHECK:           linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_9]] : memref<20x30x10xf32>)
 // CHECK:           scf.for %[[VAL_10:.*]] = %[[VAL_5]] to %[[VAL_3]] step %[[VAL_6]] {
@@ -58,10 +59,11 @@ func.func @sparse_static_dims(%arga: tensor<10x20x30xf32, #X>,
 // CHECK-DAG:       %[[VAL_2:.*]] = arith.constant 2 : index
 // CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 0 : index
 // CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 1 : index
-// CHECK-DAG:       %[[VAL_5:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<?x?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
-// CHECK-DAG:       %[[VAL_6:.*]] = sparse_tensor.lvl %[[VAL_0]], %[[VAL_2]] : tensor<?x?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
-// CHECK-DAG:       %[[VAL_7:.*]] = sparse_tensor.lvl %[[VAL_0]], %[[VAL_3]] : tensor<?x?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
-// CHECK-DAG:       %[[VAL_8:.*]] = sparse_tensor.lvl %[[VAL_0]], %[[VAL_4]] : tensor<?x?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK:           %[[DEMAP:.*]] = sparse_tensor.reinterpret_map %[[VAL_0]]
+// CHECK-DAG:       %[[VAL_5:.*]] = sparse_tensor.values %[[DEMAP]] : tensor<?x?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK-DAG:       %[[VAL_6:.*]] = sparse_tensor.lvl %[[DEMAP]], %[[VAL_2]] : tensor<?x?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK-DAG:       %[[VAL_7:.*]] = sparse_tensor.lvl %[[DEMAP]], %[[VAL_3]] : tensor<?x?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK-DAG:       %[[VAL_8:.*]] = sparse_tensor.lvl %[[DEMAP]], %[[VAL_4]] : tensor<?x?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
 // CHECK-DAG:       %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_1]] : memref<?x?x?xf32>
 // CHECK:           linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_10]] : memref<?x?x?xf32>)
 // CHECK:           scf.for %[[VAL_11:.*]] = %[[VAL_3]] to %[[VAL_7]] step %[[VAL_4]] {
diff --git a/mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir b/mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir
index 02738a9e4544ac3..1078e5c282df11a 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir
@@ -1,6 +1,6 @@
-// RUN: mlir-opt %s -sparsification --canonicalize | FileCheck %s --check-prefix=CHECK-HIR
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification --canonicalize | FileCheck %s --check-prefix=CHECK-HIR
 //
-// RUN: mlir-opt %s -sparsification --sparse-tensor-conversion --canonicalize | \
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification --sparse-tensor-conversion --canonicalize | \
 // RUN: FileCheck %s --check-prefix=CHECK-MIR
 
 #X = #sparse_tensor.encoding<{
@@ -21,10 +21,11 @@
 // CHECK-HIR-DAG:       %[[VAL_2:.*]] = arith.constant 1 : index
 // CHECK-HIR-DAG:       %[[VAL_3:.*]] = arith.constant 0 : index
 // CHECK-HIR-DAG:       %[[VAL_4:.*]] = arith.constant 2 : index
-// CHECK-HIR-DAG:       %[[VAL_5:.*]] = sparse_tensor.lvl %[[VAL_0]], %[[VAL_3]] : tensor<?x?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
-// CHECK-HIR-DAG:       %[[VAL_6:.*]] = sparse_tensor.lvl %[[VAL_0]], %[[VAL_2]] : tensor<?x?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
-// CHECK-HIR-DAG:       %[[VAL_7:.*]] = sparse_tensor.lvl %[[VAL_0]], %[[VAL_4]] : tensor<?x?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
-// CHECK-HIR-DAG:       %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<?x?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK-HIR:           %[[DEMAP:. *]] = sparse_tensor.reinterpret_map %[[VAL_0]]
+// CHECK-HIR-DAG:       %[[VAL_5:.*]] = sparse_tensor.lvl %[[DEMAP]], %[[VAL_3]] : tensor<?x?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK-HIR-DAG:       %[[VAL_6:.*]] = sparse_tensor.lvl %[[DEMAP]], %[[VAL_2]] : tensor<?x?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK-HIR-DAG:       %[[VAL_7:.*]] = sparse_tensor.lvl %[[DEMAP]], %[[VAL_4]] : tensor<?x?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK-HIR-DAG:       %[[VAL_8:.*]] = sparse_tensor.values %[[DEMAP]] : tensor<?x?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
 // CHECK-HIR-DAG:       %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_1]] : memref<f32>
 // CHECK-HIR:           %[[VAL_11:.*]] = tensor.extract %[[VAL_1]][] : tensor<f32>
 // CHECK-HIR:           %[[VAL_12:.*]] = scf.for %[[VAL_13:.*]] = %[[VAL_3]] to %[[VAL_5]] step %[[VAL_2]] iter_args(%[[VAL_14:.*]] = %[[VAL_11]]) -> (f32) {
diff --git a/mlir/test/Dialect/SparseTensor/sparse_scalars.mlir b/mlir/test/Dialect/SparseTensor/sparse_scalars.mlir
index 85ab2a654098ad6..4ca577b9f729031 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_scalars.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_scalars.mlir
@@ -1,5 +1,5 @@
 // NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
-// RUN: mlir-opt %s -sparsification | FileCheck %s
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification | FileCheck %s
 
 #SparseMatrix = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed, d1 : compressed) }>
 
diff --git a/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir b/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir
index 55deb91c5a1bb4f..1d95fe8d0569a23 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s  --test-tensor-copy-insertion --pre-sparsification-rewrite --sparsification --cse | FileCheck %s
+// RUN: mlir-opt %s  --test-tensor-copy-insertion --pre-sparsification-rewrite --sparse-reinterpret-map --sparsification --cse | FileCheck %s
 
 #SM = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed, d1 : compressed) }>
 
diff --git a/mlir/test/Dialect/SparseTensor/sparse_sddmm_org.mlir b/mlir/test/Dialect/SparseTensor/sparse_sddmm_org.mlir
index ed5b72d0d872cbc..79bedcf5a49e16c 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_sddmm_org.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_sddmm_org.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s  --pre-sparsification-rewrite --sparsification --cse | FileCheck %s
+// RUN: mlir-opt %s --pre-sparsification-rewrite --sparse-reinterpret-map --sparsification --cse | FileCheck %s
 
 #SM = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed, d1 : compressed) }>
 
diff --git a/mlir/test/Dialect/SparseTensor/sparse_storage.mlir b/mlir/test/Dialect/SparseTensor/sparse_storage.mlir
index 65ea2d131277980..a8cc056139e49e6 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_storage.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_storage.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -sparsification= | FileCheck %s
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification= | FileCheck %s
 
 #SparseVector64 = #sparse_tensor.encoding<{
   map = (d0) -> (d0 : compressed),
diff --git a/mlir/test/Dialect/SparseTensor/sparse_transpose.mlir b/mlir/test/Dialect/SparseTensor/sparse_transpose.mlir
index b61136553d2679a..fd41fedf91c2cbb 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_transpose.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_transpose.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -sparsification | FileCheck %s
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification | FileCheck %s
 
 #DCSR = #sparse_tensor.encoding<{
   map = (d0, d1) -> (d0 : compressed, d1 : compressed)
@@ -21,11 +21,12 @@
 // CHECK-DAG:       %[[VAL_2:.*]] = arith.constant 1 : index
 // CHECK-DAG:       %[[VAL_3:.*]] = tensor.empty() : tensor<4x3xf64, #sparse_tensor.encoding<{{{.*}}}>>
 // CHECK-DAG:       %[[VAL_4:.*]] = sparse_tensor.convert %[[VAL_0]] : tensor<3x4xf64, #sparse_tensor.encoding<{{{.*}}}>> to tensor<3x4xf64, #sparse_tensor.encoding<{{{.*}}}>>
-// CHECK-DAG:       %[[VAL_5:.*]] = sparse_tensor.positions %[[VAL_4]] {level = 0 : index} : tensor<3x4xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
-// CHECK-DAG:       %[[VAL_6:.*]] = sparse_tensor.coordinates %[[VAL_4]] {level = 0 : index} : tensor<3x4xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
-// CHECK-DAG:       %[[VAL_7:.*]] = sparse_tensor.positions %[[VAL_4]] {level = 1 : index} : tensor<3x4xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
-// CHECK-DAG:       %[[VAL_8:.*]] = sparse_tensor.coordinates %[[VAL_4]] {level = 1 : index} : tensor<3x4xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
-// CHECK-DAG:       %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_4]] : tensor<3x4xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xf64>
+// CHECK:           %[[DEMAP:.*]] = sparse_tensor.reinterpret_map %[[VAL_4]] : tensor<3x4xf64, #sparse_tensor.encoding<{{{.*}}}>> to tensor<4x3xf64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK-DAG:       %[[VAL_5:.*]] = sparse_tensor.positions %[[DEMAP]] {level = 0 : index} : tensor<4x3xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_6:.*]] = sparse_tensor.coordinates %[[DEMAP]] {level = 0 : index} : tensor<4x3xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_7:.*]] = sparse_tensor.positions %[[DEMAP]] {level = 1 : index} : tensor<4x3xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_8:.*]] = sparse_tensor.coordinates %[[DEMAP]] {level = 1 : index} : tensor<4x3xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_9:.*]] = sparse_tensor.values %[[DEMAP]] : tensor<4x3xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xf64>
 // CHECK:           %[[VAL_10:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_1]]] : memref<?xindex>
 // CHECK:           %[[VAL_11:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_2]]] : memref<?xindex>
 // CHECK:           %[[VAL_12:.*]] = scf.for %[[VAL_13:.*]] = %[[VAL_10]] to %[[VAL_11]] step %[[VAL_2]] iter_args(%[[VAL_14:.*]] = %[[VAL_3]]) -> (tensor<4x3xf64, #sparse_tensor.encoding<{{{.*}}}>>) {
@@ -42,7 +43,6 @@
 // CHECK:             scf.yield %[[VAL_25:.*]] : tensor<4x3xf64, #sparse_tensor.encoding<{{{.*}}}>>
 // CHECK:           }
 // CHECK:           %[[VAL_26:.*]] = sparse_tensor.load %[[VAL_27:.*]] hasInserts : tensor<4x3xf64, #sparse_tensor.encoding<{{{.*}}}>>
-// CHECK:           bufferization.dealloc_tensor %[[VAL_4]] : tensor<3x4xf64, #sparse_tensor.encoding<{{{.*}}}>>
 // CHECK:           return %[[VAL_26]] : tensor<4x3xf64, #sparse_tensor.encoding<{{{.*}}}>>
 // CHECK:         }
 func.func @sparse_transpose_auto(%arga: tensor<3x4xf64, #DCSR>)
diff --git a/mlir/test/Dialect/SparseTensor/sparse_vector.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector.mlir
index 48ba9119c4f44ea..364ba6e71ff3bb9 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_vector.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_vector.mlir
@@ -1,10 +1,10 @@
-// RUN: mlir-opt %s -sparsification -cse -split-input-file | \
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification -cse -split-input-file | \
 // RUN:   FileCheck %s --check-prefix=CHECK-SCALAR
-// RUN: mlir-opt %s -sparsification -cse -sparse-vectorization="vl=16" -cse -split-input-file | \
+// RUN: mlir-opt %s --sparse-reinterpret-map --sparse-reinterpret-map -sparsification -cse -sparse-vectorization="vl=16" -cse -split-input-file | \
 // RUN:   FileCheck %s --check-prefix=CHECK-VEC16
-// RUN: mlir-opt %s -sparsification -cse -sparse-vectorization="vl=16 enable-simd-index32=true" -cse -split-input-file | \
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification -cse -sparse-vectorization="vl=16 enable-simd-index32=true" -cse -split-input-file | \
 // RUN:   FileCheck %s --check-prefix=CHECK-VEC16-IDX32
-// RUN: mlir-opt %s -sparsification -cse -sparse-vectorization="vl=4 enable-vla-vectorization=true" -cse -split-input-file | \
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification -cse -sparse-vectorization="vl=4 enable-vla-vectorization=true" -cse -split-input-file | \
 // RUN:   FileCheck %s --check-prefix=CHECK-VEC4-SVE
 
 #DenseVector = #sparse_tensor.encoding<{ map = (d0) -> (d0 : dense) }>
diff --git a/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir
index b8a4563bd63d87f..269a4926c9ec375 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -sparsification -cse -sparse-vectorization="vl=8" -cse | \
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification -cse -sparse-vectorization="vl=8" -cse | \
 // RUN:   FileCheck %s
 
 #SparseMatrix = #sparse_tensor.encoding<{map = (d0, d1) -> (d0 : dense, d1 : compressed)}>
diff --git a/mlir/test/Dialect/SparseTensor/sparse_vector_index.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector_index.mlir
index 31a78ceb0e2b36b..9bc24fd02b827d0 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_vector_index.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_vector_index.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -sparsification -cse -sparse-vectorization="vl=8" -cse | \
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification -cse -sparse-vectorization="vl=8" -cse | \
 // RUN:   FileCheck %s
 
 // NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
diff --git a/mlir/test/Dialect/SparseTensor/sparse_vector_ops.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector_ops.mlir
index 67841beaa6933f6..fc6741c1f69a6cd 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_vector_ops.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_vector_ops.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -sparsification -cse -sparse-vectorization="vl=8" -cse | \
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification -cse -sparse-vectorization="vl=8" -cse | \
 // RUN:   FileCheck %s
 
 #DenseVector = #sparse_tensor.encoding<{ map = (d0) -> (d0 : dense) }>
@@ -84,4 +84,3 @@ func.func @vops(%arga: tensor<1024xf32, #DenseVector>,
   } -> tensor<1024xf32>
   return %0 : tensor<1024xf32>
 }
-
diff --git a/mlir/test/Dialect/SparseTensor/sparse_vector_peeled.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector_peeled.mlir
index db830f29d7ca655..99d6a3dc390e086 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_vector_peeled.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_vector_peeled.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --sparsification -cse -sparse-vectorization="vl=16" -scf-for-loop-peeling -canonicalize -cse | \
+// RUN: mlir-opt %s --sparse-reinterpret-map --sparsification -cse -sparse-vectorization="vl=16" -scf-for-loop-peeling -canonicalize -cse | \
 // RUN:   FileCheck %s
 
 #SparseVector = #sparse_tensor.encoding<{
diff --git a/mlir/test/Dialect/SparseTensor/spy_sddmm.mlir b/mlir/test/Dialect/SparseTensor/spy_sddmm.mlir
index eaa15d7f83bc47f..a4211991f260833 100644
--- a/mlir/test/Dialect/SparseTensor/spy_sddmm.mlir
+++ b/mlir/test/Dialect/SparseTensor/spy_sddmm.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -sparsification | FileCheck %s
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification | FileCheck %s
 
 //
 // A SDDMM implementation with "spy" function and
diff --git a/mlir/test/Dialect/SparseTensor/unsparsifiable_dense_op.mlir b/mlir/test/Dialect/SparseTensor/unsparsifiable_dense_op.mlir
index 7112d7dc8ca900b..ea29f1d677eff7b 100644
--- a/mlir/test/Dialect/SparseTensor/unsparsifiable_dense_op.mlir
+++ b/mlir/test/Dialect/SparseTensor/unsparsifiable_dense_op.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -sparsification | FileCheck %s
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification | FileCheck %s
 
 #trait = {
   indexing_maps = [
diff --git a/mlir/test/Dialect/SparseTensor/unused-tensor.mlir b/mlir/test/Dialect/SparseTensor/unused-tensor.mlir
index 5f169dd989bdc48..330bb9fa57483f5 100644
--- a/mlir/test/Dialect/SparseTensor/unused-tensor.mlir
+++ b/mlir/test/Dialect/SparseTensor/unused-tensor.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -sparsification | FileCheck %s
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification | FileCheck %s
 
 //
 // A contrived example where the sparse tensor B is only
diff --git a/mlir/test/Dialect/SparseTensor/vectorize_reduction.mlir b/mlir/test/Dialect/SparseTensor/vectorize_reduction.mlir
index a4f58a4fa913f0c..ef3c0f88d4222f2 100644
--- a/mlir/test/Dialect/SparseTensor/vectorize_reduction.mlir
+++ b/mlir/test/Dialect/SparseTensor/vectorize_reduction.mlir
@@ -1,6 +1,6 @@
-// RUN: mlir-opt %s -sparsification -cse -sparse-vectorization="vl=8" -cse -split-input-file | \
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification -cse -sparse-vectorization="vl=8" -cse -split-input-file | \
 // RUN:   FileCheck %s --check-prefix=CHECK-ON
-// RUN: mlir-opt %s -sparsification -cse -split-input-file | \
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification -cse -split-input-file | \
 // RUN:   FileCheck %s --check-prefix=CHECK-OFF
 
 // -----



More information about the Mlir-commits mailing list