[Mlir-commits] [mlir] [mlir][sparse] introduce MapRef, unify conversion/codegen for reader (PR #68360)

Aart Bik llvmlistbot at llvm.org
Thu Oct 5 15:18:58 PDT 2023


https://github.com/aartbik updated https://github.com/llvm/llvm-project/pull/68360

>From 6094912685a0cfa5c13e023e8ec97238a84fca2f Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Thu, 5 Oct 2023 13:22:28 -0700
Subject: [PATCH 1/3] [mlir][sparse] introduce MapRef, unify conversion/codegen
 for reader

This revision introduces a MapRef, which will support a future
generalization beyond permutations (e.g. block sparsity). This
revision also unifies the conversion/codegen paths for the
sparse_tensor.new operation from file (eg. the readers). Note
that more unification is planned as well as general affine
dim2lvl and lvl2dim (all marked with TODOs).
---
 .../mlir/ExecutionEngine/SparseTensor/File.h  | 156 ++++++----------
 .../ExecutionEngine/SparseTensor/MapRef.h     |  96 ++++++++++
 .../ExecutionEngine/SparseTensor/Storage.h    | 108 +----------
 .../ExecutionEngine/SparseTensorRuntime.h     |   8 -
 .../SparseTensor/Transforms/CodegenUtils.cpp  |  89 +++++++++
 .../SparseTensor/Transforms/CodegenUtils.h    |  18 ++
 .../Transforms/SparseTensorCodegen.cpp        |  73 ++------
 .../Transforms/SparseTensorConversion.cpp     | 111 ++---------
 .../SparseTensor/CMakeLists.txt               |   1 +
 .../ExecutionEngine/SparseTensor/MapRef.cpp   |  52 ++++++
 .../ExecutionEngine/SparseTensorRuntime.cpp   |  60 +++---
 mlir/test/Dialect/SparseTensor/codegen.mlir   | 172 +++++++++---------
 .../test/Dialect/SparseTensor/conversion.mlir |  18 +-
 13 files changed, 475 insertions(+), 487 deletions(-)
 create mode 100644 mlir/include/mlir/ExecutionEngine/SparseTensor/MapRef.h
 create mode 100644 mlir/lib/ExecutionEngine/SparseTensor/MapRef.cpp

diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h b/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h
index 78c1a0544e3a521..9157bfa7e773239 100644
--- a/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h
+++ b/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h
@@ -20,6 +20,7 @@
 #ifndef MLIR_EXECUTIONENGINE_SPARSETENSOR_FILE_H
 #define MLIR_EXECUTIONENGINE_SPARSETENSOR_FILE_H
 
+#include "mlir/ExecutionEngine/SparseTensor/MapRef.h"
 #include "mlir/ExecutionEngine/SparseTensor/Storage.h"
 
 #include <fstream>
@@ -75,6 +76,10 @@ inline V readValue(char **linePtr, bool isPattern) {
 
 } // namespace detail
 
+//===----------------------------------------------------------------------===//
+//
+//  Reader class.
+//
 //===----------------------------------------------------------------------===//
 
 /// This class abstracts over the information stored in file headers,
@@ -132,6 +137,7 @@ class SparseTensorReader final {
   /// Reads and parses the file's header.
   void readHeader();
 
+  /// Returns the stored value kind.
   ValueKind getValueKind() const { return valueKind_; }
 
   /// Checks if a header has been successfully read.
@@ -185,58 +191,37 @@ class SparseTensorReader final {
   /// valid after parsing the header.
   void assertMatchesShape(uint64_t rank, const uint64_t *shape) const;
 
-  /// Reads a sparse tensor element from the next line in the input file and
-  /// returns the value of the element. Stores the coordinates of the element
-  /// to the `dimCoords` array.
-  template <typename V>
-  V readElement(uint64_t dimRank, uint64_t *dimCoords) {
-    assert(dimRank == getRank() && "rank mismatch");
-    char *linePtr = readCoords(dimCoords);
-    return detail::readValue<V>(&linePtr, isPattern());
-  }
-
-  /// Allocates a new COO object for `lvlSizes`, initializes it by reading
-  /// all the elements from the file and applying `dim2lvl` to their
-  /// dim-coordinates, and then closes the file. Templated on V only.
-  template <typename V>
-  SparseTensorCOO<V> *readCOO(uint64_t lvlRank, const uint64_t *lvlSizes,
-                              const uint64_t *dim2lvl);
-
   /// Allocates a new sparse-tensor storage object with the given encoding,
   /// initializes it by reading all the elements from the file, and then
   /// closes the file. Templated on P, I, and V.
   template <typename P, typename I, typename V>
   SparseTensorStorage<P, I, V> *
   readSparseTensor(uint64_t lvlRank, const uint64_t *lvlSizes,
-                   const DimLevelType *lvlTypes, const uint64_t *lvl2dim,
-                   const uint64_t *dim2lvl) {
-    auto *lvlCOO = readCOO<V>(lvlRank, lvlSizes, dim2lvl);
+                   const DimLevelType *lvlTypes, const uint64_t *dim2lvl,
+                   const uint64_t *lvl2dim) {
+    const uint64_t dimRank = getRank();
+    MapRef map(dimRank, lvlRank, dim2lvl, lvl2dim);
+    auto *coo = readCOO<V>(map, lvlSizes);
     auto *tensor = SparseTensorStorage<P, I, V>::newFromCOO(
-        getRank(), getDimSizes(), lvlRank, lvlTypes, lvl2dim, *lvlCOO);
-    delete lvlCOO;
+        dimRank, getDimSizes(), lvlRank, lvlTypes, lvl2dim, *coo);
+    delete coo;
     return tensor;
   }
 
   /// Reads the COO tensor from the file, stores the coordinates and values to
   /// the given buffers, returns a boolean value to indicate whether the COO
   /// elements are sorted.
-  /// Precondition: the buffers should have enough space to hold the elements.
   template <typename C, typename V>
   bool readToBuffers(uint64_t lvlRank, const uint64_t *dim2lvl,
-                     C *lvlCoordinates, V *values);
+                     const uint64_t *lvl2dim, C *lvlCoordinates, V *values);
 
 private:
-  /// Attempts to read a line from the file.  Is private because there's
-  /// no reason for client code to call it.
+  /// Attempts to read a line from the file.
   void readLine();
 
   /// Reads the next line of the input file and parses the coordinates
   /// into the `dimCoords` argument.  Returns the position in the `line`
-  /// buffer where the element's value should be parsed from.  This method
-  /// has been factored out from `readElement` to minimize code bloat
-  /// for the generated library.
-  ///
-  /// Precondition: `dimCoords` is valid for `getRank()`.
+  /// buffer where the element's value should be parsed from.
   template <typename C>
   char *readCoords(C *dimCoords) {
     readLine();
@@ -251,24 +236,20 @@ class SparseTensorReader final {
     return linePtr;
   }
 
-  /// The internal implementation of `readCOO`.  We template over
-  /// `IsPattern` in order to perform LICM without needing to duplicate the
-  /// source code.
-  //
-  // TODO: We currently take the `dim2lvl` argument as a `PermutationRef`
-  // since that's what `readCOO` creates.  Once we update `readCOO` to
-  // functionalize the mapping, then this helper will just take that
-  // same function.
+  /// Reads all the elements from the file while applying the given map.
+  template <typename V>
+  SparseTensorCOO<V> *readCOO(const MapRef &map, const uint64_t *lvlSizes);
+
+  /// The implementation of `readCOO` that is templated `IsPattern` in order
+  /// to perform LICM without needing to duplicate the source code.
   template <typename V, bool IsPattern>
-  void readCOOLoop(uint64_t lvlRank, detail::PermutationRef dim2lvl,
-                   SparseTensorCOO<V> *lvlCOO);
+  void readCOOLoop(const MapRef &map, SparseTensorCOO<V> *coo);
 
-  /// The internal implementation of `readToBuffers`.  We template over
-  /// `IsPattern` in order to perform LICM without needing to duplicate the
-  /// source code.
+  /// The internal implementation of `readToBuffers`. We template over
+  /// `IsPattern` in order to perform LICM without needing to duplicate
+  /// the source code.
   template <typename C, typename V, bool IsPattern>
-  bool readToBuffersLoop(uint64_t lvlRank, detail::PermutationRef dim2lvl,
-                         C *lvlCoordinates, V *values);
+  bool readToBuffersLoop(const MapRef &map, C *lvlCoordinates, V *values);
 
   /// Reads the MME header of a general sparse matrix of type real.
   void readMMEHeader();
@@ -288,96 +269,76 @@ class SparseTensorReader final {
   char line[kColWidth];
 };
 
+//===----------------------------------------------------------------------===//
+//
+//  Reader class methods.
+//
 //===----------------------------------------------------------------------===//
 
 template <typename V>
-SparseTensorCOO<V> *SparseTensorReader::readCOO(uint64_t lvlRank,
-                                                const uint64_t *lvlSizes,
-                                                const uint64_t *dim2lvl) {
+SparseTensorCOO<V> *SparseTensorReader::readCOO(const MapRef &map,
+                                                const uint64_t *lvlSizes) {
   assert(isValid() && "Attempt to readCOO() before readHeader()");
-  const uint64_t dimRank = getRank();
-  assert(lvlRank == dimRank && "Rank mismatch");
-  detail::PermutationRef d2l(dimRank, dim2lvl);
   // Prepare a COO object with the number of stored elems as initial capacity.
-  auto *lvlCOO = new SparseTensorCOO<V>(lvlRank, lvlSizes, getNSE());
-  // Do some manual LICM, to avoid assertions in the for-loop.
-  const bool IsPattern = isPattern();
-  if (IsPattern)
-    readCOOLoop<V, true>(lvlRank, d2l, lvlCOO);
+  auto *coo = new SparseTensorCOO<V>(map.getLvlRank(), lvlSizes, getNSE());
+  // Enter the reading loop.
+  if (isPattern())
+    readCOOLoop<V, true>(map, coo);
   else
-    readCOOLoop<V, false>(lvlRank, d2l, lvlCOO);
+    readCOOLoop<V, false>(map, coo);
   // Close the file and return the COO.
   closeFile();
-  return lvlCOO;
+  return coo;
 }
 
 template <typename V, bool IsPattern>
-void SparseTensorReader::readCOOLoop(uint64_t lvlRank,
-                                     detail::PermutationRef dim2lvl,
-                                     SparseTensorCOO<V> *lvlCOO) {
-  const uint64_t dimRank = getRank();
+void SparseTensorReader::readCOOLoop(const MapRef &map,
+                                     SparseTensorCOO<V> *coo) {
+  const uint64_t dimRank = map.getDimRank();
+  const uint64_t lvlRank = map.getLvlRank();
+  assert(dimRank == getRank());
   std::vector<uint64_t> dimCoords(dimRank);
   std::vector<uint64_t> lvlCoords(lvlRank);
-  for (uint64_t nse = getNSE(), k = 0; k < nse; ++k) {
-    // We inline `readElement` here in order to avoid redundant
-    // assertions, since they're guaranteed by the call to `isValid()`
-    // and the construction of `dimCoords` above.
+  for (uint64_t k = 0, nse = getNSE(); k < nse; k++) {
     char *linePtr = readCoords(dimCoords.data());
     const V value = detail::readValue<V, IsPattern>(&linePtr);
-    dim2lvl.pushforward(dimRank, dimCoords.data(), lvlCoords.data());
-    // TODO: <https://github.com/llvm/llvm-project/issues/54179>
-    lvlCOO->add(lvlCoords, value);
+    map.pushforward(dimCoords.data(), lvlCoords.data());
+    coo->add(lvlCoords, value);
   }
 }
 
 template <typename C, typename V>
 bool SparseTensorReader::readToBuffers(uint64_t lvlRank,
                                        const uint64_t *dim2lvl,
+                                       const uint64_t *lvl2dim,
                                        C *lvlCoordinates, V *values) {
   assert(isValid() && "Attempt to readCOO() before readHeader()");
-  // Construct a `PermutationRef` for the `pushforward` below.
-  // TODO: This specific implementation does not generalize to arbitrary
-  // mappings, but once we functionalize the `dim2lvl` argument we can
-  // simply use that function instead.
-  const uint64_t dimRank = getRank();
-  assert(lvlRank == dimRank && "Rank mismatch");
-  detail::PermutationRef d2l(dimRank, dim2lvl);
-  // Do some manual LICM, to avoid assertions in the for-loop.
+  MapRef map(getRank(), lvlRank, dim2lvl, lvl2dim);
   bool isSorted =
-      isPattern()
-          ? readToBuffersLoop<C, V, true>(lvlRank, d2l, lvlCoordinates, values)
-          : readToBuffersLoop<C, V, false>(lvlRank, d2l, lvlCoordinates,
-                                           values);
-
-  // Close the file and return isSorted.
+      isPattern() ? readToBuffersLoop<C, V, true>(map, lvlCoordinates, values)
+                  : readToBuffersLoop<C, V, false>(map, lvlCoordinates, values);
   closeFile();
   return isSorted;
 }
 
 template <typename C, typename V, bool IsPattern>
-bool SparseTensorReader::readToBuffersLoop(uint64_t lvlRank,
-                                           detail::PermutationRef dim2lvl,
-                                           C *lvlCoordinates, V *values) {
-  const uint64_t dimRank = getRank();
+bool SparseTensorReader::readToBuffersLoop(const MapRef &map, C *lvlCoordinates,
+                                           V *values) {
+  const uint64_t dimRank = map.getDimRank();
+  const uint64_t lvlRank = map.getLvlRank();
   const uint64_t nse = getNSE();
+  assert(dimRank == getRank());
   std::vector<C> dimCoords(dimRank);
-  // Read the first element with isSorted=false as a way to avoid accessing its
-  // previous element.
   bool isSorted = false;
   char *linePtr;
-  // We inline `readElement` here in order to avoid redundant assertions,
-  // since they're guaranteed by the call to `isValid()` and the construction
-  // of `dimCoords` above.
   const auto readNextElement = [&]() {
     linePtr = readCoords<C>(dimCoords.data());
-    dim2lvl.pushforward(dimRank, dimCoords.data(), lvlCoordinates);
+    map.pushforward(dimCoords.data(), lvlCoordinates);
     *values = detail::readValue<V, IsPattern>(&linePtr);
     if (isSorted) {
-      // Note that isSorted was set to false while reading the first element,
+      // Note that isSorted is set to false when reading the first element,
       // to guarantee the safeness of using prevLvlCoords.
       C *prevLvlCoords = lvlCoordinates - lvlRank;
-      // TODO: define a new CoordsLT which is like ElementLT but doesn't have
-      // the V parameter, and use it here.
       for (uint64_t l = 0; l < lvlRank; ++l) {
         if (prevLvlCoords[l] != lvlCoordinates[l]) {
           if (prevLvlCoords[l] > lvlCoordinates[l])
@@ -393,7 +354,6 @@ bool SparseTensorReader::readToBuffersLoop(uint64_t lvlRank,
   isSorted = true;
   for (uint64_t n = 1; n < nse; ++n)
     readNextElement();
-
   return isSorted;
 }
 
diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensor/MapRef.h b/mlir/include/mlir/ExecutionEngine/SparseTensor/MapRef.h
new file mode 100644
index 000000000000000..1c155568802e579
--- /dev/null
+++ b/mlir/include/mlir/ExecutionEngine/SparseTensor/MapRef.h
@@ -0,0 +1,96 @@
+//===- MapRef.h - A dim2lvl/lvl2dim map encoding ----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// A dim2lvl/lvl2dim map encoding class, with utility methods.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_EXECUTIONENGINE_SPARSETENSOR_MAPREF_H
+#define MLIR_EXECUTIONENGINE_SPARSETENSOR_MAPREF_H
+
+#include <cinttypes>
+
+#include <cassert>
+
+namespace mlir {
+namespace sparse_tensor {
+
+/// A class for capturing the sparse tensor type map with a compact encoding.
+///
+/// Currently, the following situations are supported:
+///   (1) map is an identity
+///   (2) map is a permutation
+///   (3) map has affine ops (restricted set)
+///
+/// The pushforward/backward operations are fast for (1) and (2) but
+/// incur some obvious overhead for situation (3).
+///
+class MapRef final {
+public:
+  MapRef(uint64_t d, uint64_t l, const uint64_t *d2l, const uint64_t *l2d);
+
+  //
+  // Push forward maps from dimensions to levels.
+  //
+
+  template <typename T> inline void pushforward(const T *in, T *out) const {
+    switch (kind) {
+    case MapKind::kIdentity:
+      for (uint64_t i = 0; i < dimRank; ++i)
+        out[i] = in[i]; // TODO: optimize with in == out ?
+      break;
+    case MapKind::kPermutation:
+      for (uint64_t i = 0; i < dimRank; ++i)
+        out[dim2lvl[i]] = in[i];
+      break;
+    case MapKind::kAffine:
+      assert(0 && "coming soon");
+      break;
+    }
+  }
+
+  //
+  // Push backward maps from levels to dimensions.
+  //
+
+  template <typename T> inline void pushbackward(const T *in, T *out) const {
+    switch (kind) {
+    case MapKind::kIdentity:
+      for (uint64_t i = 0; i < lvlRank; ++i)
+        out[i] = in[i];
+      break;
+    case MapKind::kPermutation:
+      for (uint64_t i = 0; i < lvlRank; ++i)
+        out[lvl2dim[i]] = in[i];
+      break;
+    case MapKind::kAffine:
+      assert(0 && "coming soon");
+      break;
+    }
+  }
+
+  uint64_t getDimRank() const { return dimRank; }
+  uint64_t getLvlRank() const { return lvlRank; }
+
+private:
+  enum class MapKind { kIdentity, kPermutation, kAffine };
+
+  bool isIdentity() const;
+  bool isPermutation() const;
+
+  MapKind kind;
+  const uint64_t dimRank;
+  const uint64_t lvlRank;
+  const uint64_t *const dim2lvl; // non-owning pointer
+  const uint64_t *const lvl2dim; // non-owning pointer
+};
+
+} // namespace sparse_tensor
+} // namespace mlir
+
+#endif //  MLIR_EXECUTIONENGINE_SPARSETENSOR_MAPREF_H
diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
index 28c28c28109c3c7..37ad3c1b042313c 100644
--- a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
+++ b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
@@ -49,103 +49,6 @@ class SparseTensorEnumeratorBase;
 template <typename P, typename C, typename V>
 class SparseTensorEnumerator;
 
-namespace detail {
-
-/// Checks whether the `perm` array is a permutation of `[0 .. size)`.
-inline bool isPermutation(uint64_t size, const uint64_t *perm) {
-  assert(perm && "Got nullptr for permutation");
-  std::vector<bool> seen(size, false);
-  for (uint64_t i = 0; i < size; ++i) {
-    const uint64_t j = perm[i];
-    if (j >= size || seen[j])
-      return false;
-    seen[j] = true;
-  }
-  for (uint64_t i = 0; i < size; ++i)
-    if (!seen[i])
-      return false;
-  return true;
-}
-
-/// Wrapper around `isPermutation` to ensure consistent error messages.
-inline void assertIsPermutation(uint64_t size, const uint64_t *perm) {
-#ifndef NDEBUG
-  if (!isPermutation(size, perm))
-    MLIR_SPARSETENSOR_FATAL("Not a permutation of [0..%" PRIu64 ")\n", size);
-#endif
-}
-
-/// A class for capturing the knowledge that `isPermutation` is true.
-class PermutationRef final {
-public:
-  /// Asserts `isPermutation` and returns the witness to that being true.
-  explicit PermutationRef(uint64_t size, const uint64_t *perm)
-      : permSize(size), perm(perm) {
-    assertIsPermutation(size, perm);
-  }
-
-  uint64_t size() const { return permSize; }
-
-  const uint64_t *data() const { return perm; }
-
-  const uint64_t &operator[](uint64_t i) const {
-    assert(i < permSize && "index is out of bounds");
-    return perm[i];
-  }
-
-  /// Constructs a pushforward array of values.  This method is the inverse
-  /// of `permute` in the sense that for all `p` and `xs` we have:
-  /// * `p.permute(p.pushforward(xs)) == xs`
-  /// * `p.pushforward(p.permute(xs)) == xs`
-  template <typename T>
-  inline std::vector<T> pushforward(const std::vector<T> &values) const {
-    return pushforward(values.size(), values.data());
-  }
-
-  template <typename T>
-  inline std::vector<T> pushforward(uint64_t size, const T *values) const {
-    std::vector<T> out(permSize);
-    pushforward(size, values, out.data());
-    return out;
-  }
-
-  template <typename T>
-  inline void pushforward(uint64_t size, const T *values, T *out) const {
-    assert(size == permSize && "size mismatch");
-    for (uint64_t i = 0; i < permSize; ++i)
-      out[perm[i]] = values[i];
-  }
-
-  /// Constructs a permuted array of values.  This method is the inverse
-  /// of `pushforward` in the sense that for all `p` and `xs` we have:
-  /// * `p.permute(p.pushforward(xs)) == xs`
-  /// * `p.pushforward(p.permute(xs)) == xs`
-  template <typename T>
-  inline std::vector<T> permute(const std::vector<T> &values) const {
-    return permute(values.size(), values.data());
-  }
-
-  template <typename T>
-  inline std::vector<T> permute(uint64_t size, const T *values) const {
-    std::vector<T> out(permSize);
-    permute(size, values, out.data());
-    return out;
-  }
-
-  template <typename T>
-  inline void permute(uint64_t size, const T *values, T *out) const {
-    assert(size == permSize && "size mismatch");
-    for (uint64_t i = 0; i < permSize; ++i)
-      out[i] = values[perm[i]];
-  }
-
-private:
-  const uint64_t permSize;
-  const uint64_t *const perm; // non-owning pointer.
-};
-
-} // namespace detail
-
 /// Abstract base class for `SparseTensorStorage<P,C,V>`.  This class
 /// takes responsibility for all the `<P,C,V>`-independent aspects
 /// of the tensor (e.g., shape, sparsity, permutation).  In addition,
@@ -263,7 +166,7 @@ class SparseTensorStorageBase {
   bool isUniqueLvl(uint64_t l) const { return isUniqueDLT(getLvlType(l)); }
 
   /// Allocates a new enumerator.  Callers must make sure to delete
-  /// the enumerator when they're done with it.  The first argument
+  /// the enumerator when they're done with it. The first argument
   /// is the out-parameter for storing the newly allocated enumerator;
   /// all other arguments are passed along to the `SparseTensorEnumerator`
   /// ctor and must satisfy the preconditions/assertions thereof.
@@ -326,6 +229,7 @@ class SparseTensorStorageBase {
   const std::vector<uint64_t> lvl2dim;
 };
 
+
 /// A memory-resident sparse tensor using a storage scheme based on
 /// per-level sparse/dense annotations.  This data structure provides
 /// a bufferized form of a sparse tensor type.  In contrast to generating
@@ -401,7 +305,7 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
                       const DimLevelType *lvlTypes, const uint64_t *lvl2dim,
                       const intptr_t *lvlBufs);
 
-  /// Allocates a new empty sparse tensor.  The preconditions/assertions
+  /// Allocates a new empty sparse tensor. The preconditions/assertions
   /// are as per the `SparseTensorStorageBase` ctor; which is to say,
   /// the `dimSizes` and `lvlSizes` must both be "sizes" not "shapes",
   /// since there's nowhere to reconstruct dynamic sizes from.
@@ -577,6 +481,7 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
   SparseTensorCOO<V> *toCOO(uint64_t trgRank, const uint64_t *trgSizes,
                             uint64_t srcRank, const uint64_t *src2trg) const {
     // We inline `newEnumerator` to avoid virtual dispatch and allocation.
+    // TODO: use MapRef here too for the translation
     SparseTensorEnumerator<P, C, V> enumerator(*this, trgRank, trgSizes,
                                                srcRank, src2trg);
     auto *coo = new SparseTensorCOO<V>(trgRank, trgSizes, values.size());
@@ -733,7 +638,7 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
     }
   }
 
-  /// Continues a single insertion path, outer to inner.  The first
+  /// Continues a single insertion path, outer to inner. The first
   /// argument is the level-coordinates for the value being inserted.
   void insPath(const uint64_t *lvlCoords, uint64_t diffLvl, uint64_t full,
                V val) {
@@ -875,7 +780,8 @@ class SparseTensorEnumeratorBase {
 
 //===----------------------------------------------------------------------===//
 template <typename P, typename C, typename V>
-class SparseTensorEnumerator final : public SparseTensorEnumeratorBase<V> {
+class SparseTensorEnumerator final
+    : public SparseTensorEnumeratorBase<V> {
   using Base = SparseTensorEnumeratorBase<V>;
   using StorageImpl = SparseTensorStorage<P, C, V>;
 
diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensorRuntime.h b/mlir/include/mlir/ExecutionEngine/SparseTensorRuntime.h
index 8f320f04f23fc84..861b7eff65115b6 100644
--- a/mlir/include/mlir/ExecutionEngine/SparseTensorRuntime.h
+++ b/mlir/include/mlir/ExecutionEngine/SparseTensorRuntime.h
@@ -143,14 +143,6 @@ MLIR_CRUNNERUTILS_EXPORT void *_mlir_ciface_newSparseTensorFromReader(
 MLIR_CRUNNERUTILS_EXPORT void _mlir_ciface_getSparseTensorReaderDimSizes(
     StridedMemRefType<index_type, 1> *out, void *p);
 
-/// Returns the next element for the sparse tensor being read.
-#define DECL_GETNEXT(VNAME, V)                                                 \
-  MLIR_CRUNNERUTILS_EXPORT void _mlir_ciface_getSparseTensorReaderNext##VNAME( \
-      void *p, StridedMemRefType<index_type, 1> *dimCoordsRef,                 \
-      StridedMemRefType<V, 0> *vref);
-MLIR_SPARSETENSOR_FOREVERY_V(DECL_GETNEXT)
-#undef DECL_GETNEXT
-
 /// Reads the sparse tensor, stores the coordinates and values to the given
 /// memrefs. Returns a boolean to indicate whether the COO elements are sorted.
 #define DECL_GETNEXT(VNAME, V, CNAME, C)                                       \
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
index ce77d7a519877d6..ffb1a550957edb8 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
@@ -729,3 +729,92 @@ Value sparse_tensor::createOrFoldSliceStrideOp(OpBuilder &builder, Location loc,
     return constantIndex(builder, loc, *stride);
   return builder.create<ToSliceStrideOp>(loc, tensor, APInt(64, dim));
 }
+
+void sparse_tensor::fillDimShape(OpBuilder &builder, Location loc,
+                                 SparseTensorType stt,
+                                 SmallVectorImpl<Value> &out) {
+  out.clear();
+  out.reserve(stt.getDimRank());
+  for (const DynSize sh : stt.getDimShape()) {
+    const auto s = ShapedType::isDynamic(sh) ? 0 : sh;
+    out.push_back(constantIndex(builder, loc, s));
+  }
+}
+
+Value sparse_tensor::genReader(OpBuilder &builder, Location loc,
+                               SparseTensorType stt, Value tensor,
+                               /*out*/ SmallVectorImpl<Value> &dimShapesValues,
+                               /*out*/ Value &dimSizesBuffer) {
+  // Construct the dimShapes buffer. The buffer contains the static size
+  // per dimension, or otherwise a zero for a dynamic size.
+  fillDimShape(builder, loc, stt, dimShapesValues);
+  Value dimShapesBuffer = allocaBuffer(builder, loc, dimShapesValues);
+  // Create the `CheckedSparseTensorReader`. This reader performs a
+  // consistency check on the static sizes, but accepts any size
+  // of each dimension with a dynamic size.
+  Type opaqueTp = getOpaquePointerType(builder);
+  Type eltTp = stt.getElementType();
+  Value valTp = constantPrimaryTypeEncoding(builder, loc, eltTp);
+  Value reader =
+      createFuncCall(builder, loc, "createCheckedSparseTensorReader", opaqueTp,
+                     {tensor, dimShapesBuffer, valTp}, EmitCInterface::On)
+          .getResult(0);
+  // For static shapes, the shape buffer can be used right away. For dynamic
+  // shapes, use the information from the reader to construct a buffer that
+  // supplies the actual size for each dynamic dimension.
+  dimSizesBuffer = dimShapesBuffer;
+  if (stt.hasDynamicDimShape()) {
+    Type indexTp = builder.getIndexType();
+    auto memTp = MemRefType::get({ShapedType::kDynamic}, indexTp);
+    dimSizesBuffer =
+        createFuncCall(builder, loc, "getSparseTensorReaderDimSizes", memTp,
+                       reader, EmitCInterface::On)
+            .getResult(0);
+  }
+  return reader;
+}
+
+Value sparse_tensor::genReaderBuffers(OpBuilder &builder, Location loc,
+                                      SparseTensorType stt,
+                                      SmallVectorImpl<Value> &dimShapesValues,
+                                      Value dimSizesBuffer,
+                                      /*out*/ Value &dim2lvlBuffer,
+                                      /*out*/ Value &lvl2dimBuffer) {
+  const Dimension dimRank = stt.getDimRank();
+  const Level lvlRank = stt.getLvlRank();
+  // For an identify mapping, the dim2lvl and lvl2dim mappings are
+  // identical as are dimSizes and lvlSizes, so buffers are reused
+  // as much as possible.
+  if (stt.isIdentity()) {
+    assert(dimRank == lvlRank);
+    SmallVector<Value> iotaValues;
+    iotaValues.reserve(lvlRank);
+    for (Level l = 0; l < lvlRank; l++)
+      iotaValues.push_back(constantIndex(builder, loc, l));
+    dim2lvlBuffer = lvl2dimBuffer = allocaBuffer(builder, loc, iotaValues);
+    return dimSizesBuffer;
+  }
+  // Otherwise, some code needs to be generated to set up the buffers.
+  // TODO: use the lvl2dim once available and deal with non-permutations!
+  const auto dimToLvl = stt.getDimToLvl();
+  assert(dimToLvl.isPermutation());
+  SmallVector<Value> dim2lvlValues(dimRank);
+  SmallVector<Value> lvl2dimValues(lvlRank);
+  SmallVector<Value> lvlSizesValues(lvlRank);
+  for (Level l = 0; l < lvlRank; l++) {
+    // The `d`th source variable occurs in the `l`th result position.
+    Dimension d = dimToLvl.getDimPosition(l);
+    Value lvl = constantIndex(builder, loc, l);
+    Value dim = constantIndex(builder, loc, d);
+    dim2lvlValues[d] = lvl;
+    lvl2dimValues[l] = dim;
+    if (stt.isDynamicDim(d))
+      lvlSizesValues[l] =
+          builder.create<memref::LoadOp>(loc, dimSizesBuffer, dim);
+    else
+      lvlSizesValues[l] = dimShapesValues[d];
+  }
+  dim2lvlBuffer = allocaBuffer(builder, loc, dim2lvlValues);
+  lvl2dimBuffer = allocaBuffer(builder, loc, lvl2dimValues);
+  return allocaBuffer(builder, loc, lvlSizesValues);
+}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
index 8145446751b9938..08ea019d8224a73 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
@@ -19,6 +19,7 @@
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/SparseTensor/IR/Enums.h"
 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
+#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
 #include "mlir/IR/Builders.h"
 
@@ -341,6 +342,23 @@ Value createOrFoldSliceOffsetOp(OpBuilder &builder, Location loc, Value tensor,
 Value createOrFoldSliceStrideOp(OpBuilder &builder, Location loc, Value tensor,
                                 Dimension dim);
 
+/// Populates the array with the dimension-shape of the given
+/// `SparseTensorType`, where dynamic sizes are represented by zero.
+void fillDimShape(OpBuilder &builder, Location loc, SparseTensorType stt,
+                  SmallVectorImpl<Value> &out);
+
+/// Generates code that opens a reader and sets the dimension sizes.
+Value genReader(OpBuilder &builder, Location loc, SparseTensorType stt,
+                Value tensor,
+                /*out*/ SmallVectorImpl<Value> &dimShapeValues,
+                /*out*/ Value &dimSizesBuffer);
+
+/// Generates code to set up the buffer parameters for a reader.
+Value genReaderBuffers(OpBuilder &builder, Location loc, SparseTensorType stt,
+                       SmallVectorImpl<Value> &dimShapeValues,
+                       Value dimSizesBuffer, /*out*/ Value &dim2lvlBuffer,
+                       /*out*/ Value &lvl2dimBuffer);
+
 //===----------------------------------------------------------------------===//
 // Inlined constant generators.
 //
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 7c362c086623b42..2c03f0a6020e6a8 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -1428,7 +1428,7 @@ struct SparseDisassembleOpConverter
   }
 };
 
-struct SparseNewOpConverter : public OpConversionPattern<NewOp> {
+struct SparseNewConverter : public OpConversionPattern<NewOp> {
   using OpConversionPattern::OpConversionPattern;
   LogicalResult
   matchAndRewrite(NewOp op, OpAdaptor adaptor,
@@ -1440,7 +1440,7 @@ struct SparseNewOpConverter : public OpConversionPattern<NewOp> {
     if (!dstTp.hasEncoding() || getCOOStart(dstTp.getEncoding()) != 0)
       return failure();
 
-    // Implement the NewOp(filename) as follows:
+    // Implement as follows:
     //   %reader = @createCheckedSparseTensorReader(%filename)
     //   %nse = @getSparseTensorNSE(%reader)
     //   %coo = bufferization.alloc_tensor an ordered COO with
@@ -1451,74 +1451,39 @@ struct SparseNewOpConverter : public OpConversionPattern<NewOp> {
     //   if (! %isSorted) sparse_tensor.sort_coo(%nse, %coordinates, %values)
     //   update storage specifier
     //   @delSparseTensorReader(%reader)
+    SmallVector<Value> dimShapesValues;
+    Value dimSizesBuffer;
+    Value reader = genReader(rewriter, loc, dstTp, adaptor.getOperands()[0],
+                             dimShapesValues, dimSizesBuffer);
 
-    // Allocate `SparseTensorReader` and perform all initial setup that
-    // does not depend on lvlSizes (nor dimToLvl, lvlToDim, etc).
-    const Type opaqueTp = getOpaquePointerType(rewriter);
-    const Value fileName = op.getSource();
-    SmallVector<Value> dimShapeValues;
-    for (const DynSize sh : dstTp.getDimShape()) {
-      const auto s = ShapedType::isDynamic(sh) ? 0 : sh;
-      dimShapeValues.push_back(constantIndex(rewriter, loc, s));
-    }
-    Value dimShapeBuffer = allocaBuffer(rewriter, loc, dimShapeValues);
-    Value valTp =
-        constantPrimaryTypeEncoding(rewriter, loc, dstTp.getElementType());
-    Value reader =
-        createFuncCall(rewriter, loc, "createCheckedSparseTensorReader",
-                       opaqueTp, {fileName, dimShapeBuffer, valTp},
-                       EmitCInterface::On)
-            .getResult(0);
+    // Get the number of stored entries.
     const Type indexTp = rewriter.getIndexType();
-    const Dimension dimRank = dstTp.getDimRank();
-    const Level lvlRank = dstTp.getLvlRank();
+    Value nse = createFuncCall(rewriter, loc, "getSparseTensorReaderNSE",
+                               {indexTp}, {reader}, EmitCInterface::Off)
+                    .getResult(0);
 
-    // If the result tensor has dynamic dimensions, get the dynamic sizes from
-    // the sparse tensor reader.
+    // Construct allocation for each field.
     SmallVector<Value> dynSizes;
     if (dstTp.hasDynamicDimShape()) {
-      auto memTp = MemRefType::get({ShapedType::kDynamic}, indexTp);
-      Value dimSizesBuffer =
-          createFuncCall(rewriter, loc, "getSparseTensorReaderDimSizes", memTp,
-                         reader, EmitCInterface::On)
-              .getResult(0);
       for (const auto &d : llvm::enumerate(dstTp.getDimShape()))
         if (ShapedType::isDynamic(d.value()))
           dynSizes.push_back(rewriter.create<memref::LoadOp>(
               loc, dimSizesBuffer, constantIndex(rewriter, loc, d.index())));
     }
-
-    // Get the number of stored entries.
-    Value nse = createFuncCall(rewriter, loc, "getSparseTensorReaderNSE",
-                               {indexTp}, {reader}, EmitCInterface::Off)
-                    .getResult(0);
-    // Construct allocation for each field.
     SmallVector<Value> fields;
     createAllocFields(rewriter, loc, dstTp, dynSizes, /*enableInit=*/false,
                       fields, nse);
     MutSparseTensorDescriptor desc(dstTp, fields);
 
-    // Construct the `dimToLvl` buffer for handing off to the runtime library.
-    SmallVector<Value> dimToLvlValues(dimRank);
-    if (!dstTp.isIdentity()) {
-      const auto dimToLvl = dstTp.getDimToLvl();
-      assert(dimToLvl.isPermutation() && "Got non-permutation");
-      for (Level l = 0; l < lvlRank; l++) {
-        const Dimension d = dimToLvl.getDimPosition(l);
-        dimToLvlValues[d] = constantIndex(rewriter, loc, l);
-      }
-    } else {
-      // The `SparseTensorType` ctor already ensures `dimRank == lvlRank`
-      // when `isIdentity`; so no need to re-assert it here.
-      for (Dimension d = 0; d < dimRank; d++)
-        dimToLvlValues[d] = constantIndex(rewriter, loc, d);
-    }
-    Value dimToLvl = allocaBuffer(rewriter, loc, dimToLvlValues);
+    // Now construct the dim2lvl and lvl2dim buffers.
+    Value dim2lvlBuffer;
+    Value lvl2dimBuffer;
+    genReaderBuffers(rewriter, loc, dstTp, dimShapesValues, dimSizesBuffer,
+                     dim2lvlBuffer, lvl2dimBuffer);
 
     // Read the COO tensor data.
     Value xs = desc.getAOSMemRef();
     Value ys = desc.getValMemRef();
-
     const Type boolTp = rewriter.getIntegerType(1);
     const Type elemTp = dstTp.getElementType();
     const Type crdTp = dstTp.getCrdType();
@@ -1527,11 +1492,13 @@ struct SparseNewOpConverter : public OpConversionPattern<NewOp> {
                                           primaryTypeFunctionSuffix(elemTp)};
     Value isSorted =
         createFuncCall(rewriter, loc, readToBuffersFuncName, {boolTp},
-                       {reader, dimToLvl, xs, ys}, EmitCInterface::On)
+                       {reader, dim2lvlBuffer, lvl2dimBuffer, xs, ys},
+                       EmitCInterface::On)
             .getResult(0);
 
     // If the destination tensor is a sorted COO, we need to sort the COO tensor
     // data if the input elements aren't sorted yet.
+    const Level lvlRank = dstTp.getLvlRank();
     if (dstTp.isOrderedLvl(lvlRank - 1)) {
       Value kFalse = constantI1(rewriter, loc, false);
       Value notSorted = rewriter.create<arith::CmpIOp>(
@@ -1593,7 +1560,7 @@ void mlir::populateSparseTensorCodegenPatterns(
                                             StorageSpecifierKind::DimStride>,
                SparseToPositionsConverter, SparseToCoordinatesConverter,
                SparseToCoordinatesBufferConverter, SparseToValuesConverter,
-               SparseConvertConverter, SparseNewOpConverter,
+               SparseConvertConverter, SparseNewConverter,
                SparseNumberOfEntriesConverter>(typeConverter,
                                                patterns.getContext());
   patterns.add<SparseTensorDeallocConverter>(
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index a3361c2cd48c6dd..eb0c5160e8d6193 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -46,8 +46,7 @@ static std::optional<Type> convertSparseTensorTypes(Type type) {
   return std::nullopt;
 }
 
-/// Replaces the `op` with  a `CallOp` to the function reference returned
-/// by `getFunc()`.
+/// Replaces the `op` with a `CallOp` to the `getFunc()` function reference.
 static func::CallOp replaceOpWithFuncCall(RewriterBase &rewriter, Operation *op,
                                           StringRef name, TypeRange resultType,
                                           ValueRange operands,
@@ -141,27 +140,6 @@ static SmallVector<Value> getDimSizes(OpBuilder &builder, Location loc,
   return out;
 }
 
-/// Populates the array with the dimension-shape of the given
-/// `SparseTensorType`, where dynamic sizes are represented by zero.
-static void fillDimShape(OpBuilder &builder, Location loc, SparseTensorType stt,
-                         SmallVectorImpl<Value> &out) {
-  out.clear();
-  out.reserve(stt.getDimRank());
-  for (const DynSize sh : stt.getDimShape()) {
-    const auto s = ShapedType::isDynamic(sh) ? 0 : sh;
-    out.push_back(constantIndex(builder, loc, s));
-  }
-}
-
-/// Returns an array with the dimension-shape of the given `SparseTensorType`,
-/// where dynamic sizes are represented by zero.
-static SmallVector<Value> getDimShape(OpBuilder &builder, Location loc,
-                                      SparseTensorType stt) {
-  SmallVector<Value> out;
-  fillDimShape(builder, loc, stt, out);
-  return out;
-}
-
 /// Generates an uninitialized buffer of the given size and type,
 /// but returns it as type `memref<? x $tp>` (rather than as type
 /// `memref<$sz x $tp>`). Unlike temporary buffers on the stack,
@@ -503,84 +481,27 @@ class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
     const auto stt = getSparseTensorType(op);
     if (!stt.hasEncoding())
       return failure();
-    const Dimension dimRank = stt.getDimRank();
-    const Level lvlRank = stt.getLvlRank();
-    // Construct the dimShape.
-    SmallVector<Value> dimShapeValues = getDimShape(rewriter, loc, stt);
-    Value dimShapeBuffer = allocaBuffer(rewriter, loc, dimShapeValues);
-    // Allocate `SparseTensorReader` and perform all initial setup that
-    // does not depend on lvlSizes (nor dimToLvl, lvlToDim, etc).
-    Type opaqueTp = getOpaquePointerType(rewriter);
-    Value valTp =
-        constantPrimaryTypeEncoding(rewriter, loc, stt.getElementType());
-    Value reader =
-        createFuncCall(rewriter, loc, "createCheckedSparseTensorReader",
-                       opaqueTp,
-                       {adaptor.getOperands()[0], dimShapeBuffer, valTp},
-                       EmitCInterface::On)
-            .getResult(0);
-    // Construct the lvlSizes.  If the dimShape is static, then it's
-    // identical to dimSizes: so we can compute lvlSizes entirely at
-    // compile-time.  If dimShape is dynamic, then we'll need to generate
-    // code for computing lvlSizes from the `reader`'s actual dimSizes.
-    //
-    // TODO: For now we're still assuming `dimToLvl` is a permutation.
-    // But since we're computing lvlSizes here (rather than in the runtime),
-    // we can easily generalize that simply by adjusting this code.
-    //
-    // FIXME: reduce redundancy vs `NewCallParams::genBuffers`.
+    // Construct the reader opening method calls.
+    SmallVector<Value> dimShapesValues;
     Value dimSizesBuffer;
-    if (stt.hasDynamicDimShape()) {
-      Type indexTp = rewriter.getIndexType();
-      auto memTp = MemRefType::get({ShapedType::kDynamic}, indexTp);
-      dimSizesBuffer =
-          createFuncCall(rewriter, loc, "getSparseTensorReaderDimSizes", memTp,
-                         reader, EmitCInterface::On)
-              .getResult(0);
-    }
-    Value lvlSizesBuffer;
-    Value lvlToDimBuffer;
-    Value dimToLvlBuffer;
-    if (!stt.isIdentity()) {
-      const auto dimToLvl = stt.getDimToLvl();
-      assert(dimToLvl.isPermutation() && "Got non-permutation");
-      // We preinitialize `dimToLvlValues` since we need random-access writing.
-      // And we preinitialize the others for stylistic consistency.
-      SmallVector<Value> lvlSizeValues(lvlRank);
-      SmallVector<Value> lvlToDimValues(lvlRank);
-      SmallVector<Value> dimToLvlValues(dimRank);
-      for (Level l = 0; l < lvlRank; l++) {
-        // The `d`th source variable occurs in the `l`th result position.
-        Dimension d = dimToLvl.getDimPosition(l);
-        Value lvl = constantIndex(rewriter, loc, l);
-        Value dim = constantIndex(rewriter, loc, d);
-        dimToLvlValues[d] = lvl;
-        lvlToDimValues[l] = dim;
-        lvlSizeValues[l] =
-            stt.isDynamicDim(d)
-                ? rewriter.create<memref::LoadOp>(loc, dimSizesBuffer, dim)
-                : dimShapeValues[d];
-      }
-      lvlSizesBuffer = allocaBuffer(rewriter, loc, lvlSizeValues);
-      lvlToDimBuffer = allocaBuffer(rewriter, loc, lvlToDimValues);
-      dimToLvlBuffer = allocaBuffer(rewriter, loc, dimToLvlValues);
-    } else {
-      // The `SparseTensorType` ctor already ensures `dimRank == lvlRank`
-      // when `isIdentity`; so no need to re-assert it here.
-      SmallVector<Value> iotaValues;
-      iotaValues.reserve(lvlRank);
-      for (Level l = 0; l < lvlRank; l++)
-        iotaValues.push_back(constantIndex(rewriter, loc, l));
-      lvlSizesBuffer = dimSizesBuffer ? dimSizesBuffer : dimShapeBuffer;
-      dimToLvlBuffer = lvlToDimBuffer = allocaBuffer(rewriter, loc, iotaValues);
-    }
+    Value reader = genReader(rewriter, loc, stt, adaptor.getOperands()[0],
+                             dimShapesValues, dimSizesBuffer);
+    // Now construct the lvlSizes, dim2lvl, and lvl2dim buffers.
+    Value dim2lvlBuffer;
+    Value lvl2dimBuffer;
+    Value lvlSizesBuffer =
+        genReaderBuffers(rewriter, loc, stt, dimShapesValues, dimSizesBuffer,
+                         dim2lvlBuffer, lvl2dimBuffer);
     // Use the `reader` to parse the file.
+    Type opaqueTp = getOpaquePointerType(rewriter);
+    Type eltTp = stt.getElementType();
+    Value valTp = constantPrimaryTypeEncoding(rewriter, loc, eltTp);
     SmallVector<Value, 8> params{
         reader,
         lvlSizesBuffer,
         genLvlTypesBuffer(rewriter, loc, stt),
-        lvlToDimBuffer,
-        dimToLvlBuffer,
+        dim2lvlBuffer,
+        lvl2dimBuffer,
         constantPosTypeEncoding(rewriter, loc, stt.getEncoding()),
         constantCrdTypeEncoding(rewriter, loc, stt.getEncoding()),
         valTp};
diff --git a/mlir/lib/ExecutionEngine/SparseTensor/CMakeLists.txt b/mlir/lib/ExecutionEngine/SparseTensor/CMakeLists.txt
index 085d83634a702a8..c48af17b2d94bb7 100644
--- a/mlir/lib/ExecutionEngine/SparseTensor/CMakeLists.txt
+++ b/mlir/lib/ExecutionEngine/SparseTensor/CMakeLists.txt
@@ -7,6 +7,7 @@
 # that is reserved/intended for shared libraries only.
 add_mlir_library(MLIRSparseTensorRuntime
   File.cpp
+  MapRef.cpp
   NNZ.cpp
   Storage.cpp
 
diff --git a/mlir/lib/ExecutionEngine/SparseTensor/MapRef.cpp b/mlir/lib/ExecutionEngine/SparseTensor/MapRef.cpp
new file mode 100644
index 000000000000000..ed458afeae746bc
--- /dev/null
+++ b/mlir/lib/ExecutionEngine/SparseTensor/MapRef.cpp
@@ -0,0 +1,52 @@
+//===- MapRef.cpp - A dim2lvl/lvl2dim map reference wrapper ---------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include <vector>
+
+#include "mlir/ExecutionEngine/SparseTensor/MapRef.h"
+
+mlir::sparse_tensor::MapRef::MapRef(uint64_t d, uint64_t l, const uint64_t *d2l,
+                                    const uint64_t *l2d)
+    : dimRank(d), lvlRank(l), dim2lvl(d2l), lvl2dim(l2d) {
+  assert(d2l && l2d);
+  // Determine the kind of mapping (and asserts on simple inference).
+  if (isIdentity()) {
+    kind = MapKind::kIdentity;
+    for (uint64_t i = 0; i < dimRank; i++)
+      assert(lvl2dim[i] == i);
+  } else if (isPermutation()) {
+    kind = MapKind::kPermutation;
+    for (uint64_t i = 0; i < dimRank; i++)
+      assert(lvl2dim[dim2lvl[i]] == i);
+  } else {
+    kind = MapKind::kAffine;
+  }
+}
+
+bool mlir::sparse_tensor::MapRef::isIdentity() const {
+  if (dimRank != lvlRank)
+    return false;
+  for (uint64_t i = 0; i < dimRank; i++) {
+    if (dim2lvl[i] != i)
+      return false;
+  }
+  return true;
+}
+
+bool mlir::sparse_tensor::MapRef::isPermutation() const {
+  if (dimRank != lvlRank)
+    return false;
+  std::vector<bool> seen(dimRank, false);
+  for (uint64_t i = 0; i < dimRank; i++) {
+    const uint64_t j = dim2lvl[i];
+    if (j >= dimRank || seen[j])
+      return false;
+    seen[j] = true;
+  }
+  return true;
+}
diff --git a/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp b/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp
index 82cb6d3aeefa35f..5b910716c0f9e59 100644
--- a/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp
+++ b/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp
@@ -226,11 +226,7 @@ extern "C" {
 static_assert(std::is_same<index_type, uint64_t>::value,
               "Expected index_type == uint64_t");
 
-// TODO: this swiss-army-knife should be split up into separate functions
-// for each action, since the various actions don't agree on (1) whether
-// the first two arguments are "sizes" vs "shapes", (2) whether the "lvl"
-// arguments are actually storage-levels vs target tensor-dimensions,
-// (3) whether all the arguments are actually used/required.
+// The Swiss-army-knife for sparse tensor creation.
 void *_mlir_ciface_newSparseTensor( // NOLINT
     StridedMemRefType<index_type, 1> *dimSizesRef,
     StridedMemRefType<index_type, 1> *lvlSizesRef,
@@ -241,18 +237,18 @@ void *_mlir_ciface_newSparseTensor( // NOLINT
   ASSERT_NO_STRIDE(dimSizesRef);
   ASSERT_NO_STRIDE(lvlSizesRef);
   ASSERT_NO_STRIDE(lvlTypesRef);
-  ASSERT_NO_STRIDE(lvl2dimRef);
   ASSERT_NO_STRIDE(dim2lvlRef);
+  ASSERT_NO_STRIDE(lvl2dimRef);
   const uint64_t dimRank = MEMREF_GET_USIZE(dimSizesRef);
   const uint64_t lvlRank = MEMREF_GET_USIZE(lvlSizesRef);
-  ASSERT_USIZE_EQ(dim2lvlRef, dimRank);
   ASSERT_USIZE_EQ(lvlTypesRef, lvlRank);
+  ASSERT_USIZE_EQ(dim2lvlRef, dimRank);
   ASSERT_USIZE_EQ(lvl2dimRef, lvlRank);
   const index_type *dimSizes = MEMREF_GET_PAYLOAD(dimSizesRef);
   const index_type *lvlSizes = MEMREF_GET_PAYLOAD(lvlSizesRef);
   const DimLevelType *lvlTypes = MEMREF_GET_PAYLOAD(lvlTypesRef);
-  const index_type *lvl2dim = MEMREF_GET_PAYLOAD(lvl2dimRef);
   const index_type *dim2lvl = MEMREF_GET_PAYLOAD(dim2lvlRef);
+  const index_type *lvl2dim = MEMREF_GET_PAYLOAD(lvl2dimRef);
 
   // Rewrite kIndex to kU64, to avoid introducing a bunch of new cases.
   // This is safe because of the static_assert above.
@@ -403,10 +399,7 @@ MLIR_SPARSETENSOR_FOREVERY_O(IMPL_SPARSECOORDINATES)
 #undef IMPL_SPARSECOORDINATES
 #undef IMPL_GETOVERHEAD
 
-// TODO: while this API design will work for arbitrary dim2lvl mappings,
-// we should probably move the `dimCoords`-to-`lvlCoords` computation into
-// codegen (since that could enable optimizations to remove the intermediate
-// memref).
+// TODO: use MapRef here for translation of coordinates
 #define IMPL_ADDELT(VNAME, V)                                                  \
   void *_mlir_ciface_addElt##VNAME(                                            \
       void *lvlCOO, StridedMemRefType<V, 0> *vref,                             \
@@ -506,44 +499,33 @@ void _mlir_ciface_getSparseTensorReaderDimSizes(
   aliasIntoMemref(reader.getRank(), dimSizes, *out);
 }
 
-#define IMPL_GETNEXT(VNAME, V)                                                 \
-  void _mlir_ciface_getSparseTensorReaderNext##VNAME(                          \
-      void *p, StridedMemRefType<index_type, 1> *dimCoordsRef,                 \
-      StridedMemRefType<V, 0> *vref) {                                         \
-    assert(p &&vref);                                                          \
-    auto &reader = *static_cast<SparseTensorReader *>(p);                      \
-    ASSERT_NO_STRIDE(dimCoordsRef);                                            \
-    const uint64_t dimRank = MEMREF_GET_USIZE(dimCoordsRef);                   \
-    index_type *dimCoords = MEMREF_GET_PAYLOAD(dimCoordsRef);                  \
-    V *value = MEMREF_GET_PAYLOAD(vref);                                       \
-    *value = reader.readElement<V>(dimRank, dimCoords);                        \
-  }
-MLIR_SPARSETENSOR_FOREVERY_V(IMPL_GETNEXT)
-#undef IMPL_GETNEXT
-
 #define IMPL_GETNEXT(VNAME, V, CNAME, C)                                       \
   bool _mlir_ciface_getSparseTensorReaderReadToBuffers##CNAME##VNAME(          \
       void *p, StridedMemRefType<index_type, 1> *dim2lvlRef,                   \
+      StridedMemRefType<index_type, 1> *lvl2dimRef,                            \
       StridedMemRefType<C, 1> *cref, StridedMemRefType<V, 1> *vref) {          \
     assert(p);                                                                 \
     auto &reader = *static_cast<SparseTensorReader *>(p);                      \
+    ASSERT_NO_STRIDE(dim2lvlRef);                                              \
+    ASSERT_NO_STRIDE(lvl2dimRef);                                              \
     ASSERT_NO_STRIDE(cref);                                                    \
     ASSERT_NO_STRIDE(vref);                                                    \
-    ASSERT_NO_STRIDE(dim2lvlRef);                                              \
+    const uint64_t dimRank = reader.getRank();                                 \
+    const uint64_t lvlRank = MEMREF_GET_USIZE(lvl2dimRef);                     \
     const uint64_t cSize = MEMREF_GET_USIZE(cref);                             \
     const uint64_t vSize = MEMREF_GET_USIZE(vref);                             \
-    const uint64_t lvlRank = reader.getRank();                                 \
-    assert(vSize *lvlRank <= cSize);                                           \
+    ASSERT_USIZE_EQ(dim2lvlRef, dimRank);                                      \
+    assert(cSize >= lvlRank * vSize);                                          \
     assert(vSize >= reader.getNSE() && "Not enough space in buffers");         \
-    ASSERT_USIZE_EQ(dim2lvlRef, lvlRank);                                      \
+    (void)dimRank;                                                             \
     (void)cSize;                                                               \
     (void)vSize;                                                               \
-    (void)lvlRank;                                                             \
+    index_type *dim2lvl = MEMREF_GET_PAYLOAD(dim2lvlRef);                      \
+    index_type *lvl2dim = MEMREF_GET_PAYLOAD(lvl2dimRef);                      \
     C *lvlCoordinates = MEMREF_GET_PAYLOAD(cref);                              \
     V *values = MEMREF_GET_PAYLOAD(vref);                                      \
-    index_type *dim2lvl = MEMREF_GET_PAYLOAD(dim2lvlRef);                      \
-    return reader.readToBuffers<C, V>(lvlRank, dim2lvl, lvlCoordinates,        \
-                                      values);                                 \
+    return reader.readToBuffers<C, V>(lvlRank, dim2lvl, lvl2dim,               \
+                                      lvlCoordinates, values);                 \
   }
 MLIR_SPARSETENSOR_FOREVERY_V_O(IMPL_GETNEXT)
 #undef IMPL_GETNEXT
@@ -551,8 +533,8 @@ MLIR_SPARSETENSOR_FOREVERY_V_O(IMPL_GETNEXT)
 void *_mlir_ciface_newSparseTensorFromReader(
     void *p, StridedMemRefType<index_type, 1> *lvlSizesRef,
     StridedMemRefType<DimLevelType, 1> *lvlTypesRef,
-    StridedMemRefType<index_type, 1> *lvl2dimRef,
-    StridedMemRefType<index_type, 1> *dim2lvlRef, OverheadType posTp,
+    StridedMemRefType<index_type, 1> *dim2lvlRef,
+    StridedMemRefType<index_type, 1> *lvl2dimRef, OverheadType posTp,
     OverheadType crdTp, PrimaryType valTp) {
   assert(p);
   SparseTensorReader &reader = *static_cast<SparseTensorReader *>(p);
@@ -568,13 +550,13 @@ void *_mlir_ciface_newSparseTensorFromReader(
   (void)dimRank;
   const index_type *lvlSizes = MEMREF_GET_PAYLOAD(lvlSizesRef);
   const DimLevelType *lvlTypes = MEMREF_GET_PAYLOAD(lvlTypesRef);
-  const index_type *lvl2dim = MEMREF_GET_PAYLOAD(lvl2dimRef);
   const index_type *dim2lvl = MEMREF_GET_PAYLOAD(dim2lvlRef);
+  const index_type *lvl2dim = MEMREF_GET_PAYLOAD(lvl2dimRef);
 #define CASE(p, c, v, P, C, V)                                                 \
   if (posTp == OverheadType::p && crdTp == OverheadType::c &&                  \
       valTp == PrimaryType::v)                                                 \
     return static_cast<void *>(reader.readSparseTensor<P, C, V>(               \
-        lvlRank, lvlSizes, lvlTypes, lvl2dim, dim2lvl));
+        lvlRank, lvlSizes, lvlTypes, dim2lvl, lvl2dim));
 #define CASE_SECSAME(p, v, P, V) CASE(p, p, v, P, P, V)
   // Rewrite kIndex to kU64, to avoid introducing a bunch of new cases.
   // This is safe because of the static_assert above.
diff --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir
index 4ac768c21aff8fc..ff523e70bfc914a 100644
--- a/mlir/test/Dialect/SparseTensor/codegen.mlir
+++ b/mlir/test/Dialect/SparseTensor/codegen.mlir
@@ -423,7 +423,7 @@ func.func @sparse_expansion3(%arg0: index, %arg1: index) -> memref<?xindex> {
 //   CHECK-DAG: %[[A9:.*]] = arith.constant 0.000000e+00 : f64
 //   CHECK-DAG: %[[A10:.*]] = arith.constant 1 : index
 //   CHECK-DAG: %[[A11:.*]] = arith.constant 0 : index
-//       CHECK: sparse_tensor.sort hybrid_quick_sort %[[A7]], %[[A6]]
+//       CHECK: sparse_tensor.sort_coo hybrid_quick_sort %[[A7]], %[[A6]]
 //       CHECK: %[[A12:.*]]:4 = scf.for %[[A13:.*]] = %[[A11]] to %[[A7]] step %[[A10]] iter_args(%[[A14:.*]] = %[[A0]], %[[A15:.*]] = %[[A1]], %[[A16:.*]] = %[[A2]], %[[A17:.*]] = %[[A3]])
 //       CHECK:   %[[A18:.*]] = memref.load %[[A6]]{{\[}}%[[A13]]] : memref<?xindex>
 //       CHECK:   %[[A19:.*]] = memref.load %[[A4]]{{\[}}%[[A18]]] : memref<?xf64>
@@ -471,7 +471,7 @@ func.func @sparse_compression_1d(%tensor: tensor<100xf64, #SV>,
 //       CHECK:     %[[A11:.*]] = arith.constant 0.000000e+00 : f64
 //       CHECK:     %[[A12:.*]] = arith.constant 1 : index
 //       CHECK:     %[[A13:.*]] = arith.constant 0 : index
-//       CHECK:     sparse_tensor.sort hybrid_quick_sort %[[A7]], %[[A6]]
+//       CHECK:     sparse_tensor.sort_coo hybrid_quick_sort %[[A7]], %[[A6]]
 //       CHECK:     %[[A14:.*]]:4 = scf.for %[[A15:.*]] = %[[A13]] to %[[A7]] step %[[A12]] iter_args(%[[A16:.*]] = %[[A0]], %[[A17:.*]] = %[[A1]], %[[A18:.*]] = %[[A2]], %[[A19:.*]] = %[[A3]]) -> (memref<?xi32>, memref<?xi64>, memref<?xf64>, !sparse_tensor.storage_specifier
 //       CHECK:       %[[A20:.*]] = memref.load %[[A6]]{{\[}}%[[A15]]] : memref<?xindex>
 //       CHECK:       %[[A21:.*]] = memref.load %[[A4]]{{\[}}%[[A20]]] : memref<?xf64>
@@ -507,7 +507,7 @@ func.func @sparse_compression(%tensor: tensor<8x8xf64, #CSR>,
   return %1 : tensor<8x8xf64, #CSR>
 }
 
-// CHECK-LABEL: func.func private @_insert_dense_compressed_nonordered_8_8_f64_0_0(
+// CHECK-LABEL: func.func private @_insert_dense_compressed_no_8_8_f64_0_0(
 //  CHECK-SAME: %[[A1:.*0]]: memref<?xindex>,
 //  CHECK-SAME: %[[A2:.*1]]: memref<?xindex>,
 //  CHECK-SAME: %[[A3:.*2]]: memref<?xf64>,
@@ -533,7 +533,7 @@ func.func @sparse_compression(%tensor: tensor<8x8xf64, #CSR>,
 //       CHECK:     %[[A13:.*]]:4 = scf.for %[[A14:.*]] = %[[A11]] to %[[A7]] step %[[A12]] iter_args(%[[A15:.*]] = %[[A0]], %[[A16:.*]] = %[[A1]], %[[A17:.*]] = %[[A2]], %[[A18:.*]] = %[[A3]]) -> (memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
 //       CHECK:       %[[A19:.*]] = memref.load %[[A6]]{{\[}}%[[A14]]] : memref<?xindex>
 //       CHECK:       %[[A20:.*]] = memref.load %[[A4]]{{\[}}%[[A19]]] : memref<?xf64>
-//       CHECK:       %[[A21:.*]]:4 = func.call @_insert_dense_compressed_nonordered_8_8_f64_0_0(%[[A15]], %[[A16]], %[[A17]], %[[A18]], %[[A8]], %[[A19]], %[[A20]]) : (memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
+//       CHECK:       %[[A21:.*]]:4 = func.call @_insert_dense_compressed_no_8_8_f64_0_0(%[[A15]], %[[A16]], %[[A17]], %[[A18]], %[[A8]], %[[A19]], %[[A20]]) : (memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
 //       CHECK:       memref.store %[[A10]], %[[A4]]{{\[}}%[[A19]]] : memref<?xf64>
 //       CHECK:       memref.store %[[A9]], %[[A5]]{{\[}}%[[A19]]] : memref<?xi1>
 //       CHECK:       scf.yield %[[A21]]#0, %[[A21]]#1, %[[A21]]#2, %[[A21]]#3 : memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
@@ -611,7 +611,7 @@ func.func @sparse_insert_typed(%arg0: tensor<128xf64, #SparseVector>, %arg1: ind
   return %1 : tensor<128xf64, #SparseVector>
 }
 
-// CHECK-LABEL: func.func private @_insert_compressed_nonunique_singleton_5_6_f64_0_0(
+// CHECK-LABEL: func.func private @_insert_compressed_nu_singleton_5_6_f64_0_0(
 //  CHECK-SAME: %[[A1:.*0]]: memref<?xindex>,
 //  CHECK-SAME: %[[A2:.*1]]: memref<?xindex>,
 //  CHECK-SAME: %[[A3:.*2]]: memref<?xf64>,
@@ -627,7 +627,7 @@ func.func @sparse_insert_typed(%arg0: tensor<128xf64, #SparseVector>, %arg1: ind
 //  CHECK-SAME: %[[A3:.*3]]: !sparse_tensor.storage_specifier
 //  CHECK-SAME: %[[A4:.*4]]: index,
 //  CHECK-SAME: %[[A5:.*5]]: f64)
-//       CHECK: %[[R:.*]]:4 = call @_insert_compressed_nonunique_singleton_5_6_f64_0_0(%[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A4]], %[[A5]])
+//       CHECK: %[[R:.*]]:4 = call @_insert_compressed_nu_singleton_5_6_f64_0_0(%[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A4]], %[[A5]])
 //       CHECK: return %[[R]]#0, %[[R]]#1, %[[R]]#2, %[[R]]#3
 func.func @sparse_insert_coo(%arg0: tensor<5x6xf64, #Coo>, %arg1: index, %arg2: f64) -> tensor<5x6xf64, #Coo> {
   %0 = sparse_tensor.insert %arg2 into %arg0[%arg1, %arg1] : tensor<5x6xf64, #Coo>
@@ -665,90 +665,94 @@ func.func @sparse_convert_element_type(%arg0: tensor<32xf32, #SparseVector>) ->
 
 // CHECK-LABEL: func.func @sparse_new_coo(
 // CHECK-SAME:  %[[A0:.*]]: !llvm.ptr<i8>) -> (memref<?xindex>, memref<?xindex>, memref<?xf32>, !sparse_tensor.storage_specifier<#sparse_tensor.encoding<{{{.*}}}>>) {
-//   CHECK-DAG: %[[A1:.*]] = arith.constant false
-//   CHECK-DAG: %[[A2:.*]] = arith.constant 1 : index
-//   CHECK-DAG: %[[A3:.*]] = arith.constant 0 : index
-//   CHECK-DAG: %[[A4:.*]] = arith.constant 2 : index
-//   CHECK-DAG: %[[C2:.*]] = arith.constant 2 : i32
-//       CHECK: %[[D0:.*]] = memref.alloca() : memref<2xindex>
-//       CHECK: %[[D1:.*]] = memref.cast %[[D0]] : memref<2xindex> to memref<?xindex>
-//       CHECK: memref.store %[[A3]], %[[D0]][%[[A3]]] : memref<2xindex
-//       CHECK: memref.store %[[A3]], %[[D0]][%[[A2]]] : memref<2xindex>
-//       CHECK: %[[A5:.*]] = call @createCheckedSparseTensorReader(%[[A0]], %[[D1]], %[[C2]])
-//       CHECK: %[[D2:.*]] = call @getSparseTensorReaderDimSizes(%0) : (!llvm.ptr<i8>) -> memref<?xindex>
-//       CHECK: %[[A8:.*]] = memref.load %[[D2]]{{\[}}%[[A3]]] : memref<?xindex>
-//       CHECK: %[[A9:.*]] = memref.load %[[D2]]{{\[}}%[[A2]]] : memref<?xindex>
-//       CHECK: %[[A10:.*]] = call @getSparseTensorReaderNSE(%[[A5]])
-//       CHECK: %[[A11:.*]] = arith.muli %[[A10]], %[[A4]] : index
-//       CHECK: %[[A12:.*]] = memref.alloc() : memref<2xindex>
-//       CHECK: %[[A13:.*]] = memref.cast %[[A12]] : memref<2xindex> to memref<?xindex>
-//       CHECK: %[[A14:.*]] = memref.alloc(%[[A11]]) : memref<?xindex>
-//       CHECK: %[[A15:.*]] = memref.alloc(%[[A10]]) : memref<?xf32>
-//       CHECK: %[[A16:.*]] = sparse_tensor.storage_specifier.init : !sparse_tensor.storage_specifier<#sparse_tensor.encoding<{{{.*}}}>>
-//       CHECK: %[[A18:.*]] = sparse_tensor.storage_specifier.set %[[A16]]  lvl_sz at 0 with %[[A8]]
-//       CHECK: %[[A19:.*]] = sparse_tensor.storage_specifier.get %[[A18]]  pos_mem_sz at 0
-//       CHECK: %[[A21:.*]], %[[A22:.*]] = sparse_tensor.push_back %[[A19]], %[[A13]], %[[A3]]
-//       CHECK: %[[A24:.*]] = sparse_tensor.storage_specifier.set %[[A18]]  pos_mem_sz at 0 with %[[A22]]
-//       CHECK: %[[A26:.*]] = sparse_tensor.storage_specifier.set %[[A24]]  lvl_sz at 1 with %[[A9]]
-//       CHECK: %[[A27:.*]], %[[A28:.*]] = sparse_tensor.push_back %[[A22]], %[[A21]], %[[A3]], %[[A2]]
-//       CHECK: %[[A30:.*]] = sparse_tensor.storage_specifier.set %[[A26]]  pos_mem_sz at 0 with %[[A28]]
-//       CHECK: %[[A31:.*]] = memref.alloca() : memref<2xindex>
-//       CHECK: %[[A32:.*]] = memref.cast %[[A31]] : memref<2xindex> to memref<?xindex>
-//       CHECK: memref.store %[[A3]], %[[A31]]{{\[}}%[[A3]]] : memref<2xindex>
-//       CHECK: memref.store %[[A2]], %[[A31]]{{\[}}%[[A2]]] : memref<2xindex>
-//       CHECK: %[[A33:.*]] = call @getSparseTensorReaderReadToBuffers0F32(%[[A5]], %[[A32]], %[[A14]], %[[A15]])
-//       CHECK: %[[A34:.*]] = arith.cmpi eq, %[[A33]], %[[A1]] : i1
-//       CHECK: scf.if %[[A34]] {
-//       CHECK:   sparse_tensor.sort  hybrid_quick_sort %[[A10]], %[[A14]] jointly %[[A15]] {ny = 0 : index, perm_map = #{{.*}}} : memref<?xindex> jointly memref<?xf32>
-//       CHECK: }
-//       CHECK: memref.store %[[A10]], %[[A27]]{{\[}}%[[A2]]] : memref<?xindex>
-//       CHECK: %[[A36:.*]] = sparse_tensor.storage_specifier.set %[[A30]]  crd_mem_sz at 0 with %[[A11]]
-//       CHECK: %[[A38:.*]] = sparse_tensor.storage_specifier.set %[[A36]]  val_mem_sz with %[[A10]]
-//       CHECK: call @delSparseTensorReader(%[[A5]]) : (!llvm.ptr<i8>) -> ()
-//       CHECK: return %[[A27]], %[[A14]], %[[A15]], %[[A38]]
+// CHECK-DAG:   %[[VAL_1:.*]] = arith.constant false
+// CHECK-DAG:   %[[VAL_2:.*]] = arith.constant 2 : i32
+// CHECK-DAG:   %[[VAL_3:.*]] = arith.constant 1 : index
+// CHECK-DAG:   %[[VAL_4:.*]] = arith.constant 0 : index
+// CHECK-DAG:   %[[VAL_5:.*]] = arith.constant 2 : index
+// CHECK:       %[[VAL_6:.*]] = memref.alloca() : memref<2xindex>
+// CHECK:       %[[VAL_7:.*]] = memref.cast %[[VAL_6]] : memref<2xindex> to memref<?xindex>
+// CHECK:       memref.store %[[VAL_4]], %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref<2xindex>
+// CHECK:       memref.store %[[VAL_4]], %[[VAL_6]]{{\[}}%[[VAL_3]]] : memref<2xindex>
+// CHECK:       %[[VAL_8:.*]] = call @createCheckedSparseTensorReader(%[[A0]], %[[VAL_7]], %[[VAL_2]]) : (!llvm.ptr<i8>, memref<?xindex>, i32) -> !llvm.ptr<i8>
+// CHECK:       %[[VAL_9:.*]] = call @getSparseTensorReaderDimSizes(%[[VAL_8]]) : (!llvm.ptr<i8>) -> memref<?xindex>
+// CHECK:       %[[VAL_10:.*]] = call @getSparseTensorReaderNSE(%[[VAL_8]]) : (!llvm.ptr<i8>) -> index
+// CHECK:       %[[VAL_11:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_4]]] : memref<?xindex>
+// CHECK:       %[[VAL_12:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_3]]] : memref<?xindex>
+// CHECK:       %[[VAL_13:.*]] = arith.muli %[[VAL_10]], %[[VAL_5]] : index
+// CHECK:       %[[VAL_14:.*]] = memref.alloc() : memref<2xindex>
+// CHECK:       %[[VAL_15:.*]] = memref.cast %[[VAL_14]] : memref<2xindex> to memref<?xindex>
+// CHECK:       %[[VAL_16:.*]] = memref.alloc(%[[VAL_13]]) : memref<?xindex>
+// CHECK:       %[[VAL_17:.*]] = memref.alloc(%[[VAL_10]]) : memref<?xf32>
+// CHECK:       %[[VAL_18:.*]] = sparse_tensor.storage_specifier.init
+// CHECK:       %[[VAL_19:.*]] = sparse_tensor.storage_specifier.set %[[VAL_18]]  lvl_sz at 0 with %[[VAL_11]]
+// CHECK:       %[[VAL_20:.*]] = sparse_tensor.storage_specifier.get %[[VAL_19]]  pos_mem_sz at 0
+// CHECK:       %[[VAL_21:.*]], %[[VAL_22:.*]] = sparse_tensor.push_back %[[VAL_20]], %[[VAL_15]], %[[VAL_4]]
+// CHECK:       %[[VAL_23:.*]] = sparse_tensor.storage_specifier.set %[[VAL_19]]  pos_mem_sz at 0 with %[[VAL_22]]
+// CHECK:       %[[VAL_24:.*]] = sparse_tensor.storage_specifier.set %[[VAL_23]]  lvl_sz at 1 with %[[VAL_12]]
+// CHECK:       %[[VAL_25:.*]], %[[VAL_26:.*]] = sparse_tensor.push_back %[[VAL_22]], %[[VAL_21]], %[[VAL_4]], %[[VAL_3]]
+// CHECK:       %[[VAL_27:.*]] = sparse_tensor.storage_specifier.set %[[VAL_24]]  pos_mem_sz at 0 with %[[VAL_26]]
+// CHECK:       %[[VAL_28:.*]] = memref.alloca() : memref<2xindex>
+// CHECK:       %[[VAL_29:.*]] = memref.cast %[[VAL_28]] : memref<2xindex> to memref<?xindex>
+// CHECK:       memref.store %[[VAL_4]], %[[VAL_28]]{{\[}}%[[VAL_4]]] : memref<2xindex>
+// CHECK:       memref.store %[[VAL_3]], %[[VAL_28]]{{\[}}%[[VAL_3]]] : memref<2xindex>
+// CHECK:       %[[VAL_30:.*]] = call @getSparseTensorReaderReadToBuffers0F32(%[[VAL_8]], %[[VAL_29]], %[[VAL_29]], %[[VAL_16]], %[[VAL_17]]) : (!llvm.ptr<i8>, memref<?xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32>) -> i1
+// CHECK:       %[[VAL_31:.*]] = arith.cmpi eq, %[[VAL_30]], %[[VAL_1]] : i1
+// CHECK:       scf.if %[[VAL_31]] {
+// CHECK:         sparse_tensor.sort_coo  hybrid_quick_sort %[[VAL_10]], %[[VAL_16]] jointly %[[VAL_17]]
+// CHECK:       }
+// CHECK:       memref.store %[[VAL_10]], %[[VAL_25]]{{\[}}%[[VAL_3]]] : memref<?xindex>
+// CHECK:       %[[VAL_32:.*]] = sparse_tensor.storage_specifier.set %[[VAL_27]]  crd_mem_sz at 0 with %[[VAL_13]]
+// CHECK:       %[[VAL_33:.*]] = sparse_tensor.storage_specifier.set %[[VAL_32]]  val_mem_sz with %[[VAL_10]]
+// CHECK:       call @delSparseTensorReader(%[[VAL_8]]) : (!llvm.ptr<i8>) -> ()
+// CHECK:       return %[[VAL_25]], %[[VAL_16]], %[[VAL_17]], %[[VAL_33]]
 func.func @sparse_new_coo(%arg0: !llvm.ptr<i8>) -> tensor<?x?xf32, #Coo> {
   %0 = sparse_tensor.new %arg0 : !llvm.ptr<i8> to tensor<?x?xf32, #Coo>
   return %0 : tensor<?x?xf32, #Coo>
 }
 
 // CHECK-LABEL: func.func @sparse_new_coo_permute_no(
-//  CHECK-SAME: %[[A0:.*]]: !llvm.ptr<i8>) -> (memref<?xindex>, memref<?xindex>, memref<?xf32>, !sparse_tensor.storage_specifier<#sparse_tensor.encoding<{{{.*}}}>>) {
-//   CHECK-DAG: %[[A1:.*]] = arith.constant 1 : index
-//   CHECK-DAG: %[[A2:.*]] = arith.constant 0 : index
-//   CHECK-DAG: %[[A3:.*]] = arith.constant 2 : index
-//   CHECK-DAG: %[[C2:.*]] = arith.constant 2 : i32
-//       CHECK: %[[D0:.*]] = memref.alloca() : memref<2xindex>
-//       CHECK: %[[D1:.*]] = memref.cast %[[D0]] : memref<2xindex> to memref<?xindex>
-//       CHECK: memref.store %[[A2]], %[[D0]][%[[A2]]] : memref<2xindex
-//       CHECK: memref.store %[[A2]], %[[D0]][%[[A1]]] : memref<2xindex>
-//       CHECK: %[[A4:.*]] = call @createCheckedSparseTensorReader(%[[A0]], %[[D1]], %[[C2]])
-//       CHECK: %[[D2:.*]] = call @getSparseTensorReaderDimSizes(%0) : (!llvm.ptr<i8>) -> memref<?xindex>
-//       CHECK: %[[A7:.*]] = memref.load %[[D2]]{{\[}}%[[A2]]] : memref<?xindex>
-//       CHECK: %[[A8:.*]] = memref.load %[[D2]]{{\[}}%[[A1]]] : memref<?xindex>
-//       CHECK: %[[A9:.*]] = call @getSparseTensorReaderNSE(%[[A4]])
-//       CHECK: %[[A10:.*]] = arith.muli %[[A9]], %[[A3]] : index
-//       CHECK: %[[A11:.*]] = memref.alloc() : memref<2xindex>
-//       CHECK: %[[A12:.*]] = memref.cast %[[A11]] : memref<2xindex> to memref<?xindex>
-//       CHECK: %[[A13:.*]] = memref.alloc(%[[A10]]) : memref<?xindex>
-//       CHECK: %[[A14:.*]] = memref.alloc(%[[A9]]) : memref<?xf32>
-//       CHECK: %[[A15:.*]] = sparse_tensor.storage_specifier.init : !sparse_tensor.storage_specifier<#sparse_tensor.encoding<{{{.*}}}>>
-//       CHECK: %[[A17:.*]] = sparse_tensor.storage_specifier.set %[[A15]]  lvl_sz at 0 with %[[A8]]
-//       CHECK: %[[A18:.*]] = sparse_tensor.storage_specifier.get %[[A17]]  pos_mem_sz at 0
-//       CHECK: %[[A20:.*]], %[[A21:.*]] = sparse_tensor.push_back %[[A18]], %[[A12]], %[[A2]]
-//       CHECK: %[[A23:.*]] = sparse_tensor.storage_specifier.set %[[A17]]  pos_mem_sz at 0 with %[[A21]]
-//       CHECK: %[[A25:.*]] = sparse_tensor.storage_specifier.set %[[A23]]  lvl_sz at 1 with %[[A7]]
-//       CHECK: %[[A26:.*]], %[[A27:.*]] = sparse_tensor.push_back %[[A21]], %[[A20]], %[[A2]], %[[A1]]
-//       CHECK: %[[A29:.*]] = sparse_tensor.storage_specifier.set %[[A25]]  pos_mem_sz at 0 with %[[A27]]
-//       CHECK: %[[A30:.*]] = memref.alloca() : memref<2xindex>
-//       CHECK: %[[A31:.*]] = memref.cast %[[A30]] : memref<2xindex> to memref<?xindex>
-//       CHECK: memref.store %[[A1]], %[[A30]]{{\[}}%[[A2]]] : memref<2xindex>
-//       CHECK: memref.store %[[A2]], %[[A30]]{{\[}}%[[A1]]] : memref<2xindex>
-//       CHECK: %[[A32:.*]] = call @getSparseTensorReaderReadToBuffers0F32(%[[A4]], %[[A31]], %[[A13]], %[[A14]])
-//       CHECK: memref.store %[[A9]], %[[A26]]{{\[}}%[[A1]]] : memref<?xindex>
-//       CHECK: %[[A34:.*]] = sparse_tensor.storage_specifier.set %[[A29]]  crd_mem_sz at 0 with %[[A10]]
-//       CHECK: %[[A36:.*]] = sparse_tensor.storage_specifier.set %[[A34]]  val_mem_sz with %[[A9]]
-//       CHECK: call @delSparseTensorReader(%[[A4]]) : (!llvm.ptr<i8>) -> ()
-//       CHECK: return %[[A26]], %[[A13]], %[[A14]], %[[A36]]
+// CHECK-SAME:  %[[A0:.*]]: !llvm.ptr<i8>) -> (memref<?xindex>, memref<?xindex>, memref<?xf32>, !sparse_tensor.storage_specifier<#sparse_tensor.encoding<{{{.*}}}>>) {
+// CHECK-DAG:   %[[VAL_1:.*]] = arith.constant 2 : i32
+// CHECK-DAG:   %[[VAL_2:.*]] = arith.constant 1 : index
+// CHECK-DAG:   %[[VAL_3:.*]] = arith.constant 0 : index
+// CHECK-DAG:   %[[VAL_4:.*]] = arith.constant 2 : index
+// CHECK:       %[[VAL_5:.*]] = memref.alloca() : memref<2xindex>
+// CHECK:       %[[VAL_6:.*]] = memref.cast %[[VAL_5]] : memref<2xindex> to memref<?xindex>
+// CHECK:       memref.store %[[VAL_3]], %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref<2xindex>
+// CHECK:       memref.store %[[VAL_3]], %[[VAL_5]]{{\[}}%[[VAL_2]]] : memref<2xindex>
+// CHECK:       %[[VAL_7:.*]] = call @createCheckedSparseTensorReader(%[[A0]], %[[VAL_6]], %[[VAL_1]]) : (!llvm.ptr<i8>, memref<?xindex>, i32) -> !llvm.ptr<i8>
+// CHECK:       %[[VAL_8:.*]] = call @getSparseTensorReaderDimSizes(%[[VAL_7]]) : (!llvm.ptr<i8>) -> memref<?xindex>
+// CHECK:       %[[VAL_9:.*]] = call @getSparseTensorReaderNSE(%[[VAL_7]]) : (!llvm.ptr<i8>) -> index
+// CHECK:       %[[VAL_10:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_3]]] : memref<?xindex>
+// CHECK:       %[[VAL_11:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_2]]] : memref<?xindex>
+// CHECK:       %[[VAL_12:.*]] = arith.muli %[[VAL_9]], %[[VAL_4]] : index
+// CHECK:       %[[VAL_13:.*]] = memref.alloc() : memref<2xindex>
+// CHECK:       %[[VAL_14:.*]] = memref.cast %[[VAL_13]] : memref<2xindex> to memref<?xindex>
+// CHECK:       %[[VAL_15:.*]] = memref.alloc(%[[VAL_12]]) : memref<?xindex>
+// CHECK:       %[[VAL_16:.*]] = memref.alloc(%[[VAL_9]]) : memref<?xf32>
+// CHECK:       %[[VAL_17:.*]] = sparse_tensor.storage_specifier.init
+// CHECK:       %[[VAL_18:.*]] = sparse_tensor.storage_specifier.set %[[VAL_17]]  lvl_sz at 0 with %[[VAL_11]]
+// CHECK:       %[[VAL_19:.*]] = sparse_tensor.storage_specifier.get %[[VAL_18]]  pos_mem_sz at 0
+// CHECK:       %[[VAL_20:.*]], %[[VAL_21:.*]] = sparse_tensor.push_back %[[VAL_19]], %[[VAL_14]], %[[VAL_3]]
+// CHECK:       %[[VAL_22:.*]] = sparse_tensor.storage_specifier.set %[[VAL_18]]  pos_mem_sz at 0 with %[[VAL_21]]
+// CHECK:       %[[VAL_23:.*]] = sparse_tensor.storage_specifier.set %[[VAL_22]]  lvl_sz at 1 with %[[VAL_10]]
+// CHECK:       %[[VAL_24:.*]], %[[VAL_25:.*]] = sparse_tensor.push_back %[[VAL_21]], %[[VAL_20]], %[[VAL_3]], %[[VAL_2]]
+// CHECK:       %[[VAL_26:.*]] = sparse_tensor.storage_specifier.set %[[VAL_23]]  pos_mem_sz at 0 with %[[VAL_25]]
+// CHECK:       %[[VAL_27:.*]] = memref.alloca() : memref<2xindex>
+// CHECK:       %[[VAL_28:.*]] = memref.cast %[[VAL_27]] : memref<2xindex> to memref<?xindex>
+// CHECK:       memref.store %[[VAL_2]], %[[VAL_27]]{{\[}}%[[VAL_3]]] : memref<2xindex>
+// CHECK:       memref.store %[[VAL_3]], %[[VAL_27]]{{\[}}%[[VAL_2]]] : memref<2xindex>
+// CHECK:       %[[VAL_29:.*]] = memref.alloca() : memref<2xindex>
+// CHECK:       %[[VAL_30:.*]] = memref.cast %[[VAL_29]] : memref<2xindex> to memref<?xindex>
+// CHECK:       memref.store %[[VAL_2]], %[[VAL_29]]{{\[}}%[[VAL_3]]] : memref<2xindex>
+// CHECK:       memref.store %[[VAL_3]], %[[VAL_29]]{{\[}}%[[VAL_2]]] : memref<2xindex>
+// CHECK:       %[[VAL_31:.*]] = call @getSparseTensorReaderReadToBuffers0F32(%[[VAL_7]], %[[VAL_28]], %[[VAL_30]], %[[VAL_15]], %[[VAL_16]]) : (!llvm.ptr<i8>, memref<?xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32>) -> i1
+// CHECK:       memref.store %[[VAL_9]], %[[VAL_24]]{{\[}}%[[VAL_2]]] : memref<?xindex>
+// CHECK:       %[[VAL_32:.*]] = sparse_tensor.storage_specifier.set %[[VAL_26]]  crd_mem_sz at 0 with %[[VAL_12]]
+// CHECK:       %[[VAL_33:.*]] = sparse_tensor.storage_specifier.set %[[VAL_32]]  val_mem_sz with %[[VAL_9]]
+// CHECK:       call @delSparseTensorReader(%[[VAL_7]]) : (!llvm.ptr<i8>) -> ()
+// CHECK:       return %[[VAL_24]], %[[VAL_15]], %[[VAL_16]], %[[VAL_33]]
 func.func @sparse_new_coo_permute_no(%arg0: !llvm.ptr<i8>) -> tensor<?x?xf32, #CooPNo> {
   %0 = sparse_tensor.new %arg0 : !llvm.ptr<i8> to tensor<?x?xf32, #CooPNo>
   return %0 : tensor<?x?xf32, #CooPNo>
diff --git a/mlir/test/Dialect/SparseTensor/conversion.mlir b/mlir/test/Dialect/SparseTensor/conversion.mlir
index 9d337b929fa423a..138736e26c1dfdd 100644
--- a/mlir/test/Dialect/SparseTensor/conversion.mlir
+++ b/mlir/test/Dialect/SparseTensor/conversion.mlir
@@ -114,15 +114,15 @@ func.func @sparse_new2d(%arg0: !llvm.ptr<i8>) -> tensor<?x?xf32, #CSR> {
 //   CHECK-DAG: %[[DimShape:.*]] = memref.cast %[[DimShape0]] : memref<3xindex> to memref<?xindex>
 //       CHECK: %[[Reader:.*]] = call @createCheckedSparseTensorReader(%[[A]], %[[DimShape]], %{{.*}})
 //       CHECK: %[[DimSizes:.*]] = call @getSparseTensorReaderDimSizes(%[[Reader]])
-//   CHECK-DAG: %[[LvlSizes0:.*]] = memref.alloca() : memref<3xindex>
-//   CHECK-DAG: %[[LvlSizes:.*]] = memref.cast %[[LvlSizes0]] : memref<3xindex> to memref<?xindex>
-//   CHECK-DAG: %[[Lvl2Dim0:.*]] = memref.alloca() : memref<3xindex>
-//   CHECK-DAG: %[[Lvl2Dim:.*]] = memref.cast %[[Lvl2Dim0]] : memref<3xindex> to memref<?xindex>
-//   CHECK-DAG: %[[Dim2Lvl0:.*]] = memref.alloca() : memref<3xindex>
-//   CHECK-DAG: %[[Dim2Lvl:.*]] = memref.cast %[[Dim2Lvl0]] : memref<3xindex> to memref<?xindex>
-//   CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<3xi8>
-//   CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<3xi8> to memref<?xi8>
-//       CHECK: %[[T:.*]] = call @newSparseTensorFromReader(%[[Reader]], %[[LvlSizes]], %[[LvlTypes]], %[[Lvl2Dim]], %[[Dim2Lvl]], %{{.*}}, %{{.*}}, %{{.*}})
+//       CHECK: %[[Dim2Lvl0:.*]] = memref.alloca() : memref<3xindex>
+//       CHECK: %[[Dim2Lvl:.*]] = memref.cast %[[Dim2Lvl0]] : memref<3xindex> to memref<?xindex>
+//       CHECK: %[[Lvl2Dim0:.*]] = memref.alloca() : memref<3xindex>
+//       CHECK: %[[Lvl2Dim:.*]] = memref.cast %[[Lvl2Dim0]] : memref<3xindex> to memref<?xindex>
+//       CHECK: %[[LvlSizes0:.*]] = memref.alloca() : memref<3xindex>
+//       CHECK: %[[LvlSizes:.*]] = memref.cast %[[LvlSizes0]] : memref<3xindex> to memref<?xindex>
+//       CHECK: %[[LvlTypes0:.*]] = memref.alloca() : memref<3xi8>
+//       CHECK: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<3xi8> to memref<?xi8>
+//       CHECK: %[[T:.*]] = call @newSparseTensorFromReader(%[[Reader]], %[[LvlSizes]], %[[LvlTypes]], %[[Dim2Lvl]], %[[Lvl2Dim]], %{{.*}}, %{{.*}}, %{{.*}})
 //       CHECK: call @delSparseTensorReader(%[[Reader]])
 //       CHECK: return %[[T]] : !llvm.ptr<i8>
 func.func @sparse_new3d(%arg0: !llvm.ptr<i8>) -> tensor<?x?x?xf32, #SparseTensor> {

>From d54b03e367ed34ebea5a0b06c6c6f2e4a04b93b7 Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Thu, 5 Oct 2023 15:14:22 -0700
Subject: [PATCH 2/3] fix merge conflict

---
 mlir/test/Dialect/SparseTensor/codegen.mlir | 12 ++++++------
 1 file changed, 6 insertions(+), 6 deletions(-)

diff --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir
index ff523e70bfc914a..adefceba7379f99 100644
--- a/mlir/test/Dialect/SparseTensor/codegen.mlir
+++ b/mlir/test/Dialect/SparseTensor/codegen.mlir
@@ -423,7 +423,7 @@ func.func @sparse_expansion3(%arg0: index, %arg1: index) -> memref<?xindex> {
 //   CHECK-DAG: %[[A9:.*]] = arith.constant 0.000000e+00 : f64
 //   CHECK-DAG: %[[A10:.*]] = arith.constant 1 : index
 //   CHECK-DAG: %[[A11:.*]] = arith.constant 0 : index
-//       CHECK: sparse_tensor.sort_coo hybrid_quick_sort %[[A7]], %[[A6]]
+//       CHECK: sparse_tensor.sort hybrid_quick_sort %[[A7]], %[[A6]]
 //       CHECK: %[[A12:.*]]:4 = scf.for %[[A13:.*]] = %[[A11]] to %[[A7]] step %[[A10]] iter_args(%[[A14:.*]] = %[[A0]], %[[A15:.*]] = %[[A1]], %[[A16:.*]] = %[[A2]], %[[A17:.*]] = %[[A3]])
 //       CHECK:   %[[A18:.*]] = memref.load %[[A6]]{{\[}}%[[A13]]] : memref<?xindex>
 //       CHECK:   %[[A19:.*]] = memref.load %[[A4]]{{\[}}%[[A18]]] : memref<?xf64>
@@ -471,7 +471,7 @@ func.func @sparse_compression_1d(%tensor: tensor<100xf64, #SV>,
 //       CHECK:     %[[A11:.*]] = arith.constant 0.000000e+00 : f64
 //       CHECK:     %[[A12:.*]] = arith.constant 1 : index
 //       CHECK:     %[[A13:.*]] = arith.constant 0 : index
-//       CHECK:     sparse_tensor.sort_coo hybrid_quick_sort %[[A7]], %[[A6]]
+//       CHECK:     sparse_tensor.sort hybrid_quick_sort %[[A7]], %[[A6]]
 //       CHECK:     %[[A14:.*]]:4 = scf.for %[[A15:.*]] = %[[A13]] to %[[A7]] step %[[A12]] iter_args(%[[A16:.*]] = %[[A0]], %[[A17:.*]] = %[[A1]], %[[A18:.*]] = %[[A2]], %[[A19:.*]] = %[[A3]]) -> (memref<?xi32>, memref<?xi64>, memref<?xf64>, !sparse_tensor.storage_specifier
 //       CHECK:       %[[A20:.*]] = memref.load %[[A6]]{{\[}}%[[A15]]] : memref<?xindex>
 //       CHECK:       %[[A21:.*]] = memref.load %[[A4]]{{\[}}%[[A20]]] : memref<?xf64>
@@ -507,7 +507,7 @@ func.func @sparse_compression(%tensor: tensor<8x8xf64, #CSR>,
   return %1 : tensor<8x8xf64, #CSR>
 }
 
-// CHECK-LABEL: func.func private @_insert_dense_compressed_no_8_8_f64_0_0(
+// CHECK-LABEL: func.func private @_insert_dense_compressed_nonordered_8_8_f64_0_0(
 //  CHECK-SAME: %[[A1:.*0]]: memref<?xindex>,
 //  CHECK-SAME: %[[A2:.*1]]: memref<?xindex>,
 //  CHECK-SAME: %[[A3:.*2]]: memref<?xf64>,
@@ -533,7 +533,7 @@ func.func @sparse_compression(%tensor: tensor<8x8xf64, #CSR>,
 //       CHECK:     %[[A13:.*]]:4 = scf.for %[[A14:.*]] = %[[A11]] to %[[A7]] step %[[A12]] iter_args(%[[A15:.*]] = %[[A0]], %[[A16:.*]] = %[[A1]], %[[A17:.*]] = %[[A2]], %[[A18:.*]] = %[[A3]]) -> (memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
 //       CHECK:       %[[A19:.*]] = memref.load %[[A6]]{{\[}}%[[A14]]] : memref<?xindex>
 //       CHECK:       %[[A20:.*]] = memref.load %[[A4]]{{\[}}%[[A19]]] : memref<?xf64>
-//       CHECK:       %[[A21:.*]]:4 = func.call @_insert_dense_compressed_no_8_8_f64_0_0(%[[A15]], %[[A16]], %[[A17]], %[[A18]], %[[A8]], %[[A19]], %[[A20]]) : (memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
+//       CHECK:       %[[A21:.*]]:4 = func.call @_insert_dense_compressed_nonordered_8_8_f64_0_0(%[[A15]], %[[A16]], %[[A17]], %[[A18]], %[[A8]], %[[A19]], %[[A20]]) : (memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
 //       CHECK:       memref.store %[[A10]], %[[A4]]{{\[}}%[[A19]]] : memref<?xf64>
 //       CHECK:       memref.store %[[A9]], %[[A5]]{{\[}}%[[A19]]] : memref<?xi1>
 //       CHECK:       scf.yield %[[A21]]#0, %[[A21]]#1, %[[A21]]#2, %[[A21]]#3 : memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
@@ -611,7 +611,7 @@ func.func @sparse_insert_typed(%arg0: tensor<128xf64, #SparseVector>, %arg1: ind
   return %1 : tensor<128xf64, #SparseVector>
 }
 
-// CHECK-LABEL: func.func private @_insert_compressed_nu_singleton_5_6_f64_0_0(
+// CHECK-LABEL: func.func private @_insert_compressed_nonunique_singleton_5_6_f64_0_0(
 //  CHECK-SAME: %[[A1:.*0]]: memref<?xindex>,
 //  CHECK-SAME: %[[A2:.*1]]: memref<?xindex>,
 //  CHECK-SAME: %[[A3:.*2]]: memref<?xf64>,
@@ -627,7 +627,7 @@ func.func @sparse_insert_typed(%arg0: tensor<128xf64, #SparseVector>, %arg1: ind
 //  CHECK-SAME: %[[A3:.*3]]: !sparse_tensor.storage_specifier
 //  CHECK-SAME: %[[A4:.*4]]: index,
 //  CHECK-SAME: %[[A5:.*5]]: f64)
-//       CHECK: %[[R:.*]]:4 = call @_insert_compressed_nu_singleton_5_6_f64_0_0(%[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A4]], %[[A5]])
+//       CHECK: %[[R:.*]]:4 = call @_insert_compressed_nonunique_singleton_5_6_f64_0_0(%[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A4]], %[[A5]])
 //       CHECK: return %[[R]]#0, %[[R]]#1, %[[R]]#2, %[[R]]#3
 func.func @sparse_insert_coo(%arg0: tensor<5x6xf64, #Coo>, %arg1: index, %arg2: f64) -> tensor<5x6xf64, #Coo> {
   %0 = sparse_tensor.insert %arg2 into %arg0[%arg1, %arg1] : tensor<5x6xf64, #Coo>

>From 5ecff8cfae4fb7790d41ac3e07a6b2dbb3a47403 Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Thu, 5 Oct 2023 15:17:46 -0700
Subject: [PATCH 3/3] clang-format

---
 mlir/include/mlir/ExecutionEngine/SparseTensor/MapRef.h | 6 ++++--
 1 file changed, 4 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensor/MapRef.h b/mlir/include/mlir/ExecutionEngine/SparseTensor/MapRef.h
index 1c155568802e579..a1bd6798f150b43 100644
--- a/mlir/include/mlir/ExecutionEngine/SparseTensor/MapRef.h
+++ b/mlir/include/mlir/ExecutionEngine/SparseTensor/MapRef.h
@@ -38,7 +38,8 @@ class MapRef final {
   // Push forward maps from dimensions to levels.
   //
 
-  template <typename T> inline void pushforward(const T *in, T *out) const {
+  template <typename T>
+  inline void pushforward(const T *in, T *out) const {
     switch (kind) {
     case MapKind::kIdentity:
       for (uint64_t i = 0; i < dimRank; ++i)
@@ -58,7 +59,8 @@ class MapRef final {
   // Push backward maps from levels to dimensions.
   //
 
-  template <typename T> inline void pushbackward(const T *in, T *out) const {
+  template <typename T>
+  inline void pushbackward(const T *in, T *out) const {
     switch (kind) {
     case MapKind::kIdentity:
       for (uint64_t i = 0; i < lvlRank; ++i)



More information about the Mlir-commits mailing list